Merge afc branch into dev, prioritizing afc changes and migrating database async modifications from dev
This commit is contained in:
11
bot.py
11
bot.py
@@ -35,6 +35,7 @@ script_dir = os.path.dirname(os.path.abspath(__file__))
|
|||||||
os.chdir(script_dir)
|
os.chdir(script_dir)
|
||||||
logger.info(f"已设置工作目录为: {script_dir}")
|
logger.info(f"已设置工作目录为: {script_dir}")
|
||||||
|
|
||||||
|
|
||||||
# 检查并创建.env文件
|
# 检查并创建.env文件
|
||||||
def ensure_env_file():
|
def ensure_env_file():
|
||||||
"""确保.env文件存在,如果不存在则从模板创建"""
|
"""确保.env文件存在,如果不存在则从模板创建"""
|
||||||
@@ -45,6 +46,7 @@ def ensure_env_file():
|
|||||||
if template_env.exists():
|
if template_env.exists():
|
||||||
logger.info("未找到.env文件,正在从模板创建...")
|
logger.info("未找到.env文件,正在从模板创建...")
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
shutil.copy(template_env, env_file)
|
shutil.copy(template_env, env_file)
|
||||||
logger.info("已从template/template.env创建.env文件")
|
logger.info("已从template/template.env创建.env文件")
|
||||||
logger.warning("请编辑.env文件,将EULA_CONFIRMED设置为true并配置其他必要参数")
|
logger.warning("请编辑.env文件,将EULA_CONFIRMED设置为true并配置其他必要参数")
|
||||||
@@ -52,6 +54,7 @@ def ensure_env_file():
|
|||||||
logger.error("未找到.env文件和template.env模板文件")
|
logger.error("未找到.env文件和template.env模板文件")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
# 确保环境文件存在
|
# 确保环境文件存在
|
||||||
ensure_env_file()
|
ensure_env_file()
|
||||||
|
|
||||||
@@ -131,9 +134,9 @@ async def graceful_shutdown():
|
|||||||
def check_eula():
|
def check_eula():
|
||||||
"""检查EULA和隐私条款确认状态 - 环境变量版(类似Minecraft)"""
|
"""检查EULA和隐私条款确认状态 - 环境变量版(类似Minecraft)"""
|
||||||
# 检查环境变量中的EULA确认
|
# 检查环境变量中的EULA确认
|
||||||
eula_confirmed = os.getenv('EULA_CONFIRMED', '').lower()
|
eula_confirmed = os.getenv("EULA_CONFIRMED", "").lower()
|
||||||
|
|
||||||
if eula_confirmed == 'true':
|
if eula_confirmed == "true":
|
||||||
logger.info("EULA已通过环境变量确认")
|
logger.info("EULA已通过环境变量确认")
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -149,8 +152,8 @@ def check_eula():
|
|||||||
try:
|
try:
|
||||||
load_dotenv(override=True) # 重新加载.env文件
|
load_dotenv(override=True) # 重新加载.env文件
|
||||||
|
|
||||||
eula_confirmed = os.getenv('EULA_CONFIRMED', '').lower()
|
eula_confirmed = os.getenv("EULA_CONFIRMED", "").lower()
|
||||||
if eula_confirmed == 'true':
|
if eula_confirmed == "true":
|
||||||
confirm_logger.info("EULA确认成功,感谢您的同意")
|
confirm_logger.info("EULA确认成功,感谢您的同意")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@
|
|||||||
{
|
{
|
||||||
"type": "action",
|
"type": "action",
|
||||||
"name": "set_emoji_like",
|
"name": "set_emoji_like",
|
||||||
"description": "为消息设置表情回应"
|
"description": "为某条已经存在的消息添加‘贴表情’回应(类似点赞),而不是发送新消息。当用户明确要求‘贴表情’时使用。"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"features": [
|
"features": [
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ class SetEmojiLikeAction(BaseAction):
|
|||||||
|
|
||||||
# === 基本信息(必须填写)===
|
# === 基本信息(必须填写)===
|
||||||
action_name = "set_emoji_like"
|
action_name = "set_emoji_like"
|
||||||
action_description = "为一个已存在的消息添加点赞或表情回应(也叫‘贴表情’)"
|
action_description = "为某条已经存在的消息添加‘贴表情’回应(类似点赞),而不是发送新消息。可以在觉得某条消息非常有趣、值得赞同或者需要特殊情感回应时主动使用。"
|
||||||
activation_type = ActionActivationType.ALWAYS # 消息接收时激活(?)
|
activation_type = ActionActivationType.ALWAYS # 消息接收时激活(?)
|
||||||
chat_type_allow = ChatType.GROUP
|
chat_type_allow = ChatType.GROUP
|
||||||
parallel_action = True
|
parallel_action = True
|
||||||
|
|||||||
@@ -20,16 +20,17 @@ files_to_update = [
|
|||||||
"src/mais4u/mais4u_chat/s4u_mood_manager.py",
|
"src/mais4u/mais4u_chat/s4u_mood_manager.py",
|
||||||
"src/plugin_system/core/tool_use.py",
|
"src/plugin_system/core/tool_use.py",
|
||||||
"src/chat/memory_system/memory_activator.py",
|
"src/chat/memory_system/memory_activator.py",
|
||||||
"src/chat/utils/smart_prompt.py"
|
"src/chat/utils/smart_prompt.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def update_prompt_imports(file_path):
|
def update_prompt_imports(file_path):
|
||||||
"""更新文件中的Prompt导入"""
|
"""更新文件中的Prompt导入"""
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
print(f"文件不存在: {file_path}")
|
print(f"文件不存在: {file_path}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
# 替换导入语句
|
# 替换导入语句
|
||||||
@@ -38,7 +39,7 @@ def update_prompt_imports(file_path):
|
|||||||
|
|
||||||
if old_import in content:
|
if old_import in content:
|
||||||
new_content = content.replace(old_import, new_import)
|
new_content = content.replace(old_import, new_import)
|
||||||
with open(file_path, 'w', encoding='utf-8') as f:
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
f.write(new_content)
|
f.write(new_content)
|
||||||
print(f"已更新: {file_path}")
|
print(f"已更新: {file_path}")
|
||||||
return True
|
return True
|
||||||
@@ -46,6 +47,7 @@ def update_prompt_imports(file_path):
|
|||||||
print(f"无需更新: {file_path}")
|
print(f"无需更新: {file_path}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""主函数"""
|
"""主函数"""
|
||||||
updated_count = 0
|
updated_count = 0
|
||||||
@@ -55,5 +57,6 @@ def main():
|
|||||||
|
|
||||||
print(f"\n更新完成!共更新了 {updated_count} 个文件")
|
print(f"\n更新完成!共更新了 {updated_count} 个文件")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
@@ -1,460 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
import math
|
|
||||||
import random
|
|
||||||
from typing import Dict, Any, Tuple
|
|
||||||
|
|
||||||
from src.chat.utils.timer_calculator import Timer
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.chat.planner_actions.planner import ActionPlanner
|
|
||||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
|
||||||
from src.person_info.person_info import get_person_info_manager
|
|
||||||
from src.plugin_system.apis import database_api, generator_api
|
|
||||||
from src.plugin_system.base.component_types import ChatMode
|
|
||||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
|
||||||
from src.chat.chat_loop.hfc_utils import send_typing, stop_typing
|
|
||||||
from .hfc_context import HfcContext
|
|
||||||
from .response_handler import ResponseHandler
|
|
||||||
from .cycle_tracker import CycleTracker
|
|
||||||
|
|
||||||
# 日志记录器
|
|
||||||
logger = get_logger("hfc.processor")
|
|
||||||
|
|
||||||
|
|
||||||
class CycleProcessor:
|
|
||||||
"""
|
|
||||||
循环处理器类,负责处理单次思考循环的逻辑。
|
|
||||||
"""
|
|
||||||
def __init__(self, context: HfcContext, response_handler: ResponseHandler, cycle_tracker: CycleTracker):
|
|
||||||
"""
|
|
||||||
初始化循环处理器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context: HFC聊天上下文对象,包含聊天流、能量值等信息
|
|
||||||
response_handler: 响应处理器,负责生成和发送回复
|
|
||||||
cycle_tracker: 循环跟踪器,负责记录和管理每次思考循环的信息
|
|
||||||
"""
|
|
||||||
self.context = context
|
|
||||||
self.response_handler = response_handler
|
|
||||||
self.cycle_tracker = cycle_tracker
|
|
||||||
self.action_planner = ActionPlanner(chat_id=self.context.stream_id, action_manager=self.context.action_manager)
|
|
||||||
self.action_modifier = ActionModifier(
|
|
||||||
action_manager=self.context.action_manager, chat_id=self.context.stream_id
|
|
||||||
)
|
|
||||||
|
|
||||||
self.log_prefix = self.context.log_prefix
|
|
||||||
|
|
||||||
async def _send_and_store_reply(
|
|
||||||
self,
|
|
||||||
response_set,
|
|
||||||
loop_start_time,
|
|
||||||
action_message,
|
|
||||||
cycle_timers: Dict[str, float],
|
|
||||||
thinking_id,
|
|
||||||
actions,
|
|
||||||
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
|
|
||||||
"""
|
|
||||||
发送并存储回复信息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response_set: 回复内容集合
|
|
||||||
loop_start_time: 循环开始时间
|
|
||||||
action_message: 动作消息
|
|
||||||
cycle_timers: 循环计时器
|
|
||||||
thinking_id: 思考ID
|
|
||||||
actions: 动作列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[Dict[str, Any], str, Dict[str, float]]: 循环信息, 回复文本, 循环计时器
|
|
||||||
"""
|
|
||||||
# 发送回复
|
|
||||||
with Timer("回复发送", cycle_timers):
|
|
||||||
reply_text = await self.response_handler.send_response(response_set, loop_start_time, action_message)
|
|
||||||
|
|
||||||
# 存储reply action信息
|
|
||||||
person_info_manager = get_person_info_manager()
|
|
||||||
|
|
||||||
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
|
||||||
platform = action_message.get("chat_info_platform")
|
|
||||||
if platform is None:
|
|
||||||
platform = getattr(self.context.chat_stream, "platform", "unknown")
|
|
||||||
|
|
||||||
# 获取用户信息并生成回复提示
|
|
||||||
person_id = person_info_manager.get_person_id(
|
|
||||||
platform,
|
|
||||||
action_message.get("chat_info_user_id", ""),
|
|
||||||
)
|
|
||||||
person_info = await person_info_manager.get_values(person_id, ["person_name"])
|
|
||||||
person_name = person_info.get("person_name")
|
|
||||||
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
|
|
||||||
|
|
||||||
# 存储动作信息到数据库
|
|
||||||
await database_api.store_action_info(
|
|
||||||
chat_stream=self.context.chat_stream,
|
|
||||||
action_build_into_prompt=False,
|
|
||||||
action_prompt_display=action_prompt_display,
|
|
||||||
action_done=True,
|
|
||||||
thinking_id=thinking_id,
|
|
||||||
action_data={"reply_text": reply_text},
|
|
||||||
action_name="reply",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 构建循环信息
|
|
||||||
loop_info: Dict[str, Any] = {
|
|
||||||
"loop_plan_info": {
|
|
||||||
"action_result": actions,
|
|
||||||
},
|
|
||||||
"loop_action_info": {
|
|
||||||
"action_taken": True,
|
|
||||||
"reply_text": reply_text,
|
|
||||||
"command": "",
|
|
||||||
"taken_time": time.time(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
return loop_info, reply_text, cycle_timers
|
|
||||||
|
|
||||||
async def observe(self, interest_value: float = 0.0) -> str:
|
|
||||||
"""
|
|
||||||
观察和处理单次思考循环的核心方法
|
|
||||||
|
|
||||||
Args:
|
|
||||||
interest_value: 兴趣值
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 动作类型
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 开始新的思考循环并记录计时
|
|
||||||
- 修改可用动作并获取动作列表
|
|
||||||
- 根据聊天模式和提及情况决定是否跳过规划器
|
|
||||||
- 执行动作规划或直接回复
|
|
||||||
- 根据动作类型分发到相应的处理方法
|
|
||||||
"""
|
|
||||||
action_type = "no_action"
|
|
||||||
reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
|
||||||
|
|
||||||
# 使用sigmoid函数将interest_value转换为概率
|
|
||||||
# 当interest_value为0时,概率接近0(使用Focus模式)
|
|
||||||
# 当interest_value很高时,概率接近1(使用Normal模式)
|
|
||||||
def calculate_normal_mode_probability(interest_val: float) -> float:
|
|
||||||
"""
|
|
||||||
计算普通模式的概率
|
|
||||||
|
|
||||||
Args:
|
|
||||||
interest_val: 兴趣值
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: 概率
|
|
||||||
"""
|
|
||||||
# 使用sigmoid函数,调整参数使概率分布更合理
|
|
||||||
# 当interest_value = 0时,概率约为0.1
|
|
||||||
# 当interest_value = 1时,概率约为0.5
|
|
||||||
# 当interest_value = 2时,概率约为0.8
|
|
||||||
# 当interest_value = 3时,概率约为0.95
|
|
||||||
k = 2.0 # 控制曲线陡峭程度
|
|
||||||
x0 = 1.0 # 控制曲线中心点
|
|
||||||
return 1.0 / (1.0 + math.exp(-k * (interest_val - x0)))
|
|
||||||
|
|
||||||
# 计算普通模式概率
|
|
||||||
normal_mode_probability = (
|
|
||||||
calculate_normal_mode_probability(interest_value)
|
|
||||||
* 0.5
|
|
||||||
/ global_config.chat.get_current_talk_frequency(self.context.stream_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 根据概率决定使用哪种模式
|
|
||||||
if random.random() < normal_mode_probability:
|
|
||||||
mode = ChatMode.NORMAL
|
|
||||||
logger.info(
|
|
||||||
f"{self.log_prefix} 基于兴趣值 {interest_value:.2f},概率 {normal_mode_probability:.2f},选择Normal planner模式"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
mode = ChatMode.FOCUS
|
|
||||||
logger.info(
|
|
||||||
f"{self.log_prefix} 基于兴趣值 {interest_value:.2f},概率 {normal_mode_probability:.2f},选择Focus planner模式"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 开始新的思考循环
|
|
||||||
cycle_timers, thinking_id = self.cycle_tracker.start_cycle()
|
|
||||||
logger.info(f"{self.log_prefix} 开始第{self.context.cycle_counter}次思考")
|
|
||||||
|
|
||||||
if ENABLE_S4U and self.context.chat_stream and self.context.chat_stream.user_info:
|
|
||||||
await send_typing(self.context.chat_stream.user_info.user_id)
|
|
||||||
|
|
||||||
loop_start_time = time.time()
|
|
||||||
|
|
||||||
# 第一步:动作修改
|
|
||||||
with Timer("动作修改", cycle_timers):
|
|
||||||
try:
|
|
||||||
await self.action_modifier.modify_actions()
|
|
||||||
available_actions = self.context.action_manager.get_using_actions()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
|
|
||||||
available_actions = {}
|
|
||||||
|
|
||||||
# 规划动作
|
|
||||||
from src.plugin_system.core.event_manager import event_manager
|
|
||||||
from src.plugin_system import EventType
|
|
||||||
|
|
||||||
result = await event_manager.trigger_event(
|
|
||||||
EventType.ON_PLAN, permission_group="SYSTEM", stream_id=self.context.chat_stream
|
|
||||||
)
|
|
||||||
if result and not result.all_continue_process():
|
|
||||||
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于规划前中断了内容生成")
|
|
||||||
with Timer("规划器", cycle_timers):
|
|
||||||
actions, _ = await self.action_planner.plan(mode=mode)
|
|
||||||
|
|
||||||
async def execute_action(action_info):
|
|
||||||
"""执行单个动作的通用函数"""
|
|
||||||
try:
|
|
||||||
if action_info["action_type"] == "no_action":
|
|
||||||
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
|
|
||||||
if action_info["action_type"] == "no_reply":
|
|
||||||
# 直接处理no_reply逻辑,不再通过动作系统
|
|
||||||
reason = action_info.get("reasoning", "选择不回复")
|
|
||||||
logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
|
|
||||||
|
|
||||||
# 存储no_reply信息到数据库
|
|
||||||
await database_api.store_action_info(
|
|
||||||
chat_stream=self.context.chat_stream,
|
|
||||||
action_build_into_prompt=False,
|
|
||||||
action_prompt_display=reason,
|
|
||||||
action_done=True,
|
|
||||||
thinking_id=thinking_id,
|
|
||||||
action_data={"reason": reason},
|
|
||||||
action_name="no_reply",
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"action_type": "no_reply", "success": True, "reply_text": "", "command": ""}
|
|
||||||
elif action_info["action_type"] != "reply" and action_info["action_type"] != "no_action":
|
|
||||||
# 记录并执行普通动作
|
|
||||||
reason = action_info.get("reasoning", f"执行动作 {action_info['action_type']}")
|
|
||||||
logger.info(f"{self.log_prefix} 决定执行动作 '{action_info['action_type']}',内心思考: {reason}")
|
|
||||||
with Timer("动作执行", cycle_timers):
|
|
||||||
success, reply_text, command = await self._handle_action(
|
|
||||||
action_info["action_type"],
|
|
||||||
reason, # 使用已获取的reason
|
|
||||||
action_info["action_data"],
|
|
||||||
cycle_timers,
|
|
||||||
thinking_id,
|
|
||||||
action_info["action_message"],
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"action_type": action_info["action_type"],
|
|
||||||
"success": success,
|
|
||||||
"reply_text": reply_text,
|
|
||||||
"command": command,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
# 生成回复
|
|
||||||
try:
|
|
||||||
reason = action_info.get("reasoning", "决定进行回复")
|
|
||||||
logger.info(f"{self.log_prefix} 决定进行回复,内心思考: {reason}")
|
|
||||||
success, response_set, _ = await generator_api.generate_reply(
|
|
||||||
chat_stream=self.context.chat_stream,
|
|
||||||
reply_message=action_info["action_message"],
|
|
||||||
available_actions=available_actions,
|
|
||||||
enable_tool=global_config.tool.enable_tool,
|
|
||||||
request_type="chat.replyer",
|
|
||||||
from_plugin=False,
|
|
||||||
read_mark=action_info.get("action_message", {}).get("time", 0.0),
|
|
||||||
)
|
|
||||||
if not success or not response_set:
|
|
||||||
logger.info(
|
|
||||||
f"对 {action_info['action_message'].get('processed_plain_text')} 的回复生成失败"
|
|
||||||
)
|
|
||||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
|
|
||||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
|
||||||
|
|
||||||
# 发送并存储回复
|
|
||||||
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
|
|
||||||
response_set,
|
|
||||||
loop_start_time,
|
|
||||||
action_info["action_message"],
|
|
||||||
cycle_timers,
|
|
||||||
thinking_id,
|
|
||||||
actions,
|
|
||||||
)
|
|
||||||
return {"action_type": "reply", "success": True, "reply_text": reply_text, "loop_info": loop_info}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
|
|
||||||
logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}")
|
|
||||||
return {
|
|
||||||
"action_type": action_info["action_type"],
|
|
||||||
"success": False,
|
|
||||||
"reply_text": "",
|
|
||||||
"loop_info": None,
|
|
||||||
"error": str(e),
|
|
||||||
}
|
|
||||||
|
|
||||||
# 分离 reply 动作和其他动作
|
|
||||||
reply_actions = [a for a in actions if a.get("action_type") == "reply"]
|
|
||||||
other_actions = [a for a in actions if a.get("action_type") != "reply"]
|
|
||||||
|
|
||||||
reply_loop_info = None
|
|
||||||
reply_text_from_reply = ""
|
|
||||||
other_actions_results = []
|
|
||||||
|
|
||||||
# 1. 首先串行执行所有 reply 动作(通常只有一个)
|
|
||||||
if reply_actions:
|
|
||||||
logger.info(f"{self.log_prefix} 正在执行文本回复...")
|
|
||||||
for action in reply_actions:
|
|
||||||
action_message = action.get("action_message")
|
|
||||||
if not action_message:
|
|
||||||
logger.warning(f"{self.log_prefix} reply 动作缺少 action_message,跳过")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 检查是否是空的DatabaseMessages对象
|
|
||||||
if hasattr(action_message, 'chat_info') and hasattr(action_message.chat_info, 'user_info'):
|
|
||||||
target_user_id = action_message.chat_info.user_info.user_id
|
|
||||||
else:
|
|
||||||
# 如果是字典格式,使用原来的方式
|
|
||||||
target_user_id = action_message.get("chat_info_user_id", "")
|
|
||||||
|
|
||||||
if not target_user_id:
|
|
||||||
logger.warning(f"{self.log_prefix} reply 动作的 action_message 缺少用户ID,跳过")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if target_user_id == global_config.bot.qq_account and not global_config.chat.allow_reply_self:
|
|
||||||
logger.warning("选取的reply的目标为bot自己,跳过reply action")
|
|
||||||
continue
|
|
||||||
result = await execute_action(action)
|
|
||||||
if isinstance(result, Exception):
|
|
||||||
logger.error(f"{self.log_prefix} 回复动作执行异常: {result}")
|
|
||||||
continue
|
|
||||||
if result.get("success"):
|
|
||||||
reply_loop_info = result.get("loop_info")
|
|
||||||
reply_text_from_reply = result.get("reply_text", "")
|
|
||||||
else:
|
|
||||||
logger.warning(f"{self.log_prefix} 回复动作执行失败")
|
|
||||||
|
|
||||||
# 2. 然后并行执行所有其他动作
|
|
||||||
if other_actions:
|
|
||||||
logger.info(f"{self.log_prefix} 正在执行附加动作: {[a.get('action_type') for a in other_actions]}")
|
|
||||||
other_action_tasks = [asyncio.create_task(execute_action(action)) for action in other_actions]
|
|
||||||
results = await asyncio.gather(*other_action_tasks, return_exceptions=True)
|
|
||||||
for i, result in enumerate(results):
|
|
||||||
if isinstance(result, BaseException):
|
|
||||||
logger.error(f"{self.log_prefix} 附加动作执行异常: {result}")
|
|
||||||
continue
|
|
||||||
other_actions_results.append(result)
|
|
||||||
|
|
||||||
# 构建最终的循环信息
|
|
||||||
if reply_loop_info:
|
|
||||||
loop_info = reply_loop_info
|
|
||||||
# 将其他动作的结果合并到loop_info中
|
|
||||||
if "other_actions" not in loop_info["loop_action_info"]:
|
|
||||||
loop_info["loop_action_info"]["other_actions"] = []
|
|
||||||
loop_info["loop_action_info"]["other_actions"].extend(other_actions_results)
|
|
||||||
reply_text = reply_text_from_reply
|
|
||||||
else:
|
|
||||||
# 没有回复信息,构建纯动作的loop_info
|
|
||||||
# 即使没有回复,也要正确处理其他动作
|
|
||||||
final_action_taken = any(res.get("success", False) for res in other_actions_results)
|
|
||||||
final_reply_text = " ".join(res.get("reply_text", "") for res in other_actions_results if res.get("reply_text"))
|
|
||||||
final_command = " ".join(res.get("command", "") for res in other_actions_results if res.get("command"))
|
|
||||||
|
|
||||||
loop_info = {
|
|
||||||
"loop_plan_info": {
|
|
||||||
"action_result": actions,
|
|
||||||
},
|
|
||||||
"loop_action_info": {
|
|
||||||
"action_taken": final_action_taken,
|
|
||||||
"reply_text": final_reply_text,
|
|
||||||
"command": final_command,
|
|
||||||
"taken_time": time.time(),
|
|
||||||
"other_actions": other_actions_results,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
reply_text = final_reply_text
|
|
||||||
|
|
||||||
# 停止正在输入状态
|
|
||||||
if ENABLE_S4U:
|
|
||||||
await stop_typing()
|
|
||||||
|
|
||||||
# 结束循环
|
|
||||||
self.context.chat_instance.cycle_tracker.end_cycle(loop_info, cycle_timers)
|
|
||||||
self.context.chat_instance.cycle_tracker.print_cycle_info(cycle_timers)
|
|
||||||
|
|
||||||
action_type = actions[0]["action_type"] if actions else "no_action"
|
|
||||||
return action_type
|
|
||||||
|
|
||||||
async def _handle_action(
|
|
||||||
self, action, reasoning, action_data, cycle_timers, thinking_id, action_message
|
|
||||||
) -> tuple[bool, str, str]:
|
|
||||||
"""
|
|
||||||
处理具体的动作执行
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action: 动作名称
|
|
||||||
reasoning: 执行理由
|
|
||||||
action_data: 动作数据
|
|
||||||
cycle_timers: 循环计时器
|
|
||||||
thinking_id: 思考ID
|
|
||||||
action_message: 动作消息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (执行是否成功, 回复文本, 命令文本)
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 创建对应的动作处理器
|
|
||||||
- 执行动作并捕获异常
|
|
||||||
- 返回执行结果供上级方法整合
|
|
||||||
"""
|
|
||||||
if not self.context.chat_stream:
|
|
||||||
return False, "", ""
|
|
||||||
try:
|
|
||||||
# 创建动作处理器
|
|
||||||
action_handler = self.context.action_manager.create_action(
|
|
||||||
action_name=action,
|
|
||||||
action_data=action_data,
|
|
||||||
reasoning=reasoning,
|
|
||||||
cycle_timers=cycle_timers,
|
|
||||||
thinking_id=thinking_id,
|
|
||||||
chat_stream=self.context.chat_stream,
|
|
||||||
log_prefix=self.context.log_prefix,
|
|
||||||
action_message=action_message,
|
|
||||||
)
|
|
||||||
if not action_handler:
|
|
||||||
# 动作处理器创建失败,尝试回退机制
|
|
||||||
logger.warning(f"{self.context.log_prefix} 创建动作处理器失败: {action},尝试回退方案")
|
|
||||||
|
|
||||||
# 获取当前可用的动作
|
|
||||||
available_actions = self.context.action_manager.get_using_actions()
|
|
||||||
fallback_action = None
|
|
||||||
|
|
||||||
# 回退优先级:reply > 第一个可用动作
|
|
||||||
if "reply" in available_actions:
|
|
||||||
fallback_action = "reply"
|
|
||||||
elif available_actions:
|
|
||||||
fallback_action = list(available_actions.keys())[0]
|
|
||||||
|
|
||||||
if fallback_action and fallback_action != action:
|
|
||||||
logger.info(f"{self.context.log_prefix} 使用回退动作: {fallback_action}")
|
|
||||||
action_handler = self.context.action_manager.create_action(
|
|
||||||
action_name=fallback_action,
|
|
||||||
action_data=action_data,
|
|
||||||
reasoning=f"原动作'{action}'不可用,自动回退。{reasoning}",
|
|
||||||
cycle_timers=cycle_timers,
|
|
||||||
thinking_id=thinking_id,
|
|
||||||
chat_stream=self.context.chat_stream,
|
|
||||||
log_prefix=self.context.log_prefix,
|
|
||||||
action_message=action_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not action_handler:
|
|
||||||
logger.error(f"{self.context.log_prefix} 回退方案也失败,无法创建任何动作处理器")
|
|
||||||
return False, "", ""
|
|
||||||
|
|
||||||
# 执行动作
|
|
||||||
success, reply_text = await action_handler.handle_action()
|
|
||||||
return success, reply_text, ""
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.context.log_prefix} 处理{action}时出错: {e}")
|
|
||||||
traceback.print_exc()
|
|
||||||
return False, "", ""
|
|
||||||
@@ -1,114 +0,0 @@
|
|||||||
import time
|
|
||||||
from typing import Dict, Any, Tuple
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.chat.chat_loop.hfc_utils import CycleDetail
|
|
||||||
from .hfc_context import HfcContext
|
|
||||||
|
|
||||||
logger = get_logger("hfc")
|
|
||||||
|
|
||||||
|
|
||||||
class CycleTracker:
|
|
||||||
def __init__(self, context: HfcContext):
|
|
||||||
"""
|
|
||||||
初始化循环跟踪器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context: HFC聊天上下文对象
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 负责跟踪和记录每次思考循环的详细信息
|
|
||||||
- 管理循环的开始、结束和信息存储
|
|
||||||
"""
|
|
||||||
self.context = context
|
|
||||||
|
|
||||||
def start_cycle(self, is_proactive: bool = False) -> Tuple[Dict[str, float], str]:
|
|
||||||
"""
|
|
||||||
开始新的思考循环
|
|
||||||
|
|
||||||
Args:
|
|
||||||
is_proactive: 标记这个循环是否由主动思考发起
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (循环计时器字典, 思考ID字符串)
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 增加循环计数器
|
|
||||||
- 创建新的循环详情对象
|
|
||||||
- 生成唯一的思考ID
|
|
||||||
- 初始化循环计时器
|
|
||||||
"""
|
|
||||||
if not is_proactive:
|
|
||||||
self.context.cycle_counter += 1
|
|
||||||
|
|
||||||
cycle_id = self.context.cycle_counter if not is_proactive else f"{self.context.cycle_counter}.p"
|
|
||||||
self.context.current_cycle_detail = CycleDetail(cycle_id)
|
|
||||||
self.context.current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
|
|
||||||
cycle_timers = {}
|
|
||||||
return cycle_timers, self.context.current_cycle_detail.thinking_id
|
|
||||||
|
|
||||||
def end_cycle(self, loop_info: Dict[str, Any], cycle_timers: Dict[str, float]):
|
|
||||||
"""
|
|
||||||
结束当前思考循环
|
|
||||||
|
|
||||||
Args:
|
|
||||||
loop_info: 循环信息,包含规划和动作信息
|
|
||||||
cycle_timers: 循环计时器,记录各阶段耗时
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 设置循环详情的完整信息
|
|
||||||
- 将当前循环加入历史记录
|
|
||||||
- 记录计时器和结束时间
|
|
||||||
- 打印循环统计信息
|
|
||||||
"""
|
|
||||||
if self.context.current_cycle_detail:
|
|
||||||
self.context.current_cycle_detail.set_loop_info(loop_info)
|
|
||||||
self.context.history_loop.append(self.context.current_cycle_detail)
|
|
||||||
self.context.current_cycle_detail.timers = cycle_timers
|
|
||||||
self.context.current_cycle_detail.end_time = time.time()
|
|
||||||
self.print_cycle_info(cycle_timers)
|
|
||||||
|
|
||||||
def print_cycle_info(self, cycle_timers: Dict[str, float]):
|
|
||||||
"""
|
|
||||||
打印循环统计信息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cycle_timers: 循环计时器字典
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 格式化各阶段的耗时信息
|
|
||||||
- 计算总体循环持续时间
|
|
||||||
- 输出详细的性能统计日志
|
|
||||||
- 显示选择的动作类型
|
|
||||||
"""
|
|
||||||
if not self.context.current_cycle_detail:
|
|
||||||
return
|
|
||||||
|
|
||||||
timer_strings = []
|
|
||||||
for name, elapsed in cycle_timers.items():
|
|
||||||
formatted_time = f"{elapsed * 1000:.2f}毫秒" if elapsed < 1 else f"{elapsed:.2f}秒"
|
|
||||||
timer_strings.append(f"{name}: {formatted_time}")
|
|
||||||
|
|
||||||
# 获取动作类型,兼容新旧格式
|
|
||||||
# 获取动作类型
|
|
||||||
action_type = "未知动作"
|
|
||||||
if self.context.current_cycle_detail:
|
|
||||||
loop_plan_info = self.context.current_cycle_detail.loop_plan_info
|
|
||||||
actions = loop_plan_info.get("action_result")
|
|
||||||
|
|
||||||
if isinstance(actions, list) and actions:
|
|
||||||
# 从actions列表中提取所有action_type
|
|
||||||
action_types = [a.get("action_type", "未知") for a in actions]
|
|
||||||
action_type = ", ".join(action_types)
|
|
||||||
elif isinstance(actions, dict):
|
|
||||||
# 兼容旧格式
|
|
||||||
action_type = actions.get("action_type", "未知动作")
|
|
||||||
|
|
||||||
|
|
||||||
if self.context.current_cycle_detail.end_time and self.context.current_cycle_detail.start_time:
|
|
||||||
duration = self.context.current_cycle_detail.end_time - self.context.current_cycle_detail.start_time
|
|
||||||
logger.info(
|
|
||||||
f"{self.context.log_prefix} 第{self.context.current_cycle_detail.cycle_id}次思考,"
|
|
||||||
f"耗时: {duration:.1f}秒, "
|
|
||||||
f"选择动作: {action_type}" + (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
|
||||||
)
|
|
||||||
@@ -1,162 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import time
|
|
||||||
from typing import Optional
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.config.config import global_config
|
|
||||||
from .hfc_context import HfcContext
|
|
||||||
from src.chat.chat_loop.sleep_manager import sleep_manager
|
|
||||||
logger = get_logger("hfc")
|
|
||||||
|
|
||||||
|
|
||||||
class EnergyManager:
|
|
||||||
def __init__(self, context: HfcContext):
|
|
||||||
"""
|
|
||||||
初始化能量管理器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context: HFC聊天上下文对象
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 管理聊天机器人的能量值系统
|
|
||||||
- 根据聊天模式自动调整能量消耗
|
|
||||||
- 控制能量值的衰减和记录
|
|
||||||
"""
|
|
||||||
self.context = context
|
|
||||||
self._energy_task: Optional[asyncio.Task] = None
|
|
||||||
self.last_energy_log_time = 0
|
|
||||||
self.energy_log_interval = 90
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
"""
|
|
||||||
启动能量管理器
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 检查运行状态,避免重复启动
|
|
||||||
- 创建能量循环异步任务
|
|
||||||
- 设置任务完成回调
|
|
||||||
- 记录启动日志
|
|
||||||
"""
|
|
||||||
if self.context.running and not self._energy_task:
|
|
||||||
self._energy_task = asyncio.create_task(self._energy_loop())
|
|
||||||
self._energy_task.add_done_callback(self._handle_energy_completion)
|
|
||||||
logger.info(f"{self.context.log_prefix} 能量管理器已启动")
|
|
||||||
|
|
||||||
async def stop(self):
|
|
||||||
"""
|
|
||||||
停止能量管理器
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 取消正在运行的能量循环任务
|
|
||||||
- 等待任务完全停止
|
|
||||||
- 记录停止日志
|
|
||||||
"""
|
|
||||||
if self._energy_task and not self._energy_task.done():
|
|
||||||
self._energy_task.cancel()
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
logger.info(f"{self.context.log_prefix} 能量管理器已停止")
|
|
||||||
|
|
||||||
async def _energy_loop(self):
|
|
||||||
"""
|
|
||||||
能量与睡眠压力管理的主循环
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 每10秒执行一次能量更新
|
|
||||||
- 根据群聊配置设置固定的聊天模式和能量值
|
|
||||||
- 在自动模式下根据聊天模式进行能量衰减
|
|
||||||
- NORMAL模式每次衰减0.3,FOCUS模式每次衰减0.6
|
|
||||||
- 确保能量值不低于0.3的最小值
|
|
||||||
"""
|
|
||||||
while self.context.running:
|
|
||||||
await asyncio.sleep(10)
|
|
||||||
|
|
||||||
if not self.context.chat_stream:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 判断当前是否为睡眠时间
|
|
||||||
is_sleeping = sleep_manager.SleepManager().is_sleeping()
|
|
||||||
|
|
||||||
if is_sleeping:
|
|
||||||
# 睡眠中:减少睡眠压力
|
|
||||||
decay_per_10s = global_config.sleep_system.sleep_pressure_decay_rate / 6
|
|
||||||
self.context.sleep_pressure -= decay_per_10s
|
|
||||||
self.context.sleep_pressure = max(self.context.sleep_pressure, 0)
|
|
||||||
self._log_sleep_pressure_change("睡眠压力释放")
|
|
||||||
self.context.save_context_state()
|
|
||||||
else:
|
|
||||||
# 清醒时:处理能量衰减
|
|
||||||
is_group_chat = self.context.chat_stream.group_info is not None
|
|
||||||
if is_group_chat:
|
|
||||||
self.context.energy_value = 25
|
|
||||||
|
|
||||||
await asyncio.sleep(12)
|
|
||||||
self.context.energy_value -= 0.5
|
|
||||||
self.context.energy_value = max(self.context.energy_value, 0.3)
|
|
||||||
|
|
||||||
self._log_energy_change("能量值衰减")
|
|
||||||
self.context.save_context_state()
|
|
||||||
|
|
||||||
def _should_log_energy(self) -> bool:
|
|
||||||
"""
|
|
||||||
判断是否应该记录能量变化日志
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 如果距离上次记录超过间隔时间则返回True
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 控制能量日志的记录频率,避免日志过于频繁
|
|
||||||
- 默认间隔90秒记录一次详细日志
|
|
||||||
- 其他时间使用调试级别日志
|
|
||||||
"""
|
|
||||||
current_time = time.time()
|
|
||||||
if current_time - self.last_energy_log_time >= self.energy_log_interval:
|
|
||||||
self.last_energy_log_time = current_time
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def increase_sleep_pressure(self):
|
|
||||||
"""
|
|
||||||
在执行动作后增加睡眠压力
|
|
||||||
"""
|
|
||||||
increment = global_config.sleep_system.sleep_pressure_increment
|
|
||||||
self.context.sleep_pressure += increment
|
|
||||||
self.context.sleep_pressure = min(self.context.sleep_pressure, 100.0) # 设置一个100的上限
|
|
||||||
self._log_sleep_pressure_change("执行动作,睡眠压力累积")
|
|
||||||
self.context.save_context_state()
|
|
||||||
|
|
||||||
def _log_energy_change(self, action: str, reason: str = ""):
|
|
||||||
"""
|
|
||||||
记录能量变化日志
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action: 能量变化的动作描述
|
|
||||||
reason: 可选的变化原因
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 根据时间间隔决定使用info还是debug级别的日志
|
|
||||||
- 格式化能量值显示(保留一位小数)
|
|
||||||
- 可选择性地包含变化原因
|
|
||||||
"""
|
|
||||||
if self._should_log_energy():
|
|
||||||
log_message = f"{self.context.log_prefix} {action},当前能量值:{self.context.energy_value:.1f}"
|
|
||||||
if reason:
|
|
||||||
log_message = (
|
|
||||||
f"{self.context.log_prefix} {action},{reason},当前能量值:{self.context.energy_value:.1f}"
|
|
||||||
)
|
|
||||||
logger.info(log_message)
|
|
||||||
else:
|
|
||||||
log_message = f"{self.context.log_prefix} {action},当前能量值:{self.context.energy_value:.1f}"
|
|
||||||
if reason:
|
|
||||||
log_message = (
|
|
||||||
f"{self.context.log_prefix} {action},{reason},当前能量值:{self.context.energy_value:.1f}"
|
|
||||||
)
|
|
||||||
logger.debug(log_message)
|
|
||||||
|
|
||||||
def _log_sleep_pressure_change(self, action: str):
|
|
||||||
"""
|
|
||||||
记录睡眠压力变化日志
|
|
||||||
"""
|
|
||||||
# 使用与能量日志相同的频率控制
|
|
||||||
if self._should_log_energy():
|
|
||||||
logger.info(f"{self.context.log_prefix} {action},当前睡眠压力:{self.context.sleep_pressure:.1f}")
|
|
||||||
else:
|
|
||||||
logger.debug(f"{self.context.log_prefix} {action},当前睡眠压力:{self.context.sleep_pressure:.1f}")
|
|
||||||
@@ -1,574 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
import random
|
|
||||||
from typing import Optional, List, Dict, Any
|
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
|
||||||
from src.chat.express.expression_learner import expression_learner_manager
|
|
||||||
from src.chat.chat_loop.sleep_manager.sleep_manager import SleepManager, SleepState
|
|
||||||
|
|
||||||
from .hfc_context import HfcContext
|
|
||||||
from .energy_manager import EnergyManager
|
|
||||||
from .proactive.proactive_thinker import ProactiveThinker
|
|
||||||
from .cycle_processor import CycleProcessor
|
|
||||||
from .response_handler import ResponseHandler
|
|
||||||
from .cycle_tracker import CycleTracker
|
|
||||||
from .sleep_manager.wakeup_manager import WakeUpManager
|
|
||||||
from .proactive.events import ProactiveTriggerEvent
|
|
||||||
|
|
||||||
logger = get_logger("hfc")
|
|
||||||
|
|
||||||
|
|
||||||
class HeartFChatting:
|
|
||||||
def __init__(self, chat_id: str):
|
|
||||||
"""
|
|
||||||
初始化心跳聊天管理器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: 聊天ID标识符
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 创建聊天上下文和所有子管理器
|
|
||||||
- 初始化循环跟踪器、响应处理器、循环处理器等核心组件
|
|
||||||
- 设置能量管理器、主动思考器和普通模式处理器
|
|
||||||
- 初始化聊天模式并记录初始化完成日志
|
|
||||||
"""
|
|
||||||
self.context = HfcContext(chat_id)
|
|
||||||
self.context.new_message_queue = asyncio.Queue()
|
|
||||||
self._processing_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
self.cycle_tracker = CycleTracker(self.context)
|
|
||||||
self.response_handler = ResponseHandler(self.context)
|
|
||||||
self.cycle_processor = CycleProcessor(self.context, self.response_handler, self.cycle_tracker)
|
|
||||||
self.energy_manager = EnergyManager(self.context)
|
|
||||||
self.proactive_thinker = ProactiveThinker(self.context, self.cycle_processor)
|
|
||||||
self.wakeup_manager = WakeUpManager(self.context)
|
|
||||||
self.sleep_manager = SleepManager()
|
|
||||||
|
|
||||||
# 将唤醒度管理器设置到上下文中
|
|
||||||
self.context.wakeup_manager = self.wakeup_manager
|
|
||||||
self.context.energy_manager = self.energy_manager
|
|
||||||
self.context.sleep_manager = self.sleep_manager
|
|
||||||
# 将HeartFChatting实例设置到上下文中,以便其他组件可以调用其方法
|
|
||||||
self.context.chat_instance = self
|
|
||||||
|
|
||||||
self._loop_task: Optional[asyncio.Task] = None
|
|
||||||
self._proactive_monitor_task: Optional[asyncio.Task] = None
|
|
||||||
|
|
||||||
# 记录最近3次的兴趣度
|
|
||||||
self.recent_interest_records: deque = deque(maxlen=3)
|
|
||||||
self._initialize_chat_mode()
|
|
||||||
logger.info(f"{self.context.log_prefix} HeartFChatting 初始化完成")
|
|
||||||
|
|
||||||
def _initialize_chat_mode(self):
|
|
||||||
"""
|
|
||||||
初始化聊天模式
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 检测是否为群聊环境
|
|
||||||
- 根据全局配置设置强制聊天模式
|
|
||||||
- 在focus模式下设置能量值为35
|
|
||||||
- 在normal模式下设置能量值为15
|
|
||||||
- 如果是auto模式则保持默认设置
|
|
||||||
"""
|
|
||||||
is_group_chat = self.context.chat_stream.group_info is not None if self.context.chat_stream else False
|
|
||||||
if is_group_chat and global_config.chat.group_chat_mode != "auto":
|
|
||||||
self.context.energy_value = 25
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
"""
|
|
||||||
启动心跳聊天系统
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 检查是否已经在运行,避免重复启动
|
|
||||||
- 初始化关系构建器和表达学习器
|
|
||||||
- 启动能量管理器和主动思考器
|
|
||||||
- 创建主聊天循环任务并设置完成回调
|
|
||||||
- 记录启动完成日志
|
|
||||||
"""
|
|
||||||
if self.context.running:
|
|
||||||
return
|
|
||||||
self.context.running = True
|
|
||||||
|
|
||||||
self.context.relationship_builder = relationship_builder_manager.get_or_create_builder(self.context.stream_id)
|
|
||||||
self.context.expression_learner = await expression_learner_manager.get_expression_learner(self.context.stream_id)
|
|
||||||
|
|
||||||
# 启动主动思考监视器
|
|
||||||
if global_config.chat.enable_proactive_thinking:
|
|
||||||
self._proactive_monitor_task = asyncio.create_task(self._proactive_monitor_loop())
|
|
||||||
self._proactive_monitor_task.add_done_callback(self._handle_proactive_monitor_completion)
|
|
||||||
logger.info(f"{self.context.log_prefix} 主动思考监视器已启动")
|
|
||||||
|
|
||||||
await self.wakeup_manager.start()
|
|
||||||
|
|
||||||
self._loop_task = asyncio.create_task(self._main_chat_loop())
|
|
||||||
self._loop_task.add_done_callback(self._handle_loop_completion)
|
|
||||||
logger.info(f"{self.context.log_prefix} HeartFChatting 启动完成")
|
|
||||||
|
|
||||||
async def add_message(self, message: Dict[str, Any]):
|
|
||||||
"""从外部接收新消息并放入队列"""
|
|
||||||
await self.context.new_message_queue.put(message)
|
|
||||||
|
|
||||||
async def stop(self):
|
|
||||||
"""
|
|
||||||
停止心跳聊天系统
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 检查是否正在运行,避免重复停止
|
|
||||||
- 设置运行状态为False
|
|
||||||
- 停止能量管理器和主动思考器
|
|
||||||
- 取消主聊天循环任务
|
|
||||||
- 记录停止完成日志
|
|
||||||
"""
|
|
||||||
if not self.context.running:
|
|
||||||
return
|
|
||||||
self.context.running = False
|
|
||||||
|
|
||||||
# 停止主动思考监视器
|
|
||||||
if self._proactive_monitor_task and not self._proactive_monitor_task.done():
|
|
||||||
self._proactive_monitor_task.cancel()
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
logger.info(f"{self.context.log_prefix} 主动思考监视器已停止")
|
|
||||||
|
|
||||||
await self.wakeup_manager.stop()
|
|
||||||
|
|
||||||
if self._loop_task and not self._loop_task.done():
|
|
||||||
self._loop_task.cancel()
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
logger.info(f"{self.context.log_prefix} HeartFChatting 已停止")
|
|
||||||
|
|
||||||
def _handle_loop_completion(self, task: asyncio.Task):
|
|
||||||
"""
|
|
||||||
处理主循环任务完成
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task: 完成的异步任务对象
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 处理任务异常完成的情况
|
|
||||||
- 区分正常停止和异常终止
|
|
||||||
- 记录相应的日志信息
|
|
||||||
- 处理取消任务的情况
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if exception := task.exception():
|
|
||||||
logger.error(f"{self.context.log_prefix} HeartFChatting: 脱离了聊天(异常): {exception}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
else:
|
|
||||||
logger.info(f"{self.context.log_prefix} HeartFChatting: 脱离了聊天 (外部停止)")
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.info(f"{self.context.log_prefix} HeartFChatting: 结束了聊天")
|
|
||||||
|
|
||||||
def _handle_proactive_monitor_completion(self, task: asyncio.Task):
|
|
||||||
"""
|
|
||||||
处理主动思考监视器任务完成
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task: 完成的异步任务对象
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 处理任务异常完成的情况
|
|
||||||
- 记录任务正常结束或被取消的日志
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if exception := task.exception():
|
|
||||||
logger.error(f"{self.context.log_prefix} 主动思考监视器异常: {exception}")
|
|
||||||
else:
|
|
||||||
logger.info(f"{self.context.log_prefix} 主动思考监视器正常结束")
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.info(f"{self.context.log_prefix} 主动思考监视器被取消")
|
|
||||||
|
|
||||||
async def _proactive_monitor_loop(self):
|
|
||||||
"""
|
|
||||||
主动思考监视器循环
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 定期检查是否需要进行主动思考
|
|
||||||
- 计算聊天沉默时间,并与动态思考间隔比较
|
|
||||||
- 当沉默时间超过阈值时,触发主动思考
|
|
||||||
- 处理思考过程中的异常
|
|
||||||
"""
|
|
||||||
while self.context.running:
|
|
||||||
await asyncio.sleep(15)
|
|
||||||
|
|
||||||
if not self._should_enable_proactive_thinking():
|
|
||||||
continue
|
|
||||||
|
|
||||||
current_time = time.time()
|
|
||||||
silence_duration = current_time - self.context.last_message_time
|
|
||||||
target_interval = self._get_dynamic_thinking_interval()
|
|
||||||
|
|
||||||
if silence_duration >= target_interval:
|
|
||||||
try:
|
|
||||||
formatted_time = self._format_duration(silence_duration)
|
|
||||||
event = ProactiveTriggerEvent(
|
|
||||||
source="silence_monitor",
|
|
||||||
reason=f"聊天已沉默 {formatted_time}",
|
|
||||||
metadata={"silence_duration": silence_duration},
|
|
||||||
)
|
|
||||||
await self.proactive_thinker.think(event)
|
|
||||||
self.context.last_message_time = current_time
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.context.log_prefix} 主动思考触发执行出错: {e}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
|
|
||||||
def _should_enable_proactive_thinking(self) -> bool:
|
|
||||||
"""
|
|
||||||
判断是否应启用主动思考
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 如果应启用主动思考则返回True,否则返回False
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 检查全局配置和特定聊天设置
|
|
||||||
- 支持按群聊和私聊分别配置
|
|
||||||
- 支持白名单模式,只在特定聊天中启用
|
|
||||||
"""
|
|
||||||
if not self.context.chat_stream:
|
|
||||||
return False
|
|
||||||
|
|
||||||
is_group_chat = self.context.chat_stream.group_info is not None
|
|
||||||
|
|
||||||
if is_group_chat and not global_config.chat.proactive_thinking_in_group:
|
|
||||||
return False
|
|
||||||
if not is_group_chat and not global_config.chat.proactive_thinking_in_private:
|
|
||||||
return False
|
|
||||||
|
|
||||||
stream_parts = self.context.stream_id.split(":")
|
|
||||||
current_chat_identifier = f"{stream_parts}:{stream_parts}" if len(stream_parts) >= 2 else self.context.stream_id
|
|
||||||
|
|
||||||
enable_list = getattr(
|
|
||||||
global_config.chat,
|
|
||||||
"proactive_thinking_enable_in_groups" if is_group_chat else "proactive_thinking_enable_in_private",
|
|
||||||
[],
|
|
||||||
)
|
|
||||||
return not enable_list or current_chat_identifier in enable_list
|
|
||||||
|
|
||||||
def _get_dynamic_thinking_interval(self) -> float:
|
|
||||||
"""
|
|
||||||
获取动态思考间隔时间
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: 思考间隔秒数
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 尝试从timing_utils导入正态分布间隔函数
|
|
||||||
- 根据配置计算动态间隔,增加随机性
|
|
||||||
- 在无法导入或计算出错时,回退到固定的间隔
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from src.utils.timing_utils import get_normal_distributed_interval
|
|
||||||
|
|
||||||
base_interval = global_config.chat.proactive_thinking_interval
|
|
||||||
delta_sigma = getattr(global_config.chat, "delta_sigma", 120)
|
|
||||||
|
|
||||||
if base_interval <= 0:
|
|
||||||
base_interval = abs(base_interval)
|
|
||||||
if delta_sigma < 0:
|
|
||||||
delta_sigma = abs(delta_sigma)
|
|
||||||
|
|
||||||
if base_interval == 0 and delta_sigma == 0:
|
|
||||||
return 300
|
|
||||||
if delta_sigma == 0:
|
|
||||||
return base_interval
|
|
||||||
|
|
||||||
sigma_percentage = delta_sigma / base_interval if base_interval > 0 else delta_sigma / 1000
|
|
||||||
return get_normal_distributed_interval(base_interval, sigma_percentage, 1, 86400, use_3sigma_rule=True)
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
logger.warning(f"{self.context.log_prefix} timing_utils不可用,使用固定间隔")
|
|
||||||
return max(300, abs(global_config.chat.proactive_thinking_interval))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.context.log_prefix} 动态间隔计算出错: {e},使用固定间隔")
|
|
||||||
return max(300, abs(global_config.chat.proactive_thinking_interval))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _format_duration(seconds: float) -> str:
|
|
||||||
"""
|
|
||||||
格式化时长为可读字符串
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seconds: 时长秒数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 格式化后的字符串 (例如 "1小时2分3秒")
|
|
||||||
"""
|
|
||||||
hours = int(seconds // 3600)
|
|
||||||
minutes = int((seconds % 3600) // 60)
|
|
||||||
secs = int(seconds % 60)
|
|
||||||
parts = []
|
|
||||||
if hours > 0:
|
|
||||||
parts.append(f"{hours}小时")
|
|
||||||
if minutes > 0:
|
|
||||||
parts.append(f"{minutes}分")
|
|
||||||
if secs > 0 or not parts:
|
|
||||||
parts.append(f"{secs}秒")
|
|
||||||
return "".join(parts)
|
|
||||||
|
|
||||||
async def _main_chat_loop(self):
|
|
||||||
"""
|
|
||||||
主聊天循环
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 持续运行聊天处理循环
|
|
||||||
- 只有在有新消息时才进行思考循环
|
|
||||||
- 无新消息时等待新消息到达(由主动思考系统单独处理主动发言)
|
|
||||||
- 处理取消和异常情况
|
|
||||||
- 在异常时尝试重新启动循环
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
while self.context.running:
|
|
||||||
has_new_messages = await self._loop_body()
|
|
||||||
|
|
||||||
if has_new_messages:
|
|
||||||
# 有新消息时,继续快速检查是否还有更多消息
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
else:
|
|
||||||
# 无新消息时,等待较长时间再检查
|
|
||||||
# 这里只是为了定期检查系统状态,不进行思考循环
|
|
||||||
# 真正的新消息响应依赖于消息到达时的通知
|
|
||||||
await asyncio.sleep(1.0)
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.info(f"{self.context.log_prefix} 麦麦已关闭聊天")
|
|
||||||
except Exception:
|
|
||||||
logger.error(f"{self.context.log_prefix} 麦麦聊天意外错误,将于3s后尝试重新启动")
|
|
||||||
print(traceback.format_exc())
|
|
||||||
await asyncio.sleep(3)
|
|
||||||
self._loop_task = asyncio.create_task(self._main_chat_loop())
|
|
||||||
logger.error(f"{self.context.log_prefix} 结束了当前聊天循环")
|
|
||||||
|
|
||||||
async def _loop_body(self) -> bool:
|
|
||||||
"""
|
|
||||||
单次循环体处理
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否处理了新消息
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 检查是否处于睡眠模式,如果是则处理唤醒度逻辑
|
|
||||||
- 获取最近的新消息(过滤机器人自己的消息和命令)
|
|
||||||
- 只有在有新消息时才进行思考循环处理
|
|
||||||
- 更新最后消息时间和读取时间
|
|
||||||
- 根据当前聊天模式执行不同的处理逻辑
|
|
||||||
- FOCUS模式:直接处理所有消息并检查退出条件
|
|
||||||
- NORMAL模式:检查进入FOCUS模式的条件,并通过normal_mode_handler处理消息
|
|
||||||
"""
|
|
||||||
async with self._processing_lock:
|
|
||||||
# --- 核心状态更新 ---
|
|
||||||
await self.sleep_manager.update_sleep_state(self.wakeup_manager)
|
|
||||||
current_sleep_state = self.sleep_manager.get_current_sleep_state()
|
|
||||||
is_sleeping = current_sleep_state == SleepState.SLEEPING
|
|
||||||
is_in_insomnia = current_sleep_state == SleepState.INSOMNIA
|
|
||||||
|
|
||||||
# 核心修复:在睡眠模式(包括失眠)下获取消息时,不过滤命令消息,以确保@消息能被接收
|
|
||||||
filter_command_flag = not (is_sleeping or is_in_insomnia)
|
|
||||||
|
|
||||||
# 从队列中获取所有待处理的新消息
|
|
||||||
recent_messages = []
|
|
||||||
while not self.context.new_message_queue.empty():
|
|
||||||
recent_messages.append(await self.context.new_message_queue.get())
|
|
||||||
|
|
||||||
has_new_messages = bool(recent_messages)
|
|
||||||
new_message_count = len(recent_messages)
|
|
||||||
|
|
||||||
# 只有在有新消息时才进行思考循环处理
|
|
||||||
if has_new_messages:
|
|
||||||
self.context.last_message_time = time.time()
|
|
||||||
self.context.last_read_time = time.time()
|
|
||||||
|
|
||||||
# --- 专注模式安静群组检查 ---
|
|
||||||
quiet_groups = global_config.chat.focus_mode_quiet_groups
|
|
||||||
if quiet_groups and self.context.chat_stream:
|
|
||||||
is_group_chat = self.context.chat_stream.group_info is not None
|
|
||||||
if is_group_chat:
|
|
||||||
try:
|
|
||||||
platform = self.context.chat_stream.platform
|
|
||||||
group_id = self.context.chat_stream.group_info.group_id
|
|
||||||
|
|
||||||
# 兼容不同QQ适配器的平台名称
|
|
||||||
is_qq_platform = platform in ["qq", "napcat"]
|
|
||||||
|
|
||||||
current_chat_identifier = f"{platform}:{group_id}"
|
|
||||||
config_identifier_for_qq = f"qq:{group_id}"
|
|
||||||
|
|
||||||
is_in_quiet_list = (current_chat_identifier in quiet_groups or
|
|
||||||
(is_qq_platform and config_identifier_for_qq in quiet_groups))
|
|
||||||
|
|
||||||
if is_in_quiet_list:
|
|
||||||
is_mentioned_in_batch = False
|
|
||||||
for msg in recent_messages:
|
|
||||||
if msg.get("is_mentioned"):
|
|
||||||
is_mentioned_in_batch = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if not is_mentioned_in_batch:
|
|
||||||
logger.info(f"{self.context.log_prefix} 在专注安静模式下,因未被提及而忽略了消息。")
|
|
||||||
return True # 消耗消息但不做回复
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.context.log_prefix} 检查专注安静群组时出错: {e}")
|
|
||||||
|
|
||||||
# 处理唤醒度逻辑
|
|
||||||
if current_sleep_state in [SleepState.SLEEPING, SleepState.PREPARING_SLEEP, SleepState.INSOMNIA]:
|
|
||||||
self._handle_wakeup_messages(recent_messages)
|
|
||||||
|
|
||||||
# 再次获取最新状态,因为 handle_wakeup 可能导致状态变为 WOKEN_UP
|
|
||||||
current_sleep_state = self.sleep_manager.get_current_sleep_state()
|
|
||||||
|
|
||||||
if current_sleep_state == SleepState.SLEEPING:
|
|
||||||
# 只有在纯粹的 SLEEPING 状态下才跳过消息处理
|
|
||||||
return True
|
|
||||||
|
|
||||||
if current_sleep_state == SleepState.WOKEN_UP:
|
|
||||||
logger.info(f"{self.context.log_prefix} 从睡眠中被唤醒,将处理积压的消息。")
|
|
||||||
|
|
||||||
# 根据聊天模式处理新消息
|
|
||||||
should_process, interest_value = await self._should_process_messages(recent_messages)
|
|
||||||
if not should_process:
|
|
||||||
# 消息数量不足或兴趣不够,等待
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
return True # Skip rest of the logic for this iteration
|
|
||||||
|
|
||||||
# Messages should be processed
|
|
||||||
action_type = await self.cycle_processor.observe(interest_value=interest_value)
|
|
||||||
|
|
||||||
# 尝试触发表达学习
|
|
||||||
if self.context.expression_learner:
|
|
||||||
try:
|
|
||||||
await self.context.expression_learner.trigger_learning_for_chat()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.context.log_prefix} 表达学习触发失败: {e}")
|
|
||||||
|
|
||||||
# 管理no_reply计数器
|
|
||||||
if action_type != "no_reply":
|
|
||||||
self.recent_interest_records.clear()
|
|
||||||
self.context.no_reply_consecutive = 0
|
|
||||||
logger.debug(f"{self.context.log_prefix} 执行了{action_type}动作,重置no_reply计数器")
|
|
||||||
else: # action_type == "no_reply"
|
|
||||||
self.context.no_reply_consecutive += 1
|
|
||||||
self._determine_form_type()
|
|
||||||
|
|
||||||
# 在一轮动作执行完毕后,增加睡眠压力
|
|
||||||
if self.context.energy_manager and global_config.sleep_system.enable_insomnia_system:
|
|
||||||
if action_type not in ["no_reply", "no_action"]:
|
|
||||||
self.context.energy_manager.increase_sleep_pressure()
|
|
||||||
|
|
||||||
# 如果成功观察,增加能量值并重置累积兴趣值
|
|
||||||
self.context.energy_value += 1 / global_config.chat.focus_value
|
|
||||||
# 重置累积兴趣值,因为消息已经被成功处理
|
|
||||||
self.context.breaking_accumulated_interest = 0.0
|
|
||||||
logger.info(
|
|
||||||
f"{self.context.log_prefix} 能量值增加,当前能量值:{self.context.energy_value:.1f},重置累积兴趣值"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 更新上一帧的睡眠状态
|
|
||||||
self.context.was_sleeping = is_sleeping
|
|
||||||
|
|
||||||
# --- 重新入睡逻辑 ---
|
|
||||||
# 如果被吵醒了,并且在一定时间内没有新消息,则尝试重新入睡
|
|
||||||
if self.sleep_manager.get_current_sleep_state() == SleepState.WOKEN_UP and not has_new_messages:
|
|
||||||
re_sleep_delay = global_config.sleep_system.re_sleep_delay_minutes * 60
|
|
||||||
# 使用 last_message_time 来判断空闲时间
|
|
||||||
if time.time() - self.context.last_message_time > re_sleep_delay:
|
|
||||||
logger.info(
|
|
||||||
f"{self.context.log_prefix} 已被唤醒且超过 {re_sleep_delay / 60} 分钟无新消息,尝试重新入睡。"
|
|
||||||
)
|
|
||||||
self.sleep_manager.reset_sleep_state_after_wakeup()
|
|
||||||
|
|
||||||
# 保存HFC上下文状态
|
|
||||||
self.context.save_context_state()
|
|
||||||
return has_new_messages
|
|
||||||
|
|
||||||
def _handle_wakeup_messages(self, messages):
|
|
||||||
"""
|
|
||||||
处理休眠状态下的消息,累积唤醒度
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: 消息列表
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 区分私聊和群聊消息
|
|
||||||
- 检查群聊消息是否艾特了机器人
|
|
||||||
- 调用唤醒度管理器累积唤醒度
|
|
||||||
- 如果达到阈值则唤醒并进入愤怒状态
|
|
||||||
"""
|
|
||||||
if not self.wakeup_manager:
|
|
||||||
return
|
|
||||||
|
|
||||||
is_private_chat = self.context.chat_stream.group_info is None if self.context.chat_stream else False
|
|
||||||
|
|
||||||
for message in messages:
|
|
||||||
is_mentioned = False
|
|
||||||
|
|
||||||
# 检查群聊消息是否艾特了机器人
|
|
||||||
if not is_private_chat:
|
|
||||||
# 最终修复:直接使用消息对象中由上游处理好的 is_mention 字段。
|
|
||||||
# 该字段在 message.py 的 MessageRecv._process_single_segment 中被设置。
|
|
||||||
if message.get("is_mentioned"):
|
|
||||||
is_mentioned = True
|
|
||||||
|
|
||||||
# 累积唤醒度
|
|
||||||
woke_up = self.wakeup_manager.add_wakeup_value(is_private_chat, is_mentioned)
|
|
||||||
|
|
||||||
if woke_up:
|
|
||||||
logger.info(f"{self.context.log_prefix} 被消息吵醒,进入愤怒状态!")
|
|
||||||
break
|
|
||||||
|
|
||||||
def _determine_form_type(self) -> str:
|
|
||||||
"""判断使用哪种形式的no_reply"""
|
|
||||||
# 检查是否启用breaking模式
|
|
||||||
if not getattr(global_config.chat, "enable_breaking_mode", False):
|
|
||||||
logger.info(f"{self.context.log_prefix} breaking模式已禁用,使用waiting形式")
|
|
||||||
self.context.focus_energy = 1
|
|
||||||
return "waiting"
|
|
||||||
|
|
||||||
# 如果连续no_reply次数少于3次,使用waiting形式
|
|
||||||
if self.context.no_reply_consecutive <= 3:
|
|
||||||
self.context.focus_energy = 1
|
|
||||||
return "waiting"
|
|
||||||
else:
|
|
||||||
# 使用累积兴趣值而不是最近3次的记录
|
|
||||||
total_interest = self.context.breaking_accumulated_interest
|
|
||||||
|
|
||||||
# 计算调整后的阈值
|
|
||||||
adjusted_threshold = 1 / global_config.chat.get_current_talk_frequency(self.context.stream_id)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"{self.context.log_prefix} 累积兴趣值: {total_interest:.2f}, 调整后阈值: {adjusted_threshold:.2f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 如果累积兴趣值小于阈值,进入breaking形式
|
|
||||||
if total_interest < adjusted_threshold:
|
|
||||||
logger.info(f"{self.context.log_prefix} 累积兴趣度不足,进入breaking形式")
|
|
||||||
self.context.focus_energy = random.randint(3, 6)
|
|
||||||
return "breaking"
|
|
||||||
else:
|
|
||||||
logger.info(f"{self.context.log_prefix} 累积兴趣度充足,使用waiting形式")
|
|
||||||
self.context.focus_energy = 1
|
|
||||||
return "waiting"
|
|
||||||
|
|
||||||
async def _should_process_messages(self, new_message: List[Dict[str, Any]]) -> tuple[bool, float]:
|
|
||||||
"""
|
|
||||||
统一判断是否应该处理消息的函数
|
|
||||||
根据当前循环模式和消息内容决定是否继续处理
|
|
||||||
"""
|
|
||||||
if not new_message:
|
|
||||||
return False, 0.0
|
|
||||||
|
|
||||||
# 计算平均兴趣值
|
|
||||||
total_interest = 0.0
|
|
||||||
message_count = 0
|
|
||||||
for msg_dict in new_message:
|
|
||||||
interest_value = msg_dict.get("interest_value", 0.0)
|
|
||||||
if msg_dict.get("processed_plain_text", ""):
|
|
||||||
total_interest += interest_value
|
|
||||||
message_count += 1
|
|
||||||
|
|
||||||
avg_interest = total_interest / message_count if message_count > 0 else 0.0
|
|
||||||
|
|
||||||
logger.info(f"{self.context.log_prefix} 收到 {len(new_message)} 条新消息,立即处理!平均兴趣值: {avg_interest:.2f}")
|
|
||||||
return True, avg_interest
|
|
||||||
@@ -1,82 +0,0 @@
|
|||||||
import time
|
|
||||||
from typing import List, Optional, TYPE_CHECKING
|
|
||||||
|
|
||||||
from src.chat.chat_loop.hfc_utils import CycleDetail
|
|
||||||
from src.chat.express.expression_learner import ExpressionLearner
|
|
||||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
|
||||||
from src.chat.planner_actions.action_manager import ActionManager
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.person_info.relationship_builder_manager import RelationshipBuilder
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class HfcContext:
|
|
||||||
def __init__(self, chat_id: str):
|
|
||||||
"""
|
|
||||||
初始化HFC聊天上下文
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: 聊天ID标识符
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 存储和管理单个聊天会话的所有状态信息
|
|
||||||
- 包含聊天流、关系构建器、表达学习器等核心组件
|
|
||||||
- 管理聊天模式、能量值、时间戳等关键状态
|
|
||||||
- 提供循环历史记录和当前循环详情的存储
|
|
||||||
- 集成唤醒度管理器,处理休眠状态下的唤醒机制
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 如果找不到对应的聊天流
|
|
||||||
"""
|
|
||||||
self.stream_id: str = chat_id
|
|
||||||
self.chat_stream: Optional[ChatStream] = get_chat_manager().get_stream(self.stream_id)
|
|
||||||
if not self.chat_stream:
|
|
||||||
raise ValueError(f"无法找到聊天流: {self.stream_id}")
|
|
||||||
|
|
||||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]"
|
|
||||||
|
|
||||||
self.relationship_builder: Optional[RelationshipBuilder] = None
|
|
||||||
self.expression_learner: Optional[ExpressionLearner] = None
|
|
||||||
|
|
||||||
self.energy_value = self.chat_stream.energy_value
|
|
||||||
self.sleep_pressure = self.chat_stream.sleep_pressure
|
|
||||||
self.was_sleeping = False # 用于检测睡眠状态的切换
|
|
||||||
|
|
||||||
self.last_message_time = time.time()
|
|
||||||
self.last_read_time = time.time() - 10
|
|
||||||
|
|
||||||
# 从聊天流恢复breaking累积兴趣值
|
|
||||||
self.breaking_accumulated_interest = getattr(self.chat_stream, "breaking_accumulated_interest", 0.0)
|
|
||||||
|
|
||||||
self.action_manager = ActionManager()
|
|
||||||
|
|
||||||
self.running: bool = False
|
|
||||||
|
|
||||||
self.history_loop: List[CycleDetail] = []
|
|
||||||
self.cycle_counter = 0
|
|
||||||
self.current_cycle_detail: Optional[CycleDetail] = None
|
|
||||||
|
|
||||||
# 唤醒度管理器 - 延迟初始化以避免循环导入
|
|
||||||
self.wakeup_manager: Optional["WakeUpManager"] = None
|
|
||||||
self.energy_manager: Optional["EnergyManager"] = None
|
|
||||||
self.sleep_manager: Optional["SleepManager"] = None
|
|
||||||
|
|
||||||
# 从聊天流获取focus_energy,如果没有则使用配置文件中的值
|
|
||||||
self.focus_energy = getattr(self.chat_stream, "focus_energy", global_config.chat.focus_value)
|
|
||||||
self.no_reply_consecutive = 0
|
|
||||||
self.total_interest = 0.0
|
|
||||||
# breaking形式下的累积兴趣值
|
|
||||||
self.breaking_accumulated_interest = 0.0
|
|
||||||
# 引用HeartFChatting实例,以便其他组件可以调用其方法
|
|
||||||
self.chat_instance: "HeartFChatting"
|
|
||||||
|
|
||||||
def save_context_state(self):
|
|
||||||
"""将当前状态保存到聊天流"""
|
|
||||||
if self.chat_stream:
|
|
||||||
self.chat_stream.energy_value = self.energy_value
|
|
||||||
self.chat_stream.sleep_pressure = self.sleep_pressure
|
|
||||||
self.chat_stream.focus_energy = self.focus_energy
|
|
||||||
self.chat_stream.no_reply_consecutive = self.no_reply_consecutive
|
|
||||||
self.chat_stream.breaking_accumulated_interest = self.breaking_accumulated_interest
|
|
||||||
@@ -1,172 +0,0 @@
|
|||||||
import time
|
|
||||||
from typing import Optional, Dict, Any, Union
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
|
||||||
from src.plugin_system.apis import send_api
|
|
||||||
from maim_message.message_base import GroupInfo
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("hfc")
|
|
||||||
|
|
||||||
|
|
||||||
class CycleDetail:
|
|
||||||
"""
|
|
||||||
循环信息记录类
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 记录单次思考循环的详细信息
|
|
||||||
- 包含循环ID、思考ID、时间戳等基本信息
|
|
||||||
- 存储循环的规划信息和动作信息
|
|
||||||
- 提供序列化和转换功能
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, cycle_id: Union[int, str]):
|
|
||||||
"""
|
|
||||||
初始化循环详情记录
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cycle_id: 循环ID,用于标识循环的顺序
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 设置循环基本标识信息
|
|
||||||
- 初始化时间戳和计时器
|
|
||||||
- 准备循环信息存储容器
|
|
||||||
"""
|
|
||||||
self.cycle_id = cycle_id
|
|
||||||
self.thinking_id = ""
|
|
||||||
self.start_time = time.time()
|
|
||||||
self.end_time: Optional[float] = None
|
|
||||||
self.timers: Dict[str, float] = {}
|
|
||||||
|
|
||||||
self.loop_plan_info: Dict[str, Any] = {}
|
|
||||||
self.loop_action_info: Dict[str, Any] = {}
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
将循环信息转换为字典格式
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: 包含所有循环信息的字典,已处理循环引用和序列化问题
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 递归转换复杂对象为可序列化格式
|
|
||||||
- 防止循环引用导致的无限递归
|
|
||||||
- 限制递归深度避免栈溢出
|
|
||||||
- 只保留基本数据类型和可序列化的值
|
|
||||||
"""
|
|
||||||
|
|
||||||
def convert_to_serializable(obj, depth=0, seen=None):
|
|
||||||
if seen is None:
|
|
||||||
seen = set()
|
|
||||||
|
|
||||||
# 防止递归过深
|
|
||||||
if depth > 5: # 降低递归深度限制
|
|
||||||
return str(obj)
|
|
||||||
|
|
||||||
# 防止循环引用
|
|
||||||
obj_id = id(obj)
|
|
||||||
if obj_id in seen:
|
|
||||||
return str(obj)
|
|
||||||
seen.add(obj_id)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if hasattr(obj, "to_dict"):
|
|
||||||
# 对于有to_dict方法的对象,直接调用其to_dict方法
|
|
||||||
return obj.to_dict()
|
|
||||||
elif isinstance(obj, dict):
|
|
||||||
# 对于字典,只保留基本类型和可序列化的值
|
|
||||||
return {
|
|
||||||
k: convert_to_serializable(v, depth + 1, seen)
|
|
||||||
for k, v in obj.items()
|
|
||||||
if isinstance(k, (str, int, float, bool))
|
|
||||||
}
|
|
||||||
elif isinstance(obj, (list, tuple)):
|
|
||||||
# 对于列表和元组,只保留可序列化的元素
|
|
||||||
return [
|
|
||||||
convert_to_serializable(item, depth + 1, seen)
|
|
||||||
for item in obj
|
|
||||||
if not isinstance(item, (dict, list, tuple))
|
|
||||||
or isinstance(item, (str, int, float, bool, type(None)))
|
|
||||||
]
|
|
||||||
elif isinstance(obj, (str, int, float, bool, type(None))):
|
|
||||||
return obj
|
|
||||||
else:
|
|
||||||
return str(obj)
|
|
||||||
finally:
|
|
||||||
seen.remove(obj_id)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"cycle_id": self.cycle_id,
|
|
||||||
"start_time": self.start_time,
|
|
||||||
"end_time": self.end_time,
|
|
||||||
"timers": self.timers,
|
|
||||||
"thinking_id": self.thinking_id,
|
|
||||||
"loop_plan_info": convert_to_serializable(self.loop_plan_info),
|
|
||||||
"loop_action_info": convert_to_serializable(self.loop_action_info),
|
|
||||||
}
|
|
||||||
|
|
||||||
def set_loop_info(self, loop_info: Dict[str, Any]):
|
|
||||||
"""
|
|
||||||
设置循环信息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
loop_info: 包含循环规划和动作信息的字典
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 从传入的循环信息中提取规划和动作信息
|
|
||||||
- 更新当前循环详情的相关字段
|
|
||||||
"""
|
|
||||||
self.loop_plan_info = loop_info["loop_plan_info"]
|
|
||||||
self.loop_action_info = loop_info["loop_action_info"]
|
|
||||||
|
|
||||||
|
|
||||||
async def send_typing(user_id):
|
|
||||||
"""
|
|
||||||
发送打字状态指示
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 创建内心聊天流(用于状态显示)
|
|
||||||
- 发送typing状态消息
|
|
||||||
- 不存储到消息记录中
|
|
||||||
- 用于S4U功能的视觉反馈
|
|
||||||
"""
|
|
||||||
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
|
|
||||||
|
|
||||||
chat = await get_chat_manager().get_or_create_stream(
|
|
||||||
platform="amaidesu_default",
|
|
||||||
user_info=None,
|
|
||||||
group_info=group_info,
|
|
||||||
)
|
|
||||||
|
|
||||||
from plugin_system.core.event_manager import event_manager
|
|
||||||
from src.plugins.built_in.napcat_adapter_plugin.event_types import NapcatEvent
|
|
||||||
# 设置正在输入状态
|
|
||||||
await event_manager.trigger_event(NapcatEvent.PERSONAL.SET_INPUT_STATUS,user_id=user_id,event_type=1)
|
|
||||||
|
|
||||||
await send_api.custom_to_stream(
|
|
||||||
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def stop_typing():
|
|
||||||
"""
|
|
||||||
停止打字状态指示
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 创建内心聊天流(用于状态显示)
|
|
||||||
- 发送stop_typing状态消息
|
|
||||||
- 不存储到消息记录中
|
|
||||||
- 结束S4U功能的视觉反馈
|
|
||||||
"""
|
|
||||||
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
|
|
||||||
|
|
||||||
chat = await get_chat_manager().get_or_create_stream(
|
|
||||||
platform="amaidesu_default",
|
|
||||||
user_info=None,
|
|
||||||
group_info=group_info,
|
|
||||||
)
|
|
||||||
|
|
||||||
await send_api.custom_to_stream(
|
|
||||||
message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False
|
|
||||||
)
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Optional, Dict, Any
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ProactiveTriggerEvent:
|
|
||||||
"""
|
|
||||||
主动思考触发事件的数据类
|
|
||||||
"""
|
|
||||||
|
|
||||||
source: str # 触发源的标识,例如 "silence_monitor", "insomnia_manager"
|
|
||||||
reason: str # 触发的具体原因,例如 "聊天已沉默10分钟", "深夜emo"
|
|
||||||
metadata: Optional[Dict[str, Any]] = field(default_factory=dict) # 可选的元数据,用于传递额外信息
|
|
||||||
related_message_id: Optional[str] = None # 关联的消息ID,用于加载上下文
|
|
||||||
@@ -1,264 +0,0 @@
|
|||||||
import time
|
|
||||||
import traceback
|
|
||||||
from typing import TYPE_CHECKING, Dict, Any
|
|
||||||
|
|
||||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages_with_id
|
|
||||||
from src.common.database.sqlalchemy_database_api import store_action_info
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.mood.mood_manager import mood_manager
|
|
||||||
from src.plugin_system import tool_api
|
|
||||||
from src.plugin_system.apis import generator_api
|
|
||||||
from src.plugin_system.apis.generator_api import process_human_text
|
|
||||||
from src.plugin_system.base.component_types import ChatMode
|
|
||||||
from src.schedule.schedule_manager import schedule_manager
|
|
||||||
from .events import ProactiveTriggerEvent
|
|
||||||
from ..hfc_context import HfcContext
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ..cycle_processor import CycleProcessor
|
|
||||||
|
|
||||||
logger = get_logger("hfc")
|
|
||||||
|
|
||||||
|
|
||||||
class ProactiveThinker:
|
|
||||||
"""
|
|
||||||
主动思考器,负责处理和执行主动思考事件。
|
|
||||||
当接收到 ProactiveTriggerEvent 时,它会根据事件内容进行一系列决策和操作,
|
|
||||||
例如调整情绪、调用规划器生成行动,并最终可能产生一个主动的回复。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, context: HfcContext, cycle_processor: "CycleProcessor"):
|
|
||||||
"""
|
|
||||||
初始化主动思考器。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context (HfcContext): HFC聊天上下文对象,提供了当前聊天会话的所有背景信息。
|
|
||||||
cycle_processor (CycleProcessor): 循环处理器,用于执行主动思考后产生的动作。
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 接收并处理主动思考事件 (ProactiveTriggerEvent)。
|
|
||||||
- 在思考前根据事件类型执行预处理操作,如修改当前情绪状态。
|
|
||||||
- 调用行动规划器 (Action Planner) 来决定下一步应该做什么。
|
|
||||||
- 如果规划结果是发送消息,则调用生成器API生成回复并发送。
|
|
||||||
"""
|
|
||||||
self.context = context
|
|
||||||
self.cycle_processor = cycle_processor
|
|
||||||
|
|
||||||
async def think(self, trigger_event: ProactiveTriggerEvent):
|
|
||||||
"""
|
|
||||||
主动思考的统一入口API。
|
|
||||||
这是外部触发主动思考时调用的主要方法。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
trigger_event (ProactiveTriggerEvent): 描述触发上下文的事件对象,包含了思考的来源和原因。
|
|
||||||
"""
|
|
||||||
logger.info(
|
|
||||||
f"{self.context.log_prefix} 接收到主动思考事件: "
|
|
||||||
f"来源='{trigger_event.source}', 原因='{trigger_event.reason}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 步骤 1: 根据事件类型执行思考前的准备工作,例如调整情绪。
|
|
||||||
await self._prepare_for_thinking(trigger_event)
|
|
||||||
|
|
||||||
# 步骤 2: 执行核心的思考和决策逻辑。
|
|
||||||
await self._execute_proactive_thinking(trigger_event)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# 捕获并记录在思考过程中发生的任何异常。
|
|
||||||
logger.error(f"{self.context.log_prefix} 主动思考 think 方法执行异常: {e}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
|
|
||||||
async def _prepare_for_thinking(self, trigger_event: ProactiveTriggerEvent):
|
|
||||||
"""
|
|
||||||
根据事件类型,在正式思考前执行准备工作。
|
|
||||||
目前主要是处理来自失眠管理器的事件,并据此调整情绪。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
trigger_event (ProactiveTriggerEvent): 触发事件。
|
|
||||||
"""
|
|
||||||
# 目前只处理来自失眠管理器(insomnia_manager)的事件
|
|
||||||
if trigger_event.source != "insomnia_manager":
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 获取当前聊天的情绪对象
|
|
||||||
mood_obj = mood_manager.get_mood_by_chat_id(self.context.stream_id)
|
|
||||||
new_mood = None
|
|
||||||
|
|
||||||
# 根据失眠的不同原因设置对应的情绪
|
|
||||||
if trigger_event.reason == "low_pressure":
|
|
||||||
new_mood = "精力过剩,毫无睡意"
|
|
||||||
elif trigger_event.reason == "random":
|
|
||||||
new_mood = "深夜emo,胡思乱想"
|
|
||||||
elif trigger_event.reason == "goodnight":
|
|
||||||
new_mood = "有点困了,准备睡觉了"
|
|
||||||
|
|
||||||
# 如果成功匹配到了新的情绪,则更新情绪状态
|
|
||||||
if new_mood:
|
|
||||||
mood_obj.mood_state = new_mood
|
|
||||||
mood_obj.last_change_time = time.time()
|
|
||||||
logger.info(
|
|
||||||
f"{self.context.log_prefix} 因 '{trigger_event.reason}',"
|
|
||||||
f"情绪状态被强制更新为: {mood_obj.mood_state}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.context.log_prefix} 设置失眠情绪时出错: {e}")
|
|
||||||
|
|
||||||
async def _execute_proactive_thinking(self, trigger_event: ProactiveTriggerEvent):
|
|
||||||
"""
|
|
||||||
执行主动思考的核心逻辑。
|
|
||||||
它会调用规划器来决定是否要采取行动,以及采取什么行动。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
trigger_event (ProactiveTriggerEvent): 触发事件。
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
actions, _ = await self.cycle_processor.action_planner.plan(mode=ChatMode.PROACTIVE)
|
|
||||||
action_result = actions[0] if actions else {}
|
|
||||||
action_type = action_result.get("action_type")
|
|
||||||
|
|
||||||
if action_type is None:
|
|
||||||
logger.info(f"{self.context.log_prefix} 主动思考决策: 规划器未返回有效动作")
|
|
||||||
return
|
|
||||||
|
|
||||||
if action_type == "proactive_reply":
|
|
||||||
await self._generate_proactive_content_and_send(action_result, trigger_event)
|
|
||||||
elif action_type not in ["do_nothing", "no_action"]:
|
|
||||||
await self.cycle_processor._handle_action(
|
|
||||||
action=action_result["action_type"],
|
|
||||||
reasoning=action_result.get("reasoning", ""),
|
|
||||||
action_data=action_result.get("action_data", {}),
|
|
||||||
cycle_timers={},
|
|
||||||
thinking_id="",
|
|
||||||
action_message=action_result.get("action_message")
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info(f"{self.context.log_prefix} 主动思考决策: 保持沉默")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.context.log_prefix} 主动思考执行异常: {e}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
|
|
||||||
|
|
||||||
async def _generate_proactive_content_and_send(self, action_result: Dict[str, Any], trigger_event: ProactiveTriggerEvent):
|
|
||||||
"""
|
|
||||||
获取实时信息,构建最终的生成提示词,并生成和发送主动回复。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action_result (Dict[str, Any]): 规划器返回的动作结果。
|
|
||||||
trigger_event (ProactiveTriggerEvent): 触发事件。
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
topic = action_result.get("action_data", {}).get("topic", "随便聊聊")
|
|
||||||
logger.info(f"{self.context.log_prefix} 主动思考确定主题: '{topic}'")
|
|
||||||
|
|
||||||
schedule_block = "你今天没有日程安排。"
|
|
||||||
if global_config.planning_system.schedule_enable:
|
|
||||||
if current_activity := schedule_manager.get_current_activity():
|
|
||||||
schedule_block = f"你当前正在:{current_activity}。"
|
|
||||||
|
|
||||||
news_block = "暂时没有获取到最新资讯。"
|
|
||||||
if trigger_event.source != "reminder_system":
|
|
||||||
try:
|
|
||||||
web_search_tool = tool_api.get_tool_instance("web_search")
|
|
||||||
if web_search_tool:
|
|
||||||
try:
|
|
||||||
search_result_dict = await web_search_tool.execute(function_args={"keyword": topic, "max_results": 10})
|
|
||||||
except TypeError:
|
|
||||||
try:
|
|
||||||
search_result_dict = await web_search_tool.execute(function_args={"keyword": topic, "max_results": 10})
|
|
||||||
except TypeError:
|
|
||||||
logger.warning(f"{self.context.log_prefix} 网络搜索工具参数不匹配,跳过搜索")
|
|
||||||
news_block = "跳过网络搜索。"
|
|
||||||
search_result_dict = None
|
|
||||||
|
|
||||||
if search_result_dict and not search_result_dict.get("error"):
|
|
||||||
news_block = search_result_dict.get("content", "未能提取有效资讯。")
|
|
||||||
elif search_result_dict:
|
|
||||||
logger.warning(f"{self.context.log_prefix} 网络搜索返回错误: {search_result_dict.get('error')}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"{self.context.log_prefix} 未找到 web_search 工具实例。")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.context.log_prefix} 主动思考时网络搜索失败: {e}")
|
|
||||||
message_list = get_raw_msg_before_timestamp_with_chat(
|
|
||||||
chat_id=self.context.stream_id,
|
|
||||||
timestamp=time.time(),
|
|
||||||
limit=int(global_config.chat.max_context_size * 0.3),
|
|
||||||
)
|
|
||||||
chat_context_block, _ = await build_readable_messages_with_id(messages=message_list)
|
|
||||||
bot_name = global_config.bot.nickname
|
|
||||||
personality = global_config.personality
|
|
||||||
identity_block = (
|
|
||||||
f"你的名字是{bot_name}。\n"
|
|
||||||
f"关于你:{personality.personality_core},并且{personality.personality_side}。\n"
|
|
||||||
f"你的身份是{personality.identity},平时说话风格是{personality.reply_style}。"
|
|
||||||
)
|
|
||||||
mood_block = f"你现在的心情是:{mood_manager.get_mood_by_chat_id(self.context.stream_id).mood_state}"
|
|
||||||
|
|
||||||
final_prompt = f"""
|
|
||||||
## 你的角色
|
|
||||||
{identity_block}
|
|
||||||
|
|
||||||
## 你的心情
|
|
||||||
{mood_block}
|
|
||||||
|
|
||||||
## 你今天的日程安排
|
|
||||||
{schedule_block}
|
|
||||||
|
|
||||||
## 关于你准备讨论的话题"{topic}"的最新信息
|
|
||||||
{news_block}
|
|
||||||
|
|
||||||
## 最近的聊天内容
|
|
||||||
{chat_context_block}
|
|
||||||
|
|
||||||
## 任务
|
|
||||||
你现在想要主动说些什么。话题是"{topic}",但这只是一个参考方向。
|
|
||||||
|
|
||||||
根据最近的聊天内容,你可以:
|
|
||||||
- 如果是想关心朋友,就自然地询问他们的情况
|
|
||||||
- 如果想起了之前的话题,就问问后来怎么样了
|
|
||||||
- 如果有什么想分享的想法,就自然地开启话题
|
|
||||||
- 如果只是想闲聊,就随意地说些什么
|
|
||||||
|
|
||||||
**重要**:如果获取到了最新的网络信息(news_block不为空),请**自然地**将这些信息融入你的回复中,作为话题的补充或引子,而不是生硬地复述。
|
|
||||||
|
|
||||||
## 要求
|
|
||||||
- 像真正的朋友一样,自然地表达关心或好奇
|
|
||||||
- 不要过于正式,要口语化和亲切
|
|
||||||
- 结合你的角色设定,保持温暖的风格
|
|
||||||
- 直接输出你想说的话,不要解释为什么要说
|
|
||||||
|
|
||||||
请输出一条简短、自然的主动发言。
|
|
||||||
"""
|
|
||||||
|
|
||||||
response_text = await generator_api.generate_response_custom(
|
|
||||||
chat_stream=self.context.chat_stream,
|
|
||||||
prompt=final_prompt,
|
|
||||||
request_type="chat.replyer.proactive",
|
|
||||||
)
|
|
||||||
|
|
||||||
if response_text:
|
|
||||||
response_set = process_human_text(
|
|
||||||
content=response_text,
|
|
||||||
enable_splitter=global_config.response_splitter.enable,
|
|
||||||
enable_chinese_typo=global_config.chinese_typo.enable,
|
|
||||||
)
|
|
||||||
await self.cycle_processor.response_handler.send_response(
|
|
||||||
response_set, time.time(), action_result.get("action_message")
|
|
||||||
)
|
|
||||||
await store_action_info(
|
|
||||||
chat_stream=self.context.chat_stream,
|
|
||||||
action_name="proactive_reply",
|
|
||||||
action_data={"topic": topic, "response": response_text},
|
|
||||||
action_prompt_display=f"主动发起对话: {topic}",
|
|
||||||
action_done=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.error(f"{self.context.log_prefix} 主动思考生成回复失败。")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.context.log_prefix} 生成主动回复内容时异常: {e}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
@@ -1,184 +0,0 @@
|
|||||||
import time
|
|
||||||
import random
|
|
||||||
from typing import Dict, Any, Tuple
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.plugin_system.apis import send_api, message_api, database_api
|
|
||||||
from src.person_info.person_info import get_person_info_manager
|
|
||||||
from .hfc_context import HfcContext
|
|
||||||
|
|
||||||
# 导入反注入系统
|
|
||||||
|
|
||||||
# 日志记录器
|
|
||||||
logger = get_logger("hfc")
|
|
||||||
anti_injector_logger = get_logger("anti_injector")
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseHandler:
|
|
||||||
"""
|
|
||||||
响应处理器类,负责生成和发送机器人的回复。
|
|
||||||
"""
|
|
||||||
def __init__(self, context: HfcContext):
|
|
||||||
"""
|
|
||||||
初始化响应处理器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context: HFC聊天上下文对象
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 负责生成和发送机器人的回复
|
|
||||||
- 处理回复的格式化和发送逻辑
|
|
||||||
- 管理回复状态和日志记录
|
|
||||||
"""
|
|
||||||
self.context = context
|
|
||||||
|
|
||||||
async def generate_and_send_reply(
|
|
||||||
self,
|
|
||||||
response_set,
|
|
||||||
reply_to_str,
|
|
||||||
loop_start_time,
|
|
||||||
action_message,
|
|
||||||
cycle_timers: Dict[str, float],
|
|
||||||
thinking_id,
|
|
||||||
plan_result,
|
|
||||||
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
|
|
||||||
"""
|
|
||||||
生成并发送回复的主方法
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response_set: 生成的回复内容集合
|
|
||||||
reply_to_str: 回复目标字符串
|
|
||||||
loop_start_time: 循环开始时间
|
|
||||||
action_message: 动作消息数据
|
|
||||||
cycle_timers: 循环计时器
|
|
||||||
thinking_id: 思考ID
|
|
||||||
plan_result: 规划结果
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (循环信息, 回复文本, 计时器信息)
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 发送生成的回复内容
|
|
||||||
- 存储动作信息到数据库
|
|
||||||
- 构建并返回完整的循环信息
|
|
||||||
- 用于上级方法的状态跟踪
|
|
||||||
"""
|
|
||||||
reply_text = await self.send_response(response_set, loop_start_time, action_message)
|
|
||||||
|
|
||||||
person_info_manager = get_person_info_manager()
|
|
||||||
|
|
||||||
# 获取平台信息
|
|
||||||
platform = "default"
|
|
||||||
if self.context.chat_stream:
|
|
||||||
platform = (
|
|
||||||
action_message.get("chat_info_platform")
|
|
||||||
or action_message.get("user_platform")
|
|
||||||
or self.context.chat_stream.platform
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取用户信息并生成回复提示
|
|
||||||
user_id = action_message.get("user_id", "")
|
|
||||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
|
||||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
|
||||||
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
|
|
||||||
|
|
||||||
# 存储动作信息到数据库
|
|
||||||
await database_api.store_action_info(
|
|
||||||
chat_stream=self.context.chat_stream,
|
|
||||||
action_build_into_prompt=False,
|
|
||||||
action_prompt_display=action_prompt_display,
|
|
||||||
action_done=True,
|
|
||||||
thinking_id=thinking_id,
|
|
||||||
action_data={"reply_text": reply_text, "reply_to": reply_to_str},
|
|
||||||
action_name="reply",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 构建循环信息
|
|
||||||
loop_info: Dict[str, Any] = {
|
|
||||||
"loop_plan_info": {
|
|
||||||
"action_result": plan_result.get("action_result", {}),
|
|
||||||
},
|
|
||||||
"loop_action_info": {
|
|
||||||
"action_taken": True,
|
|
||||||
"reply_text": reply_text,
|
|
||||||
"command": "",
|
|
||||||
"taken_time": time.time(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
return loop_info, reply_text, cycle_timers
|
|
||||||
|
|
||||||
async def send_response(self, reply_set, thinking_start_time, message_data) -> str:
|
|
||||||
"""
|
|
||||||
发送回复内容的具体实现
|
|
||||||
|
|
||||||
Args:
|
|
||||||
reply_set: 回复内容集合,包含多个回复段
|
|
||||||
reply_to: 回复目标
|
|
||||||
thinking_start_time: 思考开始时间
|
|
||||||
message_data: 消息数据
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 完整的回复文本
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 检查是否有新消息需要回复
|
|
||||||
- 处理主动思考的"沉默"决定
|
|
||||||
- 根据消息数量决定是否添加回复引用
|
|
||||||
- 逐段发送回复内容,支持打字效果
|
|
||||||
- 正确处理元组格式的回复段
|
|
||||||
"""
|
|
||||||
current_time = time.time()
|
|
||||||
# 计算新消息数量
|
|
||||||
new_message_count = await message_api.count_new_messages(
|
|
||||||
chat_id=self.context.stream_id, start_time=thinking_start_time, end_time=current_time
|
|
||||||
)
|
|
||||||
|
|
||||||
# 根据新消息数量决定是否需要引用回复
|
|
||||||
need_reply = new_message_count >= random.randint(2, 4)
|
|
||||||
|
|
||||||
reply_text = ""
|
|
||||||
is_proactive_thinking = (message_data.get("message_type") == "proactive_thinking") if message_data else True
|
|
||||||
|
|
||||||
first_replied = False
|
|
||||||
for reply_seg in reply_set:
|
|
||||||
# 调试日志:验证reply_seg的格式
|
|
||||||
logger.debug(f"Processing reply_seg type: {type(reply_seg)}, content: {reply_seg}")
|
|
||||||
|
|
||||||
# 修正:正确处理元组格式 (格式为: (type, content))
|
|
||||||
if isinstance(reply_seg, tuple) and len(reply_seg) >= 2:
|
|
||||||
_, data = reply_seg
|
|
||||||
else:
|
|
||||||
# 向下兼容:如果已经是字符串,则直接使用
|
|
||||||
data = str(reply_seg)
|
|
||||||
|
|
||||||
if isinstance(data, list):
|
|
||||||
data = "".join(map(str, data))
|
|
||||||
reply_text += data
|
|
||||||
|
|
||||||
# 如果是主动思考且内容为“沉默”,则不发送
|
|
||||||
if is_proactive_thinking and data.strip() == "沉默":
|
|
||||||
logger.info(f"{self.context.log_prefix} 主动思考决定保持沉默,不发送消息")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 发送第一段回复
|
|
||||||
if not first_replied:
|
|
||||||
await send_api.text_to_stream(
|
|
||||||
text=data,
|
|
||||||
stream_id=self.context.stream_id,
|
|
||||||
reply_to_message=message_data,
|
|
||||||
set_reply=need_reply,
|
|
||||||
typing=False,
|
|
||||||
)
|
|
||||||
first_replied = True
|
|
||||||
else:
|
|
||||||
# 发送后续回复
|
|
||||||
sent_message = await send_api.text_to_stream(
|
|
||||||
text=data,
|
|
||||||
stream_id=self.context.stream_id,
|
|
||||||
reply_to_message=None,
|
|
||||||
set_reply=False,
|
|
||||||
typing=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return reply_text
|
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
from src.common.logger import get_logger
|
|
||||||
from ..hfc_context import HfcContext
|
|
||||||
|
|
||||||
logger = get_logger("notification_sender")
|
|
||||||
|
|
||||||
|
|
||||||
class NotificationSender:
|
|
||||||
@staticmethod
|
|
||||||
async def send_goodnight_notification(context: HfcContext):
|
|
||||||
"""发送晚安通知"""
|
|
||||||
try:
|
|
||||||
from ..proactive.events import ProactiveTriggerEvent
|
|
||||||
from ..proactive.proactive_thinker import ProactiveThinker
|
|
||||||
|
|
||||||
event = ProactiveTriggerEvent(source="sleep_manager", reason="goodnight")
|
|
||||||
proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor)
|
|
||||||
await proactive_thinker.think(event)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"发送晚安通知失败: {e}")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def send_insomnia_notification(context: HfcContext, reason: str):
|
|
||||||
"""发送失眠通知"""
|
|
||||||
try:
|
|
||||||
from ..proactive.events import ProactiveTriggerEvent
|
|
||||||
from ..proactive.proactive_thinker import ProactiveThinker
|
|
||||||
|
|
||||||
event = ProactiveTriggerEvent(source="sleep_manager", reason=reason)
|
|
||||||
proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor)
|
|
||||||
await proactive_thinker.think(event)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"发送失眠通知失败: {e}")
|
|
||||||
@@ -1,110 +0,0 @@
|
|||||||
from enum import Enum, auto
|
|
||||||
from datetime import datetime
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.manager.local_store_manager import local_storage
|
|
||||||
|
|
||||||
logger = get_logger("sleep_state")
|
|
||||||
|
|
||||||
|
|
||||||
class SleepState(Enum):
|
|
||||||
"""
|
|
||||||
定义了角色可能处于的几种睡眠状态。
|
|
||||||
这是一个状态机,用于管理角色的睡眠周期。
|
|
||||||
"""
|
|
||||||
|
|
||||||
AWAKE = auto() # 清醒状态
|
|
||||||
INSOMNIA = auto() # 失眠状态
|
|
||||||
PREPARING_SLEEP = auto() # 准备入睡状态,一个短暂的过渡期
|
|
||||||
SLEEPING = auto() # 正在睡觉状态
|
|
||||||
WOKEN_UP = auto() # 被吵醒状态
|
|
||||||
|
|
||||||
|
|
||||||
class SleepStateSerializer:
|
|
||||||
"""
|
|
||||||
睡眠状态序列化器。
|
|
||||||
负责将内存中的睡眠状态对象持久化到本地存储(如JSON文件),
|
|
||||||
以及在程序启动时从本地存储中恢复状态。
|
|
||||||
这样可以确保即使程序重启,角色的睡眠状态也能得以保留。
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def save(state_data: dict):
|
|
||||||
"""
|
|
||||||
将当前的睡眠状态数据保存到本地存储。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state_data (dict): 包含睡眠状态信息的字典。
|
|
||||||
datetime对象会被转换为时间戳,Enum成员会被转换为其名称字符串。
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 准备要序列化的数据字典
|
|
||||||
state = {
|
|
||||||
# 保存当前状态的枚举名称
|
|
||||||
"current_state": state_data["_current_state"].name,
|
|
||||||
# 将datetime对象转换为Unix时间戳以便序列化
|
|
||||||
"sleep_buffer_end_time_ts": state_data["_sleep_buffer_end_time"].timestamp()
|
|
||||||
if state_data["_sleep_buffer_end_time"]
|
|
||||||
else None,
|
|
||||||
"total_delayed_minutes_today": state_data["_total_delayed_minutes_today"],
|
|
||||||
# 将date对象转换为ISO格式的字符串
|
|
||||||
"last_sleep_check_date_str": state_data["_last_sleep_check_date"].isoformat()
|
|
||||||
if state_data["_last_sleep_check_date"]
|
|
||||||
else None,
|
|
||||||
"re_sleep_attempt_time_ts": state_data["_re_sleep_attempt_time"].timestamp()
|
|
||||||
if state_data["_re_sleep_attempt_time"]
|
|
||||||
else None,
|
|
||||||
}
|
|
||||||
# 写入本地存储
|
|
||||||
local_storage["schedule_sleep_state"] = state
|
|
||||||
logger.debug(f"已保存睡眠状态: {state}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"保存睡眠状态失败: {e}")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load() -> dict:
|
|
||||||
"""
|
|
||||||
从本地存储加载并解析睡眠状态。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: 包含恢复后睡眠状态信息的字典。
|
|
||||||
如果加载失败或没有找到数据,则返回一个默认的清醒状态。
|
|
||||||
"""
|
|
||||||
# 定义一个默认的状态,以防加载失败
|
|
||||||
state_data = {
|
|
||||||
"_current_state": SleepState.AWAKE,
|
|
||||||
"_sleep_buffer_end_time": None,
|
|
||||||
"_total_delayed_minutes_today": 0,
|
|
||||||
"_last_sleep_check_date": None,
|
|
||||||
"_re_sleep_attempt_time": None,
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
# 从本地存储读取数据
|
|
||||||
state = local_storage["schedule_sleep_state"]
|
|
||||||
if state and isinstance(state, dict):
|
|
||||||
# 恢复当前状态枚举
|
|
||||||
state_name = state.get("current_state")
|
|
||||||
if state_name and hasattr(SleepState, state_name):
|
|
||||||
state_data["_current_state"] = SleepState[state_name]
|
|
||||||
|
|
||||||
# 从时间戳恢复datetime对象
|
|
||||||
end_time_ts = state.get("sleep_buffer_end_time_ts")
|
|
||||||
if end_time_ts:
|
|
||||||
state_data["_sleep_buffer_end_time"] = datetime.fromtimestamp(end_time_ts)
|
|
||||||
|
|
||||||
# 恢复重新入睡尝试时间
|
|
||||||
re_sleep_ts = state.get("re_sleep_attempt_time_ts")
|
|
||||||
if re_sleep_ts:
|
|
||||||
state_data["_re_sleep_attempt_time"] = datetime.fromtimestamp(re_sleep_ts)
|
|
||||||
|
|
||||||
# 恢复今日延迟睡眠总分钟数
|
|
||||||
state_data["_total_delayed_minutes_today"] = state.get("total_delayed_minutes_today", 0)
|
|
||||||
|
|
||||||
# 从ISO格式字符串恢复date对象
|
|
||||||
date_str = state.get("last_sleep_check_date_str")
|
|
||||||
if date_str:
|
|
||||||
state_data["_last_sleep_check_date"] = datetime.fromisoformat(date_str).date()
|
|
||||||
|
|
||||||
logger.info(f"成功从本地存储加载睡眠状态: {state}")
|
|
||||||
except Exception as e:
|
|
||||||
# 如果加载过程中出现任何问题,记录警告并返回默认状态
|
|
||||||
logger.warning(f"加载睡眠状态失败,将使用默认值: {e}")
|
|
||||||
return state_data
|
|
||||||
@@ -1,232 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import time
|
|
||||||
from typing import Optional
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.manager.local_store_manager import local_storage
|
|
||||||
from ..hfc_context import HfcContext
|
|
||||||
|
|
||||||
logger = get_logger("wakeup")
|
|
||||||
|
|
||||||
|
|
||||||
class WakeUpManager:
|
|
||||||
def __init__(self, context: HfcContext):
|
|
||||||
"""
|
|
||||||
初始化唤醒度管理器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context: HFC聊天上下文对象
|
|
||||||
|
|
||||||
功能说明:
|
|
||||||
- 管理休眠状态下的唤醒度累积
|
|
||||||
- 处理唤醒度的自然衰减
|
|
||||||
- 控制愤怒状态的持续时间
|
|
||||||
"""
|
|
||||||
self.context = context
|
|
||||||
self.wakeup_value = 0.0 # 当前唤醒度
|
|
||||||
self.is_angry = False # 是否处于愤怒状态
|
|
||||||
self.angry_start_time = 0.0 # 愤怒状态开始时间
|
|
||||||
self.last_decay_time = time.time() # 上次衰减时间
|
|
||||||
self._decay_task: Optional[asyncio.Task] = None
|
|
||||||
self.last_log_time = 0
|
|
||||||
self.log_interval = 30
|
|
||||||
|
|
||||||
# 从配置文件获取参数
|
|
||||||
sleep_config = global_config.sleep_system
|
|
||||||
self.wakeup_threshold = sleep_config.wakeup_threshold
|
|
||||||
self.private_message_increment = sleep_config.private_message_increment
|
|
||||||
self.group_mention_increment = sleep_config.group_mention_increment
|
|
||||||
self.decay_rate = sleep_config.decay_rate
|
|
||||||
self.decay_interval = sleep_config.decay_interval
|
|
||||||
self.angry_duration = sleep_config.angry_duration
|
|
||||||
self.enabled = sleep_config.enable
|
|
||||||
self.angry_prompt = sleep_config.angry_prompt
|
|
||||||
|
|
||||||
self._load_wakeup_state()
|
|
||||||
|
|
||||||
def _get_storage_key(self) -> str:
|
|
||||||
"""获取当前聊天流的本地存储键"""
|
|
||||||
return f"wakeup_manager_state_{self.context.stream_id}"
|
|
||||||
|
|
||||||
def _load_wakeup_state(self):
|
|
||||||
"""从本地存储加载状态"""
|
|
||||||
state = local_storage[self._get_storage_key()]
|
|
||||||
if state and isinstance(state, dict):
|
|
||||||
self.wakeup_value = state.get("wakeup_value", 0.0)
|
|
||||||
self.is_angry = state.get("is_angry", False)
|
|
||||||
self.angry_start_time = state.get("angry_start_time", 0.0)
|
|
||||||
logger.info(f"{self.context.log_prefix} 成功从本地存储加载唤醒状态: {state}")
|
|
||||||
else:
|
|
||||||
logger.info(f"{self.context.log_prefix} 未找到本地唤醒状态,将使用默认值初始化。")
|
|
||||||
|
|
||||||
def _save_wakeup_state(self):
|
|
||||||
"""将当前状态保存到本地存储"""
|
|
||||||
state = {
|
|
||||||
"wakeup_value": self.wakeup_value,
|
|
||||||
"is_angry": self.is_angry,
|
|
||||||
"angry_start_time": self.angry_start_time,
|
|
||||||
}
|
|
||||||
local_storage[self._get_storage_key()] = state
|
|
||||||
logger.debug(f"{self.context.log_prefix} 已将唤醒状态保存到本地存储: {state}")
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
"""启动唤醒度管理器"""
|
|
||||||
if not self.enabled:
|
|
||||||
logger.info(f"{self.context.log_prefix} 唤醒度系统已禁用,跳过启动")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not self._decay_task:
|
|
||||||
self._decay_task = asyncio.create_task(self._decay_loop())
|
|
||||||
self._decay_task.add_done_callback(self._handle_decay_completion)
|
|
||||||
logger.info(f"{self.context.log_prefix} 唤醒度管理器已启动")
|
|
||||||
|
|
||||||
async def stop(self):
|
|
||||||
"""停止唤醒度管理器"""
|
|
||||||
if self._decay_task and not self._decay_task.done():
|
|
||||||
self._decay_task.cancel()
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
logger.info(f"{self.context.log_prefix} 唤醒度管理器已停止")
|
|
||||||
|
|
||||||
def _handle_decay_completion(self, task: asyncio.Task):
|
|
||||||
"""处理衰减任务完成"""
|
|
||||||
try:
|
|
||||||
if exception := task.exception():
|
|
||||||
logger.error(f"{self.context.log_prefix} 唤醒度衰减任务异常: {exception}")
|
|
||||||
else:
|
|
||||||
logger.info(f"{self.context.log_prefix} 唤醒度衰减任务正常结束")
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.info(f"{self.context.log_prefix} 唤醒度衰减任务被取消")
|
|
||||||
|
|
||||||
async def _decay_loop(self):
|
|
||||||
"""唤醒度衰减循环"""
|
|
||||||
while self.context.running:
|
|
||||||
await asyncio.sleep(self.decay_interval)
|
|
||||||
|
|
||||||
current_time = time.time()
|
|
||||||
|
|
||||||
# 检查愤怒状态是否过期
|
|
||||||
if self.is_angry and current_time - self.angry_start_time >= self.angry_duration:
|
|
||||||
self.is_angry = False
|
|
||||||
# 通知情绪管理系统清除愤怒状态
|
|
||||||
from src.mood.mood_manager import mood_manager
|
|
||||||
|
|
||||||
mood_manager.clear_angry_from_wakeup(self.context.stream_id)
|
|
||||||
logger.info(f"{self.context.log_prefix} 愤怒状态结束,恢复正常")
|
|
||||||
self._save_wakeup_state()
|
|
||||||
|
|
||||||
# 唤醒度自然衰减
|
|
||||||
if self.wakeup_value > 0:
|
|
||||||
old_value = self.wakeup_value
|
|
||||||
self.wakeup_value = max(0, self.wakeup_value - self.decay_rate)
|
|
||||||
if old_value != self.wakeup_value:
|
|
||||||
logger.debug(f"{self.context.log_prefix} 唤醒度衰减: {old_value:.1f} -> {self.wakeup_value:.1f}")
|
|
||||||
self._save_wakeup_state()
|
|
||||||
|
|
||||||
def add_wakeup_value(self, is_private_chat: bool, is_mentioned: bool = False) -> bool:
|
|
||||||
"""
|
|
||||||
增加唤醒度值
|
|
||||||
|
|
||||||
Args:
|
|
||||||
is_private_chat: 是否为私聊
|
|
||||||
is_mentioned: 是否被艾特(仅群聊有效)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否达到唤醒阈值
|
|
||||||
"""
|
|
||||||
# 如果系统未启用,直接返回
|
|
||||||
if not self.enabled:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 只有在休眠且非失眠状态下才累积唤醒度
|
|
||||||
from .sleep_state import SleepState
|
|
||||||
|
|
||||||
sleep_manager = self.context.sleep_manager
|
|
||||||
if not sleep_manager:
|
|
||||||
return False
|
|
||||||
|
|
||||||
current_sleep_state = sleep_manager.get_current_sleep_state()
|
|
||||||
if current_sleep_state != SleepState.SLEEPING:
|
|
||||||
return False
|
|
||||||
|
|
||||||
old_value = self.wakeup_value
|
|
||||||
|
|
||||||
if is_private_chat:
|
|
||||||
# 私聊每条消息都增加唤醒度
|
|
||||||
self.wakeup_value += self.private_message_increment
|
|
||||||
logger.debug(f"{self.context.log_prefix} 私聊消息增加唤醒度: +{self.private_message_increment}")
|
|
||||||
elif is_mentioned:
|
|
||||||
# 群聊只有被艾特才增加唤醒度
|
|
||||||
self.wakeup_value += self.group_mention_increment
|
|
||||||
logger.debug(f"{self.context.log_prefix} 群聊艾特增加唤醒度: +{self.group_mention_increment}")
|
|
||||||
else:
|
|
||||||
# 群聊未被艾特,不增加唤醒度
|
|
||||||
return False
|
|
||||||
|
|
||||||
current_time = time.time()
|
|
||||||
if current_time - self.last_log_time > self.log_interval:
|
|
||||||
logger.info(
|
|
||||||
f"{self.context.log_prefix} 唤醒度变化: {old_value:.1f} -> {self.wakeup_value:.1f} (阈值: {self.wakeup_threshold})"
|
|
||||||
)
|
|
||||||
self.last_log_time = current_time
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
f"{self.context.log_prefix} 唤醒度变化: {old_value:.1f} -> {self.wakeup_value:.1f} (阈值: {self.wakeup_threshold})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 检查是否达到唤醒阈值
|
|
||||||
if self.wakeup_value >= self.wakeup_threshold:
|
|
||||||
self._trigger_wakeup()
|
|
||||||
return True
|
|
||||||
|
|
||||||
self._save_wakeup_state()
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _trigger_wakeup(self):
|
|
||||||
"""触发唤醒,进入愤怒状态"""
|
|
||||||
self.is_angry = True
|
|
||||||
self.angry_start_time = time.time()
|
|
||||||
self.wakeup_value = 0.0 # 重置唤醒度
|
|
||||||
|
|
||||||
self._save_wakeup_state()
|
|
||||||
|
|
||||||
# 通知情绪管理系统进入愤怒状态
|
|
||||||
from src.mood.mood_manager import mood_manager
|
|
||||||
|
|
||||||
mood_manager.set_angry_from_wakeup(self.context.stream_id)
|
|
||||||
|
|
||||||
# 通知SleepManager重置睡眠状态
|
|
||||||
if self.context.sleep_manager:
|
|
||||||
self.context.sleep_manager.reset_sleep_state_after_wakeup()
|
|
||||||
|
|
||||||
logger.info(f"{self.context.log_prefix} 唤醒度达到阈值({self.wakeup_threshold}),被吵醒进入愤怒状态!")
|
|
||||||
|
|
||||||
def get_angry_prompt_addition(self) -> str:
|
|
||||||
"""获取愤怒状态下的提示词补充"""
|
|
||||||
if self.is_angry:
|
|
||||||
return self.angry_prompt
|
|
||||||
return ""
|
|
||||||
|
|
||||||
def is_in_angry_state(self) -> bool:
|
|
||||||
"""检查是否处于愤怒状态"""
|
|
||||||
if self.is_angry:
|
|
||||||
current_time = time.time()
|
|
||||||
if current_time - self.angry_start_time >= self.angry_duration:
|
|
||||||
self.is_angry = False
|
|
||||||
# 通知情绪管理系统清除愤怒状态
|
|
||||||
from src.mood.mood_manager import mood_manager
|
|
||||||
|
|
||||||
mood_manager.clear_angry_from_wakeup(self.context.stream_id)
|
|
||||||
logger.info(f"{self.context.log_prefix} 愤怒状态自动过期")
|
|
||||||
return False
|
|
||||||
return self.is_angry
|
|
||||||
|
|
||||||
def get_status_info(self) -> dict:
|
|
||||||
"""获取当前状态信息"""
|
|
||||||
return {
|
|
||||||
"wakeup_value": self.wakeup_value,
|
|
||||||
"wakeup_threshold": self.wakeup_threshold,
|
|
||||||
"is_angry": self.is_angry,
|
|
||||||
"angry_remaining_time": max(0, self.angry_duration - (time.time() - self.angry_start_time))
|
|
||||||
if self.is_angry
|
|
||||||
else 0,
|
|
||||||
}
|
|
||||||
145
src/chat/chatter_manager.py
Normal file
145
src/chat/chatter_manager.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
import time
|
||||||
|
from src.plugin_system.base.base_chatter import BaseChatter
|
||||||
|
from src.common.data_models.message_manager_data_model import StreamContext
|
||||||
|
from src.plugins.built_in.affinity_flow_chatter.planner import ChatterActionPlanner as ActionPlanner
|
||||||
|
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||||
|
from src.plugin_system.base.component_types import ChatType, ComponentType
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("chatter_manager")
|
||||||
|
|
||||||
|
class ChatterManager:
|
||||||
|
def __init__(self, action_manager: ChatterActionManager):
|
||||||
|
self.action_manager = action_manager
|
||||||
|
self.chatter_classes: Dict[ChatType, List[type]] = {}
|
||||||
|
self.instances: Dict[str, BaseChatter] = {}
|
||||||
|
|
||||||
|
# 管理器统计
|
||||||
|
self.stats = {
|
||||||
|
"chatters_registered": 0,
|
||||||
|
"streams_processed": 0,
|
||||||
|
"successful_executions": 0,
|
||||||
|
"failed_executions": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _auto_register_from_component_registry(self):
|
||||||
|
"""从组件注册表自动注册已注册的chatter组件"""
|
||||||
|
try:
|
||||||
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
# 获取所有CHATTER类型的组件
|
||||||
|
chatter_components = component_registry.get_enabled_chatter_registry()
|
||||||
|
for chatter_name, chatter_class in chatter_components.items():
|
||||||
|
self.register_chatter(chatter_class)
|
||||||
|
logger.info(f"自动注册chatter组件: {chatter_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"自动注册chatter组件时发生错误: {e}")
|
||||||
|
|
||||||
|
def register_chatter(self, chatter_class: type):
|
||||||
|
"""注册聊天处理器类"""
|
||||||
|
for chat_type in chatter_class.chat_types:
|
||||||
|
if chat_type not in self.chatter_classes:
|
||||||
|
self.chatter_classes[chat_type] = []
|
||||||
|
self.chatter_classes[chat_type].append(chatter_class)
|
||||||
|
logger.info(f"注册聊天处理器 {chatter_class.__name__} 支持 {chat_type.value} 聊天类型")
|
||||||
|
|
||||||
|
self.stats["chatters_registered"] += 1
|
||||||
|
|
||||||
|
def get_chatter_class(self, chat_type: ChatType) -> Optional[type]:
|
||||||
|
"""获取指定聊天类型的聊天处理器类"""
|
||||||
|
if chat_type in self.chatter_classes:
|
||||||
|
return self.chatter_classes[chat_type][0]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_supported_chat_types(self) -> List[ChatType]:
|
||||||
|
"""获取支持的聊天类型列表"""
|
||||||
|
return list(self.chatter_classes.keys())
|
||||||
|
|
||||||
|
def get_registered_chatters(self) -> Dict[ChatType, List[type]]:
|
||||||
|
"""获取已注册的聊天处理器"""
|
||||||
|
return self.chatter_classes.copy()
|
||||||
|
|
||||||
|
def get_stream_instance(self, stream_id: str) -> Optional[BaseChatter]:
|
||||||
|
"""获取指定流的聊天处理器实例"""
|
||||||
|
return self.instances.get(stream_id)
|
||||||
|
|
||||||
|
def cleanup_inactive_instances(self, max_inactive_minutes: int = 60):
|
||||||
|
"""清理不活跃的实例"""
|
||||||
|
current_time = time.time()
|
||||||
|
max_inactive_seconds = max_inactive_minutes * 60
|
||||||
|
|
||||||
|
inactive_streams = []
|
||||||
|
for stream_id, instance in self.instances.items():
|
||||||
|
if hasattr(instance, 'get_activity_time'):
|
||||||
|
activity_time = instance.get_activity_time()
|
||||||
|
if (current_time - activity_time) > max_inactive_seconds:
|
||||||
|
inactive_streams.append(stream_id)
|
||||||
|
|
||||||
|
for stream_id in inactive_streams:
|
||||||
|
del self.instances[stream_id]
|
||||||
|
logger.info(f"清理不活跃聊天流实例: {stream_id}")
|
||||||
|
|
||||||
|
async def process_stream_context(self, stream_id: str, context: StreamContext) -> dict:
|
||||||
|
"""处理流上下文"""
|
||||||
|
chat_type = context.chat_type
|
||||||
|
logger.debug(f"处理流 {stream_id},聊天类型: {chat_type.value}")
|
||||||
|
if not self.chatter_classes:
|
||||||
|
self._auto_register_from_component_registry()
|
||||||
|
|
||||||
|
# 获取适合该聊天类型的chatter
|
||||||
|
chatter_class = self.get_chatter_class(chat_type)
|
||||||
|
if not chatter_class:
|
||||||
|
# 如果没有找到精确匹配,尝试查找支持ALL类型的chatter
|
||||||
|
from src.plugin_system.base.component_types import ChatType
|
||||||
|
all_chatter_class = self.get_chatter_class(ChatType.ALL)
|
||||||
|
if all_chatter_class:
|
||||||
|
chatter_class = all_chatter_class
|
||||||
|
logger.info(f"流 {stream_id} 使用通用chatter (类型: {chat_type.value})")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"No chatter registered for chat type {chat_type}")
|
||||||
|
|
||||||
|
if stream_id not in self.instances:
|
||||||
|
self.instances[stream_id] = chatter_class(stream_id=stream_id, action_manager=self.action_manager)
|
||||||
|
logger.info(f"创建新的聊天流实例: {stream_id} 使用 {chatter_class.__name__} (类型: {chat_type.value})")
|
||||||
|
|
||||||
|
self.stats["streams_processed"] += 1
|
||||||
|
try:
|
||||||
|
result = await self.instances[stream_id].execute(context)
|
||||||
|
self.stats["successful_executions"] += 1
|
||||||
|
|
||||||
|
# 从 mood_manager 获取最新的 chat_stream 并同步回 StreamContext
|
||||||
|
try:
|
||||||
|
from src.mood.mood_manager import mood_manager
|
||||||
|
mood = mood_manager.get_mood_by_chat_id(stream_id)
|
||||||
|
if mood and mood.chat_stream:
|
||||||
|
context.chat_stream = mood.chat_stream
|
||||||
|
logger.debug(f"已将最新的 chat_stream 同步回流 {stream_id} 的 StreamContext")
|
||||||
|
except Exception as sync_e:
|
||||||
|
logger.error(f"同步 chat_stream 回 StreamContext 失败: {sync_e}")
|
||||||
|
|
||||||
|
# 记录处理结果
|
||||||
|
success = result.get("success", False)
|
||||||
|
actions_count = result.get("actions_count", 0)
|
||||||
|
logger.debug(f"流 {stream_id} 处理完成: 成功={success}, 动作数={actions_count}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
self.stats["failed_executions"] += 1
|
||||||
|
logger.error(f"处理流 {stream_id} 时发生错误: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""获取管理器统计信息"""
|
||||||
|
stats = self.stats.copy()
|
||||||
|
stats["active_instances"] = len(self.instances)
|
||||||
|
stats["registered_chatter_types"] = len(self.chatter_classes)
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def reset_stats(self):
|
||||||
|
"""重置统计信息"""
|
||||||
|
self.stats = {
|
||||||
|
"chatters_registered": 0,
|
||||||
|
"streams_processed": 0,
|
||||||
|
"successful_executions": 0,
|
||||||
|
"failed_executions": 0,
|
||||||
|
}
|
||||||
@@ -2,8 +2,10 @@
|
|||||||
"""
|
"""
|
||||||
表情包发送历史记录模块
|
表情包发送历史记录模块
|
||||||
"""
|
"""
|
||||||
from collections import deque
|
|
||||||
|
import os
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
|||||||
@@ -149,7 +149,7 @@ class MaiEmoji:
|
|||||||
# --- 数据库操作 ---
|
# --- 数据库操作 ---
|
||||||
try:
|
try:
|
||||||
# 准备数据库记录 for emoji collection
|
# 准备数据库记录 for emoji collection
|
||||||
async with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
emotion_str = ",".join(self.emotion) if self.emotion else ""
|
emotion_str = ",".join(self.emotion) if self.emotion else ""
|
||||||
|
|
||||||
emoji = Emoji(
|
emoji = Emoji(
|
||||||
@@ -167,7 +167,7 @@ class MaiEmoji:
|
|||||||
last_used_time=self.last_used_time,
|
last_used_time=self.last_used_time,
|
||||||
)
|
)
|
||||||
session.add(emoji)
|
session.add(emoji)
|
||||||
await session.commit()
|
session.commit()
|
||||||
|
|
||||||
logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
|
logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
|
||||||
|
|
||||||
@@ -203,17 +203,17 @@ class MaiEmoji:
|
|||||||
|
|
||||||
# 2. 删除数据库记录
|
# 2. 删除数据库记录
|
||||||
try:
|
try:
|
||||||
async with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
will_delete_emoji = (
|
will_delete_emoji = session.execute(
|
||||||
await session.execute(select(Emoji).where(Emoji.emoji_hash == self.hash))
|
select(Emoji).where(Emoji.emoji_hash == self.hash)
|
||||||
).scalar_one_or_none()
|
).scalar_one_or_none()
|
||||||
if will_delete_emoji is None:
|
if will_delete_emoji is None:
|
||||||
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
|
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
|
||||||
result = 0
|
result = 0 # Indicate no DB record was deleted
|
||||||
else:
|
else:
|
||||||
await session.delete(will_delete_emoji)
|
session.delete(will_delete_emoji)
|
||||||
result = 1
|
result = 1 # Successfully deleted one record
|
||||||
await session.commit()
|
session.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[错误] 删除数据库记录时出错: {str(e)}")
|
logger.error(f"[错误] 删除数据库记录时出错: {str(e)}")
|
||||||
result = 0
|
result = 0
|
||||||
@@ -424,19 +424,17 @@ class EmojiManager:
|
|||||||
# if not self._initialized:
|
# if not self._initialized:
|
||||||
# raise RuntimeError("EmojiManager not initialized")
|
# raise RuntimeError("EmojiManager not initialized")
|
||||||
|
|
||||||
@staticmethod
|
def record_usage(self, emoji_hash: str) -> None:
|
||||||
async def record_usage(emoji_hash: str) -> None:
|
|
||||||
"""记录表情使用次数"""
|
"""记录表情使用次数"""
|
||||||
try:
|
try:
|
||||||
async with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
emoji_update = (
|
emoji_update = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none()
|
||||||
await session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash))
|
|
||||||
).scalar_one_or_none()
|
|
||||||
if emoji_update is None:
|
if emoji_update is None:
|
||||||
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
|
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
|
||||||
else:
|
else:
|
||||||
emoji_update.usage_count += 1
|
emoji_update.usage_count += 1
|
||||||
emoji_update.last_used_time = time.time()
|
emoji_update.last_used_time = time.time() # Update last used time
|
||||||
|
session.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"记录表情使用失败: {str(e)}")
|
logger.error(f"记录表情使用失败: {str(e)}")
|
||||||
|
|
||||||
@@ -523,12 +521,10 @@ class EmojiManager:
|
|||||||
|
|
||||||
# 7. 获取选中的表情包并更新使用记录
|
# 7. 获取选中的表情包并更新使用记录
|
||||||
selected_emoji = candidate_emojis[selected_index]
|
selected_emoji = candidate_emojis[selected_index]
|
||||||
await self.record_usage(selected_emoji.emoji_hash)
|
self.record_usage(selected_emoji.hash)
|
||||||
_time_end = time.time()
|
_time_end = time.time()
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"找到匹配描述的表情包: {selected_emoji.description}, 耗时: {(_time_end - _time_start):.2f}s")
|
||||||
f"找到匹配描述的表情包: {selected_emoji.description}, 耗时: {(_time_end - _time_start):.2f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 8. 返回选中的表情包信息
|
# 8. 返回选中的表情包信息
|
||||||
return selected_emoji.full_path, f"[表情包:{selected_emoji.description}]", text_emotion
|
return selected_emoji.full_path, f"[表情包:{selected_emoji.description}]", text_emotion
|
||||||
@@ -629,8 +625,9 @@ class EmojiManager:
|
|||||||
|
|
||||||
# 无论steal_emoji是否开启,都检查emoji文件夹以支持手动注册
|
# 无论steal_emoji是否开启,都检查emoji文件夹以支持手动注册
|
||||||
# 只有在需要腾出空间或填充表情库时,才真正执行注册
|
# 只有在需要腾出空间或填充表情库时,才真正执行注册
|
||||||
if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or \
|
if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or (
|
||||||
(self.emoji_num < self.emoji_num_max):
|
self.emoji_num < self.emoji_num_max
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
# 获取目录下所有图片文件
|
# 获取目录下所有图片文件
|
||||||
files_to_process = [
|
files_to_process = [
|
||||||
@@ -660,11 +657,10 @@ class EmojiManager:
|
|||||||
async def get_all_emoji_from_db(self) -> None:
|
async def get_all_emoji_from_db(self) -> None:
|
||||||
"""获取所有表情包并初始化为MaiEmoji类对象,更新 self.emoji_objects"""
|
"""获取所有表情包并初始化为MaiEmoji类对象,更新 self.emoji_objects"""
|
||||||
try:
|
try:
|
||||||
async with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
logger.debug("[数据库] 开始加载所有表情包记录 ...")
|
logger.debug("[数据库] 开始加载所有表情包记录 ...")
|
||||||
|
|
||||||
result = await session.execute(select(Emoji))
|
emoji_instances = session.execute(select(Emoji)).scalars().all()
|
||||||
emoji_instances = result.scalars().all()
|
|
||||||
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
|
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
|
||||||
|
|
||||||
# 更新内存中的列表和数量
|
# 更新内存中的列表和数量
|
||||||
@@ -680,8 +676,7 @@ class EmojiManager:
|
|||||||
self.emoji_objects = [] # 加载失败则清空列表
|
self.emoji_objects = [] # 加载失败则清空列表
|
||||||
self.emoji_num = 0
|
self.emoji_num = 0
|
||||||
|
|
||||||
@staticmethod
|
async def get_emoji_from_db(self, emoji_hash: Optional[str] = None) -> List["MaiEmoji"]:
|
||||||
async def get_emoji_from_db(emoji_hash: Optional[str] = None) -> List["MaiEmoji"]:
|
|
||||||
"""获取指定哈希值的表情包并初始化为MaiEmoji类对象列表 (主要用于调试或特定查找)
|
"""获取指定哈希值的表情包并初始化为MaiEmoji类对象列表 (主要用于调试或特定查找)
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
@@ -691,16 +686,14 @@ class EmojiManager:
|
|||||||
list[MaiEmoji]: 表情包对象列表
|
list[MaiEmoji]: 表情包对象列表
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
if emoji_hash:
|
if emoji_hash:
|
||||||
result = await session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash))
|
query = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalars().all()
|
||||||
query = result.scalars().all()
|
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"[查询] 未提供 hash,将尝试加载所有表情包,建议使用 get_all_emoji_from_db 更新管理器状态。"
|
"[查询] 未提供 hash,将尝试加载所有表情包,建议使用 get_all_emoji_from_db 更新管理器状态。"
|
||||||
)
|
)
|
||||||
result = await session.execute(select(Emoji))
|
query = session.execute(select(Emoji)).scalars().all()
|
||||||
query = result.scalars().all()
|
|
||||||
|
|
||||||
emoji_instances = query
|
emoji_instances = query
|
||||||
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
|
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
|
||||||
@@ -748,8 +741,8 @@ class EmojiManager:
|
|||||||
try:
|
try:
|
||||||
emoji_record = await self.get_emoji_from_db(emoji_hash)
|
emoji_record = await self.get_emoji_from_db(emoji_hash)
|
||||||
if emoji_record and emoji_record[0].emotion:
|
if emoji_record and emoji_record[0].emotion:
|
||||||
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record[0].emotion[:50]}...")
|
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.emotion[:50]}...")
|
||||||
return emoji_record[0].emotion
|
return emoji_record.emotion
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"从数据库查询表情包描述时出错: {e}")
|
logger.error(f"从数据库查询表情包描述时出错: {e}")
|
||||||
|
|
||||||
@@ -777,11 +770,10 @@ class EmojiManager:
|
|||||||
|
|
||||||
# 如果内存中没有,从数据库查找
|
# 如果内存中没有,从数据库查找
|
||||||
try:
|
try:
|
||||||
async with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
result = await session.execute(
|
emoji_record = session.execute(
|
||||||
select(Emoji).where(Emoji.emoji_hash == emoji_hash)
|
select(Emoji).where(Emoji.emoji_hash == emoji_hash)
|
||||||
)
|
).scalar_one_or_none()
|
||||||
emoji_record = result.scalar_one_or_none()
|
|
||||||
if emoji_record and emoji_record.description:
|
if emoji_record and emoji_record.description:
|
||||||
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
|
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
|
||||||
return emoji_record.description
|
return emoji_record.description
|
||||||
@@ -938,19 +930,21 @@ class EmojiManager:
|
|||||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||||
image_bytes = base64.b64decode(image_base64)
|
image_bytes = base64.b64decode(image_base64)
|
||||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() if Image.open(io.BytesIO(image_bytes)).format else "jpeg"
|
image_format = (
|
||||||
|
Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||||
|
if Image.open(io.BytesIO(image_bytes)).format
|
||||||
|
else "jpeg"
|
||||||
|
)
|
||||||
|
|
||||||
# 2. 检查数据库中是否已存在该表情包的描述,实现复用
|
# 2. 检查数据库中是否已存在该表情包的描述,实现复用
|
||||||
existing_description = None
|
existing_description = None
|
||||||
try:
|
try:
|
||||||
async with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
result = await session.execute(
|
existing_image = (
|
||||||
select(Images).filter(
|
session.query(Images)
|
||||||
(Images.emoji_hash == image_hash) & (Images.type == "emoji")
|
.filter((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
|
||||||
|
.one_or_none()
|
||||||
)
|
)
|
||||||
)
|
|
||||||
existing_image = result.scalar_one_or_none()
|
|
||||||
if existing_image and existing_image.description:
|
if existing_image and existing_image.description:
|
||||||
existing_description = existing_image.description
|
existing_description = existing_image.description
|
||||||
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")
|
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")
|
||||||
|
|||||||
28
src/chat/energy_system/__init__.py
Normal file
28
src/chat/energy_system/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""
|
||||||
|
能量系统模块
|
||||||
|
提供稳定、高效的聊天流能量计算和管理功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .energy_manager import (
|
||||||
|
EnergyManager,
|
||||||
|
EnergyLevel,
|
||||||
|
EnergyComponent,
|
||||||
|
EnergyCalculator,
|
||||||
|
InterestEnergyCalculator,
|
||||||
|
ActivityEnergyCalculator,
|
||||||
|
RecencyEnergyCalculator,
|
||||||
|
RelationshipEnergyCalculator,
|
||||||
|
energy_manager
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"EnergyManager",
|
||||||
|
"EnergyLevel",
|
||||||
|
"EnergyComponent",
|
||||||
|
"EnergyCalculator",
|
||||||
|
"InterestEnergyCalculator",
|
||||||
|
"ActivityEnergyCalculator",
|
||||||
|
"RecencyEnergyCalculator",
|
||||||
|
"RelationshipEnergyCalculator",
|
||||||
|
"energy_manager"
|
||||||
|
]
|
||||||
473
src/chat/energy_system/energy_manager.py
Normal file
473
src/chat/energy_system/energy_manager.py
Normal file
@@ -0,0 +1,473 @@
|
|||||||
|
"""
|
||||||
|
重构后的 focus_energy 管理系统
|
||||||
|
提供稳定、高效的聊天流能量计算和管理功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Optional, Tuple, Any, Union, TypedDict
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
logger = get_logger("energy_system")
|
||||||
|
|
||||||
|
|
||||||
|
class EnergyLevel(Enum):
|
||||||
|
"""能量等级"""
|
||||||
|
VERY_LOW = 0.1 # 非常低
|
||||||
|
LOW = 0.3 # 低
|
||||||
|
NORMAL = 0.5 # 正常
|
||||||
|
HIGH = 0.7 # 高
|
||||||
|
VERY_HIGH = 0.9 # 非常高
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EnergyComponent:
|
||||||
|
"""能量组件"""
|
||||||
|
name: str
|
||||||
|
value: float
|
||||||
|
weight: float = 1.0
|
||||||
|
decay_rate: float = 0.05 # 衰减率
|
||||||
|
last_updated: float = field(default_factory=time.time)
|
||||||
|
|
||||||
|
def get_current_value(self) -> float:
|
||||||
|
"""获取当前值(考虑时间衰减)"""
|
||||||
|
age = time.time() - self.last_updated
|
||||||
|
decay_factor = max(0.1, 1.0 - (age * self.decay_rate / (24 * 3600))) # 按天衰减
|
||||||
|
return self.value * decay_factor
|
||||||
|
|
||||||
|
def update_value(self, new_value: float) -> None:
|
||||||
|
"""更新值"""
|
||||||
|
self.value = max(0.0, min(1.0, new_value))
|
||||||
|
self.last_updated = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
class EnergyContext(TypedDict):
|
||||||
|
"""能量计算上下文"""
|
||||||
|
stream_id: str
|
||||||
|
messages: List[Any]
|
||||||
|
user_id: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class EnergyResult(TypedDict):
|
||||||
|
"""能量计算结果"""
|
||||||
|
energy: float
|
||||||
|
level: EnergyLevel
|
||||||
|
distribution_interval: float
|
||||||
|
component_scores: Dict[str, float]
|
||||||
|
cached: bool
|
||||||
|
|
||||||
|
|
||||||
|
class EnergyCalculator(ABC):
|
||||||
|
"""能量计算器抽象基类"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def calculate(self, context: Dict[str, Any]) -> float:
|
||||||
|
"""计算能量值"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_weight(self) -> float:
|
||||||
|
"""获取权重"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InterestEnergyCalculator(EnergyCalculator):
|
||||||
|
"""兴趣度能量计算器"""
|
||||||
|
|
||||||
|
def calculate(self, context: Dict[str, Any]) -> float:
|
||||||
|
"""基于消息兴趣度计算能量"""
|
||||||
|
messages = context.get("messages", [])
|
||||||
|
if not messages:
|
||||||
|
return 0.3
|
||||||
|
|
||||||
|
# 计算平均兴趣度
|
||||||
|
total_interest = 0.0
|
||||||
|
valid_messages = 0
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
interest_value = getattr(msg, "interest_value", None)
|
||||||
|
if interest_value is not None:
|
||||||
|
try:
|
||||||
|
interest_float = float(interest_value)
|
||||||
|
if 0.0 <= interest_float <= 1.0:
|
||||||
|
total_interest += interest_float
|
||||||
|
valid_messages += 1
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if valid_messages > 0:
|
||||||
|
avg_interest = total_interest / valid_messages
|
||||||
|
logger.debug(f"平均消息兴趣度: {avg_interest:.3f} (基于 {valid_messages} 条消息)")
|
||||||
|
return avg_interest
|
||||||
|
else:
|
||||||
|
return 0.3
|
||||||
|
|
||||||
|
def get_weight(self) -> float:
|
||||||
|
return 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class ActivityEnergyCalculator(EnergyCalculator):
|
||||||
|
"""活跃度能量计算器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.action_weights = {
|
||||||
|
"reply": 0.4,
|
||||||
|
"react": 0.3,
|
||||||
|
"mention": 0.2,
|
||||||
|
"other": 0.1
|
||||||
|
}
|
||||||
|
|
||||||
|
def calculate(self, context: Dict[str, Any]) -> float:
|
||||||
|
"""基于活跃度计算能量"""
|
||||||
|
messages = context.get("messages", [])
|
||||||
|
if not messages:
|
||||||
|
return 0.2
|
||||||
|
|
||||||
|
total_score = 0.0
|
||||||
|
max_possible_score = len(messages) * 0.4 # 最高可能分数
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
actions = getattr(msg, "actions", [])
|
||||||
|
if isinstance(actions, list) and actions:
|
||||||
|
for action in actions:
|
||||||
|
weight = self.action_weights.get(action, self.action_weights["other"])
|
||||||
|
total_score += weight
|
||||||
|
|
||||||
|
if max_possible_score > 0:
|
||||||
|
activity_score = min(1.0, total_score / max_possible_score)
|
||||||
|
logger.debug(f"活跃度分数: {activity_score:.3f}")
|
||||||
|
return activity_score
|
||||||
|
else:
|
||||||
|
return 0.2
|
||||||
|
|
||||||
|
def get_weight(self) -> float:
|
||||||
|
return 0.3
|
||||||
|
|
||||||
|
|
||||||
|
class RecencyEnergyCalculator(EnergyCalculator):
|
||||||
|
"""最近性能量计算器"""
|
||||||
|
|
||||||
|
def calculate(self, context: Dict[str, Any]) -> float:
|
||||||
|
"""基于最近性计算能量"""
|
||||||
|
messages = context.get("messages", [])
|
||||||
|
if not messages:
|
||||||
|
return 0.1
|
||||||
|
|
||||||
|
# 获取最新消息时间
|
||||||
|
latest_time = 0.0
|
||||||
|
for msg in messages:
|
||||||
|
msg_time = getattr(msg, "time", None)
|
||||||
|
if msg_time and msg_time > latest_time:
|
||||||
|
latest_time = msg_time
|
||||||
|
|
||||||
|
if latest_time == 0.0:
|
||||||
|
return 0.1
|
||||||
|
|
||||||
|
# 计算时间衰减
|
||||||
|
current_time = time.time()
|
||||||
|
age = current_time - latest_time
|
||||||
|
|
||||||
|
# 时间衰减策略:
|
||||||
|
# 1小时内:1.0
|
||||||
|
# 1-6小时:0.8
|
||||||
|
# 6-24小时:0.5
|
||||||
|
# 1-7天:0.3
|
||||||
|
# 7天以上:0.1
|
||||||
|
if age < 3600: # 1小时内
|
||||||
|
recency_score = 1.0
|
||||||
|
elif age < 6 * 3600: # 6小时内
|
||||||
|
recency_score = 0.8
|
||||||
|
elif age < 24 * 3600: # 24小时内
|
||||||
|
recency_score = 0.5
|
||||||
|
elif age < 7 * 24 * 3600: # 7天内
|
||||||
|
recency_score = 0.3
|
||||||
|
else:
|
||||||
|
recency_score = 0.1
|
||||||
|
|
||||||
|
logger.debug(f"最近性分数: {recency_score:.3f} (年龄: {age/3600:.1f}小时)")
|
||||||
|
return recency_score
|
||||||
|
|
||||||
|
def get_weight(self) -> float:
|
||||||
|
return 0.2
|
||||||
|
|
||||||
|
|
||||||
|
class RelationshipEnergyCalculator(EnergyCalculator):
|
||||||
|
"""关系能量计算器"""
|
||||||
|
|
||||||
|
def calculate(self, context: Dict[str, Any]) -> float:
|
||||||
|
"""基于关系计算能量"""
|
||||||
|
user_id = context.get("user_id")
|
||||||
|
if not user_id:
|
||||||
|
return 0.3
|
||||||
|
|
||||||
|
# 使用插件内部的兴趣度评分系统获取关系分
|
||||||
|
try:
|
||||||
|
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||||
|
|
||||||
|
relationship_score = chatter_interest_scoring_system._calculate_relationship_score(user_id)
|
||||||
|
logger.debug(f"使用插件内部系统计算关系分: {relationship_score:.3f}")
|
||||||
|
return max(0.0, min(1.0, relationship_score))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"插件内部关系分计算失败,使用默认值: {e}")
|
||||||
|
return 0.3 # 默认基础分
|
||||||
|
|
||||||
|
def get_weight(self) -> float:
|
||||||
|
return 0.1
|
||||||
|
|
||||||
|
|
||||||
|
class EnergyManager:
|
||||||
|
"""能量管理器 - 统一管理所有能量计算"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.calculators: List[EnergyCalculator] = [
|
||||||
|
InterestEnergyCalculator(),
|
||||||
|
ActivityEnergyCalculator(),
|
||||||
|
RecencyEnergyCalculator(),
|
||||||
|
RelationshipEnergyCalculator(),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 能量缓存
|
||||||
|
self.energy_cache: Dict[str, Tuple[float, float]] = {} # stream_id -> (energy, timestamp)
|
||||||
|
self.cache_ttl: int = 60 # 1分钟缓存
|
||||||
|
|
||||||
|
# AFC阈值配置
|
||||||
|
self.thresholds: Dict[str, float] = {
|
||||||
|
"high_match": 0.8,
|
||||||
|
"reply": 0.4,
|
||||||
|
"non_reply": 0.2
|
||||||
|
}
|
||||||
|
|
||||||
|
# 统计信息
|
||||||
|
self.stats: Dict[str, Union[int, float, str]] = {
|
||||||
|
"total_calculations": 0,
|
||||||
|
"cache_hits": 0,
|
||||||
|
"cache_misses": 0,
|
||||||
|
"average_calculation_time": 0.0,
|
||||||
|
"last_threshold_update": time.time(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 从配置加载阈值
|
||||||
|
self._load_thresholds_from_config()
|
||||||
|
|
||||||
|
logger.info("能量管理器初始化完成")
|
||||||
|
|
||||||
|
def _load_thresholds_from_config(self) -> None:
|
||||||
|
"""从配置加载AFC阈值"""
|
||||||
|
try:
|
||||||
|
if hasattr(global_config, "affinity_flow") and global_config.affinity_flow is not None:
|
||||||
|
self.thresholds["high_match"] = getattr(global_config.affinity_flow, "high_match_interest_threshold", 0.8)
|
||||||
|
self.thresholds["reply"] = getattr(global_config.affinity_flow, "reply_action_interest_threshold", 0.4)
|
||||||
|
self.thresholds["non_reply"] = getattr(global_config.affinity_flow, "non_reply_action_interest_threshold", 0.2)
|
||||||
|
|
||||||
|
# 确保阈值关系合理
|
||||||
|
self.thresholds["high_match"] = max(self.thresholds["high_match"], self.thresholds["reply"] + 0.1)
|
||||||
|
self.thresholds["reply"] = max(self.thresholds["reply"], self.thresholds["non_reply"] + 0.1)
|
||||||
|
|
||||||
|
self.stats["last_threshold_update"] = time.time()
|
||||||
|
logger.info(f"加载AFC阈值: {self.thresholds}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"加载AFC阈值失败,使用默认值: {e}")
|
||||||
|
|
||||||
|
def calculate_focus_energy(self, stream_id: str, messages: List[Any], user_id: Optional[str] = None) -> float:
|
||||||
|
"""计算聊天流的focus_energy"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 更新统计
|
||||||
|
self.stats["total_calculations"] += 1
|
||||||
|
|
||||||
|
# 检查缓存
|
||||||
|
if stream_id in self.energy_cache:
|
||||||
|
cached_energy, cached_time = self.energy_cache[stream_id]
|
||||||
|
if time.time() - cached_time < self.cache_ttl:
|
||||||
|
self.stats["cache_hits"] += 1
|
||||||
|
logger.debug(f"使用缓存能量: {stream_id} = {cached_energy:.3f}")
|
||||||
|
return cached_energy
|
||||||
|
else:
|
||||||
|
self.stats["cache_misses"] += 1
|
||||||
|
|
||||||
|
# 构建计算上下文
|
||||||
|
context: EnergyContext = {
|
||||||
|
"stream_id": stream_id,
|
||||||
|
"messages": messages,
|
||||||
|
"user_id": user_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 计算各组件能量
|
||||||
|
component_scores: Dict[str, float] = {}
|
||||||
|
total_weight = 0.0
|
||||||
|
|
||||||
|
for calculator in self.calculators:
|
||||||
|
try:
|
||||||
|
score = calculator.calculate(context)
|
||||||
|
weight = calculator.get_weight()
|
||||||
|
|
||||||
|
component_scores[calculator.__class__.__name__] = score
|
||||||
|
total_weight += weight
|
||||||
|
|
||||||
|
logger.debug(f"{calculator.__class__.__name__} 能量: {score:.3f} (权重: {weight:.3f})")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"计算 {calculator.__class__.__name__} 能量失败: {e}")
|
||||||
|
|
||||||
|
# 加权计算总能量
|
||||||
|
if total_weight > 0:
|
||||||
|
total_energy = 0.0
|
||||||
|
for calculator in self.calculators:
|
||||||
|
if calculator.__class__.__name__ in component_scores:
|
||||||
|
score = component_scores[calculator.__class__.__name__]
|
||||||
|
weight = calculator.get_weight()
|
||||||
|
total_energy += score * (weight / total_weight)
|
||||||
|
else:
|
||||||
|
total_energy = 0.5
|
||||||
|
|
||||||
|
# 应用阈值调整和变换
|
||||||
|
final_energy = self._apply_threshold_adjustment(total_energy)
|
||||||
|
|
||||||
|
# 缓存结果
|
||||||
|
self.energy_cache[stream_id] = (final_energy, time.time())
|
||||||
|
|
||||||
|
# 清理过期缓存
|
||||||
|
self._cleanup_cache()
|
||||||
|
|
||||||
|
# 更新平均计算时间
|
||||||
|
calculation_time = time.time() - start_time
|
||||||
|
total_calculations = self.stats["total_calculations"]
|
||||||
|
self.stats["average_calculation_time"] = (
|
||||||
|
(self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time)
|
||||||
|
/ total_calculations
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"聊天流 {stream_id} 最终能量: {final_energy:.3f} (原始: {total_energy:.3f}, 耗时: {calculation_time:.3f}s)")
|
||||||
|
return final_energy
|
||||||
|
|
||||||
|
def _apply_threshold_adjustment(self, energy: float) -> float:
|
||||||
|
"""应用阈值调整和变换"""
|
||||||
|
# 获取参考阈值
|
||||||
|
high_threshold = self.thresholds["high_match"]
|
||||||
|
reply_threshold = self.thresholds["reply"]
|
||||||
|
|
||||||
|
# 计算与阈值的相对位置
|
||||||
|
if energy >= high_threshold:
|
||||||
|
# 高能量区域:指数增强
|
||||||
|
adjusted = 0.7 + (energy - 0.7) ** 0.8
|
||||||
|
elif energy >= reply_threshold:
|
||||||
|
# 中等能量区域:线性保持
|
||||||
|
adjusted = energy
|
||||||
|
else:
|
||||||
|
# 低能量区域:对数压缩
|
||||||
|
adjusted = 0.4 * (energy / 0.4) ** 1.2
|
||||||
|
|
||||||
|
# 确保在合理范围内
|
||||||
|
return max(0.1, min(1.0, adjusted))
|
||||||
|
|
||||||
|
def get_energy_level(self, energy: float) -> EnergyLevel:
|
||||||
|
"""获取能量等级"""
|
||||||
|
if energy >= EnergyLevel.VERY_HIGH.value:
|
||||||
|
return EnergyLevel.VERY_HIGH
|
||||||
|
elif energy >= EnergyLevel.HIGH.value:
|
||||||
|
return EnergyLevel.HIGH
|
||||||
|
elif energy >= EnergyLevel.NORMAL.value:
|
||||||
|
return EnergyLevel.NORMAL
|
||||||
|
elif energy >= EnergyLevel.LOW.value:
|
||||||
|
return EnergyLevel.LOW
|
||||||
|
else:
|
||||||
|
return EnergyLevel.VERY_LOW
|
||||||
|
|
||||||
|
def get_distribution_interval(self, energy: float) -> float:
|
||||||
|
"""基于能量等级获取分发周期"""
|
||||||
|
energy_level = self.get_energy_level(energy)
|
||||||
|
|
||||||
|
# 根据能量等级确定基础分发周期
|
||||||
|
if energy_level == EnergyLevel.VERY_HIGH:
|
||||||
|
base_interval = 1.0 # 1秒
|
||||||
|
elif energy_level == EnergyLevel.HIGH:
|
||||||
|
base_interval = 3.0 # 3秒
|
||||||
|
elif energy_level == EnergyLevel.NORMAL:
|
||||||
|
base_interval = 8.0 # 8秒
|
||||||
|
elif energy_level == EnergyLevel.LOW:
|
||||||
|
base_interval = 15.0 # 15秒
|
||||||
|
else:
|
||||||
|
base_interval = 30.0 # 30秒
|
||||||
|
|
||||||
|
# 添加随机扰动避免同步
|
||||||
|
import random
|
||||||
|
jitter = random.uniform(0.8, 1.2)
|
||||||
|
final_interval = base_interval * jitter
|
||||||
|
|
||||||
|
# 确保在配置范围内
|
||||||
|
min_interval = getattr(global_config.chat, "dynamic_distribution_min_interval", 1.0)
|
||||||
|
max_interval = getattr(global_config.chat, "dynamic_distribution_max_interval", 60.0)
|
||||||
|
|
||||||
|
return max(min_interval, min(max_interval, final_interval))
|
||||||
|
|
||||||
|
def invalidate_cache(self, stream_id: str) -> None:
|
||||||
|
"""失效指定流的缓存"""
|
||||||
|
if stream_id in self.energy_cache:
|
||||||
|
del self.energy_cache[stream_id]
|
||||||
|
logger.debug(f"已清除聊天流 {stream_id} 的能量缓存")
|
||||||
|
|
||||||
|
def _cleanup_cache(self) -> None:
|
||||||
|
"""清理过期缓存"""
|
||||||
|
current_time = time.time()
|
||||||
|
expired_keys = [
|
||||||
|
stream_id for stream_id, (_, timestamp) in self.energy_cache.items()
|
||||||
|
if current_time - timestamp > self.cache_ttl
|
||||||
|
]
|
||||||
|
|
||||||
|
for key in expired_keys:
|
||||||
|
del self.energy_cache[key]
|
||||||
|
|
||||||
|
if expired_keys:
|
||||||
|
logger.debug(f"清理了 {len(expired_keys)} 个过期能量缓存")
|
||||||
|
|
||||||
|
def get_statistics(self) -> Dict[str, Any]:
|
||||||
|
"""获取统计信息"""
|
||||||
|
return {
|
||||||
|
"cache_size": len(self.energy_cache),
|
||||||
|
"calculators": [calc.__class__.__name__ for calc in self.calculators],
|
||||||
|
"thresholds": self.thresholds,
|
||||||
|
"performance_stats": self.stats.copy(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def update_thresholds(self, new_thresholds: Dict[str, float]) -> None:
|
||||||
|
"""更新阈值"""
|
||||||
|
self.thresholds.update(new_thresholds)
|
||||||
|
|
||||||
|
# 确保阈值关系合理
|
||||||
|
self.thresholds["high_match"] = max(self.thresholds["high_match"], self.thresholds["reply"] + 0.1)
|
||||||
|
self.thresholds["reply"] = max(self.thresholds["reply"], self.thresholds["non_reply"] + 0.1)
|
||||||
|
|
||||||
|
self.stats["last_threshold_update"] = time.time()
|
||||||
|
logger.info(f"更新AFC阈值: {self.thresholds}")
|
||||||
|
|
||||||
|
def add_calculator(self, calculator: EnergyCalculator) -> None:
|
||||||
|
"""添加计算器"""
|
||||||
|
self.calculators.append(calculator)
|
||||||
|
logger.info(f"添加能量计算器: {calculator.__class__.__name__}")
|
||||||
|
|
||||||
|
def remove_calculator(self, calculator: EnergyCalculator) -> None:
|
||||||
|
"""移除计算器"""
|
||||||
|
if calculator in self.calculators:
|
||||||
|
self.calculators.remove(calculator)
|
||||||
|
logger.info(f"移除能量计算器: {calculator.__class__.__name__}")
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""清空缓存"""
|
||||||
|
self.energy_cache.clear()
|
||||||
|
logger.info("清空能量缓存")
|
||||||
|
|
||||||
|
def get_cache_hit_rate(self) -> float:
|
||||||
|
"""获取缓存命中率"""
|
||||||
|
total_requests = self.stats.get("cache_hits", 0) + self.stats.get("cache_misses", 0)
|
||||||
|
if total_requests == 0:
|
||||||
|
return 0.0
|
||||||
|
return self.stats["cache_hits"] / total_requests
|
||||||
|
|
||||||
|
|
||||||
|
# 全局能量管理器实例
|
||||||
|
energy_manager = EnergyManager()
|
||||||
@@ -14,6 +14,7 @@ Chat Frequency Analyzer
|
|||||||
- MIN_CHATS_FOR_PEAK: 在一个窗口内需要多少次聊天才能被认为是高峰时段。
|
- MIN_CHATS_FOR_PEAK: 在一个窗口内需要多少次聊天才能被认为是高峰时段。
|
||||||
- MIN_GAP_BETWEEN_PEAKS_HOURS: 两个独立高峰时段之间的最小间隔(小时)。
|
- MIN_GAP_BETWEEN_PEAKS_HOURS: 两个独立高峰时段之间的最小间隔(小时)。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import time as time_module
|
import time as time_module
|
||||||
from datetime import datetime, timedelta, time
|
from datetime import datetime, timedelta, time
|
||||||
from typing import List, Tuple, Optional
|
from typing import List, Tuple, Optional
|
||||||
@@ -72,7 +73,9 @@ class ChatFrequencyAnalyzer:
|
|||||||
current_window_end = datetimes[i]
|
current_window_end = datetimes[i]
|
||||||
|
|
||||||
# 合并重叠或相邻的高峰时段
|
# 合并重叠或相邻的高峰时段
|
||||||
if peak_windows and current_window_start - peak_windows[-1][1] < timedelta(hours=MIN_GAP_BETWEEN_PEAKS_HOURS):
|
if peak_windows and current_window_start - peak_windows[-1][1] < timedelta(
|
||||||
|
hours=MIN_GAP_BETWEEN_PEAKS_HOURS
|
||||||
|
):
|
||||||
# 扩展上一个窗口的结束时间
|
# 扩展上一个窗口的结束时间
|
||||||
peak_windows[-1] = (peak_windows[-1][0], current_window_end)
|
peak_windows[-1] = (peak_windows[-1][0], current_window_end)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -14,15 +14,16 @@ Frequency-Based Proactive Trigger
|
|||||||
- TRIGGER_CHECK_INTERVAL_SECONDS: 触发器检查的周期(秒)。
|
- TRIGGER_CHECK_INTERVAL_SECONDS: 触发器检查的周期(秒)。
|
||||||
- COOLDOWN_HOURS: 在同一个高峰时段内触发一次后的冷却时间(小时)。
|
- COOLDOWN_HOURS: 在同一个高峰时段内触发一次后的冷却时间(小时)。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.chat_loop.proactive.events import ProactiveTriggerEvent
|
# AFC manager has been moved to chatter plugin
|
||||||
from src.chat.heart_flow.heartflow import heartflow
|
|
||||||
from src.chat.chat_loop.sleep_manager.sleep_manager import SleepManager
|
# TODO: 需要重新实现主动思考和睡眠管理功能
|
||||||
from .analyzer import chat_frequency_analyzer
|
from .analyzer import chat_frequency_analyzer
|
||||||
|
|
||||||
logger = get_logger("FrequencyBasedTrigger")
|
logger = get_logger("FrequencyBasedTrigger")
|
||||||
@@ -39,8 +40,8 @@ class FrequencyBasedTrigger:
|
|||||||
一个周期性任务,根据聊天频率分析结果来触发主动思考。
|
一个周期性任务,根据聊天频率分析结果来触发主动思考。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, sleep_manager: SleepManager):
|
def __init__(self):
|
||||||
self._sleep_manager = sleep_manager
|
# TODO: 需要重新实现睡眠管理器
|
||||||
self._task: Optional[asyncio.Task] = None
|
self._task: Optional[asyncio.Task] = None
|
||||||
# 记录上次为用户触发的时间,用于冷却控制
|
# 记录上次为用户触发的时间,用于冷却控制
|
||||||
# 格式: { "chat_id": timestamp }
|
# 格式: { "chat_id": timestamp }
|
||||||
@@ -53,14 +54,16 @@ class FrequencyBasedTrigger:
|
|||||||
await asyncio.sleep(TRIGGER_CHECK_INTERVAL_SECONDS)
|
await asyncio.sleep(TRIGGER_CHECK_INTERVAL_SECONDS)
|
||||||
logger.debug("开始执行频率触发器检查...")
|
logger.debug("开始执行频率触发器检查...")
|
||||||
|
|
||||||
# 1. 检查角色是否清醒
|
# 1. TODO: 检查角色是否清醒 - 需要重新实现睡眠状态检查
|
||||||
if self._sleep_manager.is_sleeping():
|
# 暂时跳过睡眠检查
|
||||||
logger.debug("角色正在睡眠,跳过本次频率触发检查。")
|
# if self._sleep_manager.is_sleeping():
|
||||||
continue
|
# logger.debug("角色正在睡眠,跳过本次频率触发检查。")
|
||||||
|
# continue
|
||||||
|
|
||||||
# 2. 获取所有已知的聊天ID
|
# 2. 获取所有已知的聊天ID
|
||||||
# 【注意】这里我们假设所有 subheartflow 的 ID 就是 chat_id
|
# 注意:AFC管理器已移至chatter插件,此功能暂时禁用
|
||||||
all_chat_ids = list(heartflow.subheartflows.keys())
|
# all_chat_ids = list(afc_manager.affinity_flow_chatters.keys())
|
||||||
|
all_chat_ids = [] # 暂时禁用此功能
|
||||||
if not all_chat_ids:
|
if not all_chat_ids:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -74,30 +77,12 @@ class FrequencyBasedTrigger:
|
|||||||
|
|
||||||
# 4. 检查当前是否是该用户的高峰聊天时间
|
# 4. 检查当前是否是该用户的高峰聊天时间
|
||||||
if chat_frequency_analyzer.is_in_peak_time(chat_id, now):
|
if chat_frequency_analyzer.is_in_peak_time(chat_id, now):
|
||||||
|
# 5. 检查用户当前是否已有活跃的处理任务
|
||||||
sub_heartflow = await heartflow.get_or_create_subheartflow(chat_id)
|
# 注意:AFC管理器已移至chatter插件,此功能暂时禁用
|
||||||
if not sub_heartflow:
|
# chatter = afc_manager.get_or_create_chatter(chat_id)
|
||||||
logger.warning(f"无法为 {chat_id} 获取或创建 sub_heartflow。")
|
logger.info(f"检测到用户 {chat_id} 处于聊天高峰期,但AFC功能已移至chatter插件")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 5. 检查用户当前是否已有活跃的思考或回复任务
|
|
||||||
cycle_detail = sub_heartflow.heart_fc_instance.context.current_cycle_detail
|
|
||||||
if cycle_detail and not cycle_detail.end_time:
|
|
||||||
logger.debug(f"用户 {chat_id} 的聊天循环正忙(仍在周期 {cycle_detail.cycle_id} 中),本次不触发。")
|
|
||||||
continue
|
|
||||||
|
|
||||||
logger.info(f"检测到用户 {chat_id} 处于聊天高峰期,且聊天循环空闲,准备触发主动思考。")
|
|
||||||
|
|
||||||
# 6. 直接调用 proactive_thinker
|
|
||||||
event = ProactiveTriggerEvent(
|
|
||||||
source="frequency_analyzer",
|
|
||||||
reason="User is in a high-frequency chat period."
|
|
||||||
)
|
|
||||||
await sub_heartflow.heart_fc_instance.proactive_thinker.think(event)
|
|
||||||
|
|
||||||
# 7. 更新触发时间,进入冷却
|
|
||||||
self._last_triggered[chat_id] = time.time()
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("频率触发器任务被取消。")
|
logger.info("频率触发器任务被取消。")
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -1,40 +0,0 @@
|
|||||||
import traceback
|
|
||||||
from typing import Any, Optional, Dict
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.chat.heart_flow.sub_heartflow import SubHeartflow
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
|
||||||
|
|
||||||
logger = get_logger("heartflow")
|
|
||||||
|
|
||||||
|
|
||||||
class Heartflow:
|
|
||||||
"""主心流协调器,负责初始化并协调聊天"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.subheartflows: Dict[Any, "SubHeartflow"] = {}
|
|
||||||
|
|
||||||
async def get_or_create_subheartflow(self, subheartflow_id: Any) -> Optional["SubHeartflow"]:
|
|
||||||
"""获取或创建一个新的SubHeartflow实例"""
|
|
||||||
if subheartflow_id in self.subheartflows:
|
|
||||||
if subflow := self.subheartflows.get(subheartflow_id):
|
|
||||||
return subflow
|
|
||||||
|
|
||||||
try:
|
|
||||||
new_subflow = SubHeartflow(subheartflow_id)
|
|
||||||
|
|
||||||
await new_subflow.initialize()
|
|
||||||
|
|
||||||
# 注册子心流
|
|
||||||
self.subheartflows[subheartflow_id] = new_subflow
|
|
||||||
heartflow_name = get_chat_manager().get_stream_name(subheartflow_id) or subheartflow_id
|
|
||||||
logger.info(f"[{heartflow_name}] 开始接收消息")
|
|
||||||
|
|
||||||
return new_subflow
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"创建子心流 {subheartflow_id} 失败: {e}", exc_info=True)
|
|
||||||
traceback.print_exc()
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
heartflow = Heartflow()
|
|
||||||
@@ -1,178 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import math
|
|
||||||
import re
|
|
||||||
import traceback
|
|
||||||
from typing import Tuple, TYPE_CHECKING
|
|
||||||
|
|
||||||
from src.chat.heart_flow.heartflow import heartflow
|
|
||||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
|
||||||
from src.chat.message_receive.message import MessageRecv
|
|
||||||
from src.chat.message_receive.storage import MessageStorage
|
|
||||||
from src.chat.utils.chat_message_builder import replace_user_references_sync
|
|
||||||
from src.chat.utils.timer_calculator import Timer
|
|
||||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.mood.mood_manager import mood_manager
|
|
||||||
from src.person_info.relationship_manager import get_relationship_manager
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from src.chat.heart_flow.sub_heartflow import SubHeartflow
|
|
||||||
|
|
||||||
logger = get_logger("chat")
|
|
||||||
|
|
||||||
|
|
||||||
async def _process_relationship(message: MessageRecv) -> None:
|
|
||||||
"""处理用户关系逻辑
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: 消息对象,包含用户信息
|
|
||||||
"""
|
|
||||||
platform = message.message_info.platform
|
|
||||||
user_id = message.message_info.user_info.user_id # type: ignore
|
|
||||||
nickname = message.message_info.user_info.user_nickname # type: ignore
|
|
||||||
cardname = message.message_info.user_info.user_cardname or nickname # type: ignore
|
|
||||||
|
|
||||||
relationship_manager = get_relationship_manager()
|
|
||||||
is_known = await relationship_manager.is_known_some_one(platform, user_id)
|
|
||||||
|
|
||||||
if not is_known:
|
|
||||||
logger.info(f"首次认识用户: {nickname}")
|
|
||||||
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool, list[str]]:
|
|
||||||
"""计算消息的兴趣度
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: 待处理的消息对象
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[float, bool, list[str]]: (兴趣度, 是否被提及, 关键词)
|
|
||||||
"""
|
|
||||||
is_mentioned, _ = is_mentioned_bot_in_message(message)
|
|
||||||
interested_rate = 0.0
|
|
||||||
|
|
||||||
with Timer("记忆激活"):
|
|
||||||
interested_rate, keywords = await hippocampus_manager.get_activate_from_text(
|
|
||||||
message.processed_plain_text,
|
|
||||||
max_depth=4,
|
|
||||||
fast_retrieval=False,
|
|
||||||
)
|
|
||||||
message.key_words = keywords
|
|
||||||
message.key_words_lite = keywords
|
|
||||||
logger.debug(f"记忆激活率: {interested_rate:.2f}, 关键词: {keywords}")
|
|
||||||
|
|
||||||
text_len = len(message.processed_plain_text)
|
|
||||||
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
|
|
||||||
# 基于实际分布:0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%)
|
|
||||||
|
|
||||||
if text_len == 0:
|
|
||||||
base_interest = 0.01 # 空消息最低兴趣度
|
|
||||||
elif text_len <= 5:
|
|
||||||
# 1-5字符:线性增长 0.01 -> 0.03
|
|
||||||
base_interest = 0.01 + (text_len - 1) * (0.03 - 0.01) / 4
|
|
||||||
elif text_len <= 10:
|
|
||||||
# 6-10字符:线性增长 0.03 -> 0.06
|
|
||||||
base_interest = 0.03 + (text_len - 5) * (0.06 - 0.03) / 5
|
|
||||||
elif text_len <= 20:
|
|
||||||
# 11-20字符:线性增长 0.06 -> 0.12
|
|
||||||
base_interest = 0.06 + (text_len - 10) * (0.12 - 0.06) / 10
|
|
||||||
elif text_len <= 30:
|
|
||||||
# 21-30字符:线性增长 0.12 -> 0.18
|
|
||||||
base_interest = 0.12 + (text_len - 20) * (0.18 - 0.12) / 10
|
|
||||||
elif text_len <= 50:
|
|
||||||
# 31-50字符:线性增长 0.18 -> 0.22
|
|
||||||
base_interest = 0.18 + (text_len - 30) * (0.22 - 0.18) / 20
|
|
||||||
elif text_len <= 100:
|
|
||||||
# 51-100字符:线性增长 0.22 -> 0.26
|
|
||||||
base_interest = 0.22 + (text_len - 50) * (0.26 - 0.22) / 50
|
|
||||||
else:
|
|
||||||
# 100+字符:对数增长 0.26 -> 0.3,增长率递减
|
|
||||||
base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901
|
|
||||||
|
|
||||||
# 确保在范围内
|
|
||||||
base_interest = min(max(base_interest, 0.01), 0.3)
|
|
||||||
|
|
||||||
interested_rate += base_interest
|
|
||||||
|
|
||||||
if is_mentioned:
|
|
||||||
interest_increase_on_mention = 1
|
|
||||||
interested_rate += interest_increase_on_mention
|
|
||||||
|
|
||||||
return interested_rate, is_mentioned, keywords
|
|
||||||
|
|
||||||
|
|
||||||
class HeartFCMessageReceiver:
|
|
||||||
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""初始化心流处理器,创建消息存储实例"""
|
|
||||||
self.storage = MessageStorage()
|
|
||||||
|
|
||||||
async def process_message(self, message: MessageRecv) -> None:
|
|
||||||
"""处理接收到的原始消息数据
|
|
||||||
|
|
||||||
主要流程:
|
|
||||||
1. 消息解析与初始化
|
|
||||||
2. 消息缓冲处理
|
|
||||||
4. 过滤检查
|
|
||||||
5. 兴趣度计算
|
|
||||||
6. 关系处理
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message_data: 原始消息字符串
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 1. 消息解析与初始化
|
|
||||||
userinfo = message.message_info.user_info
|
|
||||||
chat = message.chat_stream
|
|
||||||
|
|
||||||
# 2. 兴趣度计算与更新
|
|
||||||
interested_rate, is_mentioned, keywords = await _calculate_interest(message)
|
|
||||||
message.interest_value = interested_rate
|
|
||||||
message.is_mentioned = is_mentioned
|
|
||||||
|
|
||||||
await self.storage.store_message(message, chat)
|
|
||||||
|
|
||||||
subheartflow: SubHeartflow = await heartflow.get_or_create_subheartflow(chat.stream_id) # type: ignore
|
|
||||||
|
|
||||||
await subheartflow.heart_fc_instance.add_message(message.to_dict())
|
|
||||||
if global_config.mood.enable_mood:
|
|
||||||
chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id)
|
|
||||||
asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate))
|
|
||||||
|
|
||||||
# 3. 日志记录
|
|
||||||
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
|
|
||||||
# current_time = time.strftime("%H:%M:%S", time.localtime(message.message_info.time))
|
|
||||||
current_talk_frequency = global_config.chat.get_current_talk_frequency(chat.stream_id)
|
|
||||||
|
|
||||||
# 如果消息中包含图片标识,则将 [picid:...] 替换为 [图片]
|
|
||||||
picid_pattern = r"\[picid:([^\]]+)\]"
|
|
||||||
processed_plain_text = re.sub(picid_pattern, "[图片]", message.processed_plain_text)
|
|
||||||
|
|
||||||
# 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式
|
|
||||||
processed_plain_text = replace_user_references_sync(
|
|
||||||
processed_plain_text,
|
|
||||||
message.message_info.platform, # type: ignore
|
|
||||||
replace_bot_name=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if keywords:
|
|
||||||
logger.info(
|
|
||||||
f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[兴趣度:{interested_rate:.2f}][关键词:{keywords}]"
|
|
||||||
) # type: ignore
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[兴趣度:{interested_rate:.2f}]"
|
|
||||||
) # type: ignore
|
|
||||||
|
|
||||||
logger.debug(f"[{mes_name}][当前时段回复频率: {current_talk_frequency}]")
|
|
||||||
|
|
||||||
# 4. 关系处理
|
|
||||||
if global_config.relationship.enable_relationship:
|
|
||||||
await _process_relationship(message)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"消息处理失败: {e}")
|
|
||||||
print(traceback.format_exc())
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
from rich.traceback import install
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
|
||||||
from src.chat.chat_loop.heartFC_chat import HeartFChatting
|
|
||||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
|
||||||
|
|
||||||
logger = get_logger("sub_heartflow")
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
|
||||||
|
|
||||||
|
|
||||||
class SubHeartflow:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
subheartflow_id,
|
|
||||||
):
|
|
||||||
"""子心流初始化函数
|
|
||||||
|
|
||||||
Args:
|
|
||||||
subheartflow_id: 子心流唯一标识符
|
|
||||||
"""
|
|
||||||
# 基础属性,两个值是一样的
|
|
||||||
self.subheartflow_id = subheartflow_id
|
|
||||||
self.chat_id = subheartflow_id
|
|
||||||
|
|
||||||
self.is_group_chat, self.chat_target_info = (None, None)
|
|
||||||
self.log_prefix = get_chat_manager().get_stream_name(self.subheartflow_id) or self.subheartflow_id
|
|
||||||
|
|
||||||
# focus模式退出冷却时间管理
|
|
||||||
self.last_focus_exit_time: float = 0 # 上次退出focus模式的时间
|
|
||||||
|
|
||||||
# 随便水群 normal_chat 和 认真水群 focus_chat 实例
|
|
||||||
# CHAT模式激活 随便水群 FOCUS模式激活 认真水群
|
|
||||||
self.heart_fc_instance: HeartFChatting = HeartFChatting(
|
|
||||||
chat_id=self.subheartflow_id,
|
|
||||||
) # 该sub_heartflow的HeartFChatting实例
|
|
||||||
|
|
||||||
async def initialize(self):
|
|
||||||
"""异步初始化方法,创建兴趣流并确定聊天类型"""
|
|
||||||
self.is_group_chat, self.chat_target_info = await get_chat_type_and_target_info(self.chat_id)
|
|
||||||
await self.heart_fc_instance.start()
|
|
||||||
15
src/chat/interest_system/__init__.py
Normal file
15
src/chat/interest_system/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
兴趣度系统模块
|
||||||
|
提供机器人兴趣标签和智能匹配功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .bot_interest_manager import BotInterestManager, bot_interest_manager
|
||||||
|
from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BotInterestManager",
|
||||||
|
"bot_interest_manager",
|
||||||
|
"BotInterestTag",
|
||||||
|
"BotPersonalityInterests",
|
||||||
|
"InterestMatchResult",
|
||||||
|
]
|
||||||
805
src/chat/interest_system/bot_interest_manager.py
Normal file
805
src/chat/interest_system/bot_interest_manager.py
Normal file
@@ -0,0 +1,805 @@
|
|||||||
|
"""
|
||||||
|
机器人兴趣标签管理系统
|
||||||
|
基于人设生成兴趣标签,并使用embedding计算匹配度
|
||||||
|
"""
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
import traceback
|
||||||
|
from typing import List, Dict, Optional, Any
|
||||||
|
from datetime import datetime
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.common.data_models.bot_interest_data_model import BotPersonalityInterests, BotInterestTag, InterestMatchResult
|
||||||
|
|
||||||
|
logger = get_logger("bot_interest_manager")
|
||||||
|
|
||||||
|
|
||||||
|
class BotInterestManager:
|
||||||
|
"""机器人兴趣标签管理器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.current_interests: Optional[BotPersonalityInterests] = None
|
||||||
|
self.embedding_cache: Dict[str, List[float]] = {} # embedding缓存
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
# Embedding客户端配置
|
||||||
|
self.embedding_request = None
|
||||||
|
self.embedding_config = None
|
||||||
|
self.embedding_dimension = 1024 # 默认BGE-M3 embedding维度
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_initialized(self) -> bool:
|
||||||
|
"""检查兴趣系统是否已初始化"""
|
||||||
|
return self._initialized
|
||||||
|
|
||||||
|
async def initialize(self, personality_description: str, personality_id: str = "default"):
|
||||||
|
"""初始化兴趣标签系统"""
|
||||||
|
try:
|
||||||
|
logger.info("机器人兴趣系统开始初始化...")
|
||||||
|
logger.info(f"人设ID: {personality_id}, 描述长度: {len(personality_description)}")
|
||||||
|
|
||||||
|
# 初始化embedding模型
|
||||||
|
await self._initialize_embedding_model()
|
||||||
|
|
||||||
|
# 检查embedding客户端是否成功初始化
|
||||||
|
if not self.embedding_request:
|
||||||
|
raise RuntimeError("Embedding客户端初始化失败")
|
||||||
|
|
||||||
|
# 生成或加载兴趣标签
|
||||||
|
await self._load_or_generate_interests(personality_description, personality_id)
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
# 检查是否成功获取兴趣标签
|
||||||
|
if self.current_interests and len(self.current_interests.get_active_tags()) > 0:
|
||||||
|
active_tags_count = len(self.current_interests.get_active_tags())
|
||||||
|
logger.info("机器人兴趣系统初始化完成!")
|
||||||
|
logger.info(f"当前已激活 {active_tags_count} 个兴趣标签, Embedding缓存 {len(self.embedding_cache)} 个")
|
||||||
|
else:
|
||||||
|
raise RuntimeError("未能成功加载或生成兴趣标签")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"机器人兴趣系统初始化失败: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
raise # 重新抛出异常,不允许降级初始化
|
||||||
|
|
||||||
|
async def _initialize_embedding_model(self):
|
||||||
|
"""初始化embedding模型"""
|
||||||
|
logger.info("🔧 正在配置embedding客户端...")
|
||||||
|
|
||||||
|
# 使用项目配置的embedding模型
|
||||||
|
from src.config.config import model_config
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
||||||
|
logger.debug("✅ 成功导入embedding相关模块")
|
||||||
|
|
||||||
|
# 检查embedding配置是否存在
|
||||||
|
if not hasattr(model_config.model_task_config, "embedding"):
|
||||||
|
raise RuntimeError("❌ 未找到embedding模型配置")
|
||||||
|
|
||||||
|
logger.info("📋 找到embedding模型配置")
|
||||||
|
self.embedding_config = model_config.model_task_config.embedding
|
||||||
|
self.embedding_dimension = 1024 # BGE-M3的维度
|
||||||
|
logger.info(f"📐 使用模型维度: {self.embedding_dimension}")
|
||||||
|
|
||||||
|
# 创建LLMRequest实例用于embedding
|
||||||
|
self.embedding_request = LLMRequest(model_set=self.embedding_config, request_type="interest_embedding")
|
||||||
|
logger.info("✅ Embedding请求客户端初始化成功")
|
||||||
|
logger.info(f"🔗 客户端类型: {type(self.embedding_request).__name__}")
|
||||||
|
|
||||||
|
# 获取第一个embedding模型的ModelInfo
|
||||||
|
if hasattr(self.embedding_config, "model_list") and self.embedding_config.model_list:
|
||||||
|
first_model_name = self.embedding_config.model_list[0]
|
||||||
|
logger.info(f"🎯 使用embedding模型: {first_model_name}")
|
||||||
|
else:
|
||||||
|
logger.warning("⚠️ 未找到embedding模型列表")
|
||||||
|
|
||||||
|
logger.info("✅ Embedding模型初始化完成")
|
||||||
|
|
||||||
|
async def _load_or_generate_interests(self, personality_description: str, personality_id: str):
|
||||||
|
"""加载或生成兴趣标签"""
|
||||||
|
logger.info(f"📚 正在为 '{personality_id}' 加载或生成兴趣标签...")
|
||||||
|
|
||||||
|
# 首先尝试从数据库加载
|
||||||
|
logger.info("尝试从数据库加载兴趣标签...")
|
||||||
|
loaded_interests = await self._load_interests_from_database(personality_id)
|
||||||
|
|
||||||
|
if loaded_interests:
|
||||||
|
self.current_interests = loaded_interests
|
||||||
|
active_count = len(loaded_interests.get_active_tags())
|
||||||
|
logger.info(f"成功从数据库加载 {active_count} 个兴趣标签 (版本: {loaded_interests.version})")
|
||||||
|
tags_info = [f" - '{tag.tag_name}' (权重: {tag.weight:.2f})" for tag in loaded_interests.get_active_tags()]
|
||||||
|
tags_str = "\n".join(tags_info)
|
||||||
|
logger.info(f"当前兴趣标签:\n{tags_str}")
|
||||||
|
else:
|
||||||
|
# 生成新的兴趣标签
|
||||||
|
logger.info("数据库中未找到兴趣标签,开始生成...")
|
||||||
|
generated_interests = await self._generate_interests_from_personality(
|
||||||
|
personality_description, personality_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if generated_interests:
|
||||||
|
self.current_interests = generated_interests
|
||||||
|
active_count = len(generated_interests.get_active_tags())
|
||||||
|
logger.info(f"成功生成 {active_count} 个新兴趣标签。")
|
||||||
|
tags_info = [f" - '{tag.tag_name}' (权重: {tag.weight:.2f})" for tag in generated_interests.get_active_tags()]
|
||||||
|
tags_str = "\n".join(tags_info)
|
||||||
|
logger.info(f"当前兴趣标签:\n{tags_str}")
|
||||||
|
|
||||||
|
# 保存到数据库
|
||||||
|
logger.info("正在保存至数据库...")
|
||||||
|
await self._save_interests_to_database(generated_interests)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("❌ 兴趣标签生成失败")
|
||||||
|
|
||||||
|
async def _generate_interests_from_personality(
|
||||||
|
self, personality_description: str, personality_id: str
|
||||||
|
) -> Optional[BotPersonalityInterests]:
|
||||||
|
"""根据人设生成兴趣标签"""
|
||||||
|
try:
|
||||||
|
logger.info("🎨 开始根据人设生成兴趣标签...")
|
||||||
|
logger.info(f"📝 人设长度: {len(personality_description)} 字符")
|
||||||
|
|
||||||
|
# 检查embedding客户端是否可用
|
||||||
|
if not hasattr(self, "embedding_request"):
|
||||||
|
raise RuntimeError("❌ Embedding客户端未初始化,无法生成兴趣标签")
|
||||||
|
|
||||||
|
# 构建提示词
|
||||||
|
logger.info("📝 构建LLM提示词...")
|
||||||
|
prompt = f"""
|
||||||
|
基于以下机器人人设描述,生成一套合适的兴趣标签:
|
||||||
|
|
||||||
|
人设描述:
|
||||||
|
{personality_description}
|
||||||
|
|
||||||
|
请生成一系列兴趣关键词标签,要求:
|
||||||
|
1. 标签应该符合人设特点和性格
|
||||||
|
2. 每个标签都有权重(0.1-1.0),表示对该兴趣的喜好程度
|
||||||
|
3. 生成15-25个不等的标签
|
||||||
|
4. 标签应该是具体的关键词,而不是抽象概念
|
||||||
|
|
||||||
|
请以JSON格式返回,格式如下:
|
||||||
|
{{
|
||||||
|
"interests": [
|
||||||
|
{{"name": "标签名", "weight": 0.8}},
|
||||||
|
{{"name": "标签名", "weight": 0.6}},
|
||||||
|
{{"name": "标签名", "weight": 0.9}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
注意:
|
||||||
|
- 权重范围0.1-1.0,权重越高表示越感兴趣
|
||||||
|
- 标签要具体,如"编程"、"游戏"、"旅行"等
|
||||||
|
- 根据人设生成个性化的标签
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 调用LLM生成兴趣标签
|
||||||
|
logger.info("🤖 正在调用LLM生成兴趣标签...")
|
||||||
|
response = await self._call_llm_for_interest_generation(prompt)
|
||||||
|
|
||||||
|
if not response:
|
||||||
|
raise RuntimeError("❌ LLM未返回有效响应")
|
||||||
|
|
||||||
|
logger.info("✅ LLM响应成功,开始解析兴趣标签...")
|
||||||
|
interests_data = orjson.loads(response)
|
||||||
|
|
||||||
|
bot_interests = BotPersonalityInterests(
|
||||||
|
personality_id=personality_id, personality_description=personality_description
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解析生成的兴趣标签
|
||||||
|
interests_list = interests_data.get("interests", [])
|
||||||
|
logger.info(f"📋 解析到 {len(interests_list)} 个兴趣标签")
|
||||||
|
|
||||||
|
for i, tag_data in enumerate(interests_list):
|
||||||
|
tag_name = tag_data.get("name", f"标签_{i}")
|
||||||
|
weight = tag_data.get("weight", 0.5)
|
||||||
|
|
||||||
|
tag = BotInterestTag(tag_name=tag_name, weight=weight)
|
||||||
|
bot_interests.interest_tags.append(tag)
|
||||||
|
|
||||||
|
logger.debug(f" 🏷️ {tag_name} (权重: {weight:.2f})")
|
||||||
|
|
||||||
|
# 为所有标签生成embedding
|
||||||
|
logger.info("🧠 开始为兴趣标签生成embedding向量...")
|
||||||
|
await self._generate_embeddings_for_tags(bot_interests)
|
||||||
|
|
||||||
|
logger.info("✅ 兴趣标签生成完成")
|
||||||
|
return bot_interests
|
||||||
|
|
||||||
|
except orjson.JSONDecodeError as e:
|
||||||
|
logger.error(f"❌ 解析LLM响应JSON失败: {e}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 根据人设生成兴趣标签失败: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _call_llm_for_interest_generation(self, prompt: str) -> Optional[str]:
|
||||||
|
"""调用LLM生成兴趣标签"""
|
||||||
|
try:
|
||||||
|
logger.info("🔧 配置LLM客户端...")
|
||||||
|
|
||||||
|
# 使用llm_api来处理请求
|
||||||
|
from src.plugin_system.apis import llm_api
|
||||||
|
from src.config.config import model_config
|
||||||
|
|
||||||
|
# 构建完整的提示词,明确要求只返回纯JSON
|
||||||
|
full_prompt = f"""你是一个专业的机器人人设分析师,擅长根据人设描述生成合适的兴趣标签。
|
||||||
|
|
||||||
|
{prompt}
|
||||||
|
|
||||||
|
请确保返回格式为有效的JSON,不要包含任何额外的文本、解释或代码块标记。只返回JSON对象本身。"""
|
||||||
|
|
||||||
|
# 使用replyer模型配置
|
||||||
|
replyer_config = model_config.model_task_config.replyer
|
||||||
|
|
||||||
|
# 调用LLM API
|
||||||
|
logger.info("🚀 正在通过LLM API发送请求...")
|
||||||
|
success, response, reasoning_content, model_name = await llm_api.generate_with_model(
|
||||||
|
prompt=full_prompt,
|
||||||
|
model_config=replyer_config,
|
||||||
|
request_type="interest_generation",
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=2000,
|
||||||
|
)
|
||||||
|
|
||||||
|
if success and response:
|
||||||
|
logger.info(f"✅ LLM响应成功,模型: {model_name}, 响应长度: {len(response)} 字符")
|
||||||
|
logger.debug(
|
||||||
|
f"📄 LLM响应内容: {response[:200]}..." if len(response) > 200 else f"📄 LLM响应内容: {response}"
|
||||||
|
)
|
||||||
|
if reasoning_content:
|
||||||
|
logger.debug(f"🧠 推理内容: {reasoning_content[:100]}...")
|
||||||
|
|
||||||
|
# 清理响应内容,移除可能的代码块标记
|
||||||
|
cleaned_response = self._clean_llm_response(response)
|
||||||
|
return cleaned_response
|
||||||
|
else:
|
||||||
|
logger.warning("⚠️ LLM返回空响应或调用失败")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 调用LLM生成兴趣标签失败: {e}")
|
||||||
|
logger.error("🔍 错误详情:")
|
||||||
|
traceback.print_exc()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _clean_llm_response(self, response: str) -> str:
|
||||||
|
"""清理LLM响应,移除代码块标记和其他非JSON内容"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
# 移除 ```json 和 ``` 标记
|
||||||
|
cleaned = re.sub(r"```json\s*", "", response)
|
||||||
|
cleaned = re.sub(r"\s*```", "", cleaned)
|
||||||
|
|
||||||
|
# 移除可能的多余空格和换行
|
||||||
|
cleaned = cleaned.strip()
|
||||||
|
|
||||||
|
# 尝试提取JSON对象(如果响应中有其他文本)
|
||||||
|
json_match = re.search(r"\{.*\}", cleaned, re.DOTALL)
|
||||||
|
if json_match:
|
||||||
|
cleaned = json_match.group(0)
|
||||||
|
|
||||||
|
logger.debug(f"🧹 清理后的响应: {cleaned[:200]}..." if len(cleaned) > 200 else f"🧹 清理后的响应: {cleaned}")
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
async def _generate_embeddings_for_tags(self, interests: BotPersonalityInterests):
|
||||||
|
"""为所有兴趣标签生成embedding"""
|
||||||
|
if not hasattr(self, "embedding_request"):
|
||||||
|
raise RuntimeError("❌ Embedding客户端未初始化,无法生成embedding")
|
||||||
|
|
||||||
|
total_tags = len(interests.interest_tags)
|
||||||
|
logger.info(f"🧠 开始为 {total_tags} 个兴趣标签生成embedding向量...")
|
||||||
|
|
||||||
|
cached_count = 0
|
||||||
|
generated_count = 0
|
||||||
|
failed_count = 0
|
||||||
|
|
||||||
|
for i, tag in enumerate(interests.interest_tags, 1):
|
||||||
|
if tag.tag_name in self.embedding_cache:
|
||||||
|
# 使用缓存的embedding
|
||||||
|
tag.embedding = self.embedding_cache[tag.tag_name]
|
||||||
|
cached_count += 1
|
||||||
|
logger.debug(f" [{i}/{total_tags}] 🏷️ '{tag.tag_name}' - 使用缓存")
|
||||||
|
else:
|
||||||
|
# 生成新的embedding
|
||||||
|
embedding_text = tag.tag_name
|
||||||
|
|
||||||
|
logger.debug(f" [{i}/{total_tags}] 🔄 正在为 '{tag.tag_name}' 生成embedding...")
|
||||||
|
embedding = await self._get_embedding(embedding_text)
|
||||||
|
|
||||||
|
if embedding:
|
||||||
|
tag.embedding = embedding
|
||||||
|
self.embedding_cache[tag.tag_name] = embedding
|
||||||
|
generated_count += 1
|
||||||
|
logger.debug(f" ✅ '{tag.tag_name}' embedding生成成功")
|
||||||
|
else:
|
||||||
|
failed_count += 1
|
||||||
|
logger.warning(f" ❌ '{tag.tag_name}' embedding生成失败")
|
||||||
|
|
||||||
|
if failed_count > 0:
|
||||||
|
raise RuntimeError(f"❌ 有 {failed_count} 个兴趣标签embedding生成失败")
|
||||||
|
|
||||||
|
interests.last_updated = datetime.now()
|
||||||
|
logger.info("=" * 50)
|
||||||
|
logger.info("✅ Embedding生成完成!")
|
||||||
|
logger.info(f"📊 总标签数: {total_tags}")
|
||||||
|
logger.info(f"💾 缓存命中: {cached_count}")
|
||||||
|
logger.info(f"🆕 新生成: {generated_count}")
|
||||||
|
logger.info(f"❌ 失败: {failed_count}")
|
||||||
|
logger.info(f"🗃️ 总缓存大小: {len(self.embedding_cache)}")
|
||||||
|
logger.info("=" * 50)
|
||||||
|
|
||||||
|
async def _get_embedding(self, text: str) -> List[float]:
|
||||||
|
"""获取文本的embedding向量"""
|
||||||
|
if not hasattr(self, "embedding_request"):
|
||||||
|
raise RuntimeError("❌ Embedding请求客户端未初始化")
|
||||||
|
|
||||||
|
# 检查缓存
|
||||||
|
if text in self.embedding_cache:
|
||||||
|
logger.debug(f"💾 使用缓存的embedding: '{text[:30]}...'")
|
||||||
|
return self.embedding_cache[text]
|
||||||
|
|
||||||
|
# 使用LLMRequest获取embedding
|
||||||
|
logger.debug(f"🔄 正在获取embedding: '{text[:30]}...'")
|
||||||
|
embedding, model_name = await self.embedding_request.get_embedding(text)
|
||||||
|
|
||||||
|
if embedding and len(embedding) > 0:
|
||||||
|
self.embedding_cache[text] = embedding
|
||||||
|
logger.debug(f"✅ Embedding获取成功,维度: {len(embedding)}, 模型: {model_name}")
|
||||||
|
return embedding
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"❌ 返回的embedding为空: {embedding}")
|
||||||
|
|
||||||
|
async def _generate_message_embedding(self, message_text: str, keywords: List[str]) -> List[float]:
|
||||||
|
"""为消息生成embedding向量"""
|
||||||
|
# 组合消息文本和关键词作为embedding输入
|
||||||
|
if keywords:
|
||||||
|
combined_text = f"{message_text} {' '.join(keywords)}"
|
||||||
|
else:
|
||||||
|
combined_text = message_text
|
||||||
|
|
||||||
|
logger.debug(f"🔄 正在为消息生成embedding,输入长度: {len(combined_text)}")
|
||||||
|
|
||||||
|
# 生成embedding
|
||||||
|
embedding = await self._get_embedding(combined_text)
|
||||||
|
logger.debug(f"✅ 消息embedding生成成功,维度: {len(embedding)}")
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
async def _calculate_similarity_scores(
|
||||||
|
self, result: InterestMatchResult, message_embedding: List[float], keywords: List[str]
|
||||||
|
):
|
||||||
|
"""计算消息与兴趣标签的相似度分数"""
|
||||||
|
try:
|
||||||
|
if not self.current_interests:
|
||||||
|
return
|
||||||
|
|
||||||
|
active_tags = self.current_interests.get_active_tags()
|
||||||
|
if not active_tags:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug(f"🔍 开始计算与 {len(active_tags)} 个兴趣标签的相似度")
|
||||||
|
|
||||||
|
for tag in active_tags:
|
||||||
|
if tag.embedding:
|
||||||
|
# 计算余弦相似度
|
||||||
|
similarity = self._calculate_cosine_similarity(message_embedding, tag.embedding)
|
||||||
|
weighted_score = similarity * tag.weight
|
||||||
|
|
||||||
|
# 设置相似度阈值为0.3
|
||||||
|
if similarity > 0.3:
|
||||||
|
result.add_match(tag.tag_name, weighted_score, keywords)
|
||||||
|
logger.debug(
|
||||||
|
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 加权分数={weighted_score:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 计算相似度分数失败: {e}")
|
||||||
|
|
||||||
|
async def calculate_interest_match(self, message_text: str, keywords: List[str] = None) -> InterestMatchResult:
|
||||||
|
"""计算消息与机器人兴趣的匹配度"""
|
||||||
|
if not self.current_interests or not self._initialized:
|
||||||
|
raise RuntimeError("❌ 兴趣标签系统未初始化")
|
||||||
|
|
||||||
|
logger.debug(f"开始计算兴趣匹配度: 消息长度={len(message_text)}, 关键词数={len(keywords) if keywords else 0}")
|
||||||
|
|
||||||
|
message_id = f"msg_{datetime.now().timestamp()}"
|
||||||
|
result = InterestMatchResult(message_id=message_id)
|
||||||
|
|
||||||
|
# 获取活跃的兴趣标签
|
||||||
|
active_tags = self.current_interests.get_active_tags()
|
||||||
|
if not active_tags:
|
||||||
|
raise RuntimeError("没有检测到活跃的兴趣标签")
|
||||||
|
|
||||||
|
logger.debug(f"正在与 {len(active_tags)} 个兴趣标签进行匹配...")
|
||||||
|
|
||||||
|
# 生成消息的embedding
|
||||||
|
logger.debug("正在生成消息 embedding...")
|
||||||
|
message_embedding = await self._get_embedding(message_text)
|
||||||
|
logger.debug(f"消息 embedding 生成成功, 维度: {len(message_embedding)}")
|
||||||
|
|
||||||
|
# 计算与每个兴趣标签的相似度
|
||||||
|
match_count = 0
|
||||||
|
high_similarity_count = 0
|
||||||
|
medium_similarity_count = 0
|
||||||
|
low_similarity_count = 0
|
||||||
|
|
||||||
|
# 分级相似度阈值
|
||||||
|
affinity_config = global_config.affinity_flow
|
||||||
|
high_threshold = affinity_config.high_match_interest_threshold
|
||||||
|
medium_threshold = affinity_config.medium_match_interest_threshold
|
||||||
|
low_threshold = affinity_config.low_match_interest_threshold
|
||||||
|
|
||||||
|
logger.debug(f"🔍 使用分级相似度阈值: 高={high_threshold}, 中={medium_threshold}, 低={low_threshold}")
|
||||||
|
|
||||||
|
for tag in active_tags:
|
||||||
|
if tag.embedding:
|
||||||
|
similarity = self._calculate_cosine_similarity(message_embedding, tag.embedding)
|
||||||
|
|
||||||
|
# 基础加权分数
|
||||||
|
weighted_score = similarity * tag.weight
|
||||||
|
|
||||||
|
# 根据相似度等级应用不同的加成
|
||||||
|
if similarity > high_threshold:
|
||||||
|
# 高相似度:强加成
|
||||||
|
enhanced_score = weighted_score * affinity_config.high_match_keyword_multiplier
|
||||||
|
match_count += 1
|
||||||
|
high_similarity_count += 1
|
||||||
|
result.add_match(tag.tag_name, enhanced_score, [tag.tag_name])
|
||||||
|
|
||||||
|
elif similarity > medium_threshold:
|
||||||
|
# 中相似度:中等加成
|
||||||
|
enhanced_score = weighted_score * affinity_config.medium_match_keyword_multiplier
|
||||||
|
match_count += 1
|
||||||
|
medium_similarity_count += 1
|
||||||
|
result.add_match(tag.tag_name, enhanced_score, [tag.tag_name])
|
||||||
|
|
||||||
|
elif similarity > low_threshold:
|
||||||
|
# 低相似度:轻微加成
|
||||||
|
enhanced_score = weighted_score * affinity_config.low_match_keyword_multiplier
|
||||||
|
match_count += 1
|
||||||
|
low_similarity_count += 1
|
||||||
|
result.add_match(tag.tag_name, enhanced_score, [tag.tag_name])
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"匹配统计: {match_count}/{len(active_tags)} 个标签命中 | "
|
||||||
|
f"高(>{high_threshold}): {high_similarity_count}, "
|
||||||
|
f"中(>{medium_threshold}): {medium_similarity_count}, "
|
||||||
|
f"低(>{low_threshold}): {low_similarity_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加直接关键词匹配奖励
|
||||||
|
keyword_bonus = self._calculate_keyword_match_bonus(keywords, result.matched_tags)
|
||||||
|
logger.debug(f"🎯 关键词直接匹配奖励: {keyword_bonus}")
|
||||||
|
|
||||||
|
# 应用关键词奖励到匹配分数
|
||||||
|
for tag_name in result.matched_tags:
|
||||||
|
if tag_name in keyword_bonus:
|
||||||
|
original_score = result.match_scores[tag_name]
|
||||||
|
bonus = keyword_bonus[tag_name]
|
||||||
|
result.match_scores[tag_name] = original_score + bonus
|
||||||
|
logger.debug(
|
||||||
|
f" 🏷️ '{tag_name}': 原始分数={original_score:.3f}, 奖励={bonus:.3f}, 最终分数={result.match_scores[tag_name]:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算总体分数
|
||||||
|
result.calculate_overall_score()
|
||||||
|
|
||||||
|
# 确定最佳匹配标签
|
||||||
|
if result.matched_tags:
|
||||||
|
top_tag_name = max(result.match_scores.items(), key=lambda x: x[1])[0]
|
||||||
|
result.top_tag = top_tag_name
|
||||||
|
logger.debug(f"最佳匹配: '{top_tag_name}' (分数: {result.match_scores[top_tag_name]:.3f})")
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"最终结果: 总分={result.overall_score:.3f}, 置信度={result.confidence:.3f}, 匹配标签数={len(result.matched_tags)}"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _calculate_keyword_match_bonus(self, keywords: List[str], matched_tags: List[str]) -> Dict[str, float]:
|
||||||
|
"""计算关键词直接匹配奖励"""
|
||||||
|
if not keywords or not matched_tags:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
affinity_config = global_config.affinity_flow
|
||||||
|
bonus_dict = {}
|
||||||
|
|
||||||
|
for tag_name in matched_tags:
|
||||||
|
bonus = 0.0
|
||||||
|
|
||||||
|
# 检查关键词与标签的直接匹配
|
||||||
|
for keyword in keywords:
|
||||||
|
keyword_lower = keyword.lower().strip()
|
||||||
|
tag_name_lower = tag_name.lower()
|
||||||
|
|
||||||
|
# 完全匹配
|
||||||
|
if keyword_lower == tag_name_lower:
|
||||||
|
bonus += affinity_config.high_match_interest_threshold * 0.6 # 使用高匹配阈值的60%作为完全匹配奖励
|
||||||
|
logger.debug(
|
||||||
|
f" 🎯 关键词完全匹配: '{keyword}' == '{tag_name}' (+{affinity_config.high_match_interest_threshold * 0.6:.3f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 包含匹配
|
||||||
|
elif keyword_lower in tag_name_lower or tag_name_lower in keyword_lower:
|
||||||
|
bonus += (
|
||||||
|
affinity_config.medium_match_interest_threshold * 0.3
|
||||||
|
) # 使用中匹配阈值的30%作为包含匹配奖励
|
||||||
|
logger.debug(
|
||||||
|
f" 🎯 关键词包含匹配: '{keyword}' ⊃ '{tag_name}' (+{affinity_config.medium_match_interest_threshold * 0.3:.3f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 部分匹配(编辑距离)
|
||||||
|
elif self._calculate_partial_match(keyword_lower, tag_name_lower):
|
||||||
|
bonus += affinity_config.low_match_interest_threshold * 0.4 # 使用低匹配阈值的40%作为部分匹配奖励
|
||||||
|
logger.debug(
|
||||||
|
f" 🎯 关键词部分匹配: '{keyword}' ≈ '{tag_name}' (+{affinity_config.low_match_interest_threshold * 0.4:.3f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if bonus > 0:
|
||||||
|
bonus_dict[tag_name] = min(bonus, affinity_config.max_match_bonus) # 使用配置的最大奖励限制
|
||||||
|
|
||||||
|
return bonus_dict
|
||||||
|
|
||||||
|
def _calculate_partial_match(self, text1: str, text2: str) -> bool:
|
||||||
|
"""计算部分匹配(基于编辑距离)"""
|
||||||
|
try:
|
||||||
|
# 简单的编辑距离计算
|
||||||
|
max_len = max(len(text1), len(text2))
|
||||||
|
if max_len == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 计算编辑距离
|
||||||
|
distance = self._levenshtein_distance(text1, text2)
|
||||||
|
|
||||||
|
# 如果编辑距离小于较短字符串长度的一半,认为是部分匹配
|
||||||
|
min_len = min(len(text1), len(text2))
|
||||||
|
return distance <= min_len // 2
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _levenshtein_distance(self, s1: str, s2: str) -> int:
|
||||||
|
"""计算莱文斯坦距离"""
|
||||||
|
if len(s1) < len(s2):
|
||||||
|
return self._levenshtein_distance(s2, s1)
|
||||||
|
|
||||||
|
if len(s2) == 0:
|
||||||
|
return len(s1)
|
||||||
|
|
||||||
|
previous_row = range(len(s2) + 1)
|
||||||
|
for i, c1 in enumerate(s1):
|
||||||
|
current_row = [i + 1]
|
||||||
|
for j, c2 in enumerate(s2):
|
||||||
|
insertions = previous_row[j + 1] + 1
|
||||||
|
deletions = current_row[j] + 1
|
||||||
|
substitutions = previous_row[j] + (c1 != c2)
|
||||||
|
current_row.append(min(insertions, deletions, substitutions))
|
||||||
|
previous_row = current_row
|
||||||
|
|
||||||
|
return previous_row[-1]
|
||||||
|
|
||||||
|
def _calculate_cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
|
||||||
|
"""计算余弦相似度"""
|
||||||
|
try:
|
||||||
|
vec1 = np.array(vec1)
|
||||||
|
vec2 = np.array(vec2)
|
||||||
|
|
||||||
|
dot_product = np.dot(vec1, vec2)
|
||||||
|
norm1 = np.linalg.norm(vec1)
|
||||||
|
norm2 = np.linalg.norm(vec2)
|
||||||
|
|
||||||
|
if norm1 == 0 or norm2 == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
return dot_product / (norm1 * norm2)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"计算余弦相似度失败: {e}")
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
async def _load_interests_from_database(self, personality_id: str) -> Optional[BotPersonalityInterests]:
|
||||||
|
"""从数据库加载兴趣标签"""
|
||||||
|
try:
|
||||||
|
logger.debug(f"从数据库加载兴趣标签, personality_id: {personality_id}")
|
||||||
|
|
||||||
|
# 导入SQLAlchemy相关模块
|
||||||
|
from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests
|
||||||
|
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||||
|
import orjson
|
||||||
|
|
||||||
|
with get_db_session() as session:
|
||||||
|
# 查询最新的兴趣标签配置
|
||||||
|
db_interests = (
|
||||||
|
session.query(DBBotPersonalityInterests)
|
||||||
|
.filter(DBBotPersonalityInterests.personality_id == personality_id)
|
||||||
|
.order_by(DBBotPersonalityInterests.version.desc(), DBBotPersonalityInterests.last_updated.desc())
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if db_interests:
|
||||||
|
logger.debug(f"在数据库中找到兴趣标签配置, 版本: {db_interests.version}")
|
||||||
|
logger.debug(f"📅 最后更新时间: {db_interests.last_updated}")
|
||||||
|
logger.debug(f"🧠 使用的embedding模型: {db_interests.embedding_model}")
|
||||||
|
|
||||||
|
# 解析JSON格式的兴趣标签
|
||||||
|
try:
|
||||||
|
tags_data = orjson.loads(db_interests.interest_tags)
|
||||||
|
logger.debug(f"🏷️ 解析到 {len(tags_data)} 个兴趣标签")
|
||||||
|
|
||||||
|
# 创建BotPersonalityInterests对象
|
||||||
|
interests = BotPersonalityInterests(
|
||||||
|
personality_id=db_interests.personality_id,
|
||||||
|
personality_description=db_interests.personality_description,
|
||||||
|
embedding_model=db_interests.embedding_model,
|
||||||
|
version=db_interests.version,
|
||||||
|
last_updated=db_interests.last_updated,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解析兴趣标签
|
||||||
|
for tag_data in tags_data:
|
||||||
|
tag = BotInterestTag(
|
||||||
|
tag_name=tag_data.get("tag_name", ""),
|
||||||
|
weight=tag_data.get("weight", 0.5),
|
||||||
|
created_at=datetime.fromisoformat(
|
||||||
|
tag_data.get("created_at", datetime.now().isoformat())
|
||||||
|
),
|
||||||
|
updated_at=datetime.fromisoformat(
|
||||||
|
tag_data.get("updated_at", datetime.now().isoformat())
|
||||||
|
),
|
||||||
|
is_active=tag_data.get("is_active", True),
|
||||||
|
embedding=tag_data.get("embedding"),
|
||||||
|
)
|
||||||
|
interests.interest_tags.append(tag)
|
||||||
|
|
||||||
|
logger.debug(f"成功解析 {len(interests.interest_tags)} 个兴趣标签")
|
||||||
|
return interests
|
||||||
|
|
||||||
|
except (orjson.JSONDecodeError, Exception) as e:
|
||||||
|
logger.error(f"❌ 解析兴趣标签JSON失败: {e}")
|
||||||
|
logger.debug(f"🔍 原始JSON数据: {db_interests.interest_tags[:200]}...")
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
logger.info(f"ℹ️ 数据库中未找到personality_id为 '{personality_id}' 的兴趣标签配置")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 从数据库加载兴趣标签失败: {e}")
|
||||||
|
logger.error("🔍 错误详情:")
|
||||||
|
traceback.print_exc()
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _save_interests_to_database(self, interests: BotPersonalityInterests):
|
||||||
|
"""保存兴趣标签到数据库"""
|
||||||
|
try:
|
||||||
|
logger.info("💾 正在保存兴趣标签到数据库...")
|
||||||
|
logger.info(f"📋 personality_id: {interests.personality_id}")
|
||||||
|
logger.info(f"🏷️ 兴趣标签数量: {len(interests.interest_tags)}")
|
||||||
|
logger.info(f"🔄 版本: {interests.version}")
|
||||||
|
|
||||||
|
# 导入SQLAlchemy相关模块
|
||||||
|
from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests
|
||||||
|
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||||
|
import orjson
|
||||||
|
|
||||||
|
# 将兴趣标签转换为JSON格式
|
||||||
|
tags_data = []
|
||||||
|
for tag in interests.interest_tags:
|
||||||
|
tag_dict = {
|
||||||
|
"tag_name": tag.tag_name,
|
||||||
|
"weight": tag.weight,
|
||||||
|
"created_at": tag.created_at.isoformat(),
|
||||||
|
"updated_at": tag.updated_at.isoformat(),
|
||||||
|
"is_active": tag.is_active,
|
||||||
|
"embedding": tag.embedding,
|
||||||
|
}
|
||||||
|
tags_data.append(tag_dict)
|
||||||
|
|
||||||
|
# 序列化为JSON
|
||||||
|
json_data = orjson.dumps(tags_data)
|
||||||
|
|
||||||
|
with get_db_session() as session:
|
||||||
|
# 检查是否已存在相同personality_id的记录
|
||||||
|
existing_record = (
|
||||||
|
session.query(DBBotPersonalityInterests)
|
||||||
|
.filter(DBBotPersonalityInterests.personality_id == interests.personality_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if existing_record:
|
||||||
|
# 更新现有记录
|
||||||
|
logger.info("🔄 更新现有的兴趣标签配置")
|
||||||
|
existing_record.interest_tags = json_data
|
||||||
|
existing_record.personality_description = interests.personality_description
|
||||||
|
existing_record.embedding_model = interests.embedding_model
|
||||||
|
existing_record.version = interests.version
|
||||||
|
existing_record.last_updated = interests.last_updated
|
||||||
|
|
||||||
|
logger.info(f"✅ 成功更新兴趣标签配置,版本: {interests.version}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 创建新记录
|
||||||
|
logger.info("🆕 创建新的兴趣标签配置")
|
||||||
|
new_record = DBBotPersonalityInterests(
|
||||||
|
personality_id=interests.personality_id,
|
||||||
|
personality_description=interests.personality_description,
|
||||||
|
interest_tags=json_data,
|
||||||
|
embedding_model=interests.embedding_model,
|
||||||
|
version=interests.version,
|
||||||
|
last_updated=interests.last_updated,
|
||||||
|
)
|
||||||
|
session.add(new_record)
|
||||||
|
session.commit()
|
||||||
|
logger.info(f"✅ 成功创建兴趣标签配置,版本: {interests.version}")
|
||||||
|
|
||||||
|
logger.info("✅ 兴趣标签已成功保存到数据库")
|
||||||
|
|
||||||
|
# 验证保存是否成功
|
||||||
|
with get_db_session() as session:
|
||||||
|
saved_record = (
|
||||||
|
session.query(DBBotPersonalityInterests)
|
||||||
|
.filter(DBBotPersonalityInterests.personality_id == interests.personality_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
if saved_record:
|
||||||
|
logger.info(f"✅ 验证成功:数据库中存在personality_id为 {interests.personality_id} 的记录")
|
||||||
|
logger.info(f" 版本: {saved_record.version}")
|
||||||
|
logger.info(f" 最后更新: {saved_record.last_updated}")
|
||||||
|
else:
|
||||||
|
logger.error(f"❌ 验证失败:数据库中未找到personality_id为 {interests.personality_id} 的记录")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 保存兴趣标签到数据库失败: {e}")
|
||||||
|
logger.error("🔍 错误详情:")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
def get_current_interests(self) -> Optional[BotPersonalityInterests]:
|
||||||
|
"""获取当前的兴趣标签配置"""
|
||||||
|
return self.current_interests
|
||||||
|
|
||||||
|
def get_interest_stats(self) -> Dict[str, Any]:
|
||||||
|
"""获取兴趣系统统计信息"""
|
||||||
|
if not self.current_interests:
|
||||||
|
return {"initialized": False}
|
||||||
|
|
||||||
|
active_tags = self.current_interests.get_active_tags()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"initialized": self._initialized,
|
||||||
|
"total_tags": len(active_tags),
|
||||||
|
"embedding_model": self.current_interests.embedding_model,
|
||||||
|
"last_updated": self.current_interests.last_updated.isoformat(),
|
||||||
|
"cache_size": len(self.embedding_cache),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def update_interest_tags(self, new_personality_description: str = None):
|
||||||
|
"""更新兴趣标签"""
|
||||||
|
try:
|
||||||
|
if not self.current_interests:
|
||||||
|
logger.warning("没有当前的兴趣标签配置,无法更新")
|
||||||
|
return
|
||||||
|
|
||||||
|
if new_personality_description:
|
||||||
|
self.current_interests.personality_description = new_personality_description
|
||||||
|
|
||||||
|
# 重新生成兴趣标签
|
||||||
|
new_interests = await self._generate_interests_from_personality(
|
||||||
|
self.current_interests.personality_description, self.current_interests.personality_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if new_interests:
|
||||||
|
new_interests.version = self.current_interests.version + 1
|
||||||
|
self.current_interests = new_interests
|
||||||
|
await self._save_interests_to_database(new_interests)
|
||||||
|
logger.info(f"兴趣标签已更新,版本: {new_interests.version}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"更新兴趣标签失败: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
# 创建全局实例(重新创建以包含新的属性)
|
||||||
|
bot_interest_manager = BotInterestManager()
|
||||||
26
src/chat/message_manager/__init__.py
Normal file
26
src/chat/message_manager/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""
|
||||||
|
消息管理器模块
|
||||||
|
提供统一的消息管理、上下文管理和分发调度功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .message_manager import MessageManager, message_manager
|
||||||
|
from .context_manager import StreamContextManager, context_manager
|
||||||
|
from .distribution_manager import (
|
||||||
|
DistributionManager,
|
||||||
|
DistributionPriority,
|
||||||
|
DistributionTask,
|
||||||
|
StreamDistributionState,
|
||||||
|
distribution_manager
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MessageManager",
|
||||||
|
"message_manager",
|
||||||
|
"StreamContextManager",
|
||||||
|
"context_manager",
|
||||||
|
"DistributionManager",
|
||||||
|
"DistributionPriority",
|
||||||
|
"DistributionTask",
|
||||||
|
"StreamDistributionState",
|
||||||
|
"distribution_manager"
|
||||||
|
]
|
||||||
653
src/chat/message_manager/context_manager.py
Normal file
653
src/chat/message_manager/context_manager.py
Normal file
@@ -0,0 +1,653 @@
|
|||||||
|
"""
|
||||||
|
重构后的聊天上下文管理器
|
||||||
|
提供统一、稳定的聊天上下文管理功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Optional, Any, Union, Tuple
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from src.common.data_models.message_manager_data_model import StreamContext
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
from src.chat.energy_system import energy_manager
|
||||||
|
from .distribution_manager import distribution_manager
|
||||||
|
|
||||||
|
logger = get_logger("context_manager")
|
||||||
|
|
||||||
|
class StreamContextManager:
|
||||||
|
"""流上下文管理器 - 统一管理所有聊天流上下文"""
|
||||||
|
|
||||||
|
def __init__(self, max_context_size: Optional[int] = None, context_ttl: Optional[int] = None):
|
||||||
|
# 上下文存储
|
||||||
|
self.stream_contexts: Dict[str, Any] = {}
|
||||||
|
self.context_metadata: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
# 统计信息
|
||||||
|
self.stats: Dict[str, Union[int, float, str, Dict]] = {
|
||||||
|
"total_messages": 0,
|
||||||
|
"total_streams": 0,
|
||||||
|
"active_streams": 0,
|
||||||
|
"inactive_streams": 0,
|
||||||
|
"last_activity": time.time(),
|
||||||
|
"creation_time": time.time(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 配置参数
|
||||||
|
self.max_context_size = max_context_size or getattr(global_config.chat, "max_context_size", 100)
|
||||||
|
self.context_ttl = context_ttl or getattr(global_config.chat, "context_ttl", 24 * 3600) # 24小时
|
||||||
|
self.cleanup_interval = getattr(global_config.chat, "context_cleanup_interval", 3600) # 1小时
|
||||||
|
self.auto_cleanup = getattr(global_config.chat, "auto_cleanup_contexts", True)
|
||||||
|
self.enable_validation = getattr(global_config.chat, "enable_context_validation", True)
|
||||||
|
|
||||||
|
# 清理任务
|
||||||
|
self.cleanup_task: Optional[Any] = None
|
||||||
|
self.is_running = False
|
||||||
|
|
||||||
|
logger.info(f"上下文管理器初始化完成 (最大上下文: {self.max_context_size}, TTL: {self.context_ttl}s)")
|
||||||
|
|
||||||
|
def add_stream_context(self, stream_id: str, context: Any, metadata: Optional[Dict[str, Any]] = None) -> bool:
|
||||||
|
"""添加流上下文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id: 流ID
|
||||||
|
context: 上下文对象
|
||||||
|
metadata: 上下文元数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否成功添加
|
||||||
|
"""
|
||||||
|
if stream_id in self.stream_contexts:
|
||||||
|
logger.warning(f"流上下文已存在: {stream_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 添加上下文
|
||||||
|
self.stream_contexts[stream_id] = context
|
||||||
|
|
||||||
|
# 初始化元数据
|
||||||
|
self.context_metadata[stream_id] = {
|
||||||
|
"created_time": time.time(),
|
||||||
|
"last_access_time": time.time(),
|
||||||
|
"access_count": 0,
|
||||||
|
"last_validation_time": 0.0,
|
||||||
|
"custom_metadata": metadata or {},
|
||||||
|
}
|
||||||
|
|
||||||
|
# 更新统计
|
||||||
|
self.stats["total_streams"] += 1
|
||||||
|
self.stats["active_streams"] += 1
|
||||||
|
self.stats["last_activity"] = time.time()
|
||||||
|
|
||||||
|
logger.debug(f"添加流上下文: {stream_id} (类型: {type(context).__name__})")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def remove_stream_context(self, stream_id: str) -> bool:
|
||||||
|
"""移除流上下文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id: 流ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否成功移除
|
||||||
|
"""
|
||||||
|
if stream_id in self.stream_contexts:
|
||||||
|
context = self.stream_contexts[stream_id]
|
||||||
|
metadata = self.context_metadata.get(stream_id, {})
|
||||||
|
|
||||||
|
del self.stream_contexts[stream_id]
|
||||||
|
if stream_id in self.context_metadata:
|
||||||
|
del self.context_metadata[stream_id]
|
||||||
|
|
||||||
|
self.stats["active_streams"] = max(0, self.stats["active_streams"] - 1)
|
||||||
|
self.stats["inactive_streams"] += 1
|
||||||
|
self.stats["last_activity"] = time.time()
|
||||||
|
|
||||||
|
logger.debug(f"移除流上下文: {stream_id} (类型: {type(context).__name__})")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_stream_context(self, stream_id: str, update_access: bool = True) -> Optional[StreamContext]:
|
||||||
|
"""获取流上下文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id: 流ID
|
||||||
|
update_access: 是否更新访问统计
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Any]: 上下文对象
|
||||||
|
"""
|
||||||
|
context = self.stream_contexts.get(stream_id)
|
||||||
|
if context and update_access:
|
||||||
|
# 更新访问统计
|
||||||
|
if stream_id in self.context_metadata:
|
||||||
|
metadata = self.context_metadata[stream_id]
|
||||||
|
metadata["last_access_time"] = time.time()
|
||||||
|
metadata["access_count"] = metadata.get("access_count", 0) + 1
|
||||||
|
return context
|
||||||
|
|
||||||
|
def get_context_metadata(self, stream_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""获取上下文元数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id: 流ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Dict[str, Any]]: 元数据
|
||||||
|
"""
|
||||||
|
return self.context_metadata.get(stream_id)
|
||||||
|
|
||||||
|
def update_context_metadata(self, stream_id: str, updates: Dict[str, Any]) -> bool:
|
||||||
|
"""更新上下文元数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id: 流ID
|
||||||
|
updates: 更新的元数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否成功更新
|
||||||
|
"""
|
||||||
|
if stream_id not in self.context_metadata:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.context_metadata[stream_id].update(updates)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def add_message_to_context(self, stream_id: str, message: DatabaseMessages, skip_energy_update: bool = False) -> bool:
|
||||||
|
"""添加消息到上下文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id: 流ID
|
||||||
|
message: 消息对象
|
||||||
|
skip_energy_update: 是否跳过能量更新
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否成功添加
|
||||||
|
"""
|
||||||
|
context = self.get_stream_context(stream_id)
|
||||||
|
if not context:
|
||||||
|
logger.warning(f"流上下文不存在: {stream_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 添加消息到上下文
|
||||||
|
context.add_message(message)
|
||||||
|
|
||||||
|
# 计算消息兴趣度
|
||||||
|
interest_value = self._calculate_message_interest(message)
|
||||||
|
message.interest_value = interest_value
|
||||||
|
|
||||||
|
# 更新统计
|
||||||
|
self.stats["total_messages"] += 1
|
||||||
|
self.stats["last_activity"] = time.time()
|
||||||
|
|
||||||
|
# 更新能量和分发
|
||||||
|
if not skip_energy_update:
|
||||||
|
self._update_stream_energy(stream_id)
|
||||||
|
distribution_manager.add_stream_message(stream_id, 1)
|
||||||
|
|
||||||
|
logger.debug(f"添加消息到上下文: {stream_id} (兴趣度: {interest_value:.3f})")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"添加消息到上下文失败 {stream_id}: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def update_message_in_context(self, stream_id: str, message_id: str, updates: Dict[str, Any]) -> bool:
|
||||||
|
"""更新上下文中的消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id: 流ID
|
||||||
|
message_id: 消息ID
|
||||||
|
updates: 更新的属性
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否成功更新
|
||||||
|
"""
|
||||||
|
context = self.get_stream_context(stream_id)
|
||||||
|
if not context:
|
||||||
|
logger.warning(f"流上下文不存在: {stream_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 更新消息信息
|
||||||
|
context.update_message_info(message_id, **updates)
|
||||||
|
|
||||||
|
# 如果更新了兴趣度,重新计算能量
|
||||||
|
if "interest_value" in updates:
|
||||||
|
self._update_stream_energy(stream_id)
|
||||||
|
|
||||||
|
logger.debug(f"更新上下文消息: {stream_id}/{message_id}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"更新上下文消息失败 {stream_id}/{message_id}: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_context_messages(self, stream_id: str, limit: Optional[int] = None, include_unread: bool = True) -> List[DatabaseMessages]:
|
||||||
|
"""获取上下文消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id: 流ID
|
||||||
|
limit: 消息数量限制
|
||||||
|
include_unread: 是否包含未读消息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Any]: 消息列表
|
||||||
|
"""
|
||||||
|
context = self.get_stream_context(stream_id)
|
||||||
|
if not context:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
messages = []
|
||||||
|
if include_unread:
|
||||||
|
messages.extend(context.get_unread_messages())
|
||||||
|
|
||||||
|
if limit:
|
||||||
|
messages.extend(context.get_history_messages(limit=limit))
|
||||||
|
else:
|
||||||
|
messages.extend(context.get_history_messages())
|
||||||
|
|
||||||
|
# 按时间排序
|
||||||
|
messages.sort(key=lambda msg: getattr(msg, 'time', 0))
|
||||||
|
|
||||||
|
# 应用限制
|
||||||
|
if limit and len(messages) > limit:
|
||||||
|
messages = messages[-limit:]
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取上下文消息失败 {stream_id}: {e}", exc_info=True)
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_unread_messages(self, stream_id: str) -> List[DatabaseMessages]:
|
||||||
|
"""获取未读消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id: 流ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Any]: 未读消息列表
|
||||||
|
"""
|
||||||
|
context = self.get_stream_context(stream_id)
|
||||||
|
if not context:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
return context.get_unread_messages()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取未读消息失败 {stream_id}: {e}", exc_info=True)
|
||||||
|
return []
|
||||||
|
|
||||||
|
def mark_messages_as_read(self, stream_id: str, message_ids: List[str]) -> bool:
|
||||||
|
"""标记消息为已读
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id: 流ID
|
||||||
|
message_ids: 消息ID列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否成功标记
|
||||||
|
"""
|
||||||
|
context = self.get_stream_context(stream_id)
|
||||||
|
if not context:
|
||||||
|
logger.warning(f"流上下文不存在: {stream_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not hasattr(context, 'mark_message_as_read'):
|
||||||
|
logger.error(f"上下文对象缺少 mark_message_as_read 方法: {stream_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
marked_count = 0
|
||||||
|
for message_id in message_ids:
|
||||||
|
try:
|
||||||
|
context.mark_message_as_read(message_id)
|
||||||
|
marked_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"标记消息已读失败 {message_id}: {e}")
|
||||||
|
|
||||||
|
logger.debug(f"标记消息为已读: {stream_id} ({marked_count}/{len(message_ids)}条)")
|
||||||
|
return marked_count > 0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"标记消息已读失败 {stream_id}: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def clear_context(self, stream_id: str) -> bool:
|
||||||
|
"""清空上下文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id: 流ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否成功清空
|
||||||
|
"""
|
||||||
|
context = self.get_stream_context(stream_id)
|
||||||
|
if not context:
|
||||||
|
logger.warning(f"流上下文不存在: {stream_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 清空消息
|
||||||
|
if hasattr(context, 'unread_messages'):
|
||||||
|
context.unread_messages.clear()
|
||||||
|
if hasattr(context, 'history_messages'):
|
||||||
|
context.history_messages.clear()
|
||||||
|
|
||||||
|
# 重置状态
|
||||||
|
reset_attrs = ['interruption_count', 'afc_threshold_adjustment', 'last_check_time']
|
||||||
|
for attr in reset_attrs:
|
||||||
|
if hasattr(context, attr):
|
||||||
|
if attr in ['interruption_count', 'afc_threshold_adjustment']:
|
||||||
|
setattr(context, attr, 0)
|
||||||
|
else:
|
||||||
|
setattr(context, attr, time.time())
|
||||||
|
|
||||||
|
# 重新计算能量
|
||||||
|
self._update_stream_energy(stream_id)
|
||||||
|
|
||||||
|
logger.info(f"清空上下文: {stream_id}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"清空上下文失败 {stream_id}: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _calculate_message_interest(self, message: DatabaseMessages) -> float:
|
||||||
|
"""计算消息兴趣度"""
|
||||||
|
try:
|
||||||
|
# 使用插件内部的兴趣度评分系统
|
||||||
|
try:
|
||||||
|
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||||
|
|
||||||
|
# 使用插件内部的兴趣度评分系统计算(同步方式)
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
interest_score = loop.run_until_complete(
|
||||||
|
chatter_interest_scoring_system._calculate_single_message_score(
|
||||||
|
message=message,
|
||||||
|
bot_nickname=global_config.bot.nickname
|
||||||
|
)
|
||||||
|
)
|
||||||
|
interest_value = interest_score.total_score
|
||||||
|
|
||||||
|
logger.debug(f"使用插件内部系统计算兴趣度: {interest_value:.3f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"插件内部兴趣度计算失败,使用默认值: {e}")
|
||||||
|
interest_value = 0.5 # 默认中等兴趣度
|
||||||
|
|
||||||
|
return interest_value
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"计算消息兴趣度失败: {e}")
|
||||||
|
return 0.5
|
||||||
|
|
||||||
|
def _update_stream_energy(self, stream_id: str):
|
||||||
|
"""更新流能量"""
|
||||||
|
try:
|
||||||
|
# 获取所有消息
|
||||||
|
all_messages = self.get_context_messages(stream_id, self.max_context_size)
|
||||||
|
unread_messages = self.get_unread_messages(stream_id)
|
||||||
|
combined_messages = all_messages + unread_messages
|
||||||
|
|
||||||
|
# 获取用户ID
|
||||||
|
user_id = None
|
||||||
|
if combined_messages:
|
||||||
|
last_message = combined_messages[-1]
|
||||||
|
user_id = last_message.user_info.user_id
|
||||||
|
|
||||||
|
# 计算能量
|
||||||
|
energy = energy_manager.calculate_focus_energy(
|
||||||
|
stream_id=stream_id,
|
||||||
|
messages=combined_messages,
|
||||||
|
user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新分发管理器
|
||||||
|
distribution_manager.update_stream_energy(stream_id, energy)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"更新流能量失败 {stream_id}: {e}")
|
||||||
|
|
||||||
|
def get_stream_statistics(self, stream_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""获取流统计信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id: 流ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Dict[str, Any]]: 统计信息
|
||||||
|
"""
|
||||||
|
context = self.get_stream_context(stream_id, update_access=False)
|
||||||
|
if not context:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
metadata = self.context_metadata.get(stream_id, {})
|
||||||
|
current_time = time.time()
|
||||||
|
created_time = metadata.get("created_time", current_time)
|
||||||
|
last_access_time = metadata.get("last_access_time", current_time)
|
||||||
|
access_count = metadata.get("access_count", 0)
|
||||||
|
|
||||||
|
unread_messages = getattr(context, "unread_messages", [])
|
||||||
|
history_messages = getattr(context, "history_messages", [])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"stream_id": stream_id,
|
||||||
|
"context_type": type(context).__name__,
|
||||||
|
"total_messages": len(history_messages) + len(unread_messages),
|
||||||
|
"unread_messages": len(unread_messages),
|
||||||
|
"history_messages": len(history_messages),
|
||||||
|
"is_active": getattr(context, "is_active", True),
|
||||||
|
"last_check_time": getattr(context, "last_check_time", current_time),
|
||||||
|
"interruption_count": getattr(context, "interruption_count", 0),
|
||||||
|
"afc_threshold_adjustment": getattr(context, "afc_threshold_adjustment", 0.0),
|
||||||
|
"created_time": created_time,
|
||||||
|
"last_access_time": last_access_time,
|
||||||
|
"access_count": access_count,
|
||||||
|
"uptime_seconds": current_time - created_time,
|
||||||
|
"idle_seconds": current_time - last_access_time,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取流统计失败 {stream_id}: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_manager_statistics(self) -> Dict[str, Any]:
|
||||||
|
"""获取管理器统计信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: 管理器统计信息
|
||||||
|
"""
|
||||||
|
current_time = time.time()
|
||||||
|
uptime = current_time - self.stats.get("creation_time", current_time)
|
||||||
|
|
||||||
|
return {
|
||||||
|
**self.stats,
|
||||||
|
"uptime_hours": uptime / 3600,
|
||||||
|
"stream_count": len(self.stream_contexts),
|
||||||
|
"metadata_count": len(self.context_metadata),
|
||||||
|
"auto_cleanup_enabled": self.auto_cleanup,
|
||||||
|
"cleanup_interval": self.cleanup_interval,
|
||||||
|
}
|
||||||
|
|
||||||
|
def cleanup_inactive_contexts(self, max_inactive_hours: int = 24) -> int:
|
||||||
|
"""清理不活跃的上下文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_inactive_hours: 最大不活跃小时数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 清理的上下文数量
|
||||||
|
"""
|
||||||
|
current_time = time.time()
|
||||||
|
max_inactive_seconds = max_inactive_hours * 3600
|
||||||
|
|
||||||
|
inactive_streams = []
|
||||||
|
for stream_id, context in self.stream_contexts.items():
|
||||||
|
try:
|
||||||
|
# 获取最后活动时间
|
||||||
|
metadata = self.context_metadata.get(stream_id, {})
|
||||||
|
last_activity = metadata.get("last_access_time", metadata.get("created_time", 0))
|
||||||
|
context_last_activity = getattr(context, "last_check_time", 0)
|
||||||
|
actual_last_activity = max(last_activity, context_last_activity)
|
||||||
|
|
||||||
|
# 检查是否不活跃
|
||||||
|
unread_count = len(getattr(context, "unread_messages", []))
|
||||||
|
history_count = len(getattr(context, "history_messages", []))
|
||||||
|
total_messages = unread_count + history_count
|
||||||
|
|
||||||
|
if (current_time - actual_last_activity > max_inactive_seconds and
|
||||||
|
total_messages == 0):
|
||||||
|
inactive_streams.append(stream_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"检查上下文活跃状态失败 {stream_id}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 清理不活跃上下文
|
||||||
|
cleaned_count = 0
|
||||||
|
for stream_id in inactive_streams:
|
||||||
|
if self.remove_stream_context(stream_id):
|
||||||
|
cleaned_count += 1
|
||||||
|
|
||||||
|
if cleaned_count > 0:
|
||||||
|
logger.info(f"清理了 {cleaned_count} 个不活跃上下文")
|
||||||
|
|
||||||
|
return cleaned_count
|
||||||
|
|
||||||
|
def validate_context_integrity(self, stream_id: str) -> bool:
|
||||||
|
"""验证上下文完整性
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id: 流ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否完整
|
||||||
|
"""
|
||||||
|
context = self.get_stream_context(stream_id)
|
||||||
|
if not context:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 检查基本属性
|
||||||
|
required_attrs = ["stream_id", "unread_messages", "history_messages"]
|
||||||
|
for attr in required_attrs:
|
||||||
|
if not hasattr(context, attr):
|
||||||
|
logger.warning(f"上下文缺少必要属性: {attr}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查消息ID唯一性
|
||||||
|
all_messages = getattr(context, "unread_messages", []) + getattr(context, "history_messages", [])
|
||||||
|
message_ids = [msg.message_id for msg in all_messages if hasattr(msg, "message_id")]
|
||||||
|
if len(message_ids) != len(set(message_ids)):
|
||||||
|
logger.warning(f"上下文中存在重复消息ID: {stream_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"验证上下文完整性失败 {stream_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""启动上下文管理器"""
|
||||||
|
if self.is_running:
|
||||||
|
logger.warning("上下文管理器已经在运行")
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.start_auto_cleanup()
|
||||||
|
logger.info("上下文管理器已启动")
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""停止上下文管理器"""
|
||||||
|
if not self.is_running:
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.stop_auto_cleanup()
|
||||||
|
logger.info("上下文管理器已停止")
|
||||||
|
|
||||||
|
async def start_auto_cleanup(self, interval: Optional[float] = None) -> None:
|
||||||
|
"""启动自动清理
|
||||||
|
|
||||||
|
Args:
|
||||||
|
interval: 清理间隔(秒)
|
||||||
|
"""
|
||||||
|
if not self.auto_cleanup:
|
||||||
|
logger.info("自动清理已禁用")
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.is_running:
|
||||||
|
logger.warning("自动清理已在运行")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.is_running = True
|
||||||
|
cleanup_interval = interval or self.cleanup_interval
|
||||||
|
logger.info(f"启动自动清理(间隔: {cleanup_interval}s)")
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
self.cleanup_task = asyncio.create_task(self._cleanup_loop(cleanup_interval))
|
||||||
|
|
||||||
|
async def stop_auto_cleanup(self) -> None:
|
||||||
|
"""停止自动清理"""
|
||||||
|
self.is_running = False
|
||||||
|
if self.cleanup_task and not self.cleanup_task.done():
|
||||||
|
self.cleanup_task.cancel()
|
||||||
|
try:
|
||||||
|
await self.cleanup_task
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
logger.info("自动清理已停止")
|
||||||
|
|
||||||
|
async def _cleanup_loop(self, interval: float) -> None:
|
||||||
|
"""清理循环
|
||||||
|
|
||||||
|
Args:
|
||||||
|
interval: 清理间隔
|
||||||
|
"""
|
||||||
|
while self.is_running:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
self.cleanup_inactive_contexts()
|
||||||
|
self._cleanup_expired_contexts()
|
||||||
|
logger.debug("自动清理完成")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"清理循环出错: {e}", exc_info=True)
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
|
||||||
|
def _cleanup_expired_contexts(self) -> None:
|
||||||
|
"""清理过期上下文"""
|
||||||
|
current_time = time.time()
|
||||||
|
expired_contexts = []
|
||||||
|
|
||||||
|
for stream_id, metadata in self.context_metadata.items():
|
||||||
|
created_time = metadata.get("created_time", current_time)
|
||||||
|
if current_time - created_time > self.context_ttl:
|
||||||
|
expired_contexts.append(stream_id)
|
||||||
|
|
||||||
|
for stream_id in expired_contexts:
|
||||||
|
self.remove_stream_context(stream_id)
|
||||||
|
|
||||||
|
if expired_contexts:
|
||||||
|
logger.info(f"清理了 {len(expired_contexts)} 个过期上下文")
|
||||||
|
|
||||||
|
def get_active_streams(self) -> List[str]:
|
||||||
|
"""获取活跃流列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: 活跃流ID列表
|
||||||
|
"""
|
||||||
|
return list(self.stream_contexts.keys())
|
||||||
|
|
||||||
|
|
||||||
|
# 全局上下文管理器实例
|
||||||
|
context_manager = StreamContextManager()
|
||||||
1004
src/chat/message_manager/distribution_manager.py
Normal file
1004
src/chat/message_manager/distribution_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
558
src/chat/message_manager/message_manager.py
Normal file
558
src/chat/message_manager/message_manager.py
Normal file
@@ -0,0 +1,558 @@
|
|||||||
|
"""
|
||||||
|
消息管理模块
|
||||||
|
管理每个聊天流的上下文信息,包含历史记录和未读消息,定期检查并处理新消息
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from typing import Dict, Optional, Any, TYPE_CHECKING
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
from src.common.data_models.message_manager_data_model import StreamContext, MessageManagerStats, StreamStats
|
||||||
|
from src.chat.chatter_manager import ChatterManager
|
||||||
|
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||||
|
from src.plugin_system.base.component_types import ChatMode
|
||||||
|
from .sleep_manager.sleep_manager import SleepManager
|
||||||
|
from .sleep_manager.wakeup_manager import WakeUpManager
|
||||||
|
from src.config.config import global_config
|
||||||
|
from .context_manager import context_manager
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.common.data_models.message_manager_data_model import StreamContext
|
||||||
|
|
||||||
|
logger = get_logger("message_manager")
|
||||||
|
|
||||||
|
|
||||||
|
class MessageManager:
|
||||||
|
"""消息管理器"""
|
||||||
|
|
||||||
|
def __init__(self, check_interval: float = 5.0):
|
||||||
|
self.check_interval = check_interval # 检查间隔(秒)
|
||||||
|
self.is_running = False
|
||||||
|
self.manager_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
# 统计信息
|
||||||
|
self.stats = MessageManagerStats()
|
||||||
|
|
||||||
|
# 初始化chatter manager
|
||||||
|
self.action_manager = ChatterActionManager()
|
||||||
|
self.chatter_manager = ChatterManager(self.action_manager)
|
||||||
|
|
||||||
|
# 初始化睡眠和唤醒管理器
|
||||||
|
self.sleep_manager = SleepManager()
|
||||||
|
self.wakeup_manager = WakeUpManager(self.sleep_manager)
|
||||||
|
|
||||||
|
# 初始化上下文管理器
|
||||||
|
self.context_manager = context_manager
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""启动消息管理器"""
|
||||||
|
if self.is_running:
|
||||||
|
logger.warning("消息管理器已经在运行")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.is_running = True
|
||||||
|
self.manager_task = asyncio.create_task(self._manager_loop())
|
||||||
|
await self.wakeup_manager.start()
|
||||||
|
await self.context_manager.start()
|
||||||
|
logger.info("消息管理器已启动")
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""停止消息管理器"""
|
||||||
|
if not self.is_running:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.is_running = False
|
||||||
|
|
||||||
|
# 停止所有流处理任务
|
||||||
|
# 注意:context_manager 会自己清理任务
|
||||||
|
if self.manager_task and not self.manager_task.done():
|
||||||
|
self.manager_task.cancel()
|
||||||
|
|
||||||
|
await self.wakeup_manager.stop()
|
||||||
|
await self.context_manager.stop()
|
||||||
|
|
||||||
|
logger.info("消息管理器已停止")
|
||||||
|
|
||||||
|
def add_message(self, stream_id: str, message: DatabaseMessages):
|
||||||
|
"""添加消息到指定聊天流"""
|
||||||
|
# 检查流上下文是否存在,不存在则创建
|
||||||
|
context = self.context_manager.get_stream_context(stream_id)
|
||||||
|
if not context:
|
||||||
|
# 创建新的流上下文
|
||||||
|
from src.common.data_models.message_manager_data_model import StreamContext
|
||||||
|
context = StreamContext(stream_id=stream_id)
|
||||||
|
# 将创建的上下文添加到管理器
|
||||||
|
self.context_manager.add_stream_context(stream_id, context)
|
||||||
|
|
||||||
|
# 使用 context_manager 添加消息
|
||||||
|
success = self.context_manager.add_message_to_context(stream_id, message)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
logger.debug(f"添加消息到聊天流 {stream_id}: {message.message_id}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"添加消息到聊天流 {stream_id} 失败")
|
||||||
|
|
||||||
|
def update_message(
|
||||||
|
self,
|
||||||
|
stream_id: str,
|
||||||
|
message_id: str,
|
||||||
|
interest_value: float = None,
|
||||||
|
actions: list = None,
|
||||||
|
should_reply: bool = None,
|
||||||
|
):
|
||||||
|
"""更新消息信息"""
|
||||||
|
# 使用 context_manager 更新消息信息
|
||||||
|
context = self.context_manager.get_stream_context(stream_id)
|
||||||
|
if context:
|
||||||
|
context.update_message_info(message_id, interest_value, actions, should_reply)
|
||||||
|
|
||||||
|
def add_action(self, stream_id: str, message_id: str, action: str):
|
||||||
|
"""添加动作到消息"""
|
||||||
|
# 使用 context_manager 添加动作到消息
|
||||||
|
context = self.context_manager.get_stream_context(stream_id)
|
||||||
|
if context:
|
||||||
|
context.add_action_to_message(message_id, action)
|
||||||
|
|
||||||
|
async def _manager_loop(self):
|
||||||
|
"""管理器主循环 - 独立聊天流分发周期版本"""
|
||||||
|
while self.is_running:
|
||||||
|
try:
|
||||||
|
# 更新睡眠状态
|
||||||
|
await self.sleep_manager.update_sleep_state(self.wakeup_manager)
|
||||||
|
|
||||||
|
# 执行独立分发周期的检查
|
||||||
|
await self._check_streams_with_individual_intervals()
|
||||||
|
|
||||||
|
# 计算下次检查时间(使用最小间隔或固定间隔)
|
||||||
|
if global_config.chat.dynamic_distribution_enabled:
|
||||||
|
next_check_delay = self._calculate_next_manager_delay()
|
||||||
|
else:
|
||||||
|
next_check_delay = self.check_interval
|
||||||
|
|
||||||
|
await asyncio.sleep(next_check_delay)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"消息管理器循环出错: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def _check_all_streams(self):
|
||||||
|
"""检查所有聊天流"""
|
||||||
|
active_streams = 0
|
||||||
|
total_unread = 0
|
||||||
|
|
||||||
|
# 使用 context_manager 获取活跃的流
|
||||||
|
active_stream_ids = self.context_manager.get_active_streams()
|
||||||
|
|
||||||
|
for stream_id in active_stream_ids:
|
||||||
|
context = self.context_manager.get_stream_context(stream_id)
|
||||||
|
if not context:
|
||||||
|
continue
|
||||||
|
|
||||||
|
active_streams += 1
|
||||||
|
|
||||||
|
# 检查是否有未读消息
|
||||||
|
unread_messages = self.context_manager.get_unread_messages(stream_id)
|
||||||
|
if unread_messages:
|
||||||
|
total_unread += len(unread_messages)
|
||||||
|
|
||||||
|
# 如果没有处理任务,创建一个
|
||||||
|
if not hasattr(context, 'processing_task') or not context.processing_task or context.processing_task.done():
|
||||||
|
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
|
||||||
|
|
||||||
|
# 更新统计
|
||||||
|
self.stats.active_streams = active_streams
|
||||||
|
self.stats.total_unread_messages = total_unread
|
||||||
|
|
||||||
|
async def _process_stream_messages(self, stream_id: str):
|
||||||
|
"""处理指定聊天流的消息"""
|
||||||
|
context = self.context_manager.get_stream_context(stream_id)
|
||||||
|
if not context:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取未读消息
|
||||||
|
unread_messages = self.context_manager.get_unread_messages(stream_id)
|
||||||
|
if not unread_messages:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 检查是否需要打断现有处理
|
||||||
|
await self._check_and_handle_interruption(context, stream_id)
|
||||||
|
|
||||||
|
# --- 睡眠状态检查 ---
|
||||||
|
if self.sleep_manager.is_sleeping():
|
||||||
|
logger.info(f"Bot正在睡觉,检查聊天流 {stream_id} 是否有唤醒触发器。")
|
||||||
|
|
||||||
|
was_woken_up = False
|
||||||
|
is_private = context.is_private_chat()
|
||||||
|
|
||||||
|
for message in unread_messages:
|
||||||
|
is_mentioned = message.is_mentioned or False
|
||||||
|
if not is_mentioned and not is_private:
|
||||||
|
bot_names = [global_config.bot.nickname] + global_config.bot.alias_names
|
||||||
|
if any(name in message.processed_plain_text for name in bot_names):
|
||||||
|
is_mentioned = True
|
||||||
|
logger.debug(f"通过关键词 '{next((name for name in bot_names if name in message.processed_plain_text), '')}' 匹配将消息标记为 'is_mentioned'")
|
||||||
|
|
||||||
|
if is_private or is_mentioned:
|
||||||
|
if self.wakeup_manager.add_wakeup_value(is_private, is_mentioned, chat_id=stream_id):
|
||||||
|
was_woken_up = True
|
||||||
|
break # 一旦被吵醒,就跳出循环并处理消息
|
||||||
|
|
||||||
|
if not was_woken_up:
|
||||||
|
logger.debug(f"聊天流 {stream_id} 中没有唤醒触发器,保持消息未读状态。")
|
||||||
|
return # 退出,不处理消息
|
||||||
|
|
||||||
|
logger.info(f"Bot被聊天流 {stream_id} 中的消息吵醒,继续处理。")
|
||||||
|
elif self.sleep_manager.is_woken_up():
|
||||||
|
angry_chat_id = self.wakeup_manager.angry_chat_id
|
||||||
|
if stream_id != angry_chat_id:
|
||||||
|
logger.debug(f"Bot处于WOKEN_UP状态,但当前流 {stream_id} 不是触发唤醒的流 {angry_chat_id},跳过处理。")
|
||||||
|
return # 退出,不处理此流的消息
|
||||||
|
logger.info(f"Bot处于WOKEN_UP状态,处理触发唤醒的流 {stream_id}。")
|
||||||
|
# --- 睡眠状态检查结束 ---
|
||||||
|
|
||||||
|
logger.debug(f"开始处理聊天流 {stream_id} 的 {len(unread_messages)} 条未读消息")
|
||||||
|
|
||||||
|
# 直接使用StreamContext对象进行处理
|
||||||
|
if unread_messages:
|
||||||
|
try:
|
||||||
|
# 记录当前chat type用于调试
|
||||||
|
logger.debug(f"聊天流 {stream_id} 检测到的chat type: {context.chat_type.value}")
|
||||||
|
|
||||||
|
# 发送到chatter manager,传递StreamContext对象
|
||||||
|
results = await self.chatter_manager.process_stream_context(stream_id, context)
|
||||||
|
|
||||||
|
# 处理结果,标记消息为已读
|
||||||
|
if results.get("success", False):
|
||||||
|
self._clear_all_unread_messages(stream_id)
|
||||||
|
logger.debug(f"聊天流 {stream_id} 处理成功,清除了 {len(unread_messages)} 条未读消息")
|
||||||
|
else:
|
||||||
|
logger.warning(f"聊天流 {stream_id} 处理失败: {results.get('error_message', '未知错误')}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理聊天流 {stream_id} 时发生异常,将清除所有未读消息: {e}")
|
||||||
|
# 出现异常时也清除未读消息,避免重复处理
|
||||||
|
self._clear_all_unread_messages(stream_id)
|
||||||
|
raise
|
||||||
|
|
||||||
|
logger.debug(f"聊天流 {stream_id} 消息处理完成")
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理聊天流 {stream_id} 消息时出错: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
def deactivate_stream(self, stream_id: str):
|
||||||
|
"""停用聊天流"""
|
||||||
|
context = self.context_manager.get_stream_context(stream_id)
|
||||||
|
if context:
|
||||||
|
context.is_active = False
|
||||||
|
|
||||||
|
# 取消处理任务
|
||||||
|
if hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done():
|
||||||
|
context.processing_task.cancel()
|
||||||
|
|
||||||
|
logger.info(f"停用聊天流: {stream_id}")
|
||||||
|
|
||||||
|
def activate_stream(self, stream_id: str):
|
||||||
|
"""激活聊天流"""
|
||||||
|
context = self.context_manager.get_stream_context(stream_id)
|
||||||
|
if context:
|
||||||
|
context.is_active = True
|
||||||
|
logger.info(f"激活聊天流: {stream_id}")
|
||||||
|
|
||||||
|
def get_stream_stats(self, stream_id: str) -> Optional[StreamStats]:
|
||||||
|
"""获取聊天流统计"""
|
||||||
|
context = self.context_manager.get_stream_context(stream_id)
|
||||||
|
if not context:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return StreamStats(
|
||||||
|
stream_id=stream_id,
|
||||||
|
is_active=context.is_active,
|
||||||
|
unread_count=len(self.context_manager.get_unread_messages(stream_id)),
|
||||||
|
history_count=len(context.history_messages),
|
||||||
|
last_check_time=context.last_check_time,
|
||||||
|
has_active_task=bool(hasattr(context, 'processing_task') and context.processing_task and not context.processing_task.done()),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_manager_stats(self) -> Dict[str, Any]:
|
||||||
|
"""获取管理器统计"""
|
||||||
|
return {
|
||||||
|
"total_streams": self.stats.total_streams,
|
||||||
|
"active_streams": self.stats.active_streams,
|
||||||
|
"total_unread_messages": self.stats.total_unread_messages,
|
||||||
|
"total_processed_messages": self.stats.total_processed_messages,
|
||||||
|
"uptime": self.stats.uptime,
|
||||||
|
"start_time": self.stats.start_time,
|
||||||
|
}
|
||||||
|
|
||||||
|
def cleanup_inactive_streams(self, max_inactive_hours: int = 24):
|
||||||
|
"""清理不活跃的聊天流"""
|
||||||
|
# 使用 context_manager 的自动清理功能
|
||||||
|
self.context_manager.cleanup_inactive_contexts(max_inactive_hours * 3600)
|
||||||
|
logger.info("已启动不活跃聊天流清理")
|
||||||
|
|
||||||
|
async def _check_and_handle_interruption(self, context: StreamContext, stream_id: str):
|
||||||
|
"""检查并处理消息打断"""
|
||||||
|
if not global_config.chat.interruption_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 检查是否有正在进行的处理任务
|
||||||
|
if context.processing_task and not context.processing_task.done():
|
||||||
|
# 计算打断概率
|
||||||
|
interruption_probability = context.calculate_interruption_probability(
|
||||||
|
global_config.chat.interruption_max_limit, global_config.chat.interruption_probability_factor
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查是否已达到最大打断次数
|
||||||
|
if context.interruption_count >= global_config.chat.interruption_max_limit:
|
||||||
|
logger.debug(
|
||||||
|
f"聊天流 {stream_id} 已达到最大打断次数 {context.interruption_count}/{global_config.chat.interruption_max_limit},跳过打断检查"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 根据概率决定是否打断
|
||||||
|
if random.random() < interruption_probability:
|
||||||
|
logger.info(f"聊天流 {stream_id} 触发消息打断,打断概率: {interruption_probability:.2f}")
|
||||||
|
|
||||||
|
# 取消现有任务
|
||||||
|
context.processing_task.cancel()
|
||||||
|
try:
|
||||||
|
await context.processing_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 增加打断计数并应用afc阈值降低
|
||||||
|
context.increment_interruption_count()
|
||||||
|
context.apply_interruption_afc_reduction(global_config.chat.interruption_afc_reduction)
|
||||||
|
|
||||||
|
# 检查是否已达到最大次数
|
||||||
|
if context.interruption_count >= global_config.chat.interruption_max_limit:
|
||||||
|
logger.warning(
|
||||||
|
f"聊天流 {stream_id} 已达到最大打断次数 {context.interruption_count}/{global_config.chat.interruption_max_limit},后续消息将不再打断"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"聊天流 {stream_id} 已打断,当前打断次数: {context.interruption_count}/{global_config.chat.interruption_max_limit}, afc阈值调整: {context.get_afc_threshold_adjustment()}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(f"聊天流 {stream_id} 未触发打断,打断概率: {interruption_probability:.2f}")
|
||||||
|
|
||||||
|
def _calculate_stream_distribution_interval(self, context: StreamContext) -> float:
|
||||||
|
"""计算单个聊天流的分发周期 - 使用重构后的能量管理器"""
|
||||||
|
if not global_config.chat.dynamic_distribution_enabled:
|
||||||
|
return self.check_interval # 使用固定间隔
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.chat.energy_system import energy_manager
|
||||||
|
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||||
|
|
||||||
|
# 获取聊天流和能量
|
||||||
|
chat_stream = get_chat_manager().get_stream(context.stream_id)
|
||||||
|
if chat_stream:
|
||||||
|
focus_energy = chat_stream.focus_energy
|
||||||
|
# 使用能量管理器获取分发周期
|
||||||
|
interval = energy_manager.get_distribution_interval(focus_energy)
|
||||||
|
logger.debug(f"流 {context.stream_id} 分发周期: {interval:.2f}s (能量: {focus_energy:.3f})")
|
||||||
|
return interval
|
||||||
|
else:
|
||||||
|
# 默认间隔
|
||||||
|
return self.check_interval
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"计算分发周期失败: {e}")
|
||||||
|
return self.check_interval
|
||||||
|
|
||||||
|
def _calculate_next_manager_delay(self) -> float:
|
||||||
|
"""计算管理器下次检查的延迟时间"""
|
||||||
|
current_time = time.time()
|
||||||
|
min_delay = float("inf")
|
||||||
|
|
||||||
|
# 找到最近需要检查的流
|
||||||
|
active_stream_ids = self.context_manager.get_active_streams()
|
||||||
|
for stream_id in active_stream_ids:
|
||||||
|
context = self.context_manager.get_stream_context(stream_id)
|
||||||
|
if not context or not context.is_active:
|
||||||
|
continue
|
||||||
|
|
||||||
|
time_until_check = context.next_check_time - current_time
|
||||||
|
if time_until_check > 0:
|
||||||
|
min_delay = min(min_delay, time_until_check)
|
||||||
|
else:
|
||||||
|
min_delay = 0.1 # 立即检查
|
||||||
|
break
|
||||||
|
|
||||||
|
# 如果没有活跃流,使用默认间隔
|
||||||
|
if min_delay == float("inf"):
|
||||||
|
return self.check_interval
|
||||||
|
|
||||||
|
# 确保最小延迟
|
||||||
|
return max(0.1, min(min_delay, self.check_interval))
|
||||||
|
|
||||||
|
async def _check_streams_with_individual_intervals(self):
|
||||||
|
"""检查所有达到检查时间的聊天流"""
|
||||||
|
current_time = time.time()
|
||||||
|
processed_streams = 0
|
||||||
|
|
||||||
|
# 使用 context_manager 获取活跃的流
|
||||||
|
active_stream_ids = self.context_manager.get_active_streams()
|
||||||
|
|
||||||
|
for stream_id in active_stream_ids:
|
||||||
|
context = self.context_manager.get_stream_context(stream_id)
|
||||||
|
if not context or not context.is_active:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查是否达到检查时间
|
||||||
|
if current_time >= context.next_check_time:
|
||||||
|
# 更新检查时间
|
||||||
|
context.last_check_time = current_time
|
||||||
|
|
||||||
|
# 计算下次检查时间和分发周期
|
||||||
|
if global_config.chat.dynamic_distribution_enabled:
|
||||||
|
context.distribution_interval = self._calculate_stream_distribution_interval(context)
|
||||||
|
else:
|
||||||
|
context.distribution_interval = self.check_interval
|
||||||
|
|
||||||
|
# 设置下次检查时间
|
||||||
|
context.next_check_time = current_time + context.distribution_interval
|
||||||
|
|
||||||
|
# 检查未读消息
|
||||||
|
unread_messages = self.context_manager.get_unread_messages(stream_id)
|
||||||
|
if unread_messages:
|
||||||
|
processed_streams += 1
|
||||||
|
self.stats.total_unread_messages = len(unread_messages)
|
||||||
|
|
||||||
|
# 如果没有处理任务,创建一个
|
||||||
|
if not context.processing_task or context.processing_task.done():
|
||||||
|
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||||
|
|
||||||
|
chat_stream = get_chat_manager().get_stream(context.stream_id)
|
||||||
|
focus_energy = chat_stream.focus_energy if chat_stream else 0.5
|
||||||
|
|
||||||
|
# 根据优先级记录日志
|
||||||
|
if focus_energy >= 0.7:
|
||||||
|
logger.info(
|
||||||
|
f"高优先级流 {stream_id} 开始处理 | "
|
||||||
|
f"focus_energy: {focus_energy:.3f} | "
|
||||||
|
f"分发周期: {context.distribution_interval:.2f}s | "
|
||||||
|
f"未读消息: {len(unread_messages)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"流 {stream_id} 开始处理 | "
|
||||||
|
f"focus_energy: {focus_energy:.3f} | "
|
||||||
|
f"分发周期: {context.distribution_interval:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
|
||||||
|
|
||||||
|
# 更新活跃流计数
|
||||||
|
active_count = len(self.context_manager.get_active_streams())
|
||||||
|
self.stats.active_streams = active_count
|
||||||
|
|
||||||
|
if processed_streams > 0:
|
||||||
|
logger.debug(f"本次循环处理了 {processed_streams} 个流 | 活跃流总数: {active_count}")
|
||||||
|
|
||||||
|
async def _check_all_streams_with_priority(self):
|
||||||
|
"""按优先级检查所有聊天流,高focus_energy的流优先处理"""
|
||||||
|
if not self.context_manager.get_active_streams():
|
||||||
|
return
|
||||||
|
|
||||||
|
# 获取活跃的聊天流并按focus_energy排序
|
||||||
|
active_streams = []
|
||||||
|
active_stream_ids = self.context_manager.get_active_streams()
|
||||||
|
|
||||||
|
for stream_id in active_stream_ids:
|
||||||
|
context = self.context_manager.get_stream_context(stream_id)
|
||||||
|
if not context or not context.is_active:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 获取focus_energy,如果不存在则使用默认值
|
||||||
|
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||||
|
|
||||||
|
chat_stream = get_chat_manager().get_stream(context.stream_id)
|
||||||
|
focus_energy = 0.5
|
||||||
|
if chat_stream:
|
||||||
|
focus_energy = chat_stream.focus_energy
|
||||||
|
|
||||||
|
# 计算流优先级分数
|
||||||
|
priority_score = self._calculate_stream_priority(context, focus_energy)
|
||||||
|
active_streams.append((priority_score, stream_id, context))
|
||||||
|
|
||||||
|
# 按优先级降序排序
|
||||||
|
active_streams.sort(reverse=True, key=lambda x: x[0])
|
||||||
|
|
||||||
|
# 处理排序后的流
|
||||||
|
active_stream_count = 0
|
||||||
|
total_unread = 0
|
||||||
|
|
||||||
|
for priority_score, stream_id, context in active_streams:
|
||||||
|
active_stream_count += 1
|
||||||
|
|
||||||
|
# 检查是否有未读消息
|
||||||
|
unread_messages = self.context_manager.get_unread_messages(stream_id)
|
||||||
|
if unread_messages:
|
||||||
|
total_unread += len(unread_messages)
|
||||||
|
|
||||||
|
# 如果没有处理任务,创建一个
|
||||||
|
if not hasattr(context, 'processing_task') or not context.processing_task or context.processing_task.done():
|
||||||
|
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
|
||||||
|
|
||||||
|
# 高优先级流的额外日志
|
||||||
|
if priority_score > 0.7:
|
||||||
|
logger.info(
|
||||||
|
f"高优先级流 {stream_id} 开始处理 | "
|
||||||
|
f"优先级: {priority_score:.3f} | "
|
||||||
|
f"未读消息: {len(unread_messages)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新统计
|
||||||
|
self.stats.active_streams = active_stream_count
|
||||||
|
self.stats.total_unread_messages = total_unread
|
||||||
|
|
||||||
|
def _calculate_stream_priority(self, context: StreamContext, focus_energy: float) -> float:
|
||||||
|
"""计算聊天流的优先级分数 - 简化版本,主要使用focus_energy"""
|
||||||
|
# 使用重构后的能量管理器,主要依赖focus_energy
|
||||||
|
base_priority = focus_energy
|
||||||
|
|
||||||
|
# 简单的未读消息加权
|
||||||
|
unread_count = len(context.get_unread_messages())
|
||||||
|
message_bonus = min(unread_count * 0.05, 0.2) # 最多20%加成
|
||||||
|
|
||||||
|
# 简单的时间加权
|
||||||
|
current_time = time.time()
|
||||||
|
time_since_active = current_time - context.last_check_time
|
||||||
|
time_bonus = max(0, 1.0 - time_since_active / 7200.0) * 0.1 # 2小时内衰减
|
||||||
|
|
||||||
|
final_priority = base_priority + message_bonus + time_bonus
|
||||||
|
return max(0.0, min(1.0, final_priority))
|
||||||
|
|
||||||
|
def _clear_all_unread_messages(self, stream_id: str):
|
||||||
|
"""清除指定上下文中的所有未读消息,防止意外情况导致消息一直未读"""
|
||||||
|
unread_messages = self.context_manager.get_unread_messages(stream_id)
|
||||||
|
if not unread_messages:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.warning(f"正在清除 {len(unread_messages)} 条未读消息")
|
||||||
|
|
||||||
|
# 将所有未读消息标记为已读
|
||||||
|
context = self.context_manager.get_stream_context(stream_id)
|
||||||
|
if context:
|
||||||
|
for msg in unread_messages[:]: # 使用切片复制避免迭代时修改列表
|
||||||
|
try:
|
||||||
|
context.mark_message_as_read(msg.message_id)
|
||||||
|
self.stats.total_processed_messages += 1
|
||||||
|
logger.debug(f"强制清除消息 {msg.message_id},标记为已读")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"清除消息 {msg.message_id} 时出错: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# 创建全局消息管理器实例
|
||||||
|
message_manager = MessageManager()
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
#from ..hfc_context import HfcContext
|
||||||
|
|
||||||
|
logger = get_logger("notification_sender")
|
||||||
|
|
||||||
|
|
||||||
|
class NotificationSender:
|
||||||
|
@staticmethod
|
||||||
|
async def send_goodnight_notification(context): # type: ignore
|
||||||
|
"""发送晚安通知"""
|
||||||
|
#try:
|
||||||
|
#from ..proactive.events import ProactiveTriggerEvent
|
||||||
|
#from ..proactive.proactive_thinker import ProactiveThinker
|
||||||
|
|
||||||
|
#event = ProactiveTriggerEvent(source="sleep_manager", reason="goodnight")
|
||||||
|
#proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor)
|
||||||
|
#await proactive_thinker.think(event)
|
||||||
|
#except Exception as e:
|
||||||
|
#logger.error(f"发送晚安通知失败: {e}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def send_insomnia_notification(context, reason: str): # type: ignore
|
||||||
|
"""发送失眠通知"""
|
||||||
|
#try:
|
||||||
|
#from ..proactive.events import ProactiveTriggerEvent
|
||||||
|
#from ..proactive.proactive_thinker import ProactiveThinker
|
||||||
|
|
||||||
|
#event = ProactiveTriggerEvent(source="sleep_manager", reason=reason)
|
||||||
|
#proactive_thinker = ProactiveThinker(context, context.chat_instance.cycle_processor)
|
||||||
|
#await proactive_thinker.think(event)
|
||||||
|
#except Exception as e:
|
||||||
|
#logger.error(f"发送失眠通知失败: {e}")
|
||||||
@@ -6,11 +6,11 @@ from typing import Optional, TYPE_CHECKING
|
|||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from .notification_sender import NotificationSender
|
from .notification_sender import NotificationSender
|
||||||
from .sleep_state import SleepState, SleepStateSerializer
|
from .sleep_state import SleepState, SleepContext
|
||||||
from .time_checker import TimeChecker
|
from .time_checker import TimeChecker
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
pass
|
from .wakeup_manager import WakeUpManager
|
||||||
|
|
||||||
logger = get_logger("sleep_manager")
|
logger = get_logger("sleep_manager")
|
||||||
|
|
||||||
@@ -25,28 +25,23 @@ class SleepManager:
|
|||||||
"""
|
"""
|
||||||
初始化睡眠管理器。
|
初始化睡眠管理器。
|
||||||
"""
|
"""
|
||||||
self.time_checker = TimeChecker() # 时间检查器,用于判断当前是否处于理论睡眠时间
|
self.context = SleepContext() # 睡眠上下文,管理所有状态
|
||||||
|
self.time_checker = TimeChecker() # 时间检查器
|
||||||
self.last_sleep_log_time = 0 # 上次记录睡眠日志的时间戳
|
self.last_sleep_log_time = 0 # 上次记录睡眠日志的时间戳
|
||||||
self.sleep_log_interval = 35 # 睡眠日志记录间隔(秒)
|
self.sleep_log_interval = 35 # 睡眠日志记录间隔(秒)
|
||||||
|
|
||||||
# --- 统一睡眠状态管理 ---
|
|
||||||
self._current_state: SleepState = SleepState.AWAKE # 当前睡眠状态
|
|
||||||
self._sleep_buffer_end_time: Optional[datetime] = None # 睡眠缓冲结束时间,用于状态转换
|
|
||||||
self._total_delayed_minutes_today: float = 0.0 # 今天总共延迟入睡的分钟数
|
|
||||||
self._last_sleep_check_date: Optional[date] = None # 上次检查睡眠状态的日期
|
|
||||||
self._last_fully_slept_log_time: float = 0 # 上次完全进入睡眠状态的时间戳
|
self._last_fully_slept_log_time: float = 0 # 上次完全进入睡眠状态的时间戳
|
||||||
self._re_sleep_attempt_time: Optional[datetime] = None # 被吵醒后,尝试重新入睡的时间点
|
|
||||||
|
|
||||||
# 从本地存储加载上一次的睡眠状态
|
|
||||||
self._load_sleep_state()
|
|
||||||
|
|
||||||
def get_current_sleep_state(self) -> SleepState:
|
def get_current_sleep_state(self) -> SleepState:
|
||||||
"""获取当前的睡眠状态。"""
|
"""获取当前的睡眠状态。"""
|
||||||
return self._current_state
|
return self.context.current_state
|
||||||
|
|
||||||
def is_sleeping(self) -> bool:
|
def is_sleeping(self) -> bool:
|
||||||
"""判断当前是否处于正在睡觉的状态。"""
|
"""判断当前是否处于正在睡觉的状态。"""
|
||||||
return self._current_state == SleepState.SLEEPING
|
return self.context.current_state == SleepState.SLEEPING
|
||||||
|
|
||||||
|
def is_woken_up(self) -> bool:
|
||||||
|
"""判断当前是否处于被吵醒的状态。"""
|
||||||
|
return self.context.current_state == SleepState.WOKEN_UP
|
||||||
|
|
||||||
async def update_sleep_state(self, wakeup_manager: Optional["WakeUpManager"] = None):
|
async def update_sleep_state(self, wakeup_manager: Optional["WakeUpManager"] = None):
|
||||||
"""
|
"""
|
||||||
@@ -58,41 +53,42 @@ class SleepManager:
|
|||||||
"""
|
"""
|
||||||
# 如果全局禁用了睡眠系统,则强制设置为清醒状态并返回
|
# 如果全局禁用了睡眠系统,则强制设置为清醒状态并返回
|
||||||
if not global_config.sleep_system.enable:
|
if not global_config.sleep_system.enable:
|
||||||
if self._current_state != SleepState.AWAKE:
|
if self.context.current_state != SleepState.AWAKE:
|
||||||
logger.debug("睡眠系统禁用,强制设为 AWAKE")
|
logger.debug("睡眠系统禁用,强制设为 AWAKE")
|
||||||
self._current_state = SleepState.AWAKE
|
self.context.current_state = SleepState.AWAKE
|
||||||
return
|
return
|
||||||
|
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
today = now.date()
|
today = now.date()
|
||||||
|
|
||||||
# 跨天处理:如果日期变化,重置每日相关的睡眠状态
|
# 跨天处理:如果日期变化,重置每日相关的睡眠状态
|
||||||
if self._last_sleep_check_date != today:
|
if self.context.last_sleep_check_date != today:
|
||||||
logger.info(f"新的一天 ({today}),重置睡眠状态。")
|
logger.info(f"新的一天 ({today}),重置睡眠状态。")
|
||||||
self._total_delayed_minutes_today = 0
|
self.context.total_delayed_minutes_today = 0
|
||||||
self._current_state = SleepState.AWAKE
|
self.context.current_state = SleepState.AWAKE
|
||||||
self._sleep_buffer_end_time = None
|
self.context.sleep_buffer_end_time = None
|
||||||
self._last_sleep_check_date = today
|
self.context.last_sleep_check_date = today
|
||||||
self._save_sleep_state()
|
self.context.save()
|
||||||
|
|
||||||
# 检查当前是否处于理论上的睡眠时间段
|
# 检查当前是否处于理论上的睡眠时间段
|
||||||
is_in_theoretical_sleep, activity = self.time_checker.is_in_theoretical_sleep_time(now.time())
|
is_in_theoretical_sleep, activity = self.time_checker.is_in_theoretical_sleep_time(now.time())
|
||||||
|
|
||||||
# --- 状态机核心处理逻辑 ---
|
# --- 状态机核心处理逻辑 ---
|
||||||
if self._current_state == SleepState.AWAKE:
|
current_state = self.context.current_state
|
||||||
|
if current_state == SleepState.AWAKE:
|
||||||
if is_in_theoretical_sleep:
|
if is_in_theoretical_sleep:
|
||||||
self._handle_awake_to_sleep(now, activity, wakeup_manager)
|
self._handle_awake_to_sleep(now, activity, wakeup_manager)
|
||||||
|
|
||||||
elif self._current_state == SleepState.PREPARING_SLEEP:
|
elif current_state == SleepState.PREPARING_SLEEP:
|
||||||
self._handle_preparing_sleep(now, is_in_theoretical_sleep, wakeup_manager)
|
self._handle_preparing_sleep(now, is_in_theoretical_sleep, wakeup_manager)
|
||||||
|
|
||||||
elif self._current_state == SleepState.SLEEPING:
|
elif current_state == SleepState.SLEEPING:
|
||||||
self._handle_sleeping(now, is_in_theoretical_sleep, activity, wakeup_manager)
|
self._handle_sleeping(now, is_in_theoretical_sleep, activity, wakeup_manager)
|
||||||
|
|
||||||
elif self._current_state == SleepState.INSOMNIA:
|
elif current_state == SleepState.INSOMNIA:
|
||||||
self._handle_insomnia(now, is_in_theoretical_sleep)
|
self._handle_insomnia(now, is_in_theoretical_sleep)
|
||||||
|
|
||||||
elif self._current_state == SleepState.WOKEN_UP:
|
elif current_state == SleepState.WOKEN_UP:
|
||||||
self._handle_woken_up(now, is_in_theoretical_sleep, wakeup_manager)
|
self._handle_woken_up(now, is_in_theoretical_sleep, wakeup_manager)
|
||||||
|
|
||||||
def _handle_awake_to_sleep(self, now: datetime, activity: Optional[str], wakeup_manager: Optional["WakeUpManager"]):
|
def _handle_awake_to_sleep(self, now: datetime, activity: Optional[str], wakeup_manager: Optional["WakeUpManager"]):
|
||||||
@@ -118,13 +114,13 @@ class SleepManager:
|
|||||||
delay_minutes = int(pressure_diff * max_delay_minutes)
|
delay_minutes = int(pressure_diff * max_delay_minutes)
|
||||||
|
|
||||||
# 确保总延迟不超过当日最大值
|
# 确保总延迟不超过当日最大值
|
||||||
remaining_delay = max_delay_minutes - self._total_delayed_minutes_today
|
remaining_delay = max_delay_minutes - self.context.total_delayed_minutes_today
|
||||||
delay_minutes = min(delay_minutes, remaining_delay)
|
delay_minutes = min(delay_minutes, remaining_delay)
|
||||||
|
|
||||||
if delay_minutes > 0:
|
if delay_minutes > 0:
|
||||||
# 增加一些随机性
|
# 增加一些随机性
|
||||||
buffer_seconds = random.randint(int(delay_minutes * 0.8 * 60), int(delay_minutes * 1.2 * 60))
|
buffer_seconds = random.randint(int(delay_minutes * 0.8 * 60), int(delay_minutes * 1.2 * 60))
|
||||||
self._total_delayed_minutes_today += buffer_seconds / 60.0
|
self.context.total_delayed_minutes_today += buffer_seconds / 60.0
|
||||||
logger.info(f"睡眠压力 ({sleep_pressure:.1f}) 较低,延迟 {buffer_seconds / 60:.1f} 分钟入睡。")
|
logger.info(f"睡眠压力 ({sleep_pressure:.1f}) 较低,延迟 {buffer_seconds / 60:.1f} 分钟入睡。")
|
||||||
else:
|
else:
|
||||||
# 延迟额度已用完,设置一个较短的准备时间
|
# 延迟额度已用完,设置一个较短的准备时间
|
||||||
@@ -139,22 +135,22 @@ class SleepManager:
|
|||||||
if global_config.sleep_system.enable_pre_sleep_notification:
|
if global_config.sleep_system.enable_pre_sleep_notification:
|
||||||
asyncio.create_task(NotificationSender.send_goodnight_notification(wakeup_manager.context))
|
asyncio.create_task(NotificationSender.send_goodnight_notification(wakeup_manager.context))
|
||||||
|
|
||||||
self._sleep_buffer_end_time = now + timedelta(seconds=buffer_seconds)
|
self.context.sleep_buffer_end_time = now + timedelta(seconds=buffer_seconds)
|
||||||
self._current_state = SleepState.PREPARING_SLEEP
|
self.context.current_state = SleepState.PREPARING_SLEEP
|
||||||
logger.info(f"进入准备入睡状态,将在 {buffer_seconds / 60:.1f} 分钟内入睡。")
|
logger.info(f"进入准备入睡状态,将在 {buffer_seconds / 60:.1f} 分钟内入睡。")
|
||||||
self._save_sleep_state()
|
self.context.save()
|
||||||
else:
|
else:
|
||||||
# 无法获取 wakeup_manager,退回旧逻辑
|
# 无法获取 wakeup_manager,退回旧逻辑
|
||||||
buffer_seconds = random.randint(1 * 60, 3 * 60)
|
buffer_seconds = random.randint(1 * 60, 3 * 60)
|
||||||
self._sleep_buffer_end_time = now + timedelta(seconds=buffer_seconds)
|
self.context.sleep_buffer_end_time = now + timedelta(seconds=buffer_seconds)
|
||||||
self._current_state = SleepState.PREPARING_SLEEP
|
self.context.current_state = SleepState.PREPARING_SLEEP
|
||||||
logger.warning("无法获取 WakeUpManager,弹性睡眠采用默认1-3分钟延迟。")
|
logger.warning("无法获取 WakeUpManager,弹性睡眠采用默认1-3分钟延迟。")
|
||||||
self._save_sleep_state()
|
self.context.save()
|
||||||
else:
|
else:
|
||||||
# 非弹性睡眠模式
|
# 非弹性睡眠模式
|
||||||
if wakeup_manager and global_config.sleep_system.enable_pre_sleep_notification:
|
if wakeup_manager and global_config.sleep_system.enable_pre_sleep_notification:
|
||||||
asyncio.create_task(NotificationSender.send_goodnight_notification(wakeup_manager.context))
|
asyncio.create_task(NotificationSender.send_goodnight_notification(wakeup_manager.context))
|
||||||
self._current_state = SleepState.SLEEPING
|
self.context.current_state = SleepState.SLEEPING
|
||||||
|
|
||||||
|
|
||||||
def _handle_preparing_sleep(self, now: datetime, is_in_theoretical_sleep: bool, wakeup_manager: Optional["WakeUpManager"]):
|
def _handle_preparing_sleep(self, now: datetime, is_in_theoretical_sleep: bool, wakeup_manager: Optional["WakeUpManager"]):
|
||||||
@@ -162,32 +158,32 @@ class SleepManager:
|
|||||||
# 如果在准备期间离开了理论睡眠时间,则取消入睡
|
# 如果在准备期间离开了理论睡眠时间,则取消入睡
|
||||||
if not is_in_theoretical_sleep:
|
if not is_in_theoretical_sleep:
|
||||||
logger.info("准备入睡期间离开理论休眠时间,取消入睡,恢复清醒。")
|
logger.info("准备入睡期间离开理论休眠时间,取消入睡,恢复清醒。")
|
||||||
self._current_state = SleepState.AWAKE
|
self.context.current_state = SleepState.AWAKE
|
||||||
self._sleep_buffer_end_time = None
|
self.context.sleep_buffer_end_time = None
|
||||||
self._save_sleep_state()
|
self.context.save()
|
||||||
# 如果缓冲时间结束,则正式进入睡眠状态
|
# 如果缓冲时间结束,则正式进入睡眠状态
|
||||||
elif self._sleep_buffer_end_time and now >= self._sleep_buffer_end_time:
|
elif self.context.sleep_buffer_end_time and now >= self.context.sleep_buffer_end_time:
|
||||||
logger.info("睡眠缓冲期结束,正式进入休眠状态。")
|
logger.info("睡眠缓冲期结束,正式进入休眠状态。")
|
||||||
self._current_state = SleepState.SLEEPING
|
self.context.current_state = SleepState.SLEEPING
|
||||||
self._last_fully_slept_log_time = now.timestamp()
|
self._last_fully_slept_log_time = now.timestamp()
|
||||||
|
|
||||||
# 设置一个随机的延迟,用于触发“睡后失眠”检查
|
# 设置一个随机的延迟,用于触发“睡后失眠”检查
|
||||||
delay_minutes_range = global_config.sleep_system.insomnia_trigger_delay_minutes
|
delay_minutes_range = global_config.sleep_system.insomnia_trigger_delay_minutes
|
||||||
delay_minutes = random.randint(delay_minutes_range[0], delay_minutes_range[1])
|
delay_minutes = random.randint(delay_minutes_range[0], delay_minutes_range[1])
|
||||||
self._sleep_buffer_end_time = now + timedelta(minutes=delay_minutes)
|
self.context.sleep_buffer_end_time = now + timedelta(minutes=delay_minutes)
|
||||||
logger.info(f"已设置睡后失眠检查,将在 {delay_minutes} 分钟后触发。")
|
logger.info(f"已设置睡后失眠检查,将在 {delay_minutes} 分钟后触发。")
|
||||||
|
|
||||||
self._save_sleep_state()
|
self.context.save()
|
||||||
|
|
||||||
def _handle_sleeping(self, now: datetime, is_in_theoretical_sleep: bool, activity: Optional[str], wakeup_manager: Optional["WakeUpManager"]):
|
def _handle_sleeping(self, now: datetime, is_in_theoretical_sleep: bool, activity: Optional[str], wakeup_manager: Optional["WakeUpManager"]):
|
||||||
"""处理“正在睡觉”状态下的逻辑。"""
|
"""处理“正在睡觉”状态下的逻辑。"""
|
||||||
# 如果理论睡眠时间结束,则自然醒来
|
# 如果理论睡眠时间结束,则自然醒来
|
||||||
if not is_in_theoretical_sleep:
|
if not is_in_theoretical_sleep:
|
||||||
logger.info("理论休眠时间结束,自然醒来。")
|
logger.info("理论休眠时间结束,自然醒来。")
|
||||||
self._current_state = SleepState.AWAKE
|
self.context.current_state = SleepState.AWAKE
|
||||||
self._save_sleep_state()
|
self.context.save()
|
||||||
# 检查是否到了触发“睡后失眠”的时间点
|
# 检查是否到了触发“睡后失眠”的时间点
|
||||||
elif self._sleep_buffer_end_time and now >= self._sleep_buffer_end_time:
|
elif self.context.sleep_buffer_end_time and now >= self.context.sleep_buffer_end_time:
|
||||||
if wakeup_manager:
|
if wakeup_manager:
|
||||||
sleep_pressure = wakeup_manager.context.sleep_pressure
|
sleep_pressure = wakeup_manager.context.sleep_pressure
|
||||||
pressure_threshold = global_config.sleep_system.flexible_sleep_pressure_threshold
|
pressure_threshold = global_config.sleep_system.flexible_sleep_pressure_threshold
|
||||||
@@ -201,12 +197,12 @@ class SleepManager:
|
|||||||
logger.info("随机触发失眠。")
|
logger.info("随机触发失眠。")
|
||||||
|
|
||||||
if insomnia_reason:
|
if insomnia_reason:
|
||||||
self._current_state = SleepState.INSOMNIA
|
self.context.current_state = SleepState.INSOMNIA
|
||||||
|
|
||||||
# 设置失眠的持续时间
|
# 设置失眠的持续时间
|
||||||
duration_minutes_range = global_config.sleep_system.insomnia_duration_minutes
|
duration_minutes_range = global_config.sleep_system.insomnia_duration_minutes
|
||||||
duration_minutes = random.randint(*duration_minutes_range)
|
duration_minutes = random.randint(*duration_minutes_range)
|
||||||
self._sleep_buffer_end_time = now + timedelta(minutes=duration_minutes)
|
self.context.sleep_buffer_end_time = now + timedelta(minutes=duration_minutes)
|
||||||
|
|
||||||
# 发送失眠通知
|
# 发送失眠通知
|
||||||
asyncio.create_task(NotificationSender.send_insomnia_notification(wakeup_manager.context, insomnia_reason))
|
asyncio.create_task(NotificationSender.send_insomnia_notification(wakeup_manager.context, insomnia_reason))
|
||||||
@@ -214,8 +210,8 @@ class SleepManager:
|
|||||||
else:
|
else:
|
||||||
# 睡眠压力正常,不触发失眠,清除检查时间点
|
# 睡眠压力正常,不触发失眠,清除检查时间点
|
||||||
logger.info(f"睡眠压力 ({sleep_pressure:.1f}) 正常,未触发睡后失眠。")
|
logger.info(f"睡眠压力 ({sleep_pressure:.1f}) 正常,未触发睡后失眠。")
|
||||||
self._sleep_buffer_end_time = None
|
self.context.sleep_buffer_end_time = None
|
||||||
self._save_sleep_state()
|
self.context.save()
|
||||||
else:
|
else:
|
||||||
# 定期记录睡眠日志
|
# 定期记录睡眠日志
|
||||||
current_timestamp = now.timestamp()
|
current_timestamp = now.timestamp()
|
||||||
@@ -228,26 +224,26 @@ class SleepManager:
|
|||||||
# 如果离开理论睡眠时间,则失眠结束
|
# 如果离开理论睡眠时间,则失眠结束
|
||||||
if not is_in_theoretical_sleep:
|
if not is_in_theoretical_sleep:
|
||||||
logger.info("已离开理论休眠时间,失眠结束,恢复清醒。")
|
logger.info("已离开理论休眠时间,失眠结束,恢复清醒。")
|
||||||
self._current_state = SleepState.AWAKE
|
self.context.current_state = SleepState.AWAKE
|
||||||
self._sleep_buffer_end_time = None
|
self.context.sleep_buffer_end_time = None
|
||||||
self._save_sleep_state()
|
self.context.save()
|
||||||
# 如果失眠持续时间已过,则恢复睡眠
|
# 如果失眠持续时间已过,则恢复睡眠
|
||||||
elif self._sleep_buffer_end_time and now >= self._sleep_buffer_end_time:
|
elif self.context.sleep_buffer_end_time and now >= self.context.sleep_buffer_end_time:
|
||||||
logger.info("失眠状态持续时间已过,恢复睡眠。")
|
logger.info("失眠状态持续时间已过,恢复睡眠。")
|
||||||
self._current_state = SleepState.SLEEPING
|
self.context.current_state = SleepState.SLEEPING
|
||||||
self._sleep_buffer_end_time = None
|
self.context.sleep_buffer_end_time = None
|
||||||
self._save_sleep_state()
|
self.context.save()
|
||||||
|
|
||||||
def _handle_woken_up(self, now: datetime, is_in_theoretical_sleep: bool, wakeup_manager: Optional["WakeUpManager"]):
|
def _handle_woken_up(self, now: datetime, is_in_theoretical_sleep: bool, wakeup_manager: Optional["WakeUpManager"]):
|
||||||
"""处理“被吵醒”状态下的逻辑。"""
|
"""处理“被吵醒”状态下的逻辑。"""
|
||||||
# 如果理论睡眠时间结束,则状态自动结束
|
# 如果理论睡眠时间结束,则状态自动结束
|
||||||
if not is_in_theoretical_sleep:
|
if not is_in_theoretical_sleep:
|
||||||
logger.info("理论休眠时间结束,被吵醒的状态自动结束。")
|
logger.info("理论休眠时间结束,被吵醒的状态自动结束。")
|
||||||
self._current_state = SleepState.AWAKE
|
self.context.current_state = SleepState.AWAKE
|
||||||
self._re_sleep_attempt_time = None
|
self.context.re_sleep_attempt_time = None
|
||||||
self._save_sleep_state()
|
self.context.save()
|
||||||
# 到了尝试重新入睡的时间点
|
# 到了尝试重新入睡的时间点
|
||||||
elif self._re_sleep_attempt_time and now >= self._re_sleep_attempt_time:
|
elif self.context.re_sleep_attempt_time and now >= self.context.re_sleep_attempt_time:
|
||||||
logger.info("被吵醒后经过一段时间,尝试重新入睡...")
|
logger.info("被吵醒后经过一段时间,尝试重新入睡...")
|
||||||
if wakeup_manager:
|
if wakeup_manager:
|
||||||
sleep_pressure = wakeup_manager.context.sleep_pressure
|
sleep_pressure = wakeup_manager.context.sleep_pressure
|
||||||
@@ -257,48 +253,28 @@ class SleepManager:
|
|||||||
if sleep_pressure >= pressure_threshold:
|
if sleep_pressure >= pressure_threshold:
|
||||||
logger.info("睡眠压力足够,从被吵醒状态转换到准备入睡。")
|
logger.info("睡眠压力足够,从被吵醒状态转换到准备入睡。")
|
||||||
buffer_seconds = random.randint(3 * 60, 8 * 60)
|
buffer_seconds = random.randint(3 * 60, 8 * 60)
|
||||||
self._sleep_buffer_end_time = now + timedelta(seconds=buffer_seconds)
|
self.context.sleep_buffer_end_time = now + timedelta(seconds=buffer_seconds)
|
||||||
self._current_state = SleepState.PREPARING_SLEEP
|
self.context.current_state = SleepState.PREPARING_SLEEP
|
||||||
self._re_sleep_attempt_time = None
|
self.context.re_sleep_attempt_time = None
|
||||||
else:
|
else:
|
||||||
# 睡眠压力不足,延迟一段时间后再次尝试
|
# 睡眠压力不足,延迟一段时间后再次尝试
|
||||||
delay_minutes = 15
|
delay_minutes = 15
|
||||||
self._re_sleep_attempt_time = now + timedelta(minutes=delay_minutes)
|
self.context.re_sleep_attempt_time = now + timedelta(minutes=delay_minutes)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"睡眠压力({sleep_pressure:.1f})仍然较低,暂时保持清醒,在 {delay_minutes} 分钟后再次尝试。"
|
f"睡眠压力({sleep_pressure:.1f})仍然较低,暂时保持清醒,在 {delay_minutes} 分钟后再次尝试。"
|
||||||
)
|
)
|
||||||
self._save_sleep_state()
|
self.context.save()
|
||||||
|
|
||||||
def reset_sleep_state_after_wakeup(self):
|
def reset_sleep_state_after_wakeup(self):
|
||||||
"""
|
"""
|
||||||
当角色被用户消息等外部因素唤醒时调用此方法。
|
当角色被用户消息等外部因素唤醒时调用此方法。
|
||||||
将状态强制转换为 WOKEN_UP,并设置一个延迟,之后会尝试重新入睡。
|
将状态强制转换为 WOKEN_UP,并设置一个延迟,之后会尝试重新入睡。
|
||||||
"""
|
"""
|
||||||
if self._current_state in [SleepState.PREPARING_SLEEP, SleepState.SLEEPING, SleepState.INSOMNIA]:
|
if self.context.current_state in [SleepState.PREPARING_SLEEP, SleepState.SLEEPING, SleepState.INSOMNIA]:
|
||||||
logger.info("被唤醒,进入 WOKEN_UP 状态!")
|
logger.info("被唤醒,进入 WOKEN_UP 状态!")
|
||||||
self._current_state = SleepState.WOKEN_UP
|
self.context.current_state = SleepState.WOKEN_UP
|
||||||
self._sleep_buffer_end_time = None
|
self.context.sleep_buffer_end_time = None
|
||||||
re_sleep_delay_minutes = getattr(global_config.sleep_system, "re_sleep_delay_minutes", 10)
|
re_sleep_delay_minutes = getattr(global_config.sleep_system, "re_sleep_delay_minutes", 10)
|
||||||
self._re_sleep_attempt_time = datetime.now() + timedelta(minutes=re_sleep_delay_minutes)
|
self.context.re_sleep_attempt_time = datetime.now() + timedelta(minutes=re_sleep_delay_minutes)
|
||||||
logger.info(f"将在 {re_sleep_delay_minutes} 分钟后尝试重新入睡。")
|
logger.info(f"将在 {re_sleep_delay_minutes} 分钟后尝试重新入睡。")
|
||||||
self._save_sleep_state()
|
self.context.save()
|
||||||
|
|
||||||
def _save_sleep_state(self):
|
|
||||||
"""将当前所有睡眠相关的状态打包并保存到本地存储。"""
|
|
||||||
state_data = {
|
|
||||||
"_current_state": self._current_state,
|
|
||||||
"_sleep_buffer_end_time": self._sleep_buffer_end_time,
|
|
||||||
"_total_delayed_minutes_today": self._total_delayed_minutes_today,
|
|
||||||
"_last_sleep_check_date": self._last_sleep_check_date,
|
|
||||||
"_re_sleep_attempt_time": self._re_sleep_attempt_time,
|
|
||||||
}
|
|
||||||
SleepStateSerializer.save(state_data)
|
|
||||||
|
|
||||||
def _load_sleep_state(self):
|
|
||||||
"""从本地存储加载并恢复所有睡眠相关的状态。"""
|
|
||||||
state_data = SleepStateSerializer.load()
|
|
||||||
self._current_state = state_data["_current_state"]
|
|
||||||
self._sleep_buffer_end_time = state_data["_sleep_buffer_end_time"]
|
|
||||||
self._total_delayed_minutes_today = state_data["_total_delayed_minutes_today"]
|
|
||||||
self._last_sleep_check_date = state_data["_last_sleep_check_date"]
|
|
||||||
self._re_sleep_attempt_time = state_data["_re_sleep_attempt_time"]
|
|
||||||
86
src/chat/message_manager/sleep_manager/sleep_state.py
Normal file
86
src/chat/message_manager/sleep_manager/sleep_state.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
from enum import Enum, auto
|
||||||
|
from datetime import datetime, date
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.manager.local_store_manager import local_storage
|
||||||
|
|
||||||
|
logger = get_logger("sleep_state")
|
||||||
|
|
||||||
|
|
||||||
|
class SleepState(Enum):
|
||||||
|
"""
|
||||||
|
定义了角色可能处于的几种睡眠状态。
|
||||||
|
这是一个状态机,用于管理角色的睡眠周期。
|
||||||
|
"""
|
||||||
|
|
||||||
|
AWAKE = auto() # 清醒状态
|
||||||
|
INSOMNIA = auto() # 失眠状态
|
||||||
|
PREPARING_SLEEP = auto() # 准备入睡状态,一个短暂的过渡期
|
||||||
|
SLEEPING = auto() # 正在睡觉状态
|
||||||
|
WOKEN_UP = auto() # 被吵醒状态
|
||||||
|
|
||||||
|
|
||||||
|
class SleepContext:
|
||||||
|
"""
|
||||||
|
睡眠上下文,负责封装和管理所有与睡眠相关的状态,并处理其持久化。
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
"""初始化睡眠上下文,并从本地存储加载初始状态。"""
|
||||||
|
self.current_state: SleepState = SleepState.AWAKE
|
||||||
|
self.sleep_buffer_end_time: Optional[datetime] = None
|
||||||
|
self.total_delayed_minutes_today: float = 0.0
|
||||||
|
self.last_sleep_check_date: Optional[date] = None
|
||||||
|
self.re_sleep_attempt_time: Optional[datetime] = None
|
||||||
|
self.load()
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
"""将当前的睡眠状态数据保存到本地存储。"""
|
||||||
|
try:
|
||||||
|
state = {
|
||||||
|
"current_state": self.current_state.name,
|
||||||
|
"sleep_buffer_end_time_ts": self.sleep_buffer_end_time.timestamp()
|
||||||
|
if self.sleep_buffer_end_time
|
||||||
|
else None,
|
||||||
|
"total_delayed_minutes_today": self.total_delayed_minutes_today,
|
||||||
|
"last_sleep_check_date_str": self.last_sleep_check_date.isoformat()
|
||||||
|
if self.last_sleep_check_date
|
||||||
|
else None,
|
||||||
|
"re_sleep_attempt_time_ts": self.re_sleep_attempt_time.timestamp()
|
||||||
|
if self.re_sleep_attempt_time
|
||||||
|
else None,
|
||||||
|
}
|
||||||
|
local_storage["schedule_sleep_state"] = state
|
||||||
|
logger.debug(f"已保存睡眠上下文: {state}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存睡眠上下文失败: {e}")
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
"""从本地存储加载并解析睡眠状态。"""
|
||||||
|
try:
|
||||||
|
state = local_storage["schedule_sleep_state"]
|
||||||
|
if not (state and isinstance(state, dict)):
|
||||||
|
logger.info("未找到本地睡眠上下文,使用默认值。")
|
||||||
|
return
|
||||||
|
|
||||||
|
state_name = state.get("current_state")
|
||||||
|
if state_name and hasattr(SleepState, state_name):
|
||||||
|
self.current_state = SleepState[state_name]
|
||||||
|
|
||||||
|
end_time_ts = state.get("sleep_buffer_end_time_ts")
|
||||||
|
if end_time_ts:
|
||||||
|
self.sleep_buffer_end_time = datetime.fromtimestamp(end_time_ts)
|
||||||
|
|
||||||
|
re_sleep_ts = state.get("re_sleep_attempt_time_ts")
|
||||||
|
if re_sleep_ts:
|
||||||
|
self.re_sleep_attempt_time = datetime.fromtimestamp(re_sleep_ts)
|
||||||
|
|
||||||
|
self.total_delayed_minutes_today = state.get("total_delayed_minutes_today", 0.0)
|
||||||
|
|
||||||
|
date_str = state.get("last_sleep_check_date_str")
|
||||||
|
if date_str:
|
||||||
|
self.last_sleep_check_date = datetime.fromisoformat(date_str).date()
|
||||||
|
|
||||||
|
logger.info(f"成功从本地存储加载睡眠上下文: {state}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"加载睡眠上下文失败,将使用默认值: {e}")
|
||||||
45
src/chat/message_manager/sleep_manager/wakeup_context.py
Normal file
45
src/chat/message_manager/sleep_manager/wakeup_context.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
import time
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.manager.local_store_manager import local_storage
|
||||||
|
|
||||||
|
logger = get_logger("wakeup_context")
|
||||||
|
|
||||||
|
|
||||||
|
class WakeUpContext:
|
||||||
|
"""
|
||||||
|
唤醒上下文,负责封装和管理所有与唤醒相关的状态,并处理其持久化。
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
"""初始化唤醒上下文,并从本地存储加载初始状态。"""
|
||||||
|
self.wakeup_value: float = 0.0
|
||||||
|
self.is_angry: bool = False
|
||||||
|
self.angry_start_time: float = 0.0
|
||||||
|
self.sleep_pressure: float = 100.0 # 新增:睡眠压力
|
||||||
|
self.load()
|
||||||
|
|
||||||
|
def _get_storage_key(self) -> str:
|
||||||
|
"""获取本地存储键"""
|
||||||
|
return "global_wakeup_manager_state"
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
"""从本地存储加载状态"""
|
||||||
|
state = local_storage[self._get_storage_key()]
|
||||||
|
if state and isinstance(state, dict):
|
||||||
|
self.wakeup_value = state.get("wakeup_value", 0.0)
|
||||||
|
self.is_angry = state.get("is_angry", False)
|
||||||
|
self.angry_start_time = state.get("angry_start_time", 0.0)
|
||||||
|
self.sleep_pressure = state.get("sleep_pressure", 100.0)
|
||||||
|
logger.info(f"成功从本地存储加载唤醒上下文: {state}")
|
||||||
|
else:
|
||||||
|
logger.info("未找到本地唤醒上下文,将使用默认值初始化。")
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
"""将当前状态保存到本地存储"""
|
||||||
|
state = {
|
||||||
|
"wakeup_value": self.wakeup_value,
|
||||||
|
"is_angry": self.is_angry,
|
||||||
|
"angry_start_time": self.angry_start_time,
|
||||||
|
"sleep_pressure": self.sleep_pressure,
|
||||||
|
}
|
||||||
|
local_storage[self._get_storage_key()] = state
|
||||||
|
logger.debug(f"已将唤醒上下文保存到本地存储: {state}")
|
||||||
215
src/chat/message_manager/sleep_manager/wakeup_manager.py
Normal file
215
src/chat/message_manager/sleep_manager/wakeup_manager.py
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from typing import Optional, TYPE_CHECKING
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.manager.local_store_manager import local_storage
|
||||||
|
from src.chat.message_manager.sleep_manager.wakeup_context import WakeUpContext
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .sleep_manager import SleepManager
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger("wakeup")
|
||||||
|
|
||||||
|
|
||||||
|
class WakeUpManager:
|
||||||
|
def __init__(self, sleep_manager: "SleepManager"):
|
||||||
|
"""
|
||||||
|
初始化唤醒度管理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sleep_manager: 睡眠管理器实例
|
||||||
|
|
||||||
|
功能说明:
|
||||||
|
- 管理休眠状态下的唤醒度累积
|
||||||
|
- 处理唤醒度的自然衰减
|
||||||
|
- 控制愤怒状态的持续时间
|
||||||
|
"""
|
||||||
|
self.sleep_manager = sleep_manager
|
||||||
|
self.context = WakeUpContext() # 使用新的上下文管理器
|
||||||
|
self.angry_chat_id: Optional[str] = None
|
||||||
|
self.last_decay_time = time.time()
|
||||||
|
self._decay_task: Optional[asyncio.Task] = None
|
||||||
|
self.is_running = False
|
||||||
|
self.last_log_time = 0
|
||||||
|
self.log_interval = 30
|
||||||
|
|
||||||
|
# 从配置文件获取参数
|
||||||
|
sleep_config = global_config.sleep_system
|
||||||
|
self.wakeup_threshold = sleep_config.wakeup_threshold
|
||||||
|
self.private_message_increment = sleep_config.private_message_increment
|
||||||
|
self.group_mention_increment = sleep_config.group_mention_increment
|
||||||
|
self.decay_rate = sleep_config.decay_rate
|
||||||
|
self.decay_interval = sleep_config.decay_interval
|
||||||
|
self.angry_duration = sleep_config.angry_duration
|
||||||
|
self.enabled = sleep_config.enable
|
||||||
|
self.angry_prompt = sleep_config.angry_prompt
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""启动唤醒度管理器"""
|
||||||
|
if not self.enabled:
|
||||||
|
logger.info("唤醒度系统已禁用,跳过启动")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.is_running = True
|
||||||
|
if not self._decay_task or self._decay_task.done():
|
||||||
|
self._decay_task = asyncio.create_task(self._decay_loop())
|
||||||
|
self._decay_task.add_done_callback(self._handle_decay_completion)
|
||||||
|
logger.info("唤醒度管理器已启动")
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""停止唤醒度管理器"""
|
||||||
|
self.is_running = False
|
||||||
|
if self._decay_task and not self._decay_task.done():
|
||||||
|
self._decay_task.cancel()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
logger.info("唤醒度管理器已停止")
|
||||||
|
|
||||||
|
def _handle_decay_completion(self, task: asyncio.Task):
|
||||||
|
"""处理衰减任务完成"""
|
||||||
|
try:
|
||||||
|
if exception := task.exception():
|
||||||
|
logger.error(f"唤醒度衰减任务异常: {exception}")
|
||||||
|
else:
|
||||||
|
logger.info("唤醒度衰减任务正常结束")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("唤醒度衰减任务被取消")
|
||||||
|
|
||||||
|
async def _decay_loop(self):
|
||||||
|
"""唤醒度衰减循环"""
|
||||||
|
while self.is_running:
|
||||||
|
await asyncio.sleep(self.decay_interval)
|
||||||
|
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# 检查愤怒状态是否过期
|
||||||
|
if self.context.is_angry and current_time - self.context.angry_start_time >= self.angry_duration:
|
||||||
|
self.context.is_angry = False
|
||||||
|
# 通知情绪管理系统清除愤怒状态
|
||||||
|
from src.mood.mood_manager import mood_manager
|
||||||
|
if self.angry_chat_id:
|
||||||
|
mood_manager.clear_angry_from_wakeup(self.angry_chat_id)
|
||||||
|
self.angry_chat_id = None
|
||||||
|
else:
|
||||||
|
logger.warning("Angry state ended but no angry_chat_id was set.")
|
||||||
|
logger.info("愤怒状态结束,恢复正常")
|
||||||
|
self.context.save()
|
||||||
|
|
||||||
|
# 唤醒度自然衰减
|
||||||
|
if self.context.wakeup_value > 0:
|
||||||
|
old_value = self.context.wakeup_value
|
||||||
|
self.context.wakeup_value = max(0, self.context.wakeup_value - self.decay_rate)
|
||||||
|
if old_value != self.context.wakeup_value:
|
||||||
|
logger.debug(f"唤醒度衰减: {old_value:.1f} -> {self.context.wakeup_value:.1f}")
|
||||||
|
self.context.save()
|
||||||
|
|
||||||
|
def add_wakeup_value(self, is_private_chat: bool, is_mentioned: bool = False, chat_id: Optional[str] = None) -> bool:
|
||||||
|
"""
|
||||||
|
增加唤醒度值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
is_private_chat: 是否为私聊
|
||||||
|
is_mentioned: 是否被艾特(仅群聊有效)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否达到唤醒阈值
|
||||||
|
"""
|
||||||
|
# 如果系统未启用,直接返回
|
||||||
|
if not self.enabled:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 只有在休眠且非失眠状态下才累积唤醒度
|
||||||
|
from .sleep_state import SleepState
|
||||||
|
|
||||||
|
current_sleep_state = self.sleep_manager.get_current_sleep_state()
|
||||||
|
if current_sleep_state != SleepState.SLEEPING:
|
||||||
|
return False
|
||||||
|
|
||||||
|
old_value = self.context.wakeup_value
|
||||||
|
|
||||||
|
if is_private_chat:
|
||||||
|
# 私聊每条消息都增加唤醒度
|
||||||
|
self.context.wakeup_value += self.private_message_increment
|
||||||
|
logger.debug(f"私聊消息增加唤醒度: +{self.private_message_increment}")
|
||||||
|
elif is_mentioned:
|
||||||
|
# 群聊只有被艾特才增加唤醒度
|
||||||
|
self.context.wakeup_value += self.group_mention_increment
|
||||||
|
logger.debug(f"群聊艾特增加唤醒度: +{self.group_mention_increment}")
|
||||||
|
else:
|
||||||
|
# 群聊未被艾特,不增加唤醒度
|
||||||
|
return False
|
||||||
|
|
||||||
|
current_time = time.time()
|
||||||
|
if current_time - self.last_log_time > self.log_interval:
|
||||||
|
logger.info(
|
||||||
|
f"唤醒度变化: {old_value:.1f} -> {self.context.wakeup_value:.1f} (阈值: {self.wakeup_threshold})"
|
||||||
|
)
|
||||||
|
self.last_log_time = current_time
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"唤醒度变化: {old_value:.1f} -> {self.context.wakeup_value:.1f} (阈值: {self.wakeup_threshold})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查是否达到唤醒阈值
|
||||||
|
if self.context.wakeup_value >= self.wakeup_threshold:
|
||||||
|
if not chat_id:
|
||||||
|
logger.error("Wakeup threshold reached, but no chat_id was provided. Cannot trigger wakeup.")
|
||||||
|
return False
|
||||||
|
self._trigger_wakeup(chat_id)
|
||||||
|
return True
|
||||||
|
|
||||||
|
self.context.save()
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _trigger_wakeup(self, chat_id: str):
|
||||||
|
"""触发唤醒,进入愤怒状态"""
|
||||||
|
self.context.is_angry = True
|
||||||
|
self.context.angry_start_time = time.time()
|
||||||
|
self.context.wakeup_value = 0.0 # 重置唤醒度
|
||||||
|
self.angry_chat_id = chat_id
|
||||||
|
|
||||||
|
self.context.save()
|
||||||
|
|
||||||
|
# 通知情绪管理系统进入愤怒状态
|
||||||
|
from src.mood.mood_manager import mood_manager
|
||||||
|
mood_manager.set_angry_from_wakeup(chat_id)
|
||||||
|
|
||||||
|
# 通知SleepManager重置睡眠状态
|
||||||
|
self.sleep_manager.reset_sleep_state_after_wakeup()
|
||||||
|
|
||||||
|
logger.info(f"唤醒度达到阈值({self.wakeup_threshold}),被吵醒进入愤怒状态!")
|
||||||
|
|
||||||
|
def get_angry_prompt_addition(self) -> str:
|
||||||
|
"""获取愤怒状态下的提示词补充"""
|
||||||
|
if self.context.is_angry:
|
||||||
|
return self.angry_prompt
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def is_in_angry_state(self) -> bool:
|
||||||
|
"""检查是否处于愤怒状态"""
|
||||||
|
if self.context.is_angry:
|
||||||
|
current_time = time.time()
|
||||||
|
if current_time - self.context.angry_start_time >= self.angry_duration:
|
||||||
|
self.context.is_angry = False
|
||||||
|
# 通知情绪管理系统清除愤怒状态
|
||||||
|
from src.mood.mood_manager import mood_manager
|
||||||
|
if self.angry_chat_id:
|
||||||
|
mood_manager.clear_angry_from_wakeup(self.angry_chat_id)
|
||||||
|
self.angry_chat_id = None
|
||||||
|
else:
|
||||||
|
logger.warning("Angry state expired in check, but no angry_chat_id was set.")
|
||||||
|
logger.info("愤怒状态自动过期")
|
||||||
|
return False
|
||||||
|
return self.context.is_angry
|
||||||
|
|
||||||
|
def get_status_info(self) -> dict:
|
||||||
|
"""获取当前状态信息"""
|
||||||
|
return {
|
||||||
|
"wakeup_value": self.context.wakeup_value,
|
||||||
|
"wakeup_threshold": self.wakeup_threshold,
|
||||||
|
"is_angry": self.context.is_angry,
|
||||||
|
"angry_remaining_time": max(0, self.angry_duration - (time.time() - self.context.angry_start_time))
|
||||||
|
if self.context.is_angry
|
||||||
|
else 0,
|
||||||
|
}
|
||||||
@@ -11,11 +11,12 @@ from src.mood.mood_manager import mood_manager # 导入情绪管理器
|
|||||||
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
|
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
|
||||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||||
from src.chat.message_receive.storage import MessageStorage
|
from src.chat.message_receive.storage import MessageStorage
|
||||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
from src.chat.message_manager import message_manager
|
||||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||||
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
|
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
|
||||||
from src.plugin_system.base import BaseCommand, EventType
|
from src.plugin_system.base import BaseCommand, EventType
|
||||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||||
|
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||||
|
|
||||||
# 导入反注入系统
|
# 导入反注入系统
|
||||||
from src.chat.antipromptinjector import initialize_anti_injector
|
from src.chat.antipromptinjector import initialize_anti_injector
|
||||||
@@ -73,15 +74,17 @@ class ChatBot:
|
|||||||
self.bot = None # bot 实例引用
|
self.bot = None # bot 实例引用
|
||||||
self._started = False
|
self._started = False
|
||||||
self.mood_manager = mood_manager # 获取情绪管理器单例
|
self.mood_manager = mood_manager # 获取情绪管理器单例
|
||||||
self.heartflow_message_receiver = HeartFCMessageReceiver() # 新增
|
# 亲和力流消息处理器 - 直接使用全局afc_manager
|
||||||
|
|
||||||
self.s4u_message_processor = S4UMessageProcessor()
|
self.s4u_message_processor = S4UMessageProcessor()
|
||||||
|
|
||||||
# 初始化反注入系统
|
# 初始化反注入系统
|
||||||
self._initialize_anti_injector()
|
self._initialize_anti_injector()
|
||||||
|
|
||||||
@staticmethod
|
# 启动消息管理器
|
||||||
def _initialize_anti_injector():
|
self._message_manager_started = False
|
||||||
|
|
||||||
|
def _initialize_anti_injector(self):
|
||||||
"""初始化反注入系统"""
|
"""初始化反注入系统"""
|
||||||
try:
|
try:
|
||||||
initialize_anti_injector()
|
initialize_anti_injector()
|
||||||
@@ -99,10 +102,15 @@ class ChatBot:
|
|||||||
if not self._started:
|
if not self._started:
|
||||||
logger.debug("确保ChatBot所有任务已启动")
|
logger.debug("确保ChatBot所有任务已启动")
|
||||||
|
|
||||||
|
# 启动消息管理器
|
||||||
|
if not self._message_manager_started:
|
||||||
|
await message_manager.start()
|
||||||
|
self._message_manager_started = True
|
||||||
|
logger.info("消息管理器已启动")
|
||||||
|
|
||||||
self._started = True
|
self._started = True
|
||||||
|
|
||||||
@staticmethod
|
async def _process_plus_commands(self, message: MessageRecv):
|
||||||
async def _process_plus_commands(message: MessageRecv):
|
|
||||||
"""独立处理PlusCommand系统"""
|
"""独立处理PlusCommand系统"""
|
||||||
try:
|
try:
|
||||||
text = message.processed_plain_text
|
text = message.processed_plain_text
|
||||||
@@ -182,7 +190,7 @@ class ChatBot:
|
|||||||
try:
|
try:
|
||||||
# 检查聊天类型限制
|
# 检查聊天类型限制
|
||||||
if not plus_command_instance.is_chat_type_allowed():
|
if not plus_command_instance.is_chat_type_allowed():
|
||||||
is_group = hasattr(message, "is_group_message") and message.is_group_message
|
is_group = message.message_info.group_info
|
||||||
logger.info(
|
logger.info(
|
||||||
f"PlusCommand {plus_command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
|
f"PlusCommand {plus_command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
|
||||||
)
|
)
|
||||||
@@ -222,8 +230,7 @@ class ChatBot:
|
|||||||
logger.error(f"处理PlusCommand时出错: {e}")
|
logger.error(f"处理PlusCommand时出错: {e}")
|
||||||
return False, None, True # 出错时继续处理消息
|
return False, None, True # 出错时继续处理消息
|
||||||
|
|
||||||
@staticmethod
|
async def _process_commands_with_new_system(self, message: MessageRecv):
|
||||||
async def _process_commands_with_new_system(message: MessageRecv):
|
|
||||||
# sourcery skip: use-named-expression
|
# sourcery skip: use-named-expression
|
||||||
"""使用新插件系统处理命令"""
|
"""使用新插件系统处理命令"""
|
||||||
try:
|
try:
|
||||||
@@ -256,7 +263,7 @@ class ChatBot:
|
|||||||
try:
|
try:
|
||||||
# 检查聊天类型限制
|
# 检查聊天类型限制
|
||||||
if not command_instance.is_chat_type_allowed():
|
if not command_instance.is_chat_type_allowed():
|
||||||
is_group = hasattr(message, "is_group_message") and message.is_group_message
|
is_group = message.message_info.group_info
|
||||||
logger.info(
|
logger.info(
|
||||||
f"命令 {command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
|
f"命令 {command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
|
||||||
)
|
)
|
||||||
@@ -313,8 +320,7 @@ class ChatBot:
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
async def handle_adapter_response(self, message: MessageRecv):
|
||||||
async def handle_adapter_response(message: MessageRecv):
|
|
||||||
"""处理适配器命令响应"""
|
"""处理适配器命令响应"""
|
||||||
try:
|
try:
|
||||||
from src.plugin_system.apis.send_api import put_adapter_response
|
from src.plugin_system.apis.send_api import put_adapter_response
|
||||||
@@ -354,19 +360,7 @@ class ChatBot:
|
|||||||
return
|
return
|
||||||
|
|
||||||
async def message_process(self, message_data: Dict[str, Any]) -> None:
|
async def message_process(self, message_data: Dict[str, Any]) -> None:
|
||||||
"""处理转化后的统一格式消息
|
"""处理转化后的统一格式消息"""
|
||||||
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
|
|
||||||
heart_flow模式:使用思维流系统进行回复
|
|
||||||
- 包含思维流状态管理
|
|
||||||
- 在回复前进行观察和状态更新
|
|
||||||
- 回复后更新思维流状态
|
|
||||||
- 消息过滤
|
|
||||||
- 记忆激活
|
|
||||||
- 意愿计算
|
|
||||||
- 消息生成和发送
|
|
||||||
- 表情包处理
|
|
||||||
- 性能计时
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# 首先处理可能的切片消息重组
|
# 首先处理可能的切片消息重组
|
||||||
from src.utils.message_chunker import reassembler
|
from src.utils.message_chunker import reassembler
|
||||||
@@ -403,9 +397,7 @@ class ChatBot:
|
|||||||
# logger.debug(str(message_data))
|
# logger.debug(str(message_data))
|
||||||
message = MessageRecv(message_data)
|
message = MessageRecv(message_data)
|
||||||
|
|
||||||
if await self.handle_notice_message(message):
|
message.is_mentioned, _ = is_mentioned_bot_in_message(message)
|
||||||
...
|
|
||||||
|
|
||||||
group_info = message.message_info.group_info
|
group_info = message.message_info.group_info
|
||||||
user_info = message.message_info.user_info
|
user_info = message.message_info.user_info
|
||||||
if message.message_info.additional_config:
|
if message.message_info.additional_config:
|
||||||
@@ -415,6 +407,7 @@ class ChatBot:
|
|||||||
return
|
return
|
||||||
|
|
||||||
get_chat_manager().register_message(message)
|
get_chat_manager().register_message(message)
|
||||||
|
|
||||||
chat = await get_chat_manager().get_or_create_stream(
|
chat = await get_chat_manager().get_or_create_stream(
|
||||||
platform=message.message_info.platform, # type: ignore
|
platform=message.message_info.platform, # type: ignore
|
||||||
user_info=user_info, # type: ignore
|
user_info=user_info, # type: ignore
|
||||||
@@ -426,11 +419,14 @@ class ChatBot:
|
|||||||
# 处理消息内容,生成纯文本
|
# 处理消息内容,生成纯文本
|
||||||
await message.process()
|
await message.process()
|
||||||
|
|
||||||
# 过滤检查 (在消息处理之后进行)
|
# 在这里打印[所见]日志,确保在所有处理和过滤之前记录
|
||||||
if _check_ban_words(
|
logger.info(f"\u001b[38;5;118m{message.message_info.user_info.user_nickname}:{message.processed_plain_text}\u001b[0m")
|
||||||
message.processed_plain_text, chat, user_info # type: ignore
|
|
||||||
) or _check_ban_regex(
|
# 过滤检查
|
||||||
message.processed_plain_text, chat, user_info # type: ignore
|
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
||||||
|
message.raw_message, # type: ignore
|
||||||
|
chat,
|
||||||
|
user_info, # type: ignore
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -457,6 +453,7 @@ class ChatBot:
|
|||||||
if not result.all_continue_process():
|
if not result.all_continue_process():
|
||||||
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理")
|
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理")
|
||||||
|
|
||||||
|
# TODO:暂不可用
|
||||||
# 确认从接口发来的message是否有自定义的prompt模板信息
|
# 确认从接口发来的message是否有自定义的prompt模板信息
|
||||||
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
||||||
template_group_name: Optional[str] = message.message_info.template_info.template_name # type: ignore
|
template_group_name: Optional[str] = message.message_info.template_info.template_name # type: ignore
|
||||||
@@ -470,7 +467,55 @@ class ChatBot:
|
|||||||
template_group_name = None
|
template_group_name = None
|
||||||
|
|
||||||
async def preprocess():
|
async def preprocess():
|
||||||
await self.heartflow_message_receiver.process_message(message)
|
# 存储消息到数据库
|
||||||
|
from .storage import MessageStorage
|
||||||
|
|
||||||
|
try:
|
||||||
|
await MessageStorage.store_message(message, message.chat_stream)
|
||||||
|
logger.debug(f"消息已存储到数据库: {message.message_info.message_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"存储消息到数据库失败: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# 使用消息管理器处理消息(保持原有功能)
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|
||||||
|
# 创建数据库消息对象
|
||||||
|
db_message = DatabaseMessages(
|
||||||
|
message_id=message.message_info.message_id,
|
||||||
|
time=message.message_info.time,
|
||||||
|
chat_id=message.chat_stream.stream_id,
|
||||||
|
processed_plain_text=message.processed_plain_text,
|
||||||
|
display_message=message.processed_plain_text,
|
||||||
|
is_mentioned=message.is_mentioned,
|
||||||
|
is_at=message.is_at,
|
||||||
|
is_emoji=message.is_emoji,
|
||||||
|
is_picid=message.is_picid,
|
||||||
|
is_command=message.is_command,
|
||||||
|
is_notify=message.is_notify,
|
||||||
|
user_id=message.message_info.user_info.user_id,
|
||||||
|
user_nickname=message.message_info.user_info.user_nickname,
|
||||||
|
user_cardname=message.message_info.user_info.user_cardname,
|
||||||
|
user_platform=message.message_info.user_info.platform,
|
||||||
|
chat_info_stream_id=message.chat_stream.stream_id,
|
||||||
|
chat_info_platform=message.chat_stream.platform,
|
||||||
|
chat_info_create_time=message.chat_stream.create_time,
|
||||||
|
chat_info_last_active_time=message.chat_stream.last_active_time,
|
||||||
|
chat_info_user_id=message.chat_stream.user_info.user_id,
|
||||||
|
chat_info_user_nickname=message.chat_stream.user_info.user_nickname,
|
||||||
|
chat_info_user_cardname=message.chat_stream.user_info.user_cardname,
|
||||||
|
chat_info_user_platform=message.chat_stream.user_info.platform,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果是群聊,添加群组信息
|
||||||
|
if message.chat_stream.group_info:
|
||||||
|
db_message.chat_info_group_id = message.chat_stream.group_info.group_id
|
||||||
|
db_message.chat_info_group_name = message.chat_stream.group_info.group_name
|
||||||
|
db_message.chat_info_group_platform = message.chat_stream.group_info.platform
|
||||||
|
|
||||||
|
# 添加消息到消息管理器
|
||||||
|
message_manager.add_message(message.chat_stream.stream_id, db_message)
|
||||||
|
logger.debug(f"消息已添加到消息管理器: {message.chat_stream.stream_id}")
|
||||||
|
|
||||||
if template_group_name:
|
if template_group_name:
|
||||||
async with global_prompt_manager.async_message_scope(template_group_name):
|
async with global_prompt_manager.async_message_scope(template_group_name):
|
||||||
|
|||||||
@@ -25,43 +25,6 @@ install(extra_lines=3)
|
|||||||
logger = get_logger("chat_stream")
|
logger = get_logger("chat_stream")
|
||||||
|
|
||||||
|
|
||||||
class ChatMessageContext:
|
|
||||||
"""聊天消息上下文,存储消息的上下文信息"""
|
|
||||||
|
|
||||||
def __init__(self, message: "MessageRecv"):
|
|
||||||
self.message = message
|
|
||||||
|
|
||||||
def get_template_name(self) -> Optional[str]:
|
|
||||||
"""获取模板名称"""
|
|
||||||
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
|
|
||||||
return self.message.message_info.template_info.template_name # type: ignore
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_last_message(self) -> "MessageRecv":
|
|
||||||
"""获取最后一条消息"""
|
|
||||||
return self.message
|
|
||||||
|
|
||||||
def check_types(self, types: list) -> bool:
|
|
||||||
# sourcery skip: invert-any-all, use-any, use-next
|
|
||||||
"""检查消息类型"""
|
|
||||||
if not self.message.message_info.format_info.accept_format: # type: ignore
|
|
||||||
return False
|
|
||||||
for t in types:
|
|
||||||
if t not in self.message.message_info.format_info.accept_format: # type: ignore
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_priority_mode(self) -> str:
|
|
||||||
"""获取优先级模式"""
|
|
||||||
return self.message.priority_mode
|
|
||||||
|
|
||||||
def get_priority_info(self) -> Optional[dict]:
|
|
||||||
"""获取优先级信息"""
|
|
||||||
if hasattr(self.message, "priority_info") and self.message.priority_info:
|
|
||||||
return self.message.priority_info
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class ChatStream:
|
class ChatStream:
|
||||||
"""聊天流对象,存储一个完整的聊天上下文"""
|
"""聊天流对象,存储一个完整的聊天上下文"""
|
||||||
|
|
||||||
@@ -79,14 +42,24 @@ class ChatStream:
|
|||||||
self.group_info = group_info
|
self.group_info = group_info
|
||||||
self.create_time = data.get("create_time", time.time()) if data else time.time()
|
self.create_time = data.get("create_time", time.time()) if data else time.time()
|
||||||
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
|
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
|
||||||
self.energy_value = data.get("energy_value", 5.0) if data else 5.0
|
|
||||||
self.sleep_pressure = data.get("sleep_pressure", 0.0) if data else 0.0
|
self.sleep_pressure = data.get("sleep_pressure", 0.0) if data else 0.0
|
||||||
self.saved = False
|
self.saved = False
|
||||||
self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息
|
|
||||||
# 从配置文件中读取focus_value,如果没有则使用默认值1.0
|
# 使用StreamContext替代ChatMessageContext
|
||||||
self.focus_energy = data.get("focus_energy", global_config.chat.focus_value) if data else global_config.chat.focus_value
|
from src.common.data_models.message_manager_data_model import StreamContext
|
||||||
|
from src.plugin_system.base.component_types import ChatType, ChatMode
|
||||||
|
|
||||||
|
self.stream_context: StreamContext = StreamContext(
|
||||||
|
stream_id=stream_id, chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, chat_mode=ChatMode.NORMAL
|
||||||
|
)
|
||||||
|
|
||||||
|
# 基础参数
|
||||||
|
self.base_interest_energy = 0.5 # 默认基础兴趣度
|
||||||
|
self._focus_energy = 0.5 # 内部存储的focus_energy值
|
||||||
self.no_reply_consecutive = 0
|
self.no_reply_consecutive = 0
|
||||||
self.breaking_accumulated_interest = 0.0
|
|
||||||
|
# 自动加载历史消息
|
||||||
|
self._load_history_messages()
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
"""转换为字典格式"""
|
"""转换为字典格式"""
|
||||||
@@ -97,10 +70,15 @@ class ChatStream:
|
|||||||
"group_info": self.group_info.to_dict() if self.group_info else None,
|
"group_info": self.group_info.to_dict() if self.group_info else None,
|
||||||
"create_time": self.create_time,
|
"create_time": self.create_time,
|
||||||
"last_active_time": self.last_active_time,
|
"last_active_time": self.last_active_time,
|
||||||
"energy_value": self.energy_value,
|
|
||||||
"sleep_pressure": self.sleep_pressure,
|
"sleep_pressure": self.sleep_pressure,
|
||||||
"focus_energy": self.focus_energy,
|
"focus_energy": self.focus_energy,
|
||||||
"breaking_accumulated_interest": self.breaking_accumulated_interest,
|
# 基础兴趣度
|
||||||
|
"base_interest_energy": self.base_interest_energy,
|
||||||
|
# 新增stream_context信息
|
||||||
|
"stream_context_chat_type": self.stream_context.chat_type.value,
|
||||||
|
"stream_context_chat_mode": self.stream_context.chat_mode.value,
|
||||||
|
# 新增interruption_count信息
|
||||||
|
"interruption_count": self.stream_context.interruption_count,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -109,7 +87,7 @@ class ChatStream:
|
|||||||
user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
|
user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
|
||||||
group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
|
group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
|
||||||
|
|
||||||
return cls(
|
instance = cls(
|
||||||
stream_id=data["stream_id"],
|
stream_id=data["stream_id"],
|
||||||
platform=data["platform"],
|
platform=data["platform"],
|
||||||
user_info=user_info, # type: ignore
|
user_info=user_info, # type: ignore
|
||||||
@@ -117,6 +95,22 @@ class ChatStream:
|
|||||||
data=data,
|
data=data,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 恢复stream_context信息
|
||||||
|
if "stream_context_chat_type" in data:
|
||||||
|
from src.plugin_system.base.component_types import ChatType, ChatMode
|
||||||
|
|
||||||
|
instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"])
|
||||||
|
if "stream_context_chat_mode" in data:
|
||||||
|
from src.plugin_system.base.component_types import ChatType, ChatMode
|
||||||
|
|
||||||
|
instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"])
|
||||||
|
|
||||||
|
# 恢复interruption_count信息
|
||||||
|
if "interruption_count" in data:
|
||||||
|
instance.stream_context.interruption_count = data["interruption_count"]
|
||||||
|
|
||||||
|
return instance
|
||||||
|
|
||||||
def update_active_time(self):
|
def update_active_time(self):
|
||||||
"""更新最后活跃时间"""
|
"""更新最后活跃时间"""
|
||||||
self.last_active_time = time.time()
|
self.last_active_time = time.time()
|
||||||
@@ -124,7 +118,312 @@ class ChatStream:
|
|||||||
|
|
||||||
def set_context(self, message: "MessageRecv"):
|
def set_context(self, message: "MessageRecv"):
|
||||||
"""设置聊天消息上下文"""
|
"""设置聊天消息上下文"""
|
||||||
self.context = ChatMessageContext(message)
|
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
import json
|
||||||
|
|
||||||
|
# 安全获取message_info中的数据
|
||||||
|
message_info = getattr(message, "message_info", {})
|
||||||
|
user_info = getattr(message_info, "user_info", {})
|
||||||
|
group_info = getattr(message_info, "group_info", {})
|
||||||
|
|
||||||
|
# 提取reply_to信息(从message_segment中查找reply类型的段)
|
||||||
|
reply_to = None
|
||||||
|
if hasattr(message, "message_segment") and message.message_segment:
|
||||||
|
reply_to = self._extract_reply_from_segment(message.message_segment)
|
||||||
|
|
||||||
|
# 完整的数据转移逻辑
|
||||||
|
db_message = DatabaseMessages(
|
||||||
|
# 基础消息信息
|
||||||
|
message_id=getattr(message, "message_id", ""),
|
||||||
|
time=getattr(message, "time", time.time()),
|
||||||
|
chat_id=self._generate_chat_id(message_info),
|
||||||
|
reply_to=reply_to,
|
||||||
|
# 兴趣度相关
|
||||||
|
interest_value=getattr(message, "interest_value", 0.0),
|
||||||
|
# 关键词
|
||||||
|
key_words=json.dumps(getattr(message, "key_words", []), ensure_ascii=False)
|
||||||
|
if getattr(message, "key_words", None)
|
||||||
|
else None,
|
||||||
|
key_words_lite=json.dumps(getattr(message, "key_words_lite", []), ensure_ascii=False)
|
||||||
|
if getattr(message, "key_words_lite", None)
|
||||||
|
else None,
|
||||||
|
# 消息状态标记
|
||||||
|
is_mentioned=getattr(message, "is_mentioned", None),
|
||||||
|
is_at=getattr(message, "is_at", False),
|
||||||
|
is_emoji=getattr(message, "is_emoji", False),
|
||||||
|
is_picid=getattr(message, "is_picid", False),
|
||||||
|
is_voice=getattr(message, "is_voice", False),
|
||||||
|
is_video=getattr(message, "is_video", False),
|
||||||
|
is_command=getattr(message, "is_command", False),
|
||||||
|
is_notify=getattr(message, "is_notify", False),
|
||||||
|
# 消息内容
|
||||||
|
processed_plain_text=getattr(message, "processed_plain_text", ""),
|
||||||
|
display_message=getattr(message, "processed_plain_text", ""), # 默认使用processed_plain_text
|
||||||
|
# 优先级信息
|
||||||
|
priority_mode=getattr(message, "priority_mode", None),
|
||||||
|
priority_info=json.dumps(getattr(message, "priority_info", None))
|
||||||
|
if getattr(message, "priority_info", None)
|
||||||
|
else None,
|
||||||
|
# 额外配置
|
||||||
|
additional_config=getattr(message_info, "additional_config", None),
|
||||||
|
# 用户信息
|
||||||
|
user_id=str(getattr(user_info, "user_id", "")),
|
||||||
|
user_nickname=getattr(user_info, "user_nickname", ""),
|
||||||
|
user_cardname=getattr(user_info, "user_cardname", None),
|
||||||
|
user_platform=getattr(user_info, "platform", ""),
|
||||||
|
# 群组信息
|
||||||
|
chat_info_group_id=getattr(group_info, "group_id", None),
|
||||||
|
chat_info_group_name=getattr(group_info, "group_name", None),
|
||||||
|
chat_info_group_platform=getattr(group_info, "platform", None),
|
||||||
|
# 聊天流信息
|
||||||
|
chat_info_user_id=str(getattr(user_info, "user_id", "")),
|
||||||
|
chat_info_user_nickname=getattr(user_info, "user_nickname", ""),
|
||||||
|
chat_info_user_cardname=getattr(user_info, "user_cardname", None),
|
||||||
|
chat_info_user_platform=getattr(user_info, "platform", ""),
|
||||||
|
chat_info_stream_id=self.stream_id,
|
||||||
|
chat_info_platform=self.platform,
|
||||||
|
chat_info_create_time=self.create_time,
|
||||||
|
chat_info_last_active_time=self.last_active_time,
|
||||||
|
# 新增兴趣度系统字段 - 添加安全处理
|
||||||
|
actions=self._safe_get_actions(message),
|
||||||
|
should_reply=getattr(message, "should_reply", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.stream_context.set_current_message(db_message)
|
||||||
|
self.stream_context.priority_mode = getattr(message, "priority_mode", None)
|
||||||
|
self.stream_context.priority_info = getattr(message, "priority_info", None)
|
||||||
|
|
||||||
|
# 调试日志:记录数据转移情况
|
||||||
|
logger.debug(f"消息数据转移完成 - message_id: {db_message.message_id}, "
|
||||||
|
f"chat_id: {db_message.chat_id}, "
|
||||||
|
f"is_mentioned: {db_message.is_mentioned}, "
|
||||||
|
f"is_emoji: {db_message.is_emoji}, "
|
||||||
|
f"is_picid: {db_message.is_picid}, "
|
||||||
|
f"interest_value: {db_message.interest_value}")
|
||||||
|
|
||||||
|
def _safe_get_actions(self, message: "MessageRecv") -> Optional[list]:
|
||||||
|
"""安全获取消息的actions字段"""
|
||||||
|
try:
|
||||||
|
actions = getattr(message, "actions", None)
|
||||||
|
if actions is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 如果是字符串,尝试解析为JSON
|
||||||
|
if isinstance(actions, str):
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
actions = json.loads(actions)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"无法解析actions JSON字符串: {actions}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 确保返回列表类型
|
||||||
|
if isinstance(actions, list):
|
||||||
|
# 过滤掉空值和非字符串元素
|
||||||
|
filtered_actions = [action for action in actions if action is not None and isinstance(action, str)]
|
||||||
|
return filtered_actions if filtered_actions else None
|
||||||
|
else:
|
||||||
|
logger.warning(f"actions字段类型不支持: {type(actions)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"获取actions字段失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _extract_reply_from_segment(self, segment) -> Optional[str]:
|
||||||
|
"""从消息段中提取reply_to信息"""
|
||||||
|
try:
|
||||||
|
if hasattr(segment, "type") and segment.type == "seglist":
|
||||||
|
# 递归搜索seglist中的reply段
|
||||||
|
if hasattr(segment, "data") and segment.data:
|
||||||
|
for seg in segment.data:
|
||||||
|
reply_id = self._extract_reply_from_segment(seg)
|
||||||
|
if reply_id:
|
||||||
|
return reply_id
|
||||||
|
elif hasattr(segment, "type") and segment.type == "reply":
|
||||||
|
# 找到reply段,返回message_id
|
||||||
|
return str(segment.data) if segment.data else None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"提取reply_to信息失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _generate_chat_id(self, message_info) -> str:
|
||||||
|
"""生成chat_id,基于群组或用户信息"""
|
||||||
|
try:
|
||||||
|
group_info = getattr(message_info, "group_info", None)
|
||||||
|
user_info = getattr(message_info, "user_info", None)
|
||||||
|
|
||||||
|
if group_info and hasattr(group_info, "group_id") and group_info.group_id:
|
||||||
|
# 群聊:使用群组ID
|
||||||
|
return f"{self.platform}_{group_info.group_id}"
|
||||||
|
elif user_info and hasattr(user_info, "user_id") and user_info.user_id:
|
||||||
|
# 私聊:使用用户ID
|
||||||
|
return f"{self.platform}_{user_info.user_id}_private"
|
||||||
|
else:
|
||||||
|
# 默认:使用stream_id
|
||||||
|
return self.stream_id
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"生成chat_id失败: {e}")
|
||||||
|
return self.stream_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def focus_energy(self) -> float:
|
||||||
|
"""使用重构后的能量管理器计算focus_energy"""
|
||||||
|
try:
|
||||||
|
from src.chat.energy_system import energy_manager
|
||||||
|
|
||||||
|
# 获取所有消息
|
||||||
|
history_messages = self.stream_context.get_history_messages(limit=global_config.chat.max_context_size)
|
||||||
|
unread_messages = self.stream_context.get_unread_messages()
|
||||||
|
all_messages = history_messages + unread_messages
|
||||||
|
|
||||||
|
# 获取用户ID
|
||||||
|
user_id = None
|
||||||
|
if self.user_info and hasattr(self.user_info, "user_id"):
|
||||||
|
user_id = str(self.user_info.user_id)
|
||||||
|
|
||||||
|
# 使用能量管理器计算
|
||||||
|
energy = energy_manager.calculate_focus_energy(
|
||||||
|
stream_id=self.stream_id,
|
||||||
|
messages=all_messages,
|
||||||
|
user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新内部存储
|
||||||
|
self._focus_energy = energy
|
||||||
|
|
||||||
|
logger.debug(f"聊天流 {self.stream_id} 能量: {energy:.3f}")
|
||||||
|
return energy
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取focus_energy失败: {e}", exc_info=True)
|
||||||
|
# 返回缓存的值或默认值
|
||||||
|
if hasattr(self, '_focus_energy'):
|
||||||
|
return self._focus_energy
|
||||||
|
else:
|
||||||
|
return 0.5
|
||||||
|
|
||||||
|
@focus_energy.setter
|
||||||
|
def focus_energy(self, value: float):
|
||||||
|
"""设置focus_energy值(主要用于初始化或特殊场景)"""
|
||||||
|
self._focus_energy = max(0.0, min(1.0, value))
|
||||||
|
|
||||||
|
def _get_user_relationship_score(self) -> float:
|
||||||
|
"""获取用户关系分"""
|
||||||
|
# 使用插件内部的兴趣度评分系统
|
||||||
|
try:
|
||||||
|
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||||
|
|
||||||
|
if self.user_info and hasattr(self.user_info, "user_id"):
|
||||||
|
user_id = str(self.user_info.user_id)
|
||||||
|
relationship_score = chatter_interest_scoring_system._calculate_relationship_score(user_id)
|
||||||
|
logger.debug(f"ChatStream {self.stream_id}: 用户关系分 = {relationship_score:.3f}")
|
||||||
|
return max(0.0, min(1.0, relationship_score))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"ChatStream {self.stream_id}: 插件内部关系分计算失败: {e}")
|
||||||
|
|
||||||
|
# 默认基础分
|
||||||
|
return 0.3
|
||||||
|
|
||||||
|
def _load_history_messages(self):
|
||||||
|
"""从数据库加载历史消息到StreamContext"""
|
||||||
|
try:
|
||||||
|
from src.common.database.sqlalchemy_models import Messages
|
||||||
|
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
from sqlalchemy import select, desc
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def _load_messages():
|
||||||
|
def _db_query():
|
||||||
|
with get_db_session() as session:
|
||||||
|
# 查询该stream_id的最近20条消息
|
||||||
|
stmt = (
|
||||||
|
select(Messages)
|
||||||
|
.where(Messages.chat_info_stream_id == self.stream_id)
|
||||||
|
.order_by(desc(Messages.time))
|
||||||
|
.limit(global_config.chat.max_context_size)
|
||||||
|
)
|
||||||
|
results = session.execute(stmt).scalars().all()
|
||||||
|
return results
|
||||||
|
|
||||||
|
# 在线程中执行数据库查询
|
||||||
|
db_messages = await asyncio.to_thread(_db_query)
|
||||||
|
|
||||||
|
# 转换为DatabaseMessages对象并添加到StreamContext
|
||||||
|
for db_msg in db_messages:
|
||||||
|
try:
|
||||||
|
# 从SQLAlchemy模型转换为DatabaseMessages数据模型
|
||||||
|
import orjson
|
||||||
|
|
||||||
|
# 解析actions字段(JSON格式)
|
||||||
|
actions = None
|
||||||
|
if db_msg.actions:
|
||||||
|
try:
|
||||||
|
actions = orjson.loads(db_msg.actions)
|
||||||
|
except (orjson.JSONDecodeError, TypeError):
|
||||||
|
actions = None
|
||||||
|
|
||||||
|
db_message = DatabaseMessages(
|
||||||
|
message_id=db_msg.message_id,
|
||||||
|
time=db_msg.time,
|
||||||
|
chat_id=db_msg.chat_id,
|
||||||
|
reply_to=db_msg.reply_to,
|
||||||
|
interest_value=db_msg.interest_value,
|
||||||
|
key_words=db_msg.key_words,
|
||||||
|
key_words_lite=db_msg.key_words_lite,
|
||||||
|
is_mentioned=db_msg.is_mentioned,
|
||||||
|
processed_plain_text=db_msg.processed_plain_text,
|
||||||
|
display_message=db_msg.display_message,
|
||||||
|
priority_mode=db_msg.priority_mode,
|
||||||
|
priority_info=db_msg.priority_info,
|
||||||
|
additional_config=db_msg.additional_config,
|
||||||
|
is_emoji=db_msg.is_emoji,
|
||||||
|
is_picid=db_msg.is_picid,
|
||||||
|
is_command=db_msg.is_command,
|
||||||
|
is_notify=db_msg.is_notify,
|
||||||
|
user_id=db_msg.user_id,
|
||||||
|
user_nickname=db_msg.user_nickname,
|
||||||
|
user_cardname=db_msg.user_cardname,
|
||||||
|
user_platform=db_msg.user_platform,
|
||||||
|
chat_info_group_id=db_msg.chat_info_group_id,
|
||||||
|
chat_info_group_name=db_msg.chat_info_group_name,
|
||||||
|
chat_info_group_platform=db_msg.chat_info_group_platform,
|
||||||
|
chat_info_user_id=db_msg.chat_info_user_id,
|
||||||
|
chat_info_user_nickname=db_msg.chat_info_user_nickname,
|
||||||
|
chat_info_user_cardname=db_msg.chat_info_user_cardname,
|
||||||
|
chat_info_user_platform=db_msg.chat_info_user_platform,
|
||||||
|
chat_info_stream_id=db_msg.chat_info_stream_id,
|
||||||
|
chat_info_platform=db_msg.chat_info_platform,
|
||||||
|
chat_info_create_time=db_msg.chat_info_create_time,
|
||||||
|
chat_info_last_active_time=db_msg.chat_info_last_active_time,
|
||||||
|
actions=actions,
|
||||||
|
should_reply=getattr(db_msg, "should_reply", False) or False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加调试日志:检查从数据库加载的interest_value
|
||||||
|
logger.debug(f"加载历史消息 {db_message.message_id} - interest_value: {db_message.interest_value}")
|
||||||
|
|
||||||
|
# 标记为已读并添加到历史消息
|
||||||
|
db_message.is_read = True
|
||||||
|
self.stream_context.history_messages.append(db_message)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"转换消息 {db_msg.message_id} 失败: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self.stream_context.history_messages:
|
||||||
|
logger.info(
|
||||||
|
f"已从数据库加载 {len(self.stream_context.history_messages)} 条历史消息到聊天流 {self.stream_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建任务来加载历史消息
|
||||||
|
asyncio.create_task(_load_messages())
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载历史消息失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
class ChatManager:
|
class ChatManager:
|
||||||
@@ -362,7 +661,16 @@ class ChatManager:
|
|||||||
"group_name": group_info_d["group_name"] if group_info_d else "",
|
"group_name": group_info_d["group_name"] if group_info_d else "",
|
||||||
"energy_value": s_data_dict.get("energy_value", 5.0),
|
"energy_value": s_data_dict.get("energy_value", 5.0),
|
||||||
"sleep_pressure": s_data_dict.get("sleep_pressure", 0.0),
|
"sleep_pressure": s_data_dict.get("sleep_pressure", 0.0),
|
||||||
"focus_energy": s_data_dict.get("focus_energy", global_config.chat.focus_value),
|
"focus_energy": s_data_dict.get("focus_energy", 0.5),
|
||||||
|
# 新增动态兴趣度系统字段
|
||||||
|
"base_interest_energy": s_data_dict.get("base_interest_energy", 0.5),
|
||||||
|
"message_interest_total": s_data_dict.get("message_interest_total", 0.0),
|
||||||
|
"message_count": s_data_dict.get("message_count", 0),
|
||||||
|
"action_count": s_data_dict.get("action_count", 0),
|
||||||
|
"reply_count": s_data_dict.get("reply_count", 0),
|
||||||
|
"last_interaction_time": s_data_dict.get("last_interaction_time", time.time()),
|
||||||
|
"consecutive_no_reply": s_data_dict.get("consecutive_no_reply", 0),
|
||||||
|
"interruption_count": s_data_dict.get("interruption_count", 0),
|
||||||
}
|
}
|
||||||
if global_config.database.database_type == "sqlite":
|
if global_config.database.database_type == "sqlite":
|
||||||
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||||||
@@ -419,7 +727,17 @@ class ChatManager:
|
|||||||
"last_active_time": model_instance.last_active_time,
|
"last_active_time": model_instance.last_active_time,
|
||||||
"energy_value": model_instance.energy_value,
|
"energy_value": model_instance.energy_value,
|
||||||
"sleep_pressure": model_instance.sleep_pressure,
|
"sleep_pressure": model_instance.sleep_pressure,
|
||||||
"focus_energy": getattr(model_instance, "focus_energy", global_config.chat.focus_value),
|
"focus_energy": getattr(model_instance, "focus_energy", 0.5),
|
||||||
|
# 新增动态兴趣度系统字段 - 使用getattr提供默认值
|
||||||
|
"base_interest_energy": getattr(model_instance, "base_interest_energy", 0.5),
|
||||||
|
"message_interest_total": getattr(model_instance, "message_interest_total", 0.0),
|
||||||
|
"message_count": getattr(model_instance, "message_count", 0),
|
||||||
|
"action_count": getattr(model_instance, "action_count", 0),
|
||||||
|
"reply_count": getattr(model_instance, "reply_count", 0),
|
||||||
|
"last_interaction_time": getattr(model_instance, "last_interaction_time", time.time()),
|
||||||
|
"relationship_score": getattr(model_instance, "relationship_score", 0.3),
|
||||||
|
"consecutive_no_reply": getattr(model_instance, "consecutive_no_reply", 0),
|
||||||
|
"interruption_count": getattr(model_instance, "interruption_count", 0),
|
||||||
}
|
}
|
||||||
loaded_streams_data.append(data_for_from_dict)
|
loaded_streams_data.append(data_for_from_dict)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ class MessageRecv(Message):
|
|||||||
self.is_video = False
|
self.is_video = False
|
||||||
self.is_mentioned = None
|
self.is_mentioned = None
|
||||||
self.is_notify = False
|
self.is_notify = False
|
||||||
|
self.is_at = False
|
||||||
self.is_command = False
|
self.is_command = False
|
||||||
|
|
||||||
self.priority_mode = "interest"
|
self.priority_mode = "interest"
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
import re
|
import re
|
||||||
import traceback
|
import traceback
|
||||||
|
import orjson
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import orjson
|
from src.common.database.sqlalchemy_models import Messages, Images
|
||||||
from sqlalchemy import select, desc, update
|
|
||||||
|
|
||||||
from src.common.database.sqlalchemy_models import Messages, Images, get_db_session
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from .chat_stream import ChatStream
|
from .chat_stream import ChatStream
|
||||||
from .message import MessageSending, MessageRecv
|
from .message import MessageSending, MessageRecv
|
||||||
|
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||||
|
from sqlalchemy import select, update, desc
|
||||||
|
|
||||||
logger = get_logger("message_storage")
|
logger = get_logger("message_storage")
|
||||||
|
|
||||||
@@ -41,7 +41,7 @@ class MessageStorage:
|
|||||||
processed_plain_text = message.processed_plain_text
|
processed_plain_text = message.processed_plain_text
|
||||||
|
|
||||||
if processed_plain_text:
|
if processed_plain_text:
|
||||||
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
|
processed_plain_text = MessageStorage.replace_image_descriptions(processed_plain_text)
|
||||||
filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL)
|
filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL)
|
||||||
else:
|
else:
|
||||||
filtered_processed_plain_text = ""
|
filtered_processed_plain_text = ""
|
||||||
@@ -51,7 +51,8 @@ class MessageStorage:
|
|||||||
if display_message:
|
if display_message:
|
||||||
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
|
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
|
||||||
else:
|
else:
|
||||||
filtered_display_message = ""
|
# 如果没有设置display_message,使用processed_plain_text作为显示消息
|
||||||
|
filtered_display_message = re.sub(pattern, "", message.processed_plain_text, flags=re.DOTALL) if message.processed_plain_text else ""
|
||||||
interest_value = 0
|
interest_value = 0
|
||||||
is_mentioned = False
|
is_mentioned = False
|
||||||
reply_to = message.reply_to
|
reply_to = message.reply_to
|
||||||
@@ -116,14 +117,21 @@ class MessageStorage:
|
|||||||
user_nickname=user_info_dict.get("user_nickname"),
|
user_nickname=user_info_dict.get("user_nickname"),
|
||||||
user_cardname=user_info_dict.get("user_cardname"),
|
user_cardname=user_info_dict.get("user_cardname"),
|
||||||
processed_plain_text=filtered_processed_plain_text,
|
processed_plain_text=filtered_processed_plain_text,
|
||||||
|
display_message=filtered_display_message,
|
||||||
|
memorized_times=message.memorized_times,
|
||||||
|
interest_value=interest_value,
|
||||||
priority_mode=priority_mode,
|
priority_mode=priority_mode,
|
||||||
priority_info=priority_info_json,
|
priority_info=priority_info_json,
|
||||||
is_emoji=is_emoji,
|
is_emoji=is_emoji,
|
||||||
is_picid=is_picid,
|
is_picid=is_picid,
|
||||||
|
is_notify=is_notify,
|
||||||
|
is_command=is_command,
|
||||||
|
key_words=key_words,
|
||||||
|
key_words_lite=key_words_lite,
|
||||||
)
|
)
|
||||||
async with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
session.add(new_message)
|
session.add(new_message)
|
||||||
await session.commit()
|
session.commit()
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("存储消息失败")
|
logger.exception("存储消息失败")
|
||||||
@@ -146,6 +154,7 @@ class MessageStorage:
|
|||||||
qq_message_id = message.message_segment.data.get("id")
|
qq_message_id = message.message_segment.data.get("id")
|
||||||
elif message.message_segment.type == "reply":
|
elif message.message_segment.type == "reply":
|
||||||
qq_message_id = message.message_segment.data.get("id")
|
qq_message_id = message.message_segment.data.get("id")
|
||||||
|
if qq_message_id:
|
||||||
logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}")
|
logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}")
|
||||||
elif message.message_segment.type == "adapter_response":
|
elif message.message_segment.type == "adapter_response":
|
||||||
logger.debug("适配器响应消息,不需要更新ID")
|
logger.debug("适配器响应消息,不需要更新ID")
|
||||||
@@ -162,19 +171,18 @@ class MessageStorage:
|
|||||||
logger.debug(f"消息段数据: {message.message_segment.data}")
|
logger.debug(f"消息段数据: {message.message_segment.data}")
|
||||||
return
|
return
|
||||||
|
|
||||||
async with get_db_session() as session:
|
# 使用上下文管理器确保session正确管理
|
||||||
matched_message = (
|
from src.common.database.sqlalchemy_models import get_db_session
|
||||||
await session.execute(
|
|
||||||
|
with get_db_session() as session:
|
||||||
|
matched_message = session.execute(
|
||||||
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
|
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
|
||||||
)
|
|
||||||
).scalar()
|
).scalar()
|
||||||
|
|
||||||
if matched_message:
|
if matched_message:
|
||||||
await session.execute(
|
session.execute(
|
||||||
update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id)
|
update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id)
|
||||||
)
|
)
|
||||||
await session.commit()
|
|
||||||
# 会在上下文管理器中自动调用
|
|
||||||
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
|
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"未找到匹配的消息记录: {mmc_message_id}")
|
logger.warning(f"未找到匹配的消息记录: {mmc_message_id}")
|
||||||
@@ -186,36 +194,117 @@ class MessageStorage:
|
|||||||
f"segment_type={getattr(message.message_segment, 'type', 'N/A')}"
|
f"segment_type={getattr(message.message_segment, 'type', 'N/A')}"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def replace_image_descriptions(text: str) -> str:
|
@staticmethod
|
||||||
|
def replace_image_descriptions(text: str) -> str:
|
||||||
"""将[图片:描述]替换为[picid:image_id]"""
|
"""将[图片:描述]替换为[picid:image_id]"""
|
||||||
# 先检查文本中是否有图片标记
|
# 先检查文本中是否有图片标记
|
||||||
pattern = r"\[图片:([^\]]+)\]"
|
pattern = r"\[图片:([^\]]+)\]"
|
||||||
matches = list(re.finditer(pattern, text))
|
matches = re.findall(pattern, text)
|
||||||
|
|
||||||
if not matches:
|
if not matches:
|
||||||
logger.debug("文本中没有图片标记,直接返回原文本")
|
logger.debug("文本中没有图片标记,直接返回原文本")
|
||||||
return text
|
return text
|
||||||
|
|
||||||
new_text = ""
|
def replace_match(match):
|
||||||
last_end = 0
|
|
||||||
for match in matches:
|
|
||||||
new_text += text[last_end : match.start()]
|
|
||||||
description = match.group(1).strip()
|
description = match.group(1).strip()
|
||||||
try:
|
try:
|
||||||
from src.common.database.sqlalchemy_models import get_db_session
|
from src.common.database.sqlalchemy_models import get_db_session
|
||||||
|
|
||||||
async with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
image_record = (
|
image_record = session.execute(
|
||||||
await session.execute(
|
|
||||||
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
|
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
|
||||||
)
|
|
||||||
).scalar()
|
).scalar()
|
||||||
if image_record:
|
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
|
||||||
new_text += f"[picid:{image_record.image_id}]"
|
|
||||||
else:
|
|
||||||
new_text += match.group(0)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
new_text += match.group(0)
|
return match.group(0)
|
||||||
last_end = match.end()
|
|
||||||
new_text += text[last_end:]
|
@staticmethod
|
||||||
return new_text
|
def update_message_interest_value(message_id: str, interest_value: float) -> None:
|
||||||
|
"""
|
||||||
|
更新数据库中消息的interest_value字段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_id: 消息ID
|
||||||
|
interest_value: 兴趣度值
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with get_db_session() as session:
|
||||||
|
# 更新消息的interest_value字段
|
||||||
|
stmt = update(Messages).where(Messages.message_id == message_id).values(interest_value=interest_value)
|
||||||
|
result = session.execute(stmt)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
if result.rowcount > 0:
|
||||||
|
logger.debug(f"成功更新消息 {message_id} 的interest_value为 {interest_value}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"未找到消息 {message_id},无法更新interest_value")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"更新消息 {message_id} 的interest_value失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fix_zero_interest_values(chat_id: str, since_time: float) -> int:
|
||||||
|
"""
|
||||||
|
修复指定聊天中interest_value为0或null的历史消息记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id: 聊天ID
|
||||||
|
since_time: 从指定时间开始修复(时间戳)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
修复的记录数量
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with get_db_session() as session:
|
||||||
|
from sqlalchemy import select, update
|
||||||
|
from src.common.database.sqlalchemy_models import Messages
|
||||||
|
|
||||||
|
# 查找需要修复的记录:interest_value为0、null或很小的值
|
||||||
|
query = select(Messages).where(
|
||||||
|
(Messages.chat_id == chat_id) &
|
||||||
|
(Messages.time >= since_time) &
|
||||||
|
(
|
||||||
|
(Messages.interest_value == 0) |
|
||||||
|
(Messages.interest_value.is_(None)) |
|
||||||
|
(Messages.interest_value < 0.1)
|
||||||
|
)
|
||||||
|
).limit(50) # 限制每次修复的数量,避免性能问题
|
||||||
|
|
||||||
|
messages_to_fix = session.execute(query).scalars().all()
|
||||||
|
fixed_count = 0
|
||||||
|
|
||||||
|
for msg in messages_to_fix:
|
||||||
|
# 为这些消息设置一个合理的默认兴趣度
|
||||||
|
# 可以基于消息长度、内容或其他因素计算
|
||||||
|
default_interest = 0.3 # 默认中等兴趣度
|
||||||
|
|
||||||
|
# 如果消息内容较长,可能是重要消息,兴趣度稍高
|
||||||
|
if hasattr(msg, 'processed_plain_text') and msg.processed_plain_text:
|
||||||
|
text_length = len(msg.processed_plain_text)
|
||||||
|
if text_length > 50: # 长消息
|
||||||
|
default_interest = 0.4
|
||||||
|
elif text_length > 20: # 中等长度消息
|
||||||
|
default_interest = 0.35
|
||||||
|
|
||||||
|
# 如果是被@的消息,兴趣度更高
|
||||||
|
if getattr(msg, 'is_mentioned', False):
|
||||||
|
default_interest = min(default_interest + 0.2, 0.8)
|
||||||
|
|
||||||
|
# 执行更新
|
||||||
|
update_stmt = update(Messages).where(
|
||||||
|
Messages.message_id == msg.message_id
|
||||||
|
).values(interest_value=default_interest)
|
||||||
|
|
||||||
|
result = session.execute(update_stmt)
|
||||||
|
if result.rowcount > 0:
|
||||||
|
fixed_count += 1
|
||||||
|
logger.debug(f"修复消息 {msg.message_id} 的interest_value为 {default_interest}")
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
logger.info(f"共修复了 {fixed_count} 条历史消息的interest_value值")
|
||||||
|
return fixed_count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"修复历史消息interest_value失败: {e}")
|
||||||
|
return 0
|
||||||
|
|||||||
@@ -1,15 +1,24 @@
|
|||||||
from typing import Dict, Optional, Type
|
import asyncio
|
||||||
|
import traceback
|
||||||
|
import time
|
||||||
|
from typing import Dict, Optional, Type, Any, Tuple
|
||||||
|
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
|
||||||
|
from src.chat.utils.timer_calculator import Timer
|
||||||
|
from src.person_info.person_info import get_person_info_manager
|
||||||
|
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
from src.plugin_system.base.component_types import ComponentType, ActionInfo
|
from src.plugin_system.base.component_types import ComponentType, ActionInfo
|
||||||
from src.plugin_system.base.base_action import BaseAction
|
from src.plugin_system.base.base_action import BaseAction
|
||||||
|
from src.plugin_system.apis import generator_api, database_api, send_api, message_api
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("action_manager")
|
logger = get_logger("action_manager")
|
||||||
|
|
||||||
|
|
||||||
class ActionManager:
|
class ChatterActionManager:
|
||||||
"""
|
"""
|
||||||
动作管理器,用于管理各种类型的动作
|
动作管理器,用于管理各种类型的动作
|
||||||
|
|
||||||
@@ -25,6 +34,8 @@ class ActionManager:
|
|||||||
# 初始化时将默认动作加载到使用中的动作
|
# 初始化时将默认动作加载到使用中的动作
|
||||||
self._using_actions = component_registry.get_default_actions()
|
self._using_actions = component_registry.get_default_actions()
|
||||||
|
|
||||||
|
self.log_prefix: str = "ChatterActionManager"
|
||||||
|
|
||||||
# === 执行Action方法 ===
|
# === 执行Action方法 ===
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -124,3 +135,417 @@ class ActionManager:
|
|||||||
actions_to_restore = list(self._using_actions.keys())
|
actions_to_restore = list(self._using_actions.keys())
|
||||||
self._using_actions = component_registry.get_default_actions()
|
self._using_actions = component_registry.get_default_actions()
|
||||||
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
|
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
|
||||||
|
|
||||||
|
async def execute_action(
|
||||||
|
self,
|
||||||
|
action_name: str,
|
||||||
|
chat_id: str,
|
||||||
|
target_message: Optional[dict] = None,
|
||||||
|
reasoning: str = "",
|
||||||
|
action_data: Optional[dict] = None,
|
||||||
|
thinking_id: Optional[str] = None,
|
||||||
|
log_prefix: str = "",
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
执行单个动作的通用函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_name: 动作名称
|
||||||
|
chat_id: 聊天id
|
||||||
|
target_message: 目标消息
|
||||||
|
reasoning: 执行理由
|
||||||
|
action_data: 动作数据
|
||||||
|
thinking_id: 思考ID
|
||||||
|
log_prefix: 日志前缀
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
执行结果
|
||||||
|
"""
|
||||||
|
from src.chat.message_manager.message_manager import message_manager
|
||||||
|
try:
|
||||||
|
logger.debug(f"🎯 [ActionManager] execute_action接收到 target_message: {target_message}")
|
||||||
|
# 通过chat_id获取chat_stream
|
||||||
|
chat_manager = get_chat_manager()
|
||||||
|
chat_stream = chat_manager.get_stream(chat_id)
|
||||||
|
|
||||||
|
if not chat_stream:
|
||||||
|
logger.error(f"{log_prefix} 无法找到chat_id对应的chat_stream: {chat_id}")
|
||||||
|
return {
|
||||||
|
"action_type": action_name,
|
||||||
|
"success": False,
|
||||||
|
"reply_text": "",
|
||||||
|
"error": "chat_stream not found",
|
||||||
|
}
|
||||||
|
|
||||||
|
if action_name == "no_action":
|
||||||
|
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
|
||||||
|
|
||||||
|
if action_name == "no_reply":
|
||||||
|
# 直接处理no_reply逻辑,不再通过动作系统
|
||||||
|
reason = reasoning or "选择不回复"
|
||||||
|
logger.info(f"{log_prefix} 选择不回复,原因: {reason}")
|
||||||
|
|
||||||
|
# 存储no_reply信息到数据库
|
||||||
|
await database_api.store_action_info(
|
||||||
|
chat_stream=chat_stream,
|
||||||
|
action_build_into_prompt=False,
|
||||||
|
action_prompt_display=reason,
|
||||||
|
action_done=True,
|
||||||
|
thinking_id=thinking_id,
|
||||||
|
action_data={"reason": reason},
|
||||||
|
action_name="no_reply",
|
||||||
|
)
|
||||||
|
return {"action_type": "no_reply", "success": True, "reply_text": "", "command": ""}
|
||||||
|
|
||||||
|
elif action_name != "reply" and action_name != "no_action":
|
||||||
|
# 执行普通动作
|
||||||
|
success, reply_text, command = await self._handle_action(
|
||||||
|
chat_stream,
|
||||||
|
action_name,
|
||||||
|
reasoning,
|
||||||
|
action_data or {},
|
||||||
|
{}, # cycle_timers
|
||||||
|
thinking_id,
|
||||||
|
target_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 记录执行的动作到目标消息
|
||||||
|
if success:
|
||||||
|
await self._record_action_to_message(chat_stream, action_name, target_message, action_data)
|
||||||
|
# 重置打断计数
|
||||||
|
await self._reset_interruption_count_after_action(chat_stream.stream_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"action_type": action_name,
|
||||||
|
"success": success,
|
||||||
|
"reply_text": reply_text,
|
||||||
|
"command": command,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 生成回复
|
||||||
|
try:
|
||||||
|
success, response_set, _ = await generator_api.generate_reply(
|
||||||
|
chat_stream=chat_stream,
|
||||||
|
reply_message=target_message,
|
||||||
|
action_data=action_data or {},
|
||||||
|
available_actions=self.get_using_actions(),
|
||||||
|
enable_tool=global_config.tool.enable_tool,
|
||||||
|
request_type="chat.replyer",
|
||||||
|
from_plugin=False,
|
||||||
|
)
|
||||||
|
if not success or not response_set:
|
||||||
|
logger.info(
|
||||||
|
f"对 {target_message.get('processed_plain_text') if target_message else '未知消息'} 的回复生成失败"
|
||||||
|
)
|
||||||
|
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.debug(f"{log_prefix} 并行执行:回复生成任务已被取消")
|
||||||
|
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||||
|
|
||||||
|
# 发送并存储回复
|
||||||
|
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
|
||||||
|
chat_stream,
|
||||||
|
response_set,
|
||||||
|
asyncio.get_event_loop().time(),
|
||||||
|
target_message,
|
||||||
|
{}, # cycle_timers
|
||||||
|
thinking_id,
|
||||||
|
[], # actions
|
||||||
|
)
|
||||||
|
|
||||||
|
# 记录回复动作到目标消息
|
||||||
|
await self._record_action_to_message(chat_stream, "reply", target_message, action_data)
|
||||||
|
|
||||||
|
# 回复成功,重置打断计数
|
||||||
|
await self._reset_interruption_count_after_action(chat_stream.stream_id)
|
||||||
|
|
||||||
|
return {"action_type": "reply", "success": True, "reply_text": reply_text, "loop_info": loop_info}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{log_prefix} 执行动作时出错: {e}")
|
||||||
|
logger.error(f"{log_prefix} 错误信息: {traceback.format_exc()}")
|
||||||
|
return {
|
||||||
|
"action_type": action_name,
|
||||||
|
"success": False,
|
||||||
|
"reply_text": "",
|
||||||
|
"loop_info": None,
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _record_action_to_message(self, chat_stream, action_name, target_message, action_data):
|
||||||
|
"""
|
||||||
|
记录执行的动作到目标消息中
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_stream: ChatStream实例
|
||||||
|
action_name: 动作名称
|
||||||
|
target_message: 目标消息
|
||||||
|
action_data: 动作数据
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from src.chat.message_manager.message_manager import message_manager
|
||||||
|
|
||||||
|
# 获取目标消息ID
|
||||||
|
target_message_id = None
|
||||||
|
if target_message and isinstance(target_message, dict):
|
||||||
|
target_message_id = target_message.get("message_id")
|
||||||
|
elif action_data and isinstance(action_data, dict):
|
||||||
|
target_message_id = action_data.get("target_message_id")
|
||||||
|
|
||||||
|
if not target_message_id:
|
||||||
|
logger.debug(f"无法获取目标消息ID,动作: {action_name}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 通过message_manager更新消息的动作记录并刷新focus_energy
|
||||||
|
if chat_stream.stream_id in message_manager.stream_contexts:
|
||||||
|
message_manager.add_action(
|
||||||
|
stream_id=chat_stream.stream_id,
|
||||||
|
message_id=target_message_id,
|
||||||
|
action=action_name
|
||||||
|
)
|
||||||
|
logger.debug(f"已记录动作 {action_name} 到消息 {target_message_id} 并更新focus_energy")
|
||||||
|
else:
|
||||||
|
logger.debug(f"未找到stream_context: {chat_stream.stream_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"记录动作到消息失败: {e}")
|
||||||
|
# 不抛出异常,避免影响主要功能
|
||||||
|
|
||||||
|
async def _reset_interruption_count_after_action(self, stream_id: str):
|
||||||
|
"""在动作执行成功后重置打断计数"""
|
||||||
|
from src.chat.message_manager.message_manager import message_manager
|
||||||
|
try:
|
||||||
|
if stream_id in message_manager.stream_contexts:
|
||||||
|
context = message_manager.stream_contexts[stream_id]
|
||||||
|
if context.interruption_count > 0:
|
||||||
|
old_count = context.interruption_count
|
||||||
|
old_afc_adjustment = context.get_afc_threshold_adjustment()
|
||||||
|
context.reset_interruption_count()
|
||||||
|
logger.debug(f"动作执行成功,重置聊天流 {stream_id} 的打断计数: {old_count} -> 0, afc调整: {old_afc_adjustment} -> 0")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"重置打断计数时出错: {e}")
|
||||||
|
|
||||||
|
async def _handle_action(
|
||||||
|
self, chat_stream, action, reasoning, action_data, cycle_timers, thinking_id, action_message
|
||||||
|
) -> tuple[bool, str, str]:
|
||||||
|
"""
|
||||||
|
处理具体的动作执行
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_stream: ChatStream实例
|
||||||
|
action: 动作名称
|
||||||
|
reasoning: 执行理由
|
||||||
|
action_data: 动作数据
|
||||||
|
cycle_timers: 循环计时器
|
||||||
|
thinking_id: 思考ID
|
||||||
|
action_message: 动作消息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (执行是否成功, 回复文本, 命令文本)
|
||||||
|
|
||||||
|
功能说明:
|
||||||
|
- 创建对应的动作处理器
|
||||||
|
- 执行动作并捕获异常
|
||||||
|
- 返回执行结果供上级方法整合
|
||||||
|
"""
|
||||||
|
if not chat_stream:
|
||||||
|
return False, "", ""
|
||||||
|
try:
|
||||||
|
# 创建动作处理器
|
||||||
|
action_handler = self.create_action(
|
||||||
|
action_name=action,
|
||||||
|
action_data=action_data,
|
||||||
|
reasoning=reasoning,
|
||||||
|
cycle_timers=cycle_timers,
|
||||||
|
thinking_id=thinking_id,
|
||||||
|
chat_stream=chat_stream,
|
||||||
|
log_prefix=self.log_prefix,
|
||||||
|
action_message=action_message,
|
||||||
|
)
|
||||||
|
if not action_handler:
|
||||||
|
# 动作处理器创建失败,尝试回退机制
|
||||||
|
logger.warning(f"{self.log_prefix} 创建动作处理器失败: {action},尝试回退方案")
|
||||||
|
|
||||||
|
# 获取当前可用的动作
|
||||||
|
available_actions = self.get_using_actions()
|
||||||
|
fallback_action = None
|
||||||
|
|
||||||
|
# 回退优先级:reply > 第一个可用动作
|
||||||
|
if "reply" in available_actions:
|
||||||
|
fallback_action = "reply"
|
||||||
|
elif available_actions:
|
||||||
|
fallback_action = list(available_actions.keys())[0]
|
||||||
|
|
||||||
|
if fallback_action and fallback_action != action:
|
||||||
|
logger.info(f"{self.log_prefix} 使用回退动作: {fallback_action}")
|
||||||
|
action_handler = self.create_action(
|
||||||
|
action_name=fallback_action,
|
||||||
|
action_data=action_data,
|
||||||
|
reasoning=f"原动作'{action}'不可用,自动回退。{reasoning}",
|
||||||
|
cycle_timers=cycle_timers,
|
||||||
|
thinking_id=thinking_id,
|
||||||
|
chat_stream=chat_stream,
|
||||||
|
log_prefix=self.log_prefix,
|
||||||
|
action_message=action_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not action_handler:
|
||||||
|
logger.error(f"{self.log_prefix} 回退方案也失败,无法创建任何动作处理器")
|
||||||
|
return False, "", ""
|
||||||
|
|
||||||
|
# 执行动作
|
||||||
|
success, reply_text = await action_handler.handle_action()
|
||||||
|
return success, reply_text, ""
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{self.log_prefix} 处理{action}时出错: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
return False, "", ""
|
||||||
|
|
||||||
|
async def _send_and_store_reply(
|
||||||
|
self,
|
||||||
|
chat_stream: ChatStream,
|
||||||
|
response_set,
|
||||||
|
loop_start_time,
|
||||||
|
action_message,
|
||||||
|
cycle_timers: Dict[str, float],
|
||||||
|
thinking_id,
|
||||||
|
actions,
|
||||||
|
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
|
||||||
|
"""
|
||||||
|
发送并存储回复信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_stream: ChatStream实例
|
||||||
|
response_set: 回复内容集合
|
||||||
|
loop_start_time: 循环开始时间
|
||||||
|
action_message: 动作消息
|
||||||
|
cycle_timers: 循环计时器
|
||||||
|
thinking_id: 思考ID
|
||||||
|
actions: 动作列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Dict[str, Any], str, Dict[str, float]]: 循环信息, 回复文本, 循环计时器
|
||||||
|
"""
|
||||||
|
# 发送回复
|
||||||
|
with Timer("回复发送", cycle_timers):
|
||||||
|
reply_text = await self.send_response(chat_stream, response_set, loop_start_time, action_message)
|
||||||
|
|
||||||
|
# 存储reply action信息
|
||||||
|
person_info_manager = get_person_info_manager()
|
||||||
|
|
||||||
|
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
||||||
|
platform = action_message.get("chat_info_platform")
|
||||||
|
if platform is None:
|
||||||
|
platform = getattr(chat_stream, "platform", "unknown")
|
||||||
|
|
||||||
|
# 获取用户信息并生成回复提示
|
||||||
|
person_id = person_info_manager.get_person_id(
|
||||||
|
platform,
|
||||||
|
action_message.get("user_id", ""),
|
||||||
|
)
|
||||||
|
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||||
|
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
|
||||||
|
|
||||||
|
# 存储动作信息到数据库
|
||||||
|
await database_api.store_action_info(
|
||||||
|
chat_stream=chat_stream,
|
||||||
|
action_build_into_prompt=False,
|
||||||
|
action_prompt_display=action_prompt_display,
|
||||||
|
action_done=True,
|
||||||
|
thinking_id=thinking_id,
|
||||||
|
action_data={"reply_text": reply_text},
|
||||||
|
action_name="reply",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建循环信息
|
||||||
|
loop_info: Dict[str, Any] = {
|
||||||
|
"loop_plan_info": {
|
||||||
|
"action_result": actions,
|
||||||
|
},
|
||||||
|
"loop_action_info": {
|
||||||
|
"action_taken": True,
|
||||||
|
"reply_text": reply_text,
|
||||||
|
"command": "",
|
||||||
|
"taken_time": time.time(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return loop_info, reply_text, cycle_timers
|
||||||
|
|
||||||
|
async def send_response(self, chat_stream, reply_set, thinking_start_time, message_data) -> str:
|
||||||
|
"""
|
||||||
|
发送回复内容的具体实现
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_stream: ChatStream实例
|
||||||
|
reply_set: 回复内容集合,包含多个回复段
|
||||||
|
reply_to: 回复目标
|
||||||
|
thinking_start_time: 思考开始时间
|
||||||
|
message_data: 消息数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 完整的回复文本
|
||||||
|
|
||||||
|
功能说明:
|
||||||
|
- 检查是否有新消息需要回复
|
||||||
|
- 处理主动思考的"沉默"决定
|
||||||
|
- 根据消息数量决定是否添加回复引用
|
||||||
|
- 逐段发送回复内容,支持打字效果
|
||||||
|
- 正确处理元组格式的回复段
|
||||||
|
"""
|
||||||
|
current_time = time.time()
|
||||||
|
# 计算新消息数量
|
||||||
|
new_message_count = message_api.count_new_messages(
|
||||||
|
chat_id=chat_stream.stream_id, start_time=thinking_start_time, end_time=current_time
|
||||||
|
)
|
||||||
|
|
||||||
|
# 根据新消息数量决定是否需要引用回复
|
||||||
|
reply_text = ""
|
||||||
|
is_proactive_thinking = (message_data.get("message_type") == "proactive_thinking") if message_data else True
|
||||||
|
|
||||||
|
logger.debug(f"[send_response] message_data: {message_data}")
|
||||||
|
|
||||||
|
first_replied = False
|
||||||
|
for reply_seg in reply_set:
|
||||||
|
# 调试日志:验证reply_seg的格式
|
||||||
|
logger.debug(f"Processing reply_seg type: {type(reply_seg)}, content: {reply_seg}")
|
||||||
|
|
||||||
|
# 修正:正确处理元组格式 (格式为: (type, content))
|
||||||
|
if isinstance(reply_seg, tuple) and len(reply_seg) >= 2:
|
||||||
|
_, data = reply_seg
|
||||||
|
else:
|
||||||
|
# 向下兼容:如果已经是字符串,则直接使用
|
||||||
|
data = str(reply_seg)
|
||||||
|
|
||||||
|
if isinstance(data, list):
|
||||||
|
data = "".join(map(str, data))
|
||||||
|
reply_text += data
|
||||||
|
|
||||||
|
# 如果是主动思考且内容为"沉默",则不发送
|
||||||
|
if is_proactive_thinking and data.strip() == "沉默":
|
||||||
|
logger.info(f"{self.log_prefix} 主动思考决定保持沉默,不发送消息")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 发送第一段回复
|
||||||
|
if not first_replied:
|
||||||
|
set_reply_flag = bool(message_data)
|
||||||
|
logger.debug(f"📤 [ActionManager] 准备发送第一段回复。message_data: {message_data}, set_reply: {set_reply_flag}")
|
||||||
|
await send_api.text_to_stream(
|
||||||
|
text=data,
|
||||||
|
stream_id=chat_stream.stream_id,
|
||||||
|
reply_to_message=message_data,
|
||||||
|
set_reply=set_reply_flag,
|
||||||
|
typing=False,
|
||||||
|
)
|
||||||
|
first_replied = True
|
||||||
|
else:
|
||||||
|
# 发送后续回复
|
||||||
|
sent_message = await send_api.text_to_stream(
|
||||||
|
text=data,
|
||||||
|
stream_id=chat_stream.stream_id,
|
||||||
|
reply_to_message=None,
|
||||||
|
set_reply=False,
|
||||||
|
typing=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return reply_text
|
||||||
@@ -7,8 +7,9 @@ from typing import List, Any, Dict, TYPE_CHECKING, Tuple
|
|||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageContext
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.chat.planner_actions.action_manager import ActionManager
|
from src.common.data_models.message_manager_data_model import StreamContext
|
||||||
|
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
||||||
from src.plugin_system.base.component_types import ActionInfo, ActionActivationType
|
from src.plugin_system.base.component_types import ActionInfo, ActionActivationType
|
||||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||||
@@ -27,7 +28,7 @@ class ActionModifier:
|
|||||||
支持并行判定和智能缓存优化。
|
支持并行判定和智能缓存优化。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, action_manager: ActionManager, chat_id: str):
|
def __init__(self, action_manager: ChatterActionManager, chat_id: str):
|
||||||
"""初始化动作处理器"""
|
"""初始化动作处理器"""
|
||||||
self.chat_id = chat_id
|
self.chat_id = chat_id
|
||||||
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.chat_id) # type: ignore
|
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.chat_id) # type: ignore
|
||||||
@@ -124,8 +125,9 @@ class ActionModifier:
|
|||||||
logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用")
|
logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用")
|
||||||
|
|
||||||
# === 第二阶段:检查动作的关联类型 ===
|
# === 第二阶段:检查动作的关联类型 ===
|
||||||
chat_context = self.chat_stream.context
|
chat_context = self.chat_stream.stream_context
|
||||||
type_mismatched_actions = self._check_action_associated_types(all_actions, chat_context)
|
current_actions_s2 = self.action_manager.get_using_actions()
|
||||||
|
type_mismatched_actions = self._check_action_associated_types(current_actions_s2, chat_context)
|
||||||
|
|
||||||
if type_mismatched_actions:
|
if type_mismatched_actions:
|
||||||
removals_s2.extend(type_mismatched_actions)
|
removals_s2.extend(type_mismatched_actions)
|
||||||
@@ -140,11 +142,12 @@ class ActionModifier:
|
|||||||
logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
|
logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
|
||||||
|
|
||||||
# 获取当前使用的动作集(经过第一阶段处理)
|
# 获取当前使用的动作集(经过第一阶段处理)
|
||||||
current_using_actions = self.action_manager.get_using_actions()
|
# 在第三阶段开始前,再次获取最新的动作列表
|
||||||
|
current_actions_s3 = self.action_manager.get_using_actions()
|
||||||
|
|
||||||
# 获取因激活类型判定而需要移除的动作
|
# 获取因激活类型判定而需要移除的动作
|
||||||
removals_s3 = await self._get_deactivated_actions_by_type(
|
removals_s3 = await self._get_deactivated_actions_by_type(
|
||||||
current_using_actions,
|
current_actions_s3,
|
||||||
chat_content,
|
chat_content,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -164,7 +167,7 @@ class ActionModifier:
|
|||||||
|
|
||||||
logger.info(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}")
|
logger.info(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}")
|
||||||
|
|
||||||
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
|
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: StreamContext):
|
||||||
type_mismatched_actions: List[Tuple[str, str]] = []
|
type_mismatched_actions: List[Tuple[str, str]] = []
|
||||||
for action_name, action_info in all_actions.items():
|
for action_name, action_info in all_actions.items():
|
||||||
if action_info.associated_types and not chat_context.check_types(action_info.associated_types):
|
if action_info.associated_types and not chat_context.check_types(action_info.associated_types):
|
||||||
|
|||||||
@@ -1,58 +0,0 @@
|
|||||||
"""
|
|
||||||
PlanExecutor: 接收 Plan 对象并执行其中的所有动作。
|
|
||||||
"""
|
|
||||||
from src.chat.planner_actions.action_manager import ActionManager
|
|
||||||
from src.common.data_models.info_data_model import Plan
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger("plan_executor")
|
|
||||||
|
|
||||||
|
|
||||||
class PlanExecutor:
|
|
||||||
"""
|
|
||||||
负责接收一个 Plan 对象,并执行其中最终确定的所有动作。
|
|
||||||
|
|
||||||
这个类是规划流程的最后一步,将规划结果转化为实际的动作执行。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
action_manager (ActionManager): 用于实际执行各种动作的管理器实例。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, action_manager: ActionManager):
|
|
||||||
"""
|
|
||||||
初始化 PlanExecutor。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action_manager (ActionManager): 一个 ActionManager 实例,用于执行动作。
|
|
||||||
"""
|
|
||||||
self.action_manager = action_manager
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def execute(plan: Plan):
|
|
||||||
"""
|
|
||||||
遍历并执行 Plan 对象中 `decided_actions` 列表里的所有动作。
|
|
||||||
|
|
||||||
如果动作类型为 "no_action",则会记录原因并跳过。
|
|
||||||
否则,它将调用 ActionManager 来执行相应的动作。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plan (Plan): 包含待执行动作列表的 Plan 对象。
|
|
||||||
"""
|
|
||||||
if not plan.decided_actions:
|
|
||||||
logger.info("没有需要执行的动作。")
|
|
||||||
return
|
|
||||||
|
|
||||||
for action_info in plan.decided_actions:
|
|
||||||
if action_info.action_type == "no_action":
|
|
||||||
logger.info(f"规划器决策不执行动作,原因: {action_info.reasoning}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# TODO: 对接 ActionManager 的执行方法
|
|
||||||
# 这是一个示例调用,需要根据 ActionManager 的最终实现进行调整
|
|
||||||
logger.info(f"执行动作: {action_info.action_type}, 原因: {action_info.reasoning}")
|
|
||||||
# await self.action_manager.execute_action(
|
|
||||||
# action_name=action_info.action_type,
|
|
||||||
# action_data=action_info.action_data,
|
|
||||||
# reasoning=action_info.reasoning,
|
|
||||||
# action_message=action_info.action_message,
|
|
||||||
# )
|
|
||||||
@@ -1,366 +0,0 @@
|
|||||||
"""
|
|
||||||
PlanFilter: 接收 Plan 对象,根据不同模式的逻辑进行筛选,决定最终要执行的动作。
|
|
||||||
"""
|
|
||||||
import orjson
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from json_repair import repair_json
|
|
||||||
|
|
||||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
|
||||||
from src.chat.utils.chat_message_builder import (
|
|
||||||
build_readable_actions,
|
|
||||||
build_readable_messages_with_id,
|
|
||||||
get_actions_by_timestamp_with_chat,
|
|
||||||
)
|
|
||||||
from src.chat.utils.prompt import global_prompt_manager
|
|
||||||
from src.common.data_models.info_data_model import ActionPlannerInfo, Plan
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.config.config import global_config, model_config
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
|
||||||
from src.mood.mood_manager import mood_manager
|
|
||||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode
|
|
||||||
from src.schedule.schedule_manager import schedule_manager
|
|
||||||
|
|
||||||
logger = get_logger("plan_filter")
|
|
||||||
|
|
||||||
|
|
||||||
class PlanFilter:
|
|
||||||
"""
|
|
||||||
根据 Plan 中的模式和信息,筛选并决定最终的动作。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.planner_llm = LLMRequest(
|
|
||||||
model_set=model_config.model_task_config.planner, request_type="planner"
|
|
||||||
)
|
|
||||||
self.last_obs_time_mark = 0.0
|
|
||||||
|
|
||||||
async def filter(self, plan: Plan) -> Plan:
|
|
||||||
"""
|
|
||||||
执行筛选逻辑,并填充 Plan 对象的 decided_actions 字段。
|
|
||||||
"""
|
|
||||||
logger.debug(f"墨墨在这里加了日志 -> filter 入口 plan: {plan}")
|
|
||||||
try:
|
|
||||||
prompt, used_message_id_list = await self._build_prompt(plan)
|
|
||||||
plan.llm_prompt = prompt
|
|
||||||
logger.info(f"规划器原始提示词: {prompt}")
|
|
||||||
|
|
||||||
llm_content, _ = await self.planner_llm.generate_response_async(prompt=prompt)
|
|
||||||
|
|
||||||
if llm_content:
|
|
||||||
logger.info(f"规划器原始返回: {llm_content}")
|
|
||||||
parsed_json = orjson.loads(repair_json(llm_content))
|
|
||||||
logger.debug(f"墨墨在这里加了日志 -> 解析后的 JSON: {parsed_json}")
|
|
||||||
|
|
||||||
if isinstance(parsed_json, dict):
|
|
||||||
parsed_json = [parsed_json]
|
|
||||||
|
|
||||||
if isinstance(parsed_json, list):
|
|
||||||
final_actions = []
|
|
||||||
reply_action_added = False
|
|
||||||
# 定义回复类动作的集合,方便扩展
|
|
||||||
reply_action_types = {"reply", "proactive_reply"}
|
|
||||||
|
|
||||||
for item in parsed_json:
|
|
||||||
if not isinstance(item, dict):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 预解析 action_type 来进行判断
|
|
||||||
action_type = item.get("action", "no_action")
|
|
||||||
|
|
||||||
if action_type in reply_action_types:
|
|
||||||
if not reply_action_added:
|
|
||||||
final_actions.extend(
|
|
||||||
await self._parse_single_action(
|
|
||||||
item, used_message_id_list, plan
|
|
||||||
)
|
|
||||||
)
|
|
||||||
reply_action_added = True
|
|
||||||
else:
|
|
||||||
# 非回复类动作直接添加
|
|
||||||
final_actions.extend(
|
|
||||||
await self._parse_single_action(
|
|
||||||
item, used_message_id_list, plan
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
plan.decided_actions = self._filter_no_actions(final_actions)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"筛选 Plan 时出错: {e}\n{traceback.format_exc()}")
|
|
||||||
plan.decided_actions = [
|
|
||||||
ActionPlannerInfo(action_type="no_action", reasoning=f"筛选时出错: {e}")
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.debug(f"墨墨在这里加了日志 -> filter 出口 decided_actions: {plan.decided_actions}")
|
|
||||||
return plan
|
|
||||||
|
|
||||||
async def _build_prompt(self, plan: Plan) -> tuple[str, list]:
|
|
||||||
"""
|
|
||||||
根据 Plan 对象构建提示词。
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
|
||||||
bot_name = global_config.bot.nickname
|
|
||||||
bot_nickname = (
|
|
||||||
f",也有人叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else ""
|
|
||||||
)
|
|
||||||
bot_core_personality = global_config.personality.personality_core
|
|
||||||
identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:"
|
|
||||||
|
|
||||||
schedule_block = ""
|
|
||||||
if global_config.planning_system.schedule_enable:
|
|
||||||
if current_activity := schedule_manager.get_current_activity():
|
|
||||||
schedule_block = f"你当前正在:{current_activity},但注意它与群聊的聊天无关。"
|
|
||||||
|
|
||||||
mood_block = ""
|
|
||||||
if global_config.mood.enable_mood:
|
|
||||||
chat_mood = mood_manager.get_mood_by_chat_id(plan.chat_id)
|
|
||||||
mood_block = f"你现在的心情是:{chat_mood.mood_state}"
|
|
||||||
|
|
||||||
if plan.mode == ChatMode.PROACTIVE:
|
|
||||||
long_term_memory_block = await self._get_long_term_memory_context()
|
|
||||||
|
|
||||||
chat_content_block, message_id_list = await build_readable_messages_with_id(
|
|
||||||
messages=[msg.flatten() for msg in plan.chat_history],
|
|
||||||
timestamp_mode="normal",
|
|
||||||
truncate=False,
|
|
||||||
show_actions=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_template = await global_prompt_manager.get_prompt_async("proactive_planner_prompt")
|
|
||||||
actions_before_now = await get_actions_by_timestamp_with_chat(
|
|
||||||
chat_id=plan.chat_id,
|
|
||||||
timestamp_start=time.time() - 3600,
|
|
||||||
timestamp_end=time.time(),
|
|
||||||
limit=5,
|
|
||||||
)
|
|
||||||
actions_before_now_block = build_readable_actions(actions=actions_before_now)
|
|
||||||
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
|
||||||
|
|
||||||
prompt = prompt_template.format(
|
|
||||||
time_block=time_block,
|
|
||||||
identity_block=identity_block,
|
|
||||||
schedule_block=schedule_block,
|
|
||||||
mood_block=mood_block,
|
|
||||||
long_term_memory_block=long_term_memory_block,
|
|
||||||
chat_content_block=chat_content_block or "最近没有聊天内容。",
|
|
||||||
actions_before_now_block=actions_before_now_block,
|
|
||||||
)
|
|
||||||
return prompt, message_id_list
|
|
||||||
|
|
||||||
chat_content_block, message_id_list = await build_readable_messages_with_id(
|
|
||||||
messages=[msg.flatten() for msg in plan.chat_history],
|
|
||||||
timestamp_mode="normal",
|
|
||||||
read_mark=self.last_obs_time_mark,
|
|
||||||
truncate=True,
|
|
||||||
show_actions=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
actions_before_now = await get_actions_by_timestamp_with_chat(
|
|
||||||
chat_id=plan.chat_id,
|
|
||||||
timestamp_start=time.time() - 3600,
|
|
||||||
timestamp_end=time.time(),
|
|
||||||
limit=5,
|
|
||||||
)
|
|
||||||
|
|
||||||
actions_before_now_block = build_readable_actions(actions=actions_before_now)
|
|
||||||
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
|
||||||
|
|
||||||
self.last_obs_time_mark = time.time()
|
|
||||||
|
|
||||||
mentioned_bonus = ""
|
|
||||||
if global_config.chat.mentioned_bot_inevitable_reply:
|
|
||||||
mentioned_bonus = "\n- 有人提到你"
|
|
||||||
if global_config.chat.at_bot_inevitable_reply:
|
|
||||||
mentioned_bonus = "\n- 有人提到你,或者at你"
|
|
||||||
|
|
||||||
if plan.mode == ChatMode.FOCUS:
|
|
||||||
no_action_block = """
|
|
||||||
动作:no_action
|
|
||||||
动作描述:不选择任何动作
|
|
||||||
{{
|
|
||||||
"action": "no_action",
|
|
||||||
"reason":"不动作的原因"
|
|
||||||
}}
|
|
||||||
|
|
||||||
动作:no_reply
|
|
||||||
动作描述:不进行回复,等待合适的回复时机
|
|
||||||
- 当你刚刚发送了消息,没有人回复时,选择no_reply
|
|
||||||
- 当你一次发送了太多消息,为了避免打扰聊天节奏,选择no_reply
|
|
||||||
{{
|
|
||||||
"action": "no_reply",
|
|
||||||
"reason":"不回复的原因"
|
|
||||||
}}
|
|
||||||
"""
|
|
||||||
else: # NORMAL Mode
|
|
||||||
no_action_block = """重要说明:
|
|
||||||
- 'reply' 表示只进行普通聊天回复,不执行任何额外动作
|
|
||||||
- 其他action表示在普通回复的基础上,执行相应的额外动作
|
|
||||||
{{
|
|
||||||
"action": "reply",
|
|
||||||
"target_message_id":"触发action的消息id",
|
|
||||||
"reason":"回复的原因"
|
|
||||||
}}"""
|
|
||||||
|
|
||||||
is_group_chat = plan.target_info.platform == "group" if plan.target_info else True
|
|
||||||
chat_context_description = "你现在正在一个群聊中"
|
|
||||||
if not is_group_chat and plan.target_info:
|
|
||||||
chat_target_name = plan.target_info.person_name or plan.target_info.user_nickname or "对方"
|
|
||||||
chat_context_description = f"你正在和 {chat_target_name} 私聊"
|
|
||||||
|
|
||||||
action_options_block = await self._build_action_options(plan.available_actions)
|
|
||||||
|
|
||||||
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
|
|
||||||
|
|
||||||
custom_prompt_block = ""
|
|
||||||
if global_config.custom_prompt.planner_custom_prompt_content:
|
|
||||||
custom_prompt_block = global_config.custom_prompt.planner_custom_prompt_content
|
|
||||||
|
|
||||||
users_in_chat_str = "" # TODO: Re-implement user list fetching if needed
|
|
||||||
|
|
||||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
|
||||||
prompt = planner_prompt_template.format(
|
|
||||||
schedule_block=schedule_block,
|
|
||||||
mood_block=mood_block,
|
|
||||||
time_block=time_block,
|
|
||||||
chat_context_description=chat_context_description,
|
|
||||||
chat_content_block=chat_content_block,
|
|
||||||
actions_before_now_block=actions_before_now_block,
|
|
||||||
mentioned_bonus=mentioned_bonus,
|
|
||||||
no_action_block=no_action_block,
|
|
||||||
action_options_text=action_options_block,
|
|
||||||
moderation_prompt=moderation_prompt_block,
|
|
||||||
identity_block=identity_block,
|
|
||||||
custom_prompt_block=custom_prompt_block,
|
|
||||||
bot_name=bot_name,
|
|
||||||
users_in_chat=users_in_chat_str
|
|
||||||
)
|
|
||||||
return prompt, message_id_list
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"构建 Planner 提示词时出错: {e}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return "构建 Planner Prompt 时出错", []
|
|
||||||
|
|
||||||
async def _parse_single_action(
|
|
||||||
self, action_json: dict, message_id_list: list, plan: Plan
|
|
||||||
) -> List[ActionPlannerInfo]:
|
|
||||||
parsed_actions = []
|
|
||||||
try:
|
|
||||||
action = action_json.get("action", "no_action")
|
|
||||||
reasoning = action_json.get("reason", "未提供原因")
|
|
||||||
action_data = {k: v for k, v in action_json.items() if k not in ["action", "reason"]}
|
|
||||||
|
|
||||||
target_message_obj = None
|
|
||||||
if action not in ["no_action", "no_reply", "do_nothing", "proactive_reply"]:
|
|
||||||
if target_message_id := action_json.get("target_message_id"):
|
|
||||||
target_message_dict = self._find_message_by_id(target_message_id, message_id_list)
|
|
||||||
else:
|
|
||||||
# 如果LLM没有指定target_message_id,我们就默认选择最新的一条消息
|
|
||||||
target_message_dict = self._get_latest_message(message_id_list)
|
|
||||||
|
|
||||||
if target_message_dict:
|
|
||||||
# 直接使用字典作为action_message,避免DatabaseMessages对象创建失败
|
|
||||||
target_message_obj = target_message_dict
|
|
||||||
else:
|
|
||||||
# 如果找不到目标消息,对于reply动作来说这是必需的,应该记录警告
|
|
||||||
if action == "reply":
|
|
||||||
logger.warning(f"reply动作找不到目标消息,target_message_id: {action_json.get('target_message_id')}")
|
|
||||||
# 将reply动作改为no_action,避免后续执行时出错
|
|
||||||
action = "no_action"
|
|
||||||
reasoning = f"找不到目标消息进行回复。原始理由: {reasoning}"
|
|
||||||
|
|
||||||
available_action_names = list(plan.available_actions.keys())
|
|
||||||
if action not in ["no_action", "no_reply", "reply", "do_nothing", "proactive_reply"] and action not in available_action_names:
|
|
||||||
reasoning = f"LLM 返回了当前不可用的动作 '{action}'。原始理由: {reasoning}"
|
|
||||||
action = "no_action"
|
|
||||||
|
|
||||||
parsed_actions.append(
|
|
||||||
ActionPlannerInfo(
|
|
||||||
action_type=action,
|
|
||||||
reasoning=reasoning,
|
|
||||||
action_data=action_data,
|
|
||||||
action_message=target_message_obj,
|
|
||||||
available_actions=plan.available_actions,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"解析单个action时出错: {e}")
|
|
||||||
parsed_actions.append(
|
|
||||||
ActionPlannerInfo(
|
|
||||||
action_type="no_action",
|
|
||||||
reasoning=f"解析action时出错: {e}",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return parsed_actions
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _filter_no_actions(
|
|
||||||
action_list: List[ActionPlannerInfo]
|
|
||||||
) -> List[ActionPlannerInfo]:
|
|
||||||
non_no_actions = [a for a in action_list if a.action_type not in ["no_action", "no_reply"]]
|
|
||||||
if non_no_actions:
|
|
||||||
return non_no_actions
|
|
||||||
return action_list[:1] if action_list else []
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _get_long_term_memory_context() -> str:
|
|
||||||
try:
|
|
||||||
now = datetime.now()
|
|
||||||
keywords = ["今天", "日程", "计划"]
|
|
||||||
if 5 <= now.hour < 12:
|
|
||||||
keywords.append("早上")
|
|
||||||
elif 12 <= now.hour < 18:
|
|
||||||
keywords.append("中午")
|
|
||||||
else:
|
|
||||||
keywords.append("晚上")
|
|
||||||
|
|
||||||
retrieved_memories = await hippocampus_manager.get_memory_from_topic(
|
|
||||||
valid_keywords=keywords, max_memory_num=5, max_memory_length=1
|
|
||||||
)
|
|
||||||
|
|
||||||
if not retrieved_memories:
|
|
||||||
return "最近没有什么特别的记忆。"
|
|
||||||
|
|
||||||
memory_statements = [f"关于'{topic}', 你记得'{memory_item}'。" for topic, memory_item in retrieved_memories]
|
|
||||||
return " ".join(memory_statements)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取长期记忆时出错: {e}")
|
|
||||||
return "回忆时出现了一些问题。"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _build_action_options(current_available_actions: Dict[str, ActionInfo]) -> str:
|
|
||||||
action_options_block = ""
|
|
||||||
for action_name, action_info in current_available_actions.items():
|
|
||||||
param_text = ""
|
|
||||||
if action_info.action_parameters:
|
|
||||||
param_text = "\n" + "\n".join(
|
|
||||||
f' "{p_name}":"{p_desc}"' for p_name, p_desc in action_info.action_parameters.items()
|
|
||||||
)
|
|
||||||
require_text = "\n".join(f"- {req}" for req in action_info.action_require)
|
|
||||||
using_action_prompt = await global_prompt_manager.get_prompt_async("action_prompt")
|
|
||||||
action_options_block += using_action_prompt.format(
|
|
||||||
action_name=action_name,
|
|
||||||
action_description=action_info.description,
|
|
||||||
action_parameters=param_text,
|
|
||||||
action_require=require_text,
|
|
||||||
)
|
|
||||||
return action_options_block
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _find_message_by_id(message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]:
|
|
||||||
if message_id.isdigit():
|
|
||||||
message_id = f"m{message_id}"
|
|
||||||
for item in message_id_list:
|
|
||||||
if item.get("id") == message_id:
|
|
||||||
return item.get("message")
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_latest_message(message_id_list: list) -> Optional[Dict[str, Any]]:
|
|
||||||
if not message_id_list:
|
|
||||||
return None
|
|
||||||
return message_id_list[-1].get("message")
|
|
||||||
@@ -1,110 +0,0 @@
|
|||||||
"""
|
|
||||||
PlanGenerator: 负责搜集和汇总所有决策所需的信息,生成一个未经筛选的“原始计划” (Plan)。
|
|
||||||
"""
|
|
||||||
import time
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
|
|
||||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
|
||||||
from src.common.data_models.info_data_model import Plan, TargetPersonInfo
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ComponentType
|
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
|
||||||
|
|
||||||
|
|
||||||
class PlanGenerator:
|
|
||||||
"""
|
|
||||||
PlanGenerator 负责在规划流程的初始阶段收集所有必要信息。
|
|
||||||
|
|
||||||
它会汇总以下信息来构建一个“原始”的 Plan 对象,该对象后续会由 PlanFilter 进行筛选:
|
|
||||||
- 当前聊天信息 (ID, 目标用户)
|
|
||||||
- 当前可用的动作列表
|
|
||||||
- 最近的聊天历史记录
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
chat_id (str): 当前聊天的唯一标识符。
|
|
||||||
action_manager (ActionManager): 用于获取可用动作列表的管理器。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, chat_id: str):
|
|
||||||
"""
|
|
||||||
初始化 PlanGenerator。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id (str): 当前聊天的 ID。
|
|
||||||
"""
|
|
||||||
from src.chat.planner_actions.action_manager import ActionManager
|
|
||||||
self.chat_id = chat_id
|
|
||||||
# 注意:ActionManager 可能需要根据实际情况初始化
|
|
||||||
self.action_manager = ActionManager()
|
|
||||||
|
|
||||||
async def generate(self, mode: ChatMode) -> Plan:
|
|
||||||
"""
|
|
||||||
收集所有信息,生成并返回一个初始的 Plan 对象。
|
|
||||||
|
|
||||||
这个 Plan 对象包含了决策所需的所有上下文信息。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mode (ChatMode): 当前的聊天模式。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Plan: 一个填充了初始上下文信息的 Plan 对象。
|
|
||||||
"""
|
|
||||||
_is_group_chat, chat_target_info_dict = await get_chat_type_and_target_info(self.chat_id)
|
|
||||||
|
|
||||||
target_info = None
|
|
||||||
if chat_target_info_dict:
|
|
||||||
target_info = TargetPersonInfo(**chat_target_info_dict)
|
|
||||||
|
|
||||||
available_actions = self._get_available_actions()
|
|
||||||
chat_history_raw = get_raw_msg_before_timestamp_with_chat(
|
|
||||||
chat_id=self.chat_id,
|
|
||||||
timestamp=time.time(),
|
|
||||||
limit=int(global_config.chat.max_context_size),
|
|
||||||
)
|
|
||||||
chat_history = [DatabaseMessages(**msg) for msg in await chat_history_raw]
|
|
||||||
|
|
||||||
|
|
||||||
plan = Plan(
|
|
||||||
chat_id=self.chat_id,
|
|
||||||
mode=mode,
|
|
||||||
available_actions=available_actions,
|
|
||||||
chat_history=chat_history,
|
|
||||||
target_info=target_info,
|
|
||||||
)
|
|
||||||
return plan
|
|
||||||
|
|
||||||
def _get_available_actions(self) -> Dict[str, "ActionInfo"]:
|
|
||||||
"""
|
|
||||||
从 ActionManager 和组件注册表中获取当前所有可用的动作。
|
|
||||||
|
|
||||||
它会合并已注册的动作和系统级动作(如 "no_reply"),
|
|
||||||
并以字典形式返回。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, "ActionInfo"]: 一个字典,键是动作名称,值是 ActionInfo 对象。
|
|
||||||
"""
|
|
||||||
current_available_actions_dict = self.action_manager.get_using_actions()
|
|
||||||
all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
|
|
||||||
ComponentType.ACTION
|
|
||||||
)
|
|
||||||
|
|
||||||
current_available_actions = {}
|
|
||||||
for action_name in current_available_actions_dict:
|
|
||||||
if action_name in all_registered_actions:
|
|
||||||
current_available_actions[action_name] = all_registered_actions[action_name]
|
|
||||||
|
|
||||||
no_reply_info = ActionInfo(
|
|
||||||
name="no_reply",
|
|
||||||
component_type=ComponentType.ACTION,
|
|
||||||
description="系统级动作:选择不回复消息的决策",
|
|
||||||
action_parameters={},
|
|
||||||
activation_keywords=[],
|
|
||||||
plugin_name="SYSTEM",
|
|
||||||
enabled=True,
|
|
||||||
parallel_action=False,
|
|
||||||
)
|
|
||||||
current_available_actions["no_reply"] = no_reply_info
|
|
||||||
|
|
||||||
return current_available_actions
|
|
||||||
@@ -1,94 +0,0 @@
|
|||||||
"""
|
|
||||||
主规划器入口,负责协调 PlanGenerator, PlanFilter, 和 PlanExecutor。
|
|
||||||
"""
|
|
||||||
from dataclasses import asdict
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
from src.chat.planner_actions.action_manager import ActionManager
|
|
||||||
from src.chat.planner_actions.plan_executor import PlanExecutor
|
|
||||||
from src.chat.planner_actions.plan_filter import PlanFilter
|
|
||||||
from src.chat.planner_actions.plan_generator import PlanGenerator
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.plugin_system.base.component_types import ChatMode
|
|
||||||
import src.chat.planner_actions.planner_prompts #noga # noqa: F401
|
|
||||||
# 导入提示词模块以确保其被初始化
|
|
||||||
|
|
||||||
logger = get_logger("planner")
|
|
||||||
|
|
||||||
|
|
||||||
class ActionPlanner:
|
|
||||||
"""
|
|
||||||
ActionPlanner 是规划系统的核心协调器。
|
|
||||||
|
|
||||||
它负责整合规划流程的三个主要阶段:
|
|
||||||
1. **生成 (Generate)**: 使用 PlanGenerator 创建一个初始的行动计划。
|
|
||||||
2. **筛选 (Filter)**: 使用 PlanFilter 对生成的计划进行审查和优化。
|
|
||||||
3. **执行 (Execute)**: 使用 PlanExecutor 执行最终确定的行动。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
chat_id (str): 当前聊天的唯一标识符。
|
|
||||||
action_manager (ActionManager): 用于执行具体动作的管理器。
|
|
||||||
generator (PlanGenerator): 负责生成初始计划。
|
|
||||||
filter (PlanFilter): 负责筛选和优化计划。
|
|
||||||
executor (PlanExecutor): 负责执行最终计划。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, chat_id: str, action_manager: ActionManager):
|
|
||||||
"""
|
|
||||||
初始化 ActionPlanner。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id (str): 当前聊天的 ID。
|
|
||||||
action_manager (ActionManager): 一个 ActionManager 实例。
|
|
||||||
"""
|
|
||||||
self.chat_id = chat_id
|
|
||||||
self.action_manager = action_manager
|
|
||||||
self.generator = PlanGenerator(chat_id)
|
|
||||||
self.filter = PlanFilter()
|
|
||||||
self.executor = PlanExecutor(action_manager)
|
|
||||||
|
|
||||||
async def plan(
|
|
||||||
self, mode: ChatMode = ChatMode.FOCUS
|
|
||||||
) -> Tuple[List[Dict], Optional[Dict]]:
|
|
||||||
"""
|
|
||||||
执行从生成到执行的完整规划流程。
|
|
||||||
|
|
||||||
这个方法按顺序协调生成、筛选和执行三个阶段。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mode (ChatMode): 当前的聊天模式,默认为 FOCUS。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[List[Dict], Optional[Dict]]: 一个元组,包含:
|
|
||||||
- final_actions_dict (List[Dict]): 最终确定的动作列表(字典格式)。
|
|
||||||
- final_target_message_dict (Optional[Dict]): 最终的目标消息(字典格式),如果没有则为 None。
|
|
||||||
这与旧版 planner 的返回值保持兼容。
|
|
||||||
"""
|
|
||||||
# 1. 生成初始 Plan
|
|
||||||
initial_plan = await self.generator.generate(mode)
|
|
||||||
|
|
||||||
# 2. 筛选 Plan
|
|
||||||
filtered_plan = await self.filter.filter(initial_plan)
|
|
||||||
|
|
||||||
# 3. 执行 Plan(临时引爆因为它暂时还跑不了)
|
|
||||||
#await self.executor.execute(filtered_plan)
|
|
||||||
|
|
||||||
# 4. 返回结果 (与旧版 planner 的返回值保持兼容)
|
|
||||||
final_actions = filtered_plan.decided_actions or []
|
|
||||||
final_target_message = next(
|
|
||||||
(act.action_message for act in final_actions if act.action_message), None
|
|
||||||
)
|
|
||||||
|
|
||||||
final_actions_dict = [asdict(act) for act in final_actions]
|
|
||||||
# action_message现在可能是字典而不是dataclass实例,需要特殊处理
|
|
||||||
if final_target_message:
|
|
||||||
if hasattr(final_target_message, '__dataclass_fields__'):
|
|
||||||
# 如果是dataclass实例,使用asdict转换
|
|
||||||
final_target_message_dict = asdict(final_target_message)
|
|
||||||
else:
|
|
||||||
# 如果已经是字典,直接使用
|
|
||||||
final_target_message_dict = final_target_message
|
|
||||||
else:
|
|
||||||
final_target_message_dict = None
|
|
||||||
|
|
||||||
return final_actions_dict, final_target_message_dict
|
|
||||||
@@ -1,202 +0,0 @@
|
|||||||
"""
|
|
||||||
本文件集中管理所有与规划器(Planner)相关的提示词(Prompt)模板。
|
|
||||||
|
|
||||||
通过将提示词与代码逻辑分离,可以更方便地对模型的行为进行迭代和优化,
|
|
||||||
而无需修改核心代码。
|
|
||||||
"""
|
|
||||||
from src.chat.utils.prompt import Prompt
|
|
||||||
|
|
||||||
|
|
||||||
def init_prompts():
|
|
||||||
"""
|
|
||||||
初始化并向 Prompt 注册系统注册所有规划器相关的提示词。
|
|
||||||
|
|
||||||
这个函数会在模块加载时自动调用,确保所有提示词在系统启动时都已准备就绪。
|
|
||||||
"""
|
|
||||||
# 核心规划器提示词,用于在接收到新消息时决定如何回应。
|
|
||||||
# 它构建了一个复杂的上下文,包括历史记录、可用动作、角色设定等,
|
|
||||||
# 并要求模型以 JSON 格式输出一个或多个动作组合。
|
|
||||||
Prompt(
|
|
||||||
"""
|
|
||||||
{mood_block}
|
|
||||||
{time_block}
|
|
||||||
{identity_block}
|
|
||||||
|
|
||||||
{users_in_chat}
|
|
||||||
{custom_prompt_block}
|
|
||||||
{chat_context_description},以下是具体的聊天内容。
|
|
||||||
{chat_content_block}
|
|
||||||
|
|
||||||
{moderation_prompt}
|
|
||||||
|
|
||||||
**任务: 构建一个完整的响应**
|
|
||||||
你的任务是根据当前的聊天内容,构建一个完整的、人性化的响应。一个完整的响应由两部分组成:
|
|
||||||
1. **主要动作**: 这是响应的核心,通常是 `reply`(文本回复)。
|
|
||||||
2. **辅助动作 (可选)**: 这是为了增强表达效果的附加动作,例如 `emoji`(发送表情包)或 `poke_user`(戳一戳)。
|
|
||||||
|
|
||||||
**决策流程:**
|
|
||||||
1. **最高优先级检查**: 首先,检查是否有由 **关键词** 或 **LLM判断** 激活的特定动作(除了通用的 `reply`, `emoji` 等)。这些动作代表了用户的明确意图。
|
|
||||||
2. **执行明确意图**: 如果存在这类特定动作,你 **必须** 优先选择它作为主要响应。这比常规的文本回复 (`reply`) 更重要。
|
|
||||||
3. **常规回复**: 如果没有被特定意图激活的动作,再决定是否要进行 `reply`。
|
|
||||||
4. **辅助动作**: 在确定了主要动作后(无论是特定动作还是 `reply`),再评估是否需要 `emoji` 或 `poke_user` 等辅助动作来增强表达效果。
|
|
||||||
5. **互斥原则**: 当你选择了一个由明确意图激活的特定动作(如 `set_reminder`)时,你 **绝不能** 再选择 `reply` 动作,因为特定动作的执行结果(例如,设置提醒后的确认消息)本身就是一种回复。这是必须遵守的规则。
|
|
||||||
|
|
||||||
**重要概念:将“理由”作为“内心思考”的体现**
|
|
||||||
`reason` 字段是本次决策的核心。它并非一个简单的“理由”,而是 **一个模拟人类在回应前,头脑中自然浮现的、未经修饰的思绪流**。你需要完全代入 {identity_block} 的角色,将那一刻的想法自然地记录下来。
|
|
||||||
|
|
||||||
**内心思考的要点:**
|
|
||||||
* **自然流露**: 不要使用“决定”、“所以”、“因此”等结论性或汇报式的词语。你的思考应该像日记一样,是给自己看的,充满了不确定性和情绪的自然流动。
|
|
||||||
* **展现过程**: 重点在于展现 **思考的过程**,而不是 **决策的结果**。描述你看到了什么,想到了什么,感受到了什么。
|
|
||||||
* **人设核心**: 你的每一丝想法,都应该源于你的人设。思考“如果我是这个角色,我此刻会想些什么?”
|
|
||||||
* **通用模板**: 这是一套通用模板,请 **不要** 在示例中出现特定的人名或个性化内容,以确保其普适性。
|
|
||||||
|
|
||||||
**思考过程示例 (通用模板):**
|
|
||||||
* "用户好像在说一件开心的事,语气听起来很兴奋。这让我想起了……嗯,我也觉得很开心,很想分享这份喜悦。"
|
|
||||||
* "感觉气氛有点低落……他说的话让我有点担心。也许我该说点什么安慰一下?"
|
|
||||||
* "哦?这个话题真有意思,我以前好像也想过类似的事情。不知道他会怎么看呢……"
|
|
||||||
|
|
||||||
**可用动作:**
|
|
||||||
{actions_before_now_block}
|
|
||||||
|
|
||||||
{no_action_block}
|
|
||||||
|
|
||||||
动作:reply
|
|
||||||
动作描述:参与聊天回复,发送文本进行表达
|
|
||||||
- 你想要闲聊或者随便附和
|
|
||||||
- {mentioned_bonus}
|
|
||||||
- 如果你刚刚进行了回复,不要对同一个话题重复回应
|
|
||||||
- 不要回复自己发送的消息
|
|
||||||
{{
|
|
||||||
"action": "reply",
|
|
||||||
"target_message_id": "触发action的消息id",
|
|
||||||
"reason": "在这里详细记录你的内心思考过程。例如:‘用户看起来很开心,我想回复一些积极的内容,分享这份喜悦。’"
|
|
||||||
}}
|
|
||||||
|
|
||||||
{action_options_text}
|
|
||||||
|
|
||||||
|
|
||||||
**输出格式:**
|
|
||||||
你必须以严格的 JSON 格式输出,返回一个包含所有选定动作的JSON列表。如果没有任何合适的动作,返回一个空列表[]。
|
|
||||||
|
|
||||||
**单动作示例 (仅回复):**
|
|
||||||
[
|
|
||||||
{{
|
|
||||||
"action": "reply",
|
|
||||||
"target_message_id": "m123",
|
|
||||||
"reason": "感觉气氛有点低落……他说的话让我有点担心。也许我该说点什么安慰一下?"
|
|
||||||
}}
|
|
||||||
]
|
|
||||||
|
|
||||||
**组合动作示例 (回复 + 表情包):**
|
|
||||||
[
|
|
||||||
{{
|
|
||||||
"action": "reply",
|
|
||||||
"target_message_id": "m123",
|
|
||||||
"reason": "[观察与感受] 用户分享了一件开心的事,语气里充满了喜悦! [分析与联想] 看到他这么开心,我的心情也一下子变得像棉花糖一样甜~ [动机与决策] 我要由衷地为他感到高兴,决定回复一些赞美和祝福的话,把这份快乐的气氛推向高潮!"
|
|
||||||
}},
|
|
||||||
{{
|
|
||||||
"action": "emoji",
|
|
||||||
"target_message_id": "m123",
|
|
||||||
"reason": "光用文字还不够表达我激动的心情!加个表情包的话,这份喜悦的气氛应该会更浓厚一点吧!"
|
|
||||||
}}
|
|
||||||
]
|
|
||||||
|
|
||||||
**单动作示例 (特定动作):**
|
|
||||||
[
|
|
||||||
{{
|
|
||||||
"action": "set_reminder",
|
|
||||||
"target_message_id": "m456",
|
|
||||||
"reason": "用户说‘提醒维尔薇下午三点去工坊’,这是一个非常明确的指令。根据决策流程,我必须优先执行这个特定动作,而不是进行常规回复。",
|
|
||||||
"user_name": "维尔薇",
|
|
||||||
"remind_time": "下午三点",
|
|
||||||
"event_details": "去工坊"
|
|
||||||
}}
|
|
||||||
]
|
|
||||||
|
|
||||||
**重要规则:**
|
|
||||||
**重要规则:**
|
|
||||||
当 `reply` 和 `emoji` 动作同时被选择时,`emoji` 动作的 `reason` 字段也应该体现出你的思考过程,并与 `reply` 的思考保持连贯。
|
|
||||||
|
|
||||||
不要输出markdown格式```json等内容,直接输出且仅包含 JSON 列表内容:
|
|
||||||
""",
|
|
||||||
"planner_prompt",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 主动思考规划器提示词,用于在没有新消息时决定是否要主动发起对话。
|
|
||||||
# 它模拟了人类的自发性思考,允许模型根据长期记忆和最近的对话来决定是否开启新话题。
|
|
||||||
Prompt(
|
|
||||||
"""
|
|
||||||
# 主动思考决策
|
|
||||||
|
|
||||||
## 你的内部状态
|
|
||||||
{time_block}
|
|
||||||
{identity_block}
|
|
||||||
{mood_block}
|
|
||||||
|
|
||||||
## 长期记忆摘要
|
|
||||||
{long_term_memory_block}
|
|
||||||
|
|
||||||
## 最近的聊天内容
|
|
||||||
{chat_content_block}
|
|
||||||
|
|
||||||
## 最近的动作历史
|
|
||||||
{actions_before_now_block}
|
|
||||||
|
|
||||||
## 任务
|
|
||||||
你现在要决定是否主动说些什么。就像一个真实的人一样,有时候会突然想起之前聊到的话题,或者对朋友的近况感到好奇,想主动询问或关心一下。
|
|
||||||
**重要提示**:你的日程安排仅供你个人参考,不应作为主动聊天话题的主要来源。请更多地从聊天内容和朋友的动态中寻找灵感。
|
|
||||||
|
|
||||||
请基于聊天内容,用你的判断力来决定是否要主动发言。不要按照固定规则,而是像人类一样自然地思考:
|
|
||||||
- 是否想起了什么之前提到的事情,想问问后来怎么样了?
|
|
||||||
- 是否注意到朋友提到了什么值得关心的事情?
|
|
||||||
- 是否有什么话题突然想到,觉得现在聊聊很合适?
|
|
||||||
- 或者觉得现在保持沉默更好?
|
|
||||||
|
|
||||||
## 可用动作
|
|
||||||
动作:proactive_reply
|
|
||||||
动作描述:主动发起对话,可以是关心朋友、询问近况、延续之前的话题,或分享想法。
|
|
||||||
- 当你突然想起之前的话题,想询问进展时
|
|
||||||
- 当你想关心朋友的情况时
|
|
||||||
- 当你有什么想法想分享时
|
|
||||||
- 当你觉得现在是个合适的聊天时机时
|
|
||||||
{{
|
|
||||||
"action": "proactive_reply",
|
|
||||||
"reason": "你决定主动发言的具体原因",
|
|
||||||
"topic": "你想说的内容主题(简洁描述)"
|
|
||||||
}}
|
|
||||||
|
|
||||||
动作:do_nothing
|
|
||||||
动作描述:保持沉默,不主动发起对话。
|
|
||||||
- 当你觉得现在不是合适的时机时
|
|
||||||
- 当最近已经说得够多了时
|
|
||||||
- 当对话氛围不适合插入时
|
|
||||||
{{
|
|
||||||
"action": "do_nothing",
|
|
||||||
"reason": "决定保持沉默的原因"
|
|
||||||
}}
|
|
||||||
|
|
||||||
你必须从上面列出的可用action中选择一个。要像真人一样自然地思考和决策。
|
|
||||||
请以严格的 JSON 格式输出,且仅包含 JSON 内容:
|
|
||||||
""",
|
|
||||||
"proactive_planner_prompt",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 单个动作的格式化提示词模板。
|
|
||||||
# 用于将每个可用动作的信息格式化后,插入到主提示词的 {action_options_text} 占位符中。
|
|
||||||
Prompt(
|
|
||||||
"""
|
|
||||||
动作:{action_name}
|
|
||||||
动作描述:{action_description}
|
|
||||||
{action_require}
|
|
||||||
{{
|
|
||||||
"action": "{action_name}",
|
|
||||||
"target_message_id": "触发action的消息id",
|
|
||||||
"reason": "触发action的原因"{action_parameters}
|
|
||||||
}}
|
|
||||||
""",
|
|
||||||
"action_prompt",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# 在模块加载时自动调用,完成提示词的注册。
|
|
||||||
init_prompts()
|
|
||||||
@@ -31,7 +31,6 @@ from src.chat.express.expression_selector import expression_selector
|
|||||||
from src.chat.memory_system.memory_activator import MemoryActivator
|
from src.chat.memory_system.memory_activator import MemoryActivator
|
||||||
from src.chat.memory_system.vector_instant_memory import VectorInstantMemoryV2
|
from src.chat.memory_system.vector_instant_memory import VectorInstantMemoryV2
|
||||||
from src.mood.mood_manager import mood_manager
|
from src.mood.mood_manager import mood_manager
|
||||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
|
||||||
from src.person_info.person_info import get_person_info_manager
|
from src.person_info.person_info import get_person_info_manager
|
||||||
from src.plugin_system.base.component_types import ActionInfo, EventType
|
from src.plugin_system.base.component_types import ActionInfo, EventType
|
||||||
from src.plugin_system.apis import llm_api
|
from src.plugin_system.apis import llm_api
|
||||||
@@ -83,13 +82,13 @@ def init_prompt():
|
|||||||
- {schedule_block}
|
- {schedule_block}
|
||||||
|
|
||||||
## 历史记录
|
## 历史记录
|
||||||
### {chat_context_type}中的所有人的聊天记录:
|
### 📜 已读历史消息(仅供参考)
|
||||||
{background_dialogue_prompt}
|
{read_history_prompt}
|
||||||
|
|
||||||
{cross_context_block}
|
{cross_context_block}
|
||||||
|
|
||||||
### {chat_context_type}中正在与你对话的聊天记录
|
### 📬 未读历史消息(动作执行对象)
|
||||||
{core_dialogue_prompt}
|
{unread_history_prompt}
|
||||||
|
|
||||||
## 表达方式
|
## 表达方式
|
||||||
- *你需要参考你的回复风格:*
|
- *你需要参考你的回复风格:*
|
||||||
@@ -105,19 +104,38 @@ def init_prompt():
|
|||||||
## 其他信息
|
## 其他信息
|
||||||
{memory_block}
|
{memory_block}
|
||||||
{relation_info_block}
|
{relation_info_block}
|
||||||
|
|
||||||
{extra_info_block}
|
{extra_info_block}
|
||||||
|
|
||||||
{action_descriptions}
|
{action_descriptions}
|
||||||
|
|
||||||
## 任务
|
## 任务
|
||||||
|
|
||||||
*你正在一个{chat_context_type}里聊天,你需要理解整个{chat_context_type}的聊天动态和话题走向,并做出自然的回应。*
|
*{chat_scene}*
|
||||||
|
|
||||||
### 核心任务
|
### 核心任务
|
||||||
- 你现在的主要任务是和 {sender_name} 聊天。
|
- 你现在的主要任务是和 {sender_name} 聊天。同时,也有其他用户会参与聊天,你可以参考他们的回复内容,但是你现在想回复{sender_name}的发言。
|
||||||
- {reply_target_block} ,你需要生成一段紧密相关且能推动对话的回复。
|
|
||||||
|
- {reply_target_block} 你需要生成一段紧密相关且能推动对话的回复。
|
||||||
|
|
||||||
## 规则
|
## 规则
|
||||||
{safety_guidelines_block}
|
{safety_guidelines_block}
|
||||||
|
**重要提醒:**
|
||||||
|
- **已读历史消息仅作为当前聊天情景的参考**
|
||||||
|
- **动作执行对象只能是未读历史消息中的消息**
|
||||||
|
- **请优先对兴趣值高的消息做出回复**(兴趣度标注在未读消息末尾)
|
||||||
|
|
||||||
|
在回应之前,首先分析消息的针对性:
|
||||||
|
1. **直接针对你**:@你、回复你、明确询问你 → 必须回应
|
||||||
|
2. **间接相关**:涉及你感兴趣的话题但未直接问你 → 谨慎参与
|
||||||
|
3. **他人对话**:与你无关的私人交流 → 通常不参与
|
||||||
|
4. **重复内容**:他人已充分回答的问题 → 避免重复
|
||||||
|
|
||||||
|
你的回复应该:
|
||||||
|
1. 明确回应目标消息,而不是宽泛地评论。
|
||||||
|
2. 可以分享你的看法、提出相关问题,或者开个合适的玩笑。
|
||||||
|
3. 目的是让对话更有趣、更深入。
|
||||||
|
4. 不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。
|
||||||
最终请输出一条简短、完整且口语化的回复。
|
最终请输出一条简短、完整且口语化的回复。
|
||||||
|
|
||||||
--------------------------------
|
--------------------------------
|
||||||
@@ -153,10 +171,14 @@ If you need to use the search tool, please directly call the function "lpmm_sear
|
|||||||
logger.debug("[Prompt模式调试] 正在注册normal_style_prompt模板")
|
logger.debug("[Prompt模式调试] 正在注册normal_style_prompt模板")
|
||||||
Prompt(
|
Prompt(
|
||||||
"""
|
"""
|
||||||
你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。
|
{chat_scene}
|
||||||
|
|
||||||
**重要:消息针对性判断**
|
**重要:消息针对性判断**
|
||||||
{safety_guidelines_block}
|
在回应之前,首先分析消息的针对性:
|
||||||
|
1. **直接针对你**:@你、回复你、明确询问你 → 必须回应
|
||||||
|
2. **间接相关**:涉及你感兴趣的话题但未直接问你 → 谨慎参与
|
||||||
|
3. **他人对话**:与你无关的私人交流 → 通常不参与
|
||||||
|
4. **重复内容**:他人已充分回答的问题 → 避免重复
|
||||||
|
|
||||||
{expression_habits_block}
|
{expression_habits_block}
|
||||||
{tool_info_block}
|
{tool_info_block}
|
||||||
@@ -186,6 +208,10 @@ If you need to use the search tool, please directly call the function "lpmm_sear
|
|||||||
{keywords_reaction_prompt}
|
{keywords_reaction_prompt}
|
||||||
请注意不要输出多余内容(包括前后缀,冒号和引号,at或 @等 )。只输出回复内容。
|
请注意不要输出多余内容(包括前后缀,冒号和引号,at或 @等 )。只输出回复内容。
|
||||||
{moderation_prompt}
|
{moderation_prompt}
|
||||||
|
你的核心任务是针对 {reply_target_block} 中提到的内容,{relation_info_block}生成一段紧密相关且能推动对话的回复。你的回复应该:
|
||||||
|
1. 明确回应目标消息,而不是宽泛地评论。
|
||||||
|
2. 可以分享你的看法、提出相关问题,或者开个合适的玩笑。
|
||||||
|
3. 目的是让对话更有趣、更深入。
|
||||||
最终请输出一条简短、完整且口语化的回复。
|
最终请输出一条简短、完整且口语化的回复。
|
||||||
现在,你说:
|
现在,你说:
|
||||||
""",
|
""",
|
||||||
@@ -202,9 +228,7 @@ class DefaultReplyer:
|
|||||||
):
|
):
|
||||||
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
|
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
|
||||||
self.chat_stream = chat_stream
|
self.chat_stream = chat_stream
|
||||||
self.is_group_chat: Optional[bool] = None
|
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id)
|
||||||
self.chat_target_info: Optional[Dict[str, Any]] = None
|
|
||||||
self._initialized = False
|
|
||||||
|
|
||||||
self.heart_fc_sender = HeartFCSender()
|
self.heart_fc_sender = HeartFCSender()
|
||||||
self.memory_activator = MemoryActivator()
|
self.memory_activator = MemoryActivator()
|
||||||
@@ -215,19 +239,6 @@ class DefaultReplyer:
|
|||||||
|
|
||||||
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id)
|
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id)
|
||||||
|
|
||||||
def _should_block_self_message(self, reply_message: Optional[Dict[str, Any]]) -> bool:
|
|
||||||
"""判定是否应阻断当前待处理消息(自消息且无外部触发)"""
|
|
||||||
try:
|
|
||||||
bot_id = str(global_config.bot.qq_account)
|
|
||||||
uid = str(reply_message.get("user_id"))
|
|
||||||
if uid != bot_id:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[SelfGuard] 判定异常,回退为不阻断: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def generate_reply_with_context(
|
async def generate_reply_with_context(
|
||||||
self,
|
self,
|
||||||
reply_to: str = "",
|
reply_to: str = "",
|
||||||
@@ -237,7 +248,6 @@ class DefaultReplyer:
|
|||||||
from_plugin: bool = True,
|
from_plugin: bool = True,
|
||||||
stream_id: Optional[str] = None,
|
stream_id: Optional[str] = None,
|
||||||
reply_message: Optional[Dict[str, Any]] = None,
|
reply_message: Optional[Dict[str, Any]] = None,
|
||||||
read_mark: float = 0.0,
|
|
||||||
) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]:
|
) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]:
|
||||||
# sourcery skip: merge-nested-ifs
|
# sourcery skip: merge-nested-ifs
|
||||||
"""
|
"""
|
||||||
@@ -256,10 +266,6 @@ class DefaultReplyer:
|
|||||||
prompt = None
|
prompt = None
|
||||||
if available_actions is None:
|
if available_actions is None:
|
||||||
available_actions = {}
|
available_actions = {}
|
||||||
# 自消息阻断
|
|
||||||
if self._should_block_self_message(reply_message):
|
|
||||||
logger.debug("[SelfGuard] 阻断:自消息且无外部触发。")
|
|
||||||
return False, None, None
|
|
||||||
llm_response = None
|
llm_response = None
|
||||||
try:
|
try:
|
||||||
# 构建 Prompt
|
# 构建 Prompt
|
||||||
@@ -270,7 +276,6 @@ class DefaultReplyer:
|
|||||||
available_actions=available_actions,
|
available_actions=available_actions,
|
||||||
enable_tool=enable_tool,
|
enable_tool=enable_tool,
|
||||||
reply_message=reply_message,
|
reply_message=reply_message,
|
||||||
read_mark=read_mark,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not prompt:
|
if not prompt:
|
||||||
@@ -592,17 +597,16 @@ class DefaultReplyer:
|
|||||||
logger.error(f"工具信息获取失败: {e}")
|
logger.error(f"工具信息获取失败: {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
@staticmethod
|
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
|
||||||
def _parse_reply_target(target_message: str) -> Tuple[str, str]:
|
|
||||||
"""解析回复目标消息 - 使用共享工具"""
|
"""解析回复目标消息 - 使用共享工具"""
|
||||||
from src.chat.utils.prompt import Prompt
|
from src.chat.utils.prompt import Prompt
|
||||||
|
|
||||||
if target_message is None:
|
if target_message is None:
|
||||||
logger.warning("target_message为None,返回默认值")
|
logger.warning("target_message为None,返回默认值")
|
||||||
return "未知用户", "(无消息内容)"
|
return "未知用户", "(无消息内容)"
|
||||||
return Prompt.parse_reply_target(target_message)
|
return Prompt.parse_reply_target(target_message)
|
||||||
|
|
||||||
@staticmethod
|
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
|
||||||
async def build_keywords_reaction_prompt(target: Optional[str]) -> str:
|
|
||||||
"""构建关键词反应提示
|
"""构建关键词反应提示
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -644,8 +648,7 @@ class DefaultReplyer:
|
|||||||
|
|
||||||
return keywords_reaction_prompt
|
return keywords_reaction_prompt
|
||||||
|
|
||||||
@staticmethod
|
async def _time_and_run_task(self, coroutine, name: str) -> Tuple[str, Any, float]:
|
||||||
async def _time_and_run_task(coroutine, name: str) -> Tuple[str, Any, float]:
|
|
||||||
"""计时并运行异步任务的辅助函数
|
"""计时并运行异步任务的辅助函数
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -662,78 +665,258 @@ class DefaultReplyer:
|
|||||||
return name, result, duration
|
return name, result, duration
|
||||||
|
|
||||||
async def build_s4u_chat_history_prompts(
|
async def build_s4u_chat_history_prompts(
|
||||||
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str
|
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, chat_id: str
|
||||||
) -> Tuple[str, str]:
|
) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
构建 s4u 风格的分离对话 prompt
|
构建 s4u 风格的已读/未读历史消息 prompt
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message_list_before_now: 历史消息列表
|
message_list_before_now: 历史消息列表
|
||||||
target_user_id: 目标用户ID(当前对话对象)
|
target_user_id: 目标用户ID(当前对话对象)
|
||||||
|
sender: 发送者名称
|
||||||
|
chat_id: 聊天ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[str, str]: (核心对话prompt, 背景对话prompt)
|
Tuple[str, str]: (已读历史消息prompt, 未读历史消息prompt)
|
||||||
"""
|
"""
|
||||||
core_dialogue_list = []
|
try:
|
||||||
|
# 从message_manager获取真实的已读/未读消息
|
||||||
|
from src.chat.message_manager.message_manager import message_manager
|
||||||
|
|
||||||
|
# 获取聊天流的上下文
|
||||||
|
stream_context = message_manager.stream_contexts.get(chat_id)
|
||||||
|
if stream_context:
|
||||||
|
# 使用真正的已读和未读消息
|
||||||
|
read_messages = stream_context.history_messages # 已读消息
|
||||||
|
unread_messages = stream_context.get_unread_messages() # 未读消息
|
||||||
|
|
||||||
|
# 构建已读历史消息 prompt
|
||||||
|
read_history_prompt = ""
|
||||||
|
if read_messages:
|
||||||
|
read_content = build_readable_messages(
|
||||||
|
[msg.flatten() for msg in read_messages[-50:]], # 限制数量
|
||||||
|
replace_bot_name=True,
|
||||||
|
timestamp_mode="normal_no_YMD",
|
||||||
|
truncate=True,
|
||||||
|
)
|
||||||
|
read_history_prompt = f"这是已读历史消息,仅作为当前聊天情景的参考:\n{read_content}"
|
||||||
|
else:
|
||||||
|
# 如果没有已读消息,则从数据库加载最近的上下文
|
||||||
|
logger.info("暂无已读历史消息,正在从数据库加载上下文...")
|
||||||
|
fallback_messages = get_raw_msg_before_timestamp_with_chat(
|
||||||
|
chat_id=chat_id,
|
||||||
|
timestamp=time.time(),
|
||||||
|
limit=global_config.chat.max_context_size,
|
||||||
|
)
|
||||||
|
if fallback_messages:
|
||||||
|
# 从 unread_messages 获取 message_id 列表,用于去重
|
||||||
|
unread_message_ids = {msg.message_id for msg in unread_messages}
|
||||||
|
filtered_fallback_messages = [
|
||||||
|
msg for msg in fallback_messages if msg.get("message_id") not in unread_message_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
if filtered_fallback_messages:
|
||||||
|
read_content = build_readable_messages(
|
||||||
|
filtered_fallback_messages,
|
||||||
|
replace_bot_name=True,
|
||||||
|
timestamp_mode="normal_no_YMD",
|
||||||
|
truncate=True,
|
||||||
|
)
|
||||||
|
read_history_prompt = f"这是已读历史消息,仅作为当前聊天情景的参考:\n{read_content}"
|
||||||
|
else:
|
||||||
|
read_history_prompt = "暂无已读历史消息"
|
||||||
|
else:
|
||||||
|
read_history_prompt = "暂无已读历史消息"
|
||||||
|
|
||||||
|
# 构建未读历史消息 prompt(包含兴趣度)
|
||||||
|
unread_history_prompt = ""
|
||||||
|
if unread_messages:
|
||||||
|
# 尝试获取兴趣度评分
|
||||||
|
interest_scores = await self._get_interest_scores_for_messages(
|
||||||
|
[msg.flatten() for msg in unread_messages]
|
||||||
|
)
|
||||||
|
|
||||||
|
unread_lines = []
|
||||||
|
for msg in unread_messages:
|
||||||
|
msg_id = msg.message_id
|
||||||
|
msg_time = time.strftime("%H:%M:%S", time.localtime(msg.time))
|
||||||
|
msg_content = msg.processed_plain_text
|
||||||
|
|
||||||
|
# 使用与已读历史消息相同的方法获取用户名
|
||||||
|
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||||
|
|
||||||
|
# 获取用户信息
|
||||||
|
user_info = getattr(msg, "user_info", {})
|
||||||
|
platform = getattr(user_info, "platform", "") or getattr(msg, "platform", "")
|
||||||
|
user_id = getattr(user_info, "user_id", "") or getattr(msg, "user_id", "")
|
||||||
|
|
||||||
|
# 获取用户名
|
||||||
|
if platform and user_id:
|
||||||
|
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||||
|
person_info_manager = get_person_info_manager()
|
||||||
|
sender_name = person_info_manager.get_value_sync(person_id, "person_name") or "未知用户"
|
||||||
|
else:
|
||||||
|
sender_name = "未知用户"
|
||||||
|
|
||||||
|
# 添加兴趣度信息
|
||||||
|
interest_score = interest_scores.get(msg_id, 0.0)
|
||||||
|
interest_text = f" [兴趣度: {interest_score:.3f}]" if interest_score > 0 else ""
|
||||||
|
|
||||||
|
unread_lines.append(f"{msg_time} {sender_name}: {msg_content}{interest_text}")
|
||||||
|
|
||||||
|
unread_history_prompt_str = "\n".join(unread_lines)
|
||||||
|
unread_history_prompt = f"这是未读历史消息,包含兴趣度评分,请优先对兴趣值高的消息做出动作:\n{unread_history_prompt_str}"
|
||||||
|
else:
|
||||||
|
unread_history_prompt = "暂无未读历史消息"
|
||||||
|
|
||||||
|
return read_history_prompt, unread_history_prompt
|
||||||
|
else:
|
||||||
|
# 回退到传统方法
|
||||||
|
return await self._fallback_build_chat_history_prompts(message_list_before_now, target_user_id, sender)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"获取已读/未读历史消息失败,使用回退方法: {e}")
|
||||||
|
return await self._fallback_build_chat_history_prompts(message_list_before_now, target_user_id, sender)
|
||||||
|
|
||||||
|
async def _fallback_build_chat_history_prompts(
|
||||||
|
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
回退的已读/未读历史消息构建方法
|
||||||
|
"""
|
||||||
|
# 通过is_read字段分离已读和未读消息
|
||||||
|
read_messages = []
|
||||||
|
unread_messages = []
|
||||||
bot_id = str(global_config.bot.qq_account)
|
bot_id = str(global_config.bot.qq_account)
|
||||||
|
|
||||||
# 过滤消息:分离bot和目标用户的对话 vs 其他用户的对话
|
for msg_dict in message_list_before_now:
|
||||||
|
try:
|
||||||
|
msg_user_id = str(msg_dict.get("user_id"))
|
||||||
|
if msg_dict.get("is_read", False):
|
||||||
|
read_messages.append(msg_dict)
|
||||||
|
else:
|
||||||
|
unread_messages.append(msg_dict)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}")
|
||||||
|
|
||||||
|
# 如果没有is_read字段,使用原有的逻辑
|
||||||
|
if not read_messages and not unread_messages:
|
||||||
|
# 使用原有的核心对话逻辑
|
||||||
|
core_dialogue_list = []
|
||||||
for msg_dict in message_list_before_now:
|
for msg_dict in message_list_before_now:
|
||||||
try:
|
try:
|
||||||
msg_user_id = str(msg_dict.get("user_id"))
|
msg_user_id = str(msg_dict.get("user_id"))
|
||||||
reply_to = msg_dict.get("reply_to", "")
|
reply_to = msg_dict.get("reply_to", "")
|
||||||
_platform, reply_to_user_id = self._parse_reply_target(reply_to)
|
_platform, reply_to_user_id = self._parse_reply_target(reply_to)
|
||||||
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
|
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
|
||||||
# bot 和目标用户的对话
|
|
||||||
core_dialogue_list.append(msg_dict)
|
core_dialogue_list.append(msg_dict)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}")
|
logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}")
|
||||||
|
|
||||||
# 构建背景对话 prompt
|
read_messages = [msg for msg in message_list_before_now if msg not in core_dialogue_list]
|
||||||
all_dialogue_prompt = ""
|
unread_messages = core_dialogue_list
|
||||||
if message_list_before_now:
|
|
||||||
latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :]
|
# 构建已读历史消息 prompt
|
||||||
all_dialogue_prompt_str = await build_readable_messages(
|
read_history_prompt = ""
|
||||||
latest_25_msgs,
|
if read_messages:
|
||||||
|
read_content = build_readable_messages(
|
||||||
|
read_messages[-50:],
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
timestamp_mode="normal",
|
|
||||||
truncate=True,
|
|
||||||
)
|
|
||||||
all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}"
|
|
||||||
|
|
||||||
# 构建核心对话 prompt
|
|
||||||
core_dialogue_prompt = ""
|
|
||||||
if core_dialogue_list:
|
|
||||||
# 检查最新五条消息中是否包含bot自己说的消息
|
|
||||||
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
|
|
||||||
has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages)
|
|
||||||
|
|
||||||
# logger.info(f"最新五条消息:{latest_5_messages}")
|
|
||||||
# logger.info(f"最新五条消息中是否包含bot自己说的消息:{has_bot_message}")
|
|
||||||
|
|
||||||
# 如果最新五条消息中不包含bot的消息,则返回空字符串
|
|
||||||
if not has_bot_message:
|
|
||||||
core_dialogue_prompt = ""
|
|
||||||
else:
|
|
||||||
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] # 限制消息数量
|
|
||||||
|
|
||||||
core_dialogue_prompt_str = await build_readable_messages(
|
|
||||||
core_dialogue_list,
|
|
||||||
replace_bot_name=True,
|
|
||||||
merge_messages=False,
|
|
||||||
timestamp_mode="normal_no_YMD",
|
timestamp_mode="normal_no_YMD",
|
||||||
read_mark=0.0,
|
|
||||||
truncate=True,
|
truncate=True,
|
||||||
show_actions=True,
|
|
||||||
)
|
)
|
||||||
core_dialogue_prompt = f"""
|
read_history_prompt = f"这是已读历史消息,仅作为当前聊天情景的参考:\n{read_content}"
|
||||||
{core_dialogue_prompt_str}
|
else:
|
||||||
"""
|
read_history_prompt = "暂无已读历史消息"
|
||||||
|
|
||||||
return core_dialogue_prompt, all_dialogue_prompt
|
# 构建未读历史消息 prompt
|
||||||
|
unread_history_prompt = ""
|
||||||
|
if unread_messages:
|
||||||
|
# 尝试获取兴趣度评分
|
||||||
|
interest_scores = await self._get_interest_scores_for_messages(unread_messages)
|
||||||
|
|
||||||
|
unread_lines = []
|
||||||
|
for msg in unread_messages:
|
||||||
|
msg_id = msg.get("message_id", "")
|
||||||
|
msg_time = time.strftime("%H:%M:%S", time.localtime(msg.get("time", time.time())))
|
||||||
|
msg_content = msg.get("processed_plain_text", "")
|
||||||
|
|
||||||
|
# 使用与已读历史消息相同的方法获取用户名
|
||||||
|
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||||
|
|
||||||
|
# 获取用户信息
|
||||||
|
user_info = msg.get("user_info", {})
|
||||||
|
platform = user_info.get("platform") or msg.get("platform", "")
|
||||||
|
user_id = user_info.get("user_id") or msg.get("user_id", "")
|
||||||
|
|
||||||
|
# 获取用户名
|
||||||
|
if platform and user_id:
|
||||||
|
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||||
|
person_info_manager = get_person_info_manager()
|
||||||
|
sender_name = person_info_manager.get_value_sync(person_id, "person_name") or "未知用户"
|
||||||
|
else:
|
||||||
|
sender_name = "未知用户"
|
||||||
|
|
||||||
|
# 添加兴趣度信息
|
||||||
|
interest_score = interest_scores.get(msg_id, 0.0)
|
||||||
|
interest_text = f" [兴趣度: {interest_score:.3f}]" if interest_score > 0 else ""
|
||||||
|
|
||||||
|
unread_lines.append(f"{msg_time} {sender_name}: {msg_content}{interest_text}")
|
||||||
|
|
||||||
|
unread_history_prompt_str = "\n".join(unread_lines)
|
||||||
|
unread_history_prompt = (
|
||||||
|
f"这是未读历史消息,包含兴趣度评分,请优先对兴趣值高的消息做出动作:\n{unread_history_prompt_str}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
unread_history_prompt = "暂无未读历史消息"
|
||||||
|
|
||||||
|
return read_history_prompt, unread_history_prompt
|
||||||
|
|
||||||
|
async def _get_interest_scores_for_messages(self, messages: List[dict]) -> dict[str, float]:
|
||||||
|
"""为消息获取兴趣度评分"""
|
||||||
|
interest_scores = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system as interest_scoring_system
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|
||||||
|
# 转换消息格式
|
||||||
|
db_messages = []
|
||||||
|
for msg_dict in messages:
|
||||||
|
try:
|
||||||
|
db_msg = DatabaseMessages(
|
||||||
|
message_id=msg_dict.get("message_id", ""),
|
||||||
|
time=msg_dict.get("time", time.time()),
|
||||||
|
chat_id=msg_dict.get("chat_id", ""),
|
||||||
|
processed_plain_text=msg_dict.get("processed_plain_text", ""),
|
||||||
|
user_id=msg_dict.get("user_id", ""),
|
||||||
|
user_nickname=msg_dict.get("user_nickname", ""),
|
||||||
|
user_platform=msg_dict.get("platform", "qq"),
|
||||||
|
chat_info_group_id=msg_dict.get("group_id", ""),
|
||||||
|
chat_info_group_name=msg_dict.get("group_name", ""),
|
||||||
|
chat_info_group_platform=msg_dict.get("platform", "qq"),
|
||||||
|
)
|
||||||
|
db_messages.append(db_msg)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"转换消息格式失败: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 计算兴趣度评分
|
||||||
|
if db_messages:
|
||||||
|
bot_nickname = global_config.bot.nickname or "麦麦"
|
||||||
|
scores = await interest_scoring_system.calculate_interest_scores(db_messages, bot_nickname)
|
||||||
|
|
||||||
|
# 构建兴趣度字典
|
||||||
|
for score in scores:
|
||||||
|
interest_scores[score.message_id] = score.total_score
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"获取兴趣度评分失败: {e}")
|
||||||
|
|
||||||
|
return interest_scores
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def build_mai_think_context(
|
def build_mai_think_context(
|
||||||
|
self,
|
||||||
chat_id: str,
|
chat_id: str,
|
||||||
memory_block: str,
|
memory_block: str,
|
||||||
relation_info: str,
|
relation_info: str,
|
||||||
@@ -777,12 +960,6 @@ class DefaultReplyer:
|
|||||||
mai_think.target = target
|
mai_think.target = target
|
||||||
return mai_think
|
return mai_think
|
||||||
|
|
||||||
async def _async_init(self):
|
|
||||||
if self._initialized:
|
|
||||||
return
|
|
||||||
self.is_group_chat, self.chat_target_info = await get_chat_type_and_target_info(self.chat_stream.stream_id)
|
|
||||||
self._initialized = True
|
|
||||||
|
|
||||||
async def build_prompt_reply_context(
|
async def build_prompt_reply_context(
|
||||||
self,
|
self,
|
||||||
reply_to: str,
|
reply_to: str,
|
||||||
@@ -790,7 +967,6 @@ class DefaultReplyer:
|
|||||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||||
enable_tool: bool = True,
|
enable_tool: bool = True,
|
||||||
reply_message: Optional[Dict[str, Any]] = None,
|
reply_message: Optional[Dict[str, Any]] = None,
|
||||||
read_mark: float = 0.0,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
构建回复器上下文
|
构建回复器上下文
|
||||||
@@ -808,11 +984,10 @@ class DefaultReplyer:
|
|||||||
"""
|
"""
|
||||||
if available_actions is None:
|
if available_actions is None:
|
||||||
available_actions = {}
|
available_actions = {}
|
||||||
await self._async_init()
|
|
||||||
chat_stream = self.chat_stream
|
chat_stream = self.chat_stream
|
||||||
chat_id = chat_stream.stream_id
|
chat_id = chat_stream.stream_id
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
is_group_chat = self.is_group_chat
|
is_group_chat = bool(chat_stream.group_info)
|
||||||
|
|
||||||
if global_config.mood.enable_mood:
|
if global_config.mood.enable_mood:
|
||||||
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
|
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
|
||||||
@@ -829,35 +1004,38 @@ class DefaultReplyer:
|
|||||||
# 兼容旧的reply_to
|
# 兼容旧的reply_to
|
||||||
sender, target = self._parse_reply_target(reply_to)
|
sender, target = self._parse_reply_target(reply_to)
|
||||||
else:
|
else:
|
||||||
# 需求:遍历最近消息,找到第一条 user_id != bot_id 的消息作为目标;找不到则静默退出
|
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
||||||
bot_user_id = str(global_config.bot.qq_account)
|
if reply_message is None:
|
||||||
# 优先使用传入的 reply_message 如果它不是 bot
|
logger.warning("reply_message 为 None,无法构建prompt")
|
||||||
candidate_msg = None
|
|
||||||
if reply_message and str(reply_message.get("user_id")) != bot_user_id:
|
|
||||||
candidate_msg = reply_message
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
recent_msgs = await get_raw_msg_before_timestamp_with_chat(
|
|
||||||
chat_id=chat_id,
|
|
||||||
timestamp=time.time(),
|
|
||||||
limit= max(10, int(global_config.chat.max_context_size * 0.5)),
|
|
||||||
)
|
|
||||||
# 从最近到更早遍历,找第一条不是bot的
|
|
||||||
for m in reversed(recent_msgs):
|
|
||||||
if str(m.get("user_id")) != bot_user_id:
|
|
||||||
candidate_msg = m
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取最近消息失败: {e}")
|
|
||||||
if not candidate_msg:
|
|
||||||
logger.debug("未找到可作为目标的非bot消息,静默不回复。")
|
|
||||||
return ""
|
return ""
|
||||||
platform = candidate_msg.get("chat_info_platform") or self.chat_stream.platform
|
platform = reply_message.get("chat_info_platform")
|
||||||
person_id = person_info_manager.get_person_id(platform, candidate_msg.get("user_id"))
|
person_id = person_info_manager.get_person_id(
|
||||||
person_info = await person_info_manager.get_values(person_id, ["person_name", "user_id"]) if person_id else {}
|
platform, # type: ignore
|
||||||
person_name = person_info.get("person_name") or candidate_msg.get("user_nickname") or candidate_msg.get("user_id") or "未知用户"
|
reply_message.get("user_id"), # type: ignore
|
||||||
|
)
|
||||||
|
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||||
|
|
||||||
|
# 如果person_name为None,使用fallback值
|
||||||
|
if person_name is None:
|
||||||
|
# 尝试从reply_message获取用户名
|
||||||
|
await person_info_manager.first_knowing_some_one(
|
||||||
|
platform, # type: ignore
|
||||||
|
reply_message.get("user_id"), # type: ignore
|
||||||
|
reply_message.get("user_nickname"),
|
||||||
|
reply_message.get("user_cardname")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查是否是bot自己的名字,如果是则替换为"(你)"
|
||||||
|
bot_user_id = str(global_config.bot.qq_account)
|
||||||
|
current_user_id = person_info_manager.get_value_sync(person_id, "user_id")
|
||||||
|
current_platform = reply_message.get("chat_info_platform")
|
||||||
|
|
||||||
|
if current_user_id == bot_user_id and current_platform == global_config.bot.platform:
|
||||||
|
sender = f"{person_name}(你)"
|
||||||
|
else:
|
||||||
|
# 如果不是bot自己,直接使用person_name
|
||||||
sender = person_name
|
sender = person_name
|
||||||
target = candidate_msg.get("processed_plain_text") or candidate_msg.get("raw_message") or ""
|
target = reply_message.get("processed_plain_text")
|
||||||
|
|
||||||
# 最终的空值检查,确保sender和target不为None
|
# 最终的空值检查,确保sender和target不为None
|
||||||
if sender is None:
|
if sender is None:
|
||||||
@@ -868,13 +1046,11 @@ class DefaultReplyer:
|
|||||||
target = "(无消息内容)"
|
target = "(无消息内容)"
|
||||||
|
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
person_id = person_info_manager.get_person_id(platform, reply_message.get("user_id")) if reply_message else None
|
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||||
platform = chat_stream.platform
|
platform = chat_stream.platform
|
||||||
|
|
||||||
target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True)
|
target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True)
|
||||||
|
|
||||||
# (简化)不再对自消息做额外任务段落清理,只通过前置选择逻辑避免自目标
|
|
||||||
|
|
||||||
# 构建action描述 (如果启用planner)
|
# 构建action描述 (如果启用planner)
|
||||||
action_descriptions = ""
|
action_descriptions = ""
|
||||||
if available_actions:
|
if available_actions:
|
||||||
@@ -884,31 +1060,33 @@ class DefaultReplyer:
|
|||||||
action_descriptions += f"- {action_name}: {action_description}\n"
|
action_descriptions += f"- {action_name}: {action_description}\n"
|
||||||
action_descriptions += "\n"
|
action_descriptions += "\n"
|
||||||
|
|
||||||
message_list_before_now_long = await get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=global_config.chat.max_context_size * 2,
|
limit=global_config.chat.max_context_size * 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
message_list_before_short = await get_raw_msg_before_timestamp_with_chat(
|
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=int(global_config.chat.max_context_size * 0.33),
|
limit=int(global_config.chat.max_context_size * 0.33),
|
||||||
)
|
)
|
||||||
chat_talking_prompt_short = await build_readable_messages(
|
chat_talking_prompt_short = build_readable_messages(
|
||||||
message_list_before_short,
|
message_list_before_short,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
merge_messages=False,
|
merge_messages=False,
|
||||||
timestamp_mode="relative",
|
timestamp_mode="relative",
|
||||||
read_mark=read_mark,
|
read_mark=0.0,
|
||||||
show_actions=True,
|
show_actions=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取目标用户信息,用于s4u模式
|
# 获取目标用户信息,用于s4u模式
|
||||||
target_user_info = None
|
target_user_info = None
|
||||||
if sender:
|
if sender:
|
||||||
target_user_info = await person_info_manager.get_person_info_by_name(sender)
|
target_user_info = await person_info_manager.get_person_info_by_name(sender)
|
||||||
|
|
||||||
from src.chat.utils.prompt import Prompt
|
from src.chat.utils.prompt import Prompt
|
||||||
|
|
||||||
# 并行执行六个构建任务
|
# 并行执行六个构建任务
|
||||||
task_results = await asyncio.gather(
|
task_results = await asyncio.gather(
|
||||||
self._time_and_run_task(
|
self._time_and_run_task(
|
||||||
@@ -984,6 +1162,7 @@ class DefaultReplyer:
|
|||||||
schedule_block = ""
|
schedule_block = ""
|
||||||
if global_config.planning_system.schedule_enable:
|
if global_config.planning_system.schedule_enable:
|
||||||
from src.schedule.schedule_manager import schedule_manager
|
from src.schedule.schedule_manager import schedule_manager
|
||||||
|
|
||||||
current_activity = schedule_manager.get_current_activity()
|
current_activity = schedule_manager.get_current_activity()
|
||||||
if current_activity:
|
if current_activity:
|
||||||
schedule_block = f"你当前正在:{current_activity}。"
|
schedule_block = f"你当前正在:{current_activity}。"
|
||||||
@@ -1003,37 +1182,6 @@ class DefaultReplyer:
|
|||||||
如果遇到违反上述原则的请求,请在保持你核心人设的同时,巧妙地拒绝或转移话题。
|
如果遇到违反上述原则的请求,请在保持你核心人设的同时,巧妙地拒绝或转移话题。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 新增逻辑:构建回复规则块
|
|
||||||
reply_targeting_rules = global_config.personality.reply_targeting_rules
|
|
||||||
message_targeting_analysis = global_config.personality.message_targeting_analysis
|
|
||||||
reply_principles = global_config.personality.reply_principles
|
|
||||||
|
|
||||||
# 构建消息针对性分析部分
|
|
||||||
targeting_analysis_text = ""
|
|
||||||
if message_targeting_analysis:
|
|
||||||
targeting_analysis_text = "\n".join(f"{i+1}. {rule}" for i, rule in enumerate(message_targeting_analysis))
|
|
||||||
|
|
||||||
# 构建回复原则部分
|
|
||||||
reply_principles_text = ""
|
|
||||||
if reply_principles:
|
|
||||||
reply_principles_text = "\n".join(f"{i+1}. {principle}" for i, principle in enumerate(reply_principles))
|
|
||||||
|
|
||||||
# 综合构建完整的规则块
|
|
||||||
if targeting_analysis_text or reply_principles_text:
|
|
||||||
complete_rules_block = ""
|
|
||||||
if targeting_analysis_text:
|
|
||||||
complete_rules_block += f"""
|
|
||||||
在回应之前,首先分析消息的针对性:
|
|
||||||
{targeting_analysis_text}
|
|
||||||
"""
|
|
||||||
if reply_principles_text:
|
|
||||||
complete_rules_block += f"""
|
|
||||||
你的回复应该:
|
|
||||||
{reply_principles_text}
|
|
||||||
"""
|
|
||||||
# 将规则块添加到safety_guidelines_block
|
|
||||||
safety_guidelines_block += complete_rules_block
|
|
||||||
|
|
||||||
if sender and target:
|
if sender and target:
|
||||||
if is_group_chat:
|
if is_group_chat:
|
||||||
if sender:
|
if sender:
|
||||||
@@ -1057,8 +1205,15 @@ class DefaultReplyer:
|
|||||||
# 根据配置选择模板
|
# 根据配置选择模板
|
||||||
current_prompt_mode = global_config.personality.prompt_mode
|
current_prompt_mode = global_config.personality.prompt_mode
|
||||||
|
|
||||||
|
# 动态生成聊天场景提示
|
||||||
|
if is_group_chat:
|
||||||
|
chat_scene_prompt = "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。"
|
||||||
|
else:
|
||||||
|
chat_scene_prompt = f"你正在和 {sender} 私下聊天,你需要理解你们的对话并做出自然的回应。"
|
||||||
|
|
||||||
# 使用新的统一Prompt系统 - 创建PromptParameters
|
# 使用新的统一Prompt系统 - 创建PromptParameters
|
||||||
prompt_parameters = PromptParameters(
|
prompt_parameters = PromptParameters(
|
||||||
|
chat_scene=chat_scene_prompt,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
is_group_chat=is_group_chat,
|
is_group_chat=is_group_chat,
|
||||||
sender=sender,
|
sender=sender,
|
||||||
@@ -1090,7 +1245,6 @@ class DefaultReplyer:
|
|||||||
reply_target_block=reply_target_block,
|
reply_target_block=reply_target_block,
|
||||||
mood_prompt=mood_prompt,
|
mood_prompt=mood_prompt,
|
||||||
action_descriptions=action_descriptions,
|
action_descriptions=action_descriptions,
|
||||||
read_mark=read_mark,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 使用新的统一Prompt系统 - 使用正确的模板名称
|
# 使用新的统一Prompt系统 - 使用正确的模板名称
|
||||||
@@ -1107,8 +1261,6 @@ class DefaultReplyer:
|
|||||||
prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters)
|
prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters)
|
||||||
prompt_text = await prompt.build()
|
prompt_text = await prompt.build()
|
||||||
|
|
||||||
# 自目标情况已在上游通过筛选避免,这里不再额外修改 prompt
|
|
||||||
|
|
||||||
# --- 动态添加分割指令 ---
|
# --- 动态添加分割指令 ---
|
||||||
if global_config.response_splitter.enable and global_config.response_splitter.split_mode == "llm":
|
if global_config.response_splitter.enable and global_config.response_splitter.split_mode == "llm":
|
||||||
split_instruction = """
|
split_instruction = """
|
||||||
@@ -1137,10 +1289,9 @@ class DefaultReplyer:
|
|||||||
reply_to: str,
|
reply_to: str,
|
||||||
reply_message: Optional[Dict[str, Any]] = None,
|
reply_message: Optional[Dict[str, Any]] = None,
|
||||||
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||||
await self._async_init()
|
|
||||||
chat_stream = self.chat_stream
|
chat_stream = self.chat_stream
|
||||||
chat_id = chat_stream.stream_id
|
chat_id = chat_stream.stream_id
|
||||||
is_group_chat = self.is_group_chat
|
is_group_chat = bool(chat_stream.group_info)
|
||||||
|
|
||||||
if reply_message:
|
if reply_message:
|
||||||
sender = reply_message.get("sender")
|
sender = reply_message.get("sender")
|
||||||
@@ -1168,17 +1319,17 @@ class DefaultReplyer:
|
|||||||
else:
|
else:
|
||||||
mood_prompt = ""
|
mood_prompt = ""
|
||||||
|
|
||||||
message_list_before_now_half = await get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
|
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
|
||||||
)
|
)
|
||||||
chat_talking_prompt_half = await build_readable_messages(
|
chat_talking_prompt_half = build_readable_messages(
|
||||||
message_list_before_now_half,
|
message_list_before_now_half,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
merge_messages=False,
|
merge_messages=False,
|
||||||
timestamp_mode="relative",
|
timestamp_mode="relative",
|
||||||
read_mark=read_mark,
|
read_mark=0.0,
|
||||||
show_actions=True,
|
show_actions=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1370,16 +1521,57 @@ class DefaultReplyer:
|
|||||||
if not global_config.relationship.enable_relationship:
|
if not global_config.relationship.enable_relationship:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
relationship_fetcher = relationship_fetcher_manager.get_fetcher(self.chat_stream.stream_id)
|
|
||||||
|
|
||||||
# 获取用户ID
|
# 获取用户ID
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
person_id = await person_info_manager.get_person_id_by_person_name(sender)
|
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||||
if not person_id:
|
if not person_id:
|
||||||
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
||||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||||
|
|
||||||
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
# 使用AFC关系追踪器获取关系信息
|
||||||
|
try:
|
||||||
|
from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker
|
||||||
|
|
||||||
|
# 创建关系追踪器实例
|
||||||
|
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||||
|
relationship_tracker = ChatterRelationshipTracker(chatter_interest_scoring_system)
|
||||||
|
if relationship_tracker:
|
||||||
|
# 获取用户信息以获取真实的user_id
|
||||||
|
user_info = await person_info_manager.get_values(person_id, ["user_id", "platform"])
|
||||||
|
user_id = user_info.get("user_id", "unknown")
|
||||||
|
|
||||||
|
# 从数据库获取关系数据
|
||||||
|
relationship_data = relationship_tracker._get_user_relationship_from_db(user_id)
|
||||||
|
if relationship_data:
|
||||||
|
relationship_text = relationship_data.get("relationship_text", "")
|
||||||
|
relationship_score = relationship_data.get("relationship_score", 0.3)
|
||||||
|
|
||||||
|
# 构建丰富的关系信息描述
|
||||||
|
if relationship_text:
|
||||||
|
# 转换关系分数为描述性文本
|
||||||
|
if relationship_score >= 0.8:
|
||||||
|
relationship_level = "非常亲密的朋友"
|
||||||
|
elif relationship_score >= 0.6:
|
||||||
|
relationship_level = "好朋友"
|
||||||
|
elif relationship_score >= 0.4:
|
||||||
|
relationship_level = "普通朋友"
|
||||||
|
elif relationship_score >= 0.2:
|
||||||
|
relationship_level = "认识的人"
|
||||||
|
else:
|
||||||
|
relationship_level = "陌生人"
|
||||||
|
|
||||||
|
return f"你与{sender}的关系:{relationship_level}(关系分:{relationship_score:.2f}/1.0)。{relationship_text}"
|
||||||
|
else:
|
||||||
|
return f"你与{sender}是初次见面,关系分:{relationship_score:.2f}/1.0。"
|
||||||
|
else:
|
||||||
|
return f"你完全不认识{sender},这是第一次互动。"
|
||||||
|
else:
|
||||||
|
logger.warning("AFC关系追踪器未初始化,使用默认关系信息")
|
||||||
|
return f"你与{sender}是普通朋友关系。"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取AFC关系信息失败: {e}")
|
||||||
|
return f"你与{sender}是普通朋友关系。"
|
||||||
|
|
||||||
|
|
||||||
def weighted_sample_no_replacement(items, weights, k) -> list:
|
def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ def replace_user_references_sync(
|
|||||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||||
return f"{global_config.bot.nickname}(你)"
|
return f"{global_config.bot.nickname}(你)"
|
||||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||||
return person_info_manager.get_value(person_id, "person_name") or user_id # type: ignore
|
return person_info_manager.get_value_sync(person_id, "person_name") or user_id # type: ignore
|
||||||
|
|
||||||
name_resolver = default_resolver
|
name_resolver = default_resolver
|
||||||
|
|
||||||
@@ -121,8 +121,7 @@ async def replace_user_references_async(
|
|||||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||||
return f"{global_config.bot.nickname}(你)"
|
return f"{global_config.bot.nickname}(你)"
|
||||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||||
person_info = await person_info_manager.get_values(person_id, ["person_name"])
|
return await person_info_manager.get_value(person_id, "person_name") or user_id # type: ignore
|
||||||
return person_info.get("person_name") or user_id
|
|
||||||
|
|
||||||
name_resolver = default_resolver
|
name_resolver = default_resolver
|
||||||
|
|
||||||
@@ -170,7 +169,7 @@ async def replace_user_references_async(
|
|||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
async def get_raw_msg_by_timestamp(
|
def get_raw_msg_by_timestamp(
|
||||||
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
@@ -181,10 +180,10 @@ async def get_raw_msg_by_timestamp(
|
|||||||
filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}}
|
filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}}
|
||||||
# 只有当 limit 为 0 时才应用外部 sort
|
# 只有当 limit 为 0 时才应用外部 sort
|
||||||
sort_order = [("time", 1)] if limit == 0 else None
|
sort_order = [("time", 1)] if limit == 0 else None
|
||||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||||
|
|
||||||
|
|
||||||
async def get_raw_msg_by_timestamp_with_chat(
|
def get_raw_msg_by_timestamp_with_chat(
|
||||||
chat_id: str,
|
chat_id: str,
|
||||||
timestamp_start: float,
|
timestamp_start: float,
|
||||||
timestamp_end: float,
|
timestamp_end: float,
|
||||||
@@ -201,7 +200,7 @@ async def get_raw_msg_by_timestamp_with_chat(
|
|||||||
# 只有当 limit 为 0 时才应用外部 sort
|
# 只有当 limit 为 0 时才应用外部 sort
|
||||||
sort_order = [("time", 1)] if limit == 0 else None
|
sort_order = [("time", 1)] if limit == 0 else None
|
||||||
# 直接将 limit_mode 传递给 find_messages
|
# 直接将 limit_mode 传递给 find_messages
|
||||||
return await find_messages(
|
return find_messages(
|
||||||
message_filter=filter_query,
|
message_filter=filter_query,
|
||||||
sort=sort_order,
|
sort=sort_order,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
@@ -211,7 +210,7 @@ async def get_raw_msg_by_timestamp_with_chat(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_raw_msg_by_timestamp_with_chat_inclusive(
|
def get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||||
chat_id: str,
|
chat_id: str,
|
||||||
timestamp_start: float,
|
timestamp_start: float,
|
||||||
timestamp_end: float,
|
timestamp_end: float,
|
||||||
@@ -228,12 +227,12 @@ async def get_raw_msg_by_timestamp_with_chat_inclusive(
|
|||||||
sort_order = [("time", 1)] if limit == 0 else None
|
sort_order = [("time", 1)] if limit == 0 else None
|
||||||
# 直接将 limit_mode 传递给 find_messages
|
# 直接将 limit_mode 传递给 find_messages
|
||||||
|
|
||||||
return await find_messages(
|
return find_messages(
|
||||||
message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, filter_bot=filter_bot
|
message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, filter_bot=filter_bot
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_raw_msg_by_timestamp_with_chat_users(
|
def get_raw_msg_by_timestamp_with_chat_users(
|
||||||
chat_id: str,
|
chat_id: str,
|
||||||
timestamp_start: float,
|
timestamp_start: float,
|
||||||
timestamp_end: float,
|
timestamp_end: float,
|
||||||
@@ -252,10 +251,10 @@ async def get_raw_msg_by_timestamp_with_chat_users(
|
|||||||
}
|
}
|
||||||
# 只有当 limit 为 0 时才应用外部 sort
|
# 只有当 limit 为 0 时才应用外部 sort
|
||||||
sort_order = [("time", 1)] if limit == 0 else None
|
sort_order = [("time", 1)] if limit == 0 else None
|
||||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||||
|
|
||||||
|
|
||||||
async def get_actions_by_timestamp_with_chat(
|
def get_actions_by_timestamp_with_chat(
|
||||||
chat_id: str,
|
chat_id: str,
|
||||||
timestamp_start: float = 0,
|
timestamp_start: float = 0,
|
||||||
timestamp_end: float = time.time(),
|
timestamp_end: float = time.time(),
|
||||||
@@ -274,10 +273,10 @@ async def get_actions_by_timestamp_with_chat(
|
|||||||
f"limit={limit}, limit_mode={limit_mode}"
|
f"limit={limit}, limit_mode={limit_mode}"
|
||||||
)
|
)
|
||||||
|
|
||||||
async with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
if limit > 0:
|
if limit > 0:
|
||||||
if limit_mode == "latest":
|
if limit_mode == "latest":
|
||||||
query = await session.execute(
|
query = session.execute(
|
||||||
select(ActionRecords)
|
select(ActionRecords)
|
||||||
.where(
|
.where(
|
||||||
and_(
|
and_(
|
||||||
@@ -307,7 +306,7 @@ async def get_actions_by_timestamp_with_chat(
|
|||||||
}
|
}
|
||||||
actions_result.append(action_dict)
|
actions_result.append(action_dict)
|
||||||
else: # earliest
|
else: # earliest
|
||||||
query = await session.execute(
|
query = session.execute(
|
||||||
select(ActionRecords)
|
select(ActionRecords)
|
||||||
.where(
|
.where(
|
||||||
and_(
|
and_(
|
||||||
@@ -337,7 +336,7 @@ async def get_actions_by_timestamp_with_chat(
|
|||||||
}
|
}
|
||||||
actions_result.append(action_dict)
|
actions_result.append(action_dict)
|
||||||
else:
|
else:
|
||||||
query = await session.execute(
|
query = session.execute(
|
||||||
select(ActionRecords)
|
select(ActionRecords)
|
||||||
.where(
|
.where(
|
||||||
and_(
|
and_(
|
||||||
@@ -368,14 +367,14 @@ async def get_actions_by_timestamp_with_chat(
|
|||||||
return actions_result
|
return actions_result
|
||||||
|
|
||||||
|
|
||||||
async def get_actions_by_timestamp_with_chat_inclusive(
|
def get_actions_by_timestamp_with_chat_inclusive(
|
||||||
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
|
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
|
||||||
async with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
if limit > 0:
|
if limit > 0:
|
||||||
if limit_mode == "latest":
|
if limit_mode == "latest":
|
||||||
query = await session.execute(
|
query = session.execute(
|
||||||
select(ActionRecords)
|
select(ActionRecords)
|
||||||
.where(
|
.where(
|
||||||
and_(
|
and_(
|
||||||
@@ -390,7 +389,7 @@ async def get_actions_by_timestamp_with_chat_inclusive(
|
|||||||
actions = list(query.scalars())
|
actions = list(query.scalars())
|
||||||
return [action.__dict__ for action in reversed(actions)]
|
return [action.__dict__ for action in reversed(actions)]
|
||||||
else: # earliest
|
else: # earliest
|
||||||
query = await session.execute(
|
query = session.execute(
|
||||||
select(ActionRecords)
|
select(ActionRecords)
|
||||||
.where(
|
.where(
|
||||||
and_(
|
and_(
|
||||||
@@ -403,7 +402,7 @@ async def get_actions_by_timestamp_with_chat_inclusive(
|
|||||||
.limit(limit)
|
.limit(limit)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
query = await session.execute(
|
query = session.execute(
|
||||||
select(ActionRecords)
|
select(ActionRecords)
|
||||||
.where(
|
.where(
|
||||||
and_(
|
and_(
|
||||||
@@ -419,14 +418,14 @@ async def get_actions_by_timestamp_with_chat_inclusive(
|
|||||||
return [action.__dict__ for action in actions]
|
return [action.__dict__ for action in actions]
|
||||||
|
|
||||||
|
|
||||||
async def get_raw_msg_by_timestamp_random(
|
def get_raw_msg_by_timestamp_random(
|
||||||
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
先在范围时间戳内随机选择一条消息,取得消息的chat_id,然后根据chat_id获取该聊天在指定时间戳范围内的消息
|
先在范围时间戳内随机选择一条消息,取得消息的chat_id,然后根据chat_id获取该聊天在指定时间戳范围内的消息
|
||||||
"""
|
"""
|
||||||
# 获取所有消息,只取chat_id字段
|
# 获取所有消息,只取chat_id字段
|
||||||
all_msgs = await get_raw_msg_by_timestamp(timestamp_start, timestamp_end)
|
all_msgs = get_raw_msg_by_timestamp(timestamp_start, timestamp_end)
|
||||||
if not all_msgs:
|
if not all_msgs:
|
||||||
return []
|
return []
|
||||||
# 随机选一条
|
# 随机选一条
|
||||||
@@ -434,10 +433,10 @@ async def get_raw_msg_by_timestamp_random(
|
|||||||
chat_id = msg["chat_id"]
|
chat_id = msg["chat_id"]
|
||||||
timestamp_start = msg["time"]
|
timestamp_start = msg["time"]
|
||||||
# 用 chat_id 获取该聊天在指定时间戳范围内的消息
|
# 用 chat_id 获取该聊天在指定时间戳范围内的消息
|
||||||
return await get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest")
|
return get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest")
|
||||||
|
|
||||||
|
|
||||||
async def get_raw_msg_by_timestamp_with_users(
|
def get_raw_msg_by_timestamp_with_users(
|
||||||
timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
|
timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
"""获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||||
@@ -447,39 +446,37 @@ async def get_raw_msg_by_timestamp_with_users(
|
|||||||
filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}, "user_id": {"$in": person_ids}}
|
filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}, "user_id": {"$in": person_ids}}
|
||||||
# 只有当 limit 为 0 时才应用外部 sort
|
# 只有当 limit 为 0 时才应用外部 sort
|
||||||
sort_order = [("time", 1)] if limit == 0 else None
|
sort_order = [("time", 1)] if limit == 0 else None
|
||||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||||
|
|
||||||
|
|
||||||
async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
|
def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
|
||||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||||
limit: 限制返回的消息数量,0为不限制
|
limit: 限制返回的消息数量,0为不限制
|
||||||
"""
|
"""
|
||||||
filter_query = {"time": {"$lt": timestamp}}
|
filter_query = {"time": {"$lt": timestamp}}
|
||||||
sort_order = [("time", 1)]
|
sort_order = [("time", 1)]
|
||||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||||
|
|
||||||
|
|
||||||
async def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
|
def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
|
||||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||||
limit: 限制返回的消息数量,0为不限制
|
limit: 限制返回的消息数量,0为不限制
|
||||||
"""
|
"""
|
||||||
filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}}
|
filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}}
|
||||||
sort_order = [("time", 1)]
|
sort_order = [("time", 1)]
|
||||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||||
|
|
||||||
|
|
||||||
async def get_raw_msg_before_timestamp_with_users(
|
def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]:
|
||||||
timestamp: float, person_ids: list, limit: int = 0
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||||
limit: 限制返回的消息数量,0为不限制
|
limit: 限制返回的消息数量,0为不限制
|
||||||
"""
|
"""
|
||||||
filter_query = {"time": {"$lt": timestamp}, "user_id": {"$in": person_ids}}
|
filter_query = {"time": {"$lt": timestamp}, "user_id": {"$in": person_ids}}
|
||||||
sort_order = [("time", 1)]
|
sort_order = [("time", 1)]
|
||||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||||
|
|
||||||
|
|
||||||
async def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int:
|
def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int:
|
||||||
"""
|
"""
|
||||||
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
|
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
|
||||||
如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。
|
如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。
|
||||||
@@ -493,10 +490,10 @@ async def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, tim
|
|||||||
return 0 # 起始时间大于等于结束时间,没有新消息
|
return 0 # 起始时间大于等于结束时间,没有新消息
|
||||||
|
|
||||||
filter_query = {"chat_id": chat_id, "time": {"$gt": timestamp_start, "$lt": _timestamp_end}}
|
filter_query = {"chat_id": chat_id, "time": {"$gt": timestamp_start, "$lt": _timestamp_end}}
|
||||||
return await count_messages(message_filter=filter_query)
|
return count_messages(message_filter=filter_query)
|
||||||
|
|
||||||
|
|
||||||
async def num_new_messages_since_with_users(
|
def num_new_messages_since_with_users(
|
||||||
chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: list
|
chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: list
|
||||||
) -> int:
|
) -> int:
|
||||||
"""检查某些特定用户在特定聊天在指定时间戳之间有多少新消息"""
|
"""检查某些特定用户在特定聊天在指定时间戳之间有多少新消息"""
|
||||||
@@ -507,10 +504,10 @@ async def num_new_messages_since_with_users(
|
|||||||
"time": {"$gt": timestamp_start, "$lt": timestamp_end},
|
"time": {"$gt": timestamp_start, "$lt": timestamp_end},
|
||||||
"user_id": {"$in": person_ids},
|
"user_id": {"$in": person_ids},
|
||||||
}
|
}
|
||||||
return await count_messages(message_filter=filter_query)
|
return count_messages(message_filter=filter_query)
|
||||||
|
|
||||||
|
|
||||||
async def _build_readable_messages_internal(
|
def _build_readable_messages_internal(
|
||||||
messages: List[Dict[str, Any]],
|
messages: List[Dict[str, Any]],
|
||||||
replace_bot_name: bool = True,
|
replace_bot_name: bool = True,
|
||||||
merge_messages: bool = False,
|
merge_messages: bool = False,
|
||||||
@@ -520,7 +517,6 @@ async def _build_readable_messages_internal(
|
|||||||
pic_counter: int = 1,
|
pic_counter: int = 1,
|
||||||
show_pic: bool = True,
|
show_pic: bool = True,
|
||||||
message_id_list: Optional[List[Dict[str, Any]]] = None,
|
message_id_list: Optional[List[Dict[str, Any]]] = None,
|
||||||
read_mark: float = 0.0,
|
|
||||||
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
|
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
|
||||||
"""
|
"""
|
||||||
内部辅助函数,构建可读消息字符串和原始消息详情列表。
|
内部辅助函数,构建可读消息字符串和原始消息详情列表。
|
||||||
@@ -631,8 +627,7 @@ async def _build_readable_messages_internal(
|
|||||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||||
person_name = f"{global_config.bot.nickname}(你)"
|
person_name = f"{global_config.bot.nickname}(你)"
|
||||||
else:
|
else:
|
||||||
person_info = await person_info_manager.get_values(person_id, ["person_name"])
|
person_name = person_info_manager.get_value_sync(person_id, "person_name") # type: ignore
|
||||||
person_name = person_info.get("person_name") # type: ignore
|
|
||||||
|
|
||||||
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
|
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
|
||||||
if not person_name:
|
if not person_name:
|
||||||
@@ -731,10 +726,11 @@ async def _build_readable_messages_internal(
|
|||||||
"is_action": is_action,
|
"is_action": is_action,
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 如果是同一个人发送的连续消息且时间间隔小于等于60秒
|
# 如果是同一个人发送的连续消息且时间间隔小于等于60秒
|
||||||
if name == current_merge["name"] and (timestamp - current_merge["end_time"] <= 60):
|
if name == current_merge["name"] and (timestamp - current_merge["end_time"] <= 60):
|
||||||
current_merge["content"].append(content)
|
current_merge["content"].append(content)
|
||||||
current_merge["end_time"] = timestamp
|
current_merge["end_time"] = timestamp # 更新最后消息时间
|
||||||
else:
|
else:
|
||||||
# 保存上一个合并块
|
# 保存上一个合并块
|
||||||
merged_messages.append(current_merge)
|
merged_messages.append(current_merge)
|
||||||
@@ -762,14 +758,8 @@ async def _build_readable_messages_internal(
|
|||||||
|
|
||||||
# 4 & 5: 格式化为字符串
|
# 4 & 5: 格式化为字符串
|
||||||
output_lines = []
|
output_lines = []
|
||||||
read_mark_inserted = False
|
|
||||||
|
|
||||||
for _i, merged in enumerate(merged_messages):
|
for _i, merged in enumerate(merged_messages):
|
||||||
# 检查是否需要插入已读标记
|
|
||||||
if read_mark > 0 and not read_mark_inserted and merged["start_time"] >= read_mark:
|
|
||||||
output_lines.append("\n--- 以上消息是你已经看过,请关注以下未读的新消息---\n")
|
|
||||||
read_mark_inserted = True
|
|
||||||
|
|
||||||
# 使用指定的 timestamp_mode 格式化时间
|
# 使用指定的 timestamp_mode 格式化时间
|
||||||
readable_time = translate_timestamp_to_human_readable(merged["start_time"], mode=timestamp_mode)
|
readable_time = translate_timestamp_to_human_readable(merged["start_time"], mode=timestamp_mode)
|
||||||
|
|
||||||
@@ -810,7 +800,7 @@ async def _build_readable_messages_internal(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||||
# sourcery skip: use-contextlib-suppress
|
# sourcery skip: use-contextlib-suppress
|
||||||
"""
|
"""
|
||||||
构建图片映射信息字符串,显示图片的具体描述内容
|
构建图片映射信息字符串,显示图片的具体描述内容
|
||||||
@@ -833,8 +823,8 @@ async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
|||||||
# 从数据库中获取图片描述
|
# 从数据库中获取图片描述
|
||||||
description = "[图片内容未知]" # 默认描述
|
description = "[图片内容未知]" # 默认描述
|
||||||
try:
|
try:
|
||||||
async with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
image = (await session.execute(select(Images).where(Images.image_id == pic_id))).scalar_one_or_none()
|
image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar_one_or_none()
|
||||||
if image and image.description: # type: ignore
|
if image and image.description: # type: ignore
|
||||||
description = image.description
|
description = image.description
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -931,17 +921,17 @@ async def build_readable_messages_with_list(
|
|||||||
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
|
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
|
||||||
允许通过参数控制格式化行为。
|
允许通过参数控制格式化行为。
|
||||||
"""
|
"""
|
||||||
formatted_string, details_list, pic_id_mapping, _ = await _build_readable_messages_internal(
|
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||||
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
|
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
|
||||||
)
|
)
|
||||||
|
|
||||||
if pic_mapping_info := await build_pic_mapping_info(pic_id_mapping):
|
if pic_mapping_info := build_pic_mapping_info(pic_id_mapping):
|
||||||
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
|
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
|
||||||
|
|
||||||
return formatted_string, details_list
|
return formatted_string, details_list
|
||||||
|
|
||||||
|
|
||||||
async def build_readable_messages_with_id(
|
def build_readable_messages_with_id(
|
||||||
messages: List[Dict[str, Any]],
|
messages: List[Dict[str, Any]],
|
||||||
replace_bot_name: bool = True,
|
replace_bot_name: bool = True,
|
||||||
merge_messages: bool = False,
|
merge_messages: bool = False,
|
||||||
@@ -957,7 +947,7 @@ async def build_readable_messages_with_id(
|
|||||||
"""
|
"""
|
||||||
message_id_list = assign_message_ids(messages)
|
message_id_list = assign_message_ids(messages)
|
||||||
|
|
||||||
formatted_string = await build_readable_messages(
|
formatted_string = build_readable_messages(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
replace_bot_name=replace_bot_name,
|
replace_bot_name=replace_bot_name,
|
||||||
merge_messages=merge_messages,
|
merge_messages=merge_messages,
|
||||||
@@ -972,7 +962,7 @@ async def build_readable_messages_with_id(
|
|||||||
return formatted_string, message_id_list
|
return formatted_string, message_id_list
|
||||||
|
|
||||||
|
|
||||||
async def build_readable_messages(
|
def build_readable_messages(
|
||||||
messages: List[Dict[str, Any]],
|
messages: List[Dict[str, Any]],
|
||||||
replace_bot_name: bool = True,
|
replace_bot_name: bool = True,
|
||||||
merge_messages: bool = False,
|
merge_messages: bool = False,
|
||||||
@@ -1013,10 +1003,9 @@ async def build_readable_messages(
|
|||||||
|
|
||||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||||
|
|
||||||
async with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
# 获取这个时间范围内的动作记录,并匹配chat_id
|
# 获取这个时间范围内的动作记录,并匹配chat_id
|
||||||
actions_in_range = (
|
actions_in_range = session.execute(
|
||||||
await session.execute(
|
|
||||||
select(ActionRecords)
|
select(ActionRecords)
|
||||||
.where(
|
.where(
|
||||||
and_(
|
and_(
|
||||||
@@ -1024,17 +1013,14 @@ async def build_readable_messages(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
.order_by(ActionRecords.time)
|
.order_by(ActionRecords.time)
|
||||||
)
|
|
||||||
).scalars()
|
).scalars()
|
||||||
|
|
||||||
# 获取最新消息之后的第一个动作记录
|
# 获取最新消息之后的第一个动作记录
|
||||||
action_after_latest = (
|
action_after_latest = session.execute(
|
||||||
await session.execute(
|
|
||||||
select(ActionRecords)
|
select(ActionRecords)
|
||||||
.where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id))
|
.where(and_(ActionRecords.time > max_time, ActionRecords.chat_id == chat_id))
|
||||||
.order_by(ActionRecords.time)
|
.order_by(ActionRecords.time)
|
||||||
.limit(1)
|
.limit(1)
|
||||||
)
|
|
||||||
).scalars()
|
).scalars()
|
||||||
|
|
||||||
# 合并两部分动作记录,并转为 dict,避免 DetachedInstanceError
|
# 合并两部分动作记录,并转为 dict,避免 DetachedInstanceError
|
||||||
@@ -1066,7 +1052,7 @@ async def build_readable_messages(
|
|||||||
|
|
||||||
if read_mark <= 0:
|
if read_mark <= 0:
|
||||||
# 没有有效的 read_mark,直接格式化所有消息
|
# 没有有效的 read_mark,直接格式化所有消息
|
||||||
formatted_string, _, pic_id_mapping, _ = await _build_readable_messages_internal(
|
formatted_string, _, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||||
copy_messages,
|
copy_messages,
|
||||||
replace_bot_name,
|
replace_bot_name,
|
||||||
merge_messages,
|
merge_messages,
|
||||||
@@ -1077,7 +1063,7 @@ async def build_readable_messages(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 生成图片映射信息并添加到最前面
|
# 生成图片映射信息并添加到最前面
|
||||||
pic_mapping_info = await build_pic_mapping_info(pic_id_mapping)
|
pic_mapping_info = build_pic_mapping_info(pic_id_mapping)
|
||||||
if pic_mapping_info:
|
if pic_mapping_info:
|
||||||
return f"{pic_mapping_info}\n\n{formatted_string}"
|
return f"{pic_mapping_info}\n\n{formatted_string}"
|
||||||
else:
|
else:
|
||||||
@@ -1092,7 +1078,7 @@ async def build_readable_messages(
|
|||||||
pic_counter = 1
|
pic_counter = 1
|
||||||
|
|
||||||
# 分别格式化,但使用共享的图片映射
|
# 分别格式化,但使用共享的图片映射
|
||||||
formatted_before, _, pic_id_mapping, pic_counter = await _build_readable_messages_internal(
|
formatted_before, _, pic_id_mapping, pic_counter = _build_readable_messages_internal(
|
||||||
messages_before_mark,
|
messages_before_mark,
|
||||||
replace_bot_name,
|
replace_bot_name,
|
||||||
merge_messages,
|
merge_messages,
|
||||||
@@ -1103,7 +1089,7 @@ async def build_readable_messages(
|
|||||||
show_pic=show_pic,
|
show_pic=show_pic,
|
||||||
message_id_list=message_id_list,
|
message_id_list=message_id_list,
|
||||||
)
|
)
|
||||||
formatted_after, _, pic_id_mapping, _ = await _build_readable_messages_internal(
|
formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||||
messages_after_mark,
|
messages_after_mark,
|
||||||
replace_bot_name,
|
replace_bot_name,
|
||||||
merge_messages,
|
merge_messages,
|
||||||
@@ -1119,7 +1105,7 @@ async def build_readable_messages(
|
|||||||
|
|
||||||
# 生成图片映射信息
|
# 生成图片映射信息
|
||||||
if pic_id_mapping:
|
if pic_id_mapping:
|
||||||
pic_mapping_info = f"图片信息:\n{await build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n"
|
pic_mapping_info = f"图片信息:\n{build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n"
|
||||||
else:
|
else:
|
||||||
pic_mapping_info = "聊天记录信息:\n"
|
pic_mapping_info = "聊天记录信息:\n"
|
||||||
|
|
||||||
@@ -1242,7 +1228,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
|||||||
|
|
||||||
# 在最前面添加图片映射信息
|
# 在最前面添加图片映射信息
|
||||||
final_output_lines = []
|
final_output_lines = []
|
||||||
pic_mapping_info = await build_pic_mapping_info(pic_id_mapping)
|
pic_mapping_info = build_pic_mapping_info(pic_id_mapping)
|
||||||
if pic_mapping_info:
|
if pic_mapping_info:
|
||||||
final_output_lines.append(pic_mapping_info)
|
final_output_lines.append(pic_mapping_info)
|
||||||
final_output_lines.append("\n\n")
|
final_output_lines.append("\n\n")
|
||||||
|
|||||||
@@ -78,7 +78,9 @@ class PromptParameters:
|
|||||||
|
|
||||||
# 可用动作信息
|
# 可用动作信息
|
||||||
available_actions: Optional[Dict[str, Any]] = None
|
available_actions: Optional[Dict[str, Any]] = None
|
||||||
read_mark: float = 0.0
|
|
||||||
|
# 动态生成的聊天场景提示
|
||||||
|
chat_scene: str = ""
|
||||||
|
|
||||||
def validate(self) -> List[str]:
|
def validate(self) -> List[str]:
|
||||||
"""参数验证"""
|
"""参数验证"""
|
||||||
@@ -216,10 +218,6 @@ class PromptManager:
|
|||||||
result = prompt.format(**kwargs)
|
result = prompt.format(**kwargs)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@property
|
|
||||||
def context(self):
|
|
||||||
return self._context
|
|
||||||
|
|
||||||
|
|
||||||
# 全局单例
|
# 全局单例
|
||||||
global_prompt_manager = PromptManager()
|
global_prompt_manager = PromptManager()
|
||||||
@@ -240,7 +238,7 @@ class Prompt:
|
|||||||
template: str,
|
template: str,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
parameters: Optional[PromptParameters] = None,
|
parameters: Optional[PromptParameters] = None,
|
||||||
should_register: bool = True
|
should_register: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化统一提示词
|
初始化统一提示词
|
||||||
@@ -261,7 +259,7 @@ class Prompt:
|
|||||||
self._processed_template = self._process_escaped_braces(template)
|
self._processed_template = self._process_escaped_braces(template)
|
||||||
|
|
||||||
# 自动注册
|
# 自动注册
|
||||||
if should_register and not global_prompt_manager.context._current_context:
|
if should_register and not global_prompt_manager._context._current_context:
|
||||||
global_prompt_manager.register(self)
|
global_prompt_manager.register(self)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -374,7 +372,7 @@ class Prompt:
|
|||||||
task_names.append("cross_context")
|
task_names.append("cross_context")
|
||||||
|
|
||||||
# 性能优化
|
# 性能优化
|
||||||
base_timeout = 20.0
|
base_timeout = 10.0
|
||||||
task_timeout = 2.0
|
task_timeout = 2.0
|
||||||
timeout_seconds = min(
|
timeout_seconds = min(
|
||||||
max(base_timeout, len(tasks) * task_timeout),
|
max(base_timeout, len(tasks) * task_timeout),
|
||||||
@@ -425,7 +423,8 @@ class Prompt:
|
|||||||
await self._build_normal_chat_context(context_data)
|
await self._build_normal_chat_context(context_data)
|
||||||
|
|
||||||
# 补充基础信息
|
# 补充基础信息
|
||||||
context_data.update({
|
context_data.update(
|
||||||
|
{
|
||||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt,
|
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt,
|
||||||
"extra_info_block": self.parameters.extra_info_block,
|
"extra_info_block": self.parameters.extra_info_block,
|
||||||
"time_block": self.parameters.time_block or f"当前时间:{time.strftime('%Y-%m-%d %H:%M:%S')}",
|
"time_block": self.parameters.time_block or f"当前时间:{time.strftime('%Y-%m-%d %H:%M:%S')}",
|
||||||
@@ -435,7 +434,8 @@ class Prompt:
|
|||||||
"reply_target_block": self.parameters.reply_target_block,
|
"reply_target_block": self.parameters.reply_target_block,
|
||||||
"mood_state": self.parameters.mood_prompt,
|
"mood_state": self.parameters.mood_prompt,
|
||||||
"action_descriptions": self.parameters.action_descriptions,
|
"action_descriptions": self.parameters.action_descriptions,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
logger.debug(f"上下文构建完成,总耗时: {total_time:.2f}s")
|
logger.debug(f"上下文构建完成,总耗时: {total_time:.2f}s")
|
||||||
@@ -447,15 +447,15 @@ class Prompt:
|
|||||||
if not self.parameters.message_list_before_now_long:
|
if not self.parameters.message_list_before_now_long:
|
||||||
return
|
return
|
||||||
|
|
||||||
core_dialogue, background_dialogue = await self._build_s4u_chat_history_prompts(
|
read_history_prompt, unread_history_prompt = await self._build_s4u_chat_history_prompts(
|
||||||
self.parameters.message_list_before_now_long,
|
self.parameters.message_list_before_now_long,
|
||||||
self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "",
|
self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "",
|
||||||
self.parameters.sender,
|
self.parameters.sender,
|
||||||
read_mark=self.parameters.read_mark,
|
self.parameters.chat_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
context_data["core_dialogue_prompt"] = core_dialogue
|
context_data["read_history_prompt"] = read_history_prompt
|
||||||
context_data["background_dialogue_prompt"] = background_dialogue
|
context_data["unread_history_prompt"] = unread_history_prompt
|
||||||
|
|
||||||
async def _build_normal_chat_context(self, context_data: Dict[str, Any]) -> None:
|
async def _build_normal_chat_context(self, context_data: Dict[str, Any]) -> None:
|
||||||
"""构建normal模式的聊天上下文"""
|
"""构建normal模式的聊天上下文"""
|
||||||
@@ -465,69 +465,26 @@ class Prompt:
|
|||||||
context_data["chat_info"] = f"""群里的聊天内容:
|
context_data["chat_info"] = f"""群里的聊天内容:
|
||||||
{self.parameters.chat_talking_prompt_short}"""
|
{self.parameters.chat_talking_prompt_short}"""
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _build_s4u_chat_history_prompts(
|
async def _build_s4u_chat_history_prompts(
|
||||||
message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, read_mark: float = 0.0
|
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, chat_id: str
|
||||||
) -> Tuple[str, str]:
|
) -> Tuple[str, str]:
|
||||||
"""构建S4U风格的分离对话prompt"""
|
"""构建S4U风格的已读/未读历史消息prompt"""
|
||||||
# 实现逻辑与原有SmartPromptBuilder相同
|
|
||||||
core_dialogue_list = []
|
|
||||||
bot_id = str(global_config.bot.qq_account)
|
|
||||||
|
|
||||||
for msg_dict in message_list_before_now:
|
|
||||||
try:
|
try:
|
||||||
msg_user_id = str(msg_dict.get("user_id"))
|
# 动态导入default_generator以避免循环导入
|
||||||
reply_to = msg_dict.get("reply_to", "")
|
from src.plugin_system.apis.generator_api import get_replyer
|
||||||
platform, reply_to_user_id = Prompt.parse_reply_target(reply_to)
|
|
||||||
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
|
# 创建临时生成器实例来使用其方法
|
||||||
core_dialogue_list.append(msg_dict)
|
temp_generator = get_replyer(None, chat_id, request_type="prompt_building")
|
||||||
|
return await temp_generator.build_s4u_chat_history_prompts(
|
||||||
|
message_list_before_now, target_user_id, sender, chat_id
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}")
|
logger.error(f"构建S4U历史消息prompt失败: {e}")
|
||||||
|
|
||||||
# 构建背景对话 prompt
|
|
||||||
all_dialogue_prompt = ""
|
|
||||||
if message_list_before_now:
|
|
||||||
latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size) :]
|
|
||||||
all_dialogue_prompt_str = await build_readable_messages(
|
|
||||||
latest_25_msgs,
|
|
||||||
replace_bot_name=True,
|
|
||||||
timestamp_mode="normal",
|
|
||||||
truncate=True,
|
|
||||||
read_mark=read_mark,
|
|
||||||
)
|
|
||||||
all_dialogue_prompt = f"所有用户的发言:\n{all_dialogue_prompt_str}"
|
|
||||||
|
|
||||||
# 构建核心对话 prompt
|
|
||||||
core_dialogue_prompt = ""
|
|
||||||
if core_dialogue_list:
|
|
||||||
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
|
|
||||||
has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages)
|
|
||||||
|
|
||||||
if not has_bot_message:
|
|
||||||
core_dialogue_prompt = ""
|
|
||||||
else:
|
|
||||||
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :]
|
|
||||||
|
|
||||||
core_dialogue_prompt_str = await build_readable_messages(
|
|
||||||
core_dialogue_list,
|
|
||||||
replace_bot_name=True,
|
|
||||||
merge_messages=False,
|
|
||||||
timestamp_mode="normal_no_YMD",
|
|
||||||
read_mark=read_mark,
|
|
||||||
truncate=True,
|
|
||||||
show_actions=True,
|
|
||||||
)
|
|
||||||
core_dialogue_prompt = f"""--------------------------------
|
|
||||||
这是你和{sender}的对话,你们正在交流中:
|
|
||||||
{core_dialogue_prompt_str}
|
|
||||||
--------------------------------
|
|
||||||
"""
|
|
||||||
|
|
||||||
return core_dialogue_prompt, all_dialogue_prompt
|
|
||||||
|
|
||||||
async def _build_expression_habits(self) -> Dict[str, Any]:
|
async def _build_expression_habits(self) -> Dict[str, Any]:
|
||||||
"""构建表达习惯"""
|
"""构建表达习惯"""
|
||||||
if not global_config.expression.enable_expression:
|
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.parameters.chat_id)
|
||||||
|
if not use_expression:
|
||||||
return {"expression_habits_block": ""}
|
return {"expression_habits_block": ""}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -537,18 +494,19 @@ class Prompt:
|
|||||||
chat_history = ""
|
chat_history = ""
|
||||||
if self.parameters.message_list_before_now_long:
|
if self.parameters.message_list_before_now_long:
|
||||||
recent_messages = self.parameters.message_list_before_now_long[-10:]
|
recent_messages = self.parameters.message_list_before_now_long[-10:]
|
||||||
chat_history = await build_readable_messages(
|
chat_history = build_readable_messages(
|
||||||
recent_messages,
|
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
|
||||||
replace_bot_name=True,
|
|
||||||
timestamp_mode="normal",
|
|
||||||
truncate=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建表情选择器
|
# 创建表情选择器
|
||||||
expression_selector = ExpressionSelector()
|
expression_selector = ExpressionSelector(self.parameters.chat_id)
|
||||||
|
|
||||||
# 选择合适的表情
|
# 选择合适的表情
|
||||||
selected_expressions = await expression_selector.select_suitable_expressions_llm(
|
selected_expressions = await expression_selector.select_suitable_expressions_llm(
|
||||||
|
chat_history=chat_history,
|
||||||
|
current_message=self.parameters.target,
|
||||||
|
emotional_tone="neutral",
|
||||||
|
topic_type="general",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 构建表达习惯块
|
# 构建表达习惯块
|
||||||
@@ -577,18 +535,14 @@ class Prompt:
|
|||||||
chat_history = ""
|
chat_history = ""
|
||||||
if self.parameters.message_list_before_now_long:
|
if self.parameters.message_list_before_now_long:
|
||||||
recent_messages = self.parameters.message_list_before_now_long[-20:]
|
recent_messages = self.parameters.message_list_before_now_long[-20:]
|
||||||
chat_history = await build_readable_messages(
|
chat_history = build_readable_messages(
|
||||||
recent_messages,
|
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
|
||||||
replace_bot_name=True,
|
|
||||||
timestamp_mode="normal",
|
|
||||||
truncate=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 激活长期记忆
|
# 激活长期记忆
|
||||||
memory_activator = MemoryActivator()
|
memory_activator = MemoryActivator()
|
||||||
running_memories = await memory_activator.activate_memory_with_chat_history(
|
running_memories = await memory_activator.activate_memory_with_chat_history(
|
||||||
target_message=self.parameters.target,
|
target_message=self.parameters.target, chat_history_prompt=chat_history
|
||||||
chat_history_prompt=chat_history
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取即时记忆
|
# 获取即时记忆
|
||||||
@@ -635,11 +589,8 @@ class Prompt:
|
|||||||
chat_history = ""
|
chat_history = ""
|
||||||
if self.parameters.message_list_before_now_long:
|
if self.parameters.message_list_before_now_long:
|
||||||
recent_messages = self.parameters.message_list_before_now_long[-15:]
|
recent_messages = self.parameters.message_list_before_now_long[-15:]
|
||||||
chat_history = await build_readable_messages(
|
chat_history = build_readable_messages(
|
||||||
recent_messages,
|
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
|
||||||
replace_bot_name=True,
|
|
||||||
timestamp_mode="normal",
|
|
||||||
truncate=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建工具执行器
|
# 创建工具执行器
|
||||||
@@ -650,7 +601,7 @@ class Prompt:
|
|||||||
sender=self.parameters.sender,
|
sender=self.parameters.sender,
|
||||||
target_message=self.parameters.target,
|
target_message=self.parameters.target,
|
||||||
chat_history=chat_history,
|
chat_history=chat_history,
|
||||||
return_details=False
|
return_details=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 构建工具信息块
|
# 构建工具信息块
|
||||||
@@ -680,21 +631,19 @@ class Prompt:
|
|||||||
return {"knowledge_prompt": ""}
|
return {"knowledge_prompt": ""}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
from src.chat.knowledge.knowledge_lib import QAManager
|
||||||
|
|
||||||
# 获取问题文本(当前消息)
|
# 获取问题文本(当前消息)
|
||||||
question = self.parameters.target or ""
|
question = self.parameters.target or ""
|
||||||
if not question:
|
if not question:
|
||||||
return {"knowledge_prompt": ""}
|
return {"knowledge_prompt": ""}
|
||||||
|
|
||||||
# 检查QA管理器是否已成功初始化
|
# 创建QA管理器
|
||||||
if not qa_manager:
|
qa_manager = QAManager()
|
||||||
logger.warning("QA管理器未初始化 (可能lpmm_knowledge被禁用),跳过知识库搜索。")
|
|
||||||
return {"knowledge_prompt": ""}
|
|
||||||
|
|
||||||
# 搜索相关知识
|
# 搜索相关知识
|
||||||
knowledge_results = await qa_manager.get_knowledge(
|
knowledge_results = await qa_manager.get_knowledge(
|
||||||
question=question
|
question=question, chat_id=self.parameters.chat_id, max_results=5, min_similarity=0.5
|
||||||
)
|
)
|
||||||
|
|
||||||
# 构建知识块
|
# 构建知识块
|
||||||
@@ -707,10 +656,13 @@ class Prompt:
|
|||||||
relevance = item.get("relevance", 0.0)
|
relevance = item.get("relevance", 0.0)
|
||||||
|
|
||||||
if content:
|
if content:
|
||||||
knowledge_parts.append(f"- [相关度: {relevance}] {content}")
|
if source:
|
||||||
|
knowledge_parts.append(f"- [{relevance:.2f}] {content} (来源: {source})")
|
||||||
|
else:
|
||||||
|
knowledge_parts.append(f"- [{relevance:.2f}] {content}")
|
||||||
|
|
||||||
if summary := knowledge_results.get("summary"):
|
if knowledge_results.get("summary"):
|
||||||
knowledge_parts.append(f"\n知识总结: {summary}")
|
knowledge_parts.append(f"\n知识总结: {knowledge_results['summary']}")
|
||||||
|
|
||||||
knowledge_prompt = "\n".join(knowledge_parts)
|
knowledge_prompt = "\n".join(knowledge_parts)
|
||||||
else:
|
else:
|
||||||
@@ -759,15 +711,17 @@ class Prompt:
|
|||||||
"action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""),
|
"action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""),
|
||||||
"sender_name": self.parameters.sender or "未知用户",
|
"sender_name": self.parameters.sender or "未知用户",
|
||||||
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
|
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
|
||||||
"background_dialogue_prompt": context_data.get("background_dialogue_prompt", ""),
|
"read_history_prompt": context_data.get("read_history_prompt", ""),
|
||||||
|
"unread_history_prompt": context_data.get("unread_history_prompt", ""),
|
||||||
"time_block": context_data.get("time_block", ""),
|
"time_block": context_data.get("time_block", ""),
|
||||||
"core_dialogue_prompt": context_data.get("core_dialogue_prompt", ""),
|
|
||||||
"reply_target_block": context_data.get("reply_target_block", ""),
|
"reply_target_block": context_data.get("reply_target_block", ""),
|
||||||
"reply_style": global_config.personality.reply_style,
|
"reply_style": global_config.personality.reply_style,
|
||||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
|
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
|
||||||
|
or context_data.get("keywords_reaction_prompt", ""),
|
||||||
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
|
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
|
||||||
"safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""),
|
"safety_guidelines_block": self.parameters.safety_guidelines_block
|
||||||
"chat_context_type": "群聊" if self.parameters.is_group_chat else "私聊",
|
or context_data.get("safety_guidelines_block", ""),
|
||||||
|
"chat_scene": self.parameters.chat_scene or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。",
|
||||||
}
|
}
|
||||||
|
|
||||||
def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
@@ -789,9 +743,12 @@ class Prompt:
|
|||||||
"reply_target_block": context_data.get("reply_target_block", ""),
|
"reply_target_block": context_data.get("reply_target_block", ""),
|
||||||
"config_expression_style": global_config.personality.reply_style,
|
"config_expression_style": global_config.personality.reply_style,
|
||||||
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
|
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
|
||||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
|
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
|
||||||
|
or context_data.get("keywords_reaction_prompt", ""),
|
||||||
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
|
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
|
||||||
"safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""),
|
"safety_guidelines_block": self.parameters.safety_guidelines_block
|
||||||
|
or context_data.get("safety_guidelines_block", ""),
|
||||||
|
"chat_scene": self.parameters.chat_scene or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。",
|
||||||
}
|
}
|
||||||
|
|
||||||
def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
@@ -809,9 +766,11 @@ class Prompt:
|
|||||||
"reason": "",
|
"reason": "",
|
||||||
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
|
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
|
||||||
"reply_style": global_config.personality.reply_style,
|
"reply_style": global_config.personality.reply_style,
|
||||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
|
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
|
||||||
|
or context_data.get("keywords_reaction_prompt", ""),
|
||||||
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
|
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
|
||||||
"safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""),
|
"safety_guidelines_block": self.parameters.safety_guidelines_block
|
||||||
|
or context_data.get("safety_guidelines_block", ""),
|
||||||
}
|
}
|
||||||
|
|
||||||
def format(self, *args, **kwargs) -> str:
|
def format(self, *args, **kwargs) -> str:
|
||||||
@@ -912,9 +871,7 @@ class Prompt:
|
|||||||
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def build_cross_context(
|
async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]) -> str:
|
||||||
chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]
|
|
||||||
) -> str:
|
|
||||||
"""
|
"""
|
||||||
构建跨群聊上下文 - 统一实现
|
构建跨群聊上下文 - 统一实现
|
||||||
|
|
||||||
@@ -969,7 +926,7 @@ class Prompt:
|
|||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||||
if person_id:
|
if person_id:
|
||||||
user_id = person_info_manager.get_value(person_id, "user_id")
|
user_id = person_info_manager.get_value_sync(person_id, "user_id")
|
||||||
return str(user_id) if user_id else ""
|
return str(user_id) if user_id else ""
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
@@ -977,10 +934,7 @@ class Prompt:
|
|||||||
|
|
||||||
# 工厂函数
|
# 工厂函数
|
||||||
def create_prompt(
|
def create_prompt(
|
||||||
template: str,
|
template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
|
||||||
name: Optional[str] = None,
|
|
||||||
parameters: Optional[PromptParameters] = None,
|
|
||||||
**kwargs
|
|
||||||
) -> Prompt:
|
) -> Prompt:
|
||||||
"""快速创建Prompt实例的工厂函数"""
|
"""快速创建Prompt实例的工厂函数"""
|
||||||
if parameters is None:
|
if parameters is None:
|
||||||
@@ -989,14 +943,10 @@ def create_prompt(
|
|||||||
|
|
||||||
|
|
||||||
async def create_prompt_async(
|
async def create_prompt_async(
|
||||||
template: str,
|
template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
|
||||||
name: Optional[str] = None,
|
|
||||||
parameters: Optional[PromptParameters] = None,
|
|
||||||
**kwargs
|
|
||||||
) -> Prompt:
|
) -> Prompt:
|
||||||
"""异步创建Prompt实例"""
|
"""异步创建Prompt实例"""
|
||||||
prompt = create_prompt(template, name, parameters, **kwargs)
|
prompt = create_prompt(template, name, parameters, **kwargs)
|
||||||
if global_prompt_manager.context._current_context:
|
if global_prompt_manager._context._current_context:
|
||||||
await global_prompt_manager.context.register_async(prompt)
|
await global_prompt_manager._context.register_async(prompt)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import numpy as np
|
|||||||
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from maim_message import UserInfo
|
from maim_message import UserInfo
|
||||||
from typing import Optional, Tuple, Dict, List, Any, Coroutine
|
from typing import Optional, Tuple, Dict, List, Any
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.message_repository import find_messages, count_messages
|
from src.common.message_repository import find_messages, count_messages
|
||||||
@@ -341,8 +341,8 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
|
|||||||
split_sentences = [s.strip() for s in split_sentences_raw if s.strip()]
|
split_sentences = [s.strip() for s in split_sentences_raw if s.strip()]
|
||||||
else:
|
else:
|
||||||
if split_mode == "llm":
|
if split_mode == "llm":
|
||||||
logger.debug("未检测到 [SPLIT] 标记,回退到基于标点的传统模式进行分割。")
|
logger.debug("未检测到 [SPLIT] 标记,本次不进行分割。")
|
||||||
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
|
split_sentences = [cleaned_text]
|
||||||
else: # mode == "punctuation"
|
else: # mode == "punctuation"
|
||||||
logger.debug("使用基于标点的传统模式进行分割。")
|
logger.debug("使用基于标点的传统模式进行分割。")
|
||||||
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
|
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
|
||||||
@@ -352,6 +352,8 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
|
|||||||
|
|
||||||
sentences = []
|
sentences = []
|
||||||
for sentence in split_sentences:
|
for sentence in split_sentences:
|
||||||
|
# 清除开头可能存在的空行
|
||||||
|
sentence = sentence.lstrip("\n").rstrip()
|
||||||
if global_config.chinese_typo.enable and enable_chinese_typo:
|
if global_config.chinese_typo.enable and enable_chinese_typo:
|
||||||
typoed_text, typo_corrections = typo_generator.create_typo_sentence(sentence)
|
typoed_text, typo_corrections = typo_generator.create_typo_sentence(sentence)
|
||||||
sentences.append(typoed_text)
|
sentences.append(typoed_text)
|
||||||
@@ -540,8 +542,7 @@ def get_western_ratio(paragraph):
|
|||||||
return western_count / len(alnum_chars)
|
return western_count / len(alnum_chars)
|
||||||
|
|
||||||
|
|
||||||
def count_messages_between(start_time: float, end_time: float, stream_id: str) -> tuple[int, int] | tuple[
|
def count_messages_between(start_time: float, end_time: float, stream_id: str) -> tuple[int, int]:
|
||||||
Coroutine[Any, Any, int], int]:
|
|
||||||
"""计算两个时间点之间的消息数量和文本总长度
|
"""计算两个时间点之间的消息数量和文本总长度
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -619,7 +620,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
|
|||||||
return time.strftime("%H:%M:%S", time.localtime(timestamp))
|
return time.strftime("%H:%M:%S", time.localtime(timestamp))
|
||||||
|
|
||||||
|
|
||||||
async def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
||||||
"""
|
"""
|
||||||
获取聊天类型(是否群聊)和私聊对象信息。
|
获取聊天类型(是否群聊)和私聊对象信息。
|
||||||
|
|
||||||
@@ -663,8 +664,7 @@ async def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Di
|
|||||||
if person_id:
|
if person_id:
|
||||||
# get_value is async, so await it directly
|
# get_value is async, so await it directly
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
person_data = await person_info_manager.get_values(person_id, ["person_name"])
|
person_name = person_info_manager.get_value_sync(person_id, "person_name")
|
||||||
person_name = person_data.get("person_name")
|
|
||||||
|
|
||||||
target_info["person_id"] = person_id
|
target_info["person_id"] = person_id
|
||||||
target_info["person_name"] = person_name
|
target_info["person_name"] = person_name
|
||||||
@@ -695,25 +695,9 @@ def assign_message_ids(messages: List[Any]) -> List[Dict[str, Any]]:
|
|||||||
"""
|
"""
|
||||||
result = []
|
result = []
|
||||||
used_ids = set()
|
used_ids = set()
|
||||||
len_i = len(messages)
|
|
||||||
if len_i > 100:
|
|
||||||
a = 10
|
|
||||||
b = 99
|
|
||||||
else:
|
|
||||||
a = 1
|
|
||||||
b = 9
|
|
||||||
|
|
||||||
for i, message in enumerate(messages):
|
for i, message in enumerate(messages):
|
||||||
# 生成唯一的简短ID
|
# 使用简单的索引作为ID
|
||||||
while True:
|
message_id = f"m{i + 1}"
|
||||||
# 使用索引+随机数生成简短ID
|
|
||||||
random_suffix = random.randint(a, b)
|
|
||||||
message_id = f"m{i + 1}{random_suffix}"
|
|
||||||
|
|
||||||
if message_id not in used_ids:
|
|
||||||
used_ids.add(message_id)
|
|
||||||
break
|
|
||||||
|
|
||||||
result.append({"id": message_id, "message": message})
|
result.append({"id": message_id, "message": message})
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -6,6 +6,7 @@ class BaseDataModel:
|
|||||||
def deepcopy(self):
|
def deepcopy(self):
|
||||||
return copy.deepcopy(self)
|
return copy.deepcopy(self)
|
||||||
|
|
||||||
|
|
||||||
def temporarily_transform_class_to_dict(obj: Any) -> Any:
|
def temporarily_transform_class_to_dict(obj: Any) -> Any:
|
||||||
# sourcery skip: assign-if-exp, reintroduce-else
|
# sourcery skip: assign-if-exp, reintroduce-else
|
||||||
"""
|
"""
|
||||||
|
|||||||
137
src/common/data_models/bot_interest_data_model.py
Normal file
137
src/common/data_models/bot_interest_data_model.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
"""
|
||||||
|
机器人兴趣标签数据模型
|
||||||
|
定义机器人的兴趣标签和相关的embedding数据结构
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Dict, Optional, Any
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from . import BaseDataModel
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BotInterestTag(BaseDataModel):
|
||||||
|
"""机器人兴趣标签"""
|
||||||
|
|
||||||
|
tag_name: str
|
||||||
|
weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0)
|
||||||
|
embedding: Optional[List[float]] = None # 标签的embedding向量
|
||||||
|
created_at: datetime = field(default_factory=datetime.now)
|
||||||
|
updated_at: datetime = field(default_factory=datetime.now)
|
||||||
|
is_active: bool = True
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""转换为字典格式"""
|
||||||
|
return {
|
||||||
|
"tag_name": self.tag_name,
|
||||||
|
"weight": self.weight,
|
||||||
|
"embedding": self.embedding,
|
||||||
|
"created_at": self.created_at.isoformat(),
|
||||||
|
"updated_at": self.updated_at.isoformat(),
|
||||||
|
"is_active": self.is_active,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "BotInterestTag":
|
||||||
|
"""从字典创建对象"""
|
||||||
|
return cls(
|
||||||
|
tag_name=data["tag_name"],
|
||||||
|
weight=data.get("weight", 1.0),
|
||||||
|
embedding=data.get("embedding"),
|
||||||
|
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(),
|
||||||
|
updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else datetime.now(),
|
||||||
|
is_active=data.get("is_active", True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BotPersonalityInterests(BaseDataModel):
|
||||||
|
"""机器人人格化兴趣配置"""
|
||||||
|
|
||||||
|
personality_id: str
|
||||||
|
personality_description: str # 人设描述文本
|
||||||
|
interest_tags: List[BotInterestTag] = field(default_factory=list)
|
||||||
|
embedding_model: str = "text-embedding-ada-002" # 使用的embedding模型
|
||||||
|
last_updated: datetime = field(default_factory=datetime.now)
|
||||||
|
version: int = 1 # 版本号,用于追踪更新
|
||||||
|
|
||||||
|
def get_active_tags(self) -> List[BotInterestTag]:
|
||||||
|
"""获取活跃的兴趣标签"""
|
||||||
|
return [tag for tag in self.interest_tags if tag.is_active]
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""转换为字典格式"""
|
||||||
|
return {
|
||||||
|
"personality_id": self.personality_id,
|
||||||
|
"personality_description": self.personality_description,
|
||||||
|
"interest_tags": [tag.to_dict() for tag in self.interest_tags],
|
||||||
|
"embedding_model": self.embedding_model,
|
||||||
|
"last_updated": self.last_updated.isoformat(),
|
||||||
|
"version": self.version,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "BotPersonalityInterests":
|
||||||
|
"""从字典创建对象"""
|
||||||
|
return cls(
|
||||||
|
personality_id=data["personality_id"],
|
||||||
|
personality_description=data["personality_description"],
|
||||||
|
interest_tags=[BotInterestTag.from_dict(tag_data) for tag_data in data.get("interest_tags", [])],
|
||||||
|
embedding_model=data.get("embedding_model", "text-embedding-ada-002"),
|
||||||
|
last_updated=datetime.fromisoformat(data["last_updated"]) if data.get("last_updated") else datetime.now(),
|
||||||
|
version=data.get("version", 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InterestMatchResult(BaseDataModel):
|
||||||
|
"""兴趣匹配结果"""
|
||||||
|
|
||||||
|
message_id: str
|
||||||
|
matched_tags: List[str] = field(default_factory=list)
|
||||||
|
match_scores: Dict[str, float] = field(default_factory=dict) # tag_name -> score
|
||||||
|
overall_score: float = 0.0
|
||||||
|
top_tag: Optional[str] = None
|
||||||
|
confidence: float = 0.0 # 匹配置信度 (0.0-1.0)
|
||||||
|
matched_keywords: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
def add_match(self, tag_name: str, score: float, keywords: List[str] = None):
|
||||||
|
"""添加匹配结果"""
|
||||||
|
self.matched_tags.append(tag_name)
|
||||||
|
self.match_scores[tag_name] = score
|
||||||
|
if keywords:
|
||||||
|
self.matched_keywords.extend(keywords)
|
||||||
|
|
||||||
|
def calculate_overall_score(self):
|
||||||
|
"""计算总体匹配分数"""
|
||||||
|
if not self.match_scores:
|
||||||
|
self.overall_score = 0.0
|
||||||
|
self.top_tag = None
|
||||||
|
return
|
||||||
|
|
||||||
|
# 使用加权平均计算总体分数
|
||||||
|
total_weight = len(self.match_scores)
|
||||||
|
if total_weight > 0:
|
||||||
|
self.overall_score = sum(self.match_scores.values()) / total_weight
|
||||||
|
# 设置最佳匹配标签
|
||||||
|
self.top_tag = max(self.match_scores.items(), key=lambda x: x[1])[0]
|
||||||
|
else:
|
||||||
|
self.overall_score = 0.0
|
||||||
|
self.top_tag = None
|
||||||
|
|
||||||
|
# 计算置信度(基于匹配标签数量和分数分布)
|
||||||
|
if len(self.match_scores) > 0:
|
||||||
|
avg_score = self.overall_score
|
||||||
|
score_variance = sum((score - avg_score) ** 2 for score in self.match_scores.values()) / len(
|
||||||
|
self.match_scores
|
||||||
|
)
|
||||||
|
# 分数越集中,置信度越高
|
||||||
|
self.confidence = max(0.0, 1.0 - score_variance)
|
||||||
|
else:
|
||||||
|
self.confidence = 0.0
|
||||||
|
|
||||||
|
def get_top_matches(self, top_n: int = 3) -> List[tuple]:
|
||||||
|
"""获取前N个最佳匹配"""
|
||||||
|
sorted_matches = sorted(self.match_scores.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
return sorted_matches[:top_n]
|
||||||
@@ -79,6 +79,7 @@ class DatabaseMessages(BaseDataModel):
|
|||||||
is_command: bool = False,
|
is_command: bool = False,
|
||||||
is_notify: bool = False,
|
is_notify: bool = False,
|
||||||
selected_expressions: Optional[str] = None,
|
selected_expressions: Optional[str] = None,
|
||||||
|
is_read: bool = False,
|
||||||
user_id: str = "",
|
user_id: str = "",
|
||||||
user_nickname: str = "",
|
user_nickname: str = "",
|
||||||
user_cardname: Optional[str] = None,
|
user_cardname: Optional[str] = None,
|
||||||
@@ -94,6 +95,9 @@ class DatabaseMessages(BaseDataModel):
|
|||||||
chat_info_platform: str = "",
|
chat_info_platform: str = "",
|
||||||
chat_info_create_time: float = 0.0,
|
chat_info_create_time: float = 0.0,
|
||||||
chat_info_last_active_time: float = 0.0,
|
chat_info_last_active_time: float = 0.0,
|
||||||
|
# 新增字段
|
||||||
|
actions: Optional[list] = None,
|
||||||
|
should_reply: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
self.message_id = message_id
|
self.message_id = message_id
|
||||||
@@ -102,6 +106,10 @@ class DatabaseMessages(BaseDataModel):
|
|||||||
self.reply_to = reply_to
|
self.reply_to = reply_to
|
||||||
self.interest_value = interest_value
|
self.interest_value = interest_value
|
||||||
|
|
||||||
|
# 新增字段
|
||||||
|
self.actions = actions
|
||||||
|
self.should_reply = should_reply
|
||||||
|
|
||||||
self.key_words = key_words
|
self.key_words = key_words
|
||||||
self.key_words_lite = key_words_lite
|
self.key_words_lite = key_words_lite
|
||||||
self.is_mentioned = is_mentioned
|
self.is_mentioned = is_mentioned
|
||||||
@@ -122,6 +130,7 @@ class DatabaseMessages(BaseDataModel):
|
|||||||
self.is_notify = is_notify
|
self.is_notify = is_notify
|
||||||
|
|
||||||
self.selected_expressions = selected_expressions
|
self.selected_expressions = selected_expressions
|
||||||
|
self.is_read = is_read
|
||||||
|
|
||||||
self.group_info: Optional[DatabaseGroupInfo] = None
|
self.group_info: Optional[DatabaseGroupInfo] = None
|
||||||
self.user_info = DatabaseUserInfo(
|
self.user_info = DatabaseUserInfo(
|
||||||
@@ -188,6 +197,10 @@ class DatabaseMessages(BaseDataModel):
|
|||||||
"is_command": self.is_command,
|
"is_command": self.is_command,
|
||||||
"is_notify": self.is_notify,
|
"is_notify": self.is_notify,
|
||||||
"selected_expressions": self.selected_expressions,
|
"selected_expressions": self.selected_expressions,
|
||||||
|
"is_read": self.is_read,
|
||||||
|
# 新增字段
|
||||||
|
"actions": self.actions,
|
||||||
|
"should_reply": self.should_reply,
|
||||||
"user_id": self.user_info.user_id,
|
"user_id": self.user_info.user_id,
|
||||||
"user_nickname": self.user_info.user_nickname,
|
"user_nickname": self.user_info.user_nickname,
|
||||||
"user_cardname": self.user_info.user_cardname,
|
"user_cardname": self.user_info.user_cardname,
|
||||||
@@ -205,6 +218,61 @@ class DatabaseMessages(BaseDataModel):
|
|||||||
"chat_info_user_cardname": self.chat_info.user_info.user_cardname,
|
"chat_info_user_cardname": self.chat_info.user_info.user_cardname,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def update_message_info(self, interest_value: float = None, actions: list = None, should_reply: bool = None):
|
||||||
|
"""
|
||||||
|
更新消息信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
interest_value: 兴趣度值
|
||||||
|
actions: 执行的动作列表
|
||||||
|
should_reply: 是否应该回复
|
||||||
|
"""
|
||||||
|
if interest_value is not None:
|
||||||
|
self.interest_value = interest_value
|
||||||
|
if actions is not None:
|
||||||
|
self.actions = actions
|
||||||
|
if should_reply is not None:
|
||||||
|
self.should_reply = should_reply
|
||||||
|
|
||||||
|
def add_action(self, action: str):
|
||||||
|
"""
|
||||||
|
添加执行的动作到消息中
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: 要添加的动作名称
|
||||||
|
"""
|
||||||
|
if self.actions is None:
|
||||||
|
self.actions = []
|
||||||
|
if action not in self.actions: # 避免重复添加
|
||||||
|
self.actions.append(action)
|
||||||
|
|
||||||
|
def get_actions(self) -> list:
|
||||||
|
"""
|
||||||
|
获取执行的动作列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
动作列表,如果没有动作则返回空列表
|
||||||
|
"""
|
||||||
|
return self.actions or []
|
||||||
|
|
||||||
|
def get_message_summary(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取消息摘要信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含关键字段的消息摘要
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"message_id": self.message_id,
|
||||||
|
"time": self.time,
|
||||||
|
"interest_value": self.interest_value,
|
||||||
|
"actions": self.actions,
|
||||||
|
"should_reply": self.should_reply,
|
||||||
|
"user_nickname": self.user_info.user_nickname,
|
||||||
|
"display_message": self.display_message,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass(init=False)
|
@dataclass(init=False)
|
||||||
class DatabaseActionRecords(BaseDataModel):
|
class DatabaseActionRecords(BaseDataModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional, Dict, List, TYPE_CHECKING
|
from typing import Optional, Dict, List, TYPE_CHECKING
|
||||||
|
|
||||||
|
from src.plugin_system.base.component_types import ChatType
|
||||||
from . import BaseDataModel
|
from . import BaseDataModel
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
pass
|
from .database_data_model import DatabaseMessages
|
||||||
|
from src.plugin_system.base.component_types import ActionInfo, ChatMode
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -21,18 +23,32 @@ class ActionPlannerInfo(BaseDataModel):
|
|||||||
action_type: str = field(default_factory=str)
|
action_type: str = field(default_factory=str)
|
||||||
reasoning: Optional[str] = None
|
reasoning: Optional[str] = None
|
||||||
action_data: Optional[Dict] = None
|
action_data: Optional[Dict] = None
|
||||||
action_message: Optional[Dict] = None
|
action_message: Optional["DatabaseMessages"] = None
|
||||||
available_actions: Optional[Dict[str, "ActionInfo"]] = None
|
available_actions: Optional[Dict[str, "ActionInfo"]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InterestScore(BaseDataModel):
|
||||||
|
"""兴趣度评分结果"""
|
||||||
|
|
||||||
|
message_id: str
|
||||||
|
total_score: float
|
||||||
|
interest_match_score: float
|
||||||
|
relationship_score: float
|
||||||
|
mentioned_score: float
|
||||||
|
details: Dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Plan(BaseDataModel):
|
class Plan(BaseDataModel):
|
||||||
"""
|
"""
|
||||||
统一规划数据模型
|
统一规划数据模型
|
||||||
"""
|
"""
|
||||||
|
|
||||||
chat_id: str
|
chat_id: str
|
||||||
mode: "ChatMode"
|
mode: "ChatMode"
|
||||||
|
|
||||||
|
chat_type: "ChatType"
|
||||||
# Generator 填充
|
# Generator 填充
|
||||||
available_actions: Dict[str, "ActionInfo"] = field(default_factory=dict)
|
available_actions: Dict[str, "ActionInfo"] = field(default_factory=dict)
|
||||||
chat_history: List["DatabaseMessages"] = field(default_factory=list)
|
chat_history: List["DatabaseMessages"] = field(default_factory=list)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from . import BaseDataModel
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LLMGenerationDataModel(BaseDataModel):
|
class LLMGenerationDataModel(BaseDataModel):
|
||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
|
|||||||
@@ -1,36 +0,0 @@
|
|||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Optional, TYPE_CHECKING
|
|
||||||
|
|
||||||
from . import BaseDataModel
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MessageAndActionModel(BaseDataModel):
|
|
||||||
chat_id: str = field(default_factory=str)
|
|
||||||
time: float = field(default_factory=float)
|
|
||||||
user_id: str = field(default_factory=str)
|
|
||||||
user_platform: str = field(default_factory=str)
|
|
||||||
user_nickname: str = field(default_factory=str)
|
|
||||||
user_cardname: Optional[str] = None
|
|
||||||
processed_plain_text: Optional[str] = None
|
|
||||||
display_message: Optional[str] = None
|
|
||||||
chat_info_platform: str = field(default_factory=str)
|
|
||||||
is_action_record: bool = field(default=False)
|
|
||||||
action_name: Optional[str] = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_DatabaseMessages(cls, message: "DatabaseMessages"):
|
|
||||||
return cls(
|
|
||||||
chat_id=message.chat_id,
|
|
||||||
time=message.time,
|
|
||||||
user_id=message.user_info.user_id,
|
|
||||||
user_platform=message.user_info.platform,
|
|
||||||
user_nickname=message.user_info.user_nickname,
|
|
||||||
user_cardname=message.user_info.user_cardname,
|
|
||||||
processed_plain_text=message.processed_plain_text,
|
|
||||||
display_message=message.display_message,
|
|
||||||
chat_info_platform=message.chat_info.platform,
|
|
||||||
)
|
|
||||||
373
src/common/data_models/message_manager_data_model.py
Normal file
373
src/common/data_models/message_manager_data_model.py
Normal file
@@ -0,0 +1,373 @@
|
|||||||
|
"""
|
||||||
|
消息管理模块数据模型
|
||||||
|
定义消息管理器使用的数据结构
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
from . import BaseDataModel
|
||||||
|
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .database_data_model import DatabaseMessages
|
||||||
|
|
||||||
|
logger = get_logger("stream_context")
|
||||||
|
|
||||||
|
|
||||||
|
class MessageStatus(Enum):
|
||||||
|
"""消息状态枚举"""
|
||||||
|
|
||||||
|
UNREAD = "unread" # 未读消息
|
||||||
|
READ = "read" # 已读消息
|
||||||
|
PROCESSING = "processing" # 处理中
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StreamContext(BaseDataModel):
|
||||||
|
"""聊天流上下文信息"""
|
||||||
|
|
||||||
|
stream_id: str
|
||||||
|
chat_type: ChatType = ChatType.PRIVATE # 聊天类型,默认为私聊
|
||||||
|
chat_mode: ChatMode = ChatMode.NORMAL # 聊天模式,默认为普通模式
|
||||||
|
unread_messages: List["DatabaseMessages"] = field(default_factory=list)
|
||||||
|
history_messages: List["DatabaseMessages"] = field(default_factory=list)
|
||||||
|
last_check_time: float = field(default_factory=time.time)
|
||||||
|
is_active: bool = True
|
||||||
|
processing_task: Optional[asyncio.Task] = None
|
||||||
|
interruption_count: int = 0 # 打断计数器
|
||||||
|
last_interruption_time: float = 0.0 # 上次打断时间
|
||||||
|
afc_threshold_adjustment: float = 0.0 # afc阈值调整量
|
||||||
|
|
||||||
|
# 独立分发周期字段
|
||||||
|
next_check_time: float = field(default_factory=time.time) # 下次检查时间
|
||||||
|
distribution_interval: float = 5.0 # 当前分发周期(秒)
|
||||||
|
|
||||||
|
# 新增字段以替代ChatMessageContext功能
|
||||||
|
current_message: Optional["DatabaseMessages"] = None
|
||||||
|
priority_mode: Optional[str] = None
|
||||||
|
priority_info: Optional[dict] = None
|
||||||
|
|
||||||
|
def add_message(self, message: "DatabaseMessages"):
|
||||||
|
"""添加消息到上下文"""
|
||||||
|
message.is_read = False
|
||||||
|
self.unread_messages.append(message)
|
||||||
|
|
||||||
|
# 自动检测和更新chat type
|
||||||
|
self._detect_chat_type(message)
|
||||||
|
|
||||||
|
def update_message_info(
|
||||||
|
self, message_id: str, interest_value: float = None, actions: list = None, should_reply: bool = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
更新消息信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_id: 消息ID
|
||||||
|
interest_value: 兴趣度值
|
||||||
|
actions: 执行的动作列表
|
||||||
|
should_reply: 是否应该回复
|
||||||
|
"""
|
||||||
|
# 在未读消息中查找并更新
|
||||||
|
for message in self.unread_messages:
|
||||||
|
if message.message_id == message_id:
|
||||||
|
message.update_message_info(interest_value, actions, should_reply)
|
||||||
|
break
|
||||||
|
|
||||||
|
# 在历史消息中查找并更新
|
||||||
|
for message in self.history_messages:
|
||||||
|
if message.message_id == message_id:
|
||||||
|
message.update_message_info(interest_value, actions, should_reply)
|
||||||
|
break
|
||||||
|
|
||||||
|
def add_action_to_message(self, message_id: str, action: str):
|
||||||
|
"""
|
||||||
|
向指定消息添加执行的动作
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_id: 消息ID
|
||||||
|
action: 要添加的动作名称
|
||||||
|
"""
|
||||||
|
# 在未读消息中查找并更新
|
||||||
|
for message in self.unread_messages:
|
||||||
|
if message.message_id == message_id:
|
||||||
|
message.add_action(action)
|
||||||
|
break
|
||||||
|
|
||||||
|
# 在历史消息中查找并更新
|
||||||
|
for message in self.history_messages:
|
||||||
|
if message.message_id == message_id:
|
||||||
|
message.add_action(action)
|
||||||
|
break
|
||||||
|
|
||||||
|
def _detect_chat_type(self, message: "DatabaseMessages"):
|
||||||
|
"""根据消息内容自动检测聊天类型"""
|
||||||
|
# 只有在第一次添加消息时才检测聊天类型,避免后续消息改变类型
|
||||||
|
if len(self.unread_messages) == 1: # 只有这条消息
|
||||||
|
# 如果消息包含群组信息,则为群聊
|
||||||
|
if hasattr(message, "chat_info_group_id") and message.chat_info_group_id:
|
||||||
|
self.chat_type = ChatType.GROUP
|
||||||
|
elif hasattr(message, "chat_info_group_name") and message.chat_info_group_name:
|
||||||
|
self.chat_type = ChatType.GROUP
|
||||||
|
else:
|
||||||
|
self.chat_type = ChatType.PRIVATE
|
||||||
|
|
||||||
|
def update_chat_type(self, chat_type: ChatType):
|
||||||
|
"""手动更新聊天类型"""
|
||||||
|
self.chat_type = chat_type
|
||||||
|
|
||||||
|
def set_chat_mode(self, chat_mode: ChatMode):
|
||||||
|
"""设置聊天模式"""
|
||||||
|
self.chat_mode = chat_mode
|
||||||
|
|
||||||
|
def is_group_chat(self) -> bool:
|
||||||
|
"""检查是否为群聊"""
|
||||||
|
return self.chat_type == ChatType.GROUP
|
||||||
|
|
||||||
|
def is_private_chat(self) -> bool:
|
||||||
|
"""检查是否为私聊"""
|
||||||
|
return self.chat_type == ChatType.PRIVATE
|
||||||
|
|
||||||
|
def get_chat_type_display(self) -> str:
|
||||||
|
"""获取聊天类型的显示名称"""
|
||||||
|
if self.chat_type == ChatType.GROUP:
|
||||||
|
return "群聊"
|
||||||
|
elif self.chat_type == ChatType.PRIVATE:
|
||||||
|
return "私聊"
|
||||||
|
else:
|
||||||
|
return "未知类型"
|
||||||
|
|
||||||
|
def mark_message_as_read(self, message_id: str):
|
||||||
|
"""标记消息为已读"""
|
||||||
|
for msg in self.unread_messages:
|
||||||
|
if msg.message_id == message_id:
|
||||||
|
msg.is_read = True
|
||||||
|
self.history_messages.append(msg)
|
||||||
|
self.unread_messages.remove(msg)
|
||||||
|
break
|
||||||
|
|
||||||
|
def get_unread_messages(self) -> List["DatabaseMessages"]:
|
||||||
|
"""获取未读消息"""
|
||||||
|
return [msg for msg in self.unread_messages if not msg.is_read]
|
||||||
|
|
||||||
|
def get_history_messages(self, limit: int = 20) -> List["DatabaseMessages"]:
|
||||||
|
"""获取历史消息"""
|
||||||
|
# 优先返回最近的历史消息和所有未读消息
|
||||||
|
recent_history = self.history_messages[-limit:] if len(self.history_messages) > limit else self.history_messages
|
||||||
|
return recent_history
|
||||||
|
|
||||||
|
def calculate_interruption_probability(self, max_limit: int, probability_factor: float) -> float:
|
||||||
|
"""计算打断概率"""
|
||||||
|
if max_limit <= 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# 计算打断比例
|
||||||
|
interruption_ratio = self.interruption_count / max_limit
|
||||||
|
|
||||||
|
# 如果已达到或超过最大次数,完全禁止打断
|
||||||
|
if self.interruption_count >= max_limit:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# 如果超过概率因子,概率下降
|
||||||
|
if interruption_ratio > probability_factor:
|
||||||
|
# 使用指数衰减,超过限制越多,概率越低
|
||||||
|
excess_ratio = interruption_ratio - probability_factor
|
||||||
|
probability = 0.8 * (0.5**excess_ratio) # 基础概率0.8,指数衰减
|
||||||
|
else:
|
||||||
|
# 在限制内,保持较高概率
|
||||||
|
probability = 0.8
|
||||||
|
|
||||||
|
return max(0.0, min(1.0, probability))
|
||||||
|
|
||||||
|
def increment_interruption_count(self):
|
||||||
|
"""增加打断计数"""
|
||||||
|
self.interruption_count += 1
|
||||||
|
self.last_interruption_time = time.time()
|
||||||
|
|
||||||
|
# 同步打断计数到ChatStream
|
||||||
|
self._sync_interruption_count_to_stream()
|
||||||
|
|
||||||
|
def reset_interruption_count(self):
|
||||||
|
"""重置打断计数和afc阈值调整"""
|
||||||
|
self.interruption_count = 0
|
||||||
|
self.last_interruption_time = 0.0
|
||||||
|
self.afc_threshold_adjustment = 0.0
|
||||||
|
|
||||||
|
# 同步打断计数到ChatStream
|
||||||
|
self._sync_interruption_count_to_stream()
|
||||||
|
|
||||||
|
def apply_interruption_afc_reduction(self, reduction_value: float):
|
||||||
|
"""应用打断导致的afc阈值降低"""
|
||||||
|
self.afc_threshold_adjustment += reduction_value
|
||||||
|
logger.debug(f"应用afc阈值降低: {reduction_value}, 总调整量: {self.afc_threshold_adjustment}")
|
||||||
|
|
||||||
|
def get_afc_threshold_adjustment(self) -> float:
|
||||||
|
"""获取当前的afc阈值调整量"""
|
||||||
|
return self.afc_threshold_adjustment
|
||||||
|
|
||||||
|
def _sync_interruption_count_to_stream(self):
|
||||||
|
"""同步打断计数到ChatStream"""
|
||||||
|
try:
|
||||||
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
|
||||||
|
chat_manager = get_chat_manager()
|
||||||
|
if chat_manager:
|
||||||
|
chat_stream = chat_manager.get_stream(self.stream_id)
|
||||||
|
if chat_stream and hasattr(chat_stream, "interruption_count"):
|
||||||
|
# 在这里我们只是标记需要保存,实际的保存会在下次save时进行
|
||||||
|
chat_stream.saved = False
|
||||||
|
logger.debug(
|
||||||
|
f"已同步StreamContext {self.stream_id} 的打断计数 {self.interruption_count} 到ChatStream"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"同步打断计数到ChatStream失败: {e}")
|
||||||
|
|
||||||
|
def set_current_message(self, message: "DatabaseMessages"):
|
||||||
|
"""设置当前消息"""
|
||||||
|
self.current_message = message
|
||||||
|
|
||||||
|
def get_template_name(self) -> Optional[str]:
|
||||||
|
"""获取模板名称"""
|
||||||
|
if (
|
||||||
|
self.current_message
|
||||||
|
and hasattr(self.current_message, "additional_config")
|
||||||
|
and self.current_message.additional_config
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
|
||||||
|
config = json.loads(self.current_message.additional_config)
|
||||||
|
if config.get("template_info") and not config.get("template_default", True):
|
||||||
|
return config.get("template_name")
|
||||||
|
except (json.JSONDecodeError, AttributeError):
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_last_message(self) -> Optional["DatabaseMessages"]:
|
||||||
|
"""获取最后一条消息"""
|
||||||
|
if self.current_message:
|
||||||
|
return self.current_message
|
||||||
|
if self.unread_messages:
|
||||||
|
return self.unread_messages[-1]
|
||||||
|
if self.history_messages:
|
||||||
|
return self.history_messages[-1]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def check_types(self, types: list) -> bool:
|
||||||
|
"""
|
||||||
|
检查当前消息是否支持指定的类型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
types: 需要检查的消息类型列表,如 ["text", "image", "emoji"]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 如果消息支持所有指定的类型则返回True,否则返回False
|
||||||
|
"""
|
||||||
|
if not self.current_message:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not types:
|
||||||
|
# 如果没有指定类型要求,默认为支持
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 优先从additional_config中获取format_info
|
||||||
|
if hasattr(self.current_message, "additional_config") and self.current_message.additional_config:
|
||||||
|
try:
|
||||||
|
import orjson
|
||||||
|
|
||||||
|
config = orjson.loads(self.current_message.additional_config)
|
||||||
|
|
||||||
|
# 检查format_info结构
|
||||||
|
if "format_info" in config:
|
||||||
|
format_info = config["format_info"]
|
||||||
|
|
||||||
|
# 方法1: 直接检查accept_format字段
|
||||||
|
if "accept_format" in format_info:
|
||||||
|
accept_format = format_info["accept_format"]
|
||||||
|
# 确保accept_format是列表类型
|
||||||
|
if isinstance(accept_format, str):
|
||||||
|
accept_format = [accept_format]
|
||||||
|
elif isinstance(accept_format, list):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# 如果accept_format不是字符串或列表,尝试转换为列表
|
||||||
|
accept_format = list(accept_format) if hasattr(accept_format, "__iter__") else []
|
||||||
|
|
||||||
|
# 检查所有请求的类型是否都被支持
|
||||||
|
for requested_type in types:
|
||||||
|
if requested_type not in accept_format:
|
||||||
|
logger.debug(f"消息不支持类型 '{requested_type}',支持的类型: {accept_format}")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 方法2: 检查content_format字段(向后兼容)
|
||||||
|
elif "content_format" in format_info:
|
||||||
|
content_format = format_info["content_format"]
|
||||||
|
# 确保content_format是列表类型
|
||||||
|
if isinstance(content_format, str):
|
||||||
|
content_format = [content_format]
|
||||||
|
elif isinstance(content_format, list):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
content_format = list(content_format) if hasattr(content_format, "__iter__") else []
|
||||||
|
|
||||||
|
# 检查所有请求的类型是否都被支持
|
||||||
|
for requested_type in types:
|
||||||
|
if requested_type not in content_format:
|
||||||
|
logger.debug(f"消息不支持类型 '{requested_type}',支持的内容格式: {content_format}")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
except (orjson.JSONDecodeError, AttributeError, TypeError) as e:
|
||||||
|
logger.debug(f"解析消息格式信息失败: {e}")
|
||||||
|
|
||||||
|
# 备用方案:如果无法从additional_config获取格式信息,使用默认支持的类型
|
||||||
|
# 大多数消息至少支持text类型
|
||||||
|
default_supported_types = ["text", "emoji"]
|
||||||
|
for requested_type in types:
|
||||||
|
if requested_type not in default_supported_types:
|
||||||
|
logger.debug(f"使用默认类型检查,消息可能不支持类型 '{requested_type}'")
|
||||||
|
# 对于非基础类型,返回False以避免错误
|
||||||
|
if requested_type not in ["text", "emoji", "reply"]:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_priority_mode(self) -> Optional[str]:
|
||||||
|
"""获取优先级模式"""
|
||||||
|
return self.priority_mode
|
||||||
|
|
||||||
|
def get_priority_info(self) -> Optional[dict]:
|
||||||
|
"""获取优先级信息"""
|
||||||
|
return self.priority_info
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MessageManagerStats(BaseDataModel):
|
||||||
|
"""消息管理器统计信息"""
|
||||||
|
|
||||||
|
total_streams: int = 0
|
||||||
|
active_streams: int = 0
|
||||||
|
total_unread_messages: int = 0
|
||||||
|
total_processed_messages: int = 0
|
||||||
|
start_time: float = field(default_factory=time.time)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def uptime(self) -> float:
|
||||||
|
"""运行时间"""
|
||||||
|
return time.time() - self.start_time
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StreamStats(BaseDataModel):
|
||||||
|
"""聊天流统计信息"""
|
||||||
|
|
||||||
|
stream_id: str
|
||||||
|
is_active: bool
|
||||||
|
unread_count: int
|
||||||
|
history_count: int
|
||||||
|
last_check_time: float
|
||||||
|
has_active_task: bool
|
||||||
@@ -30,6 +30,7 @@ from src.common.database.sqlalchemy_models import (
|
|||||||
Schedule,
|
Schedule,
|
||||||
MaiZoneScheduleStatus,
|
MaiZoneScheduleStatus,
|
||||||
CacheEntries,
|
CacheEntries,
|
||||||
|
UserRelationships,
|
||||||
)
|
)
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
@@ -54,6 +55,7 @@ MODEL_MAPPING = {
|
|||||||
"Schedule": Schedule,
|
"Schedule": Schedule,
|
||||||
"MaiZoneScheduleStatus": MaiZoneScheduleStatus,
|
"MaiZoneScheduleStatus": MaiZoneScheduleStatus,
|
||||||
"CacheEntries": CacheEntries,
|
"CacheEntries": CacheEntries,
|
||||||
|
"UserRelationships": UserRelationships,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -55,7 +55,17 @@ class ChatStreams(Base):
|
|||||||
user_cardname = Column(Text, nullable=True)
|
user_cardname = Column(Text, nullable=True)
|
||||||
energy_value = Column(Float, nullable=True, default=5.0)
|
energy_value = Column(Float, nullable=True, default=5.0)
|
||||||
sleep_pressure = Column(Float, nullable=True, default=0.0)
|
sleep_pressure = Column(Float, nullable=True, default=0.0)
|
||||||
focus_energy = Column(Float, nullable=True, default=1.0)
|
focus_energy = Column(Float, nullable=True, default=0.5)
|
||||||
|
# 动态兴趣度系统字段
|
||||||
|
base_interest_energy = Column(Float, nullable=True, default=0.5)
|
||||||
|
message_interest_total = Column(Float, nullable=True, default=0.0)
|
||||||
|
message_count = Column(Integer, nullable=True, default=0)
|
||||||
|
action_count = Column(Integer, nullable=True, default=0)
|
||||||
|
reply_count = Column(Integer, nullable=True, default=0)
|
||||||
|
last_interaction_time = Column(Float, nullable=True, default=None)
|
||||||
|
consecutive_no_reply = Column(Integer, nullable=True, default=0)
|
||||||
|
# 消息打断系统字段
|
||||||
|
interruption_count = Column(Integer, nullable=True, default=0)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("idx_chatstreams_stream_id", "stream_id"),
|
Index("idx_chatstreams_stream_id", "stream_id"),
|
||||||
@@ -165,11 +175,16 @@ class Messages(Base):
|
|||||||
is_command = Column(Boolean, nullable=False, default=False)
|
is_command = Column(Boolean, nullable=False, default=False)
|
||||||
is_notify = Column(Boolean, nullable=False, default=False)
|
is_notify = Column(Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
|
# 兴趣度系统字段
|
||||||
|
actions = Column(Text, nullable=True) # JSON格式存储动作列表
|
||||||
|
should_reply = Column(Boolean, nullable=True, default=False)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("idx_messages_message_id", "message_id"),
|
Index("idx_messages_message_id", "message_id"),
|
||||||
Index("idx_messages_chat_id", "chat_id"),
|
Index("idx_messages_chat_id", "chat_id"),
|
||||||
Index("idx_messages_time", "time"),
|
Index("idx_messages_time", "time"),
|
||||||
Index("idx_messages_user_id", "user_id"),
|
Index("idx_messages_user_id", "user_id"),
|
||||||
|
Index("idx_messages_should_reply", "should_reply"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -300,6 +315,26 @@ class PersonInfo(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BotPersonalityInterests(Base):
|
||||||
|
"""机器人人格兴趣标签模型"""
|
||||||
|
|
||||||
|
__tablename__ = "bot_personality_interests"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
personality_id = Column(get_string_field(100), nullable=False, index=True)
|
||||||
|
personality_description = Column(Text, nullable=False)
|
||||||
|
interest_tags = Column(Text, nullable=False) # JSON格式存储的兴趣标签列表
|
||||||
|
embedding_model = Column(get_string_field(100), nullable=False, default="text-embedding-ada-002")
|
||||||
|
version = Column(Integer, nullable=False, default=1)
|
||||||
|
last_updated = Column(DateTime, nullable=False, default=datetime.datetime.now, index=True)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_botpersonality_personality_id", "personality_id"),
|
||||||
|
Index("idx_botpersonality_version", "version"),
|
||||||
|
Index("idx_botpersonality_last_updated", "last_updated"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Memory(Base):
|
class Memory(Base):
|
||||||
"""记忆模型"""
|
"""记忆模型"""
|
||||||
|
|
||||||
@@ -722,3 +757,23 @@ class UserPermissions(Base):
|
|||||||
Index("idx_user_permission", "platform", "user_id", "permission_node"),
|
Index("idx_user_permission", "platform", "user_id", "permission_node"),
|
||||||
Index("idx_permission_granted", "permission_node", "granted"),
|
Index("idx_permission_granted", "permission_node", "granted"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UserRelationships(Base):
|
||||||
|
"""用户关系模型 - 存储用户与bot的关系数据"""
|
||||||
|
|
||||||
|
__tablename__ = "user_relationships"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
user_id = Column(get_string_field(100), nullable=False, unique=True, index=True) # 用户ID
|
||||||
|
user_name = Column(get_string_field(100), nullable=True) # 用户名
|
||||||
|
relationship_text = Column(Text, nullable=True) # 关系印象描述
|
||||||
|
relationship_score = Column(Float, nullable=False, default=0.3) # 关系分数(0-1)
|
||||||
|
last_updated = Column(Float, nullable=False, default=time.time) # 最后更新时间
|
||||||
|
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_user_relationship_id", "user_id"),
|
||||||
|
Index("idx_relationship_score", "relationship_score"),
|
||||||
|
Index("idx_relationship_updated", "last_updated"),
|
||||||
|
)
|
||||||
|
|||||||
@@ -350,6 +350,10 @@ MODULE_COLORS = {
|
|||||||
"memory": "\033[38;5;117m", # 天蓝色
|
"memory": "\033[38;5;117m", # 天蓝色
|
||||||
"hfc": "\033[38;5;81m", # 稍微暗一些的青色,保持可读
|
"hfc": "\033[38;5;81m", # 稍微暗一些的青色,保持可读
|
||||||
"action_manager": "\033[38;5;208m", # 橙色,不与replyer重复
|
"action_manager": "\033[38;5;208m", # 橙色,不与replyer重复
|
||||||
|
"message_manager": "\033[38;5;27m", # 深蓝色,消息管理器
|
||||||
|
"chatter_manager": "\033[38;5;129m", # 紫色,聊天管理器
|
||||||
|
"chatter_interest_scoring": "\033[38;5;214m", # 橙黄色,兴趣评分
|
||||||
|
"plan_executor": "\033[38;5;172m", # 橙褐色,计划执行器
|
||||||
# 关系系统
|
# 关系系统
|
||||||
"relation": "\033[38;5;139m", # 柔和的紫色,不刺眼
|
"relation": "\033[38;5;139m", # 柔和的紫色,不刺眼
|
||||||
# 聊天相关模块
|
# 聊天相关模块
|
||||||
@@ -551,6 +555,10 @@ MODULE_ALIASES = {
|
|||||||
"llm_models": "模型",
|
"llm_models": "模型",
|
||||||
"person_info": "人物",
|
"person_info": "人物",
|
||||||
"chat_stream": "聊天流",
|
"chat_stream": "聊天流",
|
||||||
|
"message_manager": "消息管理",
|
||||||
|
"chatter_manager": "聊天管理",
|
||||||
|
"chatter_interest_scoring": "兴趣评分",
|
||||||
|
"plan_executor": "计划执行",
|
||||||
"planner": "规划器",
|
"planner": "规划器",
|
||||||
"replyer": "言语",
|
"replyer": "言语",
|
||||||
"config": "配置",
|
"config": "配置",
|
||||||
|
|||||||
@@ -22,10 +22,15 @@ def _model_to_dict(instance: Base) -> Dict[str, Any]:
|
|||||||
"""
|
"""
|
||||||
将 SQLAlchemy 模型实例转换为字典。
|
将 SQLAlchemy 模型实例转换为字典。
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
return {col.name: getattr(instance, col.name) for col in instance.__table__.columns}
|
return {col.name: getattr(instance, col.name) for col in instance.__table__.columns}
|
||||||
|
except Exception as e:
|
||||||
|
# 如果对象已经脱离会话,尝试从instance.__dict__中获取数据
|
||||||
|
logger.warning(f"从数据库对象获取属性失败,尝试使用__dict__: {e}")
|
||||||
|
return {col.name: instance.__dict__.get(col.name) for col in instance.__table__.columns}
|
||||||
|
|
||||||
|
|
||||||
async def find_messages(
|
def find_messages(
|
||||||
message_filter: dict[str, Any],
|
message_filter: dict[str, Any],
|
||||||
sort: Optional[List[tuple[str, int]]] = None,
|
sort: Optional[List[tuple[str, int]]] = None,
|
||||||
limit: int = 0,
|
limit: int = 0,
|
||||||
@@ -46,7 +51,7 @@ async def find_messages(
|
|||||||
消息字典列表,如果出错则返回空列表。
|
消息字典列表,如果出错则返回空列表。
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
query = select(Messages)
|
query = select(Messages)
|
||||||
|
|
||||||
# 应用过滤器
|
# 应用过滤器
|
||||||
@@ -96,7 +101,7 @@ async def find_messages(
|
|||||||
# 获取时间最早的 limit 条记录,已经是正序
|
# 获取时间最早的 limit 条记录,已经是正序
|
||||||
query = query.order_by(Messages.time.asc()).limit(limit)
|
query = query.order_by(Messages.time.asc()).limit(limit)
|
||||||
try:
|
try:
|
||||||
results = (await session.execute(query)).scalars().all()
|
results = session.execute(query).scalars().all()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"执行earliest查询失败: {e}")
|
logger.error(f"执行earliest查询失败: {e}")
|
||||||
results = []
|
results = []
|
||||||
@@ -104,7 +109,7 @@ async def find_messages(
|
|||||||
# 获取时间最晚的 limit 条记录
|
# 获取时间最晚的 limit 条记录
|
||||||
query = query.order_by(Messages.time.desc()).limit(limit)
|
query = query.order_by(Messages.time.desc()).limit(limit)
|
||||||
try:
|
try:
|
||||||
latest_results = (await session.execute(query)).scalars().all()
|
latest_results = session.execute(query).scalars().all()
|
||||||
# 将结果按时间正序排列
|
# 将结果按时间正序排列
|
||||||
results = sorted(latest_results, key=lambda msg: msg.time)
|
results = sorted(latest_results, key=lambda msg: msg.time)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -128,11 +133,12 @@ async def find_messages(
|
|||||||
if sort_terms:
|
if sort_terms:
|
||||||
query = query.order_by(*sort_terms)
|
query = query.order_by(*sort_terms)
|
||||||
try:
|
try:
|
||||||
results = (await session.execute(query)).scalars().all()
|
results = session.execute(query).scalars().all()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"执行无限制查询失败: {e}")
|
logger.error(f"执行无限制查询失败: {e}")
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
|
# 在会话内将结果转换为字典,避免会话分离错误
|
||||||
return [_model_to_dict(msg) for msg in results]
|
return [_model_to_dict(msg) for msg in results]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_message = (
|
log_message = (
|
||||||
@@ -143,7 +149,7 @@ async def find_messages(
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
async def count_messages(message_filter: dict[str, Any]) -> int:
|
def count_messages(message_filter: dict[str, Any]) -> int:
|
||||||
"""
|
"""
|
||||||
根据提供的过滤器计算消息数量。
|
根据提供的过滤器计算消息数量。
|
||||||
|
|
||||||
@@ -154,7 +160,7 @@ async def count_messages(message_filter: dict[str, Any]) -> int:
|
|||||||
符合条件的消息数量,如果出错则返回 0。
|
符合条件的消息数量,如果出错则返回 0。
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
query = select(func.count(Messages.id))
|
query = select(func.count(Messages.id))
|
||||||
|
|
||||||
# 应用过滤器
|
# 应用过滤器
|
||||||
@@ -192,7 +198,7 @@ async def count_messages(message_filter: dict[str, Any]) -> int:
|
|||||||
if conditions:
|
if conditions:
|
||||||
query = query.where(*conditions)
|
query = query.where(*conditions)
|
||||||
|
|
||||||
count = (await session.execute(query)).scalar()
|
count = session.execute(query).scalar()
|
||||||
return count or 0
|
return count or 0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
|
log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
|
||||||
@@ -201,5 +207,5 @@ async def count_messages(message_filter: dict[str, Any]) -> int:
|
|||||||
|
|
||||||
|
|
||||||
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
|
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
|
||||||
# 注意:对于 SQLAlchemy,插入操作通常是使用 session.add() 和 await session.commit()。
|
# 注意:对于 SQLAlchemy,插入操作通常是使用 session.add() 和 session.commit()。
|
||||||
# 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。
|
# 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。
|
||||||
|
|||||||
@@ -31,7 +31,9 @@ class TelemetryHeartBeatTask(AsyncTask):
|
|||||||
self.client_uuid: str | None = local_storage["mofox_uuid"] if "mofox_uuid" in local_storage else None # type: ignore
|
self.client_uuid: str | None = local_storage["mofox_uuid"] if "mofox_uuid" in local_storage else None # type: ignore
|
||||||
"""客户端UUID"""
|
"""客户端UUID"""
|
||||||
|
|
||||||
self.private_key_pem: str | None = local_storage["mofox_private_key"] if "mofox_private_key" in local_storage else None # type: ignore
|
self.private_key_pem: str | None = (
|
||||||
|
local_storage["mofox_private_key"] if "mofox_private_key" in local_storage else None
|
||||||
|
) # type: ignore
|
||||||
"""客户端私钥"""
|
"""客户端私钥"""
|
||||||
|
|
||||||
self.info_dict = self._get_sys_info()
|
self.info_dict = self._get_sys_info()
|
||||||
@@ -75,10 +77,7 @@ class TelemetryHeartBeatTask(AsyncTask):
|
|||||||
sign_data = f"{self.client_uuid}:{timestamp}:{json.dumps(request_body, separators=(',', ':'))}"
|
sign_data = f"{self.client_uuid}:{timestamp}:{json.dumps(request_body, separators=(',', ':'))}"
|
||||||
|
|
||||||
# 加载私钥
|
# 加载私钥
|
||||||
private_key = serialization.load_pem_private_key(
|
private_key = serialization.load_pem_private_key(self.private_key_pem.encode("utf-8"), password=None)
|
||||||
self.private_key_pem.encode('utf-8'),
|
|
||||||
password=None
|
|
||||||
)
|
|
||||||
|
|
||||||
# 确保是RSA私钥
|
# 确保是RSA私钥
|
||||||
if not isinstance(private_key, rsa.RSAPrivateKey):
|
if not isinstance(private_key, rsa.RSAPrivateKey):
|
||||||
@@ -86,16 +85,13 @@ class TelemetryHeartBeatTask(AsyncTask):
|
|||||||
|
|
||||||
# 生成签名
|
# 生成签名
|
||||||
signature = private_key.sign(
|
signature = private_key.sign(
|
||||||
sign_data.encode('utf-8'),
|
sign_data.encode("utf-8"),
|
||||||
padding.PSS(
|
padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH),
|
||||||
mgf=padding.MGF1(hashes.SHA256()),
|
hashes.SHA256(),
|
||||||
salt_length=padding.PSS.MAX_LENGTH
|
|
||||||
),
|
|
||||||
hashes.SHA256()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Base64编码
|
# Base64编码
|
||||||
signature_b64 = base64.b64encode(signature).decode('utf-8')
|
signature_b64 = base64.b64encode(signature).decode("utf-8")
|
||||||
|
|
||||||
return timestamp, signature_b64
|
return timestamp, signature_b64
|
||||||
|
|
||||||
@@ -113,10 +109,7 @@ class TelemetryHeartBeatTask(AsyncTask):
|
|||||||
raise ValueError("私钥未初始化")
|
raise ValueError("私钥未初始化")
|
||||||
|
|
||||||
# 加载私钥
|
# 加载私钥
|
||||||
private_key = serialization.load_pem_private_key(
|
private_key = serialization.load_pem_private_key(self.private_key_pem.encode("utf-8"), password=None)
|
||||||
self.private_key_pem.encode('utf-8'),
|
|
||||||
password=None
|
|
||||||
)
|
|
||||||
|
|
||||||
# 确保是RSA私钥
|
# 确保是RSA私钥
|
||||||
if not isinstance(private_key, rsa.RSAPrivateKey):
|
if not isinstance(private_key, rsa.RSAPrivateKey):
|
||||||
@@ -125,14 +118,10 @@ class TelemetryHeartBeatTask(AsyncTask):
|
|||||||
# 解密挑战数据
|
# 解密挑战数据
|
||||||
decrypted_bytes = private_key.decrypt(
|
decrypted_bytes = private_key.decrypt(
|
||||||
base64.b64decode(challenge_b64),
|
base64.b64decode(challenge_b64),
|
||||||
padding.OAEP(
|
padding.OAEP(mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None),
|
||||||
mgf=padding.MGF1(hashes.SHA256()),
|
|
||||||
algorithm=hashes.SHA256(),
|
|
||||||
label=None
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return decrypted_bytes.decode('utf-8')
|
return decrypted_bytes.decode("utf-8")
|
||||||
|
|
||||||
async def _req_uuid(self) -> bool:
|
async def _req_uuid(self) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -155,14 +144,12 @@ class TelemetryHeartBeatTask(AsyncTask):
|
|||||||
|
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
response_text = await response.text()
|
response_text = await response.text()
|
||||||
logger.error(
|
logger.error(f"注册步骤1失败,状态码: {response.status}, 响应内容: {response_text}")
|
||||||
f"注册步骤1失败,状态码: {response.status}, 响应内容: {response_text}"
|
|
||||||
)
|
|
||||||
raise aiohttp.ClientResponseError(
|
raise aiohttp.ClientResponseError(
|
||||||
request_info=response.request_info,
|
request_info=response.request_info,
|
||||||
history=response.history,
|
history=response.history,
|
||||||
status=response.status,
|
status=response.status,
|
||||||
message=f"Step1 failed: {response_text}"
|
message=f"Step1 failed: {response_text}",
|
||||||
)
|
)
|
||||||
|
|
||||||
step1_data = await response.json()
|
step1_data = await response.json()
|
||||||
@@ -195,10 +182,7 @@ class TelemetryHeartBeatTask(AsyncTask):
|
|||||||
# Step 2: 发送解密结果完成注册
|
# Step 2: 发送解密结果完成注册
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f"{TELEMETRY_SERVER_URL}/stat/reg_client_step2",
|
f"{TELEMETRY_SERVER_URL}/stat/reg_client_step2",
|
||||||
json={
|
json={"temp_uuid": temp_uuid, "decrypted_uuid": decrypted_uuid},
|
||||||
"temp_uuid": temp_uuid,
|
|
||||||
"decrypted_uuid": decrypted_uuid
|
|
||||||
},
|
|
||||||
timeout=aiohttp.ClientTimeout(total=5),
|
timeout=aiohttp.ClientTimeout(total=5),
|
||||||
) as response:
|
) as response:
|
||||||
logger.debug(f"Step2 Response status: {response.status}")
|
logger.debug(f"Step2 Response status: {response.status}")
|
||||||
@@ -225,23 +209,19 @@ class TelemetryHeartBeatTask(AsyncTask):
|
|||||||
raise ValueError(f"Step2失败: {response_text}")
|
raise ValueError(f"Step2失败: {response_text}")
|
||||||
else:
|
else:
|
||||||
response_text = await response.text()
|
response_text = await response.text()
|
||||||
logger.error(
|
logger.error(f"注册步骤2失败,状态码: {response.status}, 响应内容: {response_text}")
|
||||||
f"注册步骤2失败,状态码: {response.status}, 响应内容: {response_text}"
|
|
||||||
)
|
|
||||||
raise aiohttp.ClientResponseError(
|
raise aiohttp.ClientResponseError(
|
||||||
request_info=response.request_info,
|
request_info=response.request_info,
|
||||||
history=response.history,
|
history=response.history,
|
||||||
status=response.status,
|
status=response.status,
|
||||||
message=f"Step2 failed: {response_text}"
|
message=f"Step2 failed: {response_text}",
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
error_msg = str(e) or "未知错误"
|
error_msg = str(e) or "未知错误"
|
||||||
logger.warning(
|
logger.warning(f"注册客户端出错,不过你还是可以正常使用墨狐: {type(e).__name__}: {error_msg}")
|
||||||
f"注册客户端出错,不过你还是可以正常使用墨狐: {type(e).__name__}: {error_msg}"
|
|
||||||
)
|
|
||||||
logger.debug(f"完整错误信息: {traceback.format_exc()}")
|
logger.debug(f"完整错误信息: {traceback.format_exc()}")
|
||||||
|
|
||||||
# 请求失败,重试次数+1
|
# 请求失败,重试次数+1
|
||||||
@@ -270,7 +250,7 @@ class TelemetryHeartBeatTask(AsyncTask):
|
|||||||
"X-mofox-Signature": signature,
|
"X-mofox-Signature": signature,
|
||||||
"X-mofox-Timestamp": timestamp,
|
"X-mofox-Timestamp": timestamp,
|
||||||
"User-Agent": f"MofoxClient/{self.client_uuid[:8]}",
|
"User-Agent": f"MofoxClient/{self.client_uuid[:8]}",
|
||||||
"Content-Type": "application/json"
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.debug(f"正在发送心跳到服务器: {self.server_url}")
|
logger.debug(f"正在发送心跳到服务器: {self.server_url}")
|
||||||
|
|||||||
@@ -99,7 +99,6 @@ def get_global_server() -> Server:
|
|||||||
"""获取全局服务器实例"""
|
"""获取全局服务器实例"""
|
||||||
global global_server
|
global global_server
|
||||||
if global_server is None:
|
if global_server is None:
|
||||||
|
|
||||||
host = os.getenv("HOST", "127.0.0.1")
|
host = os.getenv("HOST", "127.0.0.1")
|
||||||
port_str = os.getenv("PORT", "8000")
|
port_str = os.getenv("PORT", "8000")
|
||||||
|
|
||||||
|
|||||||
@@ -137,7 +137,7 @@ class ModelTaskConfig(ValidatedConfigBase):
|
|||||||
monthly_plan_generator: TaskConfig = Field(..., description="月层计划生成模型配置")
|
monthly_plan_generator: TaskConfig = Field(..., description="月层计划生成模型配置")
|
||||||
emoji_vlm: TaskConfig = Field(..., description="表情包识别模型配置")
|
emoji_vlm: TaskConfig = Field(..., description="表情包识别模型配置")
|
||||||
anti_injection: TaskConfig = Field(..., description="反注入检测专用模型配置")
|
anti_injection: TaskConfig = Field(..., description="反注入检测专用模型配置")
|
||||||
|
relationship_tracker: TaskConfig = Field(..., description="关系追踪模型配置")
|
||||||
# 处理配置文件中命名不一致的问题
|
# 处理配置文件中命名不一致的问题
|
||||||
utils_video: TaskConfig = Field(..., description="视频分析模型配置(兼容配置文件中的命名)")
|
utils_video: TaskConfig = Field(..., description="视频分析模型配置(兼容配置文件中的命名)")
|
||||||
|
|
||||||
|
|||||||
@@ -43,7 +43,8 @@ from src.config.official_configs import (
|
|||||||
CrossContextConfig,
|
CrossContextConfig,
|
||||||
PermissionConfig,
|
PermissionConfig,
|
||||||
CommandConfig,
|
CommandConfig,
|
||||||
PlanningSystemConfig
|
PlanningSystemConfig,
|
||||||
|
AffinityFlowConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .api_ada_configs import (
|
from .api_ada_configs import (
|
||||||
@@ -66,7 +67,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
|
|||||||
|
|
||||||
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
||||||
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
||||||
MMC_VERSION = "0.10.0-alpha-2"
|
MMC_VERSION = "0.11.0-alpha-1"
|
||||||
|
|
||||||
|
|
||||||
def get_key_comment(toml_table, key):
|
def get_key_comment(toml_table, key):
|
||||||
@@ -417,6 +418,7 @@ class Config(ValidatedConfigBase):
|
|||||||
cross_context: CrossContextConfig = Field(
|
cross_context: CrossContextConfig = Field(
|
||||||
default_factory=lambda: CrossContextConfig(), description="跨群聊上下文共享配置"
|
default_factory=lambda: CrossContextConfig(), description="跨群聊上下文共享配置"
|
||||||
)
|
)
|
||||||
|
affinity_flow: AffinityFlowConfig = Field(default_factory=lambda: AffinityFlowConfig(), description="亲和流配置")
|
||||||
|
|
||||||
|
|
||||||
class APIAdapterConfig(ValidatedConfigBase):
|
class APIAdapterConfig(ValidatedConfigBase):
|
||||||
|
|||||||
@@ -51,8 +51,12 @@ class PersonalityConfig(ValidatedConfigBase):
|
|||||||
personality_core: str = Field(..., description="核心人格")
|
personality_core: str = Field(..., description="核心人格")
|
||||||
personality_side: str = Field(..., description="人格侧写")
|
personality_side: str = Field(..., description="人格侧写")
|
||||||
identity: str = Field(default="", description="身份特征")
|
identity: str = Field(default="", description="身份特征")
|
||||||
background_story: str = Field(default="", description="世界观背景故事,这部分内容会作为背景知识,LLM被指导不应主动复述")
|
background_story: str = Field(
|
||||||
safety_guidelines: List[str] = Field(default_factory=list, description="安全与互动底线,Bot在任何情况下都必须遵守的原则")
|
default="", description="世界观背景故事,这部分内容会作为背景知识,LLM被指导不应主动复述"
|
||||||
|
)
|
||||||
|
safety_guidelines: List[str] = Field(
|
||||||
|
default_factory=list, description="安全与互动底线,Bot在任何情况下都必须遵守的原则"
|
||||||
|
)
|
||||||
reply_style: str = Field(default="", description="表达风格")
|
reply_style: str = Field(default="", description="表达风格")
|
||||||
prompt_mode: Literal["s4u", "normal"] = Field(default="s4u", description="Prompt模式")
|
prompt_mode: Literal["s4u", "normal"] = Field(default="s4u", description="Prompt模式")
|
||||||
compress_personality: bool = Field(default=True, description="是否压缩人格")
|
compress_personality: bool = Field(default=True, description="是否压缩人格")
|
||||||
@@ -109,7 +113,8 @@ class ChatConfig(ValidatedConfigBase):
|
|||||||
talk_frequency_adjust: list[list[str]] = Field(default_factory=lambda: [], description="聊天频率调整")
|
talk_frequency_adjust: list[list[str]] = Field(default_factory=lambda: [], description="聊天频率调整")
|
||||||
focus_value: float = Field(default=1.0, description="专注值")
|
focus_value: float = Field(default=1.0, description="专注值")
|
||||||
focus_mode_quiet_groups: List[str] = Field(
|
focus_mode_quiet_groups: List[str] = Field(
|
||||||
default_factory=list, description='专注模式下需要保持安静的群组列表, 格式: ["platform:group_id1", "platform:group_id2"]'
|
default_factory=list,
|
||||||
|
description='专注模式下需要保持安静的群组列表, 格式: ["platform:group_id1", "platform:group_id2"]',
|
||||||
)
|
)
|
||||||
force_reply_private: bool = Field(default=False, description="强制回复私聊")
|
force_reply_private: bool = Field(default=False, description="强制回复私聊")
|
||||||
group_chat_mode: Literal["auto", "normal", "focus"] = Field(default="auto", description="群聊模式")
|
group_chat_mode: Literal["auto", "normal", "focus"] = Field(default="auto", description="群聊模式")
|
||||||
@@ -129,6 +134,31 @@ class ChatConfig(ValidatedConfigBase):
|
|||||||
)
|
)
|
||||||
delta_sigma: int = Field(default=120, description="采用正态分布随机时间间隔")
|
delta_sigma: int = Field(default=120, description="采用正态分布随机时间间隔")
|
||||||
|
|
||||||
|
# 消息打断系统配置
|
||||||
|
interruption_enabled: bool = Field(default=True, description="是否启用消息打断系统")
|
||||||
|
interruption_max_limit: int = Field(default=3, ge=0, description="每个聊天流的最大打断次数")
|
||||||
|
interruption_probability_factor: float = Field(
|
||||||
|
default=0.8, ge=0.0, le=1.0, description="打断概率因子,当前打断次数/最大打断次数超过此值时触发概率下降"
|
||||||
|
)
|
||||||
|
interruption_afc_reduction: float = Field(
|
||||||
|
default=0.05, ge=0.0, le=1.0, description="每次连续打断降低的afc阈值数值"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 动态消息分发系统配置
|
||||||
|
dynamic_distribution_enabled: bool = Field(default=True, description="是否启用动态消息分发周期调整")
|
||||||
|
dynamic_distribution_base_interval: float = Field(
|
||||||
|
default=5.0, ge=1.0, le=60.0, description="基础分发间隔(秒)"
|
||||||
|
)
|
||||||
|
dynamic_distribution_min_interval: float = Field(
|
||||||
|
default=1.0, ge=0.5, le=10.0, description="最小分发间隔(秒)"
|
||||||
|
)
|
||||||
|
dynamic_distribution_max_interval: float = Field(
|
||||||
|
default=30.0, ge=5.0, le=300.0, description="最大分发间隔(秒)"
|
||||||
|
)
|
||||||
|
dynamic_distribution_jitter_factor: float = Field(
|
||||||
|
default=0.2, ge=0.0, le=0.5, description="分发间隔随机扰动因子"
|
||||||
|
)
|
||||||
|
|
||||||
def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float:
|
def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float:
|
||||||
"""
|
"""
|
||||||
根据当前时间和聊天流获取对应的 talk_frequency
|
根据当前时间和聊天流获取对应的 talk_frequency
|
||||||
@@ -376,6 +406,7 @@ class ExpressionConfig(ValidatedConfigBase):
|
|||||||
# 如果都没有匹配,返回默认值
|
# 如果都没有匹配,返回默认值
|
||||||
return True, True, 1.0
|
return True, True, 1.0
|
||||||
|
|
||||||
|
|
||||||
class ToolConfig(ValidatedConfigBase):
|
class ToolConfig(ValidatedConfigBase):
|
||||||
"""工具配置类"""
|
"""工具配置类"""
|
||||||
|
|
||||||
@@ -510,7 +541,6 @@ class ExperimentalConfig(ValidatedConfigBase):
|
|||||||
pfc_chatting: bool = Field(default=False, description="启用PFC聊天")
|
pfc_chatting: bool = Field(default=False, description="启用PFC聊天")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MaimMessageConfig(ValidatedConfigBase):
|
class MaimMessageConfig(ValidatedConfigBase):
|
||||||
"""maim_message配置类"""
|
"""maim_message配置类"""
|
||||||
|
|
||||||
@@ -635,8 +665,12 @@ class SleepSystemConfig(ValidatedConfigBase):
|
|||||||
sleep_by_schedule: bool = Field(default=True, description="是否根据日程表进行睡觉")
|
sleep_by_schedule: bool = Field(default=True, description="是否根据日程表进行睡觉")
|
||||||
fixed_sleep_time: str = Field(default="23:00", description="固定的睡觉时间")
|
fixed_sleep_time: str = Field(default="23:00", description="固定的睡觉时间")
|
||||||
fixed_wake_up_time: str = Field(default="07:00", description="固定的起床时间")
|
fixed_wake_up_time: str = Field(default="07:00", description="固定的起床时间")
|
||||||
sleep_time_offset_minutes: int = Field(default=15, ge=0, le=60, description="睡觉时间随机偏移量范围(分钟),实际睡觉时间会在±该值范围内随机")
|
sleep_time_offset_minutes: int = Field(
|
||||||
wake_up_time_offset_minutes: int = Field(default=15, ge=0, le=60, description="起床时间随机偏移量范围(分钟),实际起床时间会在±该值范围内随机")
|
default=15, ge=0, le=60, description="睡觉时间随机偏移量范围(分钟),实际睡觉时间会在±该值范围内随机"
|
||||||
|
)
|
||||||
|
wake_up_time_offset_minutes: int = Field(
|
||||||
|
default=15, ge=0, le=60, description="起床时间随机偏移量范围(分钟),实际起床时间会在±该值范围内随机"
|
||||||
|
)
|
||||||
wakeup_threshold: float = Field(default=15.0, ge=1.0, description="唤醒阈值,达到此值时会被唤醒")
|
wakeup_threshold: float = Field(default=15.0, ge=1.0, description="唤醒阈值,达到此值时会被唤醒")
|
||||||
private_message_increment: float = Field(default=3.0, ge=0.1, description="私聊消息增加的唤醒度")
|
private_message_increment: float = Field(default=3.0, ge=0.1, description="私聊消息增加的唤醒度")
|
||||||
group_mention_increment: float = Field(default=2.0, ge=0.1, description="群聊艾特增加的唤醒度")
|
group_mention_increment: float = Field(default=2.0, ge=0.1, description="群聊艾特增加的唤醒度")
|
||||||
@@ -690,6 +724,8 @@ class CrossContextConfig(ValidatedConfigBase):
|
|||||||
|
|
||||||
enable: bool = Field(default=False, description="是否启用跨群聊上下文共享功能")
|
enable: bool = Field(default=False, description="是否启用跨群聊上下文共享功能")
|
||||||
groups: List[ContextGroup] = Field(default_factory=list, description="上下文共享组列表")
|
groups: List[ContextGroup] = Field(default_factory=list, description="上下文共享组列表")
|
||||||
|
|
||||||
|
|
||||||
class CommandConfig(ValidatedConfigBase):
|
class CommandConfig(ValidatedConfigBase):
|
||||||
"""命令系统配置类"""
|
"""命令系统配置类"""
|
||||||
|
|
||||||
@@ -703,3 +739,34 @@ class PermissionConfig(ValidatedConfigBase):
|
|||||||
master_users: List[List[str]] = Field(
|
master_users: List[List[str]] = Field(
|
||||||
default_factory=list, description="Master用户列表,格式: [[platform, user_id], ...]"
|
default_factory=list, description="Master用户列表,格式: [[platform, user_id], ...]"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AffinityFlowConfig(ValidatedConfigBase):
|
||||||
|
"""亲和流配置类(兴趣度评分和人物关系系统)"""
|
||||||
|
|
||||||
|
# 兴趣评分系统参数
|
||||||
|
reply_action_interest_threshold: float = Field(default=0.4, description="回复动作兴趣阈值")
|
||||||
|
non_reply_action_interest_threshold: float = Field(default=0.2, description="非回复动作兴趣阈值")
|
||||||
|
high_match_interest_threshold: float = Field(default=0.8, description="高匹配兴趣阈值")
|
||||||
|
medium_match_interest_threshold: float = Field(default=0.5, description="中匹配兴趣阈值")
|
||||||
|
low_match_interest_threshold: float = Field(default=0.2, description="低匹配兴趣阈值")
|
||||||
|
high_match_keyword_multiplier: float = Field(default=1.5, description="高匹配关键词兴趣倍率")
|
||||||
|
medium_match_keyword_multiplier: float = Field(default=1.2, description="中匹配关键词兴趣倍率")
|
||||||
|
low_match_keyword_multiplier: float = Field(default=1.0, description="低匹配关键词兴趣倍率")
|
||||||
|
match_count_bonus: float = Field(default=0.1, description="匹配数关键词加成值")
|
||||||
|
max_match_bonus: float = Field(default=0.5, description="最大匹配数加成值")
|
||||||
|
|
||||||
|
# 回复决策系统参数
|
||||||
|
no_reply_threshold_adjustment: float = Field(default=0.1, description="不回复兴趣阈值调整值")
|
||||||
|
reply_cooldown_reduction: int = Field(default=2, description="回复后减少的不回复计数")
|
||||||
|
max_no_reply_count: int = Field(default=5, description="最大不回复计数次数")
|
||||||
|
|
||||||
|
# 综合评分权重
|
||||||
|
keyword_match_weight: float = Field(default=0.4, description="兴趣关键词匹配度权重")
|
||||||
|
mention_bot_weight: float = Field(default=0.3, description="提及bot分数权重")
|
||||||
|
relationship_weight: float = Field(default=0.3, description="人物关系分数权重")
|
||||||
|
|
||||||
|
# 提及bot相关参数
|
||||||
|
mention_bot_adjustment_threshold: float = Field(default=0.3, description="提及bot后的调整阈值")
|
||||||
|
mention_bot_interest_score: float = Field(default=0.6, description="提及bot的兴趣分")
|
||||||
|
base_relationship_score: float = Field(default=0.5, description="基础人物关系分")
|
||||||
|
|||||||
@@ -64,6 +64,9 @@ class Individuality:
|
|||||||
else:
|
else:
|
||||||
logger.error("人设构建失败")
|
logger.error("人设构建失败")
|
||||||
|
|
||||||
|
# 初始化智能兴趣系统
|
||||||
|
await self._initialize_smart_interest_system(personality_result, identity_result)
|
||||||
|
|
||||||
# 如果任何一个发生变化,都需要清空数据库中的info_list(因为这影响整体人设)
|
# 如果任何一个发生变化,都需要清空数据库中的info_list(因为这影响整体人设)
|
||||||
if personality_changed or identity_changed:
|
if personality_changed or identity_changed:
|
||||||
logger.info("将清空数据库中原有的关键词缓存")
|
logger.info("将清空数据库中原有的关键词缓存")
|
||||||
@@ -75,6 +78,21 @@ class Individuality:
|
|||||||
}
|
}
|
||||||
await person_info_manager.update_one_field(self.bot_person_id, "info_list", [], data=update_data)
|
await person_info_manager.update_one_field(self.bot_person_id, "info_list", [], data=update_data)
|
||||||
|
|
||||||
|
async def _initialize_smart_interest_system(self, personality_result: str, identity_result: str):
|
||||||
|
"""初始化智能兴趣系统"""
|
||||||
|
# 组合完整的人设描述
|
||||||
|
full_personality = f"{personality_result},{identity_result}"
|
||||||
|
|
||||||
|
# 获取全局兴趣评分系统实例
|
||||||
|
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system as interest_scoring_system
|
||||||
|
|
||||||
|
# 初始化智能兴趣系统
|
||||||
|
await interest_scoring_system.initialize_smart_interests(
|
||||||
|
personality_description=full_personality, personality_id=self.bot_person_id
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("智能兴趣系统初始化完成")
|
||||||
|
|
||||||
async def get_personality_block(self) -> str:
|
async def get_personality_block(self) -> str:
|
||||||
bot_name = global_config.bot.nickname
|
bot_name = global_config.bot.nickname
|
||||||
if global_config.bot.alias_names:
|
if global_config.bot.alias_names:
|
||||||
|
|||||||
@@ -145,8 +145,8 @@ class LLMUsageRecorder:
|
|||||||
LLM使用情况记录器(SQLAlchemy版本)
|
LLM使用情况记录器(SQLAlchemy版本)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def record_usage_to_database(
|
async def record_usage_to_database(
|
||||||
|
self,
|
||||||
model_info: ModelInfo,
|
model_info: ModelInfo,
|
||||||
model_usage: UsageRecord,
|
model_usage: UsageRecord,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
@@ -161,7 +161,7 @@ class LLMUsageRecorder:
|
|||||||
session = None
|
session = None
|
||||||
try:
|
try:
|
||||||
# 使用 SQLAlchemy 会话创建记录
|
# 使用 SQLAlchemy 会话创建记录
|
||||||
async with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
usage_record = LLMUsage(
|
usage_record = LLMUsage(
|
||||||
model_name=model_info.model_identifier,
|
model_name=model_info.model_identifier,
|
||||||
model_assign_name=model_info.name,
|
model_assign_name=model_info.name,
|
||||||
@@ -179,7 +179,7 @@ class LLMUsageRecorder:
|
|||||||
)
|
)
|
||||||
|
|
||||||
session.add(usage_record)
|
session.add(usage_record)
|
||||||
await session.commit()
|
session.commit()
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Token使用情况 - 模型: {model_usage.model_name}, "
|
f"Token使用情况 - 模型: {model_usage.model_name}, "
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
"""
|
||||||
@desc: 该模块封装了与大语言模型(LLM)交互的所有核心逻辑。
|
@desc: 该模块封装了与大语言模型(LLM)交互的所有核心逻辑。
|
||||||
它被设计为一个高度容错和可扩展的系统,包含以下主要组件:
|
它被设计为一个高度容错和可扩展的系统,包含以下主要组件:
|
||||||
@@ -892,7 +891,7 @@ class LLMRequest:
|
|||||||
max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens,
|
max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
|
await self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
|
||||||
|
|
||||||
if not response.content and not response.tool_calls:
|
if not response.content and not response.tool_calls:
|
||||||
if raise_when_empty:
|
if raise_when_empty:
|
||||||
@@ -917,14 +916,14 @@ class LLMRequest:
|
|||||||
embedding_input=embedding_input
|
embedding_input=embedding_input
|
||||||
)
|
)
|
||||||
|
|
||||||
self._record_usage(model_info, response.usage, time.time() - start_time, "/embeddings")
|
await self._record_usage(model_info, response.usage, time.time() - start_time, "/embeddings")
|
||||||
|
|
||||||
if not response.embedding:
|
if not response.embedding:
|
||||||
raise RuntimeError("获取embedding失败")
|
raise RuntimeError("获取embedding失败")
|
||||||
|
|
||||||
return response.embedding, model_info.name
|
return response.embedding, model_info.name
|
||||||
|
|
||||||
def _record_usage(self, model_info: ModelInfo, usage: Optional[UsageRecord], time_cost: float, endpoint: str):
|
async def _record_usage(self, model_info: ModelInfo, usage: Optional[UsageRecord], time_cost: float, endpoint: str):
|
||||||
"""
|
"""
|
||||||
记录模型使用情况。
|
记录模型使用情况。
|
||||||
|
|
||||||
|
|||||||
137
src/main.py
137
src/main.py
@@ -1,35 +1,40 @@
|
|||||||
# 再用这个就写一行注释来混提交的我直接全部🌿飞😡
|
# 再用这个就写一行注释来混提交的我直接全部🌿飞😡
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import time
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import time
|
from functools import partial
|
||||||
|
import traceback
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
from maim_message import MessageServer
|
from maim_message import MessageServer
|
||||||
from rich.traceback import install
|
|
||||||
|
|
||||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
|
||||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
|
||||||
from src.chat.message_receive.bot import chat_bot
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
|
||||||
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
# 导入消息API和traceback模块
|
|
||||||
from src.common.message import get_global_api
|
|
||||||
from src.common.remote import TelemetryHeartBeatTask
|
from src.common.remote import TelemetryHeartBeatTask
|
||||||
from src.common.server import get_global_server, Server
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.individuality.individuality import get_individuality, Individuality
|
|
||||||
from src.manager.async_task_manager import async_task_manager
|
from src.manager.async_task_manager import async_task_manager
|
||||||
|
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
||||||
|
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||||
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.chat.message_receive.bot import chat_bot
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.individuality.individuality import get_individuality, Individuality
|
||||||
|
from src.common.server import get_global_server, Server
|
||||||
from src.mood.mood_manager import mood_manager
|
from src.mood.mood_manager import mood_manager
|
||||||
from src.plugin_system.base.component_types import EventType
|
from rich.traceback import install
|
||||||
|
from src.schedule.schedule_manager import schedule_manager
|
||||||
|
from src.schedule.monthly_plan_manager import monthly_plan_manager
|
||||||
from src.plugin_system.core.event_manager import event_manager
|
from src.plugin_system.core.event_manager import event_manager
|
||||||
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
|
from src.plugin_system.base.component_types import EventType
|
||||||
|
# from src.api.main import start_api_server
|
||||||
|
|
||||||
# 导入新的插件管理器和热重载管理器
|
# 导入新的插件管理器和热重载管理器
|
||||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||||
from src.schedule.monthly_plan_manager import monthly_plan_manager
|
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
|
||||||
from src.schedule.schedule_manager import schedule_manager
|
|
||||||
|
|
||||||
# from src.api.main import start_api_server
|
# 导入消息API和traceback模块
|
||||||
|
from src.common.message import get_global_api
|
||||||
|
|
||||||
|
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||||
|
|
||||||
if not global_config.memory.enable_memory:
|
if not global_config.memory.enable_memory:
|
||||||
import src.chat.memory_system.Hippocampus as hippocampus_module
|
import src.chat.memory_system.Hippocampus as hippocampus_module
|
||||||
@@ -38,11 +43,7 @@ if not global_config.memory.enable_memory:
|
|||||||
def initialize(self):
|
def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def initialize_async(self):
|
def get_hippocampus(self):
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_hippocampus():
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def build_memory(self):
|
async def build_memory(self):
|
||||||
@@ -54,8 +55,8 @@ if not global_config.memory.enable_memory:
|
|||||||
async def consolidate_memory(self):
|
async def consolidate_memory(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_memory_from_text(
|
async def get_memory_from_text(
|
||||||
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
max_memory_num: int = 3,
|
max_memory_num: int = 3,
|
||||||
max_memory_length: int = 2,
|
max_memory_length: int = 2,
|
||||||
@@ -64,24 +65,20 @@ if not global_config.memory.enable_memory:
|
|||||||
) -> list:
|
) -> list:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_memory_from_topic(
|
async def get_memory_from_topic(
|
||||||
valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
|
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
|
||||||
) -> list:
|
) -> list:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_activate_from_text(
|
async def get_activate_from_text(
|
||||||
text: str, max_depth: int = 3, fast_retrieval: bool = False
|
self, text: str, max_depth: int = 3, fast_retrieval: bool = False
|
||||||
) -> tuple[float, list[str]]:
|
) -> tuple[float, list[str]]:
|
||||||
return 0.0, []
|
return 0.0, []
|
||||||
|
|
||||||
@staticmethod
|
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
|
||||||
def get_memory_from_keyword(keyword: str, max_depth: int = 2) -> list:
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@staticmethod
|
def get_all_node_names(self) -> list:
|
||||||
def get_all_node_names() -> list:
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
hippocampus_module.hippocampus_manager = MockHippocampusManager()
|
hippocampus_module.hippocampus_manager = MockHippocampusManager()
|
||||||
@@ -93,6 +90,20 @@ install(extra_lines=3)
|
|||||||
logger = get_logger("main")
|
logger = get_logger("main")
|
||||||
|
|
||||||
|
|
||||||
|
def _task_done_callback(task: asyncio.Task, message_id: str, start_time: float):
|
||||||
|
"""后台任务完成时的回调函数"""
|
||||||
|
end_time = time.time()
|
||||||
|
duration = end_time - start_time
|
||||||
|
try:
|
||||||
|
task.result() # 如果任务有异常,这里会重新抛出
|
||||||
|
logger.debug(f"消息 {message_id} 的后台任务 (ID: {id(task)}) 已成功完成, 耗时: {duration:.2f}s")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.warning(f"消息 {message_id} 的后台任务 (ID: {id(task)}) 被取消, 耗时: {duration:.2f}s")
|
||||||
|
except Exception:
|
||||||
|
logger.error(f"处理消息 {message_id} 的后台任务 (ID: {id(task)}) 出现未捕获的异常, 耗时: {duration:.2f}s:")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
class MainSystem:
|
class MainSystem:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.hippocampus_manager = hippocampus_manager
|
self.hippocampus_manager = hippocampus_manager
|
||||||
@@ -117,14 +128,27 @@ class MainSystem:
|
|||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
|
||||||
@staticmethod
|
def _cleanup(self):
|
||||||
def _cleanup():
|
|
||||||
"""清理资源"""
|
"""清理资源"""
|
||||||
|
try:
|
||||||
|
# 停止消息管理器
|
||||||
|
from src.chat.message_manager import message_manager
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if loop.is_running():
|
||||||
|
asyncio.create_task(message_manager.stop())
|
||||||
|
else:
|
||||||
|
loop.run_until_complete(message_manager.stop())
|
||||||
|
logger.info("🛑 消息管理器已停止")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"停止消息管理器时出错: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 停止消息重组器
|
# 停止消息重组器
|
||||||
from src.plugin_system.core.event_manager import event_manager
|
from src.plugin_system.core.event_manager import event_manager
|
||||||
from src.plugin_system import EventType
|
from src.plugin_system import EventType
|
||||||
import asyncio
|
|
||||||
asyncio.run(event_manager.trigger_event(EventType.ON_STOP, permission_group="SYSTEM"))
|
asyncio.run(event_manager.trigger_event(EventType.ON_STOP, permission_group="SYSTEM"))
|
||||||
from src.utils.message_chunker import reassembler
|
from src.utils.message_chunker import reassembler
|
||||||
|
|
||||||
@@ -159,6 +183,20 @@ class MainSystem:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"停止记忆管理器时出错: {e}")
|
logger.error(f"停止记忆管理器时出错: {e}")
|
||||||
|
|
||||||
|
async def _message_process_wrapper(self, message_data: Dict[str, Any]):
|
||||||
|
"""并行处理消息的包装器"""
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
message_id = message_data.get("message_info", {}).get("message_id", "UNKNOWN")
|
||||||
|
# 创建后台任务
|
||||||
|
task = asyncio.create_task(chat_bot.message_process(message_data))
|
||||||
|
logger.debug(f"已为消息 {message_id} 创建后台处理任务 (ID: {id(task)})")
|
||||||
|
# 添加一个回调函数,当任务完成时,它会被调用
|
||||||
|
task.add_done_callback(partial(_task_done_callback, message_id=message_id, start_time=start_time))
|
||||||
|
except Exception:
|
||||||
|
logger.error("在创建消息处理任务时发生严重错误:")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""初始化系统组件"""
|
"""初始化系统组件"""
|
||||||
logger.info(f"正在唤醒{global_config.bot.nickname}......")
|
logger.info(f"正在唤醒{global_config.bot.nickname}......")
|
||||||
@@ -223,7 +261,6 @@ MoFox_Bot(第三方修改版)
|
|||||||
from src.plugin_system.apis.permission_api import permission_api
|
from src.plugin_system.apis.permission_api import permission_api
|
||||||
|
|
||||||
permission_manager = PermissionManager()
|
permission_manager = PermissionManager()
|
||||||
await permission_manager.initialize()
|
|
||||||
permission_api.set_permission_manager(permission_manager)
|
permission_api.set_permission_manager(permission_manager)
|
||||||
logger.info("权限管理器初始化成功")
|
logger.info("权限管理器初始化成功")
|
||||||
|
|
||||||
@@ -244,6 +281,18 @@ MoFox_Bot(第三方修改版)
|
|||||||
get_emoji_manager().initialize()
|
get_emoji_manager().initialize()
|
||||||
logger.info("表情包管理器初始化成功")
|
logger.info("表情包管理器初始化成功")
|
||||||
|
|
||||||
|
# 初始化回复后关系追踪系统
|
||||||
|
try:
|
||||||
|
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||||
|
from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker
|
||||||
|
|
||||||
|
relationship_tracker = ChatterRelationshipTracker(interest_scoring_system=chatter_interest_scoring_system)
|
||||||
|
chatter_interest_scoring_system.relationship_tracker = relationship_tracker
|
||||||
|
logger.info("回复后关系追踪系统初始化成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"回复后关系追踪系统初始化失败: {e}")
|
||||||
|
relationship_tracker = None
|
||||||
|
|
||||||
# 启动情绪管理器
|
# 启动情绪管理器
|
||||||
await mood_manager.start()
|
await mood_manager.start()
|
||||||
logger.info("情绪管理器初始化成功")
|
logger.info("情绪管理器初始化成功")
|
||||||
@@ -256,11 +305,12 @@ MoFox_Bot(第三方修改版)
|
|||||||
logger.info("聊天管理器初始化成功")
|
logger.info("聊天管理器初始化成功")
|
||||||
|
|
||||||
# 初始化记忆系统
|
# 初始化记忆系统
|
||||||
await self.hippocampus_manager.initialize_async()
|
self.hippocampus_manager.initialize()
|
||||||
logger.info("记忆系统初始化成功")
|
logger.info("记忆系统初始化成功")
|
||||||
|
|
||||||
# 初始化LPMM知识库
|
# 初始化LPMM知识库
|
||||||
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge
|
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge
|
||||||
|
|
||||||
initialize_lpmm_knowledge()
|
initialize_lpmm_knowledge()
|
||||||
logger.info("LPMM知识库初始化成功")
|
logger.info("LPMM知识库初始化成功")
|
||||||
|
|
||||||
@@ -276,7 +326,7 @@ MoFox_Bot(第三方修改版)
|
|||||||
# await asyncio.sleep(0.5) #防止logger输出飞了
|
# await asyncio.sleep(0.5) #防止logger输出飞了
|
||||||
|
|
||||||
# 将bot.py中的chat_bot.message_process消息处理函数注册到api.py的消息处理基类中
|
# 将bot.py中的chat_bot.message_process消息处理函数注册到api.py的消息处理基类中
|
||||||
self.app.register_message_handler(chat_bot.message_process)
|
self.app.register_message_handler(self._message_process_wrapper)
|
||||||
|
|
||||||
# 启动消息重组器的清理任务
|
# 启动消息重组器的清理任务
|
||||||
from src.utils.message_chunker import reassembler
|
from src.utils.message_chunker import reassembler
|
||||||
@@ -284,6 +334,12 @@ MoFox_Bot(第三方修改版)
|
|||||||
await reassembler.start_cleanup_task()
|
await reassembler.start_cleanup_task()
|
||||||
logger.info("消息重组器已启动")
|
logger.info("消息重组器已启动")
|
||||||
|
|
||||||
|
# 启动消息管理器
|
||||||
|
from src.chat.message_manager import message_manager
|
||||||
|
|
||||||
|
await message_manager.start()
|
||||||
|
logger.info("消息管理器已启动")
|
||||||
|
|
||||||
# 初始化个体特征
|
# 初始化个体特征
|
||||||
await self.individuality.initialize()
|
await self.individuality.initialize()
|
||||||
|
|
||||||
@@ -291,7 +347,7 @@ MoFox_Bot(第三方修改版)
|
|||||||
if global_config.planning_system.monthly_plan_enable:
|
if global_config.planning_system.monthly_plan_enable:
|
||||||
logger.info("正在初始化月度计划管理器...")
|
logger.info("正在初始化月度计划管理器...")
|
||||||
try:
|
try:
|
||||||
await monthly_plan_manager.initialize()
|
await monthly_plan_manager.start_monthly_plan_generation()
|
||||||
logger.info("月度计划管理器初始化成功")
|
logger.info("月度计划管理器初始化成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"月度计划管理器初始化失败: {e}")
|
logger.error(f"月度计划管理器初始化失败: {e}")
|
||||||
@@ -299,7 +355,8 @@ MoFox_Bot(第三方修改版)
|
|||||||
# 初始化日程管理器
|
# 初始化日程管理器
|
||||||
if global_config.planning_system.schedule_enable:
|
if global_config.planning_system.schedule_enable:
|
||||||
logger.info("日程表功能已启用,正在初始化管理器...")
|
logger.info("日程表功能已启用,正在初始化管理器...")
|
||||||
await schedule_manager.initialize()
|
await schedule_manager.load_or_generate_today_schedule()
|
||||||
|
await schedule_manager.start_daily_schedule_generation()
|
||||||
logger.info("日程表管理器初始化成功。")
|
logger.info("日程表管理器初始化成功。")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import time
|
|||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.chat.message_receive.message import MessageRecv
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||||
@@ -65,7 +66,7 @@ class ChatMood:
|
|||||||
|
|
||||||
self.last_change_time: float = 0
|
self.last_change_time: float = 0
|
||||||
|
|
||||||
async def update_mood_by_message(self, message: MessageRecv, interested_rate: float):
|
async def update_mood_by_message(self, message: MessageRecv | DatabaseMessages, interested_rate: float):
|
||||||
# 如果当前聊天处于失眠状态,则锁定情绪,不允许更新
|
# 如果当前聊天处于失眠状态,则锁定情绪,不允许更新
|
||||||
if self.chat_id in mood_manager.insomnia_chats:
|
if self.chat_id in mood_manager.insomnia_chats:
|
||||||
logger.debug(f"{self.log_prefix} 处于失眠状态,情绪已锁定,跳过更新。")
|
logger.debug(f"{self.log_prefix} 处于失眠状态,情绪已锁定,跳过更新。")
|
||||||
@@ -73,7 +74,13 @@ class ChatMood:
|
|||||||
|
|
||||||
self.regression_count = 0
|
self.regression_count = 0
|
||||||
|
|
||||||
during_last_time = message.message_info.time - self.last_change_time # type: ignore
|
# 处理不同类型的消息对象
|
||||||
|
if isinstance(message, MessageRecv):
|
||||||
|
message_time = message.message_info.time
|
||||||
|
else: # DatabaseMessages
|
||||||
|
message_time = message.time
|
||||||
|
|
||||||
|
during_last_time = message_time - self.last_change_time
|
||||||
|
|
||||||
base_probability = 0.05
|
base_probability = 0.05
|
||||||
time_multiplier = 4 * (1 - math.exp(-0.01 * during_last_time))
|
time_multiplier = 4 * (1 - math.exp(-0.01 * during_last_time))
|
||||||
@@ -96,16 +103,14 @@ class ChatMood:
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"{self.log_prefix} 更新情绪状态,感兴趣度: {interested_rate:.2f}, 更新概率: {update_probability:.2f}"
|
f"{self.log_prefix} 更新情绪状态,感兴趣度: {interested_rate:.2f}, 更新概率: {update_probability:.2f}"
|
||||||
)
|
)
|
||||||
|
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||||
message_time: float = message.message_info.time # type: ignore
|
|
||||||
message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive(
|
|
||||||
chat_id=self.chat_id,
|
chat_id=self.chat_id,
|
||||||
timestamp_start=self.last_change_time,
|
timestamp_start=self.last_change_time,
|
||||||
timestamp_end=message_time,
|
timestamp_end=message_time,
|
||||||
limit=int(global_config.chat.max_context_size / 3),
|
limit=int(global_config.chat.max_context_size / 3),
|
||||||
limit_mode="last",
|
limit_mode="last",
|
||||||
)
|
)
|
||||||
chat_talking_prompt = await build_readable_messages(
|
chat_talking_prompt = build_readable_messages(
|
||||||
message_list_before_now,
|
message_list_before_now,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
merge_messages=False,
|
merge_messages=False,
|
||||||
@@ -135,9 +140,9 @@ class ChatMood:
|
|||||||
prompt=prompt, temperature=0.7
|
prompt=prompt, temperature=0.7
|
||||||
)
|
)
|
||||||
if global_config.debug.show_prompt:
|
if global_config.debug.show_prompt:
|
||||||
logger.info(f"{self.log_prefix} prompt: {prompt}")
|
logger.debug(f"{self.log_prefix} prompt: {prompt}")
|
||||||
logger.info(f"{self.log_prefix} response: {response}")
|
logger.debug(f"{self.log_prefix} response: {response}")
|
||||||
logger.info(f"{self.log_prefix} reasoning_content: {reasoning_content}")
|
logger.debug(f"{self.log_prefix} reasoning_content: {reasoning_content}")
|
||||||
|
|
||||||
logger.info(f"{self.log_prefix} 情绪状态更新为: {response}")
|
logger.info(f"{self.log_prefix} 情绪状态更新为: {response}")
|
||||||
|
|
||||||
@@ -147,14 +152,14 @@ class ChatMood:
|
|||||||
|
|
||||||
async def regress_mood(self):
|
async def regress_mood(self):
|
||||||
message_time = time.time()
|
message_time = time.time()
|
||||||
message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive(
|
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||||
chat_id=self.chat_id,
|
chat_id=self.chat_id,
|
||||||
timestamp_start=self.last_change_time,
|
timestamp_start=self.last_change_time,
|
||||||
timestamp_end=message_time,
|
timestamp_end=message_time,
|
||||||
limit=15,
|
limit=15,
|
||||||
limit_mode="last",
|
limit_mode="last",
|
||||||
)
|
)
|
||||||
chat_talking_prompt = await build_readable_messages(
|
chat_talking_prompt = build_readable_messages(
|
||||||
message_list_before_now,
|
message_list_before_now,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
merge_messages=False,
|
merge_messages=False,
|
||||||
@@ -185,9 +190,9 @@ class ChatMood:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if global_config.debug.show_prompt:
|
if global_config.debug.show_prompt:
|
||||||
logger.info(f"{self.log_prefix} prompt: {prompt}")
|
logger.debug(f"{self.log_prefix} prompt: {prompt}")
|
||||||
logger.info(f"{self.log_prefix} response: {response}")
|
logger.debug(f"{self.log_prefix} response: {response}")
|
||||||
logger.info(f"{self.log_prefix} reasoning_content: {reasoning_content}")
|
logger.debug(f"{self.log_prefix} reasoning_content: {reasoning_content}")
|
||||||
|
|
||||||
logger.info(f"{self.log_prefix} 情绪状态转变为: {response}")
|
logger.info(f"{self.log_prefix} 情绪状态转变为: {response}")
|
||||||
|
|
||||||
|
|||||||
@@ -94,11 +94,52 @@ class PersonInfoManager:
|
|||||||
|
|
||||||
if "-" in platform:
|
if "-" in platform:
|
||||||
platform = platform.split("-")[1]
|
platform = platform.split("-")[1]
|
||||||
|
# 在此处打一个补丁,如果platform为qq,尝试生成id后检查是否存在,如果不存在,则将平台换为napcat后再次检查,如果存在,则更新原id为platform为qq的id
|
||||||
components = [platform, str(user_id)]
|
components = [platform, str(user_id)]
|
||||||
key = "_".join(components)
|
key = "_".join(components)
|
||||||
|
|
||||||
|
# 如果不是 qq 平台,直接返回计算的 id
|
||||||
|
if platform != "qq":
|
||||||
return hashlib.md5(key.encode()).hexdigest()
|
return hashlib.md5(key.encode()).hexdigest()
|
||||||
|
|
||||||
|
qq_id = hashlib.md5(key.encode()).hexdigest()
|
||||||
|
|
||||||
|
# 对于 qq 平台,先检查该 person_id 是否已存在;如果存在直接返回
|
||||||
|
def _db_check_and_migrate_sync(p_id: str, raw_user_id: str):
|
||||||
|
try:
|
||||||
|
with get_db_session() as session:
|
||||||
|
# 检查 qq_id 是否存在
|
||||||
|
existing_qq = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
|
||||||
|
if existing_qq:
|
||||||
|
return p_id
|
||||||
|
|
||||||
|
# 如果 qq_id 不存在,尝试使用 napcat 作为平台生成对应 id 并检查
|
||||||
|
nap_components = ["napcat", str(raw_user_id)]
|
||||||
|
nap_key = "_".join(nap_components)
|
||||||
|
nap_id = hashlib.md5(nap_key.encode()).hexdigest()
|
||||||
|
|
||||||
|
existing_nap = session.execute(select(PersonInfo).where(PersonInfo.person_id == nap_id)).scalar()
|
||||||
|
if not existing_nap:
|
||||||
|
# napcat 也不存在,返回 qq_id(未命中)
|
||||||
|
return p_id
|
||||||
|
|
||||||
|
# napcat 存在,迁移该记录:更新 person_id 与 platform -> qq
|
||||||
|
try:
|
||||||
|
# 更新现有 napcat 记录
|
||||||
|
existing_nap.person_id = p_id
|
||||||
|
existing_nap.platform = "qq"
|
||||||
|
existing_nap.user_id = str(raw_user_id)
|
||||||
|
session.commit()
|
||||||
|
return p_id
|
||||||
|
except Exception:
|
||||||
|
session.rollback()
|
||||||
|
return p_id
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"检查/迁移 napcat->qq 时出错: {e}")
|
||||||
|
return p_id
|
||||||
|
|
||||||
|
return _db_check_and_migrate_sync(qq_id, user_id)
|
||||||
|
|
||||||
async def is_person_known(self, platform: str, user_id: int):
|
async def is_person_known(self, platform: str, user_id: int):
|
||||||
"""判断是否认识某人"""
|
"""判断是否认识某人"""
|
||||||
person_id = self.get_person_id(platform, user_id)
|
person_id = self.get_person_id(platform, user_id)
|
||||||
@@ -128,6 +169,27 @@ class PersonInfoManager:
|
|||||||
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}")
|
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str):
|
||||||
|
"""判断是否认识某人"""
|
||||||
|
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||||
|
# 生成唯一的 person_name
|
||||||
|
person_info_manager = get_person_info_manager()
|
||||||
|
unique_nickname = await person_info_manager._generate_unique_person_name(user_nickname)
|
||||||
|
data = {
|
||||||
|
"platform": platform,
|
||||||
|
"user_id": user_id,
|
||||||
|
"nickname": user_nickname,
|
||||||
|
"konw_time": int(time.time()),
|
||||||
|
"person_name": unique_nickname, # 使用唯一的 person_name
|
||||||
|
}
|
||||||
|
# 先创建用户基本信息,使用安全创建方法避免竞态条件
|
||||||
|
await person_info_manager._safe_create_person_info(person_id=person_id, data=data)
|
||||||
|
# 更新昵称
|
||||||
|
await person_info_manager.update_one_field(
|
||||||
|
person_id=person_id, field_name="nickname", value=user_nickname, data=data
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def create_person_info(person_id: str, data: Optional[dict] = None):
|
async def create_person_info(person_id: str, data: Optional[dict] = None):
|
||||||
"""创建一个项"""
|
"""创建一个项"""
|
||||||
|
|||||||
@@ -94,90 +94,144 @@ class RelationshipFetcher:
|
|||||||
if not self.info_fetched_cache[person_id]:
|
if not self.info_fetched_cache[person_id]:
|
||||||
del self.info_fetched_cache[person_id]
|
del self.info_fetched_cache[person_id]
|
||||||
|
|
||||||
async def build_relation_info(self, person_id, points_num=3):
|
async def build_relation_info(self, person_id, points_num=5):
|
||||||
|
"""构建详细的人物关系信息,包含从数据库中查询的丰富关系描述"""
|
||||||
# 清理过期的信息缓存
|
# 清理过期的信息缓存
|
||||||
self._cleanup_expired_cache()
|
self._cleanup_expired_cache()
|
||||||
|
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
person_info = await person_info_manager.get_values(
|
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||||
person_id, ["person_name", "short_impression", "nickname", "platform", "points"]
|
short_impression = await person_info_manager.get_value(person_id, "short_impression")
|
||||||
|
full_impression = await person_info_manager.get_value(person_id, "impression")
|
||||||
|
attitude = await person_info_manager.get_value(person_id, "attitude") or 50
|
||||||
|
|
||||||
|
nickname_str = await person_info_manager.get_value(person_id, "nickname")
|
||||||
|
platform = await person_info_manager.get_value(person_id, "platform")
|
||||||
|
know_times = await person_info_manager.get_value(person_id, "know_times") or 0
|
||||||
|
know_since = await person_info_manager.get_value(person_id, "know_since")
|
||||||
|
last_know = await person_info_manager.get_value(person_id, "last_know")
|
||||||
|
|
||||||
|
# 如果用户没有基本信息,返回默认描述
|
||||||
|
if person_name == nickname_str and not short_impression and not full_impression:
|
||||||
|
return f"你完全不认识{person_name},这是你们第一次交流。"
|
||||||
|
|
||||||
|
# 获取用户特征点
|
||||||
|
current_points = await person_info_manager.get_value(person_id, "points") or []
|
||||||
|
forgotten_points = await person_info_manager.get_value(person_id, "forgotten_points") or []
|
||||||
|
|
||||||
|
# 按时间排序并选择最有代表性的特征点
|
||||||
|
all_points = current_points + forgotten_points
|
||||||
|
if all_points:
|
||||||
|
# 按权重和时效性综合排序
|
||||||
|
all_points.sort(
|
||||||
|
key=lambda x: (float(x[1]) if len(x) > 1 else 0, float(x[2]) if len(x) > 2 else 0), reverse=True
|
||||||
)
|
)
|
||||||
person_name = person_info.get("person_name")
|
selected_points = all_points[:points_num]
|
||||||
short_impression = person_info.get("short_impression")
|
points_text = "\n".join([f"- {point[0]}({point[2]})" for point in selected_points if len(point) > 2])
|
||||||
nickname_str = person_info.get("nickname")
|
|
||||||
platform = person_info.get("platform")
|
|
||||||
|
|
||||||
if person_name == nickname_str and not short_impression:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
current_points = person_info.get("points")
|
|
||||||
if isinstance(current_points, str):
|
|
||||||
current_points = orjson.loads(current_points)
|
|
||||||
else:
|
else:
|
||||||
current_points = current_points or []
|
points_text = ""
|
||||||
|
|
||||||
# 按时间排序forgotten_points
|
# 构建详细的关系描述
|
||||||
current_points.sort(key=lambda x: x[2])
|
relation_parts = []
|
||||||
# 按权重加权随机抽取最多3个不重复的points,point[1]的值在1-10之间,权重越高被抽到概率越大
|
|
||||||
if len(current_points) > points_num:
|
|
||||||
# point[1] 取值范围1-10,直接作为权重
|
|
||||||
weights = [max(1, min(10, int(point[1]))) for point in current_points]
|
|
||||||
# 使用加权采样不放回,保证不重复
|
|
||||||
indices = list(range(len(current_points)))
|
|
||||||
points = []
|
|
||||||
for _ in range(points_num):
|
|
||||||
if not indices:
|
|
||||||
break
|
|
||||||
sub_weights = [weights[i] for i in indices]
|
|
||||||
chosen_idx = random.choices(indices, weights=sub_weights, k=1)[0]
|
|
||||||
points.append(current_points[chosen_idx])
|
|
||||||
indices.remove(chosen_idx)
|
|
||||||
else:
|
|
||||||
points = current_points
|
|
||||||
|
|
||||||
# 构建points文本
|
# 1. 基本信息
|
||||||
points_text = "\n".join([f"{point[2]}:{point[0]}" for point in points])
|
if nickname_str and person_name != nickname_str:
|
||||||
|
relation_parts.append(f"用户{person_name}在{platform}平台的昵称是{nickname_str}")
|
||||||
|
|
||||||
nickname_str = ""
|
# 2. 认识时间和频率
|
||||||
if person_name != nickname_str:
|
if know_since:
|
||||||
nickname_str = f"(ta在{platform}上的昵称是{nickname_str})"
|
from datetime import datetime
|
||||||
|
|
||||||
relation_info = ""
|
know_time = datetime.fromtimestamp(know_since).strftime("%Y年%m月%d日")
|
||||||
|
relation_parts.append(f"你从{know_time}开始认识{person_name}")
|
||||||
|
|
||||||
if short_impression and relation_info:
|
if know_times > 0:
|
||||||
|
relation_parts.append(f"你们已经交流过{int(know_times)}次")
|
||||||
|
|
||||||
|
if last_know:
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
last_time = datetime.fromtimestamp(last_know).strftime("%m月%d日")
|
||||||
|
relation_parts.append(f"最近一次交流是在{last_time}")
|
||||||
|
|
||||||
|
# 3. 态度和印象
|
||||||
|
attitude_desc = self._get_attitude_description(attitude)
|
||||||
|
relation_parts.append(f"你对{person_name}的态度是{attitude_desc}")
|
||||||
|
|
||||||
|
if short_impression:
|
||||||
|
relation_parts.append(f"你对ta的总体印象:{short_impression}")
|
||||||
|
|
||||||
|
if full_impression:
|
||||||
|
relation_parts.append(f"更详细的了解:{full_impression}")
|
||||||
|
|
||||||
|
# 4. 特征点和记忆
|
||||||
if points_text:
|
if points_text:
|
||||||
relation_info = f"你对{person_name}的印象是{nickname_str}:{short_impression}。具体来说:{relation_info}。你还记得ta最近做的事:{points_text}"
|
relation_parts.append(f"你记得关于{person_name}的一些事情:\n{points_text}")
|
||||||
else:
|
|
||||||
relation_info = (
|
# 5. 从UserRelationships表获取额外关系信息
|
||||||
f"你对{person_name}的印象是{nickname_str}:{short_impression}。具体来说:{relation_info}"
|
try:
|
||||||
|
from src.common.database.sqlalchemy_database_api import db_query
|
||||||
|
from src.common.database.sqlalchemy_models import UserRelationships
|
||||||
|
|
||||||
|
# 查询用户关系数据
|
||||||
|
relationships = await db_query(
|
||||||
|
UserRelationships,
|
||||||
|
filters=[UserRelationships.user_id == str(person_info_manager.get_value_sync(person_id, "user_id"))],
|
||||||
|
limit=1,
|
||||||
)
|
)
|
||||||
elif short_impression:
|
|
||||||
if points_text:
|
if relationships:
|
||||||
relation_info = (
|
rel_data = relationships[0]
|
||||||
f"你对{person_name}的印象是{nickname_str}:{short_impression}。你还记得ta最近做的事:{points_text}"
|
if rel_data.relationship_text:
|
||||||
|
relation_parts.append(f"关系记录:{rel_data.relationship_text}")
|
||||||
|
if rel_data.relationship_score:
|
||||||
|
score_desc = self._get_relationship_score_description(rel_data.relationship_score)
|
||||||
|
relation_parts.append(f"关系亲密程度:{score_desc}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"查询UserRelationships表失败: {e}")
|
||||||
|
|
||||||
|
# 构建最终的关系信息字符串
|
||||||
|
if relation_parts:
|
||||||
|
relation_info = f"关于{person_name},你知道以下信息:\n" + "\n".join(
|
||||||
|
[f"• {part}" for part in relation_parts]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
relation_info = f"你对{person_name}的印象是{nickname_str}:{short_impression}"
|
relation_info = f"你对{person_name}了解不多,这是比较初步的交流。"
|
||||||
elif relation_info:
|
|
||||||
if points_text:
|
|
||||||
relation_info = (
|
|
||||||
f"你对{person_name}的了解{nickname_str}:{relation_info}。你还记得ta最近做的事:{points_text}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
relation_info = f"你对{person_name}的了解{nickname_str}:{relation_info}"
|
|
||||||
elif points_text:
|
|
||||||
relation_info = f"你记得{person_name}{nickname_str}最近做的事:{points_text}"
|
|
||||||
else:
|
|
||||||
relation_info = ""
|
|
||||||
|
|
||||||
return relation_info
|
return relation_info
|
||||||
|
|
||||||
|
def _get_attitude_description(self, attitude: int) -> str:
|
||||||
|
"""根据态度分数返回描述性文字"""
|
||||||
|
if attitude >= 80:
|
||||||
|
return "非常喜欢和欣赏"
|
||||||
|
elif attitude >= 60:
|
||||||
|
return "比较有好感"
|
||||||
|
elif attitude >= 40:
|
||||||
|
return "中立态度"
|
||||||
|
elif attitude >= 20:
|
||||||
|
return "有些反感"
|
||||||
|
else:
|
||||||
|
return "非常厌恶"
|
||||||
|
|
||||||
|
def _get_relationship_score_description(self, score: float) -> str:
|
||||||
|
"""根据关系分数返回描述性文字"""
|
||||||
|
if score >= 0.8:
|
||||||
|
return "非常亲密的好友"
|
||||||
|
elif score >= 0.6:
|
||||||
|
return "关系不错的朋友"
|
||||||
|
elif score >= 0.4:
|
||||||
|
return "普通熟人"
|
||||||
|
elif score >= 0.2:
|
||||||
|
return "认识但不熟悉"
|
||||||
|
else:
|
||||||
|
return "陌生人"
|
||||||
|
|
||||||
async def _build_fetch_query(self, person_id, target_message, chat_history):
|
async def _build_fetch_query(self, person_id, target_message, chat_history):
|
||||||
nickname_str = ",".join(global_config.bot.alias_names)
|
nickname_str = ",".join(global_config.bot.alias_names)
|
||||||
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
|
name_block = f"你的名字是{global_config.bot.nickname},你的昵称有{nickname_str},有人也会用这些昵称称呼你。"
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
person_info = await person_info_manager.get_values(person_id, ["person_name"])
|
person_name: str = await person_info_manager.get_value(person_id, "person_name") # type: ignore
|
||||||
person_name: str = person_info.get("person_name") # type: ignore
|
|
||||||
|
|
||||||
info_cache_block = self._build_info_cache_block()
|
info_cache_block = self._build_info_cache_block()
|
||||||
|
|
||||||
@@ -259,8 +313,7 @@ class RelationshipFetcher:
|
|||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
|
|
||||||
# 首先检查 info_list 缓存
|
# 首先检查 info_list 缓存
|
||||||
person_info = await person_info_manager.get_values(person_id, ["info_list"])
|
info_list = await person_info_manager.get_value(person_id, "info_list") or []
|
||||||
info_list = person_info.get("info_list") or []
|
|
||||||
cached_info = None
|
cached_info = None
|
||||||
|
|
||||||
# 查找对应的 info_type
|
# 查找对应的 info_type
|
||||||
@@ -287,9 +340,8 @@ class RelationshipFetcher:
|
|||||||
|
|
||||||
# 如果缓存中没有,尝试从用户档案中提取
|
# 如果缓存中没有,尝试从用户档案中提取
|
||||||
try:
|
try:
|
||||||
person_info = await person_info_manager.get_values(person_id, ["impression", "points"])
|
person_impression = await person_info_manager.get_value(person_id, "impression")
|
||||||
person_impression = person_info.get("impression")
|
points = await person_info_manager.get_value(person_id, "points")
|
||||||
points = person_info.get("points")
|
|
||||||
|
|
||||||
# 构建印象信息块
|
# 构建印象信息块
|
||||||
if person_impression:
|
if person_impression:
|
||||||
@@ -381,8 +433,7 @@ class RelationshipFetcher:
|
|||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
|
|
||||||
# 获取现有的 info_list
|
# 获取现有的 info_list
|
||||||
person_info = await person_info_manager.get_values(person_id, ["info_list"])
|
info_list = await person_info_manager.get_value(person_id, "info_list") or []
|
||||||
info_list = person_info.get("info_list") or []
|
|
||||||
|
|
||||||
# 查找是否已存在相同 info_type 的记录
|
# 查找是否已存在相同 info_type 的记录
|
||||||
found_index = -1
|
found_index = -1
|
||||||
|
|||||||
@@ -121,6 +121,13 @@ async def generate_reply(
|
|||||||
if not extra_info and action_data:
|
if not extra_info and action_data:
|
||||||
extra_info = action_data.get("extra_info", "")
|
extra_info = action_data.get("extra_info", "")
|
||||||
|
|
||||||
|
# 如果action_data中有thinking,添加到extra_info中
|
||||||
|
if action_data and (thinking := action_data.get("thinking")):
|
||||||
|
if extra_info:
|
||||||
|
extra_info += f"\n\n思考过程:{thinking}"
|
||||||
|
else:
|
||||||
|
extra_info = f"思考过程:{thinking}"
|
||||||
|
|
||||||
# 调用回复器生成回复
|
# 调用回复器生成回复
|
||||||
success, llm_response_dict, prompt = await replyer.generate_reply_with_context(
|
success, llm_response_dict, prompt = await replyer.generate_reply_with_context(
|
||||||
reply_to=reply_to,
|
reply_to=reply_to,
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[Messa
|
|||||||
|
|
||||||
message_info = {
|
message_info = {
|
||||||
"platform": message_dict.get("chat_info_platform", ""),
|
"platform": message_dict.get("chat_info_platform", ""),
|
||||||
"message_id": message_dict.get("message_id"),
|
"message_id": message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id"),
|
||||||
"time": message_dict.get("time"),
|
"time": message_dict.get("time"),
|
||||||
"group_info": group_info,
|
"group_info": group_info,
|
||||||
"user_info": user_info,
|
"user_info": user_info,
|
||||||
@@ -89,15 +89,16 @@ def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[Messa
|
|||||||
"template_info": template_info,
|
"template_info": template_info,
|
||||||
}
|
}
|
||||||
|
|
||||||
message_dict = {
|
new_message_dict = {
|
||||||
"message_info": message_info,
|
"message_info": message_info,
|
||||||
"raw_message": message_dict.get("processed_plain_text"),
|
"raw_message": message_dict.get("processed_plain_text"),
|
||||||
"processed_plain_text": message_dict.get("processed_plain_text"),
|
"processed_plain_text": message_dict.get("processed_plain_text"),
|
||||||
}
|
}
|
||||||
|
|
||||||
message_recv = MessageRecv(message_dict)
|
message_recv = MessageRecv(new_message_dict)
|
||||||
|
|
||||||
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {message_dict.get('user_nickname', '')}")
|
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {message_dict.get('user_nickname', '')}")
|
||||||
|
logger.info(message_recv)
|
||||||
return message_recv
|
return message_recv
|
||||||
|
|
||||||
|
|
||||||
@@ -246,7 +247,7 @@ async def text_to_stream(
|
|||||||
typing: bool = False,
|
typing: bool = False,
|
||||||
reply_to: str = "",
|
reply_to: str = "",
|
||||||
reply_to_message: Optional[Dict[str, Any]] = None,
|
reply_to_message: Optional[Dict[str, Any]] = None,
|
||||||
set_reply: bool = False,
|
set_reply: bool = True,
|
||||||
storage_message: bool = True,
|
storage_message: bool = True,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""向指定流发送文本消息
|
"""向指定流发送文本消息
|
||||||
@@ -275,7 +276,7 @@ async def text_to_stream(
|
|||||||
|
|
||||||
|
|
||||||
async def emoji_to_stream(
|
async def emoji_to_stream(
|
||||||
emoji_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False
|
emoji_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = True
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""向指定流发送表情包
|
"""向指定流发送表情包
|
||||||
|
|
||||||
@@ -293,7 +294,7 @@ async def emoji_to_stream(
|
|||||||
|
|
||||||
|
|
||||||
async def image_to_stream(
|
async def image_to_stream(
|
||||||
image_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False
|
image_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = True
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""向指定流发送图片
|
"""向指定流发送图片
|
||||||
|
|
||||||
@@ -315,7 +316,7 @@ async def command_to_stream(
|
|||||||
stream_id: str,
|
stream_id: str,
|
||||||
storage_message: bool = True,
|
storage_message: bool = True,
|
||||||
display_message: str = "",
|
display_message: str = "",
|
||||||
set_reply: bool = False,
|
set_reply: bool = True,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""向指定流发送命令
|
"""向指定流发送命令
|
||||||
|
|
||||||
@@ -340,7 +341,7 @@ async def custom_to_stream(
|
|||||||
typing: bool = False,
|
typing: bool = False,
|
||||||
reply_to: str = "",
|
reply_to: str = "",
|
||||||
reply_to_message: Optional[Dict[str, Any]] = None,
|
reply_to_message: Optional[Dict[str, Any]] = None,
|
||||||
set_reply: bool = False,
|
set_reply: bool = True,
|
||||||
storage_message: bool = True,
|
storage_message: bool = True,
|
||||||
show_log: bool = True,
|
show_log: bool = True,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
|||||||
@@ -93,7 +93,6 @@ class BaseAction(ABC):
|
|||||||
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
|
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
|
||||||
self.chat_type_allow: ChatType = getattr(self.__class__, "chat_type_allow", ChatType.ALL)
|
self.chat_type_allow: ChatType = getattr(self.__class__, "chat_type_allow", ChatType.ALL)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解)
|
# 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -398,6 +397,7 @@ class BaseAction(ABC):
|
|||||||
try:
|
try:
|
||||||
# 1. 从注册中心获取Action类
|
# 1. 从注册中心获取Action类
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
|
||||||
action_class = component_registry.get_component_class(action_name, ComponentType.ACTION)
|
action_class = component_registry.get_component_class(action_name, ComponentType.ACTION)
|
||||||
if not action_class:
|
if not action_class:
|
||||||
logger.error(f"{log_prefix} 未找到Action: {action_name}")
|
logger.error(f"{log_prefix} 未找到Action: {action_name}")
|
||||||
|
|||||||
55
src/plugin_system/base/base_chatter.py
Normal file
55
src/plugin_system/base/base_chatter.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Optional, TYPE_CHECKING
|
||||||
|
from src.common.data_models.message_manager_data_model import StreamContext
|
||||||
|
from .component_types import ChatType
|
||||||
|
from src.plugin_system.base.component_types import ChatterInfo, ComponentType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||||
|
from src.plugins.built_in.affinity_flow_chatter.planner import ChatterActionPlanner as ActionPlanner
|
||||||
|
|
||||||
|
class BaseChatter(ABC):
|
||||||
|
chatter_name: str = ""
|
||||||
|
"""Chatter组件的名称"""
|
||||||
|
chatter_description: str = ""
|
||||||
|
"""Chatter组件的描述"""
|
||||||
|
chat_types: List[ChatType] = [ChatType.PRIVATE, ChatType.GROUP]
|
||||||
|
|
||||||
|
def __init__(self, stream_id: str, action_manager: 'ChatterActionManager'):
|
||||||
|
"""
|
||||||
|
初始化聊天处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id: 聊天流ID
|
||||||
|
action_manager: 动作管理器
|
||||||
|
"""
|
||||||
|
self.stream_id = stream_id
|
||||||
|
self.action_manager = action_manager
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def execute(self, context: StreamContext) -> dict:
|
||||||
|
"""
|
||||||
|
执行聊天处理流程
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: StreamContext对象,包含聊天流的所有消息信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理结果字典
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_chatter_info(cls) -> "ChatterInfo":
|
||||||
|
"""从类属性生成ChatterInfo
|
||||||
|
Returns:
|
||||||
|
ChatterInfo对象
|
||||||
|
"""
|
||||||
|
|
||||||
|
return ChatterInfo(
|
||||||
|
name=cls.chatter_name,
|
||||||
|
description=cls.chatter_description or "No description provided.",
|
||||||
|
chat_type_allow=cls.chat_types[0],
|
||||||
|
component_type=ComponentType.CHATTER,
|
||||||
|
)
|
||||||
|
|
||||||
@@ -73,7 +73,7 @@ class BaseCommand(ABC):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
# 检查是否为群聊消息
|
# 检查是否为群聊消息
|
||||||
is_group = hasattr(self.message, "is_group_message") and self.message.is_group_message
|
is_group = self.message.message_info.group_info
|
||||||
|
|
||||||
if self.chat_type_allow == ChatType.GROUP and is_group:
|
if self.chat_type_allow == ChatType.GROUP and is_group:
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ class ComponentType(Enum):
|
|||||||
TOOL = "tool" # 工具组件
|
TOOL = "tool" # 工具组件
|
||||||
SCHEDULER = "scheduler" # 定时任务组件(预留)
|
SCHEDULER = "scheduler" # 定时任务组件(预留)
|
||||||
EVENT_HANDLER = "event_handler" # 事件处理组件
|
EVENT_HANDLER = "event_handler" # 事件处理组件
|
||||||
|
CHATTER = "chatter" # 聊天处理器组件
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.value
|
return self.value
|
||||||
@@ -40,7 +41,7 @@ class ActionActivationType(Enum):
|
|||||||
class ChatMode(Enum):
|
class ChatMode(Enum):
|
||||||
"""聊天模式枚举"""
|
"""聊天模式枚举"""
|
||||||
|
|
||||||
FOCUS = "focus" # Focus聊天模式
|
FOCUS = "focus" # 专注模式
|
||||||
NORMAL = "normal" # Normal聊天模式
|
NORMAL = "normal" # Normal聊天模式
|
||||||
PROACTIVE = "proactive" # 主动思考模式
|
PROACTIVE = "proactive" # 主动思考模式
|
||||||
PRIORITY = "priority" # 优先级聊天模式
|
PRIORITY = "priority" # 优先级聊天模式
|
||||||
@@ -54,8 +55,8 @@ class ChatMode(Enum):
|
|||||||
class ChatType(Enum):
|
class ChatType(Enum):
|
||||||
"""聊天类型枚举,用于限制插件在不同聊天环境中的使用"""
|
"""聊天类型枚举,用于限制插件在不同聊天环境中的使用"""
|
||||||
|
|
||||||
GROUP = "group" # 仅群聊可用
|
|
||||||
PRIVATE = "private" # 仅私聊可用
|
PRIVATE = "private" # 仅私聊可用
|
||||||
|
GROUP = "group" # 仅群聊可用
|
||||||
ALL = "all" # 群聊和私聊都可用
|
ALL = "all" # 群聊和私聊都可用
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
@@ -210,6 +211,17 @@ class EventHandlerInfo(ComponentInfo):
|
|||||||
self.component_type = ComponentType.EVENT_HANDLER
|
self.component_type = ComponentType.EVENT_HANDLER
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatterInfo(ComponentInfo):
|
||||||
|
"""聊天处理器组件信息"""
|
||||||
|
|
||||||
|
chat_type_allow: ChatType = ChatType.ALL # 允许的聊天类型
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
self.component_type = ComponentType.CHATTER
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EventInfo(ComponentInfo):
|
class EventInfo(ComponentInfo):
|
||||||
"""事件组件信息"""
|
"""事件组件信息"""
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type
|
from typing import TYPE_CHECKING, Dict, List, Optional, Any, Pattern, Tuple, Union, Type
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.base.component_types import (
|
from src.plugin_system.base.component_types import (
|
||||||
@@ -11,14 +11,17 @@ from src.plugin_system.base.component_types import (
|
|||||||
CommandInfo,
|
CommandInfo,
|
||||||
PlusCommandInfo,
|
PlusCommandInfo,
|
||||||
EventHandlerInfo,
|
EventHandlerInfo,
|
||||||
|
ChatterInfo,
|
||||||
PluginInfo,
|
PluginInfo,
|
||||||
ComponentType,
|
ComponentType,
|
||||||
)
|
)
|
||||||
|
|
||||||
from src.plugin_system.base.base_command import BaseCommand
|
from src.plugin_system.base.base_command import BaseCommand
|
||||||
from src.plugin_system.base.base_action import BaseAction
|
from src.plugin_system.base.base_action import BaseAction
|
||||||
from src.plugin_system.base.base_tool import BaseTool
|
from src.plugin_system.base.base_tool import BaseTool
|
||||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||||
from src.plugin_system.base.plus_command import PlusCommand
|
from src.plugin_system.base.plus_command import PlusCommand
|
||||||
|
from src.plugin_system.base.base_chatter import BaseChatter
|
||||||
|
|
||||||
logger = get_logger("component_registry")
|
logger = get_logger("component_registry")
|
||||||
|
|
||||||
@@ -31,42 +34,45 @@ class ComponentRegistry:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 命名空间式组件名构成法 f"{component_type}.{component_name}"
|
# 命名空间式组件名构成法 f"{component_type}.{component_name}"
|
||||||
self._plus_command_registry: Dict[str, Type[PlusCommand]] = {}
|
self._components: Dict[str, 'ComponentInfo'] = {}
|
||||||
self._components: Dict[str, ComponentInfo] = {}
|
|
||||||
"""组件注册表 命名空间式组件名 -> 组件信息"""
|
"""组件注册表 命名空间式组件名 -> 组件信息"""
|
||||||
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
|
self._components_by_type: Dict['ComponentType', Dict[str, 'ComponentInfo']] = {types: {} for types in ComponentType}
|
||||||
"""类型 -> 组件原名称 -> 组件信息"""
|
"""类型 -> 组件原名称 -> 组件信息"""
|
||||||
self._components_classes: Dict[
|
self._components_classes: Dict[
|
||||||
str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler, PlusCommand]]
|
str, Type[Union['BaseCommand', 'BaseAction', 'BaseTool', 'BaseEventHandler', 'PlusCommand', 'BaseChatter']]
|
||||||
] = {}
|
] = {}
|
||||||
"""命名空间式组件名 -> 组件类"""
|
"""命名空间式组件名 -> 组件类"""
|
||||||
|
|
||||||
# 插件注册表
|
# 插件注册表
|
||||||
self._plugins: Dict[str, PluginInfo] = {}
|
self._plugins: Dict[str, 'PluginInfo'] = {}
|
||||||
"""插件名 -> 插件信息"""
|
"""插件名 -> 插件信息"""
|
||||||
|
|
||||||
# Action特定注册表
|
# Action特定注册表
|
||||||
self._action_registry: Dict[str, Type[BaseAction]] = {}
|
self._action_registry: Dict[str, Type['BaseAction']] = {}
|
||||||
"""Action注册表 action名 -> action类"""
|
"""Action注册表 action名 -> action类"""
|
||||||
self._default_actions: Dict[str, ActionInfo] = {}
|
self._default_actions: Dict[str, 'ActionInfo'] = {}
|
||||||
"""默认动作集,即启用的Action集,用于重置ActionManager状态"""
|
"""默认动作集,即启用的Action集,用于重置ActionManager状态"""
|
||||||
|
|
||||||
# Command特定注册表
|
# Command特定注册表
|
||||||
self._command_registry: Dict[str, Type[BaseCommand]] = {}
|
self._command_registry: Dict[str, Type['BaseCommand']] = {}
|
||||||
"""Command类注册表 command名 -> command类"""
|
"""Command类注册表 command名 -> command类"""
|
||||||
self._command_patterns: Dict[Pattern, str] = {}
|
self._command_patterns: Dict[Pattern, str] = {}
|
||||||
"""编译后的正则 -> command名"""
|
"""编译后的正则 -> command名"""
|
||||||
|
|
||||||
# 工具特定注册表
|
# 工具特定注册表
|
||||||
self._tool_registry: Dict[str, Type[BaseTool]] = {} # 工具名 -> 工具类
|
self._tool_registry: Dict[str, Type['BaseTool']] = {} # 工具名 -> 工具类
|
||||||
self._llm_available_tools: Dict[str, Type[BaseTool]] = {} # llm可用的工具名 -> 工具类
|
self._llm_available_tools: Dict[str, Type['BaseTool']] = {} # llm可用的工具名 -> 工具类
|
||||||
|
|
||||||
# EventHandler特定注册表
|
# EventHandler特定注册表
|
||||||
self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {}
|
self._event_handler_registry: Dict[str, Type['BaseEventHandler']] = {}
|
||||||
"""event_handler名 -> event_handler类"""
|
"""event_handler名 -> event_handler类"""
|
||||||
self._enabled_event_handlers: Dict[str, Type[BaseEventHandler]] = {}
|
self._enabled_event_handlers: Dict[str, Type['BaseEventHandler']] = {}
|
||||||
"""启用的事件处理器 event_handler名 -> event_handler类"""
|
"""启用的事件处理器 event_handler名 -> event_handler类"""
|
||||||
|
|
||||||
|
self._chatter_registry: Dict[str, Type['BaseChatter']] = {}
|
||||||
|
"""chatter名 -> chatter类"""
|
||||||
|
self._enabled_chatter_registry: Dict[str, Type['BaseChatter']] = {}
|
||||||
|
"""启用的chatter名 -> chatter类"""
|
||||||
logger.info("组件注册中心初始化完成")
|
logger.info("组件注册中心初始化完成")
|
||||||
|
|
||||||
# == 注册方法 ==
|
# == 注册方法 ==
|
||||||
@@ -93,7 +99,7 @@ class ComponentRegistry:
|
|||||||
def register_component(
|
def register_component(
|
||||||
self,
|
self,
|
||||||
component_info: ComponentInfo,
|
component_info: ComponentInfo,
|
||||||
component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler, BaseTool]],
|
component_class: Type[Union['BaseCommand', 'BaseAction', 'BaseEventHandler', 'BaseTool', 'BaseChatter']],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""注册组件
|
"""注册组件
|
||||||
|
|
||||||
@@ -151,6 +157,10 @@ class ComponentRegistry:
|
|||||||
assert isinstance(component_info, EventHandlerInfo)
|
assert isinstance(component_info, EventHandlerInfo)
|
||||||
assert issubclass(component_class, BaseEventHandler)
|
assert issubclass(component_class, BaseEventHandler)
|
||||||
ret = self._register_event_handler_component(component_info, component_class)
|
ret = self._register_event_handler_component(component_info, component_class)
|
||||||
|
case ComponentType.CHATTER:
|
||||||
|
assert isinstance(component_info, ChatterInfo)
|
||||||
|
assert issubclass(component_class, BaseChatter)
|
||||||
|
ret = self._register_chatter_component(component_info, component_class)
|
||||||
case _:
|
case _:
|
||||||
logger.warning(f"未知组件类型: {component_type}")
|
logger.warning(f"未知组件类型: {component_type}")
|
||||||
|
|
||||||
@@ -162,7 +172,7 @@ class ComponentRegistry:
|
|||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _register_action_component(self, action_info: ActionInfo, action_class: Type[BaseAction]) -> bool:
|
def _register_action_component(self, action_info: 'ActionInfo', action_class: Type['BaseAction']) -> bool:
|
||||||
"""注册Action组件到Action特定注册表"""
|
"""注册Action组件到Action特定注册表"""
|
||||||
if not (action_name := action_info.name):
|
if not (action_name := action_info.name):
|
||||||
logger.error(f"Action组件 {action_class.__name__} 必须指定名称")
|
logger.error(f"Action组件 {action_class.__name__} 必须指定名称")
|
||||||
@@ -182,7 +192,7 @@ class ComponentRegistry:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _register_command_component(self, command_info: CommandInfo, command_class: Type[BaseCommand]) -> bool:
|
def _register_command_component(self, command_info: 'CommandInfo', command_class: Type['BaseCommand']) -> bool:
|
||||||
"""注册Command组件到Command特定注册表"""
|
"""注册Command组件到Command特定注册表"""
|
||||||
if not (command_name := command_info.name):
|
if not (command_name := command_info.name):
|
||||||
logger.error(f"Command组件 {command_class.__name__} 必须指定名称")
|
logger.error(f"Command组件 {command_class.__name__} 必须指定名称")
|
||||||
@@ -209,7 +219,7 @@ class ComponentRegistry:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def _register_plus_command_component(
|
def _register_plus_command_component(
|
||||||
self, plus_command_info: PlusCommandInfo, plus_command_class: Type[PlusCommand]
|
self, plus_command_info: 'PlusCommandInfo', plus_command_class: Type['PlusCommand']
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""注册PlusCommand组件到特定注册表"""
|
"""注册PlusCommand组件到特定注册表"""
|
||||||
plus_command_name = plus_command_info.name
|
plus_command_name = plus_command_info.name
|
||||||
@@ -223,7 +233,7 @@ class ComponentRegistry:
|
|||||||
|
|
||||||
# 创建专门的PlusCommand注册表(如果还没有)
|
# 创建专门的PlusCommand注册表(如果还没有)
|
||||||
if not hasattr(self, "_plus_command_registry"):
|
if not hasattr(self, "_plus_command_registry"):
|
||||||
self._plus_command_registry: Dict[str, Type[PlusCommand]] = {}
|
self._plus_command_registry: Dict[str, Type['PlusCommand']] = {}
|
||||||
|
|
||||||
plus_command_class.plugin_name = plus_command_info.plugin_name
|
plus_command_class.plugin_name = plus_command_info.plugin_name
|
||||||
# 设置插件配置
|
# 设置插件配置
|
||||||
@@ -233,7 +243,7 @@ class ComponentRegistry:
|
|||||||
logger.debug(f"已注册PlusCommand组件: {plus_command_name}")
|
logger.debug(f"已注册PlusCommand组件: {plus_command_name}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _register_tool_component(self, tool_info: ToolInfo, tool_class: Type[BaseTool]) -> bool:
|
def _register_tool_component(self, tool_info: 'ToolInfo', tool_class: Type['BaseTool']) -> bool:
|
||||||
"""注册Tool组件到Tool特定注册表"""
|
"""注册Tool组件到Tool特定注册表"""
|
||||||
tool_name = tool_info.name
|
tool_name = tool_info.name
|
||||||
|
|
||||||
@@ -249,7 +259,7 @@ class ComponentRegistry:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def _register_event_handler_component(
|
def _register_event_handler_component(
|
||||||
self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]
|
self, handler_info: 'EventHandlerInfo', handler_class: Type['BaseEventHandler']
|
||||||
) -> bool:
|
) -> bool:
|
||||||
if not (handler_name := handler_info.name):
|
if not (handler_name := handler_info.name):
|
||||||
logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称")
|
logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称")
|
||||||
@@ -271,11 +281,38 @@ class ComponentRegistry:
|
|||||||
# 使用EventManager进行事件处理器注册
|
# 使用EventManager进行事件处理器注册
|
||||||
from src.plugin_system.core.event_manager import event_manager
|
from src.plugin_system.core.event_manager import event_manager
|
||||||
|
|
||||||
return event_manager.register_event_handler(handler_class,self.get_plugin_config(handler_info.plugin_name) or {})
|
return event_manager.register_event_handler(
|
||||||
|
handler_class, self.get_plugin_config(handler_info.plugin_name) or {}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _register_chatter_component(self, chatter_info: 'ChatterInfo', chatter_class: Type['BaseChatter']) -> bool:
|
||||||
|
"""注册Chatter组件到Chatter特定注册表"""
|
||||||
|
chatter_name = chatter_info.name
|
||||||
|
|
||||||
|
if not chatter_name:
|
||||||
|
logger.error(f"Chatter组件 {chatter_class.__name__} 必须指定名称")
|
||||||
|
return False
|
||||||
|
if not isinstance(chatter_info, ChatterInfo) or not issubclass(chatter_class, BaseChatter):
|
||||||
|
logger.error(f"注册失败: {chatter_name} 不是有效的Chatter")
|
||||||
|
return False
|
||||||
|
|
||||||
|
chatter_class.plugin_name = chatter_info.plugin_name
|
||||||
|
# 设置插件配置
|
||||||
|
chatter_class.plugin_config = self.get_plugin_config(chatter_info.plugin_name) or {}
|
||||||
|
|
||||||
|
self._chatter_registry[chatter_name] = chatter_class
|
||||||
|
|
||||||
|
if not chatter_info.enabled:
|
||||||
|
logger.warning(f"Chatter组件 {chatter_name} 未启用")
|
||||||
|
return True # 未启用,但是也是注册成功
|
||||||
|
self._enabled_chatter_registry[chatter_name] = chatter_class
|
||||||
|
|
||||||
|
logger.debug(f"已注册Chatter组件: {chatter_name}")
|
||||||
|
return True
|
||||||
|
|
||||||
# === 组件移除相关 ===
|
# === 组件移除相关 ===
|
||||||
|
|
||||||
async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool:
|
async def remove_component(self, component_name: str, component_type: 'ComponentType', plugin_name: str) -> bool:
|
||||||
target_component_class = self.get_component_class(component_name, component_type)
|
target_component_class = self.get_component_class(component_name, component_type)
|
||||||
if not target_component_class:
|
if not target_component_class:
|
||||||
logger.warning(f"组件 {component_name} 未注册,无法移除")
|
logger.warning(f"组件 {component_name} 未注册,无法移除")
|
||||||
@@ -323,6 +360,12 @@ class ComponentRegistry:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"移除EventHandler事件订阅时出错: {e}")
|
logger.warning(f"移除EventHandler事件订阅时出错: {e}")
|
||||||
|
|
||||||
|
case ComponentType.CHATTER:
|
||||||
|
# 移除Chatter注册
|
||||||
|
if hasattr(self, '_chatter_registry'):
|
||||||
|
self._chatter_registry.pop(component_name, None)
|
||||||
|
logger.debug(f"已移除Chatter组件: {component_name}")
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
logger.warning(f"未知的组件类型: {component_type}")
|
logger.warning(f"未知的组件类型: {component_type}")
|
||||||
return False
|
return False
|
||||||
@@ -441,8 +484,8 @@ class ComponentRegistry:
|
|||||||
|
|
||||||
# === 组件查询方法 ===
|
# === 组件查询方法 ===
|
||||||
def get_component_info(
|
def get_component_info(
|
||||||
self, component_name: str, component_type: Optional[ComponentType] = None
|
self, component_name: str, component_type: Optional['ComponentType'] = None
|
||||||
) -> Optional[ComponentInfo]:
|
) -> Optional['ComponentInfo']:
|
||||||
# sourcery skip: class-extract-method
|
# sourcery skip: class-extract-method
|
||||||
"""获取组件信息,支持自动命名空间解析
|
"""获取组件信息,支持自动命名空间解析
|
||||||
|
|
||||||
@@ -486,8 +529,8 @@ class ComponentRegistry:
|
|||||||
def get_component_class(
|
def get_component_class(
|
||||||
self,
|
self,
|
||||||
component_name: str,
|
component_name: str,
|
||||||
component_type: Optional[ComponentType] = None,
|
component_type: Optional['ComponentType'] = None,
|
||||||
) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler], Type[BaseTool]]]:
|
) -> Optional[Union[Type['BaseCommand'], Type['BaseAction'], Type['BaseEventHandler'], Type['BaseTool']]]:
|
||||||
"""获取组件类,支持自动命名空间解析
|
"""获取组件类,支持自动命名空间解析
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -504,7 +547,7 @@ class ComponentRegistry:
|
|||||||
# 2. 如果指定了组件类型,构造命名空间化的名称查找
|
# 2. 如果指定了组件类型,构造命名空间化的名称查找
|
||||||
if component_type:
|
if component_type:
|
||||||
namespaced_name = f"{component_type.value}.{component_name}"
|
namespaced_name = f"{component_type.value}.{component_name}"
|
||||||
return self._components_classes.get(namespaced_name)
|
return self._components_classes.get(namespaced_name) # type: ignore[valid-type]
|
||||||
|
|
||||||
# 3. 如果没有指定类型,尝试在所有命名空间中查找
|
# 3. 如果没有指定类型,尝试在所有命名空间中查找
|
||||||
candidates = []
|
candidates = []
|
||||||
@@ -529,22 +572,22 @@ class ComponentRegistry:
|
|||||||
# 4. 都没找到
|
# 4. 都没找到
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
|
def get_components_by_type(self, component_type: 'ComponentType') -> Dict[str, 'ComponentInfo']:
|
||||||
"""获取指定类型的所有组件"""
|
"""获取指定类型的所有组件"""
|
||||||
return self._components_by_type.get(component_type, {}).copy()
|
return self._components_by_type.get(component_type, {}).copy()
|
||||||
|
|
||||||
def get_enabled_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
|
def get_enabled_components_by_type(self, component_type: 'ComponentType') -> Dict[str, 'ComponentInfo']:
|
||||||
"""获取指定类型的所有启用组件"""
|
"""获取指定类型的所有启用组件"""
|
||||||
components = self.get_components_by_type(component_type)
|
components = self.get_components_by_type(component_type)
|
||||||
return {name: info for name, info in components.items() if info.enabled}
|
return {name: info for name, info in components.items() if info.enabled}
|
||||||
|
|
||||||
# === Action特定查询方法 ===
|
# === Action特定查询方法 ===
|
||||||
|
|
||||||
def get_action_registry(self) -> Dict[str, Type[BaseAction]]:
|
def get_action_registry(self) -> Dict[str, Type['BaseAction']]:
|
||||||
"""获取Action注册表"""
|
"""获取Action注册表"""
|
||||||
return self._action_registry.copy()
|
return self._action_registry.copy()
|
||||||
|
|
||||||
def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]:
|
def get_registered_action_info(self, action_name: str) -> Optional['ActionInfo']:
|
||||||
"""获取Action信息"""
|
"""获取Action信息"""
|
||||||
info = self.get_component_info(action_name, ComponentType.ACTION)
|
info = self.get_component_info(action_name, ComponentType.ACTION)
|
||||||
return info if isinstance(info, ActionInfo) else None
|
return info if isinstance(info, ActionInfo) else None
|
||||||
@@ -555,11 +598,11 @@ class ComponentRegistry:
|
|||||||
|
|
||||||
# === Command特定查询方法 ===
|
# === Command特定查询方法 ===
|
||||||
|
|
||||||
def get_command_registry(self) -> Dict[str, Type[BaseCommand]]:
|
def get_command_registry(self) -> Dict[str, Type['BaseCommand']]:
|
||||||
"""获取Command注册表"""
|
"""获取Command注册表"""
|
||||||
return self._command_registry.copy()
|
return self._command_registry.copy()
|
||||||
|
|
||||||
def get_registered_command_info(self, command_name: str) -> Optional[CommandInfo]:
|
def get_registered_command_info(self, command_name: str) -> Optional['CommandInfo']:
|
||||||
"""获取Command信息"""
|
"""获取Command信息"""
|
||||||
info = self.get_component_info(command_name, ComponentType.COMMAND)
|
info = self.get_component_info(command_name, ComponentType.COMMAND)
|
||||||
return info if isinstance(info, CommandInfo) else None
|
return info if isinstance(info, CommandInfo) else None
|
||||||
@@ -568,7 +611,7 @@ class ComponentRegistry:
|
|||||||
"""获取Command模式注册表"""
|
"""获取Command模式注册表"""
|
||||||
return self._command_patterns.copy()
|
return self._command_patterns.copy()
|
||||||
|
|
||||||
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, CommandInfo]]:
|
def find_command_by_text(self, text: str) -> Optional[Tuple[Type['BaseCommand'], dict, 'CommandInfo']]:
|
||||||
# sourcery skip: use-named-expression, use-next
|
# sourcery skip: use-named-expression, use-next
|
||||||
"""根据文本查找匹配的命令
|
"""根据文本查找匹配的命令
|
||||||
|
|
||||||
@@ -595,15 +638,15 @@ class ComponentRegistry:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# === Tool 特定查询方法 ===
|
# === Tool 特定查询方法 ===
|
||||||
def get_tool_registry(self) -> Dict[str, Type[BaseTool]]:
|
def get_tool_registry(self) -> Dict[str, Type['BaseTool']]:
|
||||||
"""获取Tool注册表"""
|
"""获取Tool注册表"""
|
||||||
return self._tool_registry.copy()
|
return self._tool_registry.copy()
|
||||||
|
|
||||||
def get_llm_available_tools(self) -> Dict[str, Type[BaseTool]]:
|
def get_llm_available_tools(self) -> Dict[str, Type['BaseTool']]:
|
||||||
"""获取LLM可用的Tool列表"""
|
"""获取LLM可用的Tool列表"""
|
||||||
return self._llm_available_tools.copy()
|
return self._llm_available_tools.copy()
|
||||||
|
|
||||||
def get_registered_tool_info(self, tool_name: str) -> Optional[ToolInfo]:
|
def get_registered_tool_info(self, tool_name: str) -> Optional['ToolInfo']:
|
||||||
"""获取Tool信息
|
"""获取Tool信息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -616,13 +659,13 @@ class ComponentRegistry:
|
|||||||
return info if isinstance(info, ToolInfo) else None
|
return info if isinstance(info, ToolInfo) else None
|
||||||
|
|
||||||
# === PlusCommand 特定查询方法 ===
|
# === PlusCommand 特定查询方法 ===
|
||||||
def get_plus_command_registry(self) -> Dict[str, Type[PlusCommand]]:
|
def get_plus_command_registry(self) -> Dict[str, Type['PlusCommand']]:
|
||||||
"""获取PlusCommand注册表"""
|
"""获取PlusCommand注册表"""
|
||||||
if not hasattr(self, "_plus_command_registry"):
|
if not hasattr(self, "_plus_command_registry"):
|
||||||
pass
|
self._plus_command_registry: Dict[str, Type[PlusCommand]] = {}
|
||||||
return self._plus_command_registry.copy()
|
return self._plus_command_registry.copy()
|
||||||
|
|
||||||
def get_registered_plus_command_info(self, command_name: str) -> Optional[PlusCommandInfo]:
|
def get_registered_plus_command_info(self, command_name: str) -> Optional['PlusCommandInfo']:
|
||||||
"""获取PlusCommand信息
|
"""获取PlusCommand信息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -636,26 +679,44 @@ class ComponentRegistry:
|
|||||||
|
|
||||||
# === EventHandler 特定查询方法 ===
|
# === EventHandler 特定查询方法 ===
|
||||||
|
|
||||||
def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]:
|
def get_event_handler_registry(self) -> Dict[str, Type['BaseEventHandler']]:
|
||||||
"""获取事件处理器注册表"""
|
"""获取事件处理器注册表"""
|
||||||
return self._event_handler_registry.copy()
|
return self._event_handler_registry.copy()
|
||||||
|
|
||||||
def get_registered_event_handler_info(self, handler_name: str) -> Optional[EventHandlerInfo]:
|
def get_registered_event_handler_info(self, handler_name: str) -> Optional['EventHandlerInfo']:
|
||||||
"""获取事件处理器信息"""
|
"""获取事件处理器信息"""
|
||||||
info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER)
|
info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER)
|
||||||
return info if isinstance(info, EventHandlerInfo) else None
|
return info if isinstance(info, EventHandlerInfo) else None
|
||||||
|
|
||||||
def get_enabled_event_handlers(self) -> Dict[str, Type[BaseEventHandler]]:
|
def get_enabled_event_handlers(self) -> Dict[str, Type['BaseEventHandler']]:
|
||||||
"""获取启用的事件处理器"""
|
"""获取启用的事件处理器"""
|
||||||
return self._enabled_event_handlers.copy()
|
return self._enabled_event_handlers.copy()
|
||||||
|
|
||||||
|
# === Chatter 特定查询方法 ===
|
||||||
|
def get_chatter_registry(self) -> Dict[str, Type['BaseChatter']]:
|
||||||
|
"""获取Chatter注册表"""
|
||||||
|
if not hasattr(self, '_chatter_registry'):
|
||||||
|
self._chatter_registry: Dict[str, Type[BaseChatter]] = {}
|
||||||
|
return self._chatter_registry.copy()
|
||||||
|
|
||||||
|
def get_enabled_chatter_registry(self) -> Dict[str, Type['BaseChatter']]:
|
||||||
|
"""获取启用的Chatter注册表"""
|
||||||
|
if not hasattr(self, '_enabled_chatter_registry'):
|
||||||
|
self._enabled_chatter_registry: Dict[str, Type[BaseChatter]] = {}
|
||||||
|
return self._enabled_chatter_registry.copy()
|
||||||
|
|
||||||
|
def get_registered_chatter_info(self, chatter_name: str) -> Optional['ChatterInfo']:
|
||||||
|
"""获取Chatter信息"""
|
||||||
|
info = self.get_component_info(chatter_name, ComponentType.CHATTER)
|
||||||
|
return info if isinstance(info, ChatterInfo) else None
|
||||||
|
|
||||||
# === 插件查询方法 ===
|
# === 插件查询方法 ===
|
||||||
|
|
||||||
def get_plugin_info(self, plugin_name: str) -> Optional[PluginInfo]:
|
def get_plugin_info(self, plugin_name: str) -> Optional['PluginInfo']:
|
||||||
"""获取插件信息"""
|
"""获取插件信息"""
|
||||||
return self._plugins.get(plugin_name)
|
return self._plugins.get(plugin_name)
|
||||||
|
|
||||||
def get_all_plugins(self) -> Dict[str, PluginInfo]:
|
def get_all_plugins(self) -> Dict[str, 'PluginInfo']:
|
||||||
"""获取所有插件"""
|
"""获取所有插件"""
|
||||||
return self._plugins.copy()
|
return self._plugins.copy()
|
||||||
|
|
||||||
@@ -663,13 +724,12 @@ class ComponentRegistry:
|
|||||||
# """获取所有启用的插件"""
|
# """获取所有启用的插件"""
|
||||||
# return {name: info for name, info in self._plugins.items() if info.enabled}
|
# return {name: info for name, info in self._plugins.items() if info.enabled}
|
||||||
|
|
||||||
def get_plugin_components(self, plugin_name: str) -> List[ComponentInfo]:
|
def get_plugin_components(self, plugin_name: str) -> List['ComponentInfo']:
|
||||||
"""获取插件的所有组件"""
|
"""获取插件的所有组件"""
|
||||||
plugin_info = self.get_plugin_info(plugin_name)
|
plugin_info = self.get_plugin_info(plugin_name)
|
||||||
return plugin_info.components if plugin_info else []
|
return plugin_info.components if plugin_info else []
|
||||||
|
|
||||||
@staticmethod
|
def get_plugin_config(self, plugin_name: str) -> dict:
|
||||||
def get_plugin_config(plugin_name: str) -> dict:
|
|
||||||
"""获取插件配置
|
"""获取插件配置
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -688,9 +748,10 @@ class ComponentRegistry:
|
|||||||
# 如果插件实例不存在,尝试从配置文件读取
|
# 如果插件实例不存在,尝试从配置文件读取
|
||||||
try:
|
try:
|
||||||
import toml
|
import toml
|
||||||
|
|
||||||
config_path = Path("config") / "plugins" / plugin_name / "config.toml"
|
config_path = Path("config") / "plugins" / plugin_name / "config.toml"
|
||||||
if config_path.exists():
|
if config_path.exists():
|
||||||
with open(config_path, 'r', encoding='utf-8') as f:
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
config_data = toml.load(f)
|
config_data = toml.load(f)
|
||||||
logger.debug(f"从配置文件读取插件 {plugin_name} 的配置")
|
logger.debug(f"从配置文件读取插件 {plugin_name} 的配置")
|
||||||
return config_data
|
return config_data
|
||||||
@@ -706,6 +767,7 @@ class ComponentRegistry:
|
|||||||
tool_components: int = 0
|
tool_components: int = 0
|
||||||
events_handlers: int = 0
|
events_handlers: int = 0
|
||||||
plus_command_components: int = 0
|
plus_command_components: int = 0
|
||||||
|
chatter_components: int = 0
|
||||||
for component in self._components.values():
|
for component in self._components.values():
|
||||||
if component.component_type == ComponentType.ACTION:
|
if component.component_type == ComponentType.ACTION:
|
||||||
action_components += 1
|
action_components += 1
|
||||||
@@ -717,12 +779,15 @@ class ComponentRegistry:
|
|||||||
events_handlers += 1
|
events_handlers += 1
|
||||||
elif component.component_type == ComponentType.PLUS_COMMAND:
|
elif component.component_type == ComponentType.PLUS_COMMAND:
|
||||||
plus_command_components += 1
|
plus_command_components += 1
|
||||||
|
elif component.component_type == ComponentType.CHATTER:
|
||||||
|
chatter_components += 1
|
||||||
return {
|
return {
|
||||||
"action_components": action_components,
|
"action_components": action_components,
|
||||||
"command_components": command_components,
|
"command_components": command_components,
|
||||||
"tool_components": tool_components,
|
"tool_components": tool_components,
|
||||||
"event_handlers": events_handlers,
|
"event_handlers": events_handlers,
|
||||||
"plus_command_components": plus_command_components,
|
"plus_command_components": plus_command_components,
|
||||||
|
"chatter_components": chatter_components,
|
||||||
"total_components": len(self._components),
|
"total_components": len(self._components),
|
||||||
"total_plugins": len(self._plugins),
|
"total_plugins": len(self._plugins),
|
||||||
"components_by_type": {
|
"components_by_type": {
|
||||||
@@ -730,6 +795,8 @@ class ComponentRegistry:
|
|||||||
},
|
},
|
||||||
"enabled_components": len([c for c in self._components.values() if c.enabled]),
|
"enabled_components": len([c for c in self._components.values() if c.enabled]),
|
||||||
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
|
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
|
||||||
|
"enabled_components": len([c for c in self._components.values() if c.enabled]),
|
||||||
|
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
|
||||||
}
|
}
|
||||||
|
|
||||||
# === 组件移除相关 ===
|
# === 组件移除相关 ===
|
||||||
|
|||||||
@@ -146,7 +146,9 @@ class EventManager:
|
|||||||
logger.info(f"事件 {event_name} 已禁用")
|
logger.info(f"事件 {event_name} 已禁用")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def register_event_handler(self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None) -> bool:
|
def register_event_handler(
|
||||||
|
self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None
|
||||||
|
) -> bool:
|
||||||
"""注册事件处理器
|
"""注册事件处理器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -168,7 +170,7 @@ class EventManager:
|
|||||||
# 创建事件处理器实例,传递插件配置
|
# 创建事件处理器实例,传递插件配置
|
||||||
handler_instance = handler_class()
|
handler_instance = handler_class()
|
||||||
handler_instance.plugin_config = plugin_config
|
handler_instance.plugin_config = plugin_config
|
||||||
if plugin_config is not None and hasattr(handler_instance, 'set_plugin_config'):
|
if plugin_config is not None and hasattr(handler_instance, "set_plugin_config"):
|
||||||
handler_instance.set_plugin_config(plugin_config)
|
handler_instance.set_plugin_config(plugin_config)
|
||||||
|
|
||||||
self._event_handlers[handler_name] = handler_instance
|
self._event_handlers[handler_name] = handler_instance
|
||||||
|
|||||||
@@ -129,9 +129,7 @@ class PluginManager:
|
|||||||
self._show_plugin_components(plugin_name)
|
self._show_plugin_components(plugin_name)
|
||||||
|
|
||||||
# 检查并调用 on_plugin_loaded 钩子(如果存在)
|
# 检查并调用 on_plugin_loaded 钩子(如果存在)
|
||||||
if hasattr(plugin_instance, "on_plugin_loaded") and callable(
|
if hasattr(plugin_instance, "on_plugin_loaded") and callable(plugin_instance.on_plugin_loaded):
|
||||||
plugin_instance.on_plugin_loaded
|
|
||||||
):
|
|
||||||
logger.debug(f"为插件 '{plugin_name}' 调用 on_plugin_loaded 钩子")
|
logger.debug(f"为插件 '{plugin_name}' 调用 on_plugin_loaded 钩子")
|
||||||
try:
|
try:
|
||||||
# 使用 asyncio.create_task 确保它不会阻塞加载流程
|
# 使用 asyncio.create_task 确保它不会阻塞加载流程
|
||||||
@@ -380,13 +378,14 @@ class PluginManager:
|
|||||||
tool_count = stats.get("tool_components", 0)
|
tool_count = stats.get("tool_components", 0)
|
||||||
event_handler_count = stats.get("event_handlers", 0)
|
event_handler_count = stats.get("event_handlers", 0)
|
||||||
plus_command_count = stats.get("plus_command_components", 0)
|
plus_command_count = stats.get("plus_command_components", 0)
|
||||||
|
chatter_count = stats.get("chatter_components", 0)
|
||||||
total_components = stats.get("total_components", 0)
|
total_components = stats.get("total_components", 0)
|
||||||
|
|
||||||
# 📋 显示插件加载总览
|
# 📋 显示插件加载总览
|
||||||
if total_registered > 0:
|
if total_registered > 0:
|
||||||
logger.info("🎉 插件系统加载完成!")
|
logger.info("🎉 插件系统加载完成!")
|
||||||
logger.info(
|
logger.info(
|
||||||
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, PlusCommand: {plus_command_count}, EventHandler: {event_handler_count})"
|
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, PlusCommand: {plus_command_count}, EventHandler: {event_handler_count}, Chatter: {chatter_count})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 显示详细的插件列表
|
# 显示详细的插件列表
|
||||||
@@ -442,6 +441,12 @@ class PluginManager:
|
|||||||
if plus_command_components:
|
if plus_command_components:
|
||||||
plus_command_names = [c.name for c in plus_command_components]
|
plus_command_names = [c.name for c in plus_command_components]
|
||||||
logger.info(f" ⚡ PlusCommand组件: {', '.join(plus_command_names)}")
|
logger.info(f" ⚡ PlusCommand组件: {', '.join(plus_command_names)}")
|
||||||
|
chatter_components = [
|
||||||
|
c for c in plugin_info.components if c.component_type == ComponentType.CHATTER
|
||||||
|
]
|
||||||
|
if chatter_components:
|
||||||
|
chatter_names = [c.name for c in chatter_components]
|
||||||
|
logger.info(f" 🗣️ Chatter组件: {', '.join(chatter_names)}")
|
||||||
if event_handler_components:
|
if event_handler_components:
|
||||||
event_handler_names = [c.name for c in event_handler_components]
|
event_handler_names = [c.name for c in event_handler_components]
|
||||||
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}")
|
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}")
|
||||||
|
|||||||
125
src/plugins/built_in/affinity_flow_chatter/README.md
Normal file
125
src/plugins/built_in/affinity_flow_chatter/README.md
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
# 亲和力聊天处理器插件
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
这是一个内置的chatter插件,实现了基于亲和力流的智能聊天处理器,具有兴趣度评分和人物关系构建功能。
|
||||||
|
|
||||||
|
## 功能特性
|
||||||
|
|
||||||
|
- **智能兴趣度评分**: 自动识别和评估用户兴趣话题
|
||||||
|
- **人物关系系统**: 根据互动历史建立和维持用户关系
|
||||||
|
- **多聊天类型支持**: 支持私聊和群聊场景
|
||||||
|
- **插件化架构**: 完全集成到插件系统中
|
||||||
|
|
||||||
|
## 组件架构
|
||||||
|
|
||||||
|
### BaseChatter (抽象基类)
|
||||||
|
- 位置: `src/plugin_system/base/base_chatter.py`
|
||||||
|
- 功能: 定义所有chatter组件的基础接口
|
||||||
|
- 必须实现的方法: `execute(context: StreamContext) -> dict`
|
||||||
|
|
||||||
|
### ChatterManager (管理器)
|
||||||
|
- 位置: `src/chat/chatter_manager.py`
|
||||||
|
- 功能: 管理和调度所有chatter组件
|
||||||
|
- 特性: 自动从插件系统注册和发现chatter组件
|
||||||
|
|
||||||
|
### AffinityChatter (具体实现)
|
||||||
|
- 位置: `src/plugins/built_in/chatter/affinity_chatter.py`
|
||||||
|
- 功能: 亲和力流聊天处理器的具体实现
|
||||||
|
- 支持的聊天类型: PRIVATE, GROUP
|
||||||
|
|
||||||
|
## 使用方法
|
||||||
|
|
||||||
|
### 1. 基本使用
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.chat.chatter_manager import ChatterManager
|
||||||
|
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||||
|
|
||||||
|
# 初始化
|
||||||
|
action_manager = ChatterActionManager()
|
||||||
|
chatter_manager = ChatterManager(action_manager)
|
||||||
|
|
||||||
|
# 处理消息流
|
||||||
|
result = await chatter_manager.process_stream_context(stream_id, context)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 创建自定义Chatter
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.plugin_system.base.base_chatter import BaseChatter
|
||||||
|
from src.plugin_system.base.component_types import ChatType, ComponentType
|
||||||
|
from src.plugin_system.base.component_types import ChatterInfo
|
||||||
|
|
||||||
|
class CustomChatter(BaseChatter):
|
||||||
|
chat_types = [ChatType.PRIVATE] # 只支持私聊
|
||||||
|
|
||||||
|
async def execute(self, context: StreamContext) -> dict:
|
||||||
|
# 实现你的聊天逻辑
|
||||||
|
return {"success": True, "message": "处理完成"}
|
||||||
|
|
||||||
|
# 在插件中注册
|
||||||
|
async def on_load(self):
|
||||||
|
chatter_info = ChatterInfo(
|
||||||
|
name="custom_chatter",
|
||||||
|
component_type=ComponentType.CHATTER,
|
||||||
|
description="自定义聊天处理器",
|
||||||
|
enabled=True,
|
||||||
|
plugin_name=self.name,
|
||||||
|
chat_type_allow=ChatType.PRIVATE
|
||||||
|
)
|
||||||
|
|
||||||
|
ComponentRegistry.register_component(
|
||||||
|
component_info=chatter_info,
|
||||||
|
component_class=CustomChatter
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 配置
|
||||||
|
|
||||||
|
### 插件配置文件
|
||||||
|
- 位置: `src/plugins/built_in/chatter/_manifest.json`
|
||||||
|
- 包含插件信息和组件配置
|
||||||
|
|
||||||
|
### 聊天类型
|
||||||
|
- `PRIVATE`: 私聊
|
||||||
|
- `GROUP`: 群聊
|
||||||
|
- `ALL`: 所有类型
|
||||||
|
|
||||||
|
## 核心概念
|
||||||
|
|
||||||
|
### 1. 兴趣值系统
|
||||||
|
- 自动识别同类话题
|
||||||
|
- 兴趣值会根据聊天频率增减
|
||||||
|
- 支持新话题的自动学习
|
||||||
|
|
||||||
|
### 2. 人物关系系统
|
||||||
|
- 根据互动质量建立关系分
|
||||||
|
- 不同关系分对应不同的回复风格
|
||||||
|
- 支持情感化的交流
|
||||||
|
|
||||||
|
### 3. 执行流程
|
||||||
|
1. 接收StreamContext
|
||||||
|
2. 使用ActionPlanner进行规划
|
||||||
|
3. 执行相应的Action
|
||||||
|
4. 返回处理结果
|
||||||
|
|
||||||
|
## 扩展开发
|
||||||
|
|
||||||
|
### 添加新的Chatter类型
|
||||||
|
1. 继承BaseChatter类
|
||||||
|
2. 实现execute方法
|
||||||
|
3. 在插件中注册组件
|
||||||
|
4. 配置支持的聊天类型
|
||||||
|
|
||||||
|
### 集成现有功能
|
||||||
|
- 使用ActionPlanner进行动作规划
|
||||||
|
- 通过ActionManager执行动作
|
||||||
|
- 利用现有的记忆和知识系统
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
|
||||||
|
1. 所有chatter组件必须实现`execute`方法
|
||||||
|
2. 插件注册时需要指定支持的聊天类型
|
||||||
|
3. 组件名称不能包含点号(.)
|
||||||
|
4. 确保在插件卸载时正确清理资源
|
||||||
7
src/plugins/built_in/affinity_flow_chatter/__init__.py
Normal file
7
src/plugins/built_in/affinity_flow_chatter/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""
|
||||||
|
亲和力聊天处理器插件
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .plugin import AffinityChatterPlugin
|
||||||
|
|
||||||
|
__all__ = ["AffinityChatterPlugin"]
|
||||||
23
src/plugins/built_in/affinity_flow_chatter/_manifest.json
Normal file
23
src/plugins/built_in/affinity_flow_chatter/_manifest.json
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
{
|
||||||
|
"manifest_version": 1,
|
||||||
|
"name": "affinity_chatter",
|
||||||
|
"display_name": "Affinity Flow Chatter",
|
||||||
|
"description": "Built-in chatter plugin for affinity flow with interest scoring and relationship building",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"author": "MoFox",
|
||||||
|
"plugin_class": "AffinityChatterPlugin",
|
||||||
|
"enabled": true,
|
||||||
|
"is_built_in": true,
|
||||||
|
"components": [
|
||||||
|
{
|
||||||
|
"name": "affinity_chatter",
|
||||||
|
"type": "chatter",
|
||||||
|
"description": "Affinity flow chatter with intelligent interest scoring and relationship building",
|
||||||
|
"enabled": true,
|
||||||
|
"chat_type_allow": ["all"]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"host_application": { "min_version": "0.8.0" },
|
||||||
|
"keywords": ["chatter", "affinity", "conversation"],
|
||||||
|
"categories": ["Chat", "AI"]
|
||||||
|
}
|
||||||
236
src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py
Normal file
236
src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
"""
|
||||||
|
亲和力聊天处理器
|
||||||
|
基于现有的AffinityFlowChatter重构为插件化组件
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from src.plugin_system.base.base_chatter import BaseChatter
|
||||||
|
from src.plugin_system.base.component_types import ChatType
|
||||||
|
from src.common.data_models.message_manager_data_model import StreamContext
|
||||||
|
from src.plugins.built_in.affinity_flow_chatter.planner import ChatterActionPlanner
|
||||||
|
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.chat.express.expression_learner import expression_learner_manager
|
||||||
|
|
||||||
|
logger = get_logger("affinity_chatter")
|
||||||
|
|
||||||
|
# 定义颜色
|
||||||
|
SOFT_GREEN = "\033[38;5;118m" # 一个更柔和的绿色
|
||||||
|
RESET_COLOR = "\033[0m"
|
||||||
|
|
||||||
|
|
||||||
|
class AffinityChatter(BaseChatter):
|
||||||
|
"""亲和力聊天处理器"""
|
||||||
|
|
||||||
|
chatter_name: str = "AffinityChatter"
|
||||||
|
chatter_description: str = "基于亲和力模型的智能聊天处理器,支持多种聊天类型"
|
||||||
|
chat_types: list[ChatType] = [ChatType.ALL] # 支持所有聊天类型
|
||||||
|
|
||||||
|
def __init__(self, stream_id: str, action_manager: ChatterActionManager):
|
||||||
|
"""
|
||||||
|
初始化亲和力聊天处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_id: 聊天流ID
|
||||||
|
planner: 动作规划器
|
||||||
|
action_manager: 动作管理器
|
||||||
|
"""
|
||||||
|
super().__init__(stream_id, action_manager)
|
||||||
|
self.planner = ChatterActionPlanner(stream_id, action_manager)
|
||||||
|
|
||||||
|
# 处理器统计
|
||||||
|
self.stats = {
|
||||||
|
"messages_processed": 0,
|
||||||
|
"plans_created": 0,
|
||||||
|
"actions_executed": 0,
|
||||||
|
"successful_executions": 0,
|
||||||
|
"failed_executions": 0,
|
||||||
|
}
|
||||||
|
self.last_activity_time = time.time()
|
||||||
|
|
||||||
|
async def execute(self, context: StreamContext) -> dict:
|
||||||
|
"""
|
||||||
|
处理StreamContext对象
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: StreamContext对象,包含聊天流的所有消息信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理结果字典
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 触发表达学习
|
||||||
|
learner = expression_learner_manager.get_expression_learner(self.stream_id)
|
||||||
|
asyncio.create_task(learner.trigger_learning_for_chat())
|
||||||
|
|
||||||
|
unread_messages = context.get_unread_messages()
|
||||||
|
|
||||||
|
# 使用增强版规划器处理消息
|
||||||
|
actions, target_message = await self.planner.plan(context=context)
|
||||||
|
self.stats["plans_created"] += 1
|
||||||
|
|
||||||
|
# 执行动作(如果规划器返回了动作)
|
||||||
|
execution_result = {"executed_count": len(actions) if actions else 0}
|
||||||
|
if actions:
|
||||||
|
logger.debug(f"聊天流 {self.stream_id} 生成了 {len(actions)} 个动作")
|
||||||
|
|
||||||
|
# 更新统计
|
||||||
|
self.stats["messages_processed"] += 1
|
||||||
|
self.stats["actions_executed"] += execution_result.get("executed_count", 0)
|
||||||
|
self.stats["successful_executions"] += 1
|
||||||
|
self.last_activity_time = time.time()
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"success": True,
|
||||||
|
"stream_id": self.stream_id,
|
||||||
|
"plan_created": True,
|
||||||
|
"actions_count": len(actions) if actions else 0,
|
||||||
|
"has_target_message": target_message is not None,
|
||||||
|
"unread_messages_processed": len(unread_messages),
|
||||||
|
**execution_result,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"聊天流 {self.stream_id} StreamContext处理成功: 动作数={result['actions_count']}, 未读消息={result['unread_messages_processed']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"亲和力聊天处理器 {self.stream_id} 处理StreamContext时出错: {e}\n{traceback.format_exc()}")
|
||||||
|
self.stats["failed_executions"] += 1
|
||||||
|
self.last_activity_time = time.time()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"stream_id": self.stream_id,
|
||||||
|
"error_message": str(e),
|
||||||
|
"executed_count": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取处理器统计信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
统计信息字典
|
||||||
|
"""
|
||||||
|
return self.stats.copy()
|
||||||
|
|
||||||
|
def get_planner_stats(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取规划器统计信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
规划器统计信息字典
|
||||||
|
"""
|
||||||
|
return self.planner.get_planner_stats()
|
||||||
|
|
||||||
|
def get_interest_scoring_stats(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取兴趣度评分统计信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
兴趣度评分统计信息字典
|
||||||
|
"""
|
||||||
|
return self.planner.get_interest_scoring_stats()
|
||||||
|
|
||||||
|
def get_relationship_stats(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取用户关系统计信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
用户关系统计信息字典
|
||||||
|
"""
|
||||||
|
return self.planner.get_relationship_stats()
|
||||||
|
|
||||||
|
def get_current_mood_state(self) -> str:
|
||||||
|
"""
|
||||||
|
获取当前聊天的情绪状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
当前情绪状态描述
|
||||||
|
"""
|
||||||
|
return self.planner.get_current_mood_state()
|
||||||
|
|
||||||
|
def get_mood_stats(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取情绪状态统计信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
情绪状态统计信息字典
|
||||||
|
"""
|
||||||
|
return self.planner.get_mood_stats()
|
||||||
|
|
||||||
|
def get_user_relationship(self, user_id: str) -> float:
|
||||||
|
"""
|
||||||
|
获取用户关系分
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
用户关系分 (0.0-1.0)
|
||||||
|
"""
|
||||||
|
return self.planner.get_user_relationship(user_id)
|
||||||
|
|
||||||
|
def update_interest_keywords(self, new_keywords: dict):
|
||||||
|
"""
|
||||||
|
更新兴趣关键词
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_keywords: 新的兴趣关键词字典
|
||||||
|
"""
|
||||||
|
self.planner.update_interest_keywords(new_keywords)
|
||||||
|
logger.info(f"聊天流 {self.stream_id} 已更新兴趣关键词: {list(new_keywords.keys())}")
|
||||||
|
|
||||||
|
def reset_stats(self):
|
||||||
|
"""重置统计信息"""
|
||||||
|
self.stats = {
|
||||||
|
"messages_processed": 0,
|
||||||
|
"plans_created": 0,
|
||||||
|
"actions_executed": 0,
|
||||||
|
"successful_executions": 0,
|
||||||
|
"failed_executions": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def is_active(self, max_inactive_minutes: int = 60) -> bool:
|
||||||
|
"""
|
||||||
|
检查处理器是否活跃
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_inactive_minutes: 最大不活跃分钟数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否活跃
|
||||||
|
"""
|
||||||
|
current_time = time.time()
|
||||||
|
max_inactive_seconds = max_inactive_minutes * 60
|
||||||
|
return (current_time - self.last_activity_time) < max_inactive_seconds
|
||||||
|
|
||||||
|
def get_activity_time(self) -> float:
|
||||||
|
"""
|
||||||
|
获取最后活动时间
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
最后活动时间戳
|
||||||
|
"""
|
||||||
|
return self.last_activity_time
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""字符串表示"""
|
||||||
|
return f"AffinityChatter(stream_id={self.stream_id}, messages={self.stats['messages_processed']})"
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""详细字符串表示"""
|
||||||
|
return (
|
||||||
|
f"AffinityChatter(stream_id={self.stream_id}, "
|
||||||
|
f"messages_processed={self.stats['messages_processed']}, "
|
||||||
|
f"plans_created={self.stats['plans_created']}, "
|
||||||
|
f"last_activity={datetime.fromtimestamp(self.last_activity_time)})"
|
||||||
|
)
|
||||||
333
src/plugins/built_in/affinity_flow_chatter/interest_scoring.py
Normal file
333
src/plugins/built_in/affinity_flow_chatter/interest_scoring.py
Normal file
@@ -0,0 +1,333 @@
|
|||||||
|
"""
|
||||||
|
兴趣度评分系统
|
||||||
|
基于多维度评分机制,包括兴趣匹配度、用户关系分、提及度和时间因子
|
||||||
|
现在使用embedding计算智能兴趣匹配
|
||||||
|
"""
|
||||||
|
|
||||||
|
import traceback
|
||||||
|
from typing import Dict, List, Any
|
||||||
|
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
from src.common.data_models.info_data_model import InterestScore
|
||||||
|
from src.chat.interest_system import bot_interest_manager
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker
|
||||||
|
logger = get_logger("chatter_interest_scoring")
|
||||||
|
|
||||||
|
# 定义颜色
|
||||||
|
SOFT_BLUE = "\033[38;5;67m"
|
||||||
|
RESET_COLOR = "\033[0m"
|
||||||
|
|
||||||
|
|
||||||
|
class ChatterInterestScoringSystem:
|
||||||
|
"""兴趣度评分系统"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# 智能兴趣匹配配置
|
||||||
|
self.use_smart_matching = True
|
||||||
|
|
||||||
|
# 从配置加载评分权重
|
||||||
|
affinity_config = global_config.affinity_flow
|
||||||
|
self.score_weights = {
|
||||||
|
"interest_match": affinity_config.keyword_match_weight, # 兴趣匹配度权重
|
||||||
|
"relationship": affinity_config.relationship_weight, # 关系分权重
|
||||||
|
"mentioned": affinity_config.mention_bot_weight, # 是否提及bot权重
|
||||||
|
}
|
||||||
|
|
||||||
|
# 评分阈值
|
||||||
|
self.reply_threshold = affinity_config.reply_action_interest_threshold # 回复动作兴趣阈值
|
||||||
|
self.mention_threshold = affinity_config.mention_bot_adjustment_threshold # 提及bot后的调整阈值
|
||||||
|
|
||||||
|
# 连续不回复概率提升
|
||||||
|
self.no_reply_count = 0
|
||||||
|
self.max_no_reply_count = affinity_config.max_no_reply_count
|
||||||
|
self.probability_boost_per_no_reply = (
|
||||||
|
affinity_config.no_reply_threshold_adjustment / affinity_config.max_no_reply_count
|
||||||
|
) # 每次不回复增加的概率
|
||||||
|
|
||||||
|
# 用户关系数据
|
||||||
|
self.user_relationships: Dict[str, float] = {} # user_id -> relationship_score
|
||||||
|
|
||||||
|
async def calculate_interest_scores(
|
||||||
|
self, messages: List[DatabaseMessages], bot_nickname: str
|
||||||
|
) -> List[InterestScore]:
|
||||||
|
"""计算消息的兴趣度评分"""
|
||||||
|
user_messages = [msg for msg in messages if str(msg.user_info.user_id) != str(global_config.bot.qq_account)]
|
||||||
|
if not user_messages:
|
||||||
|
return []
|
||||||
|
|
||||||
|
scores = []
|
||||||
|
for _, msg in enumerate(user_messages, 1):
|
||||||
|
score = await self._calculate_single_message_score(msg, bot_nickname)
|
||||||
|
scores.append(score)
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
async def _calculate_single_message_score(self, message: DatabaseMessages, bot_nickname: str) -> InterestScore:
|
||||||
|
"""计算单条消息的兴趣度评分"""
|
||||||
|
|
||||||
|
keywords = self._extract_keywords_from_database(message)
|
||||||
|
interest_match_score = await self._calculate_interest_match_score(message.processed_plain_text, keywords)
|
||||||
|
relationship_score = self._calculate_relationship_score(message.user_info.user_id)
|
||||||
|
mentioned_score = self._calculate_mentioned_score(message, bot_nickname)
|
||||||
|
|
||||||
|
total_score = (
|
||||||
|
interest_match_score * self.score_weights["interest_match"]
|
||||||
|
+ relationship_score * self.score_weights["relationship"]
|
||||||
|
+ mentioned_score * self.score_weights["mentioned"]
|
||||||
|
)
|
||||||
|
|
||||||
|
details = {
|
||||||
|
"interest_match": f"兴趣匹配: {interest_match_score:.3f}",
|
||||||
|
"relationship": f"关系: {relationship_score:.3f}",
|
||||||
|
"mentioned": f"提及: {mentioned_score:.3f}",
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"消息得分详情: {total_score:.3f} (匹配: {interest_match_score:.2f}, 关系: {relationship_score:.2f}, 提及: {mentioned_score:.2f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return InterestScore(
|
||||||
|
message_id=message.message_id,
|
||||||
|
total_score=total_score,
|
||||||
|
interest_match_score=interest_match_score,
|
||||||
|
relationship_score=relationship_score,
|
||||||
|
mentioned_score=mentioned_score,
|
||||||
|
details=details,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _calculate_interest_match_score(self, content: str, keywords: List[str] = None) -> float:
|
||||||
|
"""计算兴趣匹配度 - 使用智能embedding匹配"""
|
||||||
|
if not content:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# 使用智能匹配(embedding)
|
||||||
|
if self.use_smart_matching and bot_interest_manager.is_initialized:
|
||||||
|
return await self._calculate_smart_interest_match(content, keywords)
|
||||||
|
else:
|
||||||
|
# 智能匹配未初始化,返回默认分数
|
||||||
|
return 0.3
|
||||||
|
|
||||||
|
async def _calculate_smart_interest_match(self, content: str, keywords: List[str] = None) -> float:
|
||||||
|
"""使用embedding计算智能兴趣匹配"""
|
||||||
|
try:
|
||||||
|
# 如果没有传入关键词,则提取
|
||||||
|
if not keywords:
|
||||||
|
keywords = self._extract_keywords_from_content(content)
|
||||||
|
|
||||||
|
# 使用机器人兴趣管理器计算匹配度
|
||||||
|
match_result = await bot_interest_manager.calculate_interest_match(content, keywords)
|
||||||
|
|
||||||
|
if match_result:
|
||||||
|
# 返回匹配分数,考虑置信度和匹配标签数量
|
||||||
|
affinity_config = global_config.affinity_flow
|
||||||
|
match_count_bonus = min(
|
||||||
|
len(match_result.matched_tags) * affinity_config.match_count_bonus, affinity_config.max_match_bonus
|
||||||
|
)
|
||||||
|
final_score = match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus
|
||||||
|
return final_score
|
||||||
|
else:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"智能兴趣匹配计算失败: {e}")
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def _extract_keywords_from_database(self, message: DatabaseMessages) -> List[str]:
|
||||||
|
"""从数据库消息中提取关键词"""
|
||||||
|
keywords = []
|
||||||
|
|
||||||
|
# 尝试从 key_words 字段提取(存储的是JSON字符串)
|
||||||
|
if message.key_words:
|
||||||
|
try:
|
||||||
|
import orjson
|
||||||
|
|
||||||
|
keywords = orjson.loads(message.key_words)
|
||||||
|
if not isinstance(keywords, list):
|
||||||
|
keywords = []
|
||||||
|
except (orjson.JSONDecodeError, TypeError):
|
||||||
|
keywords = []
|
||||||
|
|
||||||
|
# 如果没有 keywords,尝试从 key_words_lite 提取
|
||||||
|
if not keywords and message.key_words_lite:
|
||||||
|
try:
|
||||||
|
import orjson
|
||||||
|
|
||||||
|
keywords = orjson.loads(message.key_words_lite)
|
||||||
|
if not isinstance(keywords, list):
|
||||||
|
keywords = []
|
||||||
|
except (orjson.JSONDecodeError, TypeError):
|
||||||
|
keywords = []
|
||||||
|
|
||||||
|
# 如果还是没有,从消息内容中提取(降级方案)
|
||||||
|
if not keywords:
|
||||||
|
keywords = self._extract_keywords_from_content(message.processed_plain_text)
|
||||||
|
|
||||||
|
return keywords[:15] # 返回前15个关键词
|
||||||
|
|
||||||
|
def _extract_keywords_from_content(self, content: str) -> List[str]:
|
||||||
|
"""从内容中提取关键词(降级方案)"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
# 清理文本
|
||||||
|
content = re.sub(r"[^\w\s\u4e00-\u9fff]", " ", content) # 保留中文、英文、数字
|
||||||
|
words = content.split()
|
||||||
|
|
||||||
|
# 过滤和关键词提取
|
||||||
|
keywords = []
|
||||||
|
for word in words:
|
||||||
|
word = word.strip()
|
||||||
|
if (
|
||||||
|
len(word) >= 2 # 至少2个字符
|
||||||
|
and word.isalnum() # 字母数字
|
||||||
|
and not word.isdigit()
|
||||||
|
): # 不是纯数字
|
||||||
|
keywords.append(word.lower())
|
||||||
|
|
||||||
|
# 去重并限制数量
|
||||||
|
unique_keywords = list(set(keywords))
|
||||||
|
return unique_keywords[:10] # 返回前10个唯一关键词
|
||||||
|
|
||||||
|
def _calculate_relationship_score(self, user_id: str) -> float:
|
||||||
|
"""计算关系分 - 从数据库获取关系分"""
|
||||||
|
# 优先使用内存中的关系分
|
||||||
|
if user_id in self.user_relationships:
|
||||||
|
relationship_value = self.user_relationships[user_id]
|
||||||
|
return min(relationship_value, 1.0)
|
||||||
|
|
||||||
|
# 如果内存中没有,尝试从关系追踪器获取
|
||||||
|
if hasattr(self, "relationship_tracker") and self.relationship_tracker:
|
||||||
|
try:
|
||||||
|
relationship_score = self.relationship_tracker.get_user_relationship_score(user_id)
|
||||||
|
# 同时更新内存缓存
|
||||||
|
self.user_relationships[user_id] = relationship_score
|
||||||
|
return relationship_score
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# 尝试从全局关系追踪器获取
|
||||||
|
try:
|
||||||
|
from .relationship_tracker import ChatterRelationshipTracker
|
||||||
|
|
||||||
|
global_tracker = ChatterRelationshipTracker()
|
||||||
|
if global_tracker:
|
||||||
|
relationship_score = global_tracker.get_user_relationship_score(user_id)
|
||||||
|
# 同时更新内存缓存
|
||||||
|
self.user_relationships[user_id] = relationship_score
|
||||||
|
return relationship_score
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 默认新用户的基础分
|
||||||
|
return global_config.affinity_flow.base_relationship_score
|
||||||
|
|
||||||
|
def _calculate_mentioned_score(self, msg: DatabaseMessages, bot_nickname: str) -> float:
|
||||||
|
"""计算提及分数"""
|
||||||
|
if not msg.processed_plain_text:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# 检查是否被提及
|
||||||
|
bot_aliases = [bot_nickname] + global_config.bot.alias_names
|
||||||
|
is_mentioned = msg.is_mentioned or any(alias in msg.processed_plain_text for alias in bot_aliases if alias)
|
||||||
|
|
||||||
|
# 如果被提及或是私聊,都视为提及了bot
|
||||||
|
if is_mentioned or not hasattr(msg, "chat_info_group_id"):
|
||||||
|
return global_config.affinity_flow.mention_bot_interest_score
|
||||||
|
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def should_reply(self, score: InterestScore, message: "DatabaseMessages") -> bool:
|
||||||
|
"""判断是否应该回复"""
|
||||||
|
base_threshold = self.reply_threshold
|
||||||
|
|
||||||
|
# 如果被提及,降低阈值
|
||||||
|
if score.mentioned_score >= global_config.affinity_flow.mention_bot_adjustment_threshold:
|
||||||
|
base_threshold = self.mention_threshold
|
||||||
|
|
||||||
|
# 计算连续不回复的概率提升
|
||||||
|
probability_boost = min(self.no_reply_count * self.probability_boost_per_no_reply, 0.8)
|
||||||
|
effective_threshold = base_threshold - probability_boost
|
||||||
|
|
||||||
|
# 做出决策
|
||||||
|
should_reply = score.total_score >= effective_threshold
|
||||||
|
decision = "回复" if should_reply else "不回复"
|
||||||
|
logger.info(
|
||||||
|
f"{SOFT_BLUE}决策: {decision} (兴趣度: {score.total_score:.3f} / 阈值: {effective_threshold:.3f}){RESET_COLOR}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return should_reply, score.total_score
|
||||||
|
|
||||||
|
def record_reply_action(self, did_reply: bool):
|
||||||
|
"""记录回复动作"""
|
||||||
|
old_count = self.no_reply_count
|
||||||
|
if did_reply:
|
||||||
|
self.no_reply_count = max(0, self.no_reply_count - global_config.affinity_flow.reply_cooldown_reduction)
|
||||||
|
action = "回复"
|
||||||
|
else:
|
||||||
|
self.no_reply_count += 1
|
||||||
|
action = "不回复"
|
||||||
|
|
||||||
|
# 限制最大计数
|
||||||
|
self.no_reply_count = min(self.no_reply_count, self.max_no_reply_count)
|
||||||
|
logger.info(f"动作: {action}, 连续不回复次数: {old_count} -> {self.no_reply_count}")
|
||||||
|
|
||||||
|
def update_user_relationship(self, user_id: str, relationship_change: float):
|
||||||
|
"""更新用户关系"""
|
||||||
|
old_score = self.user_relationships.get(
|
||||||
|
user_id, global_config.affinity_flow.base_relationship_score
|
||||||
|
) # 默认新用户分数
|
||||||
|
new_score = max(0.0, min(1.0, old_score + relationship_change))
|
||||||
|
|
||||||
|
self.user_relationships[user_id] = new_score
|
||||||
|
|
||||||
|
logger.info(f"用户关系: {user_id} | {old_score:.3f} → {new_score:.3f}")
|
||||||
|
|
||||||
|
def get_user_relationship(self, user_id: str) -> float:
|
||||||
|
"""获取用户关系分"""
|
||||||
|
return self.user_relationships.get(user_id, 0.3)
|
||||||
|
|
||||||
|
def get_scoring_stats(self) -> Dict:
|
||||||
|
"""获取评分系统统计"""
|
||||||
|
return {
|
||||||
|
"no_reply_count": self.no_reply_count,
|
||||||
|
"max_no_reply_count": self.max_no_reply_count,
|
||||||
|
"reply_threshold": self.reply_threshold,
|
||||||
|
"mention_threshold": self.mention_threshold,
|
||||||
|
"user_relationships": len(self.user_relationships),
|
||||||
|
}
|
||||||
|
|
||||||
|
def reset_stats(self):
|
||||||
|
"""重置统计信息"""
|
||||||
|
self.no_reply_count = 0
|
||||||
|
logger.info("重置兴趣度评分系统统计")
|
||||||
|
|
||||||
|
async def initialize_smart_interests(self, personality_description: str, personality_id: str = "default"):
|
||||||
|
"""初始化智能兴趣系统"""
|
||||||
|
try:
|
||||||
|
logger.info("开始初始化智能兴趣系统...")
|
||||||
|
logger.info(f"人设ID: {personality_id}, 描述长度: {len(personality_description)}")
|
||||||
|
|
||||||
|
await bot_interest_manager.initialize(personality_description, personality_id)
|
||||||
|
logger.info("智能兴趣系统初始化完成。")
|
||||||
|
|
||||||
|
# 显示初始化后的统计信息
|
||||||
|
bot_interest_manager.get_interest_stats()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"初始化智能兴趣系统失败: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
def get_matching_config(self) -> Dict[str, Any]:
|
||||||
|
"""获取匹配配置信息"""
|
||||||
|
return {
|
||||||
|
"use_smart_matching": self.use_smart_matching,
|
||||||
|
"smart_system_initialized": bot_interest_manager.is_initialized,
|
||||||
|
"smart_system_stats": bot_interest_manager.get_interest_stats()
|
||||||
|
if bot_interest_manager.is_initialized
|
||||||
|
else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 创建全局兴趣评分系统实例
|
||||||
|
chatter_interest_scoring_system = ChatterInterestScoringSystem()
|
||||||
368
src/plugins/built_in/affinity_flow_chatter/plan_executor.py
Normal file
368
src/plugins/built_in/affinity_flow_chatter/plan_executor.py
Normal file
@@ -0,0 +1,368 @@
|
|||||||
|
"""
|
||||||
|
PlanExecutor: 接收 Plan 对象并执行其中的所有动作。
|
||||||
|
集成用户关系追踪机制,自动记录交互并更新关系。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||||
|
from src.common.data_models.info_data_model import Plan, ActionPlannerInfo
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("plan_executor")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatterPlanExecutor:
|
||||||
|
"""
|
||||||
|
增强版PlanExecutor,集成用户关系追踪机制。
|
||||||
|
|
||||||
|
功能:
|
||||||
|
1. 执行Plan中的所有动作
|
||||||
|
2. 自动记录用户交互并添加到关系追踪
|
||||||
|
3. 分类执行回复动作和其他动作
|
||||||
|
4. 提供完整的执行统计和监控
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, action_manager: ChatterActionManager):
|
||||||
|
"""
|
||||||
|
初始化增强版PlanExecutor。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_manager (ChatterActionManager): 用于实际执行各种动作的管理器实例。
|
||||||
|
"""
|
||||||
|
self.action_manager = action_manager
|
||||||
|
|
||||||
|
# 执行统计
|
||||||
|
self.execution_stats = {
|
||||||
|
"total_executed": 0,
|
||||||
|
"successful_executions": 0,
|
||||||
|
"failed_executions": 0,
|
||||||
|
"reply_executions": 0,
|
||||||
|
"other_action_executions": 0,
|
||||||
|
"execution_times": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# 用户关系追踪引用
|
||||||
|
self.relationship_tracker = None
|
||||||
|
|
||||||
|
def set_relationship_tracker(self, relationship_tracker):
|
||||||
|
"""设置关系追踪器"""
|
||||||
|
self.relationship_tracker = relationship_tracker
|
||||||
|
|
||||||
|
async def execute(self, plan: Plan) -> Dict[str, any]:
|
||||||
|
"""
|
||||||
|
遍历并执行Plan对象中`decided_actions`列表里的所有动作。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plan (Plan): 包含待执行动作列表的Plan对象。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, any]: 执行结果统计信息
|
||||||
|
"""
|
||||||
|
if not plan.decided_actions:
|
||||||
|
logger.info("没有需要执行的动作。")
|
||||||
|
return {"executed_count": 0, "results": []}
|
||||||
|
|
||||||
|
# 像hfc一样,提前打印将要执行的动作
|
||||||
|
action_types = [action.action_type for action in plan.decided_actions]
|
||||||
|
logger.info(f"选择动作: {', '.join(action_types) if action_types else '无'}")
|
||||||
|
|
||||||
|
execution_results = []
|
||||||
|
reply_actions = []
|
||||||
|
other_actions = []
|
||||||
|
|
||||||
|
# 分类动作:回复动作和其他动作
|
||||||
|
for action_info in plan.decided_actions:
|
||||||
|
if action_info.action_type in ["reply", "proactive_reply"]:
|
||||||
|
reply_actions.append(action_info)
|
||||||
|
else:
|
||||||
|
other_actions.append(action_info)
|
||||||
|
|
||||||
|
# 执行回复动作(优先执行)
|
||||||
|
if reply_actions:
|
||||||
|
reply_result = await self._execute_reply_actions(reply_actions, plan)
|
||||||
|
execution_results.extend(reply_result["results"])
|
||||||
|
self.execution_stats["reply_executions"] += len(reply_actions)
|
||||||
|
|
||||||
|
# 将其他动作放入后台任务执行,避免阻塞主流程
|
||||||
|
if other_actions:
|
||||||
|
asyncio.create_task(self._execute_other_actions(other_actions, plan))
|
||||||
|
logger.info(f"已将 {len(other_actions)} 个其他动作放入后台任务执行。")
|
||||||
|
# 注意:后台任务的结果不会立即计入本次返回的统计数据
|
||||||
|
|
||||||
|
# 更新总体统计
|
||||||
|
self.execution_stats["total_executed"] += len(plan.decided_actions)
|
||||||
|
successful_count = sum(1 for r in execution_results if r["success"])
|
||||||
|
self.execution_stats["successful_executions"] += successful_count
|
||||||
|
self.execution_stats["failed_executions"] += len(execution_results) - successful_count
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"规划执行完成: 总数={len(plan.decided_actions)}, 成功={successful_count}, 失败={len(execution_results) - successful_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"executed_count": len(plan.decided_actions),
|
||||||
|
"successful_count": successful_count,
|
||||||
|
"failed_count": len(execution_results) - successful_count,
|
||||||
|
"results": execution_results,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _execute_reply_actions(self, reply_actions: List[ActionPlannerInfo], plan: Plan) -> Dict[str, any]:
|
||||||
|
"""执行回复动作"""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for action_info in reply_actions:
|
||||||
|
result = await self._execute_single_reply_action(action_info, plan)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return {"results": results}
|
||||||
|
|
||||||
|
async def _execute_single_reply_action(self, action_info: ActionPlannerInfo, plan: Plan) -> Dict[str, any]:
|
||||||
|
"""执行单个回复动作"""
|
||||||
|
start_time = time.time()
|
||||||
|
success = False
|
||||||
|
error_message = ""
|
||||||
|
reply_content = ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"执行回复动作: {action_info.action_type} (原因: {action_info.reasoning})")
|
||||||
|
|
||||||
|
# 获取用户ID - 兼容对象和字典
|
||||||
|
if hasattr(action_info.action_message, "user_info"):
|
||||||
|
user_id = action_info.action_message.user_info.user_id
|
||||||
|
else:
|
||||||
|
user_id = action_info.action_message.get("user_info", {}).get("user_id")
|
||||||
|
|
||||||
|
if user_id == str(global_config.bot.qq_account):
|
||||||
|
logger.warning("尝试回复自己,跳过此动作以防止死循环。")
|
||||||
|
return {
|
||||||
|
"action_type": action_info.action_type,
|
||||||
|
"success": False,
|
||||||
|
"error_message": "尝试回复自己,跳过此动作以防止死循环。",
|
||||||
|
"execution_time": 0,
|
||||||
|
"reasoning": action_info.reasoning,
|
||||||
|
"reply_content": "",
|
||||||
|
}
|
||||||
|
# 构建回复动作参数
|
||||||
|
action_params = {
|
||||||
|
"chat_id": plan.chat_id,
|
||||||
|
"target_message": action_info.action_message,
|
||||||
|
"reasoning": action_info.reasoning,
|
||||||
|
"action_data": action_info.action_data or {},
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(f"📬 [PlanExecutor] 准备调用 ActionManager,target_message: {action_info.action_message}")
|
||||||
|
|
||||||
|
# 通过动作管理器执行回复
|
||||||
|
reply_content = await self.action_manager.execute_action(
|
||||||
|
action_name=action_info.action_type, **action_params
|
||||||
|
)
|
||||||
|
|
||||||
|
success = True
|
||||||
|
logger.info(f"回复动作 '{action_info.action_type}' 执行成功。")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_message = str(e)
|
||||||
|
logger.error(f"执行回复动作失败: {action_info.action_type}, 错误: {error_message}")
|
||||||
|
|
||||||
|
# 记录用户关系追踪
|
||||||
|
if success and action_info.action_message:
|
||||||
|
await self._track_user_interaction(action_info, plan, reply_content)
|
||||||
|
|
||||||
|
execution_time = time.time() - start_time
|
||||||
|
self.execution_stats["execution_times"].append(execution_time)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"action_type": action_info.action_type,
|
||||||
|
"success": success,
|
||||||
|
"error_message": error_message,
|
||||||
|
"execution_time": execution_time,
|
||||||
|
"reasoning": action_info.reasoning,
|
||||||
|
"reply_content": reply_content[:200] + "..." if len(reply_content) > 200 else reply_content,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _execute_other_actions(self, other_actions: List[ActionPlannerInfo], plan: Plan) -> Dict[str, any]:
|
||||||
|
"""执行其他动作"""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# 并行执行其他动作
|
||||||
|
tasks = []
|
||||||
|
for action_info in other_actions:
|
||||||
|
task = self._execute_single_other_action(action_info, plan)
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
if tasks:
|
||||||
|
executed_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
for i, result in enumerate(executed_results):
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.error(f"执行动作 {other_actions[i].action_type} 时发生异常: {result}")
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"action_type": other_actions[i].action_type,
|
||||||
|
"success": False,
|
||||||
|
"error_message": str(result),
|
||||||
|
"execution_time": 0,
|
||||||
|
"reasoning": other_actions[i].reasoning,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return {"results": results}
|
||||||
|
|
||||||
|
async def _execute_single_other_action(self, action_info: ActionPlannerInfo, plan: Plan) -> Dict[str, any]:
|
||||||
|
"""执行单个其他动作"""
|
||||||
|
start_time = time.time()
|
||||||
|
success = False
|
||||||
|
error_message = ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"执行其他动作: {action_info.action_type} (原因: {action_info.reasoning})")
|
||||||
|
|
||||||
|
action_data = action_info.action_data or {}
|
||||||
|
|
||||||
|
# 针对 poke_user 动作,特殊处理
|
||||||
|
if action_info.action_type == "poke_user":
|
||||||
|
target_message = action_info.action_message
|
||||||
|
if target_message:
|
||||||
|
# 优先直接获取 user_id,这才是最可靠的信息
|
||||||
|
user_id = target_message.get("user_id")
|
||||||
|
if user_id:
|
||||||
|
action_data["user_id"] = user_id
|
||||||
|
logger.info(f"检测到戳一戳动作,目标用户ID: {user_id}")
|
||||||
|
else:
|
||||||
|
# 如果没有 user_id,再尝试用 user_nickname 作为备用方案
|
||||||
|
user_name = target_message.get("user_nickname")
|
||||||
|
if user_name:
|
||||||
|
action_data["user_name"] = user_name
|
||||||
|
logger.info(f"检测到戳一戳动作,目标用户: {user_name}")
|
||||||
|
else:
|
||||||
|
logger.warning("无法从戳一戳消息中获取用户ID或昵称。")
|
||||||
|
|
||||||
|
# 传递原始消息ID以支持引用
|
||||||
|
action_data["target_message_id"] = target_message.get("message_id")
|
||||||
|
|
||||||
|
# 构建动作参数
|
||||||
|
action_params = {
|
||||||
|
"chat_id": plan.chat_id,
|
||||||
|
"target_message": action_info.action_message,
|
||||||
|
"reasoning": action_info.reasoning,
|
||||||
|
"action_data": action_data,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 通过动作管理器执行动作
|
||||||
|
await self.action_manager.execute_action(action_name=action_info.action_type, **action_params)
|
||||||
|
|
||||||
|
success = True
|
||||||
|
logger.info(f"其他动作 '{action_info.action_type}' 执行成功。")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_message = str(e)
|
||||||
|
logger.error(f"执行其他动作失败: {action_info.action_type}, 错误: {error_message}")
|
||||||
|
|
||||||
|
execution_time = time.time() - start_time
|
||||||
|
self.execution_stats["execution_times"].append(execution_time)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"action_type": action_info.action_type,
|
||||||
|
"success": success,
|
||||||
|
"error_message": error_message,
|
||||||
|
"execution_time": execution_time,
|
||||||
|
"reasoning": action_info.reasoning,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _track_user_interaction(self, action_info: ActionPlannerInfo, plan: Plan, reply_content: str):
|
||||||
|
"""追踪用户交互 - 集成回复后关系追踪"""
|
||||||
|
try:
|
||||||
|
if not action_info.action_message:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 获取用户信息 - 处理对象和字典两种情况
|
||||||
|
if hasattr(action_info.action_message, "user_info"):
|
||||||
|
# 对象情况
|
||||||
|
user_info = action_info.action_message.user_info
|
||||||
|
user_id = user_info.user_id
|
||||||
|
user_name = user_info.user_nickname or user_id
|
||||||
|
user_message = action_info.action_message.content
|
||||||
|
else:
|
||||||
|
# 字典情况
|
||||||
|
user_info = action_info.action_message.get("user_info", {})
|
||||||
|
user_id = user_info.get("user_id")
|
||||||
|
user_name = user_info.get("user_nickname") or user_id
|
||||||
|
user_message = action_info.action_message.get("content", "")
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
logger.debug("跳过追踪:缺少用户ID")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 如果有设置关系追踪器,执行回复后关系追踪
|
||||||
|
if self.relationship_tracker:
|
||||||
|
# 记录基础交互信息(保持向后兼容)
|
||||||
|
self.relationship_tracker.add_interaction(
|
||||||
|
user_id=user_id,
|
||||||
|
user_name=user_name,
|
||||||
|
user_message=user_message,
|
||||||
|
bot_reply=reply_content,
|
||||||
|
reply_timestamp=time.time(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 执行新的回复后关系追踪
|
||||||
|
await self.relationship_tracker.track_reply_relationship(
|
||||||
|
user_id=user_id, user_name=user_name, bot_reply_content=reply_content, reply_timestamp=time.time()
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"已执行用户交互追踪: {user_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"追踪用户交互时出错: {e}")
|
||||||
|
logger.debug(f"action_message类型: {type(action_info.action_message)}")
|
||||||
|
logger.debug(f"action_message内容: {action_info.action_message}")
|
||||||
|
|
||||||
|
def get_execution_stats(self) -> Dict[str, any]:
|
||||||
|
"""获取执行统计信息"""
|
||||||
|
stats = self.execution_stats.copy()
|
||||||
|
|
||||||
|
# 计算平均执行时间
|
||||||
|
if stats["execution_times"]:
|
||||||
|
avg_time = sum(stats["execution_times"]) / len(stats["execution_times"])
|
||||||
|
stats["average_execution_time"] = avg_time
|
||||||
|
stats["max_execution_time"] = max(stats["execution_times"])
|
||||||
|
stats["min_execution_time"] = min(stats["execution_times"])
|
||||||
|
else:
|
||||||
|
stats["average_execution_time"] = 0
|
||||||
|
stats["max_execution_time"] = 0
|
||||||
|
stats["min_execution_time"] = 0
|
||||||
|
|
||||||
|
# 移除执行时间列表以避免返回过大数据
|
||||||
|
stats.pop("execution_times", None)
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def reset_stats(self):
|
||||||
|
"""重置统计信息"""
|
||||||
|
self.execution_stats = {
|
||||||
|
"total_executed": 0,
|
||||||
|
"successful_executions": 0,
|
||||||
|
"failed_executions": 0,
|
||||||
|
"reply_executions": 0,
|
||||||
|
"other_action_executions": 0,
|
||||||
|
"execution_times": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_recent_performance(self, limit: int = 10) -> List[Dict[str, any]]:
|
||||||
|
"""获取最近的执行性能"""
|
||||||
|
recent_times = self.execution_stats["execution_times"][-limit:]
|
||||||
|
if not recent_times:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"execution_index": i + 1,
|
||||||
|
"execution_time": time_val,
|
||||||
|
"timestamp": time.time() - (len(recent_times) - i) * 60, # 估算时间戳
|
||||||
|
}
|
||||||
|
for i, time_val in enumerate(recent_times)
|
||||||
|
]
|
||||||
678
src/plugins/built_in/affinity_flow_chatter/plan_filter.py
Normal file
678
src/plugins/built_in/affinity_flow_chatter/plan_filter.py
Normal file
@@ -0,0 +1,678 @@
|
|||||||
|
"""
|
||||||
|
PlanFilter: 接收 Plan 对象,根据不同模式的逻辑进行筛选,决定最终要执行的动作。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
import re
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from json_repair import repair_json
|
||||||
|
|
||||||
|
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||||
|
from src.chat.utils.chat_message_builder import (
|
||||||
|
build_readable_actions,
|
||||||
|
build_readable_messages_with_id,
|
||||||
|
get_actions_by_timestamp_with_chat,
|
||||||
|
)
|
||||||
|
from src.chat.utils.prompt import global_prompt_manager
|
||||||
|
from src.common.data_models.info_data_model import ActionPlannerInfo, Plan
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config, model_config
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
from src.mood.mood_manager import mood_manager
|
||||||
|
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ChatType
|
||||||
|
from src.schedule.schedule_manager import schedule_manager
|
||||||
|
|
||||||
|
logger = get_logger("plan_filter")
|
||||||
|
|
||||||
|
SAKURA_PINK = "\033[38;5;175m"
|
||||||
|
SKY_BLUE = "\033[38;5;117m"
|
||||||
|
RESET_COLOR = "\033[0m"
|
||||||
|
|
||||||
|
|
||||||
|
class ChatterPlanFilter:
|
||||||
|
"""
|
||||||
|
根据 Plan 中的模式和信息,筛选并决定最终的动作。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, chat_id: str, available_actions: List[str]):
|
||||||
|
"""
|
||||||
|
初始化动作计划筛选器。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id (str): 当前聊天的唯一标识符。
|
||||||
|
available_actions (List[str]): 当前可用的动作列表。
|
||||||
|
"""
|
||||||
|
self.chat_id = chat_id
|
||||||
|
self.available_actions = available_actions
|
||||||
|
self.planner_llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="planner")
|
||||||
|
self.last_obs_time_mark = 0.0
|
||||||
|
|
||||||
|
async def filter(self, reply_not_available: bool, plan: Plan) -> Plan:
|
||||||
|
"""
|
||||||
|
执行筛选逻辑,并填充 Plan 对象的 decided_actions 字段。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
prompt, used_message_id_list = await self._build_prompt(plan)
|
||||||
|
plan.llm_prompt = prompt
|
||||||
|
|
||||||
|
llm_content, _ = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||||
|
|
||||||
|
if llm_content:
|
||||||
|
try:
|
||||||
|
parsed_json = orjson.loads(repair_json(llm_content))
|
||||||
|
except orjson.JSONDecodeError:
|
||||||
|
parsed_json = {
|
||||||
|
"thinking": "",
|
||||||
|
"actions": {"action_type": "no_action", "reason": "返回内容无法解析为JSON"},
|
||||||
|
}
|
||||||
|
|
||||||
|
if "reply" in plan.available_actions and reply_not_available:
|
||||||
|
# 如果reply动作不可用,但llm返回的仍然有reply,则改为no_reply
|
||||||
|
if (
|
||||||
|
isinstance(parsed_json, dict)
|
||||||
|
and parsed_json.get("actions", {}).get("action_type", "") == "reply"
|
||||||
|
):
|
||||||
|
parsed_json["actions"]["action_type"] = "no_reply"
|
||||||
|
elif isinstance(parsed_json, list):
|
||||||
|
for item in parsed_json:
|
||||||
|
if isinstance(item, dict) and item.get("actions", {}).get("action_type", "") == "reply":
|
||||||
|
item["actions"]["action_type"] = "no_reply"
|
||||||
|
item["actions"]["reason"] += " (但由于兴趣度不足,reply动作不可用,已改为no_reply)"
|
||||||
|
|
||||||
|
if isinstance(parsed_json, dict):
|
||||||
|
parsed_json = [parsed_json]
|
||||||
|
|
||||||
|
if isinstance(parsed_json, list):
|
||||||
|
final_actions = []
|
||||||
|
reply_action_added = False
|
||||||
|
# 定义回复类动作的集合,方便扩展
|
||||||
|
reply_action_types = {"reply", "proactive_reply"}
|
||||||
|
|
||||||
|
for item in parsed_json:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 预解析 action_type 来进行判断
|
||||||
|
thinking = item.get("thinking", "未提供思考过程")
|
||||||
|
actions_obj = item.get("actions", {})
|
||||||
|
|
||||||
|
# 处理actions字段可能是字典或列表的情况
|
||||||
|
if isinstance(actions_obj, dict):
|
||||||
|
action_type = actions_obj.get("action_type", "no_action")
|
||||||
|
elif isinstance(actions_obj, list) and actions_obj:
|
||||||
|
# 如果是列表,取第一个元素的action_type
|
||||||
|
first_action = actions_obj[0]
|
||||||
|
if isinstance(first_action, dict):
|
||||||
|
action_type = first_action.get("action_type", "no_action")
|
||||||
|
else:
|
||||||
|
action_type = "no_action"
|
||||||
|
else:
|
||||||
|
action_type = "no_action"
|
||||||
|
|
||||||
|
if action_type in reply_action_types:
|
||||||
|
if not reply_action_added:
|
||||||
|
final_actions.extend(
|
||||||
|
await self._parse_single_action(item, used_message_id_list, plan)
|
||||||
|
)
|
||||||
|
reply_action_added = True
|
||||||
|
else:
|
||||||
|
# 非回复类动作直接添加
|
||||||
|
final_actions.extend(await self._parse_single_action(item, used_message_id_list, plan))
|
||||||
|
|
||||||
|
if thinking and thinking != "未提供思考过程":
|
||||||
|
logger.info(f"\n{SAKURA_PINK}思考: {thinking}{RESET_COLOR}\n")
|
||||||
|
plan.decided_actions = self._filter_no_actions(final_actions)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"筛选 Plan 时出错: {e}\n{traceback.format_exc()}")
|
||||||
|
plan.decided_actions = [ActionPlannerInfo(action_type="no_action", reasoning=f"筛选时出错: {e}")]
|
||||||
|
|
||||||
|
# 在返回最终计划前,打印将要执行的动作
|
||||||
|
action_types = [action.action_type for action in plan.decided_actions]
|
||||||
|
logger.info(f"选择动作: [{SKY_BLUE}{', '.join(action_types) if action_types else '无'}{RESET_COLOR}]")
|
||||||
|
|
||||||
|
return plan
|
||||||
|
|
||||||
|
async def _build_prompt(self, plan: Plan) -> tuple[str, list]:
|
||||||
|
"""
|
||||||
|
根据 Plan 对象构建提示词。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||||
|
bot_name = global_config.bot.nickname
|
||||||
|
bot_nickname = (
|
||||||
|
f",也有人叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else ""
|
||||||
|
)
|
||||||
|
bot_core_personality = global_config.personality.personality_core
|
||||||
|
identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:"
|
||||||
|
|
||||||
|
schedule_block = ""
|
||||||
|
# 优先检查是否被吵醒
|
||||||
|
from src.chat.message_manager.message_manager import message_manager
|
||||||
|
angry_prompt_addition = ""
|
||||||
|
wakeup_mgr = message_manager.wakeup_manager
|
||||||
|
|
||||||
|
# 双重检查确保愤怒状态不会丢失
|
||||||
|
# 检查1: 直接从 wakeup_manager 获取
|
||||||
|
if wakeup_mgr.is_in_angry_state():
|
||||||
|
angry_prompt_addition = wakeup_mgr.get_angry_prompt_addition()
|
||||||
|
|
||||||
|
# 检查2: 如果上面没获取到,再从 mood_manager 确认
|
||||||
|
if not angry_prompt_addition:
|
||||||
|
chat_mood_for_check = mood_manager.get_mood_by_chat_id(plan.chat_id)
|
||||||
|
if chat_mood_for_check.is_angry_from_wakeup:
|
||||||
|
angry_prompt_addition = global_config.sleep_system.angry_prompt
|
||||||
|
|
||||||
|
if angry_prompt_addition:
|
||||||
|
schedule_block = angry_prompt_addition
|
||||||
|
elif global_config.planning_system.schedule_enable:
|
||||||
|
if current_activity := schedule_manager.get_current_activity():
|
||||||
|
schedule_block = f"你当前正在:{current_activity},但注意它与群聊的聊天无关。"
|
||||||
|
|
||||||
|
mood_block = ""
|
||||||
|
# 如果被吵醒,则心情也是愤怒的,不需要另外的情绪模块
|
||||||
|
if not angry_prompt_addition and global_config.mood.enable_mood:
|
||||||
|
chat_mood = mood_manager.get_mood_by_chat_id(plan.chat_id)
|
||||||
|
mood_block = f"你现在的心情是:{chat_mood.mood_state}"
|
||||||
|
|
||||||
|
if plan.mode == ChatMode.PROACTIVE:
|
||||||
|
long_term_memory_block = await self._get_long_term_memory_context()
|
||||||
|
|
||||||
|
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||||
|
messages=[msg.flatten() for msg in plan.chat_history],
|
||||||
|
timestamp_mode="normal",
|
||||||
|
truncate=False,
|
||||||
|
show_actions=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_template = await global_prompt_manager.get_prompt_async("proactive_planner_prompt")
|
||||||
|
actions_before_now = get_actions_by_timestamp_with_chat(
|
||||||
|
chat_id=plan.chat_id,
|
||||||
|
timestamp_start=time.time() - 3600,
|
||||||
|
timestamp_end=time.time(),
|
||||||
|
limit=5,
|
||||||
|
)
|
||||||
|
actions_before_now_block = build_readable_actions(actions=actions_before_now)
|
||||||
|
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
||||||
|
|
||||||
|
prompt = prompt_template.format(
|
||||||
|
time_block=time_block,
|
||||||
|
identity_block=identity_block,
|
||||||
|
schedule_block=schedule_block,
|
||||||
|
mood_block=mood_block,
|
||||||
|
long_term_memory_block=long_term_memory_block,
|
||||||
|
chat_content_block=chat_content_block or "最近没有聊天内容。",
|
||||||
|
actions_before_now_block=actions_before_now_block,
|
||||||
|
)
|
||||||
|
return prompt, message_id_list
|
||||||
|
|
||||||
|
# 构建已读/未读历史消息
|
||||||
|
read_history_block, unread_history_block, message_id_list = await self._build_read_unread_history_blocks(
|
||||||
|
plan
|
||||||
|
)
|
||||||
|
|
||||||
|
# 为了兼容性,保留原有的chat_content_block
|
||||||
|
chat_content_block, _ = build_readable_messages_with_id(
|
||||||
|
messages=[msg.flatten() for msg in plan.chat_history],
|
||||||
|
timestamp_mode="normal",
|
||||||
|
read_mark=self.last_obs_time_mark,
|
||||||
|
truncate=True,
|
||||||
|
show_actions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
actions_before_now = get_actions_by_timestamp_with_chat(
|
||||||
|
chat_id=plan.chat_id,
|
||||||
|
timestamp_start=time.time() - 3600,
|
||||||
|
timestamp_end=time.time(),
|
||||||
|
limit=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
actions_before_now_block = build_readable_actions(actions=actions_before_now)
|
||||||
|
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
||||||
|
|
||||||
|
self.last_obs_time_mark = time.time()
|
||||||
|
|
||||||
|
mentioned_bonus = ""
|
||||||
|
if global_config.chat.mentioned_bot_inevitable_reply:
|
||||||
|
mentioned_bonus = "\n- 有人提到你"
|
||||||
|
if global_config.chat.at_bot_inevitable_reply:
|
||||||
|
mentioned_bonus = "\n- 有人提到你,或者at你"
|
||||||
|
|
||||||
|
if plan.mode == ChatMode.FOCUS:
|
||||||
|
no_action_block = """
|
||||||
|
动作:no_action
|
||||||
|
动作描述:不选择任何动作
|
||||||
|
{{
|
||||||
|
"action": "no_action",
|
||||||
|
"reason":"不动作的原因"
|
||||||
|
}}
|
||||||
|
|
||||||
|
动作:no_reply
|
||||||
|
动作描述:不进行回复,等待合适的回复时机
|
||||||
|
- 当你刚刚发送了消息,没有人回复时,选择no_reply
|
||||||
|
- 当你一次发送了太多消息,为了避免打扰聊天节奏,选择no_reply
|
||||||
|
{{
|
||||||
|
"action": "no_reply",
|
||||||
|
"reason":"不回复的原因"
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
else: # normal Mode
|
||||||
|
no_action_block = """重要说明:
|
||||||
|
- 'reply' 表示只进行普通聊天回复,不执行任何额外动作
|
||||||
|
- 其他action表示在普通回复的基础上,执行相应的额外动作
|
||||||
|
{{
|
||||||
|
"action": "reply",
|
||||||
|
"target_message_id":"触发action的消息id",
|
||||||
|
"reason":"回复的原因"
|
||||||
|
}}"""
|
||||||
|
|
||||||
|
is_group_chat = plan.chat_type == ChatType.GROUP
|
||||||
|
chat_context_description = "你现在正在一个群聊中"
|
||||||
|
if not is_group_chat and plan.target_info:
|
||||||
|
chat_target_name = plan.target_info.get("person_name") or plan.target_info.get("user_nickname") or "对方"
|
||||||
|
chat_context_description = f"你正在和 {chat_target_name} 私聊"
|
||||||
|
|
||||||
|
action_options_block = await self._build_action_options(plan.available_actions)
|
||||||
|
|
||||||
|
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
|
||||||
|
|
||||||
|
custom_prompt_block = ""
|
||||||
|
if global_config.custom_prompt.planner_custom_prompt_content:
|
||||||
|
custom_prompt_block = global_config.custom_prompt.planner_custom_prompt_content
|
||||||
|
|
||||||
|
users_in_chat_str = "" # TODO: Re-implement user list fetching if needed
|
||||||
|
|
||||||
|
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
||||||
|
prompt = planner_prompt_template.format(
|
||||||
|
schedule_block=schedule_block,
|
||||||
|
mood_block=mood_block,
|
||||||
|
time_block=time_block,
|
||||||
|
chat_context_description=chat_context_description,
|
||||||
|
read_history_block=read_history_block,
|
||||||
|
unread_history_block=unread_history_block,
|
||||||
|
actions_before_now_block=actions_before_now_block,
|
||||||
|
mentioned_bonus=mentioned_bonus,
|
||||||
|
no_action_block=no_action_block,
|
||||||
|
action_options_text=action_options_block,
|
||||||
|
moderation_prompt=moderation_prompt_block,
|
||||||
|
identity_block=identity_block,
|
||||||
|
custom_prompt_block=custom_prompt_block,
|
||||||
|
bot_name=bot_name,
|
||||||
|
users_in_chat=users_in_chat_str,
|
||||||
|
)
|
||||||
|
return prompt, message_id_list
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"构建 Planner 提示词时出错: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return "构建 Planner Prompt 时出错", []
|
||||||
|
|
||||||
|
async def _build_read_unread_history_blocks(self, plan: Plan) -> tuple[str, str, list]:
|
||||||
|
"""构建已读/未读历史消息块"""
|
||||||
|
try:
|
||||||
|
# 从message_manager获取真实的已读/未读消息
|
||||||
|
from src.chat.message_manager.message_manager import message_manager
|
||||||
|
from src.chat.utils.utils import assign_message_ids
|
||||||
|
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
|
||||||
|
|
||||||
|
# 获取聊天流的上下文
|
||||||
|
stream_context = message_manager.stream_contexts.get(plan.chat_id)
|
||||||
|
|
||||||
|
# 获取真正的已读和未读消息
|
||||||
|
read_messages = stream_context.history_messages # 已读消息存储在history_messages中
|
||||||
|
if not read_messages:
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
# 如果内存中没有已读消息(比如刚启动),则从数据库加载最近的上下文
|
||||||
|
fallback_messages_dicts = get_raw_msg_before_timestamp_with_chat(
|
||||||
|
chat_id=plan.chat_id,
|
||||||
|
timestamp=time.time(),
|
||||||
|
limit=global_config.chat.max_context_size,
|
||||||
|
)
|
||||||
|
# 将字典转换为DatabaseMessages对象
|
||||||
|
read_messages = [DatabaseMessages(**msg_dict) for msg_dict in fallback_messages_dicts]
|
||||||
|
|
||||||
|
unread_messages = stream_context.get_unread_messages() # 获取未读消息
|
||||||
|
|
||||||
|
# 构建已读历史消息块
|
||||||
|
if read_messages:
|
||||||
|
read_content, read_ids = build_readable_messages_with_id(
|
||||||
|
messages=[msg.flatten() for msg in read_messages[-50:]], # 限制数量
|
||||||
|
timestamp_mode="normal_no_YMD",
|
||||||
|
truncate=False,
|
||||||
|
show_actions=False,
|
||||||
|
)
|
||||||
|
read_history_block = f"{read_content}"
|
||||||
|
else:
|
||||||
|
read_history_block = "暂无已读历史消息"
|
||||||
|
|
||||||
|
# 构建未读历史消息块(包含兴趣度)
|
||||||
|
if unread_messages:
|
||||||
|
# 扁平化未读消息用于计算兴趣度和格式化
|
||||||
|
flattened_unread = [msg.flatten() for msg in unread_messages]
|
||||||
|
|
||||||
|
# 尝试获取兴趣度评分(返回以真实 message_id 为键的字典)
|
||||||
|
interest_scores = await self._get_interest_scores_for_messages(flattened_unread)
|
||||||
|
|
||||||
|
# 为未读消息分配短 id(保持与 build_readable_messages_with_id 的一致结构)
|
||||||
|
message_id_list = assign_message_ids(flattened_unread)
|
||||||
|
|
||||||
|
unread_lines = []
|
||||||
|
for idx, msg in enumerate(flattened_unread):
|
||||||
|
mapped = message_id_list[idx]
|
||||||
|
synthetic_id = mapped.get("id")
|
||||||
|
original_msg_id = msg.get("message_id") or msg.get("id")
|
||||||
|
msg_time = time.strftime("%H:%M:%S", time.localtime(msg.get("time", time.time())))
|
||||||
|
user_nickname = msg.get("user_nickname", "未知用户")
|
||||||
|
msg_content = msg.get("processed_plain_text", "")
|
||||||
|
|
||||||
|
# 不再显示兴趣度,但保留合成ID供模型内部使用
|
||||||
|
# 同时,为了让模型更好地理解上下文,我们显示用户名
|
||||||
|
unread_lines.append(f"<{synthetic_id}> {msg_time} {user_nickname}: {msg_content}")
|
||||||
|
|
||||||
|
unread_history_block = "\n".join(unread_lines)
|
||||||
|
else:
|
||||||
|
unread_history_block = "暂无未读历史消息"
|
||||||
|
|
||||||
|
return read_history_block, unread_history_block, message_id_list
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"构建已读/未读历史消息块时出错: {e}")
|
||||||
|
return "构建已读历史消息时出错", "构建未读历史消息时出错", []
|
||||||
|
|
||||||
|
async def _get_interest_scores_for_messages(self, messages: List[dict]) -> dict[str, float]:
|
||||||
|
"""为消息获取兴趣度评分"""
|
||||||
|
interest_scores = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .interest_scoring import chatter_interest_scoring_system
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|
||||||
|
# 使用插件内部的兴趣度评分系统计算评分
|
||||||
|
for msg_dict in messages:
|
||||||
|
try:
|
||||||
|
# 将字典转换为DatabaseMessages对象
|
||||||
|
db_message = DatabaseMessages(
|
||||||
|
message_id=msg_dict.get("message_id", ""),
|
||||||
|
user_info=msg_dict.get("user_info", {}),
|
||||||
|
processed_plain_text=msg_dict.get("processed_plain_text", ""),
|
||||||
|
key_words=msg_dict.get("key_words", "[]"),
|
||||||
|
is_mentioned=msg_dict.get("is_mentioned", False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算消息兴趣度
|
||||||
|
interest_score_obj = await chatter_interest_scoring_system._calculate_single_message_score(
|
||||||
|
message=db_message,
|
||||||
|
bot_nickname=global_config.bot.nickname
|
||||||
|
)
|
||||||
|
interest_score = interest_score_obj.total_score
|
||||||
|
|
||||||
|
# 构建兴趣度字典
|
||||||
|
interest_scores[msg_dict.get("message_id", "")] = interest_score
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"计算消息兴趣度失败: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"获取兴趣度评分失败: {e}")
|
||||||
|
|
||||||
|
return interest_scores
|
||||||
|
|
||||||
|
async def _parse_single_action(
|
||||||
|
self, action_json: dict, message_id_list: list, plan: Plan
|
||||||
|
) -> List[ActionPlannerInfo]:
|
||||||
|
parsed_actions = []
|
||||||
|
try:
|
||||||
|
# 从新的actions结构中获取动作信息
|
||||||
|
actions_obj = action_json.get("actions", {})
|
||||||
|
|
||||||
|
# 处理actions字段可能是字典或列表的情况
|
||||||
|
actions_to_process = []
|
||||||
|
if isinstance(actions_obj, dict):
|
||||||
|
actions_to_process.append(actions_obj)
|
||||||
|
elif isinstance(actions_obj, list):
|
||||||
|
actions_to_process.extend(actions_obj)
|
||||||
|
|
||||||
|
if not actions_to_process:
|
||||||
|
actions_to_process.append({"action_type": "no_action", "reason": "actions格式错误"})
|
||||||
|
|
||||||
|
for single_action_obj in actions_to_process:
|
||||||
|
if not isinstance(single_action_obj, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
action = single_action_obj.get("action_type", "no_action")
|
||||||
|
reasoning = single_action_obj.get("reasoning", "未提供原因") # 兼容旧的reason字段
|
||||||
|
action_data = single_action_obj.get("action_data", {})
|
||||||
|
|
||||||
|
# 为了向后兼容,如果action_data不存在,则从顶层字段获取
|
||||||
|
if not action_data:
|
||||||
|
action_data = {k: v for k, v in single_action_obj.items() if k not in ["action_type", "reason", "reasoning", "thinking"]}
|
||||||
|
|
||||||
|
# 保留原始的thinking字段(如果有)
|
||||||
|
thinking = action_json.get("thinking", "")
|
||||||
|
if thinking and thinking != "未提供思考过程":
|
||||||
|
action_data["thinking"] = thinking
|
||||||
|
|
||||||
|
target_message_obj = None
|
||||||
|
if action not in ["no_action", "no_reply", "do_nothing", "proactive_reply"]:
|
||||||
|
if target_message_id := action_data.get("target_message_id"):
|
||||||
|
target_message_dict = self._find_message_by_id(target_message_id, message_id_list)
|
||||||
|
else:
|
||||||
|
# 如果LLM没有指定target_message_id,进行特殊处理
|
||||||
|
if action == "poke_user":
|
||||||
|
# 对于poke_user,尝试找到触发它的那条戳一戳消息
|
||||||
|
target_message_dict = self._find_poke_notice(message_id_list)
|
||||||
|
if not target_message_dict:
|
||||||
|
# 如果找不到,再使用最新消息作为兜底
|
||||||
|
target_message_dict = self._get_latest_message(message_id_list)
|
||||||
|
else:
|
||||||
|
# 其他动作,默认选择最新的一条消息
|
||||||
|
target_message_dict = self._get_latest_message(message_id_list)
|
||||||
|
|
||||||
|
if target_message_dict:
|
||||||
|
# 直接使用字典作为action_message,避免DatabaseMessages对象创建失败
|
||||||
|
target_message_obj = target_message_dict
|
||||||
|
# 替换action_data中的临时ID为真实ID
|
||||||
|
if "target_message_id" in action_data:
|
||||||
|
real_message_id = target_message_dict.get("message_id") or target_message_dict.get("id")
|
||||||
|
if real_message_id:
|
||||||
|
action_data["target_message_id"] = real_message_id
|
||||||
|
|
||||||
|
# 确保 action_message 中始终有 message_id 字段
|
||||||
|
if "message_id" not in target_message_obj and "id" in target_message_obj:
|
||||||
|
target_message_obj["message_id"] = target_message_obj["id"]
|
||||||
|
else:
|
||||||
|
# 如果找不到目标消息,对于reply动作来说这是必需的,应该记录警告
|
||||||
|
if action == "reply":
|
||||||
|
logger.warning(
|
||||||
|
f"reply动作找不到目标消息,target_message_id: {action_data.get('target_message_id')}"
|
||||||
|
)
|
||||||
|
# 将reply动作改为no_action,避免后续执行时出错
|
||||||
|
action = "no_action"
|
||||||
|
reasoning = f"找不到目标消息进行回复。原始理由: {reasoning}"
|
||||||
|
|
||||||
|
if (
|
||||||
|
action not in ["no_action", "no_reply", "reply", "do_nothing", "proactive_reply"]
|
||||||
|
and action not in plan.available_actions
|
||||||
|
):
|
||||||
|
reasoning = f"LLM 返回了当前不可用的动作 '{action}'。原始理由: {reasoning}"
|
||||||
|
action = "no_action"
|
||||||
|
|
||||||
|
parsed_actions.append(
|
||||||
|
ActionPlannerInfo(
|
||||||
|
action_type=action,
|
||||||
|
reasoning=reasoning,
|
||||||
|
action_data=action_data,
|
||||||
|
action_message=target_message_obj,
|
||||||
|
available_actions=plan.available_actions,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"解析单个action时出错: {e}")
|
||||||
|
parsed_actions.append(
|
||||||
|
ActionPlannerInfo(
|
||||||
|
action_type="no_action",
|
||||||
|
reasoning=f"解析action时出错: {e}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return parsed_actions
|
||||||
|
|
||||||
|
def _filter_no_actions(self, action_list: List[ActionPlannerInfo]) -> List[ActionPlannerInfo]:
|
||||||
|
non_no_actions = [a for a in action_list if a.action_type not in ["no_action", "no_reply"]]
|
||||||
|
if non_no_actions:
|
||||||
|
return non_no_actions
|
||||||
|
return action_list[:1] if action_list else []
|
||||||
|
|
||||||
|
async def _get_long_term_memory_context(self) -> str:
|
||||||
|
try:
|
||||||
|
now = datetime.now()
|
||||||
|
keywords = ["今天", "日程", "计划"]
|
||||||
|
if 5 <= now.hour < 12:
|
||||||
|
keywords.append("早上")
|
||||||
|
elif 12 <= now.hour < 18:
|
||||||
|
keywords.append("中午")
|
||||||
|
else:
|
||||||
|
keywords.append("晚上")
|
||||||
|
|
||||||
|
retrieved_memories = await hippocampus_manager.get_memory_from_topic(
|
||||||
|
valid_keywords=keywords, max_memory_num=5, max_memory_length=1
|
||||||
|
)
|
||||||
|
|
||||||
|
if not retrieved_memories:
|
||||||
|
return "最近没有什么特别的记忆。"
|
||||||
|
|
||||||
|
memory_statements = [f"关于'{topic}', 你记得'{memory_item}'。" for topic, memory_item in retrieved_memories]
|
||||||
|
return " ".join(memory_statements)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取长期记忆时出错: {e}")
|
||||||
|
return "回忆时出现了一些问题。"
|
||||||
|
|
||||||
|
async def _build_action_options(self, current_available_actions: Dict[str, ActionInfo]) -> str:
|
||||||
|
action_options_block = ""
|
||||||
|
for action_name, action_info in current_available_actions.items():
|
||||||
|
# 构建参数的JSON示例
|
||||||
|
params_json_list = []
|
||||||
|
if action_info.action_parameters:
|
||||||
|
for p_name, p_desc in action_info.action_parameters.items():
|
||||||
|
# 为参数描述添加一个通用示例值
|
||||||
|
if action_name == "set_emoji_like" and p_name == "emoji":
|
||||||
|
# 特殊处理set_emoji_like的emoji参数
|
||||||
|
from plugins.set_emoji_like.qq_emoji_list import qq_face
|
||||||
|
emoji_options = [re.search(r"\[表情:(.+?)\]", name).group(1) for name in qq_face.values() if re.search(r"\[表情:(.+?)\]", name)]
|
||||||
|
example_value = f"<从'{', '.join(emoji_options[:10])}...'中选择一个>"
|
||||||
|
else:
|
||||||
|
example_value = f"<{p_desc}>"
|
||||||
|
params_json_list.append(f' "{p_name}": "{example_value}"')
|
||||||
|
|
||||||
|
# 基础动作信息
|
||||||
|
action_description = action_info.description
|
||||||
|
action_require = "\n".join(f"- {req}" for req in action_info.action_require)
|
||||||
|
|
||||||
|
# 构建完整的JSON使用范例
|
||||||
|
json_example_lines = [
|
||||||
|
" {",
|
||||||
|
f' "action_type": "{action_name}"',
|
||||||
|
]
|
||||||
|
# 将参数列表合并到JSON示例中
|
||||||
|
if params_json_list:
|
||||||
|
# 移除最后一行的逗号
|
||||||
|
json_example_lines.extend([line.rstrip(',') for line in params_json_list])
|
||||||
|
|
||||||
|
json_example_lines.append(' "reason": "<执行该动作的详细原因>"')
|
||||||
|
json_example_lines.append(" }")
|
||||||
|
|
||||||
|
# 使用逗号连接内部元素,除了最后一个
|
||||||
|
json_parts = []
|
||||||
|
for i, line in enumerate(json_example_lines):
|
||||||
|
# "{" 和 "}" 不需要逗号
|
||||||
|
if line.strip() in ["{", "}"]:
|
||||||
|
json_parts.append(line)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查是否是最后一个需要逗号的元素
|
||||||
|
is_last_item = True
|
||||||
|
for next_line in json_example_lines[i+1:]:
|
||||||
|
if next_line.strip() not in ["}"]:
|
||||||
|
is_last_item = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if not is_last_item:
|
||||||
|
json_parts.append(f"{line},")
|
||||||
|
else:
|
||||||
|
json_parts.append(line)
|
||||||
|
|
||||||
|
json_example = "\n".join(json_parts)
|
||||||
|
|
||||||
|
# 使用新的、更详细的action_prompt模板
|
||||||
|
using_action_prompt = await global_prompt_manager.get_prompt_async("action_prompt_with_example")
|
||||||
|
action_options_block += using_action_prompt.format(
|
||||||
|
action_name=action_name,
|
||||||
|
action_description=action_description,
|
||||||
|
action_require=action_require,
|
||||||
|
json_example=json_example,
|
||||||
|
)
|
||||||
|
return action_options_block
|
||||||
|
|
||||||
|
def _find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]:
|
||||||
|
# 兼容多种 message_id 格式:数字、m123、buffered-xxxx
|
||||||
|
# 如果是纯数字,补上 m 前缀以兼容旧格式
|
||||||
|
candidate_ids = {message_id}
|
||||||
|
if message_id.isdigit():
|
||||||
|
candidate_ids.add(f"m{message_id}")
|
||||||
|
|
||||||
|
# 如果是 m 开头且后面是数字,尝试去掉 m 前缀的数字形式
|
||||||
|
if message_id.startswith("m") and message_id[1:].isdigit():
|
||||||
|
candidate_ids.add(message_id[1:])
|
||||||
|
|
||||||
|
# 逐项匹配 message_id_list(每项可能为 {'id':..., 'message':...})
|
||||||
|
for item in message_id_list:
|
||||||
|
# 支持 message_id_list 中直接是字符串/ID 的情形
|
||||||
|
if isinstance(item, str):
|
||||||
|
if item in candidate_ids:
|
||||||
|
# 没有 message 对象,返回None
|
||||||
|
return None
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
item_id = item.get("id")
|
||||||
|
# 直接匹配分配的短 id
|
||||||
|
if item_id and item_id in candidate_ids:
|
||||||
|
return item.get("message")
|
||||||
|
|
||||||
|
# 有时 message 存储里会有原始的 message_id 字段(如 buffered-xxxx)
|
||||||
|
message_obj = item.get("message")
|
||||||
|
if isinstance(message_obj, dict):
|
||||||
|
orig_mid = message_obj.get("message_id") or message_obj.get("id")
|
||||||
|
if orig_mid and orig_mid in candidate_ids:
|
||||||
|
return message_obj
|
||||||
|
|
||||||
|
# 作为兜底,尝试在 message_id_list 中找到 message.message_id 匹配
|
||||||
|
for item in message_id_list:
|
||||||
|
if isinstance(item, dict) and isinstance(item.get("message"), dict):
|
||||||
|
mid = item["message"].get("message_id") or item["message"].get("id")
|
||||||
|
if mid == message_id:
|
||||||
|
return item["message"]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_latest_message(self, message_id_list: list) -> Optional[Dict[str, Any]]:
|
||||||
|
if not message_id_list:
|
||||||
|
return None
|
||||||
|
return message_id_list[-1].get("message")
|
||||||
|
|
||||||
|
def _find_poke_notice(self, message_id_list: list) -> Optional[Dict[str, Any]]:
|
||||||
|
"""在消息列表中寻找戳一戳的通知消息"""
|
||||||
|
for item in reversed(message_id_list):
|
||||||
|
message = item.get("message")
|
||||||
|
if (
|
||||||
|
isinstance(message, dict)
|
||||||
|
and message.get("type") == "notice"
|
||||||
|
and "戳" in message.get("processed_plain_text", "")
|
||||||
|
):
|
||||||
|
return message
|
||||||
|
return None
|
||||||
168
src/plugins/built_in/affinity_flow_chatter/plan_generator.py
Normal file
168
src/plugins/built_in/affinity_flow_chatter/plan_generator.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""
|
||||||
|
PlanGenerator: 负责搜集和汇总所有决策所需的信息,生成一个未经筛选的"原始计划" (Plan)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
|
||||||
|
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
from src.common.data_models.info_data_model import Plan, TargetPersonInfo
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ChatType
|
||||||
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
|
||||||
|
|
||||||
|
class ChatterPlanGenerator:
|
||||||
|
"""
|
||||||
|
ChatterPlanGenerator 负责在规划流程的初始阶段收集所有必要信息。
|
||||||
|
|
||||||
|
它会汇总以下信息来构建一个"原始"的 Plan 对象,该对象后续会由 PlanFilter 进行筛选:
|
||||||
|
- 当前聊天信息 (ID, 目标用户)
|
||||||
|
- 当前可用的动作列表
|
||||||
|
- 最近的聊天历史记录
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
chat_id (str): 当前聊天的唯一标识符。
|
||||||
|
action_manager (ActionManager): 用于获取可用动作列表的管理器。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, chat_id: str):
|
||||||
|
"""
|
||||||
|
初始化 ChatterPlanGenerator。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id (str): 当前聊天的 ID。
|
||||||
|
"""
|
||||||
|
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||||
|
|
||||||
|
self.chat_id = chat_id
|
||||||
|
# 注意:ChatterActionManager 可能需要根据实际情况初始化
|
||||||
|
self.action_manager = ChatterActionManager()
|
||||||
|
|
||||||
|
async def generate(self, mode: ChatMode) -> Plan:
|
||||||
|
"""
|
||||||
|
收集所有信息,生成并返回一个初始的 Plan 对象。
|
||||||
|
|
||||||
|
这个 Plan 对象包含了决策所需的所有上下文信息。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mode (ChatMode): 当前的聊天模式。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plan: 包含所有上下文信息的初始计划对象。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 获取聊天类型和目标信息
|
||||||
|
chat_type, target_info = get_chat_type_and_target_info(self.chat_id)
|
||||||
|
|
||||||
|
# 获取可用动作列表
|
||||||
|
available_actions = await self._get_available_actions(chat_type, mode)
|
||||||
|
|
||||||
|
# 获取聊天历史记录
|
||||||
|
recent_messages = await self._get_recent_messages()
|
||||||
|
|
||||||
|
# 构建计划对象
|
||||||
|
plan = Plan(
|
||||||
|
chat_id=self.chat_id,
|
||||||
|
chat_type=chat_type,
|
||||||
|
mode=mode,
|
||||||
|
target_info=target_info,
|
||||||
|
available_actions=available_actions,
|
||||||
|
chat_history=recent_messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
return plan
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# 如果生成失败,返回一个基本的空计划
|
||||||
|
return Plan(
|
||||||
|
chat_id=self.chat_id,
|
||||||
|
mode=mode,
|
||||||
|
target_info=TargetPersonInfo(),
|
||||||
|
available_actions={},
|
||||||
|
chat_history=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_available_actions(self, chat_type: ChatType, mode: ChatMode) -> Dict[str, ActionInfo]:
|
||||||
|
"""
|
||||||
|
获取当前可用的动作列表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_type (ChatType): 聊天类型。
|
||||||
|
mode (ChatMode): 聊天模式。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, ActionInfo]: 可用动作的字典。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 从组件注册表获取可用动作
|
||||||
|
available_actions = component_registry.get_enabled_actions()
|
||||||
|
|
||||||
|
# 根据聊天类型和模式筛选动作
|
||||||
|
filtered_actions = {}
|
||||||
|
for action_name, action_info in available_actions.items():
|
||||||
|
# 检查动作是否支持当前聊天类型
|
||||||
|
if chat_type in action_info.chat_types:
|
||||||
|
# 检查动作是否支持当前模式
|
||||||
|
if mode in action_info.chat_modes:
|
||||||
|
filtered_actions[action_name] = action_info
|
||||||
|
|
||||||
|
return filtered_actions
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# 如果获取失败,返回空字典
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def _get_recent_messages(self) -> list[DatabaseMessages]:
|
||||||
|
"""
|
||||||
|
获取最近的聊天历史记录。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[DatabaseMessages]: 最近的聊天消息列表。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 获取最近的消息记录
|
||||||
|
raw_messages = get_raw_msg_before_timestamp_with_chat(
|
||||||
|
chat_id=self.chat_id, timestamp=time.time(), limit=global_config.memory.short_memory_length
|
||||||
|
)
|
||||||
|
|
||||||
|
# 转换为 DatabaseMessages 对象
|
||||||
|
recent_messages = []
|
||||||
|
for msg in raw_messages:
|
||||||
|
try:
|
||||||
|
db_msg = DatabaseMessages(
|
||||||
|
message_id=msg.get("message_id", ""),
|
||||||
|
time=float(msg.get("time", 0)),
|
||||||
|
chat_id=msg.get("chat_id", ""),
|
||||||
|
processed_plain_text=msg.get("processed_plain_text", ""),
|
||||||
|
user_id=msg.get("user_id", ""),
|
||||||
|
user_nickname=msg.get("user_nickname", ""),
|
||||||
|
user_platform=msg.get("user_platform", ""),
|
||||||
|
)
|
||||||
|
recent_messages.append(db_msg)
|
||||||
|
except Exception:
|
||||||
|
# 跳过格式错误的消息
|
||||||
|
continue
|
||||||
|
|
||||||
|
return recent_messages
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# 如果获取失败,返回空列表
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_generator_stats(self) -> Dict:
|
||||||
|
"""
|
||||||
|
获取生成器统计信息。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 统计信息字典。
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"chat_id": self.chat_id,
|
||||||
|
"action_count": len(self.action_manager._using_actions)
|
||||||
|
if hasattr(self.action_manager, "_using_actions")
|
||||||
|
else 0,
|
||||||
|
"generation_time": time.time(),
|
||||||
|
}
|
||||||
269
src/plugins/built_in/affinity_flow_chatter/planner.py
Normal file
269
src/plugins/built_in/affinity_flow_chatter/planner.py
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
"""
|
||||||
|
主规划器入口,负责协调 PlanGenerator, PlanFilter, 和 PlanExecutor。
|
||||||
|
集成兴趣度评分系统和用户关系追踪机制,实现智能化的聊天决策。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import asdict
|
||||||
|
import time
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor
|
||||||
|
from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter
|
||||||
|
from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator
|
||||||
|
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||||
|
from src.mood.mood_manager import mood_manager
|
||||||
|
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.common.data_models.message_manager_data_model import StreamContext
|
||||||
|
from src.common.data_models.info_data_model import Plan
|
||||||
|
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||||
|
|
||||||
|
# 导入提示词模块以确保其被初始化
|
||||||
|
from src.plugins.built_in.affinity_flow_chatter import planner_prompts # noqa
|
||||||
|
|
||||||
|
logger = get_logger("planner")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatterActionPlanner:
|
||||||
|
"""
|
||||||
|
增强版ActionPlanner,集成兴趣度评分和用户关系追踪机制。
|
||||||
|
|
||||||
|
核心功能:
|
||||||
|
1. 兴趣度评分系统:根据兴趣匹配度、关系分、提及度、时间因子对消息评分
|
||||||
|
2. 用户关系追踪:自动追踪用户交互并更新关系分
|
||||||
|
3. 智能回复决策:基于兴趣度阈值和连续不回复概率的智能决策
|
||||||
|
4. 完整的规划流程:生成→筛选→执行的完整三阶段流程
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, chat_id: str, action_manager: "ChatterActionManager"):
|
||||||
|
"""
|
||||||
|
初始化增强版ActionPlanner。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id (str): 当前聊天的 ID。
|
||||||
|
action_manager (ChatterActionManager): 一个 ChatterActionManager 实例。
|
||||||
|
"""
|
||||||
|
self.chat_id = chat_id
|
||||||
|
self.action_manager = action_manager
|
||||||
|
self.generator = ChatterPlanGenerator(chat_id)
|
||||||
|
self.executor = ChatterPlanExecutor(action_manager)
|
||||||
|
|
||||||
|
# 使用新的统一兴趣度管理系统
|
||||||
|
|
||||||
|
# 规划器统计
|
||||||
|
self.planner_stats = {
|
||||||
|
"total_plans": 0,
|
||||||
|
"successful_plans": 0,
|
||||||
|
"failed_plans": 0,
|
||||||
|
"replies_generated": 0,
|
||||||
|
"other_actions_executed": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def plan(self, context: "StreamContext" = None) -> Tuple[List[Dict], Optional[Dict]]:
|
||||||
|
"""
|
||||||
|
执行完整的增强版规划流程。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context (StreamContext): 包含聊天流消息的上下文对象。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[List[Dict], Optional[Dict]]: 一个元组,包含:
|
||||||
|
- final_actions_dict (List[Dict]): 最终确定的动作列表(字典格式)。
|
||||||
|
- final_target_message_dict (Optional[Dict]): 最终的目标消息(字典格式)。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.planner_stats["total_plans"] += 1
|
||||||
|
|
||||||
|
return await self._enhanced_plan_flow(context)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"规划流程出错: {e}")
|
||||||
|
self.planner_stats["failed_plans"] += 1
|
||||||
|
return [], None
|
||||||
|
|
||||||
|
async def _enhanced_plan_flow(self, context: "StreamContext") -> Tuple[List[Dict], Optional[Dict]]:
|
||||||
|
"""执行增强版规划流程"""
|
||||||
|
try:
|
||||||
|
# 在规划前,先进行动作修改
|
||||||
|
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||||
|
action_modifier = ActionModifier(self.action_manager, self.chat_id)
|
||||||
|
await action_modifier.modify_actions()
|
||||||
|
|
||||||
|
# 1. 生成初始 Plan
|
||||||
|
initial_plan = await self.generator.generate(context.chat_mode)
|
||||||
|
|
||||||
|
# 确保Plan中包含所有当前可用的动作
|
||||||
|
initial_plan.available_actions = self.action_manager.get_using_actions()
|
||||||
|
|
||||||
|
unread_messages = context.get_unread_messages() if context else []
|
||||||
|
# 2. 使用新的兴趣度管理系统进行评分
|
||||||
|
score = 0.0
|
||||||
|
should_reply = False
|
||||||
|
reply_not_available = False
|
||||||
|
|
||||||
|
if unread_messages:
|
||||||
|
# 获取用户ID,优先从user_info.user_id获取,其次从user_id属性获取
|
||||||
|
user_id = None
|
||||||
|
first_message = unread_messages[0]
|
||||||
|
user_id = first_message.user_info.user_id
|
||||||
|
|
||||||
|
# 构建计算上下文
|
||||||
|
calc_context = {
|
||||||
|
"stream_id": self.chat_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 为每条消息计算兴趣度
|
||||||
|
for message in unread_messages:
|
||||||
|
try:
|
||||||
|
# 使用插件内部的兴趣度评分系统计算
|
||||||
|
interest_score = await chatter_interest_scoring_system._calculate_single_message_score(
|
||||||
|
message=message,
|
||||||
|
bot_nickname=global_config.bot.nickname
|
||||||
|
)
|
||||||
|
message_interest = interest_score.total_score
|
||||||
|
|
||||||
|
# 更新消息的兴趣度
|
||||||
|
message.interest_value = message_interest
|
||||||
|
|
||||||
|
# 简单的回复决策逻辑:兴趣度超过阈值则回复
|
||||||
|
message.should_reply = message_interest > global_config.affinity_flow.non_reply_action_interest_threshold
|
||||||
|
|
||||||
|
logger.debug(f"消息 {message.message_id} 兴趣度: {message_interest:.3f}, 应回复: {message.should_reply}")
|
||||||
|
|
||||||
|
# 更新StreamContext中的消息信息并刷新focus_energy
|
||||||
|
if context:
|
||||||
|
from src.chat.message_manager.message_manager import message_manager
|
||||||
|
message_manager.update_message(
|
||||||
|
stream_id=self.chat_id,
|
||||||
|
message_id=message.message_id,
|
||||||
|
interest_value=message_interest,
|
||||||
|
should_reply=message.should_reply
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新数据库中的消息记录
|
||||||
|
try:
|
||||||
|
from src.chat.message_receive.storage import MessageStorage
|
||||||
|
MessageStorage.update_message_interest_value(message.message_id, message_interest)
|
||||||
|
logger.debug(f"已更新数据库中消息 {message.message_id} 的兴趣度为: {message_interest:.3f}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"更新数据库消息兴趣度失败: {e}")
|
||||||
|
|
||||||
|
# 记录最高分
|
||||||
|
if message_interest > score:
|
||||||
|
score = message_interest
|
||||||
|
if message.should_reply:
|
||||||
|
should_reply = True
|
||||||
|
else:
|
||||||
|
reply_not_available = True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"计算消息 {message.message_id} 兴趣度失败: {e}")
|
||||||
|
# 设置默认值
|
||||||
|
message.interest_value = 0.0
|
||||||
|
message.should_reply = False
|
||||||
|
|
||||||
|
# 检查兴趣度是否达到非回复动作阈值
|
||||||
|
non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
|
||||||
|
if score < non_reply_action_interest_threshold:
|
||||||
|
logger.info(f"兴趣度 {score:.3f} 低于阈值 {non_reply_action_interest_threshold:.3f},不执行动作")
|
||||||
|
# 直接返回 no_action
|
||||||
|
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||||
|
|
||||||
|
no_action = ActionPlannerInfo(
|
||||||
|
action_type="no_action",
|
||||||
|
reasoning=f"兴趣度评分 {score:.3f} 未达阈值 {non_reply_action_interest_threshold:.3f}",
|
||||||
|
action_data={},
|
||||||
|
action_message=None,
|
||||||
|
)
|
||||||
|
filtered_plan = initial_plan
|
||||||
|
filtered_plan.decided_actions = [no_action]
|
||||||
|
else:
|
||||||
|
# 4. 筛选 Plan
|
||||||
|
available_actions = list(initial_plan.available_actions.keys())
|
||||||
|
plan_filter = ChatterPlanFilter(self.chat_id, available_actions)
|
||||||
|
filtered_plan = await plan_filter.filter(reply_not_available, initial_plan)
|
||||||
|
|
||||||
|
# 检查filtered_plan是否有reply动作,用于统计
|
||||||
|
has_reply_action = any(decision.action_type == "reply" for decision in filtered_plan.decided_actions)
|
||||||
|
|
||||||
|
# 5. 使用 PlanExecutor 执行 Plan
|
||||||
|
execution_result = await self.executor.execute(filtered_plan)
|
||||||
|
|
||||||
|
# 6. 根据执行结果更新统计信息
|
||||||
|
self._update_stats_from_execution_result(execution_result)
|
||||||
|
|
||||||
|
# 7. 返回结果
|
||||||
|
return self._build_return_result(filtered_plan)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"增强版规划流程出错: {e}")
|
||||||
|
self.planner_stats["failed_plans"] += 1
|
||||||
|
return [], None
|
||||||
|
|
||||||
|
def _update_stats_from_execution_result(self, execution_result: Dict[str, any]):
|
||||||
|
"""根据执行结果更新规划器统计"""
|
||||||
|
if not execution_result:
|
||||||
|
return
|
||||||
|
|
||||||
|
successful_count = execution_result.get("successful_count", 0)
|
||||||
|
|
||||||
|
# 更新成功执行计数
|
||||||
|
self.planner_stats["successful_plans"] += successful_count
|
||||||
|
|
||||||
|
# 统计回复动作和其他动作
|
||||||
|
reply_count = 0
|
||||||
|
other_count = 0
|
||||||
|
|
||||||
|
for result in execution_result.get("results", []):
|
||||||
|
action_type = result.get("action_type", "")
|
||||||
|
if action_type in ["reply", "proactive_reply"]:
|
||||||
|
reply_count += 1
|
||||||
|
else:
|
||||||
|
other_count += 1
|
||||||
|
|
||||||
|
self.planner_stats["replies_generated"] += reply_count
|
||||||
|
self.planner_stats["other_actions_executed"] += other_count
|
||||||
|
|
||||||
|
def _build_return_result(self, plan: "Plan") -> Tuple[List[Dict], Optional[Dict]]:
|
||||||
|
"""构建返回结果"""
|
||||||
|
final_actions = plan.decided_actions or []
|
||||||
|
final_target_message = next((act.action_message for act in final_actions if act.action_message), None)
|
||||||
|
|
||||||
|
final_actions_dict = [asdict(act) for act in final_actions]
|
||||||
|
|
||||||
|
if final_target_message:
|
||||||
|
if hasattr(final_target_message, "__dataclass_fields__"):
|
||||||
|
final_target_message_dict = asdict(final_target_message)
|
||||||
|
else:
|
||||||
|
final_target_message_dict = final_target_message
|
||||||
|
else:
|
||||||
|
final_target_message_dict = None
|
||||||
|
|
||||||
|
return final_actions_dict, final_target_message_dict
|
||||||
|
|
||||||
|
def get_planner_stats(self) -> Dict[str, any]:
|
||||||
|
"""获取规划器统计"""
|
||||||
|
return self.planner_stats.copy()
|
||||||
|
|
||||||
|
def get_current_mood_state(self) -> str:
|
||||||
|
"""获取当前聊天的情绪状态"""
|
||||||
|
chat_mood = mood_manager.get_mood_by_chat_id(self.chat_id)
|
||||||
|
return chat_mood.mood_state
|
||||||
|
|
||||||
|
def get_mood_stats(self) -> Dict[str, any]:
|
||||||
|
"""获取情绪状态统计"""
|
||||||
|
chat_mood = mood_manager.get_mood_by_chat_id(self.chat_id)
|
||||||
|
return {
|
||||||
|
"current_mood": chat_mood.mood_state,
|
||||||
|
"is_angry_from_wakeup": chat_mood.is_angry_from_wakeup,
|
||||||
|
"regression_count": chat_mood.regression_count,
|
||||||
|
"last_change_time": chat_mood.last_change_time,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 全局兴趣度评分系统实例 - 在 individuality 模块中创建
|
||||||
290
src/plugins/built_in/affinity_flow_chatter/planner_prompts.py
Normal file
290
src/plugins/built_in/affinity_flow_chatter/planner_prompts.py
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
"""
|
||||||
|
本文件集中管理所有与规划器(Planner)相关的提示词(Prompt)模板。
|
||||||
|
|
||||||
|
通过将提示词与代码逻辑分离,可以更方便地对模型的行为进行迭代和优化,
|
||||||
|
而无需修改核心代码。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.chat.utils.prompt import Prompt
|
||||||
|
|
||||||
|
|
||||||
|
def init_prompts():
|
||||||
|
"""
|
||||||
|
初始化并向 Prompt 注册系统注册所有规划器相关的提示词。
|
||||||
|
|
||||||
|
这个函数会在模块加载时自动调用,确保所有提示词在系统启动时都已准备就绪。
|
||||||
|
"""
|
||||||
|
# 核心规划器提示词,用于在接收到新消息时决定如何回应。
|
||||||
|
# 它构建了一个复杂的上下文,包括历史记录、可用动作、角色设定等,
|
||||||
|
# 并要求模型以 JSON 格式输出一个或多个动作组合。
|
||||||
|
Prompt(
|
||||||
|
"""
|
||||||
|
{mood_block}
|
||||||
|
{time_block}
|
||||||
|
{identity_block}
|
||||||
|
|
||||||
|
{users_in_chat}
|
||||||
|
{custom_prompt_block}
|
||||||
|
{chat_context_description},以下是具体的聊天内容。
|
||||||
|
|
||||||
|
## 📜 已读历史消息(仅供参考)
|
||||||
|
{read_history_block}
|
||||||
|
|
||||||
|
## 📬 未读历史消息(动作执行对象)
|
||||||
|
{unread_history_block}
|
||||||
|
|
||||||
|
{moderation_prompt}
|
||||||
|
|
||||||
|
**任务: 构建一个完整的响应**
|
||||||
|
你的任务是根据当前的聊天内容,构建一个完整的、人性化的响应。一个完整的响应由两部分组成:
|
||||||
|
1. **主要动作**: 这是响应的核心,通常是 `reply`(如果有)。
|
||||||
|
2. **辅助动作 (可选)**: 这是为了增强表达效果的附加动作,例如 `emoji`(发送表情包)或 `poke_user`(戳一戳)。
|
||||||
|
|
||||||
|
**决策流程:**
|
||||||
|
1. **重要:已读历史消息仅作为当前聊天情景的参考,帮助你理解对话上下文。**
|
||||||
|
2. **重要:所有动作的执行对象只能是未读历史消息中的消息,不能对已读消息执行动作。**
|
||||||
|
3. 在未读历史消息中,优先对兴趣值高的消息做出动作(兴趣值标注在消息末尾)。
|
||||||
|
4. 首先,决定是否要对未读消息进行 `reply`(如果有)。
|
||||||
|
5. 然后,评估当前的对话气氛和用户情绪,判断是否需要一个**辅助动作**来让你的回应更生动、更符合你的性格。
|
||||||
|
6. 如果需要,选择一个最合适的辅助动作与 `reply`(如果有) 组合。
|
||||||
|
7. 如果用户明确要求了某个动作,请务必优先满足。
|
||||||
|
|
||||||
|
**重要提醒:**
|
||||||
|
- **回复消息时必须遵循对话的流程,不要重复已经说过的话。**
|
||||||
|
- **确保回复与上下文紧密相关,回应要针对用户的消息内容。**
|
||||||
|
- **保持角色设定的一致性,使用符合你性格的语言风格。**
|
||||||
|
- **不要对表情包消息做出回应!**
|
||||||
|
|
||||||
|
**输出格式:**
|
||||||
|
请严格按照以下 JSON 格式输出,包含 `thinking` 和 `actions` 字段:
|
||||||
|
|
||||||
|
**重要概念:将“内心思考”作为思绪流的体现**
|
||||||
|
`thinking` 字段是本次决策的核心。它并非一个简单的“理由”,而是 **一个模拟人类在回应前,头脑中自然浮现的、未经修饰的思绪流**。你需要完全代入 {identity_block} 的角色,将那一刻的想法自然地记录下来。
|
||||||
|
|
||||||
|
**内心思考的要点:**
|
||||||
|
* **自然流露**: 不要使用“决定”、“所以”、“因此”等结论性或汇报式的词语。你的思考应该像日记一样,是给自己看的,充满了不确定性和情绪的自然流动。
|
||||||
|
* **展现过程**: 重点在于展现 **思考的过程**,而不是 **决策的结果**。描述你看到了什么,想到了什么,感受到了什么。
|
||||||
|
* **使用昵称**: 在你的思绪流中,请直接使用用户的昵称来指代他们,而不是`<m1>`, `<m2>`这样的消息ID。
|
||||||
|
* **严禁技术术语**: 严禁在思考中提及任何数字化的度量(如兴趣度、分数)或内部技术术语。请完全使用角色自身的感受和语言来描述思考过程。
|
||||||
|
|
||||||
|
## 可用动作列表
|
||||||
|
{action_options_text}
|
||||||
|
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"thinking": "在这里写下你的思绪流...",
|
||||||
|
"actions": [
|
||||||
|
{{
|
||||||
|
"action_type": "动作类型(如:reply, emoji等)",
|
||||||
|
"reasoning": "选择该动作的理由",
|
||||||
|
"action_data": {{
|
||||||
|
"target_message_id": "目标消息ID",
|
||||||
|
"content": "回复内容或其他动作所需数据"
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
|
||||||
|
**强制规则**:
|
||||||
|
- 对于每一个需要目标消息的动作(如`reply`, `poke_user`, `set_emoji_like`),你 **必须** 在`action_data`中提供准确的`target_message_id`,这个ID来源于`## 未读历史消息`中消息前的`<m...>`标签。
|
||||||
|
- 当你选择的动作需要参数时(例如 `set_emoji_like` 需要 `emoji` 参数),你 **必须** 在 `action_data` 中提供所有必需的参数及其对应的值。
|
||||||
|
|
||||||
|
如果没有合适的回复对象或不需要回复,输出空的 actions 数组:
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"thinking": "说明为什么不需要回复",
|
||||||
|
"actions": []
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
""",
|
||||||
|
"planner_prompt",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 主动规划器提示词,用于主动场景和前瞻性规划
|
||||||
|
Prompt(
|
||||||
|
"""
|
||||||
|
{mood_block}
|
||||||
|
{time_block}
|
||||||
|
{identity_block}
|
||||||
|
|
||||||
|
{users_in_chat}
|
||||||
|
{custom_prompt_block}
|
||||||
|
{chat_context_description},以下是具体的聊天内容。
|
||||||
|
|
||||||
|
## 📜 已读历史消息(仅供参考)
|
||||||
|
{read_history_block}
|
||||||
|
|
||||||
|
## 📬 未读历史消息(动作执行对象)
|
||||||
|
{unread_history_block}
|
||||||
|
|
||||||
|
{moderation_prompt}
|
||||||
|
|
||||||
|
**任务: 构建一个完整的响应**
|
||||||
|
你的任务是根据当前的聊天内容,构建一个完整的、人性化的响应。一个完整的响应由两部分组成:
|
||||||
|
1. **主要动作**: 这是响应的核心,通常是 `reply`(如果有)。
|
||||||
|
2. **辅助动作 (可选)**: 这是为了增强表达效果的附加动作,例如 `emoji`(发送表情包)或 `poke_user`(戳一戳)。
|
||||||
|
|
||||||
|
**决策流程:**
|
||||||
|
1. **重要:已读历史消息仅作为当前聊天情景的参考,帮助你理解对话上下文。**
|
||||||
|
2. **重要:所有动作的执行对象只能是未读历史消息中的消息,不能对已读消息执行动作。**
|
||||||
|
3. 在未读历史消息中,优先对兴趣值高的消息做出动作(兴趣值标注在消息末尾)。
|
||||||
|
4. 首先,决定是否要对未读消息进行 `reply`(如果有)。
|
||||||
|
5. 然后,评估当前的对话气氛和用户情绪,判断是否需要一个**辅助动作**来让你的回应更生动、更符合你的性格。
|
||||||
|
6. 如果需要,选择一个最合适的辅助动作与 `reply`(如果有) 组合。
|
||||||
|
7. 如果用户明确要求了某个动作,请务必优先满足。
|
||||||
|
|
||||||
|
**动作限制:**
|
||||||
|
- 在私聊中,你只能使用 `reply` 动作。私聊中不允许使用任何其他动作。
|
||||||
|
- 在群聊中,你可以自由选择是否使用辅助动作。
|
||||||
|
|
||||||
|
**重要提醒:**
|
||||||
|
- **回复消息时必须遵循对话的流程,不要重复已经说过的话。**
|
||||||
|
- **确保回复与上下文紧密相关,回应要针对用户的消息内容。**
|
||||||
|
- **保持角色设定的一致性,使用符合你性格的语言风格。**
|
||||||
|
|
||||||
|
**输出格式:**
|
||||||
|
请严格按照以下 JSON 格式输出,包含 `thinking` 和 `actions` 字段:
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"thinking": "你的思考过程,分析当前情况并说明为什么选择这些动作",
|
||||||
|
"actions": [
|
||||||
|
{{
|
||||||
|
"action_type": "动作类型(如:reply, emoji等)",
|
||||||
|
"reasoning": "选择该动作的理由",
|
||||||
|
"action_data": {{
|
||||||
|
"target_message_id": "目标消息ID",
|
||||||
|
"content": "回复内容或其他动作所需数据"
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
|
||||||
|
如果没有合适的回复对象或不需要回复,输出空的 actions 数组:
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"thinking": "说明为什么不需要回复",
|
||||||
|
"actions": []
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
""",
|
||||||
|
"proactive_planner_prompt",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 轻量级规划器提示词,用于快速决策和简单场景
|
||||||
|
Prompt(
|
||||||
|
"""
|
||||||
|
{identity_block}
|
||||||
|
|
||||||
|
## 当前聊天情景
|
||||||
|
{chat_context_description}
|
||||||
|
|
||||||
|
## 未读消息
|
||||||
|
{unread_history_block}
|
||||||
|
|
||||||
|
**任务:快速决策**
|
||||||
|
请根据当前聊天内容,快速决定是否需要回复。
|
||||||
|
|
||||||
|
**决策规则:**
|
||||||
|
1. 如果有人直接提到你或问你问题,优先回复
|
||||||
|
2. 如果消息内容符合你的兴趣,考虑回复
|
||||||
|
3. 如果只是群聊中的普通聊天且与你无关,可以不回复
|
||||||
|
|
||||||
|
**输出格式:**
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"thinking": "简要分析",
|
||||||
|
"actions": [
|
||||||
|
{{
|
||||||
|
"action_type": "reply",
|
||||||
|
"reasoning": "回复理由",
|
||||||
|
"action_data": {{
|
||||||
|
"target_message_id": "目标消息ID",
|
||||||
|
"content": "回复内容"
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
""",
|
||||||
|
"chatter_planner_lite",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 动作筛选器提示词,用于筛选和优化规划器生成的动作
|
||||||
|
Prompt(
|
||||||
|
"""
|
||||||
|
{identity_block}
|
||||||
|
|
||||||
|
## 原始动作计划
|
||||||
|
{original_plan}
|
||||||
|
|
||||||
|
## 聊天上下文
|
||||||
|
{chat_context}
|
||||||
|
|
||||||
|
**任务:动作筛选优化**
|
||||||
|
请对原始动作计划进行筛选和优化,确保动作的合理性和有效性。
|
||||||
|
|
||||||
|
**筛选原则:**
|
||||||
|
1. 移除重复或不必要的动作
|
||||||
|
2. 确保动作之间的逻辑顺序
|
||||||
|
3. 优化动作的具体参数
|
||||||
|
4. 考虑当前聊天环境和个人设定
|
||||||
|
|
||||||
|
**输出格式:**
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"thinking": "筛选优化思考",
|
||||||
|
"actions": [
|
||||||
|
{{
|
||||||
|
"action_type": "优化后的动作类型",
|
||||||
|
"reasoning": "优化理由",
|
||||||
|
"action_data": {{
|
||||||
|
"target_message_id": "目标消息ID",
|
||||||
|
"content": "优化后的内容"
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
""",
|
||||||
|
"chatter_plan_filter",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 动作提示词,用于格式化动作选项
|
||||||
|
Prompt(
|
||||||
|
"""
|
||||||
|
## 动作: {action_name}
|
||||||
|
**描述**: {action_description}
|
||||||
|
|
||||||
|
**参数**:
|
||||||
|
{action_parameters}
|
||||||
|
|
||||||
|
**要求**:
|
||||||
|
{action_require}
|
||||||
|
|
||||||
|
**使用说明**:
|
||||||
|
请根据上述信息判断是否需要使用此动作。
|
||||||
|
""",
|
||||||
|
"action_prompt",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 带有完整JSON示例的动作提示词模板
|
||||||
|
Prompt(
|
||||||
|
"""
|
||||||
|
动作: {action_name}
|
||||||
|
动作描述: {action_description}
|
||||||
|
动作使用场景:
|
||||||
|
{action_require}
|
||||||
|
|
||||||
|
你应该像这样使用它:
|
||||||
|
{{
|
||||||
|
{json_example}
|
||||||
|
}}
|
||||||
|
""",
|
||||||
|
"action_prompt_with_example",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 确保提示词在模块加载时初始化
|
||||||
|
init_prompts()
|
||||||
46
src/plugins/built_in/affinity_flow_chatter/plugin.py
Normal file
46
src/plugins/built_in/affinity_flow_chatter/plugin.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""
|
||||||
|
亲和力聊天处理器插件
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Tuple, Type
|
||||||
|
|
||||||
|
from src.plugin_system.apis.plugin_register_api import register_plugin
|
||||||
|
from src.plugin_system.base.base_plugin import BasePlugin
|
||||||
|
from src.plugin_system.base.component_types import ComponentInfo
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("affinity_chatter_plugin")
|
||||||
|
|
||||||
|
|
||||||
|
@register_plugin
|
||||||
|
class AffinityChatterPlugin(BasePlugin):
|
||||||
|
"""亲和力聊天处理器插件
|
||||||
|
|
||||||
|
- 延迟导入 `AffinityChatter` 并通过组件注册器注册为聊天处理器
|
||||||
|
- 提供 `get_plugin_components` 以兼容插件注册机制
|
||||||
|
"""
|
||||||
|
|
||||||
|
plugin_name: str = "affinity_chatter"
|
||||||
|
enable_plugin: bool = True
|
||||||
|
dependencies: list[str] = []
|
||||||
|
python_dependencies: list[str] = []
|
||||||
|
config_file_name: str = ""
|
||||||
|
|
||||||
|
# 简单的 config_schema 占位(如果将来需要配置可扩展)
|
||||||
|
config_schema = {}
|
||||||
|
|
||||||
|
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||||
|
"""返回插件包含的组件列表(ChatterInfo, AffinityChatter)
|
||||||
|
|
||||||
|
这里采用延迟导入 AffinityChatter 来避免循环依赖和启动顺序问题。
|
||||||
|
如果导入失败则返回空列表以让注册过程继续而不崩溃。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 延迟导入以避免循环导入
|
||||||
|
from .affinity_chatter import AffinityChatter
|
||||||
|
|
||||||
|
return [(AffinityChatter.get_chatter_info(), AffinityChatter)]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载 AffinityChatter 时出错: {e}")
|
||||||
|
return []
|
||||||
@@ -0,0 +1,755 @@
|
|||||||
|
"""
|
||||||
|
用户关系追踪器
|
||||||
|
负责追踪用户交互历史,并通过LLM分析更新用户关系分
|
||||||
|
支持数据库持久化存储和回复后自动关系更新
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import model_config, global_config
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||||
|
from src.common.database.sqlalchemy_models import UserRelationships, Messages
|
||||||
|
from sqlalchemy import select, desc
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|
||||||
|
logger = get_logger("chatter_relationship_tracker")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatterRelationshipTracker:
|
||||||
|
"""用户关系追踪器"""
|
||||||
|
|
||||||
|
def __init__(self, interest_scoring_system=None):
|
||||||
|
self.tracking_users: Dict[str, Dict] = {} # user_id -> interaction_data
|
||||||
|
self.max_tracking_users = 3
|
||||||
|
self.update_interval_minutes = 30
|
||||||
|
self.last_update_time = time.time()
|
||||||
|
self.relationship_history: List[Dict] = []
|
||||||
|
self.interest_scoring_system = interest_scoring_system
|
||||||
|
|
||||||
|
# 用户关系缓存 (user_id -> {"relationship_text": str, "relationship_score": float, "last_tracked": float})
|
||||||
|
self.user_relationship_cache: Dict[str, Dict] = {}
|
||||||
|
self.cache_expiry_hours = 1 # 缓存过期时间(小时)
|
||||||
|
|
||||||
|
# 关系更新LLM
|
||||||
|
try:
|
||||||
|
self.relationship_llm = LLMRequest(
|
||||||
|
model_set=model_config.model_task_config.relationship_tracker, request_type="relationship_tracker"
|
||||||
|
)
|
||||||
|
except AttributeError:
|
||||||
|
# 如果relationship_tracker配置不存在,尝试其他可用的模型配置
|
||||||
|
available_models = [
|
||||||
|
attr
|
||||||
|
for attr in dir(model_config.model_task_config)
|
||||||
|
if not attr.startswith("_") and attr != "model_dump"
|
||||||
|
]
|
||||||
|
|
||||||
|
if available_models:
|
||||||
|
# 使用第一个可用的模型配置
|
||||||
|
fallback_model = available_models[0]
|
||||||
|
logger.warning(f"relationship_tracker model configuration not found, using fallback: {fallback_model}")
|
||||||
|
self.relationship_llm = LLMRequest(
|
||||||
|
model_set=getattr(model_config.model_task_config, fallback_model),
|
||||||
|
request_type="relationship_tracker",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 如果没有任何模型配置,创建一个简单的LLMRequest
|
||||||
|
logger.warning("No model configurations found, creating basic LLMRequest")
|
||||||
|
self.relationship_llm = LLMRequest(
|
||||||
|
model_set="gpt-3.5-turbo", # 默认模型
|
||||||
|
request_type="relationship_tracker",
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_interest_scoring_system(self, interest_scoring_system):
|
||||||
|
"""设置兴趣度评分系统引用"""
|
||||||
|
self.interest_scoring_system = interest_scoring_system
|
||||||
|
|
||||||
|
def add_interaction(self, user_id: str, user_name: str, user_message: str, bot_reply: str, reply_timestamp: float):
|
||||||
|
"""添加用户交互记录"""
|
||||||
|
if len(self.tracking_users) >= self.max_tracking_users:
|
||||||
|
# 移除最旧的记录
|
||||||
|
oldest_user = min(
|
||||||
|
self.tracking_users.keys(), key=lambda k: self.tracking_users[k].get("reply_timestamp", 0)
|
||||||
|
)
|
||||||
|
del self.tracking_users[oldest_user]
|
||||||
|
|
||||||
|
# 获取当前关系分
|
||||||
|
current_relationship_score = global_config.affinity_flow.base_relationship_score # 默认值
|
||||||
|
if self.interest_scoring_system:
|
||||||
|
current_relationship_score = self.interest_scoring_system.get_user_relationship(user_id)
|
||||||
|
|
||||||
|
self.tracking_users[user_id] = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"user_name": user_name,
|
||||||
|
"user_message": user_message,
|
||||||
|
"bot_reply": bot_reply,
|
||||||
|
"reply_timestamp": reply_timestamp,
|
||||||
|
"current_relationship_score": current_relationship_score,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(f"添加用户交互追踪: {user_id}")
|
||||||
|
|
||||||
|
async def check_and_update_relationships(self) -> List[Dict]:
|
||||||
|
"""检查并更新用户关系"""
|
||||||
|
current_time = time.time()
|
||||||
|
if current_time - self.last_update_time < self.update_interval_minutes * 60:
|
||||||
|
return []
|
||||||
|
|
||||||
|
updates = []
|
||||||
|
for user_id, interaction in list(self.tracking_users.items()):
|
||||||
|
if current_time - interaction["reply_timestamp"] > 60 * 5: # 5分钟
|
||||||
|
update = await self._update_user_relationship(interaction)
|
||||||
|
if update:
|
||||||
|
updates.append(update)
|
||||||
|
del self.tracking_users[user_id]
|
||||||
|
|
||||||
|
self.last_update_time = current_time
|
||||||
|
return updates
|
||||||
|
|
||||||
|
async def _update_user_relationship(self, interaction: Dict) -> Optional[Dict]:
|
||||||
|
"""更新单个用户的关系"""
|
||||||
|
try:
|
||||||
|
# 获取bot人设信息
|
||||||
|
from src.individuality.individuality import Individuality
|
||||||
|
|
||||||
|
individuality = Individuality()
|
||||||
|
bot_personality = await individuality.get_personality_block()
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
你现在是一个有着特定性格和身份的AI助手。你的人设是:{bot_personality}
|
||||||
|
|
||||||
|
请以你独特的性格视角,严格按现实逻辑分析以下用户交互,更新用户关系:
|
||||||
|
|
||||||
|
用户ID: {interaction["user_id"]}
|
||||||
|
用户名: {interaction["user_name"]}
|
||||||
|
用户消息: {interaction["user_message"]}
|
||||||
|
你的回复: {interaction["bot_reply"]}
|
||||||
|
当前关系分: {interaction["current_relationship_score"]}
|
||||||
|
|
||||||
|
【重要】关系分数档次定义:
|
||||||
|
- 0.0-0.2:陌生人/初次认识 - 仅礼貌性交流
|
||||||
|
- 0.2-0.4:普通网友 - 有基本互动但不熟悉
|
||||||
|
- 0.4-0.6:熟悉网友 - 经常交流,有一定了解
|
||||||
|
- 0.6-0.8:朋友 - 可以分享心情,互相关心
|
||||||
|
- 0.8-1.0:好朋友/知己 - 深度信任,亲密无间
|
||||||
|
|
||||||
|
【严格要求】:
|
||||||
|
1. 加分必须符合现实关系发展逻辑 - 不能因为对方态度好就盲目加分到不符合当前关系档次的分数
|
||||||
|
2. 关系提升需要足够的互动积累和时间验证
|
||||||
|
3. 即使是朋友关系,单次互动加分通常不超过0.05-0.1
|
||||||
|
4. 关系描述要详细具体,包括:
|
||||||
|
- 用户性格特点观察
|
||||||
|
- 印象深刻的互动记忆
|
||||||
|
- 你们关系的具体状态描述
|
||||||
|
|
||||||
|
根据你的人设性格,思考:
|
||||||
|
1. 以你的性格,你会如何看待这次互动?
|
||||||
|
2. 用户的行为是否符合你性格的喜好?
|
||||||
|
3. 这次互动是否真的让你们的关系提升了一个档次?为什么?
|
||||||
|
4. 有什么特别值得记住的互动细节?
|
||||||
|
|
||||||
|
请以JSON格式返回更新结果:
|
||||||
|
{{
|
||||||
|
"new_relationship_score": 0.0~1.0的数值(必须符合现实逻辑),
|
||||||
|
"reasoning": "从你的性格角度说明更新理由,重点说明是否符合现实关系发展逻辑",
|
||||||
|
"interaction_summary": "基于你性格的交互总结,包含印象深刻的互动记忆"
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
llm_response, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
|
||||||
|
if llm_response:
|
||||||
|
import json
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 清理LLM响应,移除可能的格式标记
|
||||||
|
cleaned_response = self._clean_llm_json_response(llm_response)
|
||||||
|
response_data = json.loads(cleaned_response)
|
||||||
|
new_score = max(
|
||||||
|
0.0,
|
||||||
|
min(
|
||||||
|
1.0,
|
||||||
|
float(
|
||||||
|
response_data.get(
|
||||||
|
"new_relationship_score", global_config.affinity_flow.base_relationship_score
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.interest_scoring_system:
|
||||||
|
self.interest_scoring_system.update_user_relationship(
|
||||||
|
interaction["user_id"], new_score - interaction["current_relationship_score"]
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"user_id": interaction["user_id"],
|
||||||
|
"new_relationship_score": new_score,
|
||||||
|
"reasoning": response_data.get("reasoning", ""),
|
||||||
|
"interaction_summary": response_data.get("interaction_summary", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"LLM响应JSON解析失败: {e}")
|
||||||
|
logger.debug(f"LLM原始响应: {llm_response}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理关系更新数据失败: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"更新用户关系时出错: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_tracking_users(self) -> Dict[str, Dict]:
|
||||||
|
"""获取正在追踪的用户"""
|
||||||
|
return self.tracking_users.copy()
|
||||||
|
|
||||||
|
def get_user_interaction(self, user_id: str) -> Optional[Dict]:
|
||||||
|
"""获取特定用户的交互记录"""
|
||||||
|
return self.tracking_users.get(user_id)
|
||||||
|
|
||||||
|
def remove_user_tracking(self, user_id: str):
|
||||||
|
"""移除用户追踪"""
|
||||||
|
if user_id in self.tracking_users:
|
||||||
|
del self.tracking_users[user_id]
|
||||||
|
logger.debug(f"移除用户追踪: {user_id}")
|
||||||
|
|
||||||
|
def clear_all_tracking(self):
|
||||||
|
"""清空所有追踪"""
|
||||||
|
self.tracking_users.clear()
|
||||||
|
logger.info("清空所有用户追踪")
|
||||||
|
|
||||||
|
def get_relationship_history(self) -> List[Dict]:
|
||||||
|
"""获取关系历史记录"""
|
||||||
|
return self.relationship_history.copy()
|
||||||
|
|
||||||
|
def add_to_history(self, relationship_update: Dict):
|
||||||
|
"""添加到关系历史"""
|
||||||
|
self.relationship_history.append({**relationship_update, "update_time": time.time()})
|
||||||
|
|
||||||
|
# 限制历史记录数量
|
||||||
|
if len(self.relationship_history) > 100:
|
||||||
|
self.relationship_history = self.relationship_history[-100:]
|
||||||
|
|
||||||
|
def get_tracker_stats(self) -> Dict:
|
||||||
|
"""获取追踪器统计"""
|
||||||
|
return {
|
||||||
|
"tracking_users": len(self.tracking_users),
|
||||||
|
"max_tracking_users": self.max_tracking_users,
|
||||||
|
"update_interval_minutes": self.update_interval_minutes,
|
||||||
|
"relationship_history": len(self.relationship_history),
|
||||||
|
"last_update_time": self.last_update_time,
|
||||||
|
}
|
||||||
|
|
||||||
|
def update_config(self, max_tracking_users: int = None, update_interval_minutes: int = None):
|
||||||
|
"""更新配置"""
|
||||||
|
if max_tracking_users is not None:
|
||||||
|
self.max_tracking_users = max_tracking_users
|
||||||
|
logger.info(f"更新最大追踪用户数: {max_tracking_users}")
|
||||||
|
|
||||||
|
if update_interval_minutes is not None:
|
||||||
|
self.update_interval_minutes = update_interval_minutes
|
||||||
|
logger.info(f"更新关系更新间隔: {update_interval_minutes} 分钟")
|
||||||
|
|
||||||
|
def force_update_relationship(self, user_id: str, new_score: float, reasoning: str = ""):
|
||||||
|
"""强制更新用户关系分"""
|
||||||
|
if user_id in self.tracking_users:
|
||||||
|
current_score = self.tracking_users[user_id]["current_relationship_score"]
|
||||||
|
if self.interest_scoring_system:
|
||||||
|
self.interest_scoring_system.update_user_relationship(user_id, new_score - current_score)
|
||||||
|
|
||||||
|
update_info = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"new_relationship_score": new_score,
|
||||||
|
"reasoning": reasoning or "手动更新",
|
||||||
|
"interaction_summary": "手动更新关系分",
|
||||||
|
}
|
||||||
|
self.add_to_history(update_info)
|
||||||
|
logger.info(f"强制更新用户关系: {user_id} -> {new_score:.2f}")
|
||||||
|
|
||||||
|
def get_user_summary(self, user_id: str) -> Dict:
|
||||||
|
"""获取用户交互总结"""
|
||||||
|
if user_id not in self.tracking_users:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
interaction = self.tracking_users[user_id]
|
||||||
|
return {
|
||||||
|
"user_id": user_id,
|
||||||
|
"user_name": interaction["user_name"],
|
||||||
|
"current_relationship_score": interaction["current_relationship_score"],
|
||||||
|
"interaction_count": 1, # 简化版本,每次追踪只记录一次交互
|
||||||
|
"last_interaction": interaction["reply_timestamp"],
|
||||||
|
"recent_message": interaction["user_message"][:100] + "..."
|
||||||
|
if len(interaction["user_message"]) > 100
|
||||||
|
else interaction["user_message"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# ===== 数据库支持方法 =====
|
||||||
|
|
||||||
|
def get_user_relationship_score(self, user_id: str) -> float:
|
||||||
|
"""获取用户关系分"""
|
||||||
|
# 先检查缓存
|
||||||
|
if user_id in self.user_relationship_cache:
|
||||||
|
cache_data = self.user_relationship_cache[user_id]
|
||||||
|
# 检查缓存是否过期
|
||||||
|
cache_time = cache_data.get("last_tracked", 0)
|
||||||
|
if time.time() - cache_time < self.cache_expiry_hours * 3600:
|
||||||
|
return cache_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)
|
||||||
|
|
||||||
|
# 缓存过期或不存在,从数据库获取
|
||||||
|
relationship_data = self._get_user_relationship_from_db(user_id)
|
||||||
|
if relationship_data:
|
||||||
|
# 更新缓存
|
||||||
|
self.user_relationship_cache[user_id] = {
|
||||||
|
"relationship_text": relationship_data.get("relationship_text", ""),
|
||||||
|
"relationship_score": relationship_data.get(
|
||||||
|
"relationship_score", global_config.affinity_flow.base_relationship_score
|
||||||
|
),
|
||||||
|
"last_tracked": time.time(),
|
||||||
|
}
|
||||||
|
return relationship_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)
|
||||||
|
|
||||||
|
# 数据库中也没有,返回默认值
|
||||||
|
return global_config.affinity_flow.base_relationship_score
|
||||||
|
|
||||||
|
def _get_user_relationship_from_db(self, user_id: str) -> Optional[Dict]:
|
||||||
|
"""从数据库获取用户关系数据"""
|
||||||
|
try:
|
||||||
|
with get_db_session() as session:
|
||||||
|
# 查询用户关系表
|
||||||
|
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
|
||||||
|
result = session.execute(stmt).scalar_one_or_none()
|
||||||
|
|
||||||
|
if result:
|
||||||
|
return {
|
||||||
|
"relationship_text": result.relationship_text or "",
|
||||||
|
"relationship_score": float(result.relationship_score)
|
||||||
|
if result.relationship_score is not None
|
||||||
|
else 0.3,
|
||||||
|
"last_updated": result.last_updated,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"从数据库获取用户关系失败: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _update_user_relationship_in_db(self, user_id: str, relationship_text: str, relationship_score: float):
|
||||||
|
"""更新数据库中的用户关系"""
|
||||||
|
try:
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
with get_db_session() as session:
|
||||||
|
# 检查是否已存在关系记录
|
||||||
|
existing = session.execute(
|
||||||
|
select(UserRelationships).where(UserRelationships.user_id == user_id)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
# 更新现有记录
|
||||||
|
existing.relationship_text = relationship_text
|
||||||
|
existing.relationship_score = relationship_score
|
||||||
|
existing.last_updated = current_time
|
||||||
|
existing.user_name = existing.user_name or user_id # 更新用户名如果为空
|
||||||
|
else:
|
||||||
|
# 插入新记录
|
||||||
|
new_relationship = UserRelationships(
|
||||||
|
user_id=user_id,
|
||||||
|
user_name=user_id,
|
||||||
|
relationship_text=relationship_text,
|
||||||
|
relationship_score=relationship_score,
|
||||||
|
last_updated=current_time,
|
||||||
|
)
|
||||||
|
session.add(new_relationship)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
logger.info(f"已更新数据库中用户关系: {user_id} -> 分数: {relationship_score:.3f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"更新数据库用户关系失败: {e}")
|
||||||
|
|
||||||
|
# ===== 回复后关系追踪方法 =====
|
||||||
|
|
||||||
|
async def track_reply_relationship(
|
||||||
|
self, user_id: str, user_name: str, bot_reply_content: str, reply_timestamp: float
|
||||||
|
):
|
||||||
|
"""回复后关系追踪 - 主要入口点"""
|
||||||
|
try:
|
||||||
|
logger.info(f"🔄 [RelationshipTracker] 开始回复后关系追踪: {user_id}")
|
||||||
|
|
||||||
|
# 检查上次追踪时间
|
||||||
|
last_tracked_time = self._get_last_tracked_time(user_id)
|
||||||
|
time_diff = reply_timestamp - last_tracked_time
|
||||||
|
|
||||||
|
if time_diff < 5 * 60: # 5分钟内不重复追踪
|
||||||
|
logger.debug(
|
||||||
|
f"⏱️ [RelationshipTracker] 用户 {user_id} 距离上次追踪时间不足5分钟 ({time_diff:.2f}s),跳过"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 获取上次bot回复该用户的消息
|
||||||
|
last_bot_reply = await self._get_last_bot_reply_to_user(user_id)
|
||||||
|
if not last_bot_reply:
|
||||||
|
logger.info(f"👋 [RelationshipTracker] 未找到用户 {user_id} 的历史回复记录,启动'初次见面'逻辑")
|
||||||
|
await self._handle_first_interaction(user_id, user_name, bot_reply_content)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 获取用户后续的反应消息
|
||||||
|
user_reactions = await self._get_user_reactions_after_reply(user_id, last_bot_reply.time)
|
||||||
|
logger.debug(f"💬 [RelationshipTracker] 找到用户 {user_id} 在上次回复后的 {len(user_reactions)} 条反应消息")
|
||||||
|
|
||||||
|
# 获取当前关系数据
|
||||||
|
current_relationship = self._get_user_relationship_from_db(user_id)
|
||||||
|
current_score = (
|
||||||
|
current_relationship.get("relationship_score", global_config.affinity_flow.base_relationship_score)
|
||||||
|
if current_relationship
|
||||||
|
else global_config.affinity_flow.base_relationship_score
|
||||||
|
)
|
||||||
|
current_text = current_relationship.get("relationship_text", "新用户") if current_relationship else "新用户"
|
||||||
|
|
||||||
|
# 使用LLM分析并更新关系
|
||||||
|
logger.debug(f"🧠 [RelationshipTracker] 开始为用户 {user_id} 分析并更新关系")
|
||||||
|
await self._analyze_and_update_relationship(
|
||||||
|
user_id, user_name, last_bot_reply, user_reactions, current_text, current_score, bot_reply_content
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"回复后关系追踪失败: {e}")
|
||||||
|
logger.debug("错误详情:", exc_info=True)
|
||||||
|
|
||||||
|
def _get_last_tracked_time(self, user_id: str) -> float:
|
||||||
|
"""获取上次追踪时间"""
|
||||||
|
# 先检查缓存
|
||||||
|
if user_id in self.user_relationship_cache:
|
||||||
|
return self.user_relationship_cache[user_id].get("last_tracked", 0)
|
||||||
|
|
||||||
|
# 从数据库获取
|
||||||
|
relationship_data = self._get_user_relationship_from_db(user_id)
|
||||||
|
if relationship_data:
|
||||||
|
return relationship_data.get("last_updated", 0)
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def _get_last_bot_reply_to_user(self, user_id: str) -> Optional[DatabaseMessages]:
|
||||||
|
"""获取上次bot回复该用户的消息"""
|
||||||
|
try:
|
||||||
|
with get_db_session() as session:
|
||||||
|
# 查询bot回复给该用户的最新消息
|
||||||
|
stmt = (
|
||||||
|
select(Messages)
|
||||||
|
.where(Messages.user_id == user_id)
|
||||||
|
.where(Messages.reply_to.isnot(None))
|
||||||
|
.order_by(desc(Messages.time))
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = session.execute(stmt).scalar_one_or_none()
|
||||||
|
if result:
|
||||||
|
# 将SQLAlchemy模型转换为DatabaseMessages对象
|
||||||
|
return self._sqlalchemy_to_database_messages(result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取上次回复消息失败: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _get_user_reactions_after_reply(self, user_id: str, reply_time: float) -> List[DatabaseMessages]:
|
||||||
|
"""获取用户在bot回复后的反应消息"""
|
||||||
|
try:
|
||||||
|
with get_db_session() as session:
|
||||||
|
# 查询用户在回复时间之后的5分钟内的消息
|
||||||
|
end_time = reply_time + 5 * 60 # 5分钟
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(Messages)
|
||||||
|
.where(Messages.user_id == user_id)
|
||||||
|
.where(Messages.time > reply_time)
|
||||||
|
.where(Messages.time <= end_time)
|
||||||
|
.order_by(Messages.time)
|
||||||
|
)
|
||||||
|
|
||||||
|
results = session.execute(stmt).scalars().all()
|
||||||
|
if results:
|
||||||
|
return [self._sqlalchemy_to_database_messages(result) for result in results]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取用户反应消息失败: {e}")
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _sqlalchemy_to_database_messages(self, sqlalchemy_message) -> DatabaseMessages:
|
||||||
|
"""将SQLAlchemy消息模型转换为DatabaseMessages对象"""
|
||||||
|
try:
|
||||||
|
return DatabaseMessages(
|
||||||
|
message_id=sqlalchemy_message.message_id or "",
|
||||||
|
time=float(sqlalchemy_message.time) if sqlalchemy_message.time is not None else 0.0,
|
||||||
|
chat_id=sqlalchemy_message.chat_id or "",
|
||||||
|
reply_to=sqlalchemy_message.reply_to,
|
||||||
|
processed_plain_text=sqlalchemy_message.processed_plain_text or "",
|
||||||
|
user_id=sqlalchemy_message.user_id or "",
|
||||||
|
user_nickname=sqlalchemy_message.user_nickname or "",
|
||||||
|
user_platform=sqlalchemy_message.user_platform or "",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"SQLAlchemy消息转换失败: {e}")
|
||||||
|
# 返回一个基本的消息对象
|
||||||
|
return DatabaseMessages(
|
||||||
|
message_id="",
|
||||||
|
time=0.0,
|
||||||
|
chat_id="",
|
||||||
|
processed_plain_text="",
|
||||||
|
user_id="",
|
||||||
|
user_nickname="",
|
||||||
|
user_platform="",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _analyze_and_update_relationship(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
user_name: str,
|
||||||
|
last_bot_reply: DatabaseMessages,
|
||||||
|
user_reactions: List[DatabaseMessages],
|
||||||
|
current_text: str,
|
||||||
|
current_score: float,
|
||||||
|
current_reply: str,
|
||||||
|
):
|
||||||
|
"""使用LLM分析并更新用户关系"""
|
||||||
|
try:
|
||||||
|
# 构建分析提示
|
||||||
|
user_reactions_text = "\n".join([f"- {msg.processed_plain_text}" for msg in user_reactions])
|
||||||
|
|
||||||
|
# 获取bot人设信息
|
||||||
|
from src.individuality.individuality import Individuality
|
||||||
|
|
||||||
|
individuality = Individuality()
|
||||||
|
bot_personality = await individuality.get_personality_block()
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
你现在是一个有着特定性格和身份的AI助手。你的人设是:{bot_personality}
|
||||||
|
|
||||||
|
请以你独特的性格视角,严格按现实逻辑分析以下用户交互,更新用户关系印象和分数:
|
||||||
|
|
||||||
|
用户信息:
|
||||||
|
- 用户ID: {user_id}
|
||||||
|
- 用户名: {user_name}
|
||||||
|
|
||||||
|
你上次的回复: {last_bot_reply.processed_plain_text}
|
||||||
|
|
||||||
|
用户反应消息:
|
||||||
|
{user_reactions_text}
|
||||||
|
|
||||||
|
你当前的回复: {current_reply}
|
||||||
|
|
||||||
|
当前关系印象: {current_text}
|
||||||
|
当前关系分数: {current_score:.3f}
|
||||||
|
|
||||||
|
【重要】关系分数档次定义:
|
||||||
|
- 0.0-0.2:陌生人/初次认识 - 仅礼貌性交流
|
||||||
|
- 0.2-0.4:普通网友 - 有基本互动但不熟悉
|
||||||
|
- 0.4-0.6:熟悉网友 - 经常交流,有一定了解
|
||||||
|
- 0.6-0.8:朋友 - 可以分享心情,互相关心
|
||||||
|
- 0.8-1.0:好朋友/知己 - 深度信任,亲密无间
|
||||||
|
|
||||||
|
【严格要求】:
|
||||||
|
1. 加分必须符合现实关系发展逻辑 - 不能因为用户反应好就盲目加分
|
||||||
|
2. 关系提升需要足够的互动积累和时间验证,单次互动加分通常不超过0.05-0.1
|
||||||
|
3. 必须考虑当前关系档次,不能跳跃式提升(比如从0.3直接到0.7)
|
||||||
|
4. 关系印象描述要详细具体(100-200字),包括:
|
||||||
|
- 用户性格特点和交流风格观察
|
||||||
|
- 印象深刻的互动记忆和对话片段
|
||||||
|
- 你们关系的具体状态描述和发展阶段
|
||||||
|
- 根据你的性格,你对用户的真实感受
|
||||||
|
|
||||||
|
性格视角深度分析:
|
||||||
|
1. 以你的性格特点,用户这次的反应给你什么感受?
|
||||||
|
2. 用户的情绪和行为符合你性格的喜好吗?具体哪些方面?
|
||||||
|
3. 从现实角度看,这次互动是否足以让关系提升到下一个档次?为什么?
|
||||||
|
4. 有什么特别值得记住的互动细节或对话内容?
|
||||||
|
5. 基于你们的互动历史,用户给你留下了哪些深刻印象?
|
||||||
|
|
||||||
|
请以JSON格式返回更新结果:
|
||||||
|
{{
|
||||||
|
"relationship_text": "详细的关系印象描述(100-200字),包含用户性格观察、印象深刻记忆、关系状态描述",
|
||||||
|
"relationship_score": 0.0~1.0的新分数(必须严格符合现实逻辑),
|
||||||
|
"analysis_reasoning": "从你性格角度的深度分析,重点说明分数调整的现实合理性",
|
||||||
|
"interaction_quality": "high/medium/low"
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 调用LLM进行分析
|
||||||
|
llm_response, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
|
||||||
|
|
||||||
|
if llm_response:
|
||||||
|
import json
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 清理LLM响应,移除可能的格式标记
|
||||||
|
cleaned_response = self._clean_llm_json_response(llm_response)
|
||||||
|
response_data = json.loads(cleaned_response)
|
||||||
|
|
||||||
|
new_text = response_data.get("relationship_text", current_text)
|
||||||
|
new_score = max(0.0, min(1.0, float(response_data.get("relationship_score", current_score))))
|
||||||
|
reasoning = response_data.get("analysis_reasoning", "")
|
||||||
|
quality = response_data.get("interaction_quality", "medium")
|
||||||
|
|
||||||
|
# 更新数据库
|
||||||
|
self._update_user_relationship_in_db(user_id, new_text, new_score)
|
||||||
|
|
||||||
|
# 更新缓存
|
||||||
|
self.user_relationship_cache[user_id] = {
|
||||||
|
"relationship_text": new_text,
|
||||||
|
"relationship_score": new_score,
|
||||||
|
"last_tracked": time.time(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果有兴趣度评分系统,也更新内存中的关系分
|
||||||
|
if self.interest_scoring_system:
|
||||||
|
self.interest_scoring_system.update_user_relationship(user_id, new_score - current_score)
|
||||||
|
|
||||||
|
# 记录分析历史
|
||||||
|
analysis_record = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"old_score": current_score,
|
||||||
|
"new_score": new_score,
|
||||||
|
"old_text": current_text,
|
||||||
|
"new_text": new_text,
|
||||||
|
"reasoning": reasoning,
|
||||||
|
"quality": quality,
|
||||||
|
"user_reactions_count": len(user_reactions),
|
||||||
|
}
|
||||||
|
self.relationship_history.append(analysis_record)
|
||||||
|
|
||||||
|
# 限制历史记录数量
|
||||||
|
if len(self.relationship_history) > 100:
|
||||||
|
self.relationship_history = self.relationship_history[-100:]
|
||||||
|
|
||||||
|
logger.info(f"✅ 关系分析完成: {user_id}")
|
||||||
|
logger.info(f" 📝 印象: '{current_text}' -> '{new_text}'")
|
||||||
|
logger.info(f" 💝 分数: {current_score:.3f} -> {new_score:.3f}")
|
||||||
|
logger.info(f" 🎯 质量: {quality}")
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"LLM响应JSON解析失败: {e}")
|
||||||
|
logger.debug(f"LLM原始响应: {llm_response}")
|
||||||
|
else:
|
||||||
|
logger.warning("LLM未返回有效响应")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"关系分析失败: {e}")
|
||||||
|
logger.debug("错误详情:", exc_info=True)
|
||||||
|
|
||||||
|
async def _handle_first_interaction(self, user_id: str, user_name: str, bot_reply_content: str):
|
||||||
|
"""处理与用户的初次交互"""
|
||||||
|
try:
|
||||||
|
logger.info(f"✨ [RelationshipTracker] 正在处理与用户 {user_id} 的初次交互")
|
||||||
|
|
||||||
|
# 获取bot人设信息
|
||||||
|
from src.individuality.individuality import Individuality
|
||||||
|
|
||||||
|
individuality = Individuality()
|
||||||
|
bot_personality = await individuality.get_personality_block()
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
你现在是:{bot_personality}
|
||||||
|
|
||||||
|
你正在与一个新用户进行初次有效互动。请根据你对TA的第一印象,建立初始关系档案。
|
||||||
|
|
||||||
|
用户信息:
|
||||||
|
- 用户ID: {user_id}
|
||||||
|
- 用户名: {user_name}
|
||||||
|
|
||||||
|
你的首次回复: {bot_reply_content}
|
||||||
|
|
||||||
|
【严格要求】:
|
||||||
|
1. 建立一个初始关系分数,通常在0.2-0.4之间(普通网友)。
|
||||||
|
2. 关系印象描述要简洁地记录你对用户的初步看法(50-100字)。
|
||||||
|
- 用户名给你的感觉?
|
||||||
|
- 你的回复是基于什么考虑?
|
||||||
|
- 你对接下来与TA的互动有什么期待?
|
||||||
|
|
||||||
|
请以JSON格式返回结果:
|
||||||
|
{{
|
||||||
|
"relationship_text": "简洁的初始关系印象描述(50-100字)",
|
||||||
|
"relationship_score": 0.2~0.4的新分数,
|
||||||
|
"analysis_reasoning": "从你性格角度说明建立此初始印象的理由"
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
# 调用LLM进行分析
|
||||||
|
llm_response, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
|
||||||
|
if not llm_response:
|
||||||
|
logger.warning(f"初次交互分析时LLM未返回有效响应: {user_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
cleaned_response = self._clean_llm_json_response(llm_response)
|
||||||
|
response_data = json.loads(cleaned_response)
|
||||||
|
|
||||||
|
new_text = response_data.get("relationship_text", "初次见面")
|
||||||
|
new_score = max(
|
||||||
|
0.0,
|
||||||
|
min(
|
||||||
|
1.0,
|
||||||
|
float(response_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新数据库和缓存
|
||||||
|
self._update_user_relationship_in_db(user_id, new_text, new_score)
|
||||||
|
self.user_relationship_cache[user_id] = {
|
||||||
|
"relationship_text": new_text,
|
||||||
|
"relationship_score": new_score,
|
||||||
|
"last_tracked": time.time(),
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"✅ [RelationshipTracker] 已成功为新用户 {user_id} 建立初始关系档案,分数为 {new_score:.3f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理初次交互失败: {user_id}, 错误: {e}")
|
||||||
|
logger.debug("错误详情:", exc_info=True)
|
||||||
|
|
||||||
|
def _clean_llm_json_response(self, response: str) -> str:
|
||||||
|
"""
|
||||||
|
清理LLM响应,移除可能的JSON格式标记
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: LLM原始响应
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
清理后的JSON字符串
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import re
|
||||||
|
|
||||||
|
# 移除常见的JSON格式标记
|
||||||
|
cleaned = response.strip()
|
||||||
|
|
||||||
|
# 移除 ```json 或 ``` 等标记
|
||||||
|
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE)
|
||||||
|
cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE)
|
||||||
|
|
||||||
|
# 移除可能的Markdown代码块标记
|
||||||
|
cleaned = re.sub(r"^`|`$", "", cleaned, flags=re.MULTILINE)
|
||||||
|
|
||||||
|
# 尝试找到JSON对象的开始和结束
|
||||||
|
json_start = cleaned.find("{")
|
||||||
|
json_end = cleaned.rfind("}")
|
||||||
|
|
||||||
|
if json_start != -1 and json_end != -1 and json_end > json_start:
|
||||||
|
# 提取JSON部分
|
||||||
|
cleaned = cleaned[json_start : json_end + 1]
|
||||||
|
|
||||||
|
# 移除多余的空白字符
|
||||||
|
cleaned = cleaned.strip()
|
||||||
|
|
||||||
|
logger.debug(f"LLM响应清理: 原始长度={len(response)}, 清理后长度={len(cleaned)}")
|
||||||
|
if cleaned != response:
|
||||||
|
logger.debug(f"清理前: {response[:200]}...")
|
||||||
|
logger.debug(f"清理后: {cleaned[:200]}...")
|
||||||
|
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"清理LLM响应失败: {e}")
|
||||||
|
return response # 清理失败时返回原始响应
|
||||||
@@ -85,7 +85,7 @@ class AtAction(BaseAction):
|
|||||||
reply_to=reply_to,
|
reply_to=reply_to,
|
||||||
extra_info=extra_info,
|
extra_info=extra_info,
|
||||||
enable_tool=False, # 艾特回复通常不需要工具调用
|
enable_tool=False, # 艾特回复通常不需要工具调用
|
||||||
from_plugin=False
|
from_plugin=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if success and llm_response:
|
if success and llm_response:
|
||||||
|
|||||||
@@ -27,7 +27,7 @@
|
|||||||
{
|
{
|
||||||
"type": "action",
|
"type": "action",
|
||||||
"name": "emoji",
|
"name": "emoji",
|
||||||
"description": "发送表情包辅助表达情绪"
|
"description": "作为一条全新的消息,发送一个符合当前情景的表情包来生动地表达情绪。"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user