Merge branch 'dev' into dev-api-ada to resolve conflicts
This commit is contained in:
@@ -2,7 +2,7 @@ import asyncio
|
||||
import time
|
||||
import traceback
|
||||
import random
|
||||
from typing import List, Optional, Dict, Any
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from rich.traceback import install
|
||||
|
||||
from src.config.config import global_config
|
||||
@@ -18,11 +18,12 @@ from src.chat.chat_loop.hfc_utils import CycleDetail
|
||||
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode
|
||||
from src.plugin_system.apis import generator_api, send_api, message_api
|
||||
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
|
||||
from src.chat.willing.willing_manager import get_willing_manager
|
||||
from src.mais4u.mai_think import mai_thinking_manager
|
||||
from maim_message.message_base import GroupInfo
|
||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||
from src.plugins.built_in.core_actions.no_reply import NoReplyAction
|
||||
from src.chat.chat_loop.hfc_utils import send_typing, stop_typing
|
||||
|
||||
ERROR_LOOP_INFO = {
|
||||
"loop_plan_info": {
|
||||
@@ -88,11 +89,6 @@ class HeartFChatting:
|
||||
|
||||
self.loop_mode = ChatMode.NORMAL # 初始循环模式为普通模式
|
||||
|
||||
# 新增:消息计数器和疲惫阈值
|
||||
self._message_count = 0 # 发送的消息计数
|
||||
self._message_threshold = max(10, int(30 * global_config.chat.focus_value))
|
||||
self._fatigue_triggered = False # 是否已触发疲惫退出
|
||||
|
||||
self.action_manager = ActionManager()
|
||||
self.action_planner = ActionPlanner(chat_id=self.stream_id, action_manager=self.action_manager)
|
||||
self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.stream_id)
|
||||
@@ -112,7 +108,6 @@ class HeartFChatting:
|
||||
|
||||
self.last_read_time = time.time() - 1
|
||||
|
||||
self.willing_amplifier = 1
|
||||
self.willing_manager = get_willing_manager()
|
||||
|
||||
logger.info(f"{self.log_prefix} HeartFChatting 初始化完成")
|
||||
@@ -182,6 +177,9 @@ class HeartFChatting:
|
||||
if self.loop_mode == ChatMode.NORMAL:
|
||||
self.energy_value -= 0.3
|
||||
self.energy_value = max(self.energy_value, 0.3)
|
||||
if self.loop_mode == ChatMode.FOCUS:
|
||||
self.energy_value -= 0.6
|
||||
self.energy_value = max(self.energy_value, 0.3)
|
||||
|
||||
def print_cycle_info(self, cycle_timers):
|
||||
# 记录循环信息和计时器结果
|
||||
@@ -200,9 +198,9 @@ class HeartFChatting:
|
||||
async def _loopbody(self):
|
||||
if self.loop_mode == ChatMode.FOCUS:
|
||||
if await self._observe():
|
||||
self.energy_value -= 1 * global_config.chat.focus_value
|
||||
self.energy_value -= 1 / global_config.chat.focus_value
|
||||
else:
|
||||
self.energy_value -= 3 * global_config.chat.focus_value
|
||||
self.energy_value -= 3 / global_config.chat.focus_value
|
||||
if self.energy_value <= 1:
|
||||
self.energy_value = 1
|
||||
self.loop_mode = ChatMode.NORMAL
|
||||
@@ -218,15 +216,17 @@ class HeartFChatting:
|
||||
limit_mode="earliest",
|
||||
filter_bot=True,
|
||||
)
|
||||
if global_config.chat.focus_value != 0:
|
||||
if len(new_messages_data) > 3 / pow(global_config.chat.focus_value, 0.5):
|
||||
self.loop_mode = ChatMode.FOCUS
|
||||
self.energy_value = (
|
||||
10 + (len(new_messages_data) / (3 / pow(global_config.chat.focus_value, 0.5))) * 10
|
||||
)
|
||||
return True
|
||||
|
||||
if len(new_messages_data) > 3 * global_config.chat.focus_value:
|
||||
self.loop_mode = ChatMode.FOCUS
|
||||
self.energy_value = 10 + (len(new_messages_data) / (3 * global_config.chat.focus_value)) * 10
|
||||
return True
|
||||
|
||||
if self.energy_value >= 30 * global_config.chat.focus_value:
|
||||
self.loop_mode = ChatMode.FOCUS
|
||||
return True
|
||||
if self.energy_value >= 30:
|
||||
self.loop_mode = ChatMode.FOCUS
|
||||
return True
|
||||
|
||||
if new_messages_data:
|
||||
earliest_messages_data = new_messages_data[0]
|
||||
@@ -235,10 +235,10 @@ class HeartFChatting:
|
||||
if_think = await self.normal_response(earliest_messages_data)
|
||||
if if_think:
|
||||
factor = max(global_config.chat.focus_value, 0.1)
|
||||
self.energy_value *= 1.1 / factor
|
||||
self.energy_value *= 1.1 * factor
|
||||
logger.info(f"{self.log_prefix} 进行了思考,能量值按倍数增加,当前能量值:{self.energy_value:.1f}")
|
||||
else:
|
||||
self.energy_value += 0.1 / global_config.chat.focus_value
|
||||
self.energy_value += 0.1 * global_config.chat.focus_value
|
||||
logger.debug(f"{self.log_prefix} 没有进行思考,能量值线性增加,当前能量值:{self.energy_value:.1f}")
|
||||
|
||||
logger.debug(f"{self.log_prefix} 当前能量值:{self.energy_value:.1f}")
|
||||
@@ -257,44 +257,69 @@ class HeartFChatting:
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
return f"{person_name}:{message_data.get('processed_plain_text')}"
|
||||
|
||||
async def send_typing(self):
|
||||
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
|
||||
async def _send_and_store_reply(
|
||||
self,
|
||||
response_set,
|
||||
reply_to_str,
|
||||
loop_start_time,
|
||||
action_message,
|
||||
cycle_timers: Dict[str, float],
|
||||
thinking_id,
|
||||
plan_result,
|
||||
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
|
||||
with Timer("回复发送", cycle_timers):
|
||||
reply_text = await self._send_response(response_set, reply_to_str, loop_start_time, action_message)
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform="amaidesu_default",
|
||||
user_info=None,
|
||||
group_info=group_info,
|
||||
# 存储reply action信息
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id(
|
||||
action_message.get("chat_info_platform", ""),
|
||||
action_message.get("user_id", ""),
|
||||
)
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
|
||||
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=action_prompt_display,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reply_text": reply_text, "reply_to": reply_to_str},
|
||||
action_name="reply",
|
||||
)
|
||||
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
|
||||
)
|
||||
# 构建循环信息
|
||||
loop_info: Dict[str, Any] = {
|
||||
"loop_plan_info": {
|
||||
"action_result": plan_result.get("action_result", {}),
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": True,
|
||||
"reply_text": reply_text,
|
||||
"command": "",
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
|
||||
async def stop_typing(self):
|
||||
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform="amaidesu_default",
|
||||
user_info=None,
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False
|
||||
)
|
||||
return loop_info, reply_text, cycle_timers
|
||||
|
||||
async def _observe(self, message_data: Optional[Dict[str, Any]] = None):
|
||||
# sourcery skip: hoist-statement-from-if, merge-comparisons, reintroduce-else
|
||||
if not message_data:
|
||||
message_data = {}
|
||||
action_type = "no_action"
|
||||
reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
||||
gen_task = None # 初始化gen_task变量,避免UnboundLocalError
|
||||
reply_to_str = "" # 初始化reply_to_str变量
|
||||
|
||||
# 创建新的循环信息
|
||||
cycle_timers, thinking_id = self.start_cycle()
|
||||
|
||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考[模式:{self.loop_mode}]")
|
||||
|
||||
|
||||
if ENABLE_S4U:
|
||||
await self.send_typing()
|
||||
await send_typing()
|
||||
|
||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||
loop_start_time = time.time()
|
||||
@@ -310,95 +335,254 @@ class HeartFChatting:
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
|
||||
|
||||
# 如果normal,开始一个回复生成进程,先准备好回复(其实是和planer同时进行的)
|
||||
# 检查是否在normal模式下没有可用动作(除了reply相关动作)
|
||||
skip_planner = False
|
||||
if self.loop_mode == ChatMode.NORMAL:
|
||||
reply_to_str = await self.build_reply_to_str(message_data)
|
||||
gen_task = asyncio.create_task(self._generate_response(message_data, available_actions, reply_to_str))
|
||||
# 过滤掉reply相关的动作,检查是否还有其他动作
|
||||
non_reply_actions = {
|
||||
k: v for k, v in available_actions.items() if k not in ["reply", "no_reply", "no_action"]
|
||||
}
|
||||
|
||||
with Timer("规划器", cycle_timers):
|
||||
plan_result, target_message = await self.action_planner.plan(mode=self.loop_mode)
|
||||
if not non_reply_actions:
|
||||
skip_planner = True
|
||||
logger.info(f"{self.log_prefix} Normal模式下没有可用动作,直接回复")
|
||||
|
||||
action_result: dict = plan_result.get("action_result", {}) # type: ignore
|
||||
action_type, action_data, reasoning, is_parallel = (
|
||||
action_result.get("action_type", "error"),
|
||||
action_result.get("action_data", {}),
|
||||
action_result.get("reasoning", "未提供理由"),
|
||||
action_result.get("is_parallel", True),
|
||||
)
|
||||
# 直接设置为reply动作
|
||||
action_type = "reply"
|
||||
reasoning = ""
|
||||
action_data = {"loop_start_time": loop_start_time}
|
||||
is_parallel = False
|
||||
|
||||
action_data["loop_start_time"] = loop_start_time
|
||||
# 构建plan_result用于后续处理
|
||||
plan_result = {
|
||||
"action_result": {
|
||||
"action_type": action_type,
|
||||
"action_data": action_data,
|
||||
"reasoning": reasoning,
|
||||
"timestamp": time.time(),
|
||||
"is_parallel": is_parallel,
|
||||
},
|
||||
"action_prompt": "",
|
||||
}
|
||||
target_message = message_data
|
||||
|
||||
if self.loop_mode == ChatMode.NORMAL:
|
||||
if action_type == "no_action":
|
||||
logger.info(f"[{self.log_prefix}] {global_config.bot.nickname} 决定进行回复")
|
||||
elif is_parallel:
|
||||
logger.info(
|
||||
f"[{self.log_prefix}] {global_config.bot.nickname} 决定进行回复, 同时执行{action_type}动作"
|
||||
# 如果normal模式且不跳过规划器,开始一个回复生成进程,先准备好回复(其实是和planer同时进行的)
|
||||
if not skip_planner:
|
||||
reply_to_str = await self.build_reply_to_str(message_data)
|
||||
gen_task = asyncio.create_task(
|
||||
self._generate_response(
|
||||
message_data=message_data,
|
||||
available_actions=available_actions,
|
||||
reply_to=reply_to_str,
|
||||
request_type="chat.replyer.normal",
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.info(f"[{self.log_prefix}] {global_config.bot.nickname} 决定执行{action_type}动作")
|
||||
|
||||
if action_type == "no_action":
|
||||
if not skip_planner:
|
||||
with Timer("规划器", cycle_timers):
|
||||
plan_result, target_message = await self.action_planner.plan(mode=self.loop_mode)
|
||||
|
||||
action_result: Dict[str, Any] = plan_result.get("action_result", {}) # type: ignore
|
||||
action_type, action_data, reasoning, is_parallel = (
|
||||
action_result.get("action_type", "error"),
|
||||
action_result.get("action_data", {}),
|
||||
action_result.get("reasoning", "未提供理由"),
|
||||
action_result.get("is_parallel", True),
|
||||
)
|
||||
|
||||
action_data["loop_start_time"] = loop_start_time
|
||||
|
||||
if action_type == "reply":
|
||||
logger.info(f"{self.log_prefix}{global_config.bot.nickname} 决定进行回复")
|
||||
elif is_parallel:
|
||||
logger.info(f"{self.log_prefix}{global_config.bot.nickname} 决定进行回复, 同时执行{action_type}动作")
|
||||
else:
|
||||
# 只有在gen_task存在时才进行相关操作
|
||||
if gen_task:
|
||||
if not gen_task.done():
|
||||
gen_task.cancel()
|
||||
logger.debug(f"{self.log_prefix} 已取消预生成的回复任务")
|
||||
logger.info(
|
||||
f"{self.log_prefix}{global_config.bot.nickname} 原本想要回复,但选择执行{action_type},不发表回复"
|
||||
)
|
||||
elif generation_result := gen_task.result():
|
||||
content = " ".join([item[1] for item in generation_result if item[0] == "text"])
|
||||
logger.debug(f"{self.log_prefix} 预生成的回复任务已完成")
|
||||
logger.info(
|
||||
f"{self.log_prefix}{global_config.bot.nickname} 原本想要回复:{content},但选择执行{action_type},不发表回复"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 预生成的回复任务未生成有效内容")
|
||||
|
||||
action_message: Dict[str, Any] = message_data or target_message # type: ignore
|
||||
if action_type == "reply":
|
||||
# 等待回复生成完毕
|
||||
gather_timeout = global_config.chat.thinking_timeout
|
||||
try:
|
||||
response_set = await asyncio.wait_for(gen_task, timeout=gather_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
response_set = None
|
||||
if self.loop_mode == ChatMode.NORMAL:
|
||||
# 只有在gen_task存在时才等待
|
||||
if not gen_task:
|
||||
reply_to_str = await self.build_reply_to_str(message_data)
|
||||
gen_task = asyncio.create_task(
|
||||
self._generate_response(
|
||||
message_data=message_data,
|
||||
available_actions=available_actions,
|
||||
reply_to=reply_to_str,
|
||||
request_type="chat.replyer.normal",
|
||||
)
|
||||
)
|
||||
|
||||
if response_set:
|
||||
content = " ".join([item[1] for item in response_set if item[0] == "text"])
|
||||
gather_timeout = global_config.chat.thinking_timeout
|
||||
try:
|
||||
response_set = await asyncio.wait_for(gen_task, timeout=gather_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"{self.log_prefix} 回复生成超时>{global_config.chat.thinking_timeout}s,已跳过")
|
||||
response_set = None
|
||||
|
||||
# 模型炸了,没有回复内容生成
|
||||
if not response_set:
|
||||
logger.warning(f"[{self.log_prefix}] 模型未生成回复内容")
|
||||
return False
|
||||
elif action_type not in ["no_action"] and not is_parallel:
|
||||
logger.info(
|
||||
f"[{self.log_prefix}] {global_config.bot.nickname} 原本想要回复:{content},但选择执行{action_type},不发表回复"
|
||||
)
|
||||
return False
|
||||
# 模型炸了或超时,没有回复内容生成
|
||||
if not response_set:
|
||||
logger.warning(f"{self.log_prefix}模型未生成回复内容")
|
||||
return False
|
||||
else:
|
||||
logger.info(f"{self.log_prefix}{global_config.bot.nickname} 决定进行回复 (focus模式)")
|
||||
|
||||
logger.info(f"[{self.log_prefix}] {global_config.bot.nickname} 决定的回复内容: {content}")
|
||||
# 构建reply_to字符串
|
||||
reply_to_str = await self.build_reply_to_str(action_message)
|
||||
|
||||
# 发送回复 (不再需要传入 chat)
|
||||
reply_text = await self._send_response(response_set, reply_to_str, loop_start_time,message_data)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if ENABLE_S4U:
|
||||
await self.stop_typing()
|
||||
await mai_thinking_manager.get_mai_think(self.stream_id).do_think_after_response(reply_text)
|
||||
# 生成回复
|
||||
with Timer("回复生成", cycle_timers):
|
||||
response_set = await self._generate_response(
|
||||
message_data=action_message,
|
||||
available_actions=available_actions,
|
||||
reply_to=reply_to_str,
|
||||
request_type="chat.replyer.focus",
|
||||
)
|
||||
|
||||
if not response_set:
|
||||
logger.warning(f"{self.log_prefix}模型未生成回复内容")
|
||||
return False
|
||||
|
||||
loop_info, reply_text, cycle_timers = await self._send_and_store_reply(
|
||||
response_set, reply_to_str, loop_start_time, action_message, cycle_timers, thinking_id, plan_result
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
else:
|
||||
action_message: Dict[str, Any] = message_data or target_message # type: ignore
|
||||
# 并行执行:同时进行回复发送和动作执行
|
||||
# 先置空防止未定义错误
|
||||
background_reply_task = None
|
||||
background_action_task = None
|
||||
# 如果是并行执行且在normal模式下,需要等待预生成的回复任务完成并发送回复
|
||||
if self.loop_mode == ChatMode.NORMAL and is_parallel and gen_task:
|
||||
|
||||
# 动作执行计时
|
||||
with Timer("动作执行", cycle_timers):
|
||||
success, reply_text, command = await self._handle_action(
|
||||
action_type, reasoning, action_data, cycle_timers, thinking_id, action_message
|
||||
async def handle_reply_task() -> Tuple[Optional[Dict[str, Any]], str, Dict[str, float]]:
|
||||
# 等待预生成的回复任务完成
|
||||
gather_timeout = global_config.chat.thinking_timeout
|
||||
try:
|
||||
response_set = await asyncio.wait_for(gen_task, timeout=gather_timeout)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"{self.log_prefix} 并行执行:回复生成超时>{global_config.chat.thinking_timeout}s,已跳过"
|
||||
)
|
||||
return None, "", {}
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
|
||||
return None, "", {}
|
||||
|
||||
if not response_set:
|
||||
logger.warning(f"{self.log_prefix} 模型超时或生成回复内容为空")
|
||||
return None, "", {}
|
||||
|
||||
reply_to_str = await self.build_reply_to_str(action_message)
|
||||
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
|
||||
response_set,
|
||||
reply_to_str,
|
||||
loop_start_time,
|
||||
action_message,
|
||||
cycle_timers,
|
||||
thinking_id,
|
||||
plan_result,
|
||||
)
|
||||
return loop_info, reply_text, cycle_timers_reply
|
||||
|
||||
# 执行回复任务并赋值到变量
|
||||
background_reply_task = asyncio.create_task(handle_reply_task())
|
||||
|
||||
# 动作执行任务
|
||||
async def handle_action_task():
|
||||
with Timer("动作执行", cycle_timers):
|
||||
success, reply_text, command = await self._handle_action(
|
||||
action_type, reasoning, action_data, cycle_timers, thinking_id, action_message
|
||||
)
|
||||
return success, reply_text, command
|
||||
|
||||
# 执行动作任务并赋值到变量
|
||||
background_action_task = asyncio.create_task(handle_action_task())
|
||||
|
||||
reply_loop_info = None
|
||||
reply_text_from_reply = ""
|
||||
action_success = False
|
||||
action_reply_text = ""
|
||||
action_command = ""
|
||||
|
||||
# 并行执行所有任务
|
||||
if background_reply_task:
|
||||
results = await asyncio.gather(
|
||||
background_reply_task, background_action_task, return_exceptions=True
|
||||
)
|
||||
# 处理回复任务结果
|
||||
reply_result = results[0]
|
||||
if isinstance(reply_result, BaseException):
|
||||
logger.error(f"{self.log_prefix} 回复任务执行异常: {reply_result}")
|
||||
elif reply_result and reply_result[0] is not None:
|
||||
reply_loop_info, reply_text_from_reply, _ = reply_result
|
||||
|
||||
loop_info = {
|
||||
"loop_plan_info": {
|
||||
"action_result": plan_result.get("action_result", {}),
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": success,
|
||||
"reply_text": reply_text,
|
||||
"command": command,
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
# 处理动作任务结果
|
||||
action_task_result = results[1]
|
||||
if isinstance(action_task_result, BaseException):
|
||||
logger.error(f"{self.log_prefix} 动作任务执行异常: {action_task_result}")
|
||||
else:
|
||||
action_success, action_reply_text, action_command = action_task_result
|
||||
else:
|
||||
results = await asyncio.gather(background_action_task, return_exceptions=True)
|
||||
# 只有动作任务
|
||||
action_task_result = results[0]
|
||||
if isinstance(action_task_result, BaseException):
|
||||
logger.error(f"{self.log_prefix} 动作任务执行异常: {action_task_result}")
|
||||
else:
|
||||
action_success, action_reply_text, action_command = action_task_result
|
||||
|
||||
if loop_info["loop_action_info"]["command"] == "stop_focus_chat":
|
||||
logger.info(f"{self.log_prefix} 麦麦决定停止专注聊天")
|
||||
return False
|
||||
# 停止该聊天模式的循环
|
||||
# 构建最终的循环信息
|
||||
if reply_loop_info:
|
||||
# 如果有回复信息,使用回复的loop_info作为基础
|
||||
loop_info = reply_loop_info
|
||||
# 更新动作执行信息
|
||||
loop_info["loop_action_info"].update(
|
||||
{
|
||||
"action_taken": action_success,
|
||||
"command": action_command,
|
||||
"taken_time": time.time(),
|
||||
}
|
||||
)
|
||||
reply_text = reply_text_from_reply
|
||||
else:
|
||||
# 没有回复信息,构建纯动作的loop_info
|
||||
loop_info = {
|
||||
"loop_plan_info": {
|
||||
"action_result": plan_result.get("action_result", {}),
|
||||
},
|
||||
"loop_action_info": {
|
||||
"action_taken": action_success,
|
||||
"reply_text": action_reply_text,
|
||||
"command": action_command,
|
||||
"taken_time": time.time(),
|
||||
},
|
||||
}
|
||||
reply_text = action_reply_text
|
||||
|
||||
if ENABLE_S4U:
|
||||
await stop_typing()
|
||||
await mai_thinking_manager.get_mai_think(self.stream_id).do_think_after_response(reply_text)
|
||||
|
||||
self.end_cycle(loop_info, cycle_timers)
|
||||
self.print_cycle_info(cycle_timers)
|
||||
@@ -406,8 +590,16 @@ class HeartFChatting:
|
||||
if self.loop_mode == ChatMode.NORMAL:
|
||||
await self.willing_manager.after_generate_reply_handle(message_data.get("message_id", ""))
|
||||
|
||||
# 管理no_reply计数器:当执行了非no_reply动作时,重置计数器
|
||||
if action_type != "no_reply" and action_type != "no_action":
|
||||
# 导入NoReplyAction并重置计数器
|
||||
NoReplyAction.reset_consecutive_count()
|
||||
logger.info(f"{self.log_prefix} 执行了{action_type}动作,重置no_reply计数器")
|
||||
return True
|
||||
elif action_type == "no_action":
|
||||
# 当执行回复动作时,也重置no_reply计数器s
|
||||
NoReplyAction.reset_consecutive_count()
|
||||
logger.info(f"{self.log_prefix} 执行了回复动作,重置no_reply计数器")
|
||||
|
||||
return True
|
||||
|
||||
@@ -435,7 +627,7 @@ class HeartFChatting:
|
||||
action: str,
|
||||
reasoning: str,
|
||||
action_data: dict,
|
||||
cycle_timers: dict,
|
||||
cycle_timers: Dict[str, float],
|
||||
thinking_id: str,
|
||||
action_message: dict,
|
||||
) -> tuple[bool, str, str]:
|
||||
@@ -501,7 +693,7 @@ class HeartFChatting:
|
||||
在"兴趣"模式下,判断是否回复并生成内容。
|
||||
"""
|
||||
|
||||
interested_rate = (message_data.get("interest_value") or 0.0) * self.willing_amplifier
|
||||
interested_rate = (message_data.get("interest_value") or 0.0) * global_config.chat.willing_amplifier
|
||||
|
||||
self.willing_manager.setup(message_data, self.chat_stream)
|
||||
|
||||
@@ -515,8 +707,8 @@ class HeartFChatting:
|
||||
reply_probability += additional_config["maimcore_reply_probability_gain"]
|
||||
reply_probability = min(max(reply_probability, 0), 1) # 确保概率在 0-1 之间
|
||||
|
||||
talk_frequency = global_config.chat.get_current_talk_frequency(self.stream_id)
|
||||
reply_probability = talk_frequency * reply_probability
|
||||
talk_frequency = global_config.chat.get_current_talk_frequency(self.stream_id)
|
||||
reply_probability = talk_frequency * reply_probability
|
||||
|
||||
# 处理表情包
|
||||
if message_data.get("is_emoji") or message_data.get("is_picid"):
|
||||
@@ -544,7 +736,11 @@ class HeartFChatting:
|
||||
return False
|
||||
|
||||
async def _generate_response(
|
||||
self, message_data: dict, available_actions: Optional[Dict[str, ActionInfo]], reply_to: str
|
||||
self,
|
||||
message_data: dict,
|
||||
available_actions: Optional[Dict[str, ActionInfo]],
|
||||
reply_to: str,
|
||||
request_type: str = "chat.replyer.normal",
|
||||
) -> Optional[list]:
|
||||
"""生成普通回复"""
|
||||
try:
|
||||
@@ -552,8 +748,8 @@ class HeartFChatting:
|
||||
chat_stream=self.chat_stream,
|
||||
reply_to=reply_to,
|
||||
available_actions=available_actions,
|
||||
enable_tool=global_config.tool.enable_in_normal_chat,
|
||||
request_type="chat.replyer.normal",
|
||||
enable_tool=global_config.tool.enable_tool,
|
||||
request_type=request_type,
|
||||
)
|
||||
|
||||
if not success or not reply_set:
|
||||
@@ -563,10 +759,10 @@ class HeartFChatting:
|
||||
return reply_set
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.log_prefix}] 回复生成出现错误:{str(e)} {traceback.format_exc()}")
|
||||
logger.error(f"{self.log_prefix}回复生成出现错误:{str(e)} {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
async def _send_response(self, reply_set, reply_to, thinking_start_time, message_data):
|
||||
async def _send_response(self, reply_set, reply_to, thinking_start_time, message_data) -> str:
|
||||
current_time = time.time()
|
||||
new_message_count = message_api.count_new_messages(
|
||||
chat_id=self.chat_stream.stream_id, start_time=thinking_start_time, end_time=current_time
|
||||
@@ -578,13 +774,9 @@ class HeartFChatting:
|
||||
need_reply = new_message_count >= random.randint(2, 4)
|
||||
|
||||
if need_reply:
|
||||
logger.info(
|
||||
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复"
|
||||
)
|
||||
logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复")
|
||||
else:
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,不使用引用回复"
|
||||
)
|
||||
logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,不使用引用回复")
|
||||
|
||||
reply_text = ""
|
||||
first_replied = False
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import time
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.message_repository import count_messages
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.plugin_system.apis import send_api
|
||||
from maim_message.message_base import GroupInfo
|
||||
|
||||
from src.common.message_repository import count_messages
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -106,3 +109,30 @@ def get_recent_message_stats(minutes: float = 30, chat_id: Optional[str] = None)
|
||||
bot_reply_count = count_messages(bot_filter)
|
||||
|
||||
return {"bot_reply_count": bot_reply_count, "total_message_count": total_message_count}
|
||||
|
||||
|
||||
async def send_typing():
|
||||
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform="amaidesu_default",
|
||||
user_info=None,
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
|
||||
)
|
||||
|
||||
async def stop_typing():
|
||||
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform="amaidesu_default",
|
||||
user_info=None,
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False
|
||||
)
|
||||
@@ -525,9 +525,9 @@ class EmojiManager:
|
||||
如果文件已被删除,则执行对象的删除方法并从列表中移除
|
||||
"""
|
||||
try:
|
||||
if not self.emoji_objects:
|
||||
logger.warning("[检查] emoji_objects为空,跳过完整性检查")
|
||||
return
|
||||
# if not self.emoji_objects:
|
||||
# logger.warning("[检查] emoji_objects为空,跳过完整性检查")
|
||||
# return
|
||||
|
||||
total_count = len(self.emoji_objects)
|
||||
self.emoji_num = total_count
|
||||
@@ -707,6 +707,38 @@ class EmojiManager:
|
||||
return emoji
|
||||
return None # 如果循环结束还没找到,则返回 None
|
||||
|
||||
async def get_emoji_description_by_hash(self, emoji_hash: str) -> Optional[str]:
|
||||
"""根据哈希值获取已注册表情包的描述
|
||||
|
||||
Args:
|
||||
emoji_hash: 表情包的哈希值
|
||||
|
||||
Returns:
|
||||
Optional[str]: 表情包描述,如果未找到则返回None
|
||||
"""
|
||||
try:
|
||||
# 先从内存中查找
|
||||
emoji = await self.get_emoji_from_manager(emoji_hash)
|
||||
if emoji and emoji.description:
|
||||
logger.info(f"[缓存命中] 从内存获取表情包描述: {emoji.description[:50]}...")
|
||||
return emoji.description
|
||||
|
||||
# 如果内存中没有,从数据库查找
|
||||
self._ensure_db()
|
||||
try:
|
||||
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
|
||||
if emoji_record and emoji_record.description:
|
||||
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
|
||||
return emoji_record.description
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库查询表情包描述时出错: {e}")
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}")
|
||||
return None
|
||||
|
||||
async def delete_emoji(self, emoji_hash: str) -> bool:
|
||||
"""根据哈希值删除表情包
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ def init_prompt() -> None:
|
||||
当"想说明某个具体的事实观点,但懒得明说,或者不便明说,或表达一种默契",使用"懂的都懂"
|
||||
当"当涉及游戏相关时,表示意外的夸赞,略带戏谑意味"时,使用"这么强!"
|
||||
|
||||
注意不要总结你自己(SELF)的发言
|
||||
请注意:不要总结你自己(SELF)的发言
|
||||
现在请你概括
|
||||
"""
|
||||
Prompt(learn_style_prompt, "learn_style_prompt")
|
||||
@@ -330,48 +330,8 @@ class ExpressionLearner:
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# 全局衰减所有已存储的表达方式
|
||||
for type in ["style", "grammar"]:
|
||||
base_dir = os.path.join("data", "expression", f"learnt_{type}")
|
||||
if not os.path.exists(base_dir):
|
||||
logger.debug(f"目录不存在,跳过衰减: {base_dir}")
|
||||
continue
|
||||
|
||||
try:
|
||||
chat_ids = os.listdir(base_dir)
|
||||
logger.debug(f"在 {base_dir} 中找到 {len(chat_ids)} 个聊天ID目录进行衰减")
|
||||
except Exception as e:
|
||||
logger.error(f"读取目录失败 {base_dir}: {e}")
|
||||
continue
|
||||
|
||||
for chat_id in chat_ids:
|
||||
file_path = os.path.join(base_dir, chat_id, "expressions.json")
|
||||
if not os.path.exists(file_path):
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
expressions = json.load(f)
|
||||
|
||||
if not isinstance(expressions, list):
|
||||
logger.warning(f"表达方式文件格式错误,跳过衰减: {file_path}")
|
||||
continue
|
||||
|
||||
# 应用全局衰减
|
||||
decayed_expressions = self.apply_decay_to_expressions(expressions, current_time)
|
||||
|
||||
# 保存衰减后的结果
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(decayed_expressions, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.debug(f"已对 {file_path} 应用衰减,剩余 {len(decayed_expressions)} 个表达方式")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析失败,跳过衰减 {file_path}: {e}")
|
||||
except PermissionError as e:
|
||||
logger.error(f"权限不足,无法更新 {file_path}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"全局衰减{type}表达方式失败 {file_path}: {e}")
|
||||
continue
|
||||
# 全局衰减所有已存储的表达方式(直接操作数据库)
|
||||
self._apply_global_decay_to_database(current_time)
|
||||
|
||||
learnt_style: Optional[List[Tuple[str, str, str]]] = []
|
||||
learnt_grammar: Optional[List[Tuple[str, str, str]]] = []
|
||||
@@ -388,6 +348,42 @@ class ExpressionLearner:
|
||||
|
||||
return learnt_style, learnt_grammar
|
||||
|
||||
def _apply_global_decay_to_database(self, current_time: float) -> None:
|
||||
"""
|
||||
对数据库中的所有表达方式应用全局衰减
|
||||
"""
|
||||
try:
|
||||
# 获取所有表达方式
|
||||
all_expressions = Expression.select()
|
||||
|
||||
updated_count = 0
|
||||
deleted_count = 0
|
||||
|
||||
for expr in all_expressions:
|
||||
# 计算时间差
|
||||
last_active = expr.last_active_time
|
||||
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
|
||||
|
||||
# 计算衰减值
|
||||
decay_value = self.calculate_decay_factor(time_diff_days)
|
||||
new_count = max(0.01, expr.count - decay_value)
|
||||
|
||||
if new_count <= 0.01:
|
||||
# 如果count太小,删除这个表达方式
|
||||
expr.delete_instance()
|
||||
deleted_count += 1
|
||||
else:
|
||||
# 更新count
|
||||
expr.count = new_count
|
||||
expr.save()
|
||||
updated_count += 1
|
||||
|
||||
if updated_count > 0 or deleted_count > 0:
|
||||
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"数据库全局衰减失败: {e}")
|
||||
|
||||
def calculate_decay_factor(self, time_diff_days: float) -> float:
|
||||
"""
|
||||
计算衰减值
|
||||
@@ -410,30 +406,6 @@ class ExpressionLearner:
|
||||
|
||||
return min(0.01, decay)
|
||||
|
||||
def apply_decay_to_expressions(
|
||||
self, expressions: List[Dict[str, Any]], current_time: float
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
对表达式列表应用衰减
|
||||
返回衰减后的表达式列表,移除count小于0的项
|
||||
"""
|
||||
result = []
|
||||
for expr in expressions:
|
||||
# 确保last_active_time存在,如果不存在则使用current_time
|
||||
if "last_active_time" not in expr:
|
||||
expr["last_active_time"] = current_time
|
||||
|
||||
last_active = expr["last_active_time"]
|
||||
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
|
||||
|
||||
decay_value = self.calculate_decay_factor(time_diff_days)
|
||||
expr["count"] = max(0.01, expr.get("count", 1) - decay_value)
|
||||
|
||||
if expr["count"] > 0:
|
||||
result.append(expr)
|
||||
|
||||
return result
|
||||
|
||||
async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]:
|
||||
# sourcery skip: use-join
|
||||
"""
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import time
|
||||
import random
|
||||
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from typing import List, Dict, Tuple, Optional, Any
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
@@ -117,36 +117,42 @@ class ExpressionSelector:
|
||||
|
||||
def get_random_expressions(
|
||||
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
|
||||
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
|
||||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
style_exprs = []
|
||||
grammar_exprs = []
|
||||
for cid in related_chat_ids:
|
||||
style_query = Expression.select().where((Expression.chat_id == cid) & (Expression.type == "style"))
|
||||
grammar_query = Expression.select().where((Expression.chat_id == cid) & (Expression.type == "grammar"))
|
||||
style_exprs.extend([
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": cid,
|
||||
"type": "style",
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
} for expr in style_query
|
||||
])
|
||||
grammar_exprs.extend([
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": cid,
|
||||
"type": "grammar",
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
} for expr in grammar_query
|
||||
])
|
||||
|
||||
# 优化:一次性查询所有相关chat_id的表达方式
|
||||
style_query = Expression.select().where(
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
|
||||
)
|
||||
grammar_query = Expression.select().where(
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar")
|
||||
)
|
||||
|
||||
style_exprs = [
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"type": "style",
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
} for expr in style_query
|
||||
]
|
||||
|
||||
grammar_exprs = [
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"type": "grammar",
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
} for expr in grammar_query
|
||||
]
|
||||
|
||||
style_num = int(total_num * style_percentage)
|
||||
grammar_num = int(total_num * grammar_percentage)
|
||||
# 按权重抽样(使用count作为权重)
|
||||
@@ -162,7 +168,7 @@ class ExpressionSelector:
|
||||
selected_grammar = []
|
||||
return selected_style, selected_grammar
|
||||
|
||||
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, str]], increment: float = 0.1):
|
||||
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
|
||||
"""对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库"""
|
||||
if not expressions_to_update:
|
||||
return
|
||||
@@ -203,7 +209,7 @@ class ExpressionSelector:
|
||||
max_num: int = 10,
|
||||
min_num: int = 5,
|
||||
target_message: Optional[str] = None,
|
||||
) -> List[Dict[str, str]]:
|
||||
) -> List[Dict[str, Any]]:
|
||||
# sourcery skip: inline-variable, list-comprehension
|
||||
"""使用LLM选择适合的表达方式"""
|
||||
|
||||
@@ -273,6 +279,7 @@ class ExpressionSelector:
|
||||
|
||||
if not isinstance(result, dict) or "selected_situations" not in result:
|
||||
logger.error("LLM返回格式错误")
|
||||
logger.info(f"LLM返回结果: \n{content}")
|
||||
return []
|
||||
|
||||
selected_indices = result["selected_situations"]
|
||||
|
||||
@@ -12,6 +12,7 @@ from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.heart_flow.heartflow import heartflow
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.chat_message_builder import replace_user_references_sync
|
||||
from src.common.logger import get_logger
|
||||
from src.person_info.relationship_manager import get_relationship_manager
|
||||
from src.mood.mood_manager import mood_manager
|
||||
@@ -56,16 +57,41 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
|
||||
with Timer("记忆激活"):
|
||||
interested_rate = await hippocampus_manager.get_activate_from_text(
|
||||
message.processed_plain_text,
|
||||
max_depth= 5,
|
||||
fast_retrieval=False,
|
||||
)
|
||||
logger.debug(f"记忆激活率: {interested_rate:.2f}")
|
||||
|
||||
text_len = len(message.processed_plain_text)
|
||||
# 根据文本长度调整兴趣度,长度越大兴趣度越高,但增长率递减,最低0.01,最高0.05
|
||||
# 采用对数函数实现递减增长
|
||||
|
||||
base_interest = 0.01 + (0.05 - 0.01) * (math.log10(text_len + 1) / math.log10(1000 + 1))
|
||||
base_interest = min(max(base_interest, 0.01), 0.05)
|
||||
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
|
||||
# 基于实际分布:0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%)
|
||||
|
||||
if text_len == 0:
|
||||
base_interest = 0.01 # 空消息最低兴趣度
|
||||
elif text_len <= 5:
|
||||
# 1-5字符:线性增长 0.01 -> 0.03
|
||||
base_interest = 0.01 + (text_len - 1) * (0.03 - 0.01) / 4
|
||||
elif text_len <= 10:
|
||||
# 6-10字符:线性增长 0.03 -> 0.06
|
||||
base_interest = 0.03 + (text_len - 5) * (0.06 - 0.03) / 5
|
||||
elif text_len <= 20:
|
||||
# 11-20字符:线性增长 0.06 -> 0.12
|
||||
base_interest = 0.06 + (text_len - 10) * (0.12 - 0.06) / 10
|
||||
elif text_len <= 30:
|
||||
# 21-30字符:线性增长 0.12 -> 0.18
|
||||
base_interest = 0.12 + (text_len - 20) * (0.18 - 0.12) / 10
|
||||
elif text_len <= 50:
|
||||
# 31-50字符:线性增长 0.18 -> 0.22
|
||||
base_interest = 0.18 + (text_len - 30) * (0.22 - 0.18) / 20
|
||||
elif text_len <= 100:
|
||||
# 51-100字符:线性增长 0.22 -> 0.26
|
||||
base_interest = 0.22 + (text_len - 50) * (0.26 - 0.22) / 50
|
||||
else:
|
||||
# 100+字符:对数增长 0.26 -> 0.3,增长率递减
|
||||
base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901
|
||||
|
||||
# 确保在范围内
|
||||
base_interest = min(max(base_interest, 0.01), 0.3)
|
||||
|
||||
interested_rate += base_interest
|
||||
|
||||
@@ -123,8 +149,15 @@ class HeartFCMessageReceiver:
|
||||
# 如果消息中包含图片标识,则将 [picid:...] 替换为 [图片]
|
||||
picid_pattern = r"\[picid:([^\]]+)\]"
|
||||
processed_plain_text = re.sub(picid_pattern, "[图片]", message.processed_plain_text)
|
||||
|
||||
# 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式
|
||||
processed_plain_text = replace_user_references_sync(
|
||||
processed_plain_text,
|
||||
message.message_info.platform, # type: ignore
|
||||
replace_bot_name=True
|
||||
)
|
||||
|
||||
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") # type: ignore
|
||||
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[兴趣度:{interested_rate:.2f}]") # type: ignore
|
||||
|
||||
logger.debug(f"[{mes_name}][当前时段回复频率: {current_talk_frequency}]")
|
||||
|
||||
|
||||
@@ -224,10 +224,16 @@ class Hippocampus:
|
||||
return hash((source, target))
|
||||
|
||||
@staticmethod
|
||||
def find_topic_llm(text, topic_num):
|
||||
def find_topic_llm(text: str, topic_num: int | list[int]):
|
||||
# sourcery skip: inline-immediately-returned-variable
|
||||
topic_num_str = ""
|
||||
if isinstance(topic_num, list):
|
||||
topic_num_str = f"{topic_num[0]}-{topic_num[1]}"
|
||||
else:
|
||||
topic_num_str = topic_num
|
||||
|
||||
prompt = (
|
||||
f"这是一段文字:\n{text}\n\n请你从这段话中总结出最多{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,"
|
||||
f"这是一段文字:\n{text}\n\n请你从这段话中总结出最多{topic_num_str}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,"
|
||||
f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
|
||||
f"如果确定找不出主题或者没有明显主题,返回<none>。"
|
||||
)
|
||||
@@ -300,6 +306,60 @@ class Hippocampus:
|
||||
memories.sort(key=lambda x: x[2], reverse=True)
|
||||
return memories
|
||||
|
||||
async def get_keywords_from_text(self, text: str) -> list:
|
||||
"""从文本中提取关键词。
|
||||
|
||||
Args:
|
||||
text (str): 输入文本
|
||||
fast_retrieval (bool, optional): 是否使用快速检索。默认为False。
|
||||
如果为True,使用jieba分词提取关键词,速度更快但可能不够准确。
|
||||
如果为False,使用LLM提取关键词,速度较慢但更准确。
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
# 使用LLM提取关键词 - 根据详细文本长度分布优化topic_num计算
|
||||
text_length = len(text)
|
||||
topic_num: int | list[int] = 0
|
||||
if text_length <= 5:
|
||||
words = jieba.cut(text)
|
||||
keywords = [word for word in words if len(word) > 1]
|
||||
keywords = list(set(keywords))[:3] # 限制最多3个关键词
|
||||
if keywords:
|
||||
logger.info(f"提取关键词: {keywords}")
|
||||
return keywords
|
||||
elif text_length <= 10:
|
||||
topic_num = [1, 3] # 6-10字符: 1个关键词 (27.18%的文本)
|
||||
elif text_length <= 20:
|
||||
topic_num = [2, 4] # 11-20字符: 2个关键词 (22.76%的文本)
|
||||
elif text_length <= 30:
|
||||
topic_num = [3, 5] # 21-30字符: 3个关键词 (10.33%的文本)
|
||||
elif text_length <= 50:
|
||||
topic_num = [4, 5] # 31-50字符: 4个关键词 (9.79%的文本)
|
||||
else:
|
||||
topic_num = 5 # 51+字符: 5个关键词 (其余长文本)
|
||||
|
||||
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(
|
||||
self.find_topic_llm(text, topic_num)
|
||||
)
|
||||
|
||||
# 提取关键词
|
||||
keywords = re.findall(r"<([^>]+)>", topics_response)
|
||||
if not keywords:
|
||||
keywords = []
|
||||
else:
|
||||
keywords = [
|
||||
keyword.strip()
|
||||
for keyword in ",".join(keywords).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
||||
if keyword.strip()
|
||||
]
|
||||
|
||||
if keywords:
|
||||
logger.info(f"提取关键词: {keywords}")
|
||||
|
||||
return keywords
|
||||
|
||||
|
||||
async def get_memory_from_text(
|
||||
self,
|
||||
text: str,
|
||||
@@ -325,39 +385,7 @@ class Hippocampus:
|
||||
- memory_items: list, 该主题下的记忆项列表
|
||||
- similarity: float, 与文本的相似度
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
if fast_retrieval:
|
||||
# 使用jieba分词提取关键词
|
||||
words = jieba.cut(text)
|
||||
# 过滤掉停用词和单字词
|
||||
keywords = [word for word in words if len(word) > 1]
|
||||
# 去重
|
||||
keywords = list(set(keywords))
|
||||
# 限制关键词数量
|
||||
logger.debug(f"提取关键词: {keywords}")
|
||||
|
||||
else:
|
||||
# 使用LLM提取关键词
|
||||
topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量
|
||||
# logger.info(f"提取关键词数量: {topic_num}")
|
||||
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(
|
||||
self.find_topic_llm(text, topic_num)
|
||||
)
|
||||
|
||||
# 提取关键词
|
||||
keywords = re.findall(r"<([^>]+)>", topics_response)
|
||||
if not keywords:
|
||||
keywords = []
|
||||
else:
|
||||
keywords = [
|
||||
keyword.strip()
|
||||
for keyword in ",".join(keywords).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
||||
if keyword.strip()
|
||||
]
|
||||
|
||||
# logger.info(f"提取的关键词: {', '.join(keywords)}")
|
||||
keywords = await self.get_keywords_from_text(text)
|
||||
|
||||
# 过滤掉不存在于记忆图中的关键词
|
||||
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
|
||||
@@ -679,38 +707,7 @@ class Hippocampus:
|
||||
Returns:
|
||||
float: 激活节点数与总节点数的比值
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
if fast_retrieval:
|
||||
# 使用jieba分词提取关键词
|
||||
words = jieba.cut(text)
|
||||
# 过滤掉停用词和单字词
|
||||
keywords = [word for word in words if len(word) > 1]
|
||||
# 去重
|
||||
keywords = list(set(keywords))
|
||||
# 限制关键词数量
|
||||
keywords = keywords[:5]
|
||||
else:
|
||||
# 使用LLM提取关键词
|
||||
topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量
|
||||
# logger.info(f"提取关键词数量: {topic_num}")
|
||||
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(
|
||||
self.find_topic_llm(text, topic_num)
|
||||
)
|
||||
|
||||
# 提取关键词
|
||||
keywords = re.findall(r"<([^>]+)>", topics_response)
|
||||
if not keywords:
|
||||
keywords = []
|
||||
else:
|
||||
keywords = [
|
||||
keyword.strip()
|
||||
for keyword in ",".join(keywords).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
||||
if keyword.strip()
|
||||
]
|
||||
|
||||
# logger.info(f"提取的关键词: {', '.join(keywords)}")
|
||||
keywords = await self.get_keywords_from_text(text)
|
||||
|
||||
# 过滤掉不存在于记忆图中的关键词
|
||||
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
|
||||
@@ -727,7 +724,7 @@ class Hippocampus:
|
||||
for keyword in valid_keywords:
|
||||
logger.debug(f"开始以关键词 '{keyword}' 为中心进行扩散检索 (最大深度: {max_depth}):")
|
||||
# 初始化激活值
|
||||
activation_values = {keyword: 1.0}
|
||||
activation_values = {keyword: 1.5}
|
||||
# 记录已访问的节点
|
||||
visited_nodes = {keyword}
|
||||
# 待处理的节点队列,每个元素是(节点, 激活值, 当前深度)
|
||||
@@ -1315,6 +1312,7 @@ class ParahippocampalGyrus:
|
||||
return compressed_memory, similar_topics_dict
|
||||
|
||||
async def operation_build_memory(self):
|
||||
# sourcery skip: merge-list-appends-into-extend
|
||||
logger.info("------------------------------------开始构建记忆--------------------------------------")
|
||||
start_time = time.time()
|
||||
memory_samples = self.hippocampus.entorhinal_cortex.get_memory_sample()
|
||||
|
||||
@@ -444,7 +444,7 @@ class MessageSending(MessageProcessBase):
|
||||
is_emoji: bool = False,
|
||||
thinking_start_time: float = 0,
|
||||
apply_set_reply_logic: bool = False,
|
||||
reply_to: str = None, # type: ignore
|
||||
reply_to: Optional[str] = None,
|
||||
):
|
||||
# 调用父类初始化
|
||||
super().__init__(
|
||||
|
||||
@@ -3,7 +3,7 @@ from src.plugin_system.base.base_action import BaseAction
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.base.component_types import ComponentType, ActionActivationType, ChatMode, ActionInfo
|
||||
from src.plugin_system.base.component_types import ComponentType, ActionInfo
|
||||
|
||||
logger = get_logger("action_manager")
|
||||
|
||||
@@ -15,11 +15,6 @@ class ActionManager:
|
||||
现在统一使用新插件系统,简化了原有的新旧兼容逻辑。
|
||||
"""
|
||||
|
||||
# 类常量
|
||||
DEFAULT_RANDOM_PROBABILITY = 0.3
|
||||
DEFAULT_MODE = ChatMode.ALL
|
||||
DEFAULT_ACTIVATION_TYPE = ActionActivationType.ALWAYS
|
||||
|
||||
def __init__(self):
|
||||
"""初始化动作管理器"""
|
||||
|
||||
|
||||
@@ -174,7 +174,7 @@ class ActionModifier:
|
||||
continue # 总是激活,无需处理
|
||||
|
||||
elif activation_type == ActionActivationType.RANDOM:
|
||||
probability = action_info.random_activation_probability or ActionManager.DEFAULT_RANDOM_PROBABILITY
|
||||
probability = action_info.random_activation_probability
|
||||
if random.random() >= probability:
|
||||
reason = f"RANDOM类型未触发(概率{probability})"
|
||||
deactivated_actions.append((action_name, reason))
|
||||
|
||||
@@ -33,10 +33,11 @@ def init_prompt():
|
||||
{time_block}
|
||||
{identity_block}
|
||||
你现在需要根据聊天内容,选择的合适的action来参与聊天。
|
||||
{chat_context_description},以下是具体的聊天内容:
|
||||
{chat_context_description},以下是具体的聊天内容
|
||||
{chat_content_block}
|
||||
|
||||
|
||||
|
||||
{moderation_prompt}
|
||||
|
||||
现在请你根据{by_what}选择合适的action和触发action的消息:
|
||||
@@ -45,7 +46,7 @@ def init_prompt():
|
||||
{no_action_block}
|
||||
{action_options_text}
|
||||
|
||||
你必须从上面列出的可用action中选择一个,并说明触发action的消息id和原因。
|
||||
你必须从上面列出的可用action中选择一个,并说明触发action的消息id(不是消息原文)和选择该action的原因。
|
||||
|
||||
请根据动作示例,以严格的 JSON 格式输出,且仅包含 JSON 内容:
|
||||
""",
|
||||
@@ -128,20 +129,6 @@ class ActionPlanner:
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到")
|
||||
|
||||
# 如果没有可用动作或只有no_reply动作,直接返回no_reply
|
||||
# 因为现在reply是永远激活,所以不需要空跳判定
|
||||
# if not current_available_actions:
|
||||
# action = "no_reply" if mode == ChatMode.FOCUS else "no_action"
|
||||
# reasoning = "没有可用的动作"
|
||||
# logger.info(f"{self.log_prefix}{reasoning}")
|
||||
# return {
|
||||
# "action_result": {
|
||||
# "action_type": action,
|
||||
# "action_data": action_data,
|
||||
# "reasoning": reasoning,
|
||||
# },
|
||||
# }, None
|
||||
|
||||
# --- 构建提示词 (调用修改后的 PromptBuilder 方法) ---
|
||||
prompt, message_id_list = await self.build_planner_prompt(
|
||||
is_group_chat=is_group_chat, # <-- Pass HFC state
|
||||
@@ -224,7 +211,7 @@ class ActionPlanner:
|
||||
reasoning = f"Planner 内部处理错误: {outer_e}"
|
||||
|
||||
is_parallel = False
|
||||
if action in current_available_actions:
|
||||
if mode == ChatMode.NORMAL and action in current_available_actions:
|
||||
is_parallel = current_available_actions[action].parallel_action
|
||||
|
||||
action_result = {
|
||||
@@ -268,7 +255,7 @@ class ActionPlanner:
|
||||
|
||||
actions_before_now = get_actions_by_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=time.time()-3600,
|
||||
timestamp_start=time.time() - 3600,
|
||||
timestamp_end=time.time(),
|
||||
limit=5,
|
||||
)
|
||||
@@ -276,7 +263,7 @@ class ActionPlanner:
|
||||
actions_before_now_block = build_readable_actions(
|
||||
actions=actions_before_now,
|
||||
)
|
||||
|
||||
|
||||
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
||||
|
||||
self.last_obs_time_mark = time.time()
|
||||
@@ -288,7 +275,6 @@ class ActionPlanner:
|
||||
if global_config.chat.at_bot_inevitable_reply:
|
||||
mentioned_bonus = "\n- 有人提到你,或者at你"
|
||||
|
||||
|
||||
by_what = "聊天内容"
|
||||
target_prompt = '\n "target_message_id":"触发action的消息id"'
|
||||
no_action_block = f"""重要说明:
|
||||
@@ -311,7 +297,7 @@ class ActionPlanner:
|
||||
by_what = "聊天内容和用户的最新消息"
|
||||
target_prompt = ""
|
||||
no_action_block = """重要说明:
|
||||
- 'no_action' 表示只进行普通聊天回复,不执行任何额外动作
|
||||
- 'reply' 表示只进行普通聊天回复,不执行任何额外动作
|
||||
- 其他action表示在普通回复的基础上,执行相应的额外动作"""
|
||||
|
||||
chat_context_description = "你现在正在一个群聊中"
|
||||
|
||||
@@ -17,7 +17,11 @@ from src.chat.message_receive.uni_message_sender import HeartFCSender
|
||||
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
replace_user_references_sync,
|
||||
)
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
||||
from src.chat.memory_system.memory_activator import MemoryActivator
|
||||
@@ -25,42 +29,16 @@ from src.chat.memory_system.instant_memory import InstantMemory
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.tools.tool_executor import ToolExecutor
|
||||
from src.plugin_system.base.component_types import ActionInfo
|
||||
|
||||
logger = get_logger("replyer")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt("你正在qq群里聊天,下面是群里在聊的内容:", "chat_target_group1")
|
||||
Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1")
|
||||
Prompt("在群里聊天", "chat_target_group2")
|
||||
Prompt("和{sender_name}聊天", "chat_target_private2")
|
||||
Prompt("\n你有以下这些**知识**:\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt")
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
{expression_habits_block}
|
||||
{tool_info_block}
|
||||
{knowledge_prompt}
|
||||
{memory_block}
|
||||
{relation_info_block}
|
||||
{extra_info_block}
|
||||
|
||||
{chat_target}
|
||||
{time_block}
|
||||
{chat_info}
|
||||
{reply_target_block}
|
||||
{identity}
|
||||
|
||||
{action_descriptions}
|
||||
你正在{chat_target_2},你现在的心情是:{mood_state}
|
||||
现在请你读读之前的聊天记录,并给出回复
|
||||
{config_expression_style}。注意不要复读你说过的话
|
||||
{keywords_reaction_prompt}
|
||||
{moderation_prompt}
|
||||
不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容""",
|
||||
"default_generator_prompt",
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
@@ -109,7 +87,8 @@ def init_prompt():
|
||||
{core_dialogue_prompt}
|
||||
|
||||
{reply_target_block}
|
||||
对方最新发送的内容:{message_txt}
|
||||
|
||||
|
||||
你现在的心情是:{mood_state}
|
||||
{config_expression_style}
|
||||
注意不要复读你说过的话
|
||||
@@ -159,6 +138,8 @@ class DefaultReplyer:
|
||||
self.heart_fc_sender = HeartFCSender()
|
||||
self.memory_activator = MemoryActivator()
|
||||
self.instant_memory = InstantMemory(chat_id=self.chat_stream.stream_id)
|
||||
|
||||
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor,不然会循环依赖
|
||||
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3)
|
||||
|
||||
def _select_weighted_model_config(self) -> Dict[str, Any]:
|
||||
@@ -171,67 +152,49 @@ class DefaultReplyer:
|
||||
|
||||
async def generate_reply_with_context(
|
||||
self,
|
||||
reply_data: Optional[Dict[str, Any]] = None,
|
||||
reply_to: str = "",
|
||||
extra_info: str = "",
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
enable_tool: bool = True,
|
||||
enable_timeout: bool = False,
|
||||
) -> Tuple[bool, Optional[str], Optional[str]]:
|
||||
"""
|
||||
回复器 (Replier): 核心逻辑,负责生成回复文本。
|
||||
(已整合原 HeartFCGenerator 的功能)
|
||||
回复器 (Replier): 负责生成回复文本的核心逻辑。
|
||||
|
||||
Args:
|
||||
reply_to: 回复对象,格式为 "发送者:消息内容"
|
||||
extra_info: 额外信息,用于补充上下文
|
||||
available_actions: 可用的动作信息字典
|
||||
enable_tool: 是否启用工具调用
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str], Optional[str]]: (是否成功, 生成的回复内容, 使用的prompt)
|
||||
"""
|
||||
prompt = None
|
||||
if available_actions is None:
|
||||
available_actions = {}
|
||||
try:
|
||||
if not reply_data:
|
||||
reply_data = {
|
||||
"reply_to": reply_to,
|
||||
"extra_info": extra_info,
|
||||
}
|
||||
for key, value in reply_data.items():
|
||||
if not value:
|
||||
logger.debug(f"回复数据跳过{key},生成回复时将忽略。")
|
||||
|
||||
# 3. 构建 Prompt
|
||||
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
||||
prompt = await self.build_prompt_reply_context(
|
||||
reply_data=reply_data, # 传递action_data
|
||||
reply_to=reply_to,
|
||||
extra_info=extra_info,
|
||||
available_actions=available_actions,
|
||||
enable_timeout=enable_timeout,
|
||||
enable_tool=enable_tool,
|
||||
)
|
||||
|
||||
if not prompt:
|
||||
logger.warning("构建prompt失败,跳过回复生成")
|
||||
return False, None, None
|
||||
|
||||
# 4. 调用 LLM 生成回复
|
||||
content = None
|
||||
reasoning_content = None
|
||||
model_name = "unknown_model"
|
||||
# TODO: 复活这里
|
||||
# reasoning_content = None
|
||||
# model_name = "unknown_model"
|
||||
|
||||
try:
|
||||
with Timer("LLM生成", {}): # 内部计时器,可选保留
|
||||
# 加权随机选择一个模型配置
|
||||
selected_model_config = self._select_weighted_model_config()
|
||||
# 兼容新旧格式的模型名称获取
|
||||
model_display_name = selected_model_config.get('model_name', selected_model_config.get('name', 'N/A'))
|
||||
logger.info(
|
||||
f"使用模型生成回复: {model_display_name} (选中概率: {selected_model_config.get('weight', 1.0)})"
|
||||
)
|
||||
|
||||
express_model = LLMRequest(
|
||||
model=selected_model_config,
|
||||
request_type=self.request_type,
|
||||
)
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"\n{prompt}\n")
|
||||
else:
|
||||
logger.debug(f"\n{prompt}\n")
|
||||
|
||||
content, (reasoning_content, model_name) = await express_model.generate_response_async(prompt)
|
||||
|
||||
logger.debug(f"replyer生成内容: {content}")
|
||||
content = await self.llm_generate_content(prompt)
|
||||
logger.debug(f"replyer生成内容: {content}")
|
||||
|
||||
except Exception as llm_e:
|
||||
# 精简报错信息
|
||||
@@ -247,73 +210,62 @@ class DefaultReplyer:
|
||||
|
||||
async def rewrite_reply_with_context(
|
||||
self,
|
||||
reply_data: Dict[str, Any],
|
||||
raw_reply: str = "",
|
||||
reason: str = "",
|
||||
reply_to: str = "",
|
||||
relation_info: str = "",
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
return_prompt: bool = False,
|
||||
) -> Tuple[bool, Optional[str], Optional[str]]:
|
||||
"""
|
||||
表达器 (Expressor): 核心逻辑,负责生成回复文本。
|
||||
表达器 (Expressor): 负责重写和优化回复文本。
|
||||
|
||||
Args:
|
||||
raw_reply: 原始回复内容
|
||||
reason: 回复原因
|
||||
reply_to: 回复对象,格式为 "发送者:消息内容"
|
||||
relation_info: 关系信息
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str]]: (是否成功, 重写后的回复内容)
|
||||
"""
|
||||
try:
|
||||
if not reply_data:
|
||||
reply_data = {
|
||||
"reply_to": reply_to,
|
||||
"relation_info": relation_info,
|
||||
}
|
||||
|
||||
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
||||
prompt = await self.build_prompt_rewrite_context(
|
||||
reply_data=reply_data,
|
||||
raw_reply=raw_reply,
|
||||
reason=reason,
|
||||
reply_to=reply_to,
|
||||
)
|
||||
|
||||
content = None
|
||||
reasoning_content = None
|
||||
model_name = "unknown_model"
|
||||
# TODO: 复活这里
|
||||
# reasoning_content = None
|
||||
# model_name = "unknown_model"
|
||||
if not prompt:
|
||||
logger.error("Prompt 构建失败,无法生成回复。")
|
||||
return False, None
|
||||
return False, None, None
|
||||
|
||||
try:
|
||||
with Timer("LLM生成", {}): # 内部计时器,可选保留
|
||||
# 加权随机选择一个模型配置
|
||||
selected_model_config = self._select_weighted_model_config()
|
||||
# 兼容新旧格式的模型名称获取
|
||||
model_display_name = selected_model_config.get('model_name', selected_model_config.get('name', 'N/A'))
|
||||
logger.info(
|
||||
f"使用模型重写回复: {model_display_name} (选中概率: {selected_model_config.get('weight', 1.0)})"
|
||||
)
|
||||
|
||||
express_model = LLMRequest(
|
||||
model=selected_model_config,
|
||||
request_type=self.request_type,
|
||||
)
|
||||
|
||||
content, (reasoning_content, model_name) = await express_model.generate_response_async(prompt)
|
||||
|
||||
logger.info(f"想要表达:{raw_reply}||理由:{reason}||生成回复: {content}\n")
|
||||
content = await self.llm_generate_content(prompt)
|
||||
logger.info(f"想要表达:{raw_reply}||理由:{reason}||生成回复: {content}\n")
|
||||
|
||||
except Exception as llm_e:
|
||||
# 精简报错信息
|
||||
logger.error(f"LLM 生成失败: {llm_e}")
|
||||
return False, None # LLM 调用失败则无法生成回复
|
||||
return False, None, prompt if return_prompt else None # LLM 调用失败则无法生成回复
|
||||
|
||||
return True, content
|
||||
return True, content, prompt if return_prompt else None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"回复生成意外失败: {e}")
|
||||
traceback.print_exc()
|
||||
return False, None
|
||||
return False, None, prompt if return_prompt else None
|
||||
|
||||
async def build_relation_info(self, reply_data=None):
|
||||
async def build_relation_info(self, reply_to: str = ""):
|
||||
if not global_config.relationship.enable_relationship:
|
||||
return ""
|
||||
|
||||
relationship_fetcher = relationship_fetcher_manager.get_fetcher(self.chat_stream.stream_id)
|
||||
if not reply_data:
|
||||
if not reply_to:
|
||||
return ""
|
||||
reply_to = reply_data.get("reply_to", "")
|
||||
sender, text = self._parse_reply_target(reply_to)
|
||||
if not sender or not text:
|
||||
return ""
|
||||
@@ -327,7 +279,16 @@ class DefaultReplyer:
|
||||
|
||||
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
|
||||
async def build_expression_habits(self, chat_history, target):
|
||||
async def build_expression_habits(self, chat_history: str, target: str) -> str:
|
||||
"""构建表达习惯块
|
||||
|
||||
Args:
|
||||
chat_history: 聊天历史记录
|
||||
target: 目标消息内容
|
||||
|
||||
Returns:
|
||||
str: 表达习惯信息字符串
|
||||
"""
|
||||
if not global_config.expression.enable_expression:
|
||||
return ""
|
||||
|
||||
@@ -360,54 +321,65 @@ class DefaultReplyer:
|
||||
expression_habits_block = ""
|
||||
expression_habits_title = ""
|
||||
if style_habits_str.strip():
|
||||
expression_habits_title = "你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:"
|
||||
expression_habits_title = (
|
||||
"你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:"
|
||||
)
|
||||
expression_habits_block += f"{style_habits_str}\n"
|
||||
if grammar_habits_str.strip():
|
||||
expression_habits_title = "你可以选择下面的句法进行回复,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式使用:"
|
||||
expression_habits_title = (
|
||||
"你可以选择下面的句法进行回复,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式使用:"
|
||||
)
|
||||
expression_habits_block += f"{grammar_habits_str}\n"
|
||||
|
||||
|
||||
if style_habits_str.strip() and grammar_habits_str.strip():
|
||||
expression_habits_title = "你可以参考以下的语言习惯和句法,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式结合到你的回复中:"
|
||||
|
||||
expression_habits_block = f"{expression_habits_title}\n{expression_habits_block}"
|
||||
|
||||
|
||||
return expression_habits_block
|
||||
return f"{expression_habits_title}\n{expression_habits_block}"
|
||||
|
||||
async def build_memory_block(self, chat_history, target):
|
||||
async def build_memory_block(self, chat_history: str, target: str) -> str:
|
||||
"""构建记忆块
|
||||
|
||||
Args:
|
||||
chat_history: 聊天历史记录
|
||||
target: 目标消息内容
|
||||
|
||||
Returns:
|
||||
str: 记忆信息字符串
|
||||
"""
|
||||
if not global_config.memory.enable_memory:
|
||||
return ""
|
||||
|
||||
instant_memory = None
|
||||
|
||||
|
||||
running_memories = await self.memory_activator.activate_memory_with_chat_history(
|
||||
target_message=target, chat_history_prompt=chat_history
|
||||
)
|
||||
|
||||
|
||||
if global_config.memory.enable_instant_memory:
|
||||
asyncio.create_task(self.instant_memory.create_and_store_memory(chat_history))
|
||||
|
||||
instant_memory = await self.instant_memory.get_memory(target)
|
||||
logger.info(f"即时记忆:{instant_memory}")
|
||||
|
||||
|
||||
if not running_memories:
|
||||
return ""
|
||||
|
||||
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
|
||||
for running_memory in running_memories:
|
||||
memory_str += f"- {running_memory['content']}\n"
|
||||
|
||||
|
||||
if instant_memory:
|
||||
memory_str += f"- {instant_memory}\n"
|
||||
|
||||
|
||||
return memory_str
|
||||
|
||||
async def build_tool_info(self, chat_history, reply_data: Optional[Dict], enable_tool: bool = True):
|
||||
async def build_tool_info(self, chat_history: str, reply_to: str = "", enable_tool: bool = True) -> str:
|
||||
"""构建工具信息块
|
||||
|
||||
Args:
|
||||
reply_data: 回复数据,包含要回复的消息内容
|
||||
chat_history: 聊天历史
|
||||
chat_history: 聊天历史记录
|
||||
reply_to: 回复对象,格式为 "发送者:消息内容"
|
||||
enable_tool: 是否启用工具调用
|
||||
|
||||
Returns:
|
||||
str: 工具信息字符串
|
||||
@@ -416,10 +388,9 @@ class DefaultReplyer:
|
||||
if not enable_tool:
|
||||
return ""
|
||||
|
||||
if not reply_data:
|
||||
if not reply_to:
|
||||
return ""
|
||||
|
||||
reply_to = reply_data.get("reply_to", "")
|
||||
sender, text = self._parse_reply_target(reply_to)
|
||||
|
||||
if not text:
|
||||
@@ -442,7 +413,7 @@ class DefaultReplyer:
|
||||
|
||||
tool_info_str += "以上是你获取到的实时信息,请在回复时参考这些信息。"
|
||||
logger.info(f"获取到 {len(tool_results)} 个工具结果")
|
||||
|
||||
|
||||
return tool_info_str
|
||||
else:
|
||||
logger.debug("未获取到任何工具结果")
|
||||
@@ -452,7 +423,15 @@ class DefaultReplyer:
|
||||
logger.error(f"工具信息获取失败: {e}")
|
||||
return ""
|
||||
|
||||
def _parse_reply_target(self, target_message: str) -> tuple:
|
||||
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
|
||||
"""解析回复目标消息
|
||||
|
||||
Args:
|
||||
target_message: 目标消息,格式为 "发送者:消息内容" 或 "发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (发送者名称, 消息内容)
|
||||
"""
|
||||
sender = ""
|
||||
target = ""
|
||||
# 添加None检查,防止NoneType错误
|
||||
@@ -466,14 +445,22 @@ class DefaultReplyer:
|
||||
target = parts[1].strip()
|
||||
return sender, target
|
||||
|
||||
async def build_keywords_reaction_prompt(self, target):
|
||||
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
|
||||
"""构建关键词反应提示
|
||||
|
||||
Args:
|
||||
target: 目标消息内容
|
||||
|
||||
Returns:
|
||||
str: 关键词反应提示字符串
|
||||
"""
|
||||
# 关键词检测与反应
|
||||
keywords_reaction_prompt = ""
|
||||
try:
|
||||
# 添加None检查,防止NoneType错误
|
||||
if target is None:
|
||||
return keywords_reaction_prompt
|
||||
|
||||
|
||||
# 处理关键词规则
|
||||
for rule in global_config.keyword_reaction.keyword_rules:
|
||||
if any(keyword in target for keyword in rule.keywords):
|
||||
@@ -500,15 +487,25 @@ class DefaultReplyer:
|
||||
|
||||
return keywords_reaction_prompt
|
||||
|
||||
async def _time_and_run_task(self, coroutine, name: str):
|
||||
"""一个简单的帮助函数,用于计时和运行异步任务,返回任务名、结果和耗时"""
|
||||
async def _time_and_run_task(self, coroutine, name: str) -> Tuple[str, Any, float]:
|
||||
"""计时并运行异步任务的辅助函数
|
||||
|
||||
Args:
|
||||
coroutine: 要执行的协程
|
||||
name: 任务名称
|
||||
|
||||
Returns:
|
||||
Tuple[str, Any, float]: (任务名称, 任务结果, 执行耗时)
|
||||
"""
|
||||
start_time = time.time()
|
||||
result = await coroutine
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
return name, result, duration
|
||||
|
||||
def build_s4u_chat_history_prompts(self, message_list_before_now: list, target_user_id: str) -> tuple[str, str]:
|
||||
def build_s4u_chat_history_prompts(
|
||||
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
构建 s4u 风格的分离对话 prompt
|
||||
|
||||
@@ -517,7 +514,7 @@ class DefaultReplyer:
|
||||
target_user_id: 目标用户ID(当前对话对象)
|
||||
|
||||
Returns:
|
||||
tuple: (核心对话prompt, 背景对话prompt)
|
||||
Tuple[str, str]: (核心对话prompt, 背景对话prompt)
|
||||
"""
|
||||
core_dialogue_list = []
|
||||
background_dialogue_list = []
|
||||
@@ -536,7 +533,7 @@ class DefaultReplyer:
|
||||
# 其他用户的对话
|
||||
background_dialogue_list.append(msg_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"记录: {msg_dict}, 错误: {e}")
|
||||
logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}")
|
||||
|
||||
# 构建背景对话 prompt
|
||||
background_dialogue_prompt = ""
|
||||
@@ -581,8 +578,25 @@ class DefaultReplyer:
|
||||
sender: str,
|
||||
target: str,
|
||||
chat_info: str,
|
||||
):
|
||||
"""构建 mai_think 上下文信息"""
|
||||
) -> Any:
|
||||
"""构建 mai_think 上下文信息
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
memory_block: 记忆块内容
|
||||
relation_info: 关系信息
|
||||
time_block: 时间块内容
|
||||
chat_target_1: 聊天目标1
|
||||
chat_target_2: 聊天目标2
|
||||
mood_prompt: 情绪提示
|
||||
identity_block: 身份块内容
|
||||
sender: 发送者名称
|
||||
target: 目标消息内容
|
||||
chat_info: 聊天信息
|
||||
|
||||
Returns:
|
||||
Any: mai_think 实例
|
||||
"""
|
||||
mai_think = mai_thinking_manager.get_mai_think(chat_id)
|
||||
mai_think.memory_block = memory_block
|
||||
mai_think.relation_info_block = relation_info
|
||||
@@ -598,21 +612,20 @@ class DefaultReplyer:
|
||||
|
||||
async def build_prompt_reply_context(
|
||||
self,
|
||||
reply_data: Dict[str, Any],
|
||||
reply_to: str,
|
||||
extra_info: str = "",
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
enable_timeout: bool = False,
|
||||
enable_tool: bool = True,
|
||||
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||
"""
|
||||
构建回复器上下文
|
||||
|
||||
Args:
|
||||
reply_data: 回复数据
|
||||
replay_data 包含以下字段:
|
||||
structured_info: 结构化信息,一般是工具调用获得的信息
|
||||
reply_to: 回复对象
|
||||
extra_info/extra_info_block: 额外信息
|
||||
reply_to: 回复对象,格式为 "发送者:消息内容"
|
||||
extra_info: 额外信息,用于补充上下文
|
||||
available_actions: 可用动作
|
||||
enable_timeout: 是否启用超时处理
|
||||
enable_tool: 是否启用工具调用
|
||||
|
||||
Returns:
|
||||
str: 构建好的上下文
|
||||
@@ -623,9 +636,7 @@ class DefaultReplyer:
|
||||
chat_id = chat_stream.stream_id
|
||||
person_info_manager = get_person_info_manager()
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
reply_to = reply_data.get("reply_to", "none")
|
||||
extra_info_block = reply_data.get("extra_info", "") or reply_data.get("extra_info_block", "")
|
||||
|
||||
|
||||
if global_config.mood.enable_mood:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
|
||||
mood_prompt = chat_mood.mood_state
|
||||
@@ -633,6 +644,15 @@ class DefaultReplyer:
|
||||
mood_prompt = ""
|
||||
|
||||
sender, target = self._parse_reply_target(reply_to)
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
user_id = person_info_manager.get_value_sync(person_id, "user_id")
|
||||
platform = chat_stream.platform
|
||||
if user_id == global_config.bot.qq_account and platform == global_config.bot.platform:
|
||||
logger.warning("选取了自身作为回复对象,跳过构建prompt")
|
||||
return ""
|
||||
|
||||
target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True)
|
||||
|
||||
# 构建action描述 (如果启用planner)
|
||||
action_descriptions = ""
|
||||
@@ -649,21 +669,6 @@ class DefaultReplyer:
|
||||
limit=global_config.chat.max_context_size * 2,
|
||||
)
|
||||
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=global_config.chat.max_context_size,
|
||||
)
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
@@ -683,25 +688,21 @@ class DefaultReplyer:
|
||||
self._time_and_run_task(
|
||||
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
|
||||
),
|
||||
self._time_and_run_task(
|
||||
self.build_relation_info(reply_data), "relation_info"
|
||||
),
|
||||
self._time_and_run_task(self.build_relation_info(reply_to), "relation_info"),
|
||||
self._time_and_run_task(self.build_memory_block(chat_talking_prompt_short, target), "memory_block"),
|
||||
self._time_and_run_task(
|
||||
self.build_tool_info(chat_talking_prompt_short, reply_data, enable_tool=enable_tool), "tool_info"
|
||||
),
|
||||
self._time_and_run_task(
|
||||
get_prompt_info(target, threshold=0.38), "prompt_info"
|
||||
self.build_tool_info(chat_talking_prompt_short, reply_to, enable_tool=enable_tool), "tool_info"
|
||||
),
|
||||
self._time_and_run_task(get_prompt_info(target, threshold=0.38), "prompt_info"),
|
||||
)
|
||||
|
||||
# 任务名称中英文映射
|
||||
task_name_mapping = {
|
||||
"expression_habits": "选取表达方式",
|
||||
"relation_info": "感受关系",
|
||||
"relation_info": "感受关系",
|
||||
"memory_block": "回忆",
|
||||
"tool_info": "使用工具",
|
||||
"prompt_info": "获取知识"
|
||||
"prompt_info": "获取知识",
|
||||
}
|
||||
|
||||
# 处理结果
|
||||
@@ -723,8 +724,8 @@ class DefaultReplyer:
|
||||
|
||||
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
|
||||
|
||||
if extra_info_block:
|
||||
extra_info_block = f"以下是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策\n{extra_info_block}\n以上是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策"
|
||||
if extra_info:
|
||||
extra_info_block = f"以下是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策\n{extra_info}\n以上是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策"
|
||||
else:
|
||||
extra_info_block = ""
|
||||
|
||||
@@ -779,116 +780,74 @@ class DefaultReplyer:
|
||||
# 根据sender通过person_info_manager反向查找person_id,再获取user_id
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
|
||||
# 根据配置选择使用哪种 prompt 构建模式
|
||||
if global_config.chat.use_s4u_prompt_mode and person_id:
|
||||
# 使用 s4u 对话构建模式:分离当前对话对象和其他对话
|
||||
try:
|
||||
user_id_value = await person_info_manager.get_value(person_id, "user_id")
|
||||
if user_id_value:
|
||||
target_user_id = str(user_id_value)
|
||||
except Exception as e:
|
||||
logger.warning(f"无法从person_id {person_id} 获取user_id: {e}")
|
||||
target_user_id = ""
|
||||
# 使用 s4u 对话构建模式:分离当前对话对象和其他对话
|
||||
try:
|
||||
user_id_value = await person_info_manager.get_value(person_id, "user_id")
|
||||
if user_id_value:
|
||||
target_user_id = str(user_id_value)
|
||||
except Exception as e:
|
||||
logger.warning(f"无法从person_id {person_id} 获取user_id: {e}")
|
||||
target_user_id = ""
|
||||
|
||||
# 构建分离的对话 prompt
|
||||
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
|
||||
message_list_before_now_long, target_user_id
|
||||
)
|
||||
|
||||
self.build_mai_think_context(
|
||||
chat_id=chat_id,
|
||||
memory_block=memory_block,
|
||||
relation_info=relation_info,
|
||||
time_block=time_block,
|
||||
chat_target_1=chat_target_1,
|
||||
chat_target_2=chat_target_2,
|
||||
mood_prompt=mood_prompt,
|
||||
identity_block=identity_block,
|
||||
sender=sender,
|
||||
target=target,
|
||||
chat_info=f"""
|
||||
# 构建分离的对话 prompt
|
||||
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
|
||||
message_list_before_now_long, target_user_id
|
||||
)
|
||||
|
||||
self.build_mai_think_context(
|
||||
chat_id=chat_id,
|
||||
memory_block=memory_block,
|
||||
relation_info=relation_info,
|
||||
time_block=time_block,
|
||||
chat_target_1=chat_target_1,
|
||||
chat_target_2=chat_target_2,
|
||||
mood_prompt=mood_prompt,
|
||||
identity_block=identity_block,
|
||||
sender=sender,
|
||||
target=target,
|
||||
chat_info=f"""
|
||||
{background_dialogue_prompt}
|
||||
--------------------------------
|
||||
{time_block}
|
||||
这是你和{sender}的对话,你们正在交流中:
|
||||
{core_dialogue_prompt}"""
|
||||
)
|
||||
|
||||
{core_dialogue_prompt}""",
|
||||
)
|
||||
|
||||
# 使用 s4u 风格的模板
|
||||
template_name = "s4u_style_prompt"
|
||||
# 使用 s4u 风格的模板
|
||||
template_name = "s4u_style_prompt"
|
||||
|
||||
return await global_prompt_manager.format_prompt(
|
||||
template_name,
|
||||
expression_habits_block=expression_habits_block,
|
||||
tool_info_block=tool_info,
|
||||
knowledge_prompt=prompt_info,
|
||||
memory_block=memory_block,
|
||||
relation_info_block=relation_info,
|
||||
extra_info_block=extra_info_block,
|
||||
identity=identity_block,
|
||||
action_descriptions=action_descriptions,
|
||||
sender_name=sender,
|
||||
mood_state=mood_prompt,
|
||||
background_dialogue_prompt=background_dialogue_prompt,
|
||||
time_block=time_block,
|
||||
core_dialogue_prompt=core_dialogue_prompt,
|
||||
reply_target_block=reply_target_block,
|
||||
message_txt=target,
|
||||
config_expression_style=global_config.expression.expression_style,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
)
|
||||
else:
|
||||
self.build_mai_think_context(
|
||||
chat_id=chat_id,
|
||||
memory_block=memory_block,
|
||||
relation_info=relation_info,
|
||||
time_block=time_block,
|
||||
chat_target_1=chat_target_1,
|
||||
chat_target_2=chat_target_2,
|
||||
mood_prompt=mood_prompt,
|
||||
identity_block=identity_block,
|
||||
sender=sender,
|
||||
target=target,
|
||||
chat_info=chat_talking_prompt
|
||||
)
|
||||
|
||||
# 使用原有的模式
|
||||
return await global_prompt_manager.format_prompt(
|
||||
template_name,
|
||||
expression_habits_block=expression_habits_block,
|
||||
chat_target=chat_target_1,
|
||||
chat_info=chat_talking_prompt,
|
||||
memory_block=memory_block,
|
||||
tool_info_block=tool_info,
|
||||
knowledge_prompt=prompt_info,
|
||||
extra_info_block=extra_info_block,
|
||||
relation_info_block=relation_info,
|
||||
time_block=time_block,
|
||||
reply_target_block=reply_target_block,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
identity=identity_block,
|
||||
target_message=target,
|
||||
sender_name=sender,
|
||||
config_expression_style=global_config.expression.expression_style,
|
||||
action_descriptions=action_descriptions,
|
||||
chat_target_2=chat_target_2,
|
||||
mood_state=mood_prompt,
|
||||
)
|
||||
return await global_prompt_manager.format_prompt(
|
||||
template_name,
|
||||
expression_habits_block=expression_habits_block,
|
||||
tool_info_block=tool_info,
|
||||
knowledge_prompt=prompt_info,
|
||||
memory_block=memory_block,
|
||||
relation_info_block=relation_info,
|
||||
extra_info_block=extra_info_block,
|
||||
identity=identity_block,
|
||||
action_descriptions=action_descriptions,
|
||||
sender_name=sender,
|
||||
mood_state=mood_prompt,
|
||||
background_dialogue_prompt=background_dialogue_prompt,
|
||||
time_block=time_block,
|
||||
core_dialogue_prompt=core_dialogue_prompt,
|
||||
reply_target_block=reply_target_block,
|
||||
message_txt=target,
|
||||
config_expression_style=global_config.expression.expression_style,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
)
|
||||
|
||||
async def build_prompt_rewrite_context(
|
||||
self,
|
||||
reply_data: Dict[str, Any],
|
||||
raw_reply: str,
|
||||
reason: str,
|
||||
reply_to: str,
|
||||
) -> str:
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
|
||||
reply_to = reply_data.get("reply_to", "none")
|
||||
raw_reply = reply_data.get("raw_reply", "")
|
||||
reason = reply_data.get("reason", "")
|
||||
sender, target = self._parse_reply_target(reply_to)
|
||||
|
||||
# 添加情绪状态获取
|
||||
@@ -915,7 +874,7 @@ class DefaultReplyer:
|
||||
# 并行执行2个构建任务
|
||||
expression_habits_block, relation_info = await asyncio.gather(
|
||||
self.build_expression_habits(chat_talking_prompt_half, target),
|
||||
self.build_relation_info(reply_data),
|
||||
self.build_relation_info(reply_to),
|
||||
)
|
||||
|
||||
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
|
||||
@@ -1018,6 +977,31 @@ class DefaultReplyer:
|
||||
display_message=display_message,
|
||||
)
|
||||
|
||||
async def llm_generate_content(self, prompt: str) -> str:
|
||||
with Timer("LLM生成", {}): # 内部计时器,可选保留
|
||||
# 加权随机选择一个模型配置
|
||||
selected_model_config = self._select_weighted_model_config()
|
||||
model_display_name = selected_model_config.get('model_name') or selected_model_config.get('name', 'N/A')
|
||||
logger.info(
|
||||
f"使用模型生成回复: {model_display_name} (选中概率: {selected_model_config.get('weight', 1.0)})"
|
||||
)
|
||||
|
||||
express_model = LLMRequest(
|
||||
model=selected_model_config,
|
||||
request_type=self.request_type,
|
||||
)
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"\n{prompt}\n")
|
||||
else:
|
||||
logger.debug(f"\n{prompt}\n")
|
||||
|
||||
# TODO: 这里的_应该做出替换
|
||||
content, _ = await express_model.generate_response_async(prompt)
|
||||
|
||||
logger.debug(f"replyer生成内容: {content}")
|
||||
return content
|
||||
|
||||
|
||||
def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||
"""
|
||||
@@ -1075,10 +1059,8 @@ async def get_prompt_info(message: str, threshold: float):
|
||||
related_info += found_knowledge_from_lpmm
|
||||
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒")
|
||||
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
|
||||
|
||||
# 格式化知识信息
|
||||
formatted_prompt_info = await global_prompt_manager.format_prompt("knowledge_prompt", prompt_info=related_info)
|
||||
return formatted_prompt_info
|
||||
|
||||
return f"你有以下这些**知识**:\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n"
|
||||
else:
|
||||
logger.debug("从LPMM知识库获取知识失败,可能是从未导入过知识,返回空知识...")
|
||||
return ""
|
||||
|
||||
@@ -2,7 +2,7 @@ import time # 导入 time 模块以获取当前时间
|
||||
import random
|
||||
import re
|
||||
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from typing import List, Dict, Any, Tuple, Optional, Callable
|
||||
from rich.traceback import install
|
||||
|
||||
from src.config.config import global_config
|
||||
@@ -10,11 +10,161 @@ from src.common.message_repository import find_messages, count_messages
|
||||
from src.common.database.database_model import ActionRecords
|
||||
from src.common.database.database_model import Images
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable,assign_message_ids
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
def replace_user_references_sync(
|
||||
content: str,
|
||||
platform: str,
|
||||
name_resolver: Optional[Callable[[str, str], str]] = None,
|
||||
replace_bot_name: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
替换内容中的用户引用格式,包括回复<aaa:bbb>和@<aaa:bbb>格式
|
||||
|
||||
Args:
|
||||
content: 要处理的内容字符串
|
||||
platform: 平台标识
|
||||
name_resolver: 名称解析函数,接收(platform, user_id)参数,返回用户名称
|
||||
如果为None,则使用默认的person_info_manager
|
||||
replace_bot_name: 是否将机器人的user_id替换为"机器人昵称(你)"
|
||||
|
||||
Returns:
|
||||
str: 处理后的内容字符串
|
||||
"""
|
||||
if name_resolver is None:
|
||||
person_info_manager = get_person_info_manager()
|
||||
|
||||
def default_resolver(platform: str, user_id: str) -> str:
|
||||
# 检查是否是机器人自己
|
||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
return person_info_manager.get_value_sync(person_id, "person_name") or user_id # type: ignore
|
||||
|
||||
name_resolver = default_resolver
|
||||
|
||||
# 处理回复<aaa:bbb>格式
|
||||
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
|
||||
match = re.search(reply_pattern, content)
|
||||
if match:
|
||||
aaa = match[1]
|
||||
bbb = match[2]
|
||||
try:
|
||||
# 检查是否是机器人自己
|
||||
if replace_bot_name and bbb == global_config.bot.qq_account:
|
||||
reply_person_name = f"{global_config.bot.nickname}(你)"
|
||||
else:
|
||||
reply_person_name = name_resolver(platform, bbb) or aaa
|
||||
content = re.sub(reply_pattern, f"回复 {reply_person_name}", content, count=1)
|
||||
except Exception:
|
||||
# 如果解析失败,使用原始昵称
|
||||
content = re.sub(reply_pattern, f"回复 {aaa}", content, count=1)
|
||||
|
||||
# 处理@<aaa:bbb>格式
|
||||
at_pattern = r"@<([^:<>]+):([^:<>]+)>"
|
||||
at_matches = list(re.finditer(at_pattern, content))
|
||||
if at_matches:
|
||||
new_content = ""
|
||||
last_end = 0
|
||||
for m in at_matches:
|
||||
new_content += content[last_end : m.start()]
|
||||
aaa = m.group(1)
|
||||
bbb = m.group(2)
|
||||
try:
|
||||
# 检查是否是机器人自己
|
||||
if replace_bot_name and bbb == global_config.bot.qq_account:
|
||||
at_person_name = f"{global_config.bot.nickname}(你)"
|
||||
else:
|
||||
at_person_name = name_resolver(platform, bbb) or aaa
|
||||
new_content += f"@{at_person_name}"
|
||||
except Exception:
|
||||
# 如果解析失败,使用原始昵称
|
||||
new_content += f"@{aaa}"
|
||||
last_end = m.end()
|
||||
new_content += content[last_end:]
|
||||
content = new_content
|
||||
|
||||
return content
|
||||
|
||||
|
||||
async def replace_user_references_async(
|
||||
content: str,
|
||||
platform: str,
|
||||
name_resolver: Optional[Callable[[str, str], Any]] = None,
|
||||
replace_bot_name: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
替换内容中的用户引用格式,包括回复<aaa:bbb>和@<aaa:bbb>格式
|
||||
|
||||
Args:
|
||||
content: 要处理的内容字符串
|
||||
platform: 平台标识
|
||||
name_resolver: 名称解析函数,接收(platform, user_id)参数,返回用户名称
|
||||
如果为None,则使用默认的person_info_manager
|
||||
replace_bot_name: 是否将机器人的user_id替换为"机器人昵称(你)"
|
||||
|
||||
Returns:
|
||||
str: 处理后的内容字符串
|
||||
"""
|
||||
if name_resolver is None:
|
||||
person_info_manager = get_person_info_manager()
|
||||
|
||||
async def default_resolver(platform: str, user_id: str) -> str:
|
||||
# 检查是否是机器人自己
|
||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
return await person_info_manager.get_value(person_id, "person_name") or user_id # type: ignore
|
||||
|
||||
name_resolver = default_resolver
|
||||
|
||||
# 处理回复<aaa:bbb>格式
|
||||
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
|
||||
match = re.search(reply_pattern, content)
|
||||
if match:
|
||||
aaa = match.group(1)
|
||||
bbb = match.group(2)
|
||||
try:
|
||||
# 检查是否是机器人自己
|
||||
if replace_bot_name and bbb == global_config.bot.qq_account:
|
||||
reply_person_name = f"{global_config.bot.nickname}(你)"
|
||||
else:
|
||||
reply_person_name = await name_resolver(platform, bbb) or aaa
|
||||
content = re.sub(reply_pattern, f"回复 {reply_person_name}", content, count=1)
|
||||
except Exception:
|
||||
# 如果解析失败,使用原始昵称
|
||||
content = re.sub(reply_pattern, f"回复 {aaa}", content, count=1)
|
||||
|
||||
# 处理@<aaa:bbb>格式
|
||||
at_pattern = r"@<([^:<>]+):([^:<>]+)>"
|
||||
at_matches = list(re.finditer(at_pattern, content))
|
||||
if at_matches:
|
||||
new_content = ""
|
||||
last_end = 0
|
||||
for m in at_matches:
|
||||
new_content += content[last_end : m.start()]
|
||||
aaa = m.group(1)
|
||||
bbb = m.group(2)
|
||||
try:
|
||||
# 检查是否是机器人自己
|
||||
if replace_bot_name and bbb == global_config.bot.qq_account:
|
||||
at_person_name = f"{global_config.bot.nickname}(你)"
|
||||
else:
|
||||
at_person_name = await name_resolver(platform, bbb) or aaa
|
||||
new_content += f"@{at_person_name}"
|
||||
except Exception:
|
||||
# 如果解析失败,使用原始昵称
|
||||
new_content += f"@{aaa}"
|
||||
last_end = m.end()
|
||||
new_content += content[last_end:]
|
||||
content = new_content
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp(
|
||||
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
@@ -374,33 +524,8 @@ def _build_readable_messages_internal(
|
||||
else:
|
||||
person_name = "某人"
|
||||
|
||||
# 检查是否有 回复<aaa:bbb> 字段
|
||||
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
|
||||
match = re.search(reply_pattern, content)
|
||||
if match:
|
||||
aaa: str = match[1]
|
||||
bbb: str = match[2]
|
||||
reply_person_id = PersonInfoManager.get_person_id(platform, bbb)
|
||||
reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name") or aaa
|
||||
# 在内容前加上回复信息
|
||||
content = re.sub(reply_pattern, lambda m, name=reply_person_name: f"回复 {name}", content, count=1)
|
||||
|
||||
# 检查是否有 @<aaa:bbb> 字段 @<{member_info.get('nickname')}:{member_info.get('user_id')}>
|
||||
at_pattern = r"@<([^:<>]+):([^:<>]+)>"
|
||||
at_matches = list(re.finditer(at_pattern, content))
|
||||
if at_matches:
|
||||
new_content = ""
|
||||
last_end = 0
|
||||
for m in at_matches:
|
||||
new_content += content[last_end : m.start()]
|
||||
aaa = m.group(1)
|
||||
bbb = m.group(2)
|
||||
at_person_id = PersonInfoManager.get_person_id(platform, bbb)
|
||||
at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name") or aaa
|
||||
new_content += f"@{at_person_name}"
|
||||
last_end = m.end()
|
||||
new_content += content[last_end:]
|
||||
content = new_content
|
||||
# 使用独立函数处理用户引用格式
|
||||
content = replace_user_references_sync(content, platform, replace_bot_name=replace_bot_name)
|
||||
|
||||
target_str = "这是QQ的一个功能,用于提及某人,但没那么明显"
|
||||
if target_str in content and random.random() < 0.6:
|
||||
@@ -654,6 +779,7 @@ async def build_readable_messages_with_list(
|
||||
|
||||
return formatted_string, details_list
|
||||
|
||||
|
||||
def build_readable_messages_with_id(
|
||||
messages: List[Dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
@@ -669,9 +795,9 @@ def build_readable_messages_with_id(
|
||||
允许通过参数控制格式化行为。
|
||||
"""
|
||||
message_id_list = assign_message_ids(messages)
|
||||
|
||||
|
||||
formatted_string = build_readable_messages(
|
||||
messages = messages,
|
||||
messages=messages,
|
||||
replace_bot_name=replace_bot_name,
|
||||
merge_messages=merge_messages,
|
||||
timestamp_mode=timestamp_mode,
|
||||
@@ -682,10 +808,7 @@ def build_readable_messages_with_id(
|
||||
message_id_list=message_id_list,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
return formatted_string , message_id_list
|
||||
return formatted_string, message_id_list
|
||||
|
||||
|
||||
def build_readable_messages(
|
||||
@@ -770,7 +893,13 @@ def build_readable_messages(
|
||||
if read_mark <= 0:
|
||||
# 没有有效的 read_mark,直接格式化所有消息
|
||||
formatted_string, _, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
copy_messages, replace_bot_name, merge_messages, timestamp_mode, truncate, show_pic=show_pic, message_id_list=message_id_list
|
||||
copy_messages,
|
||||
replace_bot_name,
|
||||
merge_messages,
|
||||
timestamp_mode,
|
||||
truncate,
|
||||
show_pic=show_pic,
|
||||
message_id_list=message_id_list,
|
||||
)
|
||||
|
||||
# 生成图片映射信息并添加到最前面
|
||||
@@ -893,7 +1022,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
|
||||
for msg in messages:
|
||||
try:
|
||||
platform = msg.get("chat_info_platform")
|
||||
platform: str = msg.get("chat_info_platform") # type: ignore
|
||||
user_id = msg.get("user_id")
|
||||
_timestamp = msg.get("time")
|
||||
content: str = ""
|
||||
@@ -916,38 +1045,14 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
anon_name = get_anon_name(platform, user_id)
|
||||
# print(f"anon_name:{anon_name}")
|
||||
|
||||
# 处理 回复<aaa:bbb>
|
||||
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
|
||||
match = re.search(reply_pattern, content)
|
||||
if match:
|
||||
# print(f"发现回复match:{match}")
|
||||
bbb = match.group(2)
|
||||
# 使用独立函数处理用户引用格式,传入自定义的匿名名称解析器
|
||||
def anon_name_resolver(platform: str, user_id: str) -> str:
|
||||
try:
|
||||
anon_reply = get_anon_name(platform, bbb)
|
||||
# print(f"anon_reply:{anon_reply}")
|
||||
return get_anon_name(platform, user_id)
|
||||
except Exception:
|
||||
anon_reply = "?"
|
||||
content = re.sub(reply_pattern, f"回复 {anon_reply}", content, count=1)
|
||||
return "?"
|
||||
|
||||
# 处理 @<aaa:bbb>,无嵌套def
|
||||
at_pattern = r"@<([^:<>]+):([^:<>]+)>"
|
||||
at_matches = list(re.finditer(at_pattern, content))
|
||||
if at_matches:
|
||||
# print(f"发现@match:{at_matches}")
|
||||
new_content = ""
|
||||
last_end = 0
|
||||
for m in at_matches:
|
||||
new_content += content[last_end : m.start()]
|
||||
bbb = m.group(2)
|
||||
try:
|
||||
anon_at = get_anon_name(platform, bbb)
|
||||
# print(f"anon_at:{anon_at}")
|
||||
except Exception:
|
||||
anon_at = "?"
|
||||
new_content += f"@{anon_at}"
|
||||
last_end = m.end()
|
||||
new_content += content[last_end:]
|
||||
content = new_content
|
||||
content = replace_user_references_sync(content, platform, anon_name_resolver, replace_bot_name=False)
|
||||
|
||||
header = f"{anon_name}说 "
|
||||
output_lines.append(header)
|
||||
|
||||
@@ -37,7 +37,7 @@ class ImageManager:
|
||||
self._ensure_image_dir()
|
||||
|
||||
self._initialized = True
|
||||
self._llm = LLMRequest(model=global_config.model.vlm, temperature=0.4, max_tokens=300, request_type="image")
|
||||
self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.4, max_tokens=300, request_type="image")
|
||||
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
@@ -94,7 +94,7 @@ class ImageManager:
|
||||
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
|
||||
|
||||
async def get_emoji_description(self, image_base64: str) -> str:
|
||||
"""获取表情包描述,使用二步走识别并带缓存优化"""
|
||||
"""获取表情包描述,优先使用Emoji表中的缓存数据"""
|
||||
try:
|
||||
# 计算图片哈希
|
||||
# 确保base64字符串只包含ASCII字符
|
||||
@@ -104,9 +104,21 @@ class ImageManager:
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
|
||||
# 查询缓存的描述
|
||||
# 优先使用EmojiManager查询已注册表情包的描述
|
||||
try:
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
emoji_manager = get_emoji_manager()
|
||||
cached_emoji_description = await emoji_manager.get_emoji_description_by_hash(image_hash)
|
||||
if cached_emoji_description:
|
||||
logger.info(f"[缓存命中] 使用已注册表情包描述: {cached_emoji_description[:50]}...")
|
||||
return cached_emoji_description
|
||||
except Exception as e:
|
||||
logger.debug(f"查询EmojiManager时出错: {e}")
|
||||
|
||||
# 查询ImageDescriptions表的缓存描述
|
||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
||||
if cached_description:
|
||||
logger.info(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
|
||||
return f"[表情包:{cached_description}]"
|
||||
|
||||
# === 二步走识别流程 ===
|
||||
@@ -118,10 +130,10 @@ class ImageManager:
|
||||
logger.warning("GIF转换失败,无法获取描述")
|
||||
return "[表情包(GIF处理失败)]"
|
||||
vlm_prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
detailed_description, _ = await self._llm.generate_response_for_image(vlm_prompt, image_base64_processed, "jpg")
|
||||
detailed_description, _ = await self.vlm.generate_response_for_image(vlm_prompt, image_base64_processed, "jpg")
|
||||
else:
|
||||
vlm_prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
detailed_description, _ = await self._llm.generate_response_for_image(vlm_prompt, image_base64, image_format)
|
||||
detailed_description, _ = await self.vlm.generate_response_for_image(vlm_prompt, image_base64, image_format)
|
||||
|
||||
if detailed_description is None:
|
||||
logger.warning("VLM未能生成表情包详细描述")
|
||||
@@ -158,7 +170,7 @@ class ImageManager:
|
||||
if len(emotions) > 1 and emotions[1] != emotions[0]:
|
||||
final_emotion = f"{emotions[0]},{emotions[1]}"
|
||||
|
||||
logger.info(f"[二步走识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}")
|
||||
logger.info(f"[emoji识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}")
|
||||
|
||||
# 再次检查缓存,防止并发写入时重复生成
|
||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
||||
@@ -201,13 +213,13 @@ class ImageManager:
|
||||
self._save_description_to_db(image_hash, final_emotion, "emoji")
|
||||
|
||||
return f"[表情包:{final_emotion}]"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取表情包描述失败: {str(e)}")
|
||||
return "[表情包]"
|
||||
return "[表情包(处理失败)]"
|
||||
|
||||
async def get_image_description(self, image_base64: str) -> str:
|
||||
"""获取普通图片描述,带查重和保存功能"""
|
||||
"""获取普通图片描述,优先使用Images表中的缓存数据"""
|
||||
try:
|
||||
# 计算图片哈希
|
||||
if isinstance(image_base64, str):
|
||||
@@ -215,7 +227,7 @@ class ImageManager:
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
# 检查图片是否已存在
|
||||
# 优先检查Images表中是否已有完整的描述
|
||||
existing_image = Images.get_or_none(Images.emoji_hash == image_hash)
|
||||
if existing_image:
|
||||
# 更新计数
|
||||
@@ -227,18 +239,20 @@ class ImageManager:
|
||||
|
||||
# 如果已有描述,直接返回
|
||||
if existing_image.description:
|
||||
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...")
|
||||
return f"[图片:{existing_image.description}]"
|
||||
|
||||
# 查询缓存的描述
|
||||
# 查询ImageDescriptions表的缓存描述
|
||||
cached_description = self._get_description_from_db(image_hash, "image")
|
||||
if cached_description:
|
||||
logger.debug(f"图片描述缓存中 {cached_description}")
|
||||
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
|
||||
return f"[图片:{cached_description}]"
|
||||
|
||||
# 调用AI获取描述
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
prompt = global_config.custom_prompt.image_prompt
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
|
||||
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
|
||||
if description is None:
|
||||
logger.warning("AI未能生成图片描述")
|
||||
@@ -266,6 +280,7 @@ class ImageManager:
|
||||
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
|
||||
existing_image.vlm_processed = True
|
||||
existing_image.save()
|
||||
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
|
||||
else:
|
||||
Images.create(
|
||||
image_id=str(uuid.uuid4()),
|
||||
@@ -277,16 +292,18 @@ class ImageManager:
|
||||
vlm_processed=True,
|
||||
count=1,
|
||||
)
|
||||
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"保存图片文件或元数据失败: {str(e)}")
|
||||
|
||||
# 保存描述到ImageDescriptions表
|
||||
# 保存描述到ImageDescriptions表作为备用缓存
|
||||
self._save_description_to_db(image_hash, description, "image")
|
||||
|
||||
logger.info(f"[VLM完成] 图片描述生成: {description[:50]}...")
|
||||
return f"[图片:{description}]"
|
||||
except Exception as e:
|
||||
logger.error(f"获取图片描述失败: {str(e)}")
|
||||
return "[图片]"
|
||||
return "[图片(处理失败)]"
|
||||
|
||||
@staticmethod
|
||||
def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> Optional[str]:
|
||||
@@ -502,12 +519,28 @@ class ImageManager:
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
# 先检查缓存的描述
|
||||
# 获取当前图片记录
|
||||
image = Images.get(Images.image_id == image_id)
|
||||
|
||||
# 优先检查是否已有其他相同哈希的图片记录包含描述
|
||||
existing_with_description = Images.get_or_none(
|
||||
(Images.emoji_hash == image_hash) &
|
||||
(Images.description.is_null(False)) &
|
||||
(Images.description != "")
|
||||
)
|
||||
if existing_with_description and existing_with_description.id != image.id:
|
||||
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
|
||||
image.description = existing_with_description.description
|
||||
image.vlm_processed = True
|
||||
image.save()
|
||||
# 同时保存到ImageDescriptions表作为备用缓存
|
||||
self._save_description_to_db(image_hash, existing_with_description.description, "image")
|
||||
return
|
||||
|
||||
# 检查ImageDescriptions表的缓存描述
|
||||
cached_description = self._get_description_from_db(image_hash, "image")
|
||||
if cached_description:
|
||||
logger.debug(f"VLM处理时发现缓存描述: {cached_description}")
|
||||
# 更新数据库
|
||||
image = Images.get(Images.image_id == image_id)
|
||||
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
|
||||
image.description = cached_description
|
||||
image.vlm_processed = True
|
||||
image.save()
|
||||
@@ -520,7 +553,8 @@ class ImageManager:
|
||||
prompt = global_config.custom_prompt.image_prompt
|
||||
|
||||
# 获取VLM描述
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)")
|
||||
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
|
||||
if description is None:
|
||||
logger.warning("VLM未能生成图片描述")
|
||||
@@ -533,14 +567,15 @@ class ImageManager:
|
||||
description = cached_description
|
||||
|
||||
# 更新数据库
|
||||
image = Images.get(Images.image_id == image_id)
|
||||
image.description = description
|
||||
image.vlm_processed = True
|
||||
image.save()
|
||||
|
||||
# 保存描述到ImageDescriptions表
|
||||
# 保存描述到ImageDescriptions表作为备用缓存
|
||||
self._save_description_to_db(image_hash, description, "image")
|
||||
|
||||
logger.info(f"[VLM异步完成] 图片描述生成: {description[:50]}...")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"VLM处理图片失败: {str(e)}")
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ class ClassicalWillingManager(BaseWillingManager):
|
||||
|
||||
# print(f"[{chat_id}] 回复意愿: {current_willing}")
|
||||
|
||||
interested_rate = willing_info.interested_rate * global_config.normal_chat.response_interested_rate_amplifier
|
||||
interested_rate = willing_info.interested_rate
|
||||
|
||||
# print(f"[{chat_id}] 兴趣值: {interested_rate}")
|
||||
|
||||
@@ -36,20 +36,18 @@ class ClassicalWillingManager(BaseWillingManager):
|
||||
current_willing += interested_rate - 0.2
|
||||
|
||||
if willing_info.is_mentioned_bot and global_config.chat.mentioned_bot_inevitable_reply and current_willing < 2:
|
||||
current_willing += 1 if current_willing < 1.0 else 0.05
|
||||
current_willing += 1 if current_willing < 1.0 else 0.2
|
||||
|
||||
self.chat_reply_willing[chat_id] = min(current_willing, 1.0)
|
||||
|
||||
reply_probability = min(max((current_willing - 0.5), 0.01) * 2, 1)
|
||||
reply_probability = min(max((current_willing - 0.5), 0.01) * 2, 1.5)
|
||||
|
||||
# print(f"[{chat_id}] 回复概率: {reply_probability}")
|
||||
|
||||
return reply_probability
|
||||
|
||||
async def before_generate_reply_handle(self, message_id):
|
||||
chat_id = self.ongoing_messages[message_id].chat_id
|
||||
current_willing = self.chat_reply_willing.get(chat_id, 0)
|
||||
self.chat_reply_willing[chat_id] = max(0.0, current_willing - 1.8)
|
||||
pass
|
||||
|
||||
async def after_generate_reply_handle(self, message_id):
|
||||
if message_id not in self.ongoing_messages:
|
||||
@@ -58,7 +56,7 @@ class ClassicalWillingManager(BaseWillingManager):
|
||||
chat_id = self.ongoing_messages[message_id].chat_id
|
||||
current_willing = self.chat_reply_willing.get(chat_id, 0)
|
||||
if current_willing < 1:
|
||||
self.chat_reply_willing[chat_id] = min(1.0, current_willing + 0.4)
|
||||
self.chat_reply_willing[chat_id] = min(1.0, current_willing + 0.3)
|
||||
|
||||
async def not_reply_handle(self, message_id):
|
||||
return await super().not_reply_handle(message_id)
|
||||
|
||||
@@ -390,7 +390,7 @@ MODULE_COLORS = {
|
||||
"tts_action": "\033[38;5;58m", # 深黄色
|
||||
"doubao_pic_plugin": "\033[38;5;64m", # 深绿色
|
||||
# Action组件
|
||||
"no_reply_action": "\033[38;5;196m", # 亮红色,更显眼
|
||||
"no_reply_action": "\033[38;5;214m", # 亮橙色,显眼但不像警告
|
||||
"reply_action": "\033[38;5;46m", # 亮绿色
|
||||
"base_action": "\033[38;5;250m", # 浅灰色
|
||||
# 数据库和消息
|
||||
|
||||
@@ -69,6 +69,8 @@ class ChatConfig(ConfigBase):
|
||||
|
||||
max_context_size: int = 18
|
||||
"""上下文长度"""
|
||||
|
||||
willing_amplifier: float = 1.0
|
||||
|
||||
replyer_random_probability: float = 0.5
|
||||
"""
|
||||
@@ -76,15 +78,12 @@ class ChatConfig(ConfigBase):
|
||||
选择普通模型的概率为 1 - reasoning_normal_model_probability
|
||||
"""
|
||||
|
||||
thinking_timeout: int = 30
|
||||
thinking_timeout: int = 40
|
||||
"""麦麦最长思考规划时间,超过这个时间的思考会放弃(往往是api反应太慢)"""
|
||||
|
||||
talk_frequency: float = 1
|
||||
"""回复频率阈值"""
|
||||
|
||||
use_s4u_prompt_mode: bool = False
|
||||
"""是否使用 s4u 对话构建模式,该模式会分开处理当前对话对象和其他所有对话的内容进行 prompt 构建"""
|
||||
|
||||
mentioned_bot_inevitable_reply: bool = False
|
||||
"""提及 bot 必然回复"""
|
||||
|
||||
@@ -274,12 +273,6 @@ class NormalChatConfig(ConfigBase):
|
||||
willing_mode: str = "classical"
|
||||
"""意愿模式"""
|
||||
|
||||
response_interested_rate_amplifier: float = 1.0
|
||||
"""回复兴趣度放大系数"""
|
||||
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpressionConfig(ConfigBase):
|
||||
"""表达配置类"""
|
||||
@@ -307,11 +300,8 @@ class ExpressionConfig(ConfigBase):
|
||||
class ToolConfig(ConfigBase):
|
||||
"""工具配置类"""
|
||||
|
||||
enable_in_normal_chat: bool = False
|
||||
"""是否在普通聊天中启用工具"""
|
||||
|
||||
enable_in_focus_chat: bool = True
|
||||
"""是否在专注聊天中启用工具"""
|
||||
enable_tool: bool = False
|
||||
"""是否在聊天中启用工具"""
|
||||
|
||||
@dataclass
|
||||
class VoiceConfig(ConfigBase):
|
||||
|
||||
@@ -273,15 +273,19 @@ class Individuality:
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
if response.strip():
|
||||
if response and response.strip():
|
||||
personality_parts.append(response.strip())
|
||||
logger.info(f"精简人格侧面: {response.strip()}")
|
||||
else:
|
||||
logger.error(f"使用LLM压缩人设时出错: {response}")
|
||||
# 压缩失败时使用原始内容
|
||||
if personality_side:
|
||||
personality_parts.append(personality_side)
|
||||
|
||||
if personality_parts:
|
||||
personality_result = "。".join(personality_parts)
|
||||
else:
|
||||
personality_result = personality_core
|
||||
personality_result = personality_core or "友好活泼"
|
||||
else:
|
||||
personality_result = personality_core
|
||||
if personality_side:
|
||||
@@ -308,13 +312,14 @@ class Individuality:
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
if response.strip():
|
||||
if response and response.strip():
|
||||
identity_result = response.strip()
|
||||
logger.info(f"精简身份: {identity_result}")
|
||||
else:
|
||||
logger.error(f"使用LLM压缩身份时出错: {response}")
|
||||
identity_result = identity
|
||||
else:
|
||||
identity_result = "。".join(identity)
|
||||
identity_result = identity
|
||||
|
||||
return identity_result
|
||||
|
||||
|
||||
@@ -47,11 +47,35 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
|
||||
logger.debug(f"记忆激活率: {interested_rate:.2f}")
|
||||
|
||||
text_len = len(message.processed_plain_text)
|
||||
# 根据文本长度调整兴趣度,长度越大兴趣度越高,但增长率递减,最低0.01,最高0.05
|
||||
# 采用对数函数实现递减增长
|
||||
|
||||
base_interest = 0.01 + (0.05 - 0.01) * (math.log10(text_len + 1) / math.log10(1000 + 1))
|
||||
base_interest = min(max(base_interest, 0.01), 0.05)
|
||||
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
|
||||
# 基于实际分布:0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%)
|
||||
|
||||
if text_len == 0:
|
||||
base_interest = 0.01 # 空消息最低兴趣度
|
||||
elif text_len <= 5:
|
||||
# 1-5字符:线性增长 0.01 -> 0.03
|
||||
base_interest = 0.01 + (text_len - 1) * (0.03 - 0.01) / 4
|
||||
elif text_len <= 10:
|
||||
# 6-10字符:线性增长 0.03 -> 0.06
|
||||
base_interest = 0.03 + (text_len - 5) * (0.06 - 0.03) / 5
|
||||
elif text_len <= 20:
|
||||
# 11-20字符:线性增长 0.06 -> 0.12
|
||||
base_interest = 0.06 + (text_len - 10) * (0.12 - 0.06) / 10
|
||||
elif text_len <= 30:
|
||||
# 21-30字符:线性增长 0.12 -> 0.18
|
||||
base_interest = 0.12 + (text_len - 20) * (0.18 - 0.12) / 10
|
||||
elif text_len <= 50:
|
||||
# 31-50字符:线性增长 0.18 -> 0.22
|
||||
base_interest = 0.18 + (text_len - 30) * (0.22 - 0.18) / 20
|
||||
elif text_len <= 100:
|
||||
# 51-100字符:线性增长 0.22 -> 0.26
|
||||
base_interest = 0.22 + (text_len - 50) * (0.26 - 0.22) / 50
|
||||
else:
|
||||
# 100+字符:对数增长 0.26 -> 0.3,增长率递减
|
||||
base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901
|
||||
|
||||
# 确保在范围内
|
||||
base_interest = min(max(base_interest, 0.01), 0.3)
|
||||
|
||||
interested_rate += base_interest
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ class ChatMood:
|
||||
if interested_rate <= 0:
|
||||
interest_multiplier = 0
|
||||
else:
|
||||
interest_multiplier = 3 * math.pow(interested_rate, 0.25)
|
||||
interest_multiplier = 2 * math.pow(interested_rate, 0.25)
|
||||
|
||||
logger.debug(
|
||||
f"base_probability: {base_probability}, time_multiplier: {time_multiplier}, interest_multiplier: {interest_multiplier}"
|
||||
|
||||
@@ -139,7 +139,7 @@ class RelationshipManager:
|
||||
请用json格式输出,引起了你的兴趣,或者有什么需要你记忆的点。
|
||||
并为每个点赋予1-10的权重,权重越高,表示越重要。
|
||||
格式如下:
|
||||
{{
|
||||
[
|
||||
{{
|
||||
"point": "{person_name}想让我记住他的生日,我回答确认了,他的生日是11月23日",
|
||||
"weight": 10
|
||||
@@ -156,13 +156,10 @@ class RelationshipManager:
|
||||
"point": "{person_name}喜欢吃辣,具体来说,没有辣的食物ta都不喜欢吃,可能是因为ta是湖南人。",
|
||||
"weight": 7
|
||||
}}
|
||||
}}
|
||||
]
|
||||
|
||||
如果没有,就输出none,或points为空:
|
||||
{{
|
||||
"point": "none",
|
||||
"weight": 0
|
||||
}}
|
||||
如果没有,就输出none,或返回空数组:
|
||||
[]
|
||||
"""
|
||||
|
||||
# 调用LLM生成印象
|
||||
@@ -184,17 +181,25 @@ class RelationshipManager:
|
||||
try:
|
||||
points = repair_json(points)
|
||||
points_data = json.loads(points)
|
||||
if points_data == "none" or not points_data or points_data.get("point") == "none":
|
||||
|
||||
# 只处理正确的格式,错误格式直接跳过
|
||||
if points_data == "none" or not points_data:
|
||||
points_list = []
|
||||
elif isinstance(points_data, str) and points_data.lower() == "none":
|
||||
points_list = []
|
||||
elif isinstance(points_data, list):
|
||||
# 正确格式:数组格式 [{"point": "...", "weight": 10}, ...]
|
||||
if not points_data: # 空数组
|
||||
points_list = []
|
||||
else:
|
||||
points_list = [(item["point"], float(item["weight"]), current_time) for item in points_data]
|
||||
else:
|
||||
# logger.info(f"points_data: {points_data}")
|
||||
if isinstance(points_data, dict) and "points" in points_data:
|
||||
points_data = points_data["points"]
|
||||
if not isinstance(points_data, list):
|
||||
points_data = [points_data]
|
||||
# 添加可读时间到每个point
|
||||
points_list = [(item["point"], float(item["weight"]), current_time) for item in points_data]
|
||||
# 错误格式,直接跳过不解析
|
||||
logger.warning(f"LLM返回了错误的JSON格式,跳过解析: {type(points_data)}, 内容: {points_data}")
|
||||
points_list = []
|
||||
|
||||
# 权重过滤逻辑
|
||||
if points_list:
|
||||
original_points_list = list(points_list)
|
||||
points_list.clear()
|
||||
discarded_count = 0
|
||||
|
||||
@@ -9,6 +9,7 @@ from .base import (
|
||||
BasePlugin,
|
||||
BaseAction,
|
||||
BaseCommand,
|
||||
BaseTool,
|
||||
ConfigField,
|
||||
ComponentType,
|
||||
ActionActivationType,
|
||||
@@ -34,6 +35,7 @@ from .utils import (
|
||||
|
||||
from .apis import (
|
||||
chat_api,
|
||||
tool_api,
|
||||
component_manage_api,
|
||||
config_api,
|
||||
database_api,
|
||||
@@ -44,17 +46,17 @@ from .apis import (
|
||||
person_api,
|
||||
plugin_manage_api,
|
||||
send_api,
|
||||
utils_api,
|
||||
register_plugin,
|
||||
get_logger,
|
||||
)
|
||||
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__version__ = "2.0.0"
|
||||
|
||||
__all__ = [
|
||||
# API 模块
|
||||
"chat_api",
|
||||
"tool_api",
|
||||
"component_manage_api",
|
||||
"config_api",
|
||||
"database_api",
|
||||
@@ -65,13 +67,13 @@ __all__ = [
|
||||
"person_api",
|
||||
"plugin_manage_api",
|
||||
"send_api",
|
||||
"utils_api",
|
||||
"register_plugin",
|
||||
"get_logger",
|
||||
# 基础类
|
||||
"BasePlugin",
|
||||
"BaseAction",
|
||||
"BaseCommand",
|
||||
"BaseTool",
|
||||
"BaseEventHandler",
|
||||
# 类型定义
|
||||
"ComponentType",
|
||||
|
||||
@@ -17,7 +17,7 @@ from src.plugin_system.apis import (
|
||||
person_api,
|
||||
plugin_manage_api,
|
||||
send_api,
|
||||
utils_api,
|
||||
tool_api,
|
||||
)
|
||||
from .logging_api import get_logger
|
||||
from .plugin_register_api import register_plugin
|
||||
@@ -35,7 +35,7 @@ __all__ = [
|
||||
"person_api",
|
||||
"plugin_manage_api",
|
||||
"send_api",
|
||||
"utils_api",
|
||||
"get_logger",
|
||||
"register_plugin",
|
||||
"tool_api",
|
||||
]
|
||||
|
||||
@@ -32,6 +32,7 @@ class ChatManager:
|
||||
|
||||
@staticmethod
|
||||
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""获取所有聊天流
|
||||
|
||||
Args:
|
||||
@@ -57,6 +58,7 @@ class ChatManager:
|
||||
|
||||
@staticmethod
|
||||
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""获取所有群聊聊天流
|
||||
|
||||
Args:
|
||||
@@ -79,6 +81,7 @@ class ChatManager:
|
||||
|
||||
@staticmethod
|
||||
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""获取所有私聊聊天流
|
||||
|
||||
Args:
|
||||
@@ -105,7 +108,7 @@ class ChatManager:
|
||||
@staticmethod
|
||||
def get_group_stream_by_group_id(
|
||||
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
||||
) -> Optional[ChatStream]:
|
||||
) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast
|
||||
"""根据群ID获取聊天流
|
||||
|
||||
Args:
|
||||
@@ -142,7 +145,7 @@ class ChatManager:
|
||||
@staticmethod
|
||||
def get_private_stream_by_user_id(
|
||||
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
||||
) -> Optional[ChatStream]:
|
||||
) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast
|
||||
"""根据用户ID获取私聊流
|
||||
|
||||
Args:
|
||||
@@ -207,7 +210,7 @@ class ChatManager:
|
||||
chat_stream: 聊天流对象
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 聊天流信息字典
|
||||
Dict ({str: Any}): 聊天流信息字典
|
||||
|
||||
Raises:
|
||||
TypeError: 如果 chat_stream 不是 ChatStream 类型
|
||||
@@ -282,41 +285,41 @@ class ChatManager:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq"):
|
||||
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
"""获取所有聊天流的便捷函数"""
|
||||
return ChatManager.get_all_streams(platform)
|
||||
|
||||
|
||||
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq"):
|
||||
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
"""获取群聊聊天流的便捷函数"""
|
||||
return ChatManager.get_group_streams(platform)
|
||||
|
||||
|
||||
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq"):
|
||||
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
"""获取私聊聊天流的便捷函数"""
|
||||
return ChatManager.get_private_streams(platform)
|
||||
|
||||
|
||||
def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq"):
|
||||
def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]:
|
||||
"""根据群ID获取聊天流的便捷函数"""
|
||||
return ChatManager.get_group_stream_by_group_id(group_id, platform)
|
||||
|
||||
|
||||
def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq"):
|
||||
def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]:
|
||||
"""根据用户ID获取私聊流的便捷函数"""
|
||||
return ChatManager.get_private_stream_by_user_id(user_id, platform)
|
||||
|
||||
|
||||
def get_stream_type(chat_stream: ChatStream):
|
||||
def get_stream_type(chat_stream: ChatStream) -> str:
|
||||
"""获取聊天流类型的便捷函数"""
|
||||
return ChatManager.get_stream_type(chat_stream)
|
||||
|
||||
|
||||
def get_stream_info(chat_stream: ChatStream):
|
||||
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
|
||||
"""获取聊天流信息的便捷函数"""
|
||||
return ChatManager.get_stream_info(chat_stream)
|
||||
|
||||
|
||||
def get_streams_summary():
|
||||
def get_streams_summary() -> Dict[str, int]:
|
||||
"""获取聊天流统计摘要的便捷函数"""
|
||||
return ChatManager.get_streams_summary()
|
||||
|
||||
@@ -5,6 +5,7 @@ from src.plugin_system.base.component_types import (
|
||||
EventHandlerInfo,
|
||||
PluginInfo,
|
||||
ComponentType,
|
||||
ToolInfo,
|
||||
)
|
||||
|
||||
|
||||
@@ -119,6 +120,21 @@ def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
|
||||
return component_registry.get_registered_command_info(command_name)
|
||||
|
||||
|
||||
def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]:
|
||||
"""
|
||||
获取指定 Tool 的注册信息。
|
||||
|
||||
Args:
|
||||
tool_name (str): Tool 名称。
|
||||
|
||||
Returns:
|
||||
ToolInfo: Tool 信息对象,如果 Tool 不存在则返回 None。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_registered_tool_info(tool_name)
|
||||
|
||||
|
||||
# === EventHandler 特定查询方法 ===
|
||||
def get_registered_event_handler_info(
|
||||
event_handler_name: str,
|
||||
@@ -191,6 +207,8 @@ def locally_enable_component(component_name: str, component_type: ComponentType,
|
||||
return global_announcement_manager.enable_specific_chat_action(stream_id, component_name)
|
||||
case ComponentType.COMMAND:
|
||||
return global_announcement_manager.enable_specific_chat_command(stream_id, component_name)
|
||||
case ComponentType.TOOL:
|
||||
return global_announcement_manager.enable_specific_chat_tool(stream_id, component_name)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
return global_announcement_manager.enable_specific_chat_event_handler(stream_id, component_name)
|
||||
case _:
|
||||
@@ -216,11 +234,14 @@ def locally_disable_component(component_name: str, component_type: ComponentType
|
||||
return global_announcement_manager.disable_specific_chat_action(stream_id, component_name)
|
||||
case ComponentType.COMMAND:
|
||||
return global_announcement_manager.disable_specific_chat_command(stream_id, component_name)
|
||||
case ComponentType.TOOL:
|
||||
return global_announcement_manager.disable_specific_chat_tool(stream_id, component_name)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
return global_announcement_manager.disable_specific_chat_event_handler(stream_id, component_name)
|
||||
case _:
|
||||
raise ValueError(f"未知 component type: {component_type}")
|
||||
|
||||
|
||||
def get_locally_disabled_components(stream_id: str, component_type: ComponentType) -> list[str]:
|
||||
"""
|
||||
获取指定消息流中禁用的组件列表。
|
||||
@@ -239,7 +260,9 @@ def get_locally_disabled_components(stream_id: str, component_type: ComponentTyp
|
||||
return global_announcement_manager.get_disabled_chat_actions(stream_id)
|
||||
case ComponentType.COMMAND:
|
||||
return global_announcement_manager.get_disabled_chat_commands(stream_id)
|
||||
case ComponentType.TOOL:
|
||||
return global_announcement_manager.get_disabled_chat_tools(stream_id)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
return global_announcement_manager.get_disabled_chat_event_handlers(stream_id)
|
||||
case _:
|
||||
raise ValueError(f"未知 component type: {component_type}")
|
||||
raise ValueError(f"未知 component type: {component_type}")
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
from typing import Any
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
|
||||
logger = get_logger("config_api")
|
||||
|
||||
@@ -26,7 +25,7 @@ def get_global_config(key: str, default: Any = None) -> Any:
|
||||
插件应使用此方法读取全局配置,以保证只读和隔离性。
|
||||
|
||||
Args:
|
||||
key: 命名空间式配置键名,支持嵌套访问,如 "section.subsection.key",大小写敏感
|
||||
key: 命名空间式配置键名,使用嵌套访问,如 "section.subsection.key",大小写敏感
|
||||
default: 如果配置不存在时返回的默认值
|
||||
|
||||
Returns:
|
||||
@@ -76,50 +75,3 @@ def get_plugin_config(plugin_config: dict, key: str, default: Any = None) -> Any
|
||||
except Exception as e:
|
||||
logger.warning(f"[ConfigAPI] 获取插件配置 {key} 失败: {e}")
|
||||
return default
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 用户信息API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def get_user_id_by_person_name(person_name: str) -> tuple[str, str]:
|
||||
"""根据内部用户名获取用户ID
|
||||
|
||||
Args:
|
||||
person_name: 用户名
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: (平台, 用户ID)
|
||||
"""
|
||||
try:
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(person_name)
|
||||
user_id: str = await person_info_manager.get_value(person_id, "user_id") # type: ignore
|
||||
platform: str = await person_info_manager.get_value(person_id, "platform") # type: ignore
|
||||
return platform, user_id
|
||||
except Exception as e:
|
||||
logger.error(f"[ConfigAPI] 根据用户名获取用户ID失败: {e}")
|
||||
return "", ""
|
||||
|
||||
|
||||
async def get_person_info(person_id: str, key: str, default: Any = None) -> Any:
|
||||
"""获取用户信息
|
||||
|
||||
Args:
|
||||
person_id: 用户ID
|
||||
key: 信息键名
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
Any: 用户信息值或默认值
|
||||
"""
|
||||
try:
|
||||
person_info_manager = get_person_info_manager()
|
||||
response = await person_info_manager.get_value(person_id, key)
|
||||
if not response:
|
||||
raise ValueError(f"[ConfigAPI] 获取用户 {person_id} 的信息 '{key}' 失败,返回默认值")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"[ConfigAPI] 获取用户信息失败: {e}")
|
||||
return default
|
||||
|
||||
@@ -152,10 +152,7 @@ async def db_query(
|
||||
|
||||
except DoesNotExist:
|
||||
# 记录不存在
|
||||
if query_type == "get" and single_result:
|
||||
return None
|
||||
return []
|
||||
|
||||
return None if query_type == "get" and single_result else []
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseAPI] 数据库操作出错: {e}")
|
||||
traceback.print_exc()
|
||||
@@ -170,7 +167,8 @@ async def db_query(
|
||||
|
||||
async def db_save(
|
||||
model_class: Type[Model], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None
|
||||
) -> Union[Dict[str, Any], None]:
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
# sourcery skip: inline-immediately-returned-variable
|
||||
"""保存数据到数据库(创建或更新)
|
||||
|
||||
如果提供了key_field和key_value,会先尝试查找匹配的记录进行更新;
|
||||
@@ -203,10 +201,9 @@ async def db_save(
|
||||
try:
|
||||
# 如果提供了key_field和key_value,尝试更新现有记录
|
||||
if key_field and key_value is not None:
|
||||
# 查找现有记录
|
||||
existing_records = list(model_class.select().where(getattr(model_class, key_field) == key_value).limit(1))
|
||||
|
||||
if existing_records:
|
||||
if existing_records := list(
|
||||
model_class.select().where(getattr(model_class, key_field) == key_value).limit(1)
|
||||
):
|
||||
# 更新现有记录
|
||||
existing_record = existing_records[0]
|
||||
for field, value in data.items():
|
||||
@@ -244,8 +241,8 @@ async def db_get(
|
||||
Args:
|
||||
model_class: Peewee模型类
|
||||
filters: 过滤条件,字段名和值的字典
|
||||
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段(即time字段)降序
|
||||
limit: 结果数量限制
|
||||
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段(即time字段)降序
|
||||
single_result: 是否只返回单个结果,如果为True,则返回单个记录字典或None;否则返回记录字典列表或空列表
|
||||
|
||||
Returns:
|
||||
@@ -310,7 +307,7 @@ async def store_action_info(
|
||||
thinking_id: str = "",
|
||||
action_data: Optional[dict] = None,
|
||||
action_name: str = "",
|
||||
) -> Union[Dict[str, Any], None]:
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""存储动作信息到数据库
|
||||
|
||||
将Action执行的相关信息保存到ActionRecords表中,用于后续的记忆和上下文构建。
|
||||
|
||||
@@ -65,14 +65,14 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]
|
||||
return None
|
||||
|
||||
|
||||
async def get_random(count: Optional[int] = 1) -> Optional[List[Tuple[str, str, str]]]:
|
||||
async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
|
||||
"""随机获取指定数量的表情包
|
||||
|
||||
Args:
|
||||
count: 要获取的表情包数量,默认为1
|
||||
|
||||
Returns:
|
||||
Optional[List[Tuple[str, str, str]]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表,如果失败则为None
|
||||
List[Tuple[str, str, str]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表,失败则返回空列表
|
||||
|
||||
Raises:
|
||||
TypeError: 如果count不是整数类型
|
||||
@@ -94,13 +94,13 @@ async def get_random(count: Optional[int] = 1) -> Optional[List[Tuple[str, str,
|
||||
|
||||
if not all_emojis:
|
||||
logger.warning("[EmojiAPI] 没有可用的表情包")
|
||||
return None
|
||||
return []
|
||||
|
||||
# 过滤有效表情包
|
||||
valid_emojis = [emoji for emoji in all_emojis if not emoji.is_deleted]
|
||||
if not valid_emojis:
|
||||
logger.warning("[EmojiAPI] 没有有效的表情包")
|
||||
return None
|
||||
return []
|
||||
|
||||
if len(valid_emojis) < count:
|
||||
logger.warning(
|
||||
@@ -127,14 +127,14 @@ async def get_random(count: Optional[int] = 1) -> Optional[List[Tuple[str, str,
|
||||
|
||||
if not results and count > 0:
|
||||
logger.warning("[EmojiAPI] 随机获取表情包失败,没有一个可以成功处理")
|
||||
return None
|
||||
return []
|
||||
|
||||
logger.info(f"[EmojiAPI] 成功获取 {len(results)} 个随机表情包")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 获取随机表情包失败: {e}")
|
||||
return None
|
||||
return []
|
||||
|
||||
|
||||
async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
|
||||
@@ -162,10 +162,11 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
|
||||
|
||||
# 筛选匹配情感的表情包
|
||||
matching_emojis = []
|
||||
for emoji_obj in all_emojis:
|
||||
if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion]:
|
||||
matching_emojis.append(emoji_obj)
|
||||
|
||||
matching_emojis.extend(
|
||||
emoji_obj
|
||||
for emoji_obj in all_emojis
|
||||
if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion]
|
||||
)
|
||||
if not matching_emojis:
|
||||
logger.warning(f"[EmojiAPI] 未找到匹配情感 '{emotion}' 的表情包")
|
||||
return None
|
||||
@@ -256,10 +257,11 @@ def get_descriptions() -> List[str]:
|
||||
emoji_manager = get_emoji_manager()
|
||||
descriptions = []
|
||||
|
||||
for emoji_obj in emoji_manager.emoji_objects:
|
||||
if not emoji_obj.is_deleted and emoji_obj.description:
|
||||
descriptions.append(emoji_obj.description)
|
||||
|
||||
descriptions.extend(
|
||||
emoji_obj.description
|
||||
for emoji_obj in emoji_manager.emoji_objects
|
||||
if not emoji_obj.is_deleted and emoji_obj.description
|
||||
)
|
||||
return descriptions
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 获取表情包描述失败: {e}")
|
||||
|
||||
@@ -84,18 +84,23 @@ async def generate_reply(
|
||||
enable_chinese_typo: bool = True,
|
||||
return_prompt: bool = False,
|
||||
model_configs: Optional[List[Dict[str, Any]]] = None,
|
||||
request_type: str = "",
|
||||
enable_timeout: bool = False,
|
||||
request_type: str = "generator_api",
|
||||
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
|
||||
"""生成回复
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象(优先)
|
||||
chat_id: 聊天ID(备用)
|
||||
action_data: 动作数据
|
||||
action_data: 动作数据(向下兼容,包含reply_to和extra_info)
|
||||
reply_to: 回复对象,格式为 "发送者:消息内容"
|
||||
extra_info: 额外信息,用于补充上下文
|
||||
available_actions: 可用动作
|
||||
enable_tool: 是否启用工具调用
|
||||
enable_splitter: 是否启用消息分割器
|
||||
enable_chinese_typo: 是否启用错字生成器
|
||||
return_prompt: 是否返回提示词
|
||||
model_configs: 模型配置列表
|
||||
request_type: 请求类型(可选,记录LLM使用)
|
||||
Returns:
|
||||
Tuple[bool, List[Tuple[str, Any]], Optional[str]]: (是否成功, 回复集合, 提示词)
|
||||
"""
|
||||
@@ -108,13 +113,16 @@ async def generate_reply(
|
||||
|
||||
logger.debug("[GeneratorAPI] 开始生成回复")
|
||||
|
||||
if not reply_to and action_data:
|
||||
reply_to = action_data.get("reply_to", "")
|
||||
if not extra_info and action_data:
|
||||
extra_info = action_data.get("extra_info", "")
|
||||
|
||||
# 调用回复器生成回复
|
||||
success, content, prompt = await replyer.generate_reply_with_context(
|
||||
reply_data=action_data or {},
|
||||
reply_to=reply_to,
|
||||
extra_info=extra_info,
|
||||
available_actions=available_actions,
|
||||
enable_timeout=enable_timeout,
|
||||
enable_tool=enable_tool,
|
||||
)
|
||||
reply_set = []
|
||||
@@ -136,6 +144,7 @@ async def generate_reply(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorAPI] 生成回复时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False, [], None
|
||||
|
||||
|
||||
@@ -146,15 +155,24 @@ async def rewrite_reply(
|
||||
enable_splitter: bool = True,
|
||||
enable_chinese_typo: bool = True,
|
||||
model_configs: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> Tuple[bool, List[Tuple[str, Any]]]:
|
||||
raw_reply: str = "",
|
||||
reason: str = "",
|
||||
reply_to: str = "",
|
||||
return_prompt: bool = False,
|
||||
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
|
||||
"""重写回复
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象(优先)
|
||||
reply_data: 回复数据
|
||||
reply_data: 回复数据字典(向下兼容备用,当其他参数缺失时从此获取)
|
||||
chat_id: 聊天ID(备用)
|
||||
enable_splitter: 是否启用消息分割器
|
||||
enable_chinese_typo: 是否启用错字生成器
|
||||
model_configs: 模型配置列表
|
||||
raw_reply: 原始回复内容
|
||||
reason: 回复原因
|
||||
reply_to: 回复对象
|
||||
return_prompt: 是否返回提示词
|
||||
|
||||
Returns:
|
||||
Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合)
|
||||
@@ -164,12 +182,23 @@ async def rewrite_reply(
|
||||
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||
return False, []
|
||||
return False, [], None
|
||||
|
||||
logger.info("[GeneratorAPI] 开始重写回复")
|
||||
|
||||
# 如果参数缺失,从reply_data中获取
|
||||
if reply_data:
|
||||
raw_reply = raw_reply or reply_data.get("raw_reply", "")
|
||||
reason = reason or reply_data.get("reason", "")
|
||||
reply_to = reply_to or reply_data.get("reply_to", "")
|
||||
|
||||
# 调用回复器重写回复
|
||||
success, content = await replyer.rewrite_reply_with_context(reply_data=reply_data or {})
|
||||
success, content, prompt = await replyer.rewrite_reply_with_context(
|
||||
raw_reply=raw_reply,
|
||||
reason=reason,
|
||||
reply_to=reply_to,
|
||||
return_prompt=return_prompt,
|
||||
)
|
||||
reply_set = []
|
||||
if content:
|
||||
reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo)
|
||||
@@ -179,14 +208,14 @@ async def rewrite_reply(
|
||||
else:
|
||||
logger.warning("[GeneratorAPI] 重写回复失败")
|
||||
|
||||
return success, reply_set
|
||||
return success, reply_set, prompt if return_prompt else None
|
||||
|
||||
except ValueError as ve:
|
||||
raise ve
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorAPI] 重写回复时出错: {e}")
|
||||
return False, []
|
||||
return False, [], None
|
||||
|
||||
|
||||
async def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> List[Tuple[str, Any]]:
|
||||
@@ -212,3 +241,27 @@ async def process_human_text(content: str, enable_splitter: bool, enable_chinese
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorAPI] 处理人形文本时出错: {e}")
|
||||
return []
|
||||
|
||||
async def generate_response_custom(
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
model_configs: Optional[List[Dict[str, Any]]] = None,
|
||||
prompt: str = "",
|
||||
) -> Optional[str]:
|
||||
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||
return None
|
||||
|
||||
try:
|
||||
logger.debug("[GeneratorAPI] 开始生成自定义回复")
|
||||
response = await replyer.llm_generate_content(prompt)
|
||||
if response:
|
||||
logger.debug("[GeneratorAPI] 自定义回复生成成功")
|
||||
return response
|
||||
else:
|
||||
logger.warning("[GeneratorAPI] 自定义回复生成失败")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorAPI] 生成自定义回复时出错: {e}")
|
||||
return None
|
||||
@@ -54,7 +54,7 @@ def get_available_models() -> Dict[str, Any]:
|
||||
|
||||
async def generate_with_model(
|
||||
prompt: str, model_config: Dict[str, Any], request_type: str = "plugin.generate", **kwargs
|
||||
) -> Tuple[bool, str, str, str]:
|
||||
) -> Tuple[bool, str]:
|
||||
"""使用指定模型生成内容
|
||||
|
||||
Args:
|
||||
@@ -73,10 +73,11 @@ async def generate_with_model(
|
||||
|
||||
llm_request = LLMRequest(model=model_config, request_type=request_type, **kwargs)
|
||||
|
||||
response, (reasoning, model_name) = await llm_request.generate_response_async(prompt)
|
||||
return True, response, reasoning, model_name
|
||||
# TODO: 复活这个_
|
||||
response, _ = await llm_request.generate_response_async(prompt)
|
||||
return True, response
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"生成内容时出错: {str(e)}"
|
||||
logger.error(f"[LLMAPI] {error_msg}")
|
||||
return False, error_msg, "", ""
|
||||
return False, error_msg
|
||||
|
||||
@@ -207,7 +207,7 @@ def get_random_chat_messages(
|
||||
|
||||
|
||||
def get_messages_by_time_for_users(
|
||||
start_time: float, end_time: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
|
||||
start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定用户在所有聊天中指定时间范围内的消息
|
||||
@@ -287,7 +287,7 @@ def get_messages_before_time_in_chat(
|
||||
return get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit)
|
||||
|
||||
|
||||
def get_messages_before_time_for_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]:
|
||||
def get_messages_before_time_for_users(timestamp: float, person_ids: List[str], limit: int = 0) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定用户在指定时间戳之前的消息
|
||||
|
||||
@@ -372,7 +372,7 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional
|
||||
return num_new_messages_since(chat_id, start_time, end_time)
|
||||
|
||||
|
||||
def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: list) -> int:
|
||||
def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int:
|
||||
"""
|
||||
计算指定聊天中指定用户从开始时间到结束时间的新消息数量
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from typing import Tuple, List
|
||||
|
||||
|
||||
def list_loaded_plugins() -> List[str]:
|
||||
"""
|
||||
列出所有当前加载的插件。
|
||||
|
||||
Returns:
|
||||
list: 当前加载的插件名称列表。
|
||||
List[str]: 当前加载的插件名称列表。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
@@ -16,17 +18,38 @@ def list_registered_plugins() -> List[str]:
|
||||
列出所有已注册的插件。
|
||||
|
||||
Returns:
|
||||
list: 已注册的插件名称列表。
|
||||
List[str]: 已注册的插件名称列表。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
return plugin_manager.list_registered_plugins()
|
||||
|
||||
|
||||
def get_plugin_path(plugin_name: str) -> str:
|
||||
"""
|
||||
获取指定插件的路径。
|
||||
|
||||
Args:
|
||||
plugin_name (str): 插件名称。
|
||||
|
||||
Returns:
|
||||
str: 插件目录的绝对路径。
|
||||
|
||||
Raises:
|
||||
ValueError: 如果插件不存在。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
if plugin_path := plugin_manager.get_plugin_path(plugin_name):
|
||||
return plugin_path
|
||||
else:
|
||||
raise ValueError(f"插件 '{plugin_name}' 不存在。")
|
||||
|
||||
|
||||
async def remove_plugin(plugin_name: str) -> bool:
|
||||
"""
|
||||
卸载指定的插件。
|
||||
|
||||
|
||||
**此函数是异步的,确保在异步环境中调用。**
|
||||
|
||||
Args:
|
||||
@@ -43,7 +66,7 @@ async def remove_plugin(plugin_name: str) -> bool:
|
||||
async def reload_plugin(plugin_name: str) -> bool:
|
||||
"""
|
||||
重新加载指定的插件。
|
||||
|
||||
|
||||
**此函数是异步的,确保在异步环境中调用。**
|
||||
|
||||
Args:
|
||||
@@ -71,6 +94,7 @@ def load_plugin(plugin_name: str) -> Tuple[bool, int]:
|
||||
|
||||
return plugin_manager.load_registered_plugin_classes(plugin_name)
|
||||
|
||||
|
||||
def add_plugin_directory(plugin_directory: str) -> bool:
|
||||
"""
|
||||
添加插件目录。
|
||||
@@ -84,6 +108,7 @@ def add_plugin_directory(plugin_directory: str) -> bool:
|
||||
|
||||
return plugin_manager.add_plugin_directory(plugin_directory)
|
||||
|
||||
|
||||
def rescan_plugin_directory() -> Tuple[int, int]:
|
||||
"""
|
||||
重新扫描插件目录,加载新插件。
|
||||
@@ -92,4 +117,4 @@ def rescan_plugin_directory() -> Tuple[int, int]:
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
return plugin_manager.rescan_plugin_directory()
|
||||
return plugin_manager.rescan_plugin_directory()
|
||||
|
||||
@@ -22,7 +22,6 @@
|
||||
import traceback
|
||||
import time
|
||||
import difflib
|
||||
import re
|
||||
from typing import Optional, Union
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -30,7 +29,7 @@ from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.uni_message_sender import HeartFCSender
|
||||
from src.chat.message_receive.message import MessageSending, MessageRecv
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, replace_user_references_async
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from maim_message import Seg, UserInfo
|
||||
from src.config.config import global_config
|
||||
@@ -50,7 +49,7 @@ async def _send_to_target(
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
reply_to: str = "",
|
||||
reply_to_platform_id: str = "",
|
||||
reply_to_platform_id: Optional[str] = None,
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
) -> bool:
|
||||
@@ -61,8 +60,11 @@ async def _send_to_target(
|
||||
content: 消息内容
|
||||
stream_id: 目标流ID
|
||||
display_message: 显示消息
|
||||
typing: 是否显示正在输入
|
||||
reply_to: 回复消息的格式,如"发送者:消息内容"
|
||||
typing: 是否模拟打字等待。
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
reply_to_platform_id: 回复消息,格式为"平台:用户ID",如果不提供则自动查找(插件开发者禁用!)
|
||||
storage_message: 是否存储消息到数据库
|
||||
show_log: 发送是否显示日志
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
@@ -98,6 +100,10 @@ async def _send_to_target(
|
||||
anchor_message = None
|
||||
if reply_to:
|
||||
anchor_message = await _find_reply_message(target_stream, reply_to)
|
||||
if anchor_message and anchor_message.message_info.user_info and not reply_to_platform_id:
|
||||
reply_to_platform_id = (
|
||||
f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}"
|
||||
)
|
||||
|
||||
# 构建发送消息对象
|
||||
bot_message = MessageSending(
|
||||
@@ -183,32 +189,8 @@ async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageR
|
||||
if person_name == sender:
|
||||
translate_text = message["processed_plain_text"]
|
||||
|
||||
# 检查是否有 回复<aaa:bbb> 字段
|
||||
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
|
||||
if match := re.search(reply_pattern, translate_text):
|
||||
aaa = match.group(1)
|
||||
bbb = match.group(2)
|
||||
reply_person_id = get_person_info_manager().get_person_id(platform, bbb)
|
||||
reply_person_name = await get_person_info_manager().get_value(reply_person_id, "person_name") or aaa
|
||||
# 在内容前加上回复信息
|
||||
translate_text = re.sub(reply_pattern, f"回复 {reply_person_name}", translate_text, count=1)
|
||||
|
||||
# 检查是否有 @<aaa:bbb> 字段
|
||||
at_pattern = r"@<([^:<>]+):([^:<>]+)>"
|
||||
at_matches = list(re.finditer(at_pattern, translate_text))
|
||||
if at_matches:
|
||||
new_content = ""
|
||||
last_end = 0
|
||||
for m in at_matches:
|
||||
new_content += translate_text[last_end : m.start()]
|
||||
aaa = m.group(1)
|
||||
bbb = m.group(2)
|
||||
at_person_id = get_person_info_manager().get_person_id(platform, bbb)
|
||||
at_person_name = await get_person_info_manager().get_value(at_person_id, "person_name") or aaa
|
||||
new_content += f"@{at_person_name}"
|
||||
last_end = m.end()
|
||||
new_content += translate_text[last_end:]
|
||||
translate_text = new_content
|
||||
# 使用独立函数处理用户引用格式
|
||||
translate_text = await replace_user_references_async(translate_text, platform)
|
||||
|
||||
similarity = difflib.SequenceMatcher(None, text, translate_text).ratio()
|
||||
if similarity >= 0.9:
|
||||
@@ -287,12 +269,22 @@ async def text_to_stream(
|
||||
stream_id: 聊天流ID
|
||||
typing: 是否显示正在输入
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
reply_to_platform_id: 回复消息,格式为"平台:用户ID",如果不提供则自动查找(插件开发者禁用!)
|
||||
storage_message: 是否存储消息到数据库
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await _send_to_target("text", text, stream_id, "", typing, reply_to, reply_to_platform_id, storage_message)
|
||||
return await _send_to_target(
|
||||
"text",
|
||||
text,
|
||||
stream_id,
|
||||
"",
|
||||
typing,
|
||||
reply_to,
|
||||
reply_to_platform_id=reply_to_platform_id,
|
||||
storage_message=storage_message,
|
||||
)
|
||||
|
||||
|
||||
async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bool = True) -> bool:
|
||||
@@ -375,249 +367,3 @@ async def custom_to_stream(
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
|
||||
|
||||
async def text_to_group(
|
||||
text: str,
|
||||
group_id: str,
|
||||
platform: str = "qq",
|
||||
typing: bool = False,
|
||||
reply_to: str = "",
|
||||
storage_message: bool = True,
|
||||
) -> bool:
|
||||
"""向群聊发送文本消息
|
||||
|
||||
Args:
|
||||
text: 要发送的文本内容
|
||||
group_id: 群聊ID
|
||||
platform: 平台,默认为"qq"
|
||||
typing: 是否显示正在输入
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
|
||||
|
||||
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message)
|
||||
|
||||
|
||||
async def text_to_user(
|
||||
text: str,
|
||||
user_id: str,
|
||||
platform: str = "qq",
|
||||
typing: bool = False,
|
||||
reply_to: str = "",
|
||||
storage_message: bool = True,
|
||||
) -> bool:
|
||||
"""向用户发送私聊文本消息
|
||||
|
||||
Args:
|
||||
text: 要发送的文本内容
|
||||
user_id: 用户ID
|
||||
platform: 平台,默认为"qq"
|
||||
typing: 是否显示正在输入
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
|
||||
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message)
|
||||
|
||||
|
||||
async def emoji_to_group(emoji_base64: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
|
||||
"""向群聊发送表情包
|
||||
|
||||
Args:
|
||||
emoji_base64: 表情包的base64编码
|
||||
group_id: 群聊ID
|
||||
platform: 平台,默认为"qq"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
|
||||
return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message)
|
||||
|
||||
|
||||
async def emoji_to_user(emoji_base64: str, user_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
|
||||
"""向用户发送表情包
|
||||
|
||||
Args:
|
||||
emoji_base64: 表情包的base64编码
|
||||
user_id: 用户ID
|
||||
platform: 平台,默认为"qq"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
|
||||
return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message)
|
||||
|
||||
|
||||
async def image_to_group(image_base64: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
|
||||
"""向群聊发送图片
|
||||
|
||||
Args:
|
||||
image_base64: 图片的base64编码
|
||||
group_id: 群聊ID
|
||||
platform: 平台,默认为"qq"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
|
||||
return await _send_to_target("image", image_base64, stream_id, "", typing=False, storage_message=storage_message)
|
||||
|
||||
|
||||
async def image_to_user(image_base64: str, user_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
|
||||
"""向用户发送图片
|
||||
|
||||
Args:
|
||||
image_base64: 图片的base64编码
|
||||
user_id: 用户ID
|
||||
platform: 平台,默认为"qq"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
|
||||
return await _send_to_target("image", image_base64, stream_id, "", typing=False)
|
||||
|
||||
|
||||
async def command_to_group(command: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
|
||||
"""向群聊发送命令
|
||||
|
||||
Args:
|
||||
command: 命令
|
||||
group_id: 群聊ID
|
||||
platform: 平台,默认为"qq"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
|
||||
return await _send_to_target("command", command, stream_id, "", typing=False, storage_message=storage_message)
|
||||
|
||||
|
||||
async def command_to_user(command: str, user_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
|
||||
"""向用户发送命令
|
||||
|
||||
Args:
|
||||
command: 命令
|
||||
user_id: 用户ID
|
||||
platform: 平台,默认为"qq"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
|
||||
return await _send_to_target("command", command, stream_id, "", typing=False, storage_message=storage_message)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 通用发送函数 - 支持任意消息类型
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def custom_to_group(
|
||||
message_type: str,
|
||||
content: str,
|
||||
group_id: str,
|
||||
platform: str = "qq",
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
reply_to: str = "",
|
||||
storage_message: bool = True,
|
||||
) -> bool:
|
||||
"""向群聊发送自定义类型消息
|
||||
|
||||
Args:
|
||||
message_type: 消息类型,如"text"、"image"、"emoji"、"video"、"file"等
|
||||
content: 消息内容(通常是base64编码或文本)
|
||||
group_id: 群聊ID
|
||||
platform: 平台,默认为"qq"
|
||||
display_message: 显示消息
|
||||
typing: 是否显示正在输入
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
|
||||
return await _send_to_target(
|
||||
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
|
||||
)
|
||||
|
||||
|
||||
async def custom_to_user(
|
||||
message_type: str,
|
||||
content: str,
|
||||
user_id: str,
|
||||
platform: str = "qq",
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
reply_to: str = "",
|
||||
storage_message: bool = True,
|
||||
) -> bool:
|
||||
"""向用户发送自定义类型消息
|
||||
|
||||
Args:
|
||||
message_type: 消息类型,如"text"、"image"、"emoji"、"video"、"file"等
|
||||
content: 消息内容(通常是base64编码或文本)
|
||||
user_id: 用户ID
|
||||
platform: 平台,默认为"qq"
|
||||
display_message: 显示消息
|
||||
typing: 是否显示正在输入
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
|
||||
return await _send_to_target(
|
||||
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
|
||||
)
|
||||
|
||||
|
||||
async def custom_message(
|
||||
message_type: str,
|
||||
content: str,
|
||||
target_id: str,
|
||||
is_group: bool = True,
|
||||
platform: str = "qq",
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
reply_to: str = "",
|
||||
storage_message: bool = True,
|
||||
) -> bool:
|
||||
"""发送自定义消息的通用接口
|
||||
|
||||
Args:
|
||||
message_type: 消息类型,如"text"、"image"、"emoji"、"video"、"file"、"audio"等
|
||||
content: 消息内容
|
||||
target_id: 目标ID(群ID或用户ID)
|
||||
is_group: 是否为群聊,True为群聊,False为私聊
|
||||
platform: 平台,默认为"qq"
|
||||
display_message: 显示消息
|
||||
typing: 是否显示正在输入
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
|
||||
示例:
|
||||
# 发送视频到群聊
|
||||
await send_api.custom_message("video", video_base64, "123456", True)
|
||||
|
||||
# 发送文件到用户
|
||||
await send_api.custom_message("file", file_base64, "987654", False)
|
||||
|
||||
# 发送音频到群聊并回复特定消息
|
||||
await send_api.custom_message("audio", audio_base64, "123456", True, reply_to="张三:你好")
|
||||
"""
|
||||
stream_id = get_chat_manager().get_stream_id(platform, target_id, is_group)
|
||||
return await _send_to_target(
|
||||
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
|
||||
)
|
||||
|
||||
27
src/plugin_system/apis/tool_api.py
Normal file
27
src/plugin_system/apis/tool_api.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from typing import Optional, Type
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("tool_api")
|
||||
|
||||
|
||||
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
|
||||
"""获取公开工具实例"""
|
||||
from src.plugin_system.core import component_registry
|
||||
|
||||
tool_class: Type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
|
||||
return tool_class() if tool_class else None
|
||||
|
||||
|
||||
def get_llm_available_tool_definitions():
|
||||
"""获取LLM可用的工具定义列表
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, Dict[str, Any]]]: 工具定义列表,为[("tool_name", 定义)]
|
||||
"""
|
||||
from src.plugin_system.core import component_registry
|
||||
|
||||
llm_available_tools = component_registry.get_llm_available_tools()
|
||||
return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()]
|
||||
@@ -1,168 +0,0 @@
|
||||
"""工具类API模块
|
||||
|
||||
提供了各种辅助功能
|
||||
使用方式:
|
||||
from src.plugin_system.apis import utils_api
|
||||
plugin_path = utils_api.get_plugin_path()
|
||||
data = utils_api.read_json_file("data.json")
|
||||
timestamp = utils_api.get_timestamp()
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import inspect
|
||||
import datetime
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("utils_api")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 文件操作API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_plugin_path(caller_frame=None) -> str:
|
||||
"""获取调用者插件的路径
|
||||
|
||||
Args:
|
||||
caller_frame: 调用者的栈帧,默认为None(自动获取)
|
||||
|
||||
Returns:
|
||||
str: 插件目录的绝对路径
|
||||
"""
|
||||
try:
|
||||
if caller_frame is None:
|
||||
caller_frame = inspect.currentframe().f_back # type: ignore
|
||||
|
||||
plugin_module_path = inspect.getfile(caller_frame) # type: ignore
|
||||
plugin_dir = os.path.dirname(plugin_module_path)
|
||||
return plugin_dir
|
||||
except Exception as e:
|
||||
logger.error(f"[UtilsAPI] 获取插件路径失败: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def read_json_file(file_path: str, default: Any = None) -> Any:
|
||||
"""读取JSON文件
|
||||
|
||||
Args:
|
||||
file_path: 文件路径,可以是相对于插件目录的路径
|
||||
default: 如果文件不存在或读取失败时返回的默认值
|
||||
|
||||
Returns:
|
||||
Any: JSON数据或默认值
|
||||
"""
|
||||
try:
|
||||
# 如果是相对路径,则相对于调用者的插件目录
|
||||
if not os.path.isabs(file_path):
|
||||
caller_frame = inspect.currentframe().f_back # type: ignore
|
||||
plugin_dir = get_plugin_path(caller_frame)
|
||||
file_path = os.path.join(plugin_dir, file_path)
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
logger.warning(f"[UtilsAPI] 文件不存在: {file_path}")
|
||||
return default
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"[UtilsAPI] 读取JSON文件出错: {e}")
|
||||
return default
|
||||
|
||||
|
||||
def write_json_file(file_path: str, data: Any, indent: int = 2) -> bool:
|
||||
"""写入JSON文件
|
||||
|
||||
Args:
|
||||
file_path: 文件路径,可以是相对于插件目录的路径
|
||||
data: 要写入的数据
|
||||
indent: JSON缩进
|
||||
|
||||
Returns:
|
||||
bool: 是否写入成功
|
||||
"""
|
||||
try:
|
||||
# 如果是相对路径,则相对于调用者的插件目录
|
||||
if not os.path.isabs(file_path):
|
||||
caller_frame = inspect.currentframe().f_back # type: ignore
|
||||
plugin_dir = get_plugin_path(caller_frame)
|
||||
file_path = os.path.join(plugin_dir, file_path)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=indent)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[UtilsAPI] 写入JSON文件出错: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 时间相关API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_timestamp() -> int:
|
||||
"""获取当前时间戳
|
||||
|
||||
Returns:
|
||||
int: 当前时间戳(秒)
|
||||
"""
|
||||
return int(time.time())
|
||||
|
||||
|
||||
def format_time(timestamp: Optional[int | float] = None, format_str: str = "%Y-%m-%d %H:%M:%S") -> str:
|
||||
"""格式化时间
|
||||
|
||||
Args:
|
||||
timestamp: 时间戳,如果为None则使用当前时间
|
||||
format_str: 时间格式字符串
|
||||
|
||||
Returns:
|
||||
str: 格式化后的时间字符串
|
||||
"""
|
||||
try:
|
||||
if timestamp is None:
|
||||
timestamp = time.time()
|
||||
return datetime.datetime.fromtimestamp(timestamp).strftime(format_str)
|
||||
except Exception as e:
|
||||
logger.error(f"[UtilsAPI] 格式化时间失败: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def parse_time(time_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> int:
|
||||
"""解析时间字符串为时间戳
|
||||
|
||||
Args:
|
||||
time_str: 时间字符串
|
||||
format_str: 时间格式字符串
|
||||
|
||||
Returns:
|
||||
int: 时间戳(秒)
|
||||
"""
|
||||
try:
|
||||
dt = datetime.datetime.strptime(time_str, format_str)
|
||||
return int(dt.timestamp())
|
||||
except Exception as e:
|
||||
logger.error(f"[UtilsAPI] 解析时间失败: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 其他工具函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def generate_unique_id() -> str:
|
||||
"""生成唯一ID
|
||||
|
||||
Returns:
|
||||
str: 唯一ID
|
||||
"""
|
||||
return str(uuid.uuid4())
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
from .base_plugin import BasePlugin
|
||||
from .base_action import BaseAction
|
||||
from .base_tool import BaseTool
|
||||
from .base_command import BaseCommand
|
||||
from .base_events_handler import BaseEventHandler
|
||||
from .component_types import (
|
||||
@@ -15,6 +16,7 @@ from .component_types import (
|
||||
ComponentInfo,
|
||||
ActionInfo,
|
||||
CommandInfo,
|
||||
ToolInfo,
|
||||
PluginInfo,
|
||||
PythonDependency,
|
||||
EventHandlerInfo,
|
||||
@@ -27,12 +29,14 @@ __all__ = [
|
||||
"BasePlugin",
|
||||
"BaseAction",
|
||||
"BaseCommand",
|
||||
"BaseTool",
|
||||
"ComponentType",
|
||||
"ActionActivationType",
|
||||
"ChatMode",
|
||||
"ComponentInfo",
|
||||
"ActionInfo",
|
||||
"CommandInfo",
|
||||
"ToolInfo",
|
||||
"PluginInfo",
|
||||
"PythonDependency",
|
||||
"ConfigField",
|
||||
|
||||
@@ -208,7 +208,7 @@ class BaseAction(ABC):
|
||||
return False, f"等待新消息失败: {str(e)}"
|
||||
|
||||
async def send_text(
|
||||
self, content: str, reply_to: str = "", reply_to_platform_id: str = "", typing: bool = False
|
||||
self, content: str, reply_to: str = "", typing: bool = False
|
||||
) -> bool:
|
||||
"""发送文本消息
|
||||
|
||||
@@ -227,7 +227,6 @@ class BaseAction(ABC):
|
||||
text=content,
|
||||
stream_id=self.chat_id,
|
||||
reply_to=reply_to,
|
||||
reply_to_platform_id=reply_to_platform_id,
|
||||
typing=typing,
|
||||
)
|
||||
|
||||
@@ -384,7 +383,7 @@ class BaseAction(ABC):
|
||||
keyword_case_sensitive=getattr(cls, "keyword_case_sensitive", False),
|
||||
mode_enable=getattr(cls, "mode_enable", ChatMode.ALL),
|
||||
parallel_action=getattr(cls, "parallel_action", True),
|
||||
random_activation_probability=getattr(cls, "random_activation_probability", 0.3),
|
||||
random_activation_probability=getattr(cls, "random_activation_probability", 0.0),
|
||||
llm_judge_prompt=getattr(cls, "llm_judge_prompt", ""),
|
||||
# 使用正确的字段名
|
||||
action_parameters=getattr(cls, "action_parameters", {}).copy(),
|
||||
|
||||
@@ -60,10 +60,10 @@ class BaseCommand(ABC):
|
||||
pass
|
||||
|
||||
def get_config(self, key: str, default=None):
|
||||
"""获取插件配置值,支持嵌套键访问
|
||||
"""获取插件配置值,使用嵌套键访问
|
||||
|
||||
Args:
|
||||
key: 配置键名,支持嵌套访问如 "section.subsection.key"
|
||||
key: 配置键名,使用嵌套访问如 "section.subsection.key"
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
|
||||
85
src/plugin_system/base/base_tool.py
Normal file
85
src/plugin_system/base/base_tool.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ComponentType, ToolInfo
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("base_tool")
|
||||
|
||||
|
||||
class BaseTool(ABC):
|
||||
"""所有工具的基类"""
|
||||
|
||||
name: str = ""
|
||||
"""工具的名称"""
|
||||
description: str = ""
|
||||
"""工具的描述"""
|
||||
parameters: Dict[str, Any] = {}
|
||||
"""工具的参数定义"""
|
||||
available_for_llm: bool = False
|
||||
"""是否可供LLM使用"""
|
||||
|
||||
@classmethod
|
||||
def get_tool_definition(cls) -> dict[str, Any]:
|
||||
"""获取工具定义,用于LLM工具调用
|
||||
|
||||
Returns:
|
||||
dict: 工具定义字典
|
||||
"""
|
||||
if not cls.name or not cls.description or not cls.parameters:
|
||||
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_tool_info(cls) -> ToolInfo:
|
||||
"""获取工具信息"""
|
||||
if not cls.name or not cls.description or not cls.parameters:
|
||||
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
|
||||
|
||||
return ToolInfo(
|
||||
name=cls.name,
|
||||
tool_description=cls.description,
|
||||
enabled=cls.available_for_llm,
|
||||
tool_parameters=cls.parameters,
|
||||
component_type=ComponentType.TOOL,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行工具函数(供llm调用)
|
||||
通过该方法,maicore会通过llm的tool call来调用工具
|
||||
传入的是json格式的参数,符合parameters定义的格式
|
||||
|
||||
Args:
|
||||
function_args: 工具调用参数
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现execute方法")
|
||||
|
||||
async def direct_execute(self, **function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""直接执行工具函数(供插件调用)
|
||||
通过该方法,插件可以直接调用工具,而不需要传入字典格式的参数
|
||||
插件可以直接调用此方法,用更加明了的方式传入参数
|
||||
示例: result = await tool.direct_execute(arg1="参数",arg2="参数2")
|
||||
|
||||
工具开发者可以重写此方法以实现与llm调用差异化的执行逻辑
|
||||
|
||||
Args:
|
||||
**function_args: 工具调用参数
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
if self.parameters and (missing := [p for p in self.parameters.get("required", []) if p not in function_args]):
|
||||
raise ValueError(f"工具类 {self.__class__.__name__} 缺少必要参数: {', '.join(missing)}")
|
||||
|
||||
return await self.execute(function_args)
|
||||
@@ -10,6 +10,7 @@ class ComponentType(Enum):
|
||||
|
||||
ACTION = "action" # 动作组件
|
||||
COMMAND = "command" # 命令组件
|
||||
TOOL = "tool" # 服务组件(预留)
|
||||
SCHEDULER = "scheduler" # 定时任务组件(预留)
|
||||
EVENT_HANDLER = "event_handler" # 事件处理组件(预留)
|
||||
|
||||
@@ -144,7 +145,17 @@ class CommandInfo(ComponentInfo):
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.component_type = ComponentType.COMMAND
|
||||
|
||||
@dataclass
|
||||
class ToolInfo(ComponentInfo):
|
||||
"""工具组件信息"""
|
||||
|
||||
tool_parameters: Dict[str, Any] = field(default_factory=dict) # 工具参数定义
|
||||
tool_description: str = "" # 工具描述
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.component_type = ComponentType.TOOL
|
||||
|
||||
@dataclass
|
||||
class EventHandlerInfo(ComponentInfo):
|
||||
|
||||
@@ -6,14 +6,12 @@
|
||||
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.core.dependency_manager import dependency_manager
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
|
||||
__all__ = [
|
||||
"plugin_manager",
|
||||
"component_registry",
|
||||
"dependency_manager",
|
||||
"events_manager",
|
||||
"global_announcement_manager",
|
||||
]
|
||||
|
||||
@@ -6,6 +6,7 @@ from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import (
|
||||
ComponentInfo,
|
||||
ActionInfo,
|
||||
ToolInfo,
|
||||
CommandInfo,
|
||||
EventHandlerInfo,
|
||||
PluginInfo,
|
||||
@@ -13,6 +14,7 @@ from src.plugin_system.base.component_types import (
|
||||
)
|
||||
from src.plugin_system.base.base_command import BaseCommand
|
||||
from src.plugin_system.base.base_action import BaseAction
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
|
||||
logger = get_logger("component_registry")
|
||||
@@ -30,7 +32,7 @@ class ComponentRegistry:
|
||||
"""组件注册表 命名空间式组件名 -> 组件信息"""
|
||||
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
|
||||
"""类型 -> 组件原名称 -> 组件信息"""
|
||||
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseEventHandler]]] = {}
|
||||
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler]]] = {}
|
||||
"""命名空间式组件名 -> 组件类"""
|
||||
|
||||
# 插件注册表
|
||||
@@ -49,6 +51,10 @@ class ComponentRegistry:
|
||||
self._command_patterns: Dict[Pattern, str] = {}
|
||||
"""编译后的正则 -> command名"""
|
||||
|
||||
# 工具特定注册表
|
||||
self._tool_registry: Dict[str, Type[BaseTool]] = {} # 工具名 -> 工具类
|
||||
self._llm_available_tools: Dict[str, Type[BaseTool]] = {} # llm可用的工具名 -> 工具类
|
||||
|
||||
# EventHandler特定注册表
|
||||
self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {}
|
||||
"""event_handler名 -> event_handler类"""
|
||||
@@ -79,7 +85,9 @@ class ComponentRegistry:
|
||||
return True
|
||||
|
||||
def register_component(
|
||||
self, component_info: ComponentInfo, component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler]]
|
||||
self,
|
||||
component_info: ComponentInfo,
|
||||
component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler, BaseTool]],
|
||||
) -> bool:
|
||||
"""注册组件
|
||||
|
||||
@@ -125,6 +133,10 @@ class ComponentRegistry:
|
||||
assert isinstance(component_info, CommandInfo)
|
||||
assert issubclass(component_class, BaseCommand)
|
||||
ret = self._register_command_component(component_info, component_class)
|
||||
case ComponentType.TOOL:
|
||||
assert isinstance(component_info, ToolInfo)
|
||||
assert issubclass(component_class, BaseTool)
|
||||
ret = self._register_tool_component(component_info, component_class)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
assert isinstance(component_info, EventHandlerInfo)
|
||||
assert issubclass(component_class, BaseEventHandler)
|
||||
@@ -180,6 +192,18 @@ class ComponentRegistry:
|
||||
|
||||
return True
|
||||
|
||||
def _register_tool_component(self, tool_info: ToolInfo, tool_class: Type[BaseTool]) -> bool:
|
||||
"""注册Tool组件到Tool特定注册表"""
|
||||
tool_name = tool_info.name
|
||||
|
||||
self._tool_registry[tool_name] = tool_class
|
||||
|
||||
# 如果是llm可用的且启用的工具,添加到 llm可用工具列表
|
||||
if tool_info.enabled:
|
||||
self._llm_available_tools[tool_name] = tool_class
|
||||
|
||||
return True
|
||||
|
||||
def _register_event_handler_component(
|
||||
self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]
|
||||
) -> bool:
|
||||
@@ -222,6 +246,9 @@ class ComponentRegistry:
|
||||
keys_to_remove = [k for k, v in self._command_patterns.items() if v == component_name]
|
||||
for key in keys_to_remove:
|
||||
self._command_patterns.pop(key)
|
||||
case ComponentType.TOOL:
|
||||
self._tool_registry.pop(component_name)
|
||||
self._llm_available_tools.pop(component_name)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||
|
||||
@@ -234,13 +261,13 @@ class ComponentRegistry:
|
||||
self._components_classes.pop(namespaced_name)
|
||||
logger.info(f"组件 {component_name} 已移除")
|
||||
return True
|
||||
except KeyError:
|
||||
logger.warning(f"移除组件时未找到组件: {component_name}")
|
||||
except KeyError as e:
|
||||
logger.warning(f"移除组件时未找到组件: {component_name}, 发生错误: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"移除组件 {component_name} 时发生错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def remove_plugin_registry(self, plugin_name: str) -> bool:
|
||||
"""移除插件注册信息
|
||||
|
||||
@@ -281,6 +308,10 @@ class ComponentRegistry:
|
||||
assert isinstance(target_component_info, CommandInfo)
|
||||
pattern = target_component_info.command_pattern
|
||||
self._command_patterns[re.compile(pattern)] = component_name
|
||||
case ComponentType.TOOL:
|
||||
assert isinstance(target_component_info, ToolInfo)
|
||||
assert issubclass(target_component_class, BaseTool)
|
||||
self._llm_available_tools[component_name] = target_component_class
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
assert isinstance(target_component_info, EventHandlerInfo)
|
||||
assert issubclass(target_component_class, BaseEventHandler)
|
||||
@@ -308,20 +339,29 @@ class ComponentRegistry:
|
||||
logger.warning(f"组件 {component_name} 未注册,无法禁用")
|
||||
return False
|
||||
target_component_info.enabled = False
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
self._default_actions.pop(component_name, None)
|
||||
case ComponentType.COMMAND:
|
||||
self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name}
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
self._enabled_event_handlers.pop(component_name, None)
|
||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||
try:
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
self._default_actions.pop(component_name)
|
||||
case ComponentType.COMMAND:
|
||||
self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name}
|
||||
case ComponentType.TOOL:
|
||||
self._llm_available_tools.pop(component_name)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
self._enabled_event_handlers.pop(component_name)
|
||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||
|
||||
await events_manager.unregister_event_subscriber(component_name)
|
||||
self._components[component_name].enabled = False
|
||||
self._components_by_type[component_type][component_name].enabled = False
|
||||
logger.info(f"组件 {component_name} 已禁用")
|
||||
return True
|
||||
await events_manager.unregister_event_subscriber(component_name)
|
||||
self._components[component_name].enabled = False
|
||||
self._components_by_type[component_type][component_name].enabled = False
|
||||
logger.info(f"组件 {component_name} 已禁用")
|
||||
return True
|
||||
except KeyError as e:
|
||||
logger.warning(f"禁用组件时未找到组件或已禁用: {component_name}, 发生错误: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"禁用组件 {component_name} 时发生错误: {e}")
|
||||
return False
|
||||
|
||||
# === 组件查询方法 ===
|
||||
def get_component_info(
|
||||
@@ -371,7 +411,7 @@ class ComponentRegistry:
|
||||
self,
|
||||
component_name: str,
|
||||
component_type: Optional[ComponentType] = None,
|
||||
) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler]]]:
|
||||
) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler], Type[BaseTool]]]:
|
||||
"""获取组件类,支持自动命名空间解析
|
||||
|
||||
Args:
|
||||
@@ -476,6 +516,27 @@ class ComponentRegistry:
|
||||
command_info,
|
||||
)
|
||||
|
||||
# === Tool 特定查询方法 ===
|
||||
def get_tool_registry(self) -> Dict[str, Type[BaseTool]]:
|
||||
"""获取Tool注册表"""
|
||||
return self._tool_registry.copy()
|
||||
|
||||
def get_llm_available_tools(self) -> Dict[str, Type[BaseTool]]:
|
||||
"""获取LLM可用的Tool列表"""
|
||||
return self._llm_available_tools.copy()
|
||||
|
||||
def get_registered_tool_info(self, tool_name: str) -> Optional[ToolInfo]:
|
||||
"""获取Tool信息
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
|
||||
Returns:
|
||||
ToolInfo: 工具信息对象,如果工具不存在则返回 None
|
||||
"""
|
||||
info = self.get_component_info(tool_name, ComponentType.TOOL)
|
||||
return info if isinstance(info, ToolInfo) else None
|
||||
|
||||
# === EventHandler 特定查询方法 ===
|
||||
|
||||
def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]:
|
||||
@@ -529,17 +590,21 @@ class ComponentRegistry:
|
||||
"""获取注册中心统计信息"""
|
||||
action_components: int = 0
|
||||
command_components: int = 0
|
||||
tool_components: int = 0
|
||||
events_handlers: int = 0
|
||||
for component in self._components.values():
|
||||
if component.component_type == ComponentType.ACTION:
|
||||
action_components += 1
|
||||
elif component.component_type == ComponentType.COMMAND:
|
||||
command_components += 1
|
||||
elif component.component_type == ComponentType.TOOL:
|
||||
tool_components += 1
|
||||
elif component.component_type == ComponentType.EVENT_HANDLER:
|
||||
events_handlers += 1
|
||||
return {
|
||||
"action_components": action_components,
|
||||
"command_components": command_components,
|
||||
"tool_components": tool_components,
|
||||
"event_handlers": events_handlers,
|
||||
"total_components": len(self._components),
|
||||
"total_plugins": len(self._plugins),
|
||||
|
||||
@@ -1,190 +0,0 @@
|
||||
"""
|
||||
插件依赖管理器
|
||||
|
||||
负责检查和安装插件的Python包依赖
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import importlib
|
||||
from typing import List, Dict, Tuple, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import PythonDependency
|
||||
|
||||
logger = get_logger("dependency_manager")
|
||||
|
||||
|
||||
class DependencyManager:
|
||||
"""依赖管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.install_log: List[str] = []
|
||||
self.failed_installs: Dict[str, str] = {}
|
||||
|
||||
def check_dependencies(
|
||||
self, dependencies: List[PythonDependency]
|
||||
) -> Tuple[List[PythonDependency], List[PythonDependency]]:
|
||||
"""检查依赖包状态
|
||||
|
||||
Args:
|
||||
dependencies: 依赖包列表
|
||||
|
||||
Returns:
|
||||
Tuple[List[PythonDependency], List[PythonDependency]]: (缺失的依赖, 可选缺失的依赖)
|
||||
"""
|
||||
missing_required = []
|
||||
missing_optional = []
|
||||
|
||||
for dep in dependencies:
|
||||
if self._is_package_available(dep.package_name):
|
||||
logger.debug(f"依赖包已存在: {dep.package_name}")
|
||||
elif dep.optional:
|
||||
missing_optional.append(dep)
|
||||
logger.warning(f"可选依赖包缺失: {dep.package_name} - {dep.description}")
|
||||
else:
|
||||
missing_required.append(dep)
|
||||
logger.error(f"必需依赖包缺失: {dep.package_name} - {dep.description}")
|
||||
return missing_required, missing_optional
|
||||
|
||||
def _is_package_available(self, package_name: str) -> bool:
|
||||
"""检查包是否可用"""
|
||||
try:
|
||||
importlib.import_module(package_name)
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
def install_dependencies(self, dependencies: List[PythonDependency], auto_install: bool = False) -> bool:
|
||||
"""安装依赖包
|
||||
|
||||
Args:
|
||||
dependencies: 需要安装的依赖包列表
|
||||
auto_install: 是否自动安装(True时不询问用户)
|
||||
|
||||
Returns:
|
||||
bool: 安装是否成功
|
||||
"""
|
||||
if not dependencies:
|
||||
return True
|
||||
|
||||
logger.info(f"需要安装 {len(dependencies)} 个依赖包")
|
||||
|
||||
# 显示将要安装的包
|
||||
for dep in dependencies:
|
||||
install_cmd = dep.get_pip_requirement()
|
||||
logger.info(f" - {install_cmd} {'(可选)' if dep.optional else '(必需)'}")
|
||||
if dep.description:
|
||||
logger.info(f" 说明: {dep.description}")
|
||||
|
||||
if not auto_install:
|
||||
# 这里可以添加用户确认逻辑
|
||||
logger.warning("手动安装模式:请手动运行 pip install 命令安装依赖包")
|
||||
return False
|
||||
|
||||
# 执行安装
|
||||
success_count = 0
|
||||
for dep in dependencies:
|
||||
if self._install_single_package(dep):
|
||||
success_count += 1
|
||||
else:
|
||||
self.failed_installs[dep.package_name] = f"安装失败: {dep.get_pip_requirement()}"
|
||||
|
||||
logger.info(f"依赖安装完成: {success_count}/{len(dependencies)} 个成功")
|
||||
return success_count == len(dependencies)
|
||||
|
||||
def _install_single_package(self, dependency: PythonDependency) -> bool:
|
||||
"""安装单个包"""
|
||||
pip_requirement = dependency.get_pip_requirement()
|
||||
|
||||
try:
|
||||
logger.info(f"正在安装: {pip_requirement}")
|
||||
|
||||
# 使用subprocess安装包
|
||||
cmd = [sys.executable, "-m", "pip", "install", pip_requirement]
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5分钟超时
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info(f"✅ 成功安装: {pip_requirement}")
|
||||
self.install_log.append(f"成功安装: {pip_requirement}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"❌ 安装失败: {pip_requirement}")
|
||||
logger.error(f"错误输出: {result.stderr}")
|
||||
self.install_log.append(f"安装失败: {pip_requirement} - {result.stderr}")
|
||||
return False
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error(f"❌ 安装超时: {pip_requirement}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 安装异常: {pip_requirement} - {str(e)}")
|
||||
return False
|
||||
|
||||
def generate_requirements_file(
|
||||
self, plugins_dependencies: List[List[PythonDependency]], output_path: str = "plugin_requirements.txt"
|
||||
) -> bool:
|
||||
"""生成插件依赖的requirements文件
|
||||
|
||||
Args:
|
||||
plugins_dependencies: 所有插件的依赖列表
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
bool: 生成是否成功
|
||||
"""
|
||||
try:
|
||||
all_deps = {}
|
||||
|
||||
# 合并所有插件的依赖
|
||||
for plugin_deps in plugins_dependencies:
|
||||
for dep in plugin_deps:
|
||||
key = dep.install_name
|
||||
if key in all_deps:
|
||||
# 如果已存在,可以添加版本兼容性检查逻辑
|
||||
existing = all_deps[key]
|
||||
if dep.version and existing.version != dep.version:
|
||||
logger.warning(f"依赖版本冲突: {key} ({existing.version} vs {dep.version})")
|
||||
else:
|
||||
all_deps[key] = dep
|
||||
|
||||
# 写入requirements文件
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write("# 插件依赖包自动生成\n")
|
||||
f.write("# Auto-generated plugin dependencies\n\n")
|
||||
|
||||
# 按包名排序
|
||||
sorted_deps = sorted(all_deps.values(), key=lambda x: x.install_name)
|
||||
|
||||
for dep in sorted_deps:
|
||||
requirement = dep.get_pip_requirement()
|
||||
if dep.description:
|
||||
f.write(f"# {dep.description}\n")
|
||||
if dep.optional:
|
||||
f.write("# Optional dependency\n")
|
||||
f.write(f"{requirement}\n\n")
|
||||
|
||||
logger.info(f"已生成插件依赖文件: {output_path} ({len(all_deps)} 个包)")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成requirements文件失败: {str(e)}")
|
||||
return False
|
||||
|
||||
def get_install_summary(self) -> Dict[str, Any]:
|
||||
"""获取安装摘要"""
|
||||
return {
|
||||
"install_log": self.install_log.copy(),
|
||||
"failed_installs": self.failed_installs.copy(),
|
||||
"total_attempts": len(self.install_log),
|
||||
"failed_count": len(self.failed_installs),
|
||||
}
|
||||
|
||||
|
||||
# 全局依赖管理器实例
|
||||
dependency_manager = DependencyManager()
|
||||
@@ -13,6 +13,8 @@ class GlobalAnnouncementManager:
|
||||
self._user_disabled_commands: Dict[str, List[str]] = {}
|
||||
# 用户禁用的事件处理器,chat_id -> [handler_name]
|
||||
self._user_disabled_event_handlers: Dict[str, List[str]] = {}
|
||||
# 用户禁用的工具,chat_id -> [tool_name]
|
||||
self._user_disabled_tools: Dict[str, List[str]] = {}
|
||||
|
||||
def disable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
|
||||
"""禁用特定聊天的某个动作"""
|
||||
@@ -77,6 +79,27 @@ class GlobalAnnouncementManager:
|
||||
return False
|
||||
return False
|
||||
|
||||
def disable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool:
|
||||
"""禁用特定聊天的某个工具"""
|
||||
if chat_id not in self._user_disabled_tools:
|
||||
self._user_disabled_tools[chat_id] = []
|
||||
if tool_name in self._user_disabled_tools[chat_id]:
|
||||
logger.warning(f"工具 {tool_name} 已经被禁用")
|
||||
return False
|
||||
self._user_disabled_tools[chat_id].append(tool_name)
|
||||
return True
|
||||
|
||||
def enable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool:
|
||||
"""启用特定聊天的某个工具"""
|
||||
if chat_id in self._user_disabled_tools:
|
||||
try:
|
||||
self._user_disabled_tools[chat_id].remove(tool_name)
|
||||
return True
|
||||
except ValueError:
|
||||
logger.warning(f"工具 {tool_name} 不在禁用列表中")
|
||||
return False
|
||||
return False
|
||||
|
||||
def get_disabled_chat_actions(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有动作"""
|
||||
return self._user_disabled_actions.get(chat_id, []).copy()
|
||||
@@ -88,6 +111,10 @@ class GlobalAnnouncementManager:
|
||||
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有事件处理器"""
|
||||
return self._user_disabled_event_handlers.get(chat_id, []).copy()
|
||||
|
||||
def get_disabled_chat_tools(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有工具"""
|
||||
return self._user_disabled_tools.get(chat_id, []).copy()
|
||||
|
||||
|
||||
global_announcement_manager = GlobalAnnouncementManager()
|
||||
|
||||
@@ -8,10 +8,9 @@ from pathlib import Path
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.plugin_base import PluginBase
|
||||
from src.plugin_system.base.component_types import ComponentType, PythonDependency
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
from src.plugin_system.utils.manifest_utils import VersionComparator
|
||||
from .component_registry import component_registry
|
||||
from .dependency_manager import dependency_manager
|
||||
|
||||
logger = get_logger("plugin_manager")
|
||||
|
||||
@@ -207,104 +206,6 @@ class PluginManager:
|
||||
"""
|
||||
return self.loaded_plugins.get(plugin_name)
|
||||
|
||||
def check_all_dependencies(self, auto_install: bool = False) -> Dict[str, Any]:
|
||||
"""检查所有插件的Python依赖包
|
||||
|
||||
Args:
|
||||
auto_install: 是否自动安装缺失的依赖包
|
||||
|
||||
Returns:
|
||||
Dict[str, any]: 检查结果摘要
|
||||
"""
|
||||
logger.info("开始检查所有插件的Python依赖包...")
|
||||
|
||||
all_required_missing: List[PythonDependency] = []
|
||||
all_optional_missing: List[PythonDependency] = []
|
||||
plugin_status = {}
|
||||
|
||||
for plugin_name in self.loaded_plugins:
|
||||
plugin_info = component_registry.get_plugin_info(plugin_name)
|
||||
if not plugin_info or not plugin_info.python_dependencies:
|
||||
plugin_status[plugin_name] = {"status": "no_dependencies", "missing": []}
|
||||
continue
|
||||
|
||||
logger.info(f"检查插件 {plugin_name} 的依赖...")
|
||||
|
||||
missing_required, missing_optional = dependency_manager.check_dependencies(plugin_info.python_dependencies)
|
||||
|
||||
if missing_required:
|
||||
all_required_missing.extend(missing_required)
|
||||
plugin_status[plugin_name] = {
|
||||
"status": "missing_required",
|
||||
"missing": [dep.package_name for dep in missing_required],
|
||||
"optional_missing": [dep.package_name for dep in missing_optional],
|
||||
}
|
||||
logger.error(f"插件 {plugin_name} 缺少必需依赖: {[dep.package_name for dep in missing_required]}")
|
||||
elif missing_optional:
|
||||
all_optional_missing.extend(missing_optional)
|
||||
plugin_status[plugin_name] = {
|
||||
"status": "missing_optional",
|
||||
"missing": [],
|
||||
"optional_missing": [dep.package_name for dep in missing_optional],
|
||||
}
|
||||
logger.warning(f"插件 {plugin_name} 缺少可选依赖: {[dep.package_name for dep in missing_optional]}")
|
||||
else:
|
||||
plugin_status[plugin_name] = {"status": "ok", "missing": []}
|
||||
logger.info(f"插件 {plugin_name} 依赖检查通过")
|
||||
|
||||
# 汇总结果
|
||||
total_missing = len({dep.package_name for dep in all_required_missing})
|
||||
total_optional_missing = len({dep.package_name for dep in all_optional_missing})
|
||||
|
||||
logger.info(f"依赖检查完成 - 缺少必需包: {total_missing}个, 缺少可选包: {total_optional_missing}个")
|
||||
|
||||
# 如果需要自动安装
|
||||
install_success = True
|
||||
if auto_install and all_required_missing:
|
||||
unique_required = {dep.package_name: dep for dep in all_required_missing}
|
||||
logger.info(f"开始自动安装 {len(unique_required)} 个必需依赖包...")
|
||||
install_success = dependency_manager.install_dependencies(list(unique_required.values()), auto_install=True)
|
||||
|
||||
return {
|
||||
"total_plugins_checked": len(plugin_status),
|
||||
"plugins_with_missing_required": len(
|
||||
[p for p in plugin_status.values() if p["status"] == "missing_required"]
|
||||
),
|
||||
"plugins_with_missing_optional": len(
|
||||
[p for p in plugin_status.values() if p["status"] == "missing_optional"]
|
||||
),
|
||||
"total_missing_required": total_missing,
|
||||
"total_missing_optional": total_optional_missing,
|
||||
"plugin_status": plugin_status,
|
||||
"auto_install_attempted": auto_install and bool(all_required_missing),
|
||||
"auto_install_success": install_success,
|
||||
"install_summary": dependency_manager.get_install_summary(),
|
||||
}
|
||||
|
||||
def generate_plugin_requirements(self, output_path: str = "plugin_requirements.txt") -> bool:
|
||||
"""生成所有插件依赖的requirements文件
|
||||
|
||||
Args:
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
bool: 生成是否成功
|
||||
"""
|
||||
logger.info("开始生成插件依赖requirements文件...")
|
||||
|
||||
all_dependencies = []
|
||||
|
||||
for plugin_name in self.loaded_plugins:
|
||||
plugin_info = component_registry.get_plugin_info(plugin_name)
|
||||
if plugin_info and plugin_info.python_dependencies:
|
||||
all_dependencies.append(plugin_info.python_dependencies)
|
||||
|
||||
if not all_dependencies:
|
||||
logger.info("没有找到任何插件依赖")
|
||||
return False
|
||||
|
||||
return dependency_manager.generate_requirements_file(all_dependencies, output_path)
|
||||
|
||||
# === 查询方法 ===
|
||||
def list_loaded_plugins(self) -> List[str]:
|
||||
"""
|
||||
@@ -323,6 +224,18 @@ class PluginManager:
|
||||
list: 已注册的插件类名称列表。
|
||||
"""
|
||||
return list(self.plugin_classes.keys())
|
||||
|
||||
def get_plugin_path(self, plugin_name: str) -> Optional[str]:
|
||||
"""
|
||||
获取指定插件的路径。
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
Optional[str]: 插件目录的绝对路径,如果插件不存在则返回None。
|
||||
"""
|
||||
return self.plugin_paths.get(plugin_name)
|
||||
|
||||
# === 私有方法 ===
|
||||
# == 目录管理 ==
|
||||
@@ -388,6 +301,7 @@ class PluginManager:
|
||||
return False
|
||||
|
||||
module = module_from_spec(spec)
|
||||
module.__package__ = module_name # 设置模块包名
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
logger.debug(f"插件模块加载成功: {plugin_file}")
|
||||
@@ -444,6 +358,7 @@ class PluginManager:
|
||||
stats = component_registry.get_registry_stats()
|
||||
action_count = stats.get("action_components", 0)
|
||||
command_count = stats.get("command_components", 0)
|
||||
tool_count = stats.get("tool_components", 0)
|
||||
event_handler_count = stats.get("event_handlers", 0)
|
||||
total_components = stats.get("total_components", 0)
|
||||
|
||||
@@ -451,7 +366,7 @@ class PluginManager:
|
||||
if total_registered > 0:
|
||||
logger.info("🎉 插件系统加载完成!")
|
||||
logger.info(
|
||||
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, EventHandler: {event_handler_count})"
|
||||
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, EventHandler: {event_handler_count})"
|
||||
)
|
||||
|
||||
# 显示详细的插件列表
|
||||
@@ -486,6 +401,9 @@ class PluginManager:
|
||||
command_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.COMMAND
|
||||
]
|
||||
tool_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.TOOL
|
||||
]
|
||||
event_handler_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER
|
||||
]
|
||||
@@ -497,7 +415,9 @@ class PluginManager:
|
||||
if command_components:
|
||||
command_names = [c.name for c in command_components]
|
||||
logger.info(f" ⚡ Command组件: {', '.join(command_names)}")
|
||||
|
||||
if tool_components:
|
||||
tool_names = [c.name for c in tool_components]
|
||||
logger.info(f" 🛠️ Tool组件: {', '.join(tool_names)}")
|
||||
if event_handler_components:
|
||||
event_handler_names = [c.name for c in event_handler_components]
|
||||
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}")
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
import json
|
||||
import time
|
||||
from typing import List, Dict, Tuple, Optional, Any
|
||||
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
import time
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.tools.tool_use import ToolUser
|
||||
from src.chat.utils.json_utils import process_llm_tool_calls
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("tool_executor")
|
||||
logger = get_logger("tool_use")
|
||||
|
||||
|
||||
def init_tool_executor_prompt():
|
||||
@@ -28,6 +30,10 @@ If you need to use a tool, please directly call the corresponding tool function.
|
||||
Prompt(tool_executor_prompt, "tool_executor_prompt")
|
||||
|
||||
|
||||
# 初始化提示词
|
||||
init_tool_executor_prompt()
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""独立的工具执行器组件
|
||||
|
||||
@@ -51,9 +57,6 @@ class ToolExecutor:
|
||||
request_type="tool_executor",
|
||||
)
|
||||
|
||||
# 初始化工具实例
|
||||
self.tool_instance = ToolUser()
|
||||
|
||||
# 缓存配置
|
||||
self.enable_cache = enable_cache
|
||||
self.cache_ttl = cache_ttl
|
||||
@@ -73,7 +76,7 @@ class ToolExecutor:
|
||||
return_details: 是否返回详细信息(使用的工具列表和提示词)
|
||||
|
||||
Returns:
|
||||
如果return_details为False: List[Dict] - 工具执行结果列表
|
||||
如果return_details为False: Tuple[List[Dict], List[str], str] - (工具执行结果列表, 空, 空)
|
||||
如果return_details为True: Tuple[List[Dict], List[str], str] - (结果列表, 使用的工具, 提示词)
|
||||
"""
|
||||
|
||||
@@ -82,15 +85,15 @@ class ToolExecutor:
|
||||
if cached_result := self._get_from_cache(cache_key):
|
||||
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行")
|
||||
if not return_details:
|
||||
return cached_result, [], "使用缓存结果"
|
||||
return cached_result, [], ""
|
||||
|
||||
# 从缓存结果中提取工具名称
|
||||
used_tools = [result.get("tool_name", "unknown") for result in cached_result]
|
||||
return cached_result, used_tools, "使用缓存结果"
|
||||
return cached_result, used_tools, ""
|
||||
|
||||
# 缓存未命中,执行工具调用
|
||||
# 获取可用工具
|
||||
tools = self.tool_instance._define_tools()
|
||||
tools = self._get_tool_definitions()
|
||||
|
||||
# 获取当前时间
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
@@ -112,6 +115,7 @@ class ToolExecutor:
|
||||
# 调用LLM进行工具决策
|
||||
response, other_info = await self.llm_model.generate_response_async(prompt=prompt, tools=tools)
|
||||
|
||||
# TODO: 在APIADA加入后完全修复这里!
|
||||
# 解析LLM响应
|
||||
if len(other_info) == 3:
|
||||
reasoning_content, model_name, tool_calls = other_info
|
||||
@@ -133,6 +137,11 @@ class ToolExecutor:
|
||||
return tool_results, used_tools, prompt
|
||||
else:
|
||||
return tool_results, [], ""
|
||||
|
||||
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
|
||||
all_tools = get_llm_available_tool_definitions()
|
||||
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
|
||||
return [parameters for name, parameters in all_tools if name not in user_disabled_tools]
|
||||
|
||||
async def _execute_tool_calls(self, tool_calls) -> Tuple[List[Dict], List[str]]:
|
||||
"""执行工具调用
|
||||
@@ -172,7 +181,7 @@ class ToolExecutor:
|
||||
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
|
||||
|
||||
# 执行工具
|
||||
result = await self.tool_instance._execute_tool_call(tool_call)
|
||||
result = await self._execute_tool_call(tool_call)
|
||||
|
||||
if result:
|
||||
tool_info = {
|
||||
@@ -205,6 +214,45 @@ class ToolExecutor:
|
||||
|
||||
return tool_results, used_tools
|
||||
|
||||
async def _execute_tool_call(self, tool_call: Dict[str, Any]) -> Optional[Dict]:
|
||||
# sourcery skip: use-assigned-variable
|
||||
"""执行单个工具调用
|
||||
|
||||
Args:
|
||||
tool_call: 工具调用对象
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: 工具调用结果,如果失败则返回None
|
||||
"""
|
||||
try:
|
||||
function_name = tool_call["function"]["name"]
|
||||
function_args = json.loads(tool_call["function"]["arguments"])
|
||||
function_args["llm_called"] = True # 标记为LLM调用
|
||||
|
||||
# 获取对应工具实例
|
||||
tool_instance = get_tool_instance(function_name)
|
||||
if not tool_instance:
|
||||
logger.warning(f"未知工具名称: {function_name}")
|
||||
return None
|
||||
|
||||
# 执行工具
|
||||
result = await tool_instance.execute(function_args)
|
||||
if result:
|
||||
# 直接使用 function_name 作为 tool_type
|
||||
tool_type = function_name
|
||||
|
||||
return {
|
||||
"tool_call_id": tool_call["id"],
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"type": tool_type,
|
||||
"content": result["content"],
|
||||
}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"执行工具调用时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
|
||||
"""生成缓存键
|
||||
|
||||
@@ -272,15 +320,6 @@ class ToolExecutor:
|
||||
if expired_keys:
|
||||
logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存")
|
||||
|
||||
def get_available_tools(self) -> List[str]:
|
||||
"""获取可用工具列表
|
||||
|
||||
Returns:
|
||||
List[str]: 可用工具名称列表
|
||||
"""
|
||||
tools = self.tool_instance._define_tools()
|
||||
return [tool.get("function", {}).get("name", "unknown") for tool in tools]
|
||||
|
||||
async def execute_specific_tool(
|
||||
self, tool_name: str, tool_args: Dict, validate_args: bool = True
|
||||
) -> Optional[Dict]:
|
||||
@@ -299,7 +338,7 @@ class ToolExecutor:
|
||||
|
||||
logger.info(f"{self.log_prefix}直接执行工具: {tool_name}")
|
||||
|
||||
result = await self.tool_instance._execute_tool_call(tool_call)
|
||||
result = await self._execute_tool_call(tool_call)
|
||||
|
||||
if result:
|
||||
tool_info = {
|
||||
@@ -366,12 +405,8 @@ class ToolExecutor:
|
||||
logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}")
|
||||
|
||||
|
||||
# 初始化提示词
|
||||
init_tool_executor_prompt()
|
||||
|
||||
|
||||
"""
|
||||
使用示例:
|
||||
ToolExecutor使用示例:
|
||||
|
||||
# 1. 基础使用 - 从聊天消息执行工具(启用缓存,默认TTL=3)
|
||||
executor = ToolExecutor(executor_id="my_executor")
|
||||
@@ -400,7 +435,6 @@ result = await executor.execute_specific_tool(
|
||||
)
|
||||
|
||||
# 6. 缓存管理
|
||||
available_tools = executor.get_available_tools()
|
||||
cache_status = executor.get_cache_status() # 查看缓存状态
|
||||
executor.clear_cache() # 清空缓存
|
||||
executor.set_cache_config(cache_ttl=5) # 动态修改缓存配置
|
||||
@@ -163,12 +163,11 @@ class VersionComparator:
|
||||
version_normalized, max_normalized
|
||||
)
|
||||
|
||||
if is_compatible:
|
||||
logger.info(f"版本兼容性检查:{compat_msg}")
|
||||
return True, compat_msg
|
||||
else:
|
||||
if not is_compatible:
|
||||
return False, f"版本 {version_normalized} 高于最大支持版本 {max_normalized},且无兼容性映射"
|
||||
|
||||
logger.info(f"版本兼容性检查:{compat_msg}")
|
||||
return True, compat_msg
|
||||
return True, ""
|
||||
|
||||
@staticmethod
|
||||
@@ -358,14 +357,10 @@ class ManifestValidator:
|
||||
|
||||
if self.validation_errors:
|
||||
report.append("❌ 验证错误:")
|
||||
for error in self.validation_errors:
|
||||
report.append(f" - {error}")
|
||||
|
||||
report.extend(f" - {error}" for error in self.validation_errors)
|
||||
if self.validation_warnings:
|
||||
report.append("⚠️ 验证警告:")
|
||||
for warning in self.validation_warnings:
|
||||
report.append(f" - {warning}")
|
||||
|
||||
report.extend(f" - {warning}" for warning in self.validation_warnings)
|
||||
if not self.validation_errors and not self.validation_warnings:
|
||||
report.append("✅ Manifest文件验证通过")
|
||||
|
||||
|
||||
@@ -24,11 +24,6 @@
|
||||
"is_built_in": true,
|
||||
"plugin_type": "action_provider",
|
||||
"components": [
|
||||
{
|
||||
"type": "action",
|
||||
"name": "reply",
|
||||
"description": "参与聊天回复,发送文本进行表达"
|
||||
},
|
||||
{
|
||||
"type": "action",
|
||||
"name": "no_reply",
|
||||
|
||||
@@ -9,7 +9,8 @@ from src.common.logger import get_logger
|
||||
|
||||
# 导入API模块 - 标准Python包方式
|
||||
from src.plugin_system.apis import emoji_api, llm_api, message_api
|
||||
from src.plugins.built_in.core_actions.no_reply import NoReplyAction
|
||||
# 注释:不再需要导入NoReplyAction,因为计数器管理已移至heartFC_chat.py
|
||||
# from src.plugins.built_in.core_actions.no_reply import NoReplyAction
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
@@ -20,10 +21,14 @@ class EmojiAction(BaseAction):
|
||||
"""表情动作 - 发送表情包"""
|
||||
|
||||
# 激活设置
|
||||
activation_type = ActionActivationType.RANDOM
|
||||
if global_config.emoji.emoji_activate_type == "llm":
|
||||
activation_type = ActionActivationType.LLM_JUDGE
|
||||
random_activation_probability = 0
|
||||
else:
|
||||
activation_type = ActionActivationType.RANDOM
|
||||
random_activation_probability = global_config.emoji.emoji_chance
|
||||
mode_enable = ChatMode.ALL
|
||||
parallel_action = True
|
||||
random_activation_probability = 0.2 # 默认值,可通过配置覆盖
|
||||
|
||||
# 动作基本信息
|
||||
action_name = "emoji"
|
||||
@@ -115,7 +120,7 @@ class EmojiAction(BaseAction):
|
||||
logger.error(f"{self.log_prefix} 未找到'utils_small'模型配置,无法调用LLM")
|
||||
return False, "未找到'utils_small'模型配置"
|
||||
|
||||
success, chosen_emotion, _, _ = await llm_api.generate_with_model(
|
||||
success, chosen_emotion = await llm_api.generate_with_model(
|
||||
prompt, model_config=chat_model_config, request_type="emoji"
|
||||
)
|
||||
|
||||
@@ -143,8 +148,8 @@ class EmojiAction(BaseAction):
|
||||
logger.error(f"{self.log_prefix} 表情包发送失败")
|
||||
return False, "表情包发送失败"
|
||||
|
||||
# 重置NoReplyAction的连续计数器
|
||||
NoReplyAction.reset_consecutive_count()
|
||||
# 注释:重置NoReplyAction的连续计数器现在由heartFC_chat.py统一管理
|
||||
# NoReplyAction.reset_consecutive_count()
|
||||
|
||||
return True, f"发送表情包: {emoji_description}"
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import random
|
||||
import time
|
||||
from typing import Tuple
|
||||
from typing import Tuple, List
|
||||
from collections import deque
|
||||
|
||||
# 导入新插件系统
|
||||
from src.plugin_system import BaseAction, ActionActivationType, ChatMode
|
||||
@@ -17,11 +18,15 @@ logger = get_logger("no_reply_action")
|
||||
|
||||
|
||||
class NoReplyAction(BaseAction):
|
||||
"""不回复动作,根据新消息的兴趣值或数量决定何时结束等待.
|
||||
"""不回复动作,支持waiting和breaking两种形式.
|
||||
|
||||
新的等待逻辑:
|
||||
1. 新消息累计兴趣值超过阈值 (默认10) 则结束等待
|
||||
2. 累计新消息数量达到随机阈值 (默认5-10条) 则结束等待
|
||||
waiting形式:
|
||||
- 只要有新消息就结束动作
|
||||
- 记录新消息的兴趣度到列表(最多保留最近三项)
|
||||
- 如果最近三次动作都是no_reply,且最近新消息列表兴趣度之和小于阈值,就进入breaking形式
|
||||
|
||||
breaking形式:
|
||||
- 和原有逻辑一致,需要消息满足一定数量或累计一定兴趣值才结束动作
|
||||
"""
|
||||
|
||||
focus_activation_type = ActionActivationType.NEVER
|
||||
@@ -35,112 +40,45 @@ class NoReplyAction(BaseAction):
|
||||
|
||||
# 连续no_reply计数器
|
||||
_consecutive_count = 0
|
||||
|
||||
# 最近三次no_reply的新消息兴趣度记录
|
||||
_recent_interest_records: deque = deque(maxlen=3)
|
||||
|
||||
# 新增:兴趣值退出阈值
|
||||
# 兴趣值退出阈值
|
||||
_interest_exit_threshold = 3.0
|
||||
# 新增:消息数量退出阈值
|
||||
_min_exit_message_count = 5
|
||||
_max_exit_message_count = 10
|
||||
# 消息数量退出阈值
|
||||
_min_exit_message_count = 3
|
||||
_max_exit_message_count = 6
|
||||
|
||||
# 动作参数定义
|
||||
action_parameters = {}
|
||||
|
||||
# 动作使用场景
|
||||
action_require = ["你发送了消息,目前无人回复"]
|
||||
action_require = [""]
|
||||
|
||||
# 关联类型
|
||||
associated_types = []
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""执行不回复动作"""
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
# 增加连续计数
|
||||
NoReplyAction._consecutive_count += 1
|
||||
count = NoReplyAction._consecutive_count
|
||||
|
||||
reason = self.action_data.get("reason", "")
|
||||
start_time = self.action_data.get("loop_start_time", time.time())
|
||||
check_interval = 0.6 # 每秒检查一次
|
||||
check_interval = 0.6
|
||||
|
||||
# 随机生成本次等待需要的新消息数量阈值
|
||||
exit_message_count_threshold = random.randint(self._min_exit_message_count, self._max_exit_message_count)
|
||||
logger.info(
|
||||
f"{self.log_prefix} 本次no_reply需要 {exit_message_count_threshold} 条新消息或累计兴趣值超过 {self._interest_exit_threshold} 才能打断"
|
||||
)
|
||||
# 判断使用哪种形式
|
||||
form_type = self._determine_form_type()
|
||||
|
||||
logger.info(f"{self.log_prefix} 选择不回复(第{NoReplyAction._consecutive_count + 1}次),使用{form_type}形式,原因: {reason}")
|
||||
|
||||
logger.info(f"{self.log_prefix} 选择不回复(第{count}次),开始摸鱼,原因: {reason}")
|
||||
# 增加连续计数(在确定要执行no_reply时才增加)
|
||||
NoReplyAction._consecutive_count += 1
|
||||
|
||||
# 进入等待状态
|
||||
while True:
|
||||
current_time = time.time()
|
||||
elapsed_time = current_time - start_time
|
||||
|
||||
# 1. 检查新消息
|
||||
recent_messages_dict = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.chat_id,
|
||||
start_time=start_time,
|
||||
end_time=current_time,
|
||||
filter_mai=True,
|
||||
filter_command=True,
|
||||
)
|
||||
new_message_count = len(recent_messages_dict)
|
||||
|
||||
# 2. 检查消息数量是否达到阈值
|
||||
talk_frequency = global_config.chat.get_current_talk_frequency(self.chat_id)
|
||||
if new_message_count >= exit_message_count_threshold / talk_frequency:
|
||||
logger.info(
|
||||
f"{self.log_prefix} 累计消息数量达到{new_message_count}条(>{exit_message_count_threshold / talk_frequency}),结束等待"
|
||||
)
|
||||
exit_reason = f"{global_config.bot.nickname}(你)看到了{new_message_count}条新消息,可以考虑一下是否要进行回复"
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=exit_reason,
|
||||
action_done=True,
|
||||
)
|
||||
return True, f"累计消息数量达到{new_message_count}条,结束等待 (等待时间: {elapsed_time:.1f}秒)"
|
||||
|
||||
# 3. 检查累计兴趣值
|
||||
if new_message_count > 0:
|
||||
accumulated_interest = 0.0
|
||||
for msg_dict in recent_messages_dict:
|
||||
text = msg_dict.get("processed_plain_text", "")
|
||||
interest_value = msg_dict.get("interest_value", 0.0)
|
||||
if text:
|
||||
accumulated_interest += interest_value
|
||||
|
||||
talk_frequency = global_config.chat.get_current_talk_frequency(self.chat_id)
|
||||
# 只在兴趣值变化时输出log
|
||||
if not hasattr(self, "_last_accumulated_interest") or accumulated_interest != self._last_accumulated_interest:
|
||||
logger.info(f"{self.log_prefix} 当前累计兴趣值: {accumulated_interest:.2f}, 当前聊天频率: {talk_frequency:.2f}")
|
||||
self._last_accumulated_interest = accumulated_interest
|
||||
|
||||
if accumulated_interest >= self._interest_exit_threshold / talk_frequency:
|
||||
logger.info(
|
||||
f"{self.log_prefix} 累计兴趣值达到{accumulated_interest:.2f}(>{self._interest_exit_threshold / talk_frequency}),结束等待"
|
||||
)
|
||||
exit_reason = f"{global_config.bot.nickname}(你)感觉到了大家浓厚的兴趣(兴趣值{accumulated_interest:.1f}),决定重新加入讨论"
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=exit_reason,
|
||||
action_done=True,
|
||||
)
|
||||
return (
|
||||
True,
|
||||
f"累计兴趣值达到{accumulated_interest:.2f},结束等待 (等待时间: {elapsed_time:.1f}秒)",
|
||||
)
|
||||
|
||||
# 每10秒输出一次等待状态
|
||||
if int(elapsed_time) > 0 and int(elapsed_time) % 10 == 0:
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 已等待{elapsed_time:.0f}秒,累计{new_message_count}条消息,继续等待..."
|
||||
)
|
||||
# 使用 asyncio.sleep(1) 来避免在同一秒内重复打印日志
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# 短暂等待后继续检查
|
||||
await asyncio.sleep(check_interval)
|
||||
if form_type == "waiting":
|
||||
return await self._execute_waiting_form(start_time, check_interval)
|
||||
else:
|
||||
return await self._execute_breaking_form(start_time, check_interval)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 不回复动作执行失败: {e}")
|
||||
@@ -153,8 +91,191 @@ class NoReplyAction(BaseAction):
|
||||
)
|
||||
return False, f"不回复动作执行失败: {e}"
|
||||
|
||||
def _determine_form_type(self) -> str:
|
||||
"""判断使用哪种形式的no_reply"""
|
||||
# 如果连续no_reply次数少于3次,使用waiting形式
|
||||
if NoReplyAction._consecutive_count < 3:
|
||||
return "waiting"
|
||||
|
||||
# 如果最近三次记录不足,使用waiting形式
|
||||
if len(NoReplyAction._recent_interest_records) < 3:
|
||||
return "waiting"
|
||||
|
||||
# 计算最近三次记录的兴趣度总和
|
||||
total_recent_interest = sum(NoReplyAction._recent_interest_records)
|
||||
|
||||
# 获取当前聊天频率和意愿系数
|
||||
talk_frequency = global_config.chat.get_current_talk_frequency(self.chat_id)
|
||||
willing_amplifier = global_config.chat.willing_amplifier
|
||||
|
||||
# 计算调整后的阈值
|
||||
adjusted_threshold = self._interest_exit_threshold / talk_frequency / willing_amplifier
|
||||
|
||||
logger.info(f"{self.log_prefix} 最近三次兴趣度总和: {total_recent_interest:.2f}, 调整后阈值: {adjusted_threshold:.2f}")
|
||||
|
||||
# 如果兴趣度总和小于阈值,进入breaking形式
|
||||
if total_recent_interest < adjusted_threshold:
|
||||
logger.info(f"{self.log_prefix} 兴趣度不足,进入breaking形式")
|
||||
return "breaking"
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} 兴趣度充足,继续使用waiting形式")
|
||||
return "waiting"
|
||||
|
||||
async def _execute_waiting_form(self, start_time: float, check_interval: float) -> Tuple[bool, str]:
|
||||
"""执行waiting形式的no_reply"""
|
||||
import asyncio
|
||||
|
||||
logger.info(f"{self.log_prefix} 进入waiting形式,等待任何新消息")
|
||||
|
||||
while True:
|
||||
current_time = time.time()
|
||||
elapsed_time = current_time - start_time
|
||||
|
||||
# 检查新消息
|
||||
recent_messages_dict = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.chat_id,
|
||||
start_time=start_time,
|
||||
end_time=current_time,
|
||||
filter_mai=True,
|
||||
filter_command=True,
|
||||
)
|
||||
new_message_count = len(recent_messages_dict)
|
||||
|
||||
# waiting形式:只要有新消息就结束
|
||||
if new_message_count > 0:
|
||||
# 计算新消息的总兴趣度
|
||||
total_interest = 0.0
|
||||
for msg_dict in recent_messages_dict:
|
||||
interest_value = msg_dict.get("interest_value", 0.0)
|
||||
if msg_dict.get("processed_plain_text", ""):
|
||||
total_interest += interest_value * global_config.chat.willing_amplifier
|
||||
|
||||
# 记录到最近兴趣度列表
|
||||
NoReplyAction._recent_interest_records.append(total_interest)
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} waiting形式检测到{new_message_count}条新消息,总兴趣度: {total_interest:.2f},结束等待"
|
||||
)
|
||||
|
||||
exit_reason = f"{global_config.bot.nickname}(你)看到了{new_message_count}条新消息,可以考虑一下是否要进行回复"
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=exit_reason,
|
||||
action_done=True,
|
||||
)
|
||||
return True, f"waiting形式检测到{new_message_count}条新消息,结束等待 (等待时间: {elapsed_time:.1f}秒)"
|
||||
|
||||
# 每10秒输出一次等待状态
|
||||
if int(elapsed_time) > 0 and int(elapsed_time) % 10 == 0:
|
||||
logger.debug(f"{self.log_prefix} waiting形式已等待{elapsed_time:.0f}秒,继续等待新消息...")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# 短暂等待后继续检查
|
||||
await asyncio.sleep(check_interval)
|
||||
|
||||
async def _execute_breaking_form(self, start_time: float, check_interval: float) -> Tuple[bool, str]:
|
||||
"""执行breaking形式的no_reply(原有逻辑)"""
|
||||
import asyncio
|
||||
|
||||
# 随机生成本次等待需要的新消息数量阈值
|
||||
exit_message_count_threshold = random.randint(self._min_exit_message_count, self._max_exit_message_count)
|
||||
|
||||
logger.info(f"{self.log_prefix} 进入breaking形式,需要{exit_message_count_threshold}条消息或足够兴趣度")
|
||||
|
||||
while True:
|
||||
current_time = time.time()
|
||||
elapsed_time = current_time - start_time
|
||||
|
||||
# 检查新消息
|
||||
recent_messages_dict = message_api.get_messages_by_time_in_chat(
|
||||
chat_id=self.chat_id,
|
||||
start_time=start_time,
|
||||
end_time=current_time,
|
||||
filter_mai=True,
|
||||
filter_command=True,
|
||||
)
|
||||
new_message_count = len(recent_messages_dict)
|
||||
|
||||
# 检查消息数量是否达到阈值
|
||||
talk_frequency = global_config.chat.get_current_talk_frequency(self.chat_id)
|
||||
modified_exit_count_threshold = (exit_message_count_threshold / talk_frequency) / global_config.chat.willing_amplifier
|
||||
|
||||
if new_message_count >= modified_exit_count_threshold:
|
||||
# 记录兴趣度到列表
|
||||
total_interest = 0.0
|
||||
for msg_dict in recent_messages_dict:
|
||||
interest_value = msg_dict.get("interest_value", 0.0)
|
||||
if msg_dict.get("processed_plain_text", ""):
|
||||
total_interest += interest_value * global_config.chat.willing_amplifier
|
||||
|
||||
NoReplyAction._recent_interest_records.append(total_interest)
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} breaking形式累计消息数量达到{new_message_count}条(>{modified_exit_count_threshold}),结束等待"
|
||||
)
|
||||
exit_reason = f"{global_config.bot.nickname}(你)看到了{new_message_count}条新消息,可以考虑一下是否要进行回复"
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=exit_reason,
|
||||
action_done=True,
|
||||
)
|
||||
return True, f"breaking形式累计消息数量达到{new_message_count}条,结束等待 (等待时间: {elapsed_time:.1f}秒)"
|
||||
|
||||
# 检查累计兴趣值
|
||||
if new_message_count > 0:
|
||||
accumulated_interest = 0.0
|
||||
for msg_dict in recent_messages_dict:
|
||||
text = msg_dict.get("processed_plain_text", "")
|
||||
interest_value = msg_dict.get("interest_value", 0.0)
|
||||
if text:
|
||||
accumulated_interest += interest_value * global_config.chat.willing_amplifier
|
||||
|
||||
# 只在兴趣值变化时输出log
|
||||
if not hasattr(self, "_last_accumulated_interest") or accumulated_interest != self._last_accumulated_interest:
|
||||
logger.info(f"{self.log_prefix} breaking形式当前累计兴趣值: {accumulated_interest:.2f}, 当前聊天频率: {talk_frequency:.2f}")
|
||||
self._last_accumulated_interest = accumulated_interest
|
||||
|
||||
if accumulated_interest >= self._interest_exit_threshold / talk_frequency:
|
||||
# 记录兴趣度到列表
|
||||
NoReplyAction._recent_interest_records.append(accumulated_interest)
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} breaking形式累计兴趣值达到{accumulated_interest:.2f}(>{self._interest_exit_threshold / talk_frequency}),结束等待"
|
||||
)
|
||||
exit_reason = f"{global_config.bot.nickname}(你)感觉到了大家浓厚的兴趣(兴趣值{accumulated_interest:.1f}),决定重新加入讨论"
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=exit_reason,
|
||||
action_done=True,
|
||||
)
|
||||
return (
|
||||
True,
|
||||
f"breaking形式累计兴趣值达到{accumulated_interest:.2f},结束等待 (等待时间: {elapsed_time:.1f}秒)",
|
||||
)
|
||||
|
||||
# 每10秒输出一次等待状态
|
||||
if int(elapsed_time) > 0 and int(elapsed_time) % 10 == 0:
|
||||
logger.debug(
|
||||
f"{self.log_prefix} breaking形式已等待{elapsed_time:.0f}秒,累计{new_message_count}条消息,继续等待..."
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# 短暂等待后继续检查
|
||||
await asyncio.sleep(check_interval)
|
||||
|
||||
@classmethod
|
||||
def reset_consecutive_count(cls):
|
||||
"""重置连续计数器"""
|
||||
"""重置连续计数器和兴趣度记录"""
|
||||
cls._consecutive_count = 0
|
||||
logger.debug("NoReplyAction连续计数器已重置")
|
||||
cls._recent_interest_records.clear()
|
||||
logger.debug("NoReplyAction连续计数器和兴趣度记录已重置")
|
||||
|
||||
@classmethod
|
||||
def get_recent_interest_records(cls) -> List[float]:
|
||||
"""获取最近的兴趣度记录"""
|
||||
return list(cls._recent_interest_records)
|
||||
|
||||
@classmethod
|
||||
def get_consecutive_count(cls) -> int:
|
||||
"""获取连续计数"""
|
||||
return cls._consecutive_count
|
||||
|
||||
@@ -8,9 +8,8 @@
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
# 导入新插件系统
|
||||
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo, ActionActivationType
|
||||
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from src.config.config import global_config
|
||||
|
||||
# 导入依赖的系统组件
|
||||
from src.common.logger import get_logger
|
||||
@@ -18,7 +17,6 @@ from src.common.logger import get_logger
|
||||
# 导入API模块 - 标准Python包方式
|
||||
from src.plugins.built_in.core_actions.no_reply import NoReplyAction
|
||||
from src.plugins.built_in.core_actions.emoji import EmojiAction
|
||||
from src.plugins.built_in.core_actions.reply import ReplyAction
|
||||
|
||||
logger = get_logger("core_actions")
|
||||
|
||||
@@ -52,10 +50,9 @@ class CoreActionsPlugin(BasePlugin):
|
||||
config_schema: dict = {
|
||||
"plugin": {
|
||||
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
||||
"config_version": ConfigField(type=str, default="0.4.0", description="配置文件版本"),
|
||||
"config_version": ConfigField(type=str, default="0.5.0", description="配置文件版本"),
|
||||
},
|
||||
"components": {
|
||||
"enable_reply": ConfigField(type=bool, default=True, description="是否启用回复动作"),
|
||||
"enable_no_reply": ConfigField(type=bool, default=True, description="是否启用不回复动作"),
|
||||
"enable_emoji": ConfigField(type=bool, default=True, description="是否启用发送表情/图片动作"),
|
||||
},
|
||||
@@ -64,24 +61,12 @@ class CoreActionsPlugin(BasePlugin):
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
"""返回插件包含的组件列表"""
|
||||
|
||||
if global_config.emoji.emoji_activate_type == "llm":
|
||||
EmojiAction.random_activation_probability = 0.0
|
||||
EmojiAction.focus_activation_type = ActionActivationType.LLM_JUDGE
|
||||
EmojiAction.normal_activation_type = ActionActivationType.LLM_JUDGE
|
||||
|
||||
elif global_config.emoji.emoji_activate_type == "random":
|
||||
EmojiAction.random_activation_probability = global_config.emoji.emoji_chance
|
||||
EmojiAction.focus_activation_type = ActionActivationType.RANDOM
|
||||
EmojiAction.normal_activation_type = ActionActivationType.RANDOM
|
||||
# --- 根据配置注册组件 ---
|
||||
components = []
|
||||
if self.get_config("components.enable_reply", True):
|
||||
components.append((ReplyAction.get_action_info(), ReplyAction))
|
||||
if self.get_config("components.enable_no_reply", True):
|
||||
components.append((NoReplyAction.get_action_info(), NoReplyAction))
|
||||
if self.get_config("components.enable_emoji", True):
|
||||
components.append((EmojiAction.get_action_info(), EmojiAction))
|
||||
|
||||
# components.append((DeepReplyAction.get_action_info(), DeepReplyAction))
|
||||
|
||||
return components
|
||||
|
||||
@@ -1,149 +0,0 @@
|
||||
# 导入新插件系统
|
||||
from src.plugin_system import BaseAction, ActionActivationType, ChatMode
|
||||
from src.config.config import global_config
|
||||
import random
|
||||
import time
|
||||
from typing import Tuple
|
||||
import asyncio
|
||||
import re
|
||||
import traceback
|
||||
|
||||
# 导入依赖的系统组件
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# 导入API模块 - 标准Python包方式
|
||||
from src.plugin_system.apis import generator_api, message_api
|
||||
from src.plugins.built_in.core_actions.no_reply import NoReplyAction
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.mais4u.mai_think import mai_thinking_manager
|
||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||
|
||||
logger = get_logger("reply_action")
|
||||
|
||||
|
||||
class ReplyAction(BaseAction):
|
||||
"""回复动作 - 参与聊天回复"""
|
||||
|
||||
# 激活设置
|
||||
focus_activation_type = ActionActivationType.NEVER
|
||||
normal_activation_type = ActionActivationType.NEVER
|
||||
mode_enable = ChatMode.FOCUS
|
||||
parallel_action = False
|
||||
|
||||
# 动作基本信息
|
||||
action_name = "reply"
|
||||
action_description = ""
|
||||
|
||||
# 动作参数定义
|
||||
action_parameters = {}
|
||||
|
||||
# 动作使用场景
|
||||
action_require = [""]
|
||||
|
||||
# 关联类型
|
||||
associated_types = ["text"]
|
||||
|
||||
def _parse_reply_target(self, target_message: str) -> tuple:
|
||||
sender = ""
|
||||
target = ""
|
||||
# 添加None检查,防止NoneType错误
|
||||
if target_message is None:
|
||||
return sender, target
|
||||
if ":" in target_message or ":" in target_message:
|
||||
# 使用正则表达式匹配中文或英文冒号
|
||||
parts = re.split(pattern=r"[::]", string=target_message, maxsplit=1)
|
||||
if len(parts) == 2:
|
||||
sender = parts[0].strip()
|
||||
target = parts[1].strip()
|
||||
return sender, target
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""执行回复动作"""
|
||||
logger.debug(f"{self.log_prefix} 决定进行回复")
|
||||
start_time = self.action_data.get("loop_start_time", time.time())
|
||||
|
||||
user_id = self.user_id
|
||||
platform = self.platform
|
||||
# logger.info(f"{self.log_prefix} 用户ID: {user_id}, 平台: {platform}")
|
||||
person_id = get_person_info_manager().get_person_id(platform, user_id) # type: ignore
|
||||
# logger.info(f"{self.log_prefix} 人物ID: {person_id}")
|
||||
person_name = get_person_info_manager().get_value_sync(person_id, "person_name")
|
||||
reply_to = f"{person_name}:{self.action_message.get('processed_plain_text', '')}" # type: ignore
|
||||
logger.info(f"{self.log_prefix} 决定进行回复,目标: {reply_to}")
|
||||
|
||||
try:
|
||||
if prepared_reply := self.action_data.get("prepared_reply", ""):
|
||||
reply_text = prepared_reply
|
||||
else:
|
||||
try:
|
||||
success, reply_set, _ = await asyncio.wait_for(
|
||||
generator_api.generate_reply(
|
||||
extra_info="",
|
||||
reply_to=reply_to,
|
||||
chat_id=self.chat_id,
|
||||
request_type="chat.replyer.focus",
|
||||
enable_tool=global_config.tool.enable_in_focus_chat,
|
||||
),
|
||||
timeout=global_config.chat.thinking_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"{self.log_prefix} 回复生成超时 ({global_config.chat.thinking_timeout}s)")
|
||||
return False, "timeout"
|
||||
|
||||
# 检查从start_time以来的新消息数量
|
||||
# 获取动作触发时间或使用默认值
|
||||
current_time = time.time()
|
||||
new_message_count = message_api.count_new_messages(
|
||||
chat_id=self.chat_id, start_time=start_time, end_time=current_time
|
||||
)
|
||||
|
||||
# 根据新消息数量决定是否使用reply_to
|
||||
need_reply = new_message_count >= random.randint(2, 4)
|
||||
if need_reply:
|
||||
logger.info(
|
||||
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,不使用引用回复"
|
||||
)
|
||||
|
||||
# 构建回复文本
|
||||
reply_text = ""
|
||||
first_replied = False
|
||||
reply_to_platform_id = f"{platform}:{user_id}"
|
||||
for reply_seg in reply_set:
|
||||
data = reply_seg[1]
|
||||
if not first_replied:
|
||||
if need_reply:
|
||||
await self.send_text(
|
||||
content=data, reply_to=reply_to, reply_to_platform_id=reply_to_platform_id, typing=False
|
||||
)
|
||||
else:
|
||||
await self.send_text(content=data, reply_to_platform_id=reply_to_platform_id, typing=False)
|
||||
first_replied = True
|
||||
else:
|
||||
await self.send_text(content=data, reply_to_platform_id=reply_to_platform_id, typing=True)
|
||||
reply_text += data
|
||||
|
||||
# 存储动作记录
|
||||
reply_text = f"你对{person_name}进行了回复:{reply_text}"
|
||||
|
||||
if ENABLE_S4U:
|
||||
await mai_thinking_manager.get_mai_think(self.chat_id).do_think_after_response(reply_text)
|
||||
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reply_text,
|
||||
action_done=True,
|
||||
)
|
||||
|
||||
# 重置NoReplyAction的连续计数器
|
||||
NoReplyAction.reset_consecutive_count()
|
||||
|
||||
return success, reply_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 回复动作执行失败: {e}")
|
||||
traceback.print_exc()
|
||||
return False, f"回复失败: {str(e)}"
|
||||
@@ -1,4 +1,4 @@
|
||||
from src.tools.tool_can_use.base_tool import BaseTool
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.common.database.database_model import Knowledges # Updated import
|
||||
from src.common.logger import get_logger
|
||||
@@ -77,7 +77,7 @@ class SearchKnowledgeTool(BaseTool):
|
||||
Union[str, list]: 格式化的信息字符串或原始结果列表
|
||||
"""
|
||||
if not query_embedding:
|
||||
return "" if not return_raw else []
|
||||
return [] if return_raw else ""
|
||||
|
||||
similar_items = []
|
||||
try:
|
||||
@@ -115,10 +115,10 @@ class SearchKnowledgeTool(BaseTool):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从 Peewee 数据库获取知识信息失败: {str(e)}")
|
||||
return "" if not return_raw else []
|
||||
return [] if return_raw else ""
|
||||
|
||||
if not results:
|
||||
return "" if not return_raw else []
|
||||
return [] if return_raw else ""
|
||||
|
||||
if return_raw:
|
||||
# Peewee 模型实例不能直接序列化为 JSON,如果需要原始模型,调用者需要处理
|
||||
@@ -1,4 +1,4 @@
|
||||
from src.tools.tool_can_use.base_tool import BaseTool
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
|
||||
# from src.common.database import db
|
||||
from src.common.logger import get_logger
|
||||
@@ -9,7 +9,7 @@
|
||||
},
|
||||
"license": "GPL-v3.0-or-later",
|
||||
"host_application": {
|
||||
"min_version": "0.9.0"
|
||||
"min_version": "0.9.1"
|
||||
},
|
||||
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
||||
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
||||
|
||||
@@ -11,6 +11,7 @@ from src.plugin_system import (
|
||||
component_manage_api,
|
||||
ComponentInfo,
|
||||
ComponentType,
|
||||
send_api,
|
||||
)
|
||||
|
||||
|
||||
@@ -27,8 +28,15 @@ class ManagementCommand(BaseCommand):
|
||||
or not self.message.message_info.user_info
|
||||
or str(self.message.message_info.user_info.user_id) not in self.get_config("plugin.permission", []) # type: ignore
|
||||
):
|
||||
await self.send_text("你没有权限使用插件管理命令")
|
||||
await self._send_message("你没有权限使用插件管理命令")
|
||||
return False, "没有权限", True
|
||||
if not self.message.chat_stream:
|
||||
await self._send_message("无法获取聊天流信息")
|
||||
return False, "无法获取聊天流信息", True
|
||||
self.stream_id = self.message.chat_stream.stream_id
|
||||
if not self.stream_id:
|
||||
await self._send_message("无法获取聊天流信息")
|
||||
return False, "无法获取聊天流信息", True
|
||||
command_list = self.matched_groups["manage_command"].strip().split(" ")
|
||||
if len(command_list) == 1:
|
||||
await self.show_help("all")
|
||||
@@ -42,7 +50,7 @@ class ManagementCommand(BaseCommand):
|
||||
case "help":
|
||||
await self.show_help("all")
|
||||
case _:
|
||||
await self.send_text("插件管理命令不合法")
|
||||
await self._send_message("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
if len(command_list) == 3:
|
||||
if command_list[1] == "plugin":
|
||||
@@ -56,7 +64,7 @@ class ManagementCommand(BaseCommand):
|
||||
case "rescan":
|
||||
await self._rescan_plugin_dirs()
|
||||
case _:
|
||||
await self.send_text("插件管理命令不合法")
|
||||
await self._send_message("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
elif command_list[1] == "component":
|
||||
if command_list[2] == "list":
|
||||
@@ -64,10 +72,10 @@ class ManagementCommand(BaseCommand):
|
||||
elif command_list[2] == "help":
|
||||
await self.show_help("component")
|
||||
else:
|
||||
await self.send_text("插件管理命令不合法")
|
||||
await self._send_message("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
else:
|
||||
await self.send_text("插件管理命令不合法")
|
||||
await self._send_message("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
if len(command_list) == 4:
|
||||
if command_list[1] == "plugin":
|
||||
@@ -81,28 +89,28 @@ class ManagementCommand(BaseCommand):
|
||||
case "add_dir":
|
||||
await self._add_dir(command_list[3])
|
||||
case _:
|
||||
await self.send_text("插件管理命令不合法")
|
||||
await self._send_message("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
elif command_list[1] == "component":
|
||||
if command_list[2] != "list":
|
||||
await self.send_text("插件管理命令不合法")
|
||||
await self._send_message("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
if command_list[3] == "enabled":
|
||||
await self._list_enabled_components()
|
||||
elif command_list[3] == "disabled":
|
||||
await self._list_disabled_components()
|
||||
else:
|
||||
await self.send_text("插件管理命令不合法")
|
||||
await self._send_message("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
else:
|
||||
await self.send_text("插件管理命令不合法")
|
||||
await self._send_message("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
if len(command_list) == 5:
|
||||
if command_list[1] != "component":
|
||||
await self.send_text("插件管理命令不合法")
|
||||
await self._send_message("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
if command_list[2] != "list":
|
||||
await self.send_text("插件管理命令不合法")
|
||||
await self._send_message("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
if command_list[3] == "enabled":
|
||||
await self._list_enabled_components(target_type=command_list[4])
|
||||
@@ -111,11 +119,11 @@ class ManagementCommand(BaseCommand):
|
||||
elif command_list[3] == "type":
|
||||
await self._list_registered_components_by_type(command_list[4])
|
||||
else:
|
||||
await self.send_text("插件管理命令不合法")
|
||||
await self._send_message("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
if len(command_list) == 6:
|
||||
if command_list[1] != "component":
|
||||
await self.send_text("插件管理命令不合法")
|
||||
await self._send_message("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
if command_list[2] == "enable":
|
||||
if command_list[3] == "global":
|
||||
@@ -123,7 +131,7 @@ class ManagementCommand(BaseCommand):
|
||||
elif command_list[3] == "local":
|
||||
await self._locally_enable_component(command_list[4], command_list[5])
|
||||
else:
|
||||
await self.send_text("插件管理命令不合法")
|
||||
await self._send_message("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
elif command_list[2] == "disable":
|
||||
if command_list[3] == "global":
|
||||
@@ -131,10 +139,10 @@ class ManagementCommand(BaseCommand):
|
||||
elif command_list[3] == "local":
|
||||
await self._locally_disable_component(command_list[4], command_list[5])
|
||||
else:
|
||||
await self.send_text("插件管理命令不合法")
|
||||
await self._send_message("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
else:
|
||||
await self.send_text("插件管理命令不合法")
|
||||
await self._send_message("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
|
||||
return True, "命令执行完成", True
|
||||
@@ -180,51 +188,51 @@ class ManagementCommand(BaseCommand):
|
||||
)
|
||||
case _:
|
||||
return
|
||||
await self.send_text(help_msg)
|
||||
await self._send_message(help_msg)
|
||||
|
||||
async def _list_loaded_plugins(self):
|
||||
plugins = plugin_manage_api.list_loaded_plugins()
|
||||
await self.send_text(f"已加载的插件: {', '.join(plugins)}")
|
||||
await self._send_message(f"已加载的插件: {', '.join(plugins)}")
|
||||
|
||||
async def _list_registered_plugins(self):
|
||||
plugins = plugin_manage_api.list_registered_plugins()
|
||||
await self.send_text(f"已注册的插件: {', '.join(plugins)}")
|
||||
await self._send_message(f"已注册的插件: {', '.join(plugins)}")
|
||||
|
||||
async def _rescan_plugin_dirs(self):
|
||||
plugin_manage_api.rescan_plugin_directory()
|
||||
await self.send_text("插件目录重新扫描执行中")
|
||||
await self._send_message("插件目录重新扫描执行中")
|
||||
|
||||
async def _load_plugin(self, plugin_name: str):
|
||||
success, count = plugin_manage_api.load_plugin(plugin_name)
|
||||
if success:
|
||||
await self.send_text(f"插件加载成功: {plugin_name}")
|
||||
await self._send_message(f"插件加载成功: {plugin_name}")
|
||||
else:
|
||||
if count == 0:
|
||||
await self.send_text(f"插件{plugin_name}为禁用状态")
|
||||
await self.send_text(f"插件加载失败: {plugin_name}")
|
||||
await self._send_message(f"插件{plugin_name}为禁用状态")
|
||||
await self._send_message(f"插件加载失败: {plugin_name}")
|
||||
|
||||
async def _unload_plugin(self, plugin_name: str):
|
||||
success = await plugin_manage_api.remove_plugin(plugin_name)
|
||||
if success:
|
||||
await self.send_text(f"插件卸载成功: {plugin_name}")
|
||||
await self._send_message(f"插件卸载成功: {plugin_name}")
|
||||
else:
|
||||
await self.send_text(f"插件卸载失败: {plugin_name}")
|
||||
await self._send_message(f"插件卸载失败: {plugin_name}")
|
||||
|
||||
async def _reload_plugin(self, plugin_name: str):
|
||||
success = await plugin_manage_api.reload_plugin(plugin_name)
|
||||
if success:
|
||||
await self.send_text(f"插件重新加载成功: {plugin_name}")
|
||||
await self._send_message(f"插件重新加载成功: {plugin_name}")
|
||||
else:
|
||||
await self.send_text(f"插件重新加载失败: {plugin_name}")
|
||||
await self._send_message(f"插件重新加载失败: {plugin_name}")
|
||||
|
||||
async def _add_dir(self, dir_path: str):
|
||||
await self.send_text(f"正在添加插件目录: {dir_path}")
|
||||
await self._send_message(f"正在添加插件目录: {dir_path}")
|
||||
success = plugin_manage_api.add_plugin_directory(dir_path)
|
||||
await asyncio.sleep(0.5) # 防止乱序发送
|
||||
if success:
|
||||
await self.send_text(f"插件目录添加成功: {dir_path}")
|
||||
await self._send_message(f"插件目录添加成功: {dir_path}")
|
||||
else:
|
||||
await self.send_text(f"插件目录添加失败: {dir_path}")
|
||||
await self._send_message(f"插件目录添加失败: {dir_path}")
|
||||
|
||||
def _fetch_all_registered_components(self) -> List[ComponentInfo]:
|
||||
all_plugin_info = component_manage_api.get_all_plugin_info()
|
||||
@@ -255,29 +263,29 @@ class ManagementCommand(BaseCommand):
|
||||
async def _list_all_registered_components(self):
|
||||
components_info = self._fetch_all_registered_components()
|
||||
if not components_info:
|
||||
await self.send_text("没有注册的组件")
|
||||
await self._send_message("没有注册的组件")
|
||||
return
|
||||
|
||||
all_components_str = ", ".join(
|
||||
f"{component.name} ({component.component_type})" for component in components_info
|
||||
)
|
||||
await self.send_text(f"已注册的组件: {all_components_str}")
|
||||
await self._send_message(f"已注册的组件: {all_components_str}")
|
||||
|
||||
async def _list_enabled_components(self, target_type: str = "global"):
|
||||
components_info = self._fetch_all_registered_components()
|
||||
if not components_info:
|
||||
await self.send_text("没有注册的组件")
|
||||
await self._send_message("没有注册的组件")
|
||||
return
|
||||
|
||||
if target_type == "global":
|
||||
enabled_components = [component for component in components_info if component.enabled]
|
||||
if not enabled_components:
|
||||
await self.send_text("没有满足条件的已启用全局组件")
|
||||
await self._send_message("没有满足条件的已启用全局组件")
|
||||
return
|
||||
enabled_components_str = ", ".join(
|
||||
f"{component.name} ({component.component_type})" for component in enabled_components
|
||||
)
|
||||
await self.send_text(f"满足条件的已启用全局组件: {enabled_components_str}")
|
||||
await self._send_message(f"满足条件的已启用全局组件: {enabled_components_str}")
|
||||
elif target_type == "local":
|
||||
locally_disabled_components = self._fetch_locally_disabled_components()
|
||||
enabled_components = [
|
||||
@@ -286,28 +294,28 @@ class ManagementCommand(BaseCommand):
|
||||
if (component.name not in locally_disabled_components and component.enabled)
|
||||
]
|
||||
if not enabled_components:
|
||||
await self.send_text("本聊天没有满足条件的已启用组件")
|
||||
await self._send_message("本聊天没有满足条件的已启用组件")
|
||||
return
|
||||
enabled_components_str = ", ".join(
|
||||
f"{component.name} ({component.component_type})" for component in enabled_components
|
||||
)
|
||||
await self.send_text(f"本聊天满足条件的已启用组件: {enabled_components_str}")
|
||||
await self._send_message(f"本聊天满足条件的已启用组件: {enabled_components_str}")
|
||||
|
||||
async def _list_disabled_components(self, target_type: str = "global"):
|
||||
components_info = self._fetch_all_registered_components()
|
||||
if not components_info:
|
||||
await self.send_text("没有注册的组件")
|
||||
await self._send_message("没有注册的组件")
|
||||
return
|
||||
|
||||
if target_type == "global":
|
||||
disabled_components = [component for component in components_info if not component.enabled]
|
||||
if not disabled_components:
|
||||
await self.send_text("没有满足条件的已禁用全局组件")
|
||||
await self._send_message("没有满足条件的已禁用全局组件")
|
||||
return
|
||||
disabled_components_str = ", ".join(
|
||||
f"{component.name} ({component.component_type})" for component in disabled_components
|
||||
)
|
||||
await self.send_text(f"满足条件的已禁用全局组件: {disabled_components_str}")
|
||||
await self._send_message(f"满足条件的已禁用全局组件: {disabled_components_str}")
|
||||
elif target_type == "local":
|
||||
locally_disabled_components = self._fetch_locally_disabled_components()
|
||||
disabled_components = [
|
||||
@@ -316,12 +324,12 @@ class ManagementCommand(BaseCommand):
|
||||
if (component.name in locally_disabled_components or not component.enabled)
|
||||
]
|
||||
if not disabled_components:
|
||||
await self.send_text("本聊天没有满足条件的已禁用组件")
|
||||
await self._send_message("本聊天没有满足条件的已禁用组件")
|
||||
return
|
||||
disabled_components_str = ", ".join(
|
||||
f"{component.name} ({component.component_type})" for component in disabled_components
|
||||
)
|
||||
await self.send_text(f"本聊天满足条件的已禁用组件: {disabled_components_str}")
|
||||
await self._send_message(f"本聊天满足条件的已禁用组件: {disabled_components_str}")
|
||||
|
||||
async def _list_registered_components_by_type(self, target_type: str):
|
||||
match target_type:
|
||||
@@ -332,18 +340,18 @@ class ManagementCommand(BaseCommand):
|
||||
case "event_handler":
|
||||
component_type = ComponentType.EVENT_HANDLER
|
||||
case _:
|
||||
await self.send_text(f"未知组件类型: {target_type}")
|
||||
await self._send_message(f"未知组件类型: {target_type}")
|
||||
return
|
||||
|
||||
components_info = component_manage_api.get_components_info_by_type(component_type)
|
||||
if not components_info:
|
||||
await self.send_text(f"没有注册的 {target_type} 组件")
|
||||
await self._send_message(f"没有注册的 {target_type} 组件")
|
||||
return
|
||||
|
||||
components_str = ", ".join(
|
||||
f"{name} ({component.component_type})" for name, component in components_info.items()
|
||||
)
|
||||
await self.send_text(f"注册的 {target_type} 组件: {components_str}")
|
||||
await self._send_message(f"注册的 {target_type} 组件: {components_str}")
|
||||
|
||||
async def _globally_enable_component(self, component_name: str, component_type: str):
|
||||
match component_type:
|
||||
@@ -354,12 +362,12 @@ class ManagementCommand(BaseCommand):
|
||||
case "event_handler":
|
||||
target_component_type = ComponentType.EVENT_HANDLER
|
||||
case _:
|
||||
await self.send_text(f"未知组件类型: {component_type}")
|
||||
await self._send_message(f"未知组件类型: {component_type}")
|
||||
return
|
||||
if component_manage_api.globally_enable_component(component_name, target_component_type):
|
||||
await self.send_text(f"全局启用组件成功: {component_name}")
|
||||
await self._send_message(f"全局启用组件成功: {component_name}")
|
||||
else:
|
||||
await self.send_text(f"全局启用组件失败: {component_name}")
|
||||
await self._send_message(f"全局启用组件失败: {component_name}")
|
||||
|
||||
async def _globally_disable_component(self, component_name: str, component_type: str):
|
||||
match component_type:
|
||||
@@ -370,13 +378,13 @@ class ManagementCommand(BaseCommand):
|
||||
case "event_handler":
|
||||
target_component_type = ComponentType.EVENT_HANDLER
|
||||
case _:
|
||||
await self.send_text(f"未知组件类型: {component_type}")
|
||||
await self._send_message(f"未知组件类型: {component_type}")
|
||||
return
|
||||
success = await component_manage_api.globally_disable_component(component_name, target_component_type)
|
||||
if success:
|
||||
await self.send_text(f"全局禁用组件成功: {component_name}")
|
||||
await self._send_message(f"全局禁用组件成功: {component_name}")
|
||||
else:
|
||||
await self.send_text(f"全局禁用组件失败: {component_name}")
|
||||
await self._send_message(f"全局禁用组件失败: {component_name}")
|
||||
|
||||
async def _locally_enable_component(self, component_name: str, component_type: str):
|
||||
match component_type:
|
||||
@@ -387,16 +395,16 @@ class ManagementCommand(BaseCommand):
|
||||
case "event_handler":
|
||||
target_component_type = ComponentType.EVENT_HANDLER
|
||||
case _:
|
||||
await self.send_text(f"未知组件类型: {component_type}")
|
||||
await self._send_message(f"未知组件类型: {component_type}")
|
||||
return
|
||||
if component_manage_api.locally_enable_component(
|
||||
component_name,
|
||||
target_component_type,
|
||||
self.message.chat_stream.stream_id,
|
||||
):
|
||||
await self.send_text(f"本地启用组件成功: {component_name}")
|
||||
await self._send_message(f"本地启用组件成功: {component_name}")
|
||||
else:
|
||||
await self.send_text(f"本地启用组件失败: {component_name}")
|
||||
await self._send_message(f"本地启用组件失败: {component_name}")
|
||||
|
||||
async def _locally_disable_component(self, component_name: str, component_type: str):
|
||||
match component_type:
|
||||
@@ -407,34 +415,40 @@ class ManagementCommand(BaseCommand):
|
||||
case "event_handler":
|
||||
target_component_type = ComponentType.EVENT_HANDLER
|
||||
case _:
|
||||
await self.send_text(f"未知组件类型: {component_type}")
|
||||
await self._send_message(f"未知组件类型: {component_type}")
|
||||
return
|
||||
if component_manage_api.locally_disable_component(
|
||||
component_name,
|
||||
target_component_type,
|
||||
self.message.chat_stream.stream_id,
|
||||
):
|
||||
await self.send_text(f"本地禁用组件成功: {component_name}")
|
||||
await self._send_message(f"本地禁用组件成功: {component_name}")
|
||||
else:
|
||||
await self.send_text(f"本地禁用组件失败: {component_name}")
|
||||
await self._send_message(f"本地禁用组件失败: {component_name}")
|
||||
|
||||
async def _send_message(self, message: str):
|
||||
await send_api.text_to_stream(message, self.stream_id, typing=False, storage_message=False)
|
||||
|
||||
|
||||
@register_plugin
|
||||
class PluginManagementPlugin(BasePlugin):
|
||||
plugin_name: str = "plugin_management_plugin"
|
||||
enable_plugin: bool = True
|
||||
enable_plugin: bool = False
|
||||
dependencies: list[str] = []
|
||||
python_dependencies: list[str] = []
|
||||
config_file_name: str = "config.toml"
|
||||
config_schema: dict = {
|
||||
"plugin": {
|
||||
"enable": ConfigField(bool, default=True, description="是否启用插件"),
|
||||
"permission": ConfigField(list, default=[], description="有权限使用插件管理命令的用户列表"),
|
||||
"enabled": ConfigField(bool, default=False, description="是否启用插件"),
|
||||
"config_version": ConfigField(type=str, default="1.1.0", description="配置文件版本"),
|
||||
"permission": ConfigField(
|
||||
list, default=[], description="有权限使用插件管理命令的用户列表,请填写字符串形式的用户ID"
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[CommandInfo, Type[BaseCommand]]]:
|
||||
components = []
|
||||
if self.get_config("plugin.enable", True):
|
||||
if self.get_config("plugin.enabled", True):
|
||||
components.append((ManagementCommand.get_command_info(), ManagementCommand))
|
||||
return components
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
from src.tools.tool_can_use.base_tool import (
|
||||
BaseTool,
|
||||
register_tool,
|
||||
discover_tools,
|
||||
get_all_tool_definitions,
|
||||
get_tool_instance,
|
||||
TOOL_REGISTRY,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseTool",
|
||||
"register_tool",
|
||||
"discover_tools",
|
||||
"get_all_tool_definitions",
|
||||
"get_tool_instance",
|
||||
"TOOL_REGISTRY",
|
||||
]
|
||||
|
||||
# 自动发现并注册工具
|
||||
discover_tools()
|
||||
@@ -1,115 +0,0 @@
|
||||
from typing import List, Any, Optional, Type
|
||||
import inspect
|
||||
import importlib
|
||||
import pkgutil
|
||||
import os
|
||||
from src.common.logger import get_logger
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("base_tool")
|
||||
|
||||
# 工具注册表
|
||||
TOOL_REGISTRY = {}
|
||||
|
||||
|
||||
class BaseTool:
|
||||
"""所有工具的基类"""
|
||||
|
||||
# 工具名称,子类必须重写
|
||||
name = None
|
||||
# 工具描述,子类必须重写
|
||||
description = None
|
||||
# 工具参数定义,子类必须重写
|
||||
parameters = None
|
||||
|
||||
@classmethod
|
||||
def get_tool_definition(cls) -> dict[str, Any]:
|
||||
"""获取工具定义,用于LLM工具调用
|
||||
|
||||
Returns:
|
||||
dict: 工具定义字典
|
||||
"""
|
||||
if not cls.name or not cls.description or not cls.parameters:
|
||||
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
|
||||
}
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行工具函数
|
||||
|
||||
Args:
|
||||
function_args: 工具调用参数
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现execute方法")
|
||||
|
||||
|
||||
def register_tool(tool_class: Type[BaseTool]):
|
||||
"""注册工具到全局注册表
|
||||
|
||||
Args:
|
||||
tool_class: 工具类
|
||||
"""
|
||||
if not issubclass(tool_class, BaseTool):
|
||||
raise TypeError(f"{tool_class.__name__} 不是 BaseTool 的子类")
|
||||
|
||||
tool_name = tool_class.name
|
||||
if not tool_name:
|
||||
raise ValueError(f"工具类 {tool_class.__name__} 没有定义 name 属性")
|
||||
|
||||
TOOL_REGISTRY[tool_name] = tool_class
|
||||
logger.info(f"已注册: {tool_name}")
|
||||
|
||||
|
||||
def discover_tools():
|
||||
"""自动发现并注册tool_can_use目录下的所有工具"""
|
||||
# 获取当前目录路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
package_name = os.path.basename(current_dir)
|
||||
|
||||
# 遍历包中的所有模块
|
||||
for _, module_name, _ in pkgutil.iter_modules([current_dir]):
|
||||
# 跳过当前模块和__pycache__
|
||||
if module_name == "base_tool" or module_name.startswith("__"):
|
||||
continue
|
||||
|
||||
# 导入模块
|
||||
module = importlib.import_module(f"src.tools.{package_name}.{module_name}")
|
||||
|
||||
# 查找模块中的工具类
|
||||
for _, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
|
||||
register_tool(obj)
|
||||
|
||||
logger.info(f"工具发现完成,共注册 {len(TOOL_REGISTRY)} 个工具")
|
||||
|
||||
|
||||
def get_all_tool_definitions() -> List[dict[str, Any]]:
|
||||
"""获取所有已注册工具的定义
|
||||
|
||||
Returns:
|
||||
List[dict]: 工具定义列表
|
||||
"""
|
||||
return [tool_class().get_tool_definition() for tool_class in TOOL_REGISTRY.values()]
|
||||
|
||||
|
||||
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
|
||||
"""获取指定名称的工具实例
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
|
||||
Returns:
|
||||
Optional[BaseTool]: 工具实例,如果找不到则返回None
|
||||
"""
|
||||
tool_class = TOOL_REGISTRY.get(tool_name)
|
||||
if not tool_class:
|
||||
return None
|
||||
return tool_class()
|
||||
@@ -1,49 +0,0 @@
|
||||
from src.tools.tool_can_use.base_tool import BaseTool
|
||||
from src.common.logger import get_logger
|
||||
from typing import Any
|
||||
|
||||
logger = get_logger("compare_numbers_tool")
|
||||
|
||||
|
||||
class CompareNumbersTool(BaseTool):
|
||||
"""比较两个数大小的工具"""
|
||||
|
||||
name = "compare_numbers"
|
||||
description = "使用工具 比较两个数的大小,返回较大的数"
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"num1": {"type": "number", "description": "第一个数字"},
|
||||
"num2": {"type": "number", "description": "第二个数字"},
|
||||
},
|
||||
"required": ["num1", "num2"],
|
||||
}
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行比较两个数的大小
|
||||
|
||||
Args:
|
||||
function_args: 工具参数
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
num1: int | float = function_args.get("num1") # type: ignore
|
||||
num2: int | float = function_args.get("num2") # type: ignore
|
||||
|
||||
try:
|
||||
if num1 > num2:
|
||||
result = f"{num1} 大于 {num2}"
|
||||
elif num1 < num2:
|
||||
result = f"{num1} 小于 {num2}"
|
||||
else:
|
||||
result = f"{num1} 等于 {num2}"
|
||||
|
||||
return {"type": "comparison_result", "id": f"{num1}_vs_{num2}", "content": result}
|
||||
except Exception as e:
|
||||
logger.error(f"比较数字失败: {str(e)}")
|
||||
return {"type": "info", "id": f"{num1}_vs_{num2}", "content": f"比较数字失败,炸了: {str(e)}"}
|
||||
|
||||
|
||||
# 注册工具
|
||||
# register_tool(CompareNumbersTool)
|
||||
@@ -1,104 +0,0 @@
|
||||
from src.tools.tool_can_use.base_tool import BaseTool
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.common.logger import get_logger
|
||||
import time
|
||||
|
||||
|
||||
logger = get_logger("rename_person_tool")
|
||||
|
||||
|
||||
class RenamePersonTool(BaseTool):
|
||||
name = "rename_person"
|
||||
description = (
|
||||
"这个工具可以改变用户的昵称。你可以选择改变对他人的称呼。你想给人改名,叫别人别的称呼,需要调用这个工具。"
|
||||
)
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"person_name": {"type": "string", "description": "需要重新取名的用户的当前昵称"},
|
||||
"message_content": {
|
||||
"type": "string",
|
||||
"description": "当前的聊天内容或特定要求,用于提供取名建议的上下文,尽可能详细。",
|
||||
},
|
||||
},
|
||||
"required": ["person_name"],
|
||||
}
|
||||
|
||||
async def execute(self, function_args: dict, message_txt=""):
|
||||
"""
|
||||
执行取名工具逻辑
|
||||
|
||||
Args:
|
||||
function_args (dict): 包含 'person_name' 和可选 'message_content' 的字典
|
||||
message_txt (str): 原始消息文本 (这里未使用,因为 message_content 更明确)
|
||||
|
||||
Returns:
|
||||
dict: 包含执行结果的字典
|
||||
"""
|
||||
person_name_to_find = function_args.get("person_name")
|
||||
request_context = function_args.get("message_content", "") # 如果没有提供,则为空字符串
|
||||
|
||||
if not person_name_to_find:
|
||||
return {"name": self.name, "content": "错误:必须提供需要重命名的用户昵称 (person_name)。"}
|
||||
person_info_manager = get_person_info_manager()
|
||||
try:
|
||||
# 1. 根据昵称查找用户信息
|
||||
logger.debug(f"尝试根据昵称 '{person_name_to_find}' 查找用户...")
|
||||
person_info = await person_info_manager.get_person_info_by_name(person_name_to_find)
|
||||
|
||||
if not person_info:
|
||||
logger.info(f"未找到昵称为 '{person_name_to_find}' 的用户。")
|
||||
return {
|
||||
"name": self.name,
|
||||
"content": f"找不到昵称为 '{person_name_to_find}' 的用户。请确保输入的是我之前为该用户取的昵称。",
|
||||
}
|
||||
|
||||
person_id = person_info.get("person_id")
|
||||
user_nickname = person_info.get("nickname") # 这是用户原始昵称
|
||||
user_cardname = person_info.get("user_cardname")
|
||||
user_avatar = person_info.get("user_avatar")
|
||||
|
||||
if not person_id:
|
||||
logger.error(f"找到了用户 '{person_name_to_find}' 但无法获取 person_id")
|
||||
return {"name": self.name, "content": f"找到了用户 '{person_name_to_find}' 但获取内部ID时出错。"}
|
||||
|
||||
# 2. 调用 qv_person_name 进行取名
|
||||
logger.debug(
|
||||
f"为用户 {person_id} (原昵称: {person_name_to_find}) 调用 qv_person_name,请求上下文: '{request_context}'"
|
||||
)
|
||||
result = await person_info_manager.qv_person_name(
|
||||
person_id=person_id,
|
||||
user_nickname=user_nickname, # type: ignore
|
||||
user_cardname=user_cardname, # type: ignore
|
||||
user_avatar=user_avatar, # type: ignore
|
||||
request=request_context,
|
||||
)
|
||||
|
||||
# 3. 处理结果
|
||||
if result and result.get("nickname"):
|
||||
new_name = result["nickname"]
|
||||
# reason = result.get("reason", "未提供理由")
|
||||
logger.info(f"成功为用户 {person_id} 取了新昵称: {new_name}")
|
||||
|
||||
content = f"已成功将用户 {person_name_to_find} 的备注名更新为 {new_name}"
|
||||
logger.info(content)
|
||||
return {"type": "info", "id": f"rename_success_{time.time()}", "content": content}
|
||||
else:
|
||||
logger.warning(f"为用户 {person_id} 调用 qv_person_name 后未能成功获取新昵称。")
|
||||
# 尝试从内存中获取可能已经更新的名字
|
||||
current_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
if current_name and current_name != person_name_to_find:
|
||||
return {
|
||||
"name": self.name,
|
||||
"content": f"尝试取新昵称时遇到一点小问题,但我已经将 '{person_name_to_find}' 的昵称更新为 '{current_name}' 了。",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"name": self.name,
|
||||
"content": f"尝试为 '{person_name_to_find}' 取新昵称时遇到了问题,未能成功生成。可能需要稍后再试。",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"重命名失败: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return {"type": "info_error", "id": f"rename_error_{time.time()}", "content": error_msg}
|
||||
@@ -1,55 +0,0 @@
|
||||
import json
|
||||
from src.common.logger import get_logger
|
||||
from src.tools.tool_can_use import get_all_tool_definitions, get_tool_instance
|
||||
|
||||
logger = get_logger("tool_use")
|
||||
|
||||
|
||||
class ToolUser:
|
||||
@staticmethod
|
||||
def _define_tools():
|
||||
"""获取所有已注册工具的定义
|
||||
|
||||
Returns:
|
||||
list: 工具定义列表
|
||||
"""
|
||||
return get_all_tool_definitions()
|
||||
|
||||
@staticmethod
|
||||
async def _execute_tool_call(tool_call):
|
||||
"""执行特定的工具调用
|
||||
|
||||
Args:
|
||||
tool_call: 工具调用对象
|
||||
message_txt: 原始消息文本
|
||||
|
||||
Returns:
|
||||
dict: 工具调用结果
|
||||
"""
|
||||
try:
|
||||
function_name = tool_call["function"]["name"]
|
||||
function_args = json.loads(tool_call["function"]["arguments"])
|
||||
|
||||
# 获取对应工具实例
|
||||
tool_instance = get_tool_instance(function_name)
|
||||
if not tool_instance:
|
||||
logger.warning(f"未知工具名称: {function_name}")
|
||||
return None
|
||||
|
||||
# 执行工具
|
||||
result = await tool_instance.execute(function_args)
|
||||
if result:
|
||||
# 直接使用 function_name 作为 tool_type
|
||||
tool_type = function_name
|
||||
|
||||
return {
|
||||
"tool_call_id": tool_call["id"],
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"type": tool_type,
|
||||
"content": result["content"],
|
||||
}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"执行工具调用时发生错误: {str(e)}")
|
||||
return None
|
||||
Reference in New Issue
Block a user