Merge branch 'dev' into dev-api-ada to resolve conflicts

This commit is contained in:
UnCLAS-Prommer
2025-07-29 10:22:43 +08:00
92 changed files with 3980 additions and 6492 deletions

View File

@@ -2,7 +2,7 @@ import asyncio
import time
import traceback
import random
from typing import List, Optional, Dict, Any
from typing import List, Optional, Dict, Any, Tuple
from rich.traceback import install
from src.config.config import global_config
@@ -18,11 +18,12 @@ from src.chat.chat_loop.hfc_utils import CycleDetail
from src.person_info.relationship_builder_manager import relationship_builder_manager
from src.person_info.person_info import get_person_info_manager
from src.plugin_system.base.component_types import ActionInfo, ChatMode
from src.plugin_system.apis import generator_api, send_api, message_api
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
from src.chat.willing.willing_manager import get_willing_manager
from src.mais4u.mai_think import mai_thinking_manager
from maim_message.message_base import GroupInfo
from src.mais4u.constant_s4u import ENABLE_S4U
from src.plugins.built_in.core_actions.no_reply import NoReplyAction
from src.chat.chat_loop.hfc_utils import send_typing, stop_typing
ERROR_LOOP_INFO = {
"loop_plan_info": {
@@ -88,11 +89,6 @@ class HeartFChatting:
self.loop_mode = ChatMode.NORMAL # 初始循环模式为普通模式
# 新增:消息计数器和疲惫阈值
self._message_count = 0 # 发送的消息计数
self._message_threshold = max(10, int(30 * global_config.chat.focus_value))
self._fatigue_triggered = False # 是否已触发疲惫退出
self.action_manager = ActionManager()
self.action_planner = ActionPlanner(chat_id=self.stream_id, action_manager=self.action_manager)
self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.stream_id)
@@ -112,7 +108,6 @@ class HeartFChatting:
self.last_read_time = time.time() - 1
self.willing_amplifier = 1
self.willing_manager = get_willing_manager()
logger.info(f"{self.log_prefix} HeartFChatting 初始化完成")
@@ -182,6 +177,9 @@ class HeartFChatting:
if self.loop_mode == ChatMode.NORMAL:
self.energy_value -= 0.3
self.energy_value = max(self.energy_value, 0.3)
if self.loop_mode == ChatMode.FOCUS:
self.energy_value -= 0.6
self.energy_value = max(self.energy_value, 0.3)
def print_cycle_info(self, cycle_timers):
# 记录循环信息和计时器结果
@@ -200,9 +198,9 @@ class HeartFChatting:
async def _loopbody(self):
if self.loop_mode == ChatMode.FOCUS:
if await self._observe():
self.energy_value -= 1 * global_config.chat.focus_value
self.energy_value -= 1 / global_config.chat.focus_value
else:
self.energy_value -= 3 * global_config.chat.focus_value
self.energy_value -= 3 / global_config.chat.focus_value
if self.energy_value <= 1:
self.energy_value = 1
self.loop_mode = ChatMode.NORMAL
@@ -218,15 +216,17 @@ class HeartFChatting:
limit_mode="earliest",
filter_bot=True,
)
if global_config.chat.focus_value != 0:
if len(new_messages_data) > 3 / pow(global_config.chat.focus_value, 0.5):
self.loop_mode = ChatMode.FOCUS
self.energy_value = (
10 + (len(new_messages_data) / (3 / pow(global_config.chat.focus_value, 0.5))) * 10
)
return True
if len(new_messages_data) > 3 * global_config.chat.focus_value:
self.loop_mode = ChatMode.FOCUS
self.energy_value = 10 + (len(new_messages_data) / (3 * global_config.chat.focus_value)) * 10
return True
if self.energy_value >= 30 * global_config.chat.focus_value:
self.loop_mode = ChatMode.FOCUS
return True
if self.energy_value >= 30:
self.loop_mode = ChatMode.FOCUS
return True
if new_messages_data:
earliest_messages_data = new_messages_data[0]
@@ -235,10 +235,10 @@ class HeartFChatting:
if_think = await self.normal_response(earliest_messages_data)
if if_think:
factor = max(global_config.chat.focus_value, 0.1)
self.energy_value *= 1.1 / factor
self.energy_value *= 1.1 * factor
logger.info(f"{self.log_prefix} 进行了思考,能量值按倍数增加,当前能量值:{self.energy_value:.1f}")
else:
self.energy_value += 0.1 / global_config.chat.focus_value
self.energy_value += 0.1 * global_config.chat.focus_value
logger.debug(f"{self.log_prefix} 没有进行思考,能量值线性增加,当前能量值:{self.energy_value:.1f}")
logger.debug(f"{self.log_prefix} 当前能量值:{self.energy_value:.1f}")
@@ -257,44 +257,69 @@ class HeartFChatting:
person_name = await person_info_manager.get_value(person_id, "person_name")
return f"{person_name}:{message_data.get('processed_plain_text')}"
async def send_typing(self):
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
async def _send_and_store_reply(
self,
response_set,
reply_to_str,
loop_start_time,
action_message,
cycle_timers: Dict[str, float],
thinking_id,
plan_result,
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
with Timer("回复发送", cycle_timers):
reply_text = await self._send_response(response_set, reply_to_str, loop_start_time, action_message)
chat = await get_chat_manager().get_or_create_stream(
platform="amaidesu_default",
user_info=None,
group_info=group_info,
# 存储reply action信息
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id(
action_message.get("chat_info_platform", ""),
action_message.get("user_id", ""),
)
person_name = await person_info_manager.get_value(person_id, "person_name")
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=action_prompt_display,
action_done=True,
thinking_id=thinking_id,
action_data={"reply_text": reply_text, "reply_to": reply_to_str},
action_name="reply",
)
await send_api.custom_to_stream(
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
)
# 构建循环信息
loop_info: Dict[str, Any] = {
"loop_plan_info": {
"action_result": plan_result.get("action_result", {}),
},
"loop_action_info": {
"action_taken": True,
"reply_text": reply_text,
"command": "",
"taken_time": time.time(),
},
}
async def stop_typing(self):
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
chat = await get_chat_manager().get_or_create_stream(
platform="amaidesu_default",
user_info=None,
group_info=group_info,
)
await send_api.custom_to_stream(
message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False
)
return loop_info, reply_text, cycle_timers
async def _observe(self, message_data: Optional[Dict[str, Any]] = None):
# sourcery skip: hoist-statement-from-if, merge-comparisons, reintroduce-else
if not message_data:
message_data = {}
action_type = "no_action"
reply_text = "" # 初始化reply_text变量避免UnboundLocalError
gen_task = None # 初始化gen_task变量避免UnboundLocalError
reply_to_str = "" # 初始化reply_to_str变量
# 创建新的循环信息
cycle_timers, thinking_id = self.start_cycle()
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考[模式:{self.loop_mode}]")
if ENABLE_S4U:
await self.send_typing()
await send_typing()
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
loop_start_time = time.time()
@@ -310,95 +335,254 @@ class HeartFChatting:
except Exception as e:
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
# 如果normal开始一个回复生成进程先准备好回复其实是和planer同时进行的
# 检查是否在normal模式下没有可用动作除了reply相关动作
skip_planner = False
if self.loop_mode == ChatMode.NORMAL:
reply_to_str = await self.build_reply_to_str(message_data)
gen_task = asyncio.create_task(self._generate_response(message_data, available_actions, reply_to_str))
# 过滤掉reply相关的动作检查是否还有其他动作
non_reply_actions = {
k: v for k, v in available_actions.items() if k not in ["reply", "no_reply", "no_action"]
}
with Timer("规划器", cycle_timers):
plan_result, target_message = await self.action_planner.plan(mode=self.loop_mode)
if not non_reply_actions:
skip_planner = True
logger.info(f"{self.log_prefix} Normal模式下没有可用动作直接回复")
action_result: dict = plan_result.get("action_result", {}) # type: ignore
action_type, action_data, reasoning, is_parallel = (
action_result.get("action_type", "error"),
action_result.get("action_data", {}),
action_result.get("reasoning", "未提供理由"),
action_result.get("is_parallel", True),
)
# 直接设置为reply动作
action_type = "reply"
reasoning = ""
action_data = {"loop_start_time": loop_start_time}
is_parallel = False
action_data["loop_start_time"] = loop_start_time
# 构建plan_result用于后续处理
plan_result = {
"action_result": {
"action_type": action_type,
"action_data": action_data,
"reasoning": reasoning,
"timestamp": time.time(),
"is_parallel": is_parallel,
},
"action_prompt": "",
}
target_message = message_data
if self.loop_mode == ChatMode.NORMAL:
if action_type == "no_action":
logger.info(f"[{self.log_prefix}] {global_config.bot.nickname} 决定进行回复")
elif is_parallel:
logger.info(
f"[{self.log_prefix}] {global_config.bot.nickname} 决定进行回复, 同时执行{action_type}动作"
# 如果normal模式且不跳过规划器开始一个回复生成进程先准备好回复其实是和planer同时进行的
if not skip_planner:
reply_to_str = await self.build_reply_to_str(message_data)
gen_task = asyncio.create_task(
self._generate_response(
message_data=message_data,
available_actions=available_actions,
reply_to=reply_to_str,
request_type="chat.replyer.normal",
)
)
else:
logger.info(f"[{self.log_prefix}] {global_config.bot.nickname} 决定执行{action_type}动作")
if action_type == "no_action":
if not skip_planner:
with Timer("规划器", cycle_timers):
plan_result, target_message = await self.action_planner.plan(mode=self.loop_mode)
action_result: Dict[str, Any] = plan_result.get("action_result", {}) # type: ignore
action_type, action_data, reasoning, is_parallel = (
action_result.get("action_type", "error"),
action_result.get("action_data", {}),
action_result.get("reasoning", "未提供理由"),
action_result.get("is_parallel", True),
)
action_data["loop_start_time"] = loop_start_time
if action_type == "reply":
logger.info(f"{self.log_prefix}{global_config.bot.nickname} 决定进行回复")
elif is_parallel:
logger.info(f"{self.log_prefix}{global_config.bot.nickname} 决定进行回复, 同时执行{action_type}动作")
else:
# 只有在gen_task存在时才进行相关操作
if gen_task:
if not gen_task.done():
gen_task.cancel()
logger.debug(f"{self.log_prefix} 已取消预生成的回复任务")
logger.info(
f"{self.log_prefix}{global_config.bot.nickname} 原本想要回复,但选择执行{action_type},不发表回复"
)
elif generation_result := gen_task.result():
content = " ".join([item[1] for item in generation_result if item[0] == "text"])
logger.debug(f"{self.log_prefix} 预生成的回复任务已完成")
logger.info(
f"{self.log_prefix}{global_config.bot.nickname} 原本想要回复:{content},但选择执行{action_type},不发表回复"
)
else:
logger.warning(f"{self.log_prefix} 预生成的回复任务未生成有效内容")
action_message: Dict[str, Any] = message_data or target_message # type: ignore
if action_type == "reply":
# 等待回复生成完毕
gather_timeout = global_config.chat.thinking_timeout
try:
response_set = await asyncio.wait_for(gen_task, timeout=gather_timeout)
except asyncio.TimeoutError:
response_set = None
if self.loop_mode == ChatMode.NORMAL:
# 只有在gen_task存在时才等待
if not gen_task:
reply_to_str = await self.build_reply_to_str(message_data)
gen_task = asyncio.create_task(
self._generate_response(
message_data=message_data,
available_actions=available_actions,
reply_to=reply_to_str,
request_type="chat.replyer.normal",
)
)
if response_set:
content = " ".join([item[1] for item in response_set if item[0] == "text"])
gather_timeout = global_config.chat.thinking_timeout
try:
response_set = await asyncio.wait_for(gen_task, timeout=gather_timeout)
except asyncio.TimeoutError:
logger.warning(f"{self.log_prefix} 回复生成超时>{global_config.chat.thinking_timeout}s已跳过")
response_set = None
# 模型炸了,没有回复内容生成
if not response_set:
logger.warning(f"[{self.log_prefix}] 模型未生成回复内容")
return False
elif action_type not in ["no_action"] and not is_parallel:
logger.info(
f"[{self.log_prefix}] {global_config.bot.nickname} 原本想要回复:{content},但选择执行{action_type},不发表回复"
)
return False
# 模型炸了或超时,没有回复内容生成
if not response_set:
logger.warning(f"{self.log_prefix}模型未生成回复内容")
return False
else:
logger.info(f"{self.log_prefix}{global_config.bot.nickname} 决定进行回复 (focus模式)")
logger.info(f"[{self.log_prefix}] {global_config.bot.nickname} 决定的回复内容: {content}")
# 构建reply_to字符串
reply_to_str = await self.build_reply_to_str(action_message)
# 发送回复 (不再需要传入 chat)
reply_text = await self._send_response(response_set, reply_to_str, loop_start_time,message_data)
if ENABLE_S4U:
await self.stop_typing()
await mai_thinking_manager.get_mai_think(self.stream_id).do_think_after_response(reply_text)
# 生成回复
with Timer("回复生成", cycle_timers):
response_set = await self._generate_response(
message_data=action_message,
available_actions=available_actions,
reply_to=reply_to_str,
request_type="chat.replyer.focus",
)
if not response_set:
logger.warning(f"{self.log_prefix}模型未生成回复内容")
return False
loop_info, reply_text, cycle_timers = await self._send_and_store_reply(
response_set, reply_to_str, loop_start_time, action_message, cycle_timers, thinking_id, plan_result
)
return True
else:
action_message: Dict[str, Any] = message_data or target_message # type: ignore
# 并行执行:同时进行回复发送和动作执行
# 先置空防止未定义错误
background_reply_task = None
background_action_task = None
# 如果是并行执行且在normal模式下需要等待预生成的回复任务完成并发送回复
if self.loop_mode == ChatMode.NORMAL and is_parallel and gen_task:
# 动作执行计时
with Timer("动作执行", cycle_timers):
success, reply_text, command = await self._handle_action(
action_type, reasoning, action_data, cycle_timers, thinking_id, action_message
async def handle_reply_task() -> Tuple[Optional[Dict[str, Any]], str, Dict[str, float]]:
# 等待预生成的回复任务完成
gather_timeout = global_config.chat.thinking_timeout
try:
response_set = await asyncio.wait_for(gen_task, timeout=gather_timeout)
except asyncio.TimeoutError:
logger.warning(
f"{self.log_prefix} 并行执行:回复生成超时>{global_config.chat.thinking_timeout}s已跳过"
)
return None, "", {}
except asyncio.CancelledError:
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
return None, "", {}
if not response_set:
logger.warning(f"{self.log_prefix} 模型超时或生成回复内容为空")
return None, "", {}
reply_to_str = await self.build_reply_to_str(action_message)
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
response_set,
reply_to_str,
loop_start_time,
action_message,
cycle_timers,
thinking_id,
plan_result,
)
return loop_info, reply_text, cycle_timers_reply
# 执行回复任务并赋值到变量
background_reply_task = asyncio.create_task(handle_reply_task())
# 动作执行任务
async def handle_action_task():
with Timer("动作执行", cycle_timers):
success, reply_text, command = await self._handle_action(
action_type, reasoning, action_data, cycle_timers, thinking_id, action_message
)
return success, reply_text, command
# 执行动作任务并赋值到变量
background_action_task = asyncio.create_task(handle_action_task())
reply_loop_info = None
reply_text_from_reply = ""
action_success = False
action_reply_text = ""
action_command = ""
# 并行执行所有任务
if background_reply_task:
results = await asyncio.gather(
background_reply_task, background_action_task, return_exceptions=True
)
# 处理回复任务结果
reply_result = results[0]
if isinstance(reply_result, BaseException):
logger.error(f"{self.log_prefix} 回复任务执行异常: {reply_result}")
elif reply_result and reply_result[0] is not None:
reply_loop_info, reply_text_from_reply, _ = reply_result
loop_info = {
"loop_plan_info": {
"action_result": plan_result.get("action_result", {}),
},
"loop_action_info": {
"action_taken": success,
"reply_text": reply_text,
"command": command,
"taken_time": time.time(),
},
}
# 处理动作任务结果
action_task_result = results[1]
if isinstance(action_task_result, BaseException):
logger.error(f"{self.log_prefix} 动作任务执行异常: {action_task_result}")
else:
action_success, action_reply_text, action_command = action_task_result
else:
results = await asyncio.gather(background_action_task, return_exceptions=True)
# 只有动作任务
action_task_result = results[0]
if isinstance(action_task_result, BaseException):
logger.error(f"{self.log_prefix} 动作任务执行异常: {action_task_result}")
else:
action_success, action_reply_text, action_command = action_task_result
if loop_info["loop_action_info"]["command"] == "stop_focus_chat":
logger.info(f"{self.log_prefix} 麦麦决定停止专注聊天")
return False
# 停止该聊天模式的循环
# 构建最终的循环信息
if reply_loop_info:
# 如果有回复信息使用回复的loop_info作为基础
loop_info = reply_loop_info
# 更新动作执行信息
loop_info["loop_action_info"].update(
{
"action_taken": action_success,
"command": action_command,
"taken_time": time.time(),
}
)
reply_text = reply_text_from_reply
else:
# 没有回复信息构建纯动作的loop_info
loop_info = {
"loop_plan_info": {
"action_result": plan_result.get("action_result", {}),
},
"loop_action_info": {
"action_taken": action_success,
"reply_text": action_reply_text,
"command": action_command,
"taken_time": time.time(),
},
}
reply_text = action_reply_text
if ENABLE_S4U:
await stop_typing()
await mai_thinking_manager.get_mai_think(self.stream_id).do_think_after_response(reply_text)
self.end_cycle(loop_info, cycle_timers)
self.print_cycle_info(cycle_timers)
@@ -406,8 +590,16 @@ class HeartFChatting:
if self.loop_mode == ChatMode.NORMAL:
await self.willing_manager.after_generate_reply_handle(message_data.get("message_id", ""))
# 管理no_reply计数器当执行了非no_reply动作时重置计数器
if action_type != "no_reply" and action_type != "no_action":
# 导入NoReplyAction并重置计数器
NoReplyAction.reset_consecutive_count()
logger.info(f"{self.log_prefix} 执行了{action_type}动作重置no_reply计数器")
return True
elif action_type == "no_action":
# 当执行回复动作时也重置no_reply计数器s
NoReplyAction.reset_consecutive_count()
logger.info(f"{self.log_prefix} 执行了回复动作重置no_reply计数器")
return True
@@ -435,7 +627,7 @@ class HeartFChatting:
action: str,
reasoning: str,
action_data: dict,
cycle_timers: dict,
cycle_timers: Dict[str, float],
thinking_id: str,
action_message: dict,
) -> tuple[bool, str, str]:
@@ -501,7 +693,7 @@ class HeartFChatting:
"兴趣"模式下,判断是否回复并生成内容。
"""
interested_rate = (message_data.get("interest_value") or 0.0) * self.willing_amplifier
interested_rate = (message_data.get("interest_value") or 0.0) * global_config.chat.willing_amplifier
self.willing_manager.setup(message_data, self.chat_stream)
@@ -515,8 +707,8 @@ class HeartFChatting:
reply_probability += additional_config["maimcore_reply_probability_gain"]
reply_probability = min(max(reply_probability, 0), 1) # 确保概率在 0-1 之间
talk_frequency = global_config.chat.get_current_talk_frequency(self.stream_id)
reply_probability = talk_frequency * reply_probability
talk_frequency = global_config.chat.get_current_talk_frequency(self.stream_id)
reply_probability = talk_frequency * reply_probability
# 处理表情包
if message_data.get("is_emoji") or message_data.get("is_picid"):
@@ -544,7 +736,11 @@ class HeartFChatting:
return False
async def _generate_response(
self, message_data: dict, available_actions: Optional[Dict[str, ActionInfo]], reply_to: str
self,
message_data: dict,
available_actions: Optional[Dict[str, ActionInfo]],
reply_to: str,
request_type: str = "chat.replyer.normal",
) -> Optional[list]:
"""生成普通回复"""
try:
@@ -552,8 +748,8 @@ class HeartFChatting:
chat_stream=self.chat_stream,
reply_to=reply_to,
available_actions=available_actions,
enable_tool=global_config.tool.enable_in_normal_chat,
request_type="chat.replyer.normal",
enable_tool=global_config.tool.enable_tool,
request_type=request_type,
)
if not success or not reply_set:
@@ -563,10 +759,10 @@ class HeartFChatting:
return reply_set
except Exception as e:
logger.error(f"[{self.log_prefix}] 回复生成出现错误:{str(e)} {traceback.format_exc()}")
logger.error(f"{self.log_prefix}回复生成出现错误:{str(e)} {traceback.format_exc()}")
return None
async def _send_response(self, reply_set, reply_to, thinking_start_time, message_data):
async def _send_response(self, reply_set, reply_to, thinking_start_time, message_data) -> str:
current_time = time.time()
new_message_count = message_api.count_new_messages(
chat_id=self.chat_stream.stream_id, start_time=thinking_start_time, end_time=current_time
@@ -578,13 +774,9 @@ class HeartFChatting:
need_reply = new_message_count >= random.randint(2, 4)
if need_reply:
logger.info(
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复"
)
logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复")
else:
logger.debug(
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,不使用引用回复"
)
logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,不使用引用回复")
reply_text = ""
first_replied = False

View File

@@ -1,10 +1,13 @@
import time
from typing import Optional, Dict, Any
from src.config.config import global_config
from src.common.message_repository import count_messages
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import get_chat_manager
from src.plugin_system.apis import send_api
from maim_message.message_base import GroupInfo
from src.common.message_repository import count_messages
logger = get_logger(__name__)
@@ -106,3 +109,30 @@ def get_recent_message_stats(minutes: float = 30, chat_id: Optional[str] = None)
bot_reply_count = count_messages(bot_filter)
return {"bot_reply_count": bot_reply_count, "total_message_count": total_message_count}
async def send_typing():
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
chat = await get_chat_manager().get_or_create_stream(
platform="amaidesu_default",
user_info=None,
group_info=group_info,
)
await send_api.custom_to_stream(
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
)
async def stop_typing():
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
chat = await get_chat_manager().get_or_create_stream(
platform="amaidesu_default",
user_info=None,
group_info=group_info,
)
await send_api.custom_to_stream(
message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False
)

View File

@@ -525,9 +525,9 @@ class EmojiManager:
如果文件已被删除,则执行对象的删除方法并从列表中移除
"""
try:
if not self.emoji_objects:
logger.warning("[检查] emoji_objects为空跳过完整性检查")
return
# if not self.emoji_objects:
# logger.warning("[检查] emoji_objects为空跳过完整性检查")
# return
total_count = len(self.emoji_objects)
self.emoji_num = total_count
@@ -707,6 +707,38 @@ class EmojiManager:
return emoji
return None # 如果循环结束还没找到,则返回 None
async def get_emoji_description_by_hash(self, emoji_hash: str) -> Optional[str]:
"""根据哈希值获取已注册表情包的描述
Args:
emoji_hash: 表情包的哈希值
Returns:
Optional[str]: 表情包描述如果未找到则返回None
"""
try:
# 先从内存中查找
emoji = await self.get_emoji_from_manager(emoji_hash)
if emoji and emoji.description:
logger.info(f"[缓存命中] 从内存获取表情包描述: {emoji.description[:50]}...")
return emoji.description
# 如果内存中没有,从数据库查找
self._ensure_db()
try:
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
if emoji_record and emoji_record.description:
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
return emoji_record.description
except Exception as e:
logger.error(f"从数据库查询表情包描述时出错: {e}")
return None
except Exception as e:
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}")
return None
async def delete_emoji(self, emoji_hash: str) -> bool:
"""根据哈希值删除表情包

View File

@@ -51,7 +51,7 @@ def init_prompt() -> None:
"想说明某个具体的事实观点,但懒得明说,或者不便明说,或表达一种默契",使用"懂的都懂"
"当涉及游戏相关时,表示意外的夸赞,略带戏谑意味"时,使用"这么强!"
注意不要总结你自己SELF的发言
注意不要总结你自己SELF的发言
现在请你概括
"""
Prompt(learn_style_prompt, "learn_style_prompt")
@@ -330,48 +330,8 @@ class ExpressionLearner:
"""
current_time = time.time()
# 全局衰减所有已存储的表达方式
for type in ["style", "grammar"]:
base_dir = os.path.join("data", "expression", f"learnt_{type}")
if not os.path.exists(base_dir):
logger.debug(f"目录不存在,跳过衰减: {base_dir}")
continue
try:
chat_ids = os.listdir(base_dir)
logger.debug(f"{base_dir} 中找到 {len(chat_ids)} 个聊天ID目录进行衰减")
except Exception as e:
logger.error(f"读取目录失败 {base_dir}: {e}")
continue
for chat_id in chat_ids:
file_path = os.path.join(base_dir, chat_id, "expressions.json")
if not os.path.exists(file_path):
continue
try:
with open(file_path, "r", encoding="utf-8") as f:
expressions = json.load(f)
if not isinstance(expressions, list):
logger.warning(f"表达方式文件格式错误,跳过衰减: {file_path}")
continue
# 应用全局衰减
decayed_expressions = self.apply_decay_to_expressions(expressions, current_time)
# 保存衰减后的结果
with open(file_path, "w", encoding="utf-8") as f:
json.dump(decayed_expressions, f, ensure_ascii=False, indent=2)
logger.debug(f"已对 {file_path} 应用衰减,剩余 {len(decayed_expressions)} 个表达方式")
except json.JSONDecodeError as e:
logger.error(f"JSON解析失败跳过衰减 {file_path}: {e}")
except PermissionError as e:
logger.error(f"权限不足,无法更新 {file_path}: {e}")
except Exception as e:
logger.error(f"全局衰减{type}表达方式失败 {file_path}: {e}")
continue
# 全局衰减所有已存储的表达方式(直接操作数据库)
self._apply_global_decay_to_database(current_time)
learnt_style: Optional[List[Tuple[str, str, str]]] = []
learnt_grammar: Optional[List[Tuple[str, str, str]]] = []
@@ -388,6 +348,42 @@ class ExpressionLearner:
return learnt_style, learnt_grammar
def _apply_global_decay_to_database(self, current_time: float) -> None:
"""
对数据库中的所有表达方式应用全局衰减
"""
try:
# 获取所有表达方式
all_expressions = Expression.select()
updated_count = 0
deleted_count = 0
for expr in all_expressions:
# 计算时间差
last_active = expr.last_active_time
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
# 计算衰减值
decay_value = self.calculate_decay_factor(time_diff_days)
new_count = max(0.01, expr.count - decay_value)
if new_count <= 0.01:
# 如果count太小删除这个表达方式
expr.delete_instance()
deleted_count += 1
else:
# 更新count
expr.count = new_count
expr.save()
updated_count += 1
if updated_count > 0 or deleted_count > 0:
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
except Exception as e:
logger.error(f"数据库全局衰减失败: {e}")
def calculate_decay_factor(self, time_diff_days: float) -> float:
"""
计算衰减值
@@ -410,30 +406,6 @@ class ExpressionLearner:
return min(0.01, decay)
def apply_decay_to_expressions(
self, expressions: List[Dict[str, Any]], current_time: float
) -> List[Dict[str, Any]]:
"""
对表达式列表应用衰减
返回衰减后的表达式列表移除count小于0的项
"""
result = []
for expr in expressions:
# 确保last_active_time存在如果不存在则使用current_time
if "last_active_time" not in expr:
expr["last_active_time"] = current_time
last_active = expr["last_active_time"]
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
decay_value = self.calculate_decay_factor(time_diff_days)
expr["count"] = max(0.01, expr.get("count", 1) - decay_value)
if expr["count"] > 0:
result.append(expr)
return result
async def learn_and_store(self, type: str, num: int = 10) -> List[Tuple[str, str, str]]:
# sourcery skip: use-join
"""

View File

@@ -2,7 +2,7 @@ import json
import time
import random
from typing import List, Dict, Tuple, Optional
from typing import List, Dict, Tuple, Optional, Any
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
@@ -117,36 +117,42 @@ class ExpressionSelector:
def get_random_expressions(
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
style_exprs = []
grammar_exprs = []
for cid in related_chat_ids:
style_query = Expression.select().where((Expression.chat_id == cid) & (Expression.type == "style"))
grammar_query = Expression.select().where((Expression.chat_id == cid) & (Expression.type == "grammar"))
style_exprs.extend([
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": cid,
"type": "style",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
} for expr in style_query
])
grammar_exprs.extend([
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": cid,
"type": "grammar",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
} for expr in grammar_query
])
# 优化一次性查询所有相关chat_id的表达方式
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
)
grammar_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar")
)
style_exprs = [
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"type": "style",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
} for expr in style_query
]
grammar_exprs = [
{
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"type": "grammar",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
} for expr in grammar_query
]
style_num = int(total_num * style_percentage)
grammar_num = int(total_num * grammar_percentage)
# 按权重抽样使用count作为权重
@@ -162,7 +168,7 @@ class ExpressionSelector:
selected_grammar = []
return selected_style, selected_grammar
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, str]], increment: float = 0.1):
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
"""对一批表达方式更新count值按chat_id+type分组后一次性写入数据库"""
if not expressions_to_update:
return
@@ -203,7 +209,7 @@ class ExpressionSelector:
max_num: int = 10,
min_num: int = 5,
target_message: Optional[str] = None,
) -> List[Dict[str, str]]:
) -> List[Dict[str, Any]]:
# sourcery skip: inline-variable, list-comprehension
"""使用LLM选择适合的表达方式"""
@@ -273,6 +279,7 @@ class ExpressionSelector:
if not isinstance(result, dict) or "selected_situations" not in result:
logger.error("LLM返回格式错误")
logger.info(f"LLM返回结果: \n{content}")
return []
selected_indices = result["selected_situations"]

View File

@@ -12,6 +12,7 @@ from src.chat.message_receive.storage import MessageStorage
from src.chat.heart_flow.heartflow import heartflow
from src.chat.utils.utils import is_mentioned_bot_in_message
from src.chat.utils.timer_calculator import Timer
from src.chat.utils.chat_message_builder import replace_user_references_sync
from src.common.logger import get_logger
from src.person_info.relationship_manager import get_relationship_manager
from src.mood.mood_manager import mood_manager
@@ -56,16 +57,41 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
with Timer("记忆激活"):
interested_rate = await hippocampus_manager.get_activate_from_text(
message.processed_plain_text,
max_depth= 5,
fast_retrieval=False,
)
logger.debug(f"记忆激活率: {interested_rate:.2f}")
text_len = len(message.processed_plain_text)
# 根据文本长度调整兴趣度,长度越大兴趣度越高但增长率递减最低0.01最高0.05
# 采用对数函数实现递减增长
base_interest = 0.01 + (0.05 - 0.01) * (math.log10(text_len + 1) / math.log10(1000 + 1))
base_interest = min(max(base_interest, 0.01), 0.05)
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
# 基于实际分布0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%)
if text_len == 0:
base_interest = 0.01 # 空消息最低兴趣度
elif text_len <= 5:
# 1-5字符线性增长 0.01 -> 0.03
base_interest = 0.01 + (text_len - 1) * (0.03 - 0.01) / 4
elif text_len <= 10:
# 6-10字符线性增长 0.03 -> 0.06
base_interest = 0.03 + (text_len - 5) * (0.06 - 0.03) / 5
elif text_len <= 20:
# 11-20字符线性增长 0.06 -> 0.12
base_interest = 0.06 + (text_len - 10) * (0.12 - 0.06) / 10
elif text_len <= 30:
# 21-30字符线性增长 0.12 -> 0.18
base_interest = 0.12 + (text_len - 20) * (0.18 - 0.12) / 10
elif text_len <= 50:
# 31-50字符线性增长 0.18 -> 0.22
base_interest = 0.18 + (text_len - 30) * (0.22 - 0.18) / 20
elif text_len <= 100:
# 51-100字符线性增长 0.22 -> 0.26
base_interest = 0.22 + (text_len - 50) * (0.26 - 0.22) / 50
else:
# 100+字符:对数增长 0.26 -> 0.3,增长率递减
base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901
# 确保在范围内
base_interest = min(max(base_interest, 0.01), 0.3)
interested_rate += base_interest
@@ -123,8 +149,15 @@ class HeartFCMessageReceiver:
# 如果消息中包含图片标识,则将 [picid:...] 替换为 [图片]
picid_pattern = r"\[picid:([^\]]+)\]"
processed_plain_text = re.sub(picid_pattern, "[图片]", message.processed_plain_text)
# 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式
processed_plain_text = replace_user_references_sync(
processed_plain_text,
message.message_info.platform, # type: ignore
replace_bot_name=True
)
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") # type: ignore
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}[兴趣度:{interested_rate:.2f}]") # type: ignore
logger.debug(f"[{mes_name}][当前时段回复频率: {current_talk_frequency}]")

View File

@@ -224,10 +224,16 @@ class Hippocampus:
return hash((source, target))
@staticmethod
def find_topic_llm(text, topic_num):
def find_topic_llm(text: str, topic_num: int | list[int]):
# sourcery skip: inline-immediately-returned-variable
topic_num_str = ""
if isinstance(topic_num, list):
topic_num_str = f"{topic_num[0]}-{topic_num[1]}"
else:
topic_num_str = topic_num
prompt = (
f"这是一段文字:\n{text}\n\n请你从这段话中总结出最多{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,"
f"这是一段文字:\n{text}\n\n请你从这段话中总结出最多{topic_num_str}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,"
f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
f"如果确定找不出主题或者没有明显主题,返回<none>。"
)
@@ -300,6 +306,60 @@ class Hippocampus:
memories.sort(key=lambda x: x[2], reverse=True)
return memories
async def get_keywords_from_text(self, text: str) -> list:
"""从文本中提取关键词。
Args:
text (str): 输入文本
fast_retrieval (bool, optional): 是否使用快速检索。默认为False。
如果为True使用jieba分词提取关键词速度更快但可能不够准确。
如果为False使用LLM提取关键词速度较慢但更准确。
"""
if not text:
return []
# 使用LLM提取关键词 - 根据详细文本长度分布优化topic_num计算
text_length = len(text)
topic_num: int | list[int] = 0
if text_length <= 5:
words = jieba.cut(text)
keywords = [word for word in words if len(word) > 1]
keywords = list(set(keywords))[:3] # 限制最多3个关键词
if keywords:
logger.info(f"提取关键词: {keywords}")
return keywords
elif text_length <= 10:
topic_num = [1, 3] # 6-10字符: 1个关键词 (27.18%的文本)
elif text_length <= 20:
topic_num = [2, 4] # 11-20字符: 2个关键词 (22.76%的文本)
elif text_length <= 30:
topic_num = [3, 5] # 21-30字符: 3个关键词 (10.33%的文本)
elif text_length <= 50:
topic_num = [4, 5] # 31-50字符: 4个关键词 (9.79%的文本)
else:
topic_num = 5 # 51+字符: 5个关键词 (其余长文本)
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(
self.find_topic_llm(text, topic_num)
)
# 提取关键词
keywords = re.findall(r"<([^>]+)>", topics_response)
if not keywords:
keywords = []
else:
keywords = [
keyword.strip()
for keyword in ",".join(keywords).replace("", ",").replace("", ",").replace(" ", ",").split(",")
if keyword.strip()
]
if keywords:
logger.info(f"提取关键词: {keywords}")
return keywords
async def get_memory_from_text(
self,
text: str,
@@ -325,39 +385,7 @@ class Hippocampus:
- memory_items: list, 该主题下的记忆项列表
- similarity: float, 与文本的相似度
"""
if not text:
return []
if fast_retrieval:
# 使用jieba分词提取关键词
words = jieba.cut(text)
# 过滤掉停用词和单字词
keywords = [word for word in words if len(word) > 1]
# 去重
keywords = list(set(keywords))
# 限制关键词数量
logger.debug(f"提取关键词: {keywords}")
else:
# 使用LLM提取关键词
topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量
# logger.info(f"提取关键词数量: {topic_num}")
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(
self.find_topic_llm(text, topic_num)
)
# 提取关键词
keywords = re.findall(r"<([^>]+)>", topics_response)
if not keywords:
keywords = []
else:
keywords = [
keyword.strip()
for keyword in ",".join(keywords).replace("", ",").replace("", ",").replace(" ", ",").split(",")
if keyword.strip()
]
# logger.info(f"提取的关键词: {', '.join(keywords)}")
keywords = await self.get_keywords_from_text(text)
# 过滤掉不存在于记忆图中的关键词
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
@@ -679,38 +707,7 @@ class Hippocampus:
Returns:
float: 激活节点数与总节点数的比值
"""
if not text:
return 0
if fast_retrieval:
# 使用jieba分词提取关键词
words = jieba.cut(text)
# 过滤掉停用词和单字词
keywords = [word for word in words if len(word) > 1]
# 去重
keywords = list(set(keywords))
# 限制关键词数量
keywords = keywords[:5]
else:
# 使用LLM提取关键词
topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量
# logger.info(f"提取关键词数量: {topic_num}")
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(
self.find_topic_llm(text, topic_num)
)
# 提取关键词
keywords = re.findall(r"<([^>]+)>", topics_response)
if not keywords:
keywords = []
else:
keywords = [
keyword.strip()
for keyword in ",".join(keywords).replace("", ",").replace("", ",").replace(" ", ",").split(",")
if keyword.strip()
]
# logger.info(f"提取的关键词: {', '.join(keywords)}")
keywords = await self.get_keywords_from_text(text)
# 过滤掉不存在于记忆图中的关键词
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
@@ -727,7 +724,7 @@ class Hippocampus:
for keyword in valid_keywords:
logger.debug(f"开始以关键词 '{keyword}' 为中心进行扩散检索 (最大深度: {max_depth}):")
# 初始化激活值
activation_values = {keyword: 1.0}
activation_values = {keyword: 1.5}
# 记录已访问的节点
visited_nodes = {keyword}
# 待处理的节点队列,每个元素是(节点, 激活值, 当前深度)
@@ -1315,6 +1312,7 @@ class ParahippocampalGyrus:
return compressed_memory, similar_topics_dict
async def operation_build_memory(self):
# sourcery skip: merge-list-appends-into-extend
logger.info("------------------------------------开始构建记忆--------------------------------------")
start_time = time.time()
memory_samples = self.hippocampus.entorhinal_cortex.get_memory_sample()

View File

@@ -444,7 +444,7 @@ class MessageSending(MessageProcessBase):
is_emoji: bool = False,
thinking_start_time: float = 0,
apply_set_reply_logic: bool = False,
reply_to: str = None, # type: ignore
reply_to: Optional[str] = None,
):
# 调用父类初始化
super().__init__(

View File

@@ -3,7 +3,7 @@ from src.plugin_system.base.base_action import BaseAction
from src.chat.message_receive.chat_stream import ChatStream
from src.common.logger import get_logger
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.base.component_types import ComponentType, ActionActivationType, ChatMode, ActionInfo
from src.plugin_system.base.component_types import ComponentType, ActionInfo
logger = get_logger("action_manager")
@@ -15,11 +15,6 @@ class ActionManager:
现在统一使用新插件系统,简化了原有的新旧兼容逻辑。
"""
# 类常量
DEFAULT_RANDOM_PROBABILITY = 0.3
DEFAULT_MODE = ChatMode.ALL
DEFAULT_ACTIVATION_TYPE = ActionActivationType.ALWAYS
def __init__(self):
"""初始化动作管理器"""

View File

@@ -174,7 +174,7 @@ class ActionModifier:
continue # 总是激活,无需处理
elif activation_type == ActionActivationType.RANDOM:
probability = action_info.random_activation_probability or ActionManager.DEFAULT_RANDOM_PROBABILITY
probability = action_info.random_activation_probability
if random.random() >= probability:
reason = f"RANDOM类型未触发概率{probability}"
deactivated_actions.append((action_name, reason))

View File

@@ -33,10 +33,11 @@ def init_prompt():
{time_block}
{identity_block}
你现在需要根据聊天内容选择的合适的action来参与聊天。
{chat_context_description},以下是具体的聊天内容
{chat_context_description},以下是具体的聊天内容
{chat_content_block}
{moderation_prompt}
现在请你根据{by_what}选择合适的action和触发action的消息:
@@ -45,7 +46,7 @@ def init_prompt():
{no_action_block}
{action_options_text}
你必须从上面列出的可用action中选择一个并说明触发action的消息id原因。
你必须从上面列出的可用action中选择一个并说明触发action的消息id不是消息原文和选择该action的原因。
请根据动作示例,以严格的 JSON 格式输出,且仅包含 JSON 内容:
""",
@@ -128,20 +129,6 @@ class ActionPlanner:
else:
logger.warning(f"{self.log_prefix}使用中的动作 {action_name} 未在已注册动作中找到")
# 如果没有可用动作或只有no_reply动作直接返回no_reply
# 因为现在reply是永远激活所以不需要空跳判定
# if not current_available_actions:
# action = "no_reply" if mode == ChatMode.FOCUS else "no_action"
# reasoning = "没有可用的动作"
# logger.info(f"{self.log_prefix}{reasoning}")
# return {
# "action_result": {
# "action_type": action,
# "action_data": action_data,
# "reasoning": reasoning,
# },
# }, None
# --- 构建提示词 (调用修改后的 PromptBuilder 方法) ---
prompt, message_id_list = await self.build_planner_prompt(
is_group_chat=is_group_chat, # <-- Pass HFC state
@@ -224,7 +211,7 @@ class ActionPlanner:
reasoning = f"Planner 内部处理错误: {outer_e}"
is_parallel = False
if action in current_available_actions:
if mode == ChatMode.NORMAL and action in current_available_actions:
is_parallel = current_available_actions[action].parallel_action
action_result = {
@@ -268,7 +255,7 @@ class ActionPlanner:
actions_before_now = get_actions_by_timestamp_with_chat(
chat_id=self.chat_id,
timestamp_start=time.time()-3600,
timestamp_start=time.time() - 3600,
timestamp_end=time.time(),
limit=5,
)
@@ -276,7 +263,7 @@ class ActionPlanner:
actions_before_now_block = build_readable_actions(
actions=actions_before_now,
)
actions_before_now_block = f"你刚刚选择并执行过的action是\n{actions_before_now_block}"
self.last_obs_time_mark = time.time()
@@ -288,7 +275,6 @@ class ActionPlanner:
if global_config.chat.at_bot_inevitable_reply:
mentioned_bonus = "\n- 有人提到你或者at你"
by_what = "聊天内容"
target_prompt = '\n "target_message_id":"触发action的消息id"'
no_action_block = f"""重要说明:
@@ -311,7 +297,7 @@ class ActionPlanner:
by_what = "聊天内容和用户的最新消息"
target_prompt = ""
no_action_block = """重要说明:
- 'no_action' 表示只进行普通聊天回复,不执行任何额外动作
- 'reply' 表示只进行普通聊天回复,不执行任何额外动作
- 其他action表示在普通回复的基础上执行相应的额外动作"""
chat_context_description = "你现在正在一个群聊中"

View File

@@ -17,7 +17,11 @@ from src.chat.message_receive.uni_message_sender import HeartFCSender
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
from src.chat.utils.utils import get_chat_type_and_target_info
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
from src.chat.utils.chat_message_builder import (
build_readable_messages,
get_raw_msg_before_timestamp_with_chat,
replace_user_references_sync,
)
from src.chat.express.expression_selector import expression_selector
from src.chat.knowledge.knowledge_lib import qa_manager
from src.chat.memory_system.memory_activator import MemoryActivator
@@ -25,42 +29,16 @@ from src.chat.memory_system.instant_memory import InstantMemory
from src.mood.mood_manager import mood_manager
from src.person_info.relationship_fetcher import relationship_fetcher_manager
from src.person_info.person_info import get_person_info_manager
from src.tools.tool_executor import ToolExecutor
from src.plugin_system.base.component_types import ActionInfo
logger = get_logger("replyer")
def init_prompt():
Prompt("你正在qq群里聊天下面是群里在聊的内容", "chat_target_group1")
Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1")
Prompt("在群里聊天", "chat_target_group2")
Prompt("{sender_name}聊天", "chat_target_private2")
Prompt("\n你有以下这些**知识**\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt")
Prompt(
"""
{expression_habits_block}
{tool_info_block}
{knowledge_prompt}
{memory_block}
{relation_info_block}
{extra_info_block}
{chat_target}
{time_block}
{chat_info}
{reply_target_block}
{identity}
{action_descriptions}
你正在{chat_target_2},你现在的心情是:{mood_state}
现在请你读读之前的聊天记录,并给出回复
{config_expression_style}。注意不要复读你说过的话
{keywords_reaction_prompt}
{moderation_prompt}
不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号()表情包at或 @等 )。只输出回复内容""",
"default_generator_prompt",
)
Prompt(
"""
@@ -109,7 +87,8 @@ def init_prompt():
{core_dialogue_prompt}
{reply_target_block}
对方最新发送的内容:{message_txt}
你现在的心情是:{mood_state}
{config_expression_style}
注意不要复读你说过的话
@@ -159,6 +138,8 @@ class DefaultReplyer:
self.heart_fc_sender = HeartFCSender()
self.memory_activator = MemoryActivator()
self.instant_memory = InstantMemory(chat_id=self.chat_stream.stream_id)
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor不然会循环依赖
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3)
def _select_weighted_model_config(self) -> Dict[str, Any]:
@@ -171,67 +152,49 @@ class DefaultReplyer:
async def generate_reply_with_context(
self,
reply_data: Optional[Dict[str, Any]] = None,
reply_to: str = "",
extra_info: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
enable_tool: bool = True,
enable_timeout: bool = False,
) -> Tuple[bool, Optional[str], Optional[str]]:
"""
回复器 (Replier): 核心逻辑,负责生成回复文本。
(已整合原 HeartFCGenerator 的功能)
回复器 (Replier): 负责生成回复文本的核心逻辑
Args:
reply_to: 回复对象,格式为 "发送者:消息内容"
extra_info: 额外信息,用于补充上下文
available_actions: 可用的动作信息字典
enable_tool: 是否启用工具调用
Returns:
Tuple[bool, Optional[str], Optional[str]]: (是否成功, 生成的回复内容, 使用的prompt)
"""
prompt = None
if available_actions is None:
available_actions = {}
try:
if not reply_data:
reply_data = {
"reply_to": reply_to,
"extra_info": extra_info,
}
for key, value in reply_data.items():
if not value:
logger.debug(f"回复数据跳过{key},生成回复时将忽略。")
# 3. 构建 Prompt
with Timer("构建Prompt", {}): # 内部计时器,可选保留
prompt = await self.build_prompt_reply_context(
reply_data=reply_data, # 传递action_data
reply_to=reply_to,
extra_info=extra_info,
available_actions=available_actions,
enable_timeout=enable_timeout,
enable_tool=enable_tool,
)
if not prompt:
logger.warning("构建prompt失败跳过回复生成")
return False, None, None
# 4. 调用 LLM 生成回复
content = None
reasoning_content = None
model_name = "unknown_model"
# TODO: 复活这里
# reasoning_content = None
# model_name = "unknown_model"
try:
with Timer("LLM生成", {}): # 内部计时器,可选保留
# 加权随机选择一个模型配置
selected_model_config = self._select_weighted_model_config()
# 兼容新旧格式的模型名称获取
model_display_name = selected_model_config.get('model_name', selected_model_config.get('name', 'N/A'))
logger.info(
f"使用模型生成回复: {model_display_name} (选中概率: {selected_model_config.get('weight', 1.0)})"
)
express_model = LLMRequest(
model=selected_model_config,
request_type=self.request_type,
)
if global_config.debug.show_prompt:
logger.info(f"\n{prompt}\n")
else:
logger.debug(f"\n{prompt}\n")
content, (reasoning_content, model_name) = await express_model.generate_response_async(prompt)
logger.debug(f"replyer生成内容: {content}")
content = await self.llm_generate_content(prompt)
logger.debug(f"replyer生成内容: {content}")
except Exception as llm_e:
# 精简报错信息
@@ -247,73 +210,62 @@ class DefaultReplyer:
async def rewrite_reply_with_context(
self,
reply_data: Dict[str, Any],
raw_reply: str = "",
reason: str = "",
reply_to: str = "",
relation_info: str = "",
) -> Tuple[bool, Optional[str]]:
return_prompt: bool = False,
) -> Tuple[bool, Optional[str], Optional[str]]:
"""
表达器 (Expressor): 核心逻辑,负责生成回复文本。
表达器 (Expressor): 负责重写和优化回复文本。
Args:
raw_reply: 原始回复内容
reason: 回复原因
reply_to: 回复对象,格式为 "发送者:消息内容"
relation_info: 关系信息
Returns:
Tuple[bool, Optional[str]]: (是否成功, 重写后的回复内容)
"""
try:
if not reply_data:
reply_data = {
"reply_to": reply_to,
"relation_info": relation_info,
}
with Timer("构建Prompt", {}): # 内部计时器,可选保留
prompt = await self.build_prompt_rewrite_context(
reply_data=reply_data,
raw_reply=raw_reply,
reason=reason,
reply_to=reply_to,
)
content = None
reasoning_content = None
model_name = "unknown_model"
# TODO: 复活这里
# reasoning_content = None
# model_name = "unknown_model"
if not prompt:
logger.error("Prompt 构建失败,无法生成回复。")
return False, None
return False, None, None
try:
with Timer("LLM生成", {}): # 内部计时器,可选保留
# 加权随机选择一个模型配置
selected_model_config = self._select_weighted_model_config()
# 兼容新旧格式的模型名称获取
model_display_name = selected_model_config.get('model_name', selected_model_config.get('name', 'N/A'))
logger.info(
f"使用模型重写回复: {model_display_name} (选中概率: {selected_model_config.get('weight', 1.0)})"
)
express_model = LLMRequest(
model=selected_model_config,
request_type=self.request_type,
)
content, (reasoning_content, model_name) = await express_model.generate_response_async(prompt)
logger.info(f"想要表达:{raw_reply}||理由:{reason}||生成回复: {content}\n")
content = await self.llm_generate_content(prompt)
logger.info(f"想要表达:{raw_reply}||理由:{reason}||生成回复: {content}\n")
except Exception as llm_e:
# 精简报错信息
logger.error(f"LLM 生成失败: {llm_e}")
return False, None # LLM 调用失败则无法生成回复
return False, None, prompt if return_prompt else None # LLM 调用失败则无法生成回复
return True, content
return True, content, prompt if return_prompt else None
except Exception as e:
logger.error(f"回复生成意外失败: {e}")
traceback.print_exc()
return False, None
return False, None, prompt if return_prompt else None
async def build_relation_info(self, reply_data=None):
async def build_relation_info(self, reply_to: str = ""):
if not global_config.relationship.enable_relationship:
return ""
relationship_fetcher = relationship_fetcher_manager.get_fetcher(self.chat_stream.stream_id)
if not reply_data:
if not reply_to:
return ""
reply_to = reply_data.get("reply_to", "")
sender, text = self._parse_reply_target(reply_to)
if not sender or not text:
return ""
@@ -327,7 +279,16 @@ class DefaultReplyer:
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
async def build_expression_habits(self, chat_history, target):
async def build_expression_habits(self, chat_history: str, target: str) -> str:
"""构建表达习惯块
Args:
chat_history: 聊天历史记录
target: 目标消息内容
Returns:
str: 表达习惯信息字符串
"""
if not global_config.expression.enable_expression:
return ""
@@ -360,54 +321,65 @@ class DefaultReplyer:
expression_habits_block = ""
expression_habits_title = ""
if style_habits_str.strip():
expression_habits_title = "你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:"
expression_habits_title = (
"你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:"
)
expression_habits_block += f"{style_habits_str}\n"
if grammar_habits_str.strip():
expression_habits_title = "你可以选择下面的句法进行回复,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式使用:"
expression_habits_title = (
"你可以选择下面的句法进行回复,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式使用:"
)
expression_habits_block += f"{grammar_habits_str}\n"
if style_habits_str.strip() and grammar_habits_str.strip():
expression_habits_title = "你可以参考以下的语言习惯和句法,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式结合到你的回复中:"
expression_habits_block = f"{expression_habits_title}\n{expression_habits_block}"
return expression_habits_block
return f"{expression_habits_title}\n{expression_habits_block}"
async def build_memory_block(self, chat_history, target):
async def build_memory_block(self, chat_history: str, target: str) -> str:
"""构建记忆块
Args:
chat_history: 聊天历史记录
target: 目标消息内容
Returns:
str: 记忆信息字符串
"""
if not global_config.memory.enable_memory:
return ""
instant_memory = None
running_memories = await self.memory_activator.activate_memory_with_chat_history(
target_message=target, chat_history_prompt=chat_history
)
if global_config.memory.enable_instant_memory:
asyncio.create_task(self.instant_memory.create_and_store_memory(chat_history))
instant_memory = await self.instant_memory.get_memory(target)
logger.info(f"即时记忆:{instant_memory}")
if not running_memories:
return ""
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
for running_memory in running_memories:
memory_str += f"- {running_memory['content']}\n"
if instant_memory:
memory_str += f"- {instant_memory}\n"
return memory_str
async def build_tool_info(self, chat_history, reply_data: Optional[Dict], enable_tool: bool = True):
async def build_tool_info(self, chat_history: str, reply_to: str = "", enable_tool: bool = True) -> str:
"""构建工具信息块
Args:
reply_data: 回复数据,包含要回复的消息内容
chat_history: 聊天历史
chat_history: 聊天历史记录
reply_to: 回复对象,格式为 "发送者:消息内容"
enable_tool: 是否启用工具调用
Returns:
str: 工具信息字符串
@@ -416,10 +388,9 @@ class DefaultReplyer:
if not enable_tool:
return ""
if not reply_data:
if not reply_to:
return ""
reply_to = reply_data.get("reply_to", "")
sender, text = self._parse_reply_target(reply_to)
if not text:
@@ -442,7 +413,7 @@ class DefaultReplyer:
tool_info_str += "以上是你获取到的实时信息,请在回复时参考这些信息。"
logger.info(f"获取到 {len(tool_results)} 个工具结果")
return tool_info_str
else:
logger.debug("未获取到任何工具结果")
@@ -452,7 +423,15 @@ class DefaultReplyer:
logger.error(f"工具信息获取失败: {e}")
return ""
def _parse_reply_target(self, target_message: str) -> tuple:
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
"""解析回复目标消息
Args:
target_message: 目标消息,格式为 "发送者:消息内容""发送者:消息内容"
Returns:
Tuple[str, str]: (发送者名称, 消息内容)
"""
sender = ""
target = ""
# 添加None检查防止NoneType错误
@@ -466,14 +445,22 @@ class DefaultReplyer:
target = parts[1].strip()
return sender, target
async def build_keywords_reaction_prompt(self, target):
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
"""构建关键词反应提示
Args:
target: 目标消息内容
Returns:
str: 关键词反应提示字符串
"""
# 关键词检测与反应
keywords_reaction_prompt = ""
try:
# 添加None检查防止NoneType错误
if target is None:
return keywords_reaction_prompt
# 处理关键词规则
for rule in global_config.keyword_reaction.keyword_rules:
if any(keyword in target for keyword in rule.keywords):
@@ -500,15 +487,25 @@ class DefaultReplyer:
return keywords_reaction_prompt
async def _time_and_run_task(self, coroutine, name: str):
"""一个简单的帮助函数,用于计时运行异步任务,返回任务名、结果和耗时"""
async def _time_and_run_task(self, coroutine, name: str) -> Tuple[str, Any, float]:
"""计时运行异步任务的辅助函数
Args:
coroutine: 要执行的协程
name: 任务名称
Returns:
Tuple[str, Any, float]: (任务名称, 任务结果, 执行耗时)
"""
start_time = time.time()
result = await coroutine
end_time = time.time()
duration = end_time - start_time
return name, result, duration
def build_s4u_chat_history_prompts(self, message_list_before_now: list, target_user_id: str) -> tuple[str, str]:
def build_s4u_chat_history_prompts(
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str
) -> Tuple[str, str]:
"""
构建 s4u 风格的分离对话 prompt
@@ -517,7 +514,7 @@ class DefaultReplyer:
target_user_id: 目标用户ID当前对话对象
Returns:
tuple: (核心对话prompt, 背景对话prompt)
Tuple[str, str]: (核心对话prompt, 背景对话prompt)
"""
core_dialogue_list = []
background_dialogue_list = []
@@ -536,7 +533,7 @@ class DefaultReplyer:
# 其他用户的对话
background_dialogue_list.append(msg_dict)
except Exception as e:
logger.error(f"![1753364551656](image/default_generator/1753364551656.png)记录: {msg_dict}, 错误: {e}")
logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}")
# 构建背景对话 prompt
background_dialogue_prompt = ""
@@ -581,8 +578,25 @@ class DefaultReplyer:
sender: str,
target: str,
chat_info: str,
):
"""构建 mai_think 上下文信息"""
) -> Any:
"""构建 mai_think 上下文信息
Args:
chat_id: 聊天ID
memory_block: 记忆块内容
relation_info: 关系信息
time_block: 时间块内容
chat_target_1: 聊天目标1
chat_target_2: 聊天目标2
mood_prompt: 情绪提示
identity_block: 身份块内容
sender: 发送者名称
target: 目标消息内容
chat_info: 聊天信息
Returns:
Any: mai_think 实例
"""
mai_think = mai_thinking_manager.get_mai_think(chat_id)
mai_think.memory_block = memory_block
mai_think.relation_info_block = relation_info
@@ -598,21 +612,20 @@ class DefaultReplyer:
async def build_prompt_reply_context(
self,
reply_data: Dict[str, Any],
reply_to: str,
extra_info: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
enable_timeout: bool = False,
enable_tool: bool = True,
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
"""
构建回复器上下文
Args:
reply_data: 回复数据
replay_data 包含以下字段:
structured_info: 结构化信息,一般是工具调用获得的信息
reply_to: 回复对象
extra_info/extra_info_block: 额外信息
reply_to: 回复对象,格式为 "发送者:消息内容"
extra_info: 额外信息,用于补充上下文
available_actions: 可用动作
enable_timeout: 是否启用超时处理
enable_tool: 是否启用工具调用
Returns:
str: 构建好的上下文
@@ -623,9 +636,7 @@ class DefaultReplyer:
chat_id = chat_stream.stream_id
person_info_manager = get_person_info_manager()
is_group_chat = bool(chat_stream.group_info)
reply_to = reply_data.get("reply_to", "none")
extra_info_block = reply_data.get("extra_info", "") or reply_data.get("extra_info_block", "")
if global_config.mood.enable_mood:
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
mood_prompt = chat_mood.mood_state
@@ -633,6 +644,15 @@ class DefaultReplyer:
mood_prompt = ""
sender, target = self._parse_reply_target(reply_to)
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(sender)
user_id = person_info_manager.get_value_sync(person_id, "user_id")
platform = chat_stream.platform
if user_id == global_config.bot.qq_account and platform == global_config.bot.platform:
logger.warning("选取了自身作为回复对象跳过构建prompt")
return ""
target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True)
# 构建action描述 (如果启用planner)
action_descriptions = ""
@@ -649,21 +669,6 @@ class DefaultReplyer:
limit=global_config.chat.max_context_size * 2,
)
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=global_config.chat.max_context_size,
)
chat_talking_prompt = build_readable_messages(
message_list_before_now,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
show_actions=True,
)
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
@@ -683,25 +688,21 @@ class DefaultReplyer:
self._time_and_run_task(
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
),
self._time_and_run_task(
self.build_relation_info(reply_data), "relation_info"
),
self._time_and_run_task(self.build_relation_info(reply_to), "relation_info"),
self._time_and_run_task(self.build_memory_block(chat_talking_prompt_short, target), "memory_block"),
self._time_and_run_task(
self.build_tool_info(chat_talking_prompt_short, reply_data, enable_tool=enable_tool), "tool_info"
),
self._time_and_run_task(
get_prompt_info(target, threshold=0.38), "prompt_info"
self.build_tool_info(chat_talking_prompt_short, reply_to, enable_tool=enable_tool), "tool_info"
),
self._time_and_run_task(get_prompt_info(target, threshold=0.38), "prompt_info"),
)
# 任务名称中英文映射
task_name_mapping = {
"expression_habits": "选取表达方式",
"relation_info": "感受关系",
"relation_info": "感受关系",
"memory_block": "回忆",
"tool_info": "使用工具",
"prompt_info": "获取知识"
"prompt_info": "获取知识",
}
# 处理结果
@@ -723,8 +724,8 @@ class DefaultReplyer:
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
if extra_info_block:
extra_info_block = f"以下是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策\n{extra_info_block}\n以上是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策"
if extra_info:
extra_info_block = f"以下是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策\n{extra_info}\n以上是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策"
else:
extra_info_block = ""
@@ -779,116 +780,74 @@ class DefaultReplyer:
# 根据sender通过person_info_manager反向查找person_id再获取user_id
person_id = person_info_manager.get_person_id_by_person_name(sender)
# 根据配置选择使用哪种 prompt 构建模式
if global_config.chat.use_s4u_prompt_mode and person_id:
# 使用 s4u 对话构建模式:分离当前对话对象和其他对话
try:
user_id_value = await person_info_manager.get_value(person_id, "user_id")
if user_id_value:
target_user_id = str(user_id_value)
except Exception as e:
logger.warning(f"无法从person_id {person_id} 获取user_id: {e}")
target_user_id = ""
# 使用 s4u 对话构建模式:分离当前对话对象和其他对话
try:
user_id_value = await person_info_manager.get_value(person_id, "user_id")
if user_id_value:
target_user_id = str(user_id_value)
except Exception as e:
logger.warning(f"无法从person_id {person_id} 获取user_id: {e}")
target_user_id = ""
# 构建分离的对话 prompt
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
message_list_before_now_long, target_user_id
)
self.build_mai_think_context(
chat_id=chat_id,
memory_block=memory_block,
relation_info=relation_info,
time_block=time_block,
chat_target_1=chat_target_1,
chat_target_2=chat_target_2,
mood_prompt=mood_prompt,
identity_block=identity_block,
sender=sender,
target=target,
chat_info=f"""
# 构建分离的对话 prompt
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
message_list_before_now_long, target_user_id
)
self.build_mai_think_context(
chat_id=chat_id,
memory_block=memory_block,
relation_info=relation_info,
time_block=time_block,
chat_target_1=chat_target_1,
chat_target_2=chat_target_2,
mood_prompt=mood_prompt,
identity_block=identity_block,
sender=sender,
target=target,
chat_info=f"""
{background_dialogue_prompt}
--------------------------------
{time_block}
这是你和{sender}的对话,你们正在交流中:
{core_dialogue_prompt}"""
)
{core_dialogue_prompt}""",
)
# 使用 s4u 风格的模板
template_name = "s4u_style_prompt"
# 使用 s4u 风格的模板
template_name = "s4u_style_prompt"
return await global_prompt_manager.format_prompt(
template_name,
expression_habits_block=expression_habits_block,
tool_info_block=tool_info,
knowledge_prompt=prompt_info,
memory_block=memory_block,
relation_info_block=relation_info,
extra_info_block=extra_info_block,
identity=identity_block,
action_descriptions=action_descriptions,
sender_name=sender,
mood_state=mood_prompt,
background_dialogue_prompt=background_dialogue_prompt,
time_block=time_block,
core_dialogue_prompt=core_dialogue_prompt,
reply_target_block=reply_target_block,
message_txt=target,
config_expression_style=global_config.expression.expression_style,
keywords_reaction_prompt=keywords_reaction_prompt,
moderation_prompt=moderation_prompt_block,
)
else:
self.build_mai_think_context(
chat_id=chat_id,
memory_block=memory_block,
relation_info=relation_info,
time_block=time_block,
chat_target_1=chat_target_1,
chat_target_2=chat_target_2,
mood_prompt=mood_prompt,
identity_block=identity_block,
sender=sender,
target=target,
chat_info=chat_talking_prompt
)
# 使用原有的模式
return await global_prompt_manager.format_prompt(
template_name,
expression_habits_block=expression_habits_block,
chat_target=chat_target_1,
chat_info=chat_talking_prompt,
memory_block=memory_block,
tool_info_block=tool_info,
knowledge_prompt=prompt_info,
extra_info_block=extra_info_block,
relation_info_block=relation_info,
time_block=time_block,
reply_target_block=reply_target_block,
moderation_prompt=moderation_prompt_block,
keywords_reaction_prompt=keywords_reaction_prompt,
identity=identity_block,
target_message=target,
sender_name=sender,
config_expression_style=global_config.expression.expression_style,
action_descriptions=action_descriptions,
chat_target_2=chat_target_2,
mood_state=mood_prompt,
)
return await global_prompt_manager.format_prompt(
template_name,
expression_habits_block=expression_habits_block,
tool_info_block=tool_info,
knowledge_prompt=prompt_info,
memory_block=memory_block,
relation_info_block=relation_info,
extra_info_block=extra_info_block,
identity=identity_block,
action_descriptions=action_descriptions,
sender_name=sender,
mood_state=mood_prompt,
background_dialogue_prompt=background_dialogue_prompt,
time_block=time_block,
core_dialogue_prompt=core_dialogue_prompt,
reply_target_block=reply_target_block,
message_txt=target,
config_expression_style=global_config.expression.expression_style,
keywords_reaction_prompt=keywords_reaction_prompt,
moderation_prompt=moderation_prompt_block,
)
async def build_prompt_rewrite_context(
self,
reply_data: Dict[str, Any],
raw_reply: str,
reason: str,
reply_to: str,
) -> str:
chat_stream = self.chat_stream
chat_id = chat_stream.stream_id
is_group_chat = bool(chat_stream.group_info)
reply_to = reply_data.get("reply_to", "none")
raw_reply = reply_data.get("raw_reply", "")
reason = reply_data.get("reason", "")
sender, target = self._parse_reply_target(reply_to)
# 添加情绪状态获取
@@ -915,7 +874,7 @@ class DefaultReplyer:
# 并行执行2个构建任务
expression_habits_block, relation_info = await asyncio.gather(
self.build_expression_habits(chat_talking_prompt_half, target),
self.build_relation_info(reply_data),
self.build_relation_info(reply_to),
)
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
@@ -1018,6 +977,31 @@ class DefaultReplyer:
display_message=display_message,
)
async def llm_generate_content(self, prompt: str) -> str:
with Timer("LLM生成", {}): # 内部计时器,可选保留
# 加权随机选择一个模型配置
selected_model_config = self._select_weighted_model_config()
model_display_name = selected_model_config.get('model_name') or selected_model_config.get('name', 'N/A')
logger.info(
f"使用模型生成回复: {model_display_name} (选中概率: {selected_model_config.get('weight', 1.0)})"
)
express_model = LLMRequest(
model=selected_model_config,
request_type=self.request_type,
)
if global_config.debug.show_prompt:
logger.info(f"\n{prompt}\n")
else:
logger.debug(f"\n{prompt}\n")
# TODO: 这里的_应该做出替换
content, _ = await express_model.generate_response_async(prompt)
logger.debug(f"replyer生成内容: {content}")
return content
def weighted_sample_no_replacement(items, weights, k) -> list:
"""
@@ -1075,10 +1059,8 @@ async def get_prompt_info(message: str, threshold: float):
related_info += found_knowledge_from_lpmm
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}")
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
# 格式化知识信息
formatted_prompt_info = await global_prompt_manager.format_prompt("knowledge_prompt", prompt_info=related_info)
return formatted_prompt_info
return f"你有以下这些**知识**\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n"
else:
logger.debug("从LPMM知识库获取知识失败可能是从未导入过知识返回空知识...")
return ""

View File

@@ -2,7 +2,7 @@ import time # 导入 time 模块以获取当前时间
import random
import re
from typing import List, Dict, Any, Tuple, Optional
from typing import List, Dict, Any, Tuple, Optional, Callable
from rich.traceback import install
from src.config.config import global_config
@@ -10,11 +10,161 @@ from src.common.message_repository import find_messages, count_messages
from src.common.database.database_model import ActionRecords
from src.common.database.database_model import Images
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from src.chat.utils.utils import translate_timestamp_to_human_readable,assign_message_ids
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
install(extra_lines=3)
def replace_user_references_sync(
content: str,
platform: str,
name_resolver: Optional[Callable[[str, str], str]] = None,
replace_bot_name: bool = True,
) -> str:
"""
替换内容中的用户引用格式,包括回复<aaa:bbb>和@<aaa:bbb>格式
Args:
content: 要处理的内容字符串
platform: 平台标识
name_resolver: 名称解析函数,接收(platform, user_id)参数,返回用户名称
如果为None则使用默认的person_info_manager
replace_bot_name: 是否将机器人的user_id替换为"机器人昵称(你)"
Returns:
str: 处理后的内容字符串
"""
if name_resolver is None:
person_info_manager = get_person_info_manager()
def default_resolver(platform: str, user_id: str) -> str:
# 检查是否是机器人自己
if replace_bot_name and user_id == global_config.bot.qq_account:
return f"{global_config.bot.nickname}(你)"
person_id = PersonInfoManager.get_person_id(platform, user_id)
return person_info_manager.get_value_sync(person_id, "person_name") or user_id # type: ignore
name_resolver = default_resolver
# 处理回复<aaa:bbb>格式
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
match = re.search(reply_pattern, content)
if match:
aaa = match[1]
bbb = match[2]
try:
# 检查是否是机器人自己
if replace_bot_name and bbb == global_config.bot.qq_account:
reply_person_name = f"{global_config.bot.nickname}(你)"
else:
reply_person_name = name_resolver(platform, bbb) or aaa
content = re.sub(reply_pattern, f"回复 {reply_person_name}", content, count=1)
except Exception:
# 如果解析失败,使用原始昵称
content = re.sub(reply_pattern, f"回复 {aaa}", content, count=1)
# 处理@<aaa:bbb>格式
at_pattern = r"@<([^:<>]+):([^:<>]+)>"
at_matches = list(re.finditer(at_pattern, content))
if at_matches:
new_content = ""
last_end = 0
for m in at_matches:
new_content += content[last_end : m.start()]
aaa = m.group(1)
bbb = m.group(2)
try:
# 检查是否是机器人自己
if replace_bot_name and bbb == global_config.bot.qq_account:
at_person_name = f"{global_config.bot.nickname}(你)"
else:
at_person_name = name_resolver(platform, bbb) or aaa
new_content += f"@{at_person_name}"
except Exception:
# 如果解析失败,使用原始昵称
new_content += f"@{aaa}"
last_end = m.end()
new_content += content[last_end:]
content = new_content
return content
async def replace_user_references_async(
content: str,
platform: str,
name_resolver: Optional[Callable[[str, str], Any]] = None,
replace_bot_name: bool = True,
) -> str:
"""
替换内容中的用户引用格式,包括回复<aaa:bbb>和@<aaa:bbb>格式
Args:
content: 要处理的内容字符串
platform: 平台标识
name_resolver: 名称解析函数,接收(platform, user_id)参数,返回用户名称
如果为None则使用默认的person_info_manager
replace_bot_name: 是否将机器人的user_id替换为"机器人昵称(你)"
Returns:
str: 处理后的内容字符串
"""
if name_resolver is None:
person_info_manager = get_person_info_manager()
async def default_resolver(platform: str, user_id: str) -> str:
# 检查是否是机器人自己
if replace_bot_name and user_id == global_config.bot.qq_account:
return f"{global_config.bot.nickname}(你)"
person_id = PersonInfoManager.get_person_id(platform, user_id)
return await person_info_manager.get_value(person_id, "person_name") or user_id # type: ignore
name_resolver = default_resolver
# 处理回复<aaa:bbb>格式
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
match = re.search(reply_pattern, content)
if match:
aaa = match.group(1)
bbb = match.group(2)
try:
# 检查是否是机器人自己
if replace_bot_name and bbb == global_config.bot.qq_account:
reply_person_name = f"{global_config.bot.nickname}(你)"
else:
reply_person_name = await name_resolver(platform, bbb) or aaa
content = re.sub(reply_pattern, f"回复 {reply_person_name}", content, count=1)
except Exception:
# 如果解析失败,使用原始昵称
content = re.sub(reply_pattern, f"回复 {aaa}", content, count=1)
# 处理@<aaa:bbb>格式
at_pattern = r"@<([^:<>]+):([^:<>]+)>"
at_matches = list(re.finditer(at_pattern, content))
if at_matches:
new_content = ""
last_end = 0
for m in at_matches:
new_content += content[last_end : m.start()]
aaa = m.group(1)
bbb = m.group(2)
try:
# 检查是否是机器人自己
if replace_bot_name and bbb == global_config.bot.qq_account:
at_person_name = f"{global_config.bot.nickname}(你)"
else:
at_person_name = await name_resolver(platform, bbb) or aaa
new_content += f"@{at_person_name}"
except Exception:
# 如果解析失败,使用原始昵称
new_content += f"@{aaa}"
last_end = m.end()
new_content += content[last_end:]
content = new_content
return content
def get_raw_msg_by_timestamp(
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
@@ -374,33 +524,8 @@ def _build_readable_messages_internal(
else:
person_name = "某人"
# 检查是否有 回复<aaa:bbb> 字段
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
match = re.search(reply_pattern, content)
if match:
aaa: str = match[1]
bbb: str = match[2]
reply_person_id = PersonInfoManager.get_person_id(platform, bbb)
reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name") or aaa
# 在内容前加上回复信息
content = re.sub(reply_pattern, lambda m, name=reply_person_name: f"回复 {name}", content, count=1)
# 检查是否有 @<aaa:bbb> 字段 @<{member_info.get('nickname')}:{member_info.get('user_id')}>
at_pattern = r"@<([^:<>]+):([^:<>]+)>"
at_matches = list(re.finditer(at_pattern, content))
if at_matches:
new_content = ""
last_end = 0
for m in at_matches:
new_content += content[last_end : m.start()]
aaa = m.group(1)
bbb = m.group(2)
at_person_id = PersonInfoManager.get_person_id(platform, bbb)
at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name") or aaa
new_content += f"@{at_person_name}"
last_end = m.end()
new_content += content[last_end:]
content = new_content
# 使用独立函数处理用户引用格式
content = replace_user_references_sync(content, platform, replace_bot_name=replace_bot_name)
target_str = "这是QQ的一个功能用于提及某人但没那么明显"
if target_str in content and random.random() < 0.6:
@@ -654,6 +779,7 @@ async def build_readable_messages_with_list(
return formatted_string, details_list
def build_readable_messages_with_id(
messages: List[Dict[str, Any]],
replace_bot_name: bool = True,
@@ -669,9 +795,9 @@ def build_readable_messages_with_id(
允许通过参数控制格式化行为。
"""
message_id_list = assign_message_ids(messages)
formatted_string = build_readable_messages(
messages = messages,
messages=messages,
replace_bot_name=replace_bot_name,
merge_messages=merge_messages,
timestamp_mode=timestamp_mode,
@@ -682,10 +808,7 @@ def build_readable_messages_with_id(
message_id_list=message_id_list,
)
return formatted_string , message_id_list
return formatted_string, message_id_list
def build_readable_messages(
@@ -770,7 +893,13 @@ def build_readable_messages(
if read_mark <= 0:
# 没有有效的 read_mark直接格式化所有消息
formatted_string, _, pic_id_mapping, _ = _build_readable_messages_internal(
copy_messages, replace_bot_name, merge_messages, timestamp_mode, truncate, show_pic=show_pic, message_id_list=message_id_list
copy_messages,
replace_bot_name,
merge_messages,
timestamp_mode,
truncate,
show_pic=show_pic,
message_id_list=message_id_list,
)
# 生成图片映射信息并添加到最前面
@@ -893,7 +1022,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
for msg in messages:
try:
platform = msg.get("chat_info_platform")
platform: str = msg.get("chat_info_platform") # type: ignore
user_id = msg.get("user_id")
_timestamp = msg.get("time")
content: str = ""
@@ -916,38 +1045,14 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
anon_name = get_anon_name(platform, user_id)
# print(f"anon_name:{anon_name}")
# 处理 回复<aaa:bbb>
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
match = re.search(reply_pattern, content)
if match:
# print(f"发现回复match:{match}")
bbb = match.group(2)
# 使用独立函数处理用户引用格式,传入自定义的匿名名称解析器
def anon_name_resolver(platform: str, user_id: str) -> str:
try:
anon_reply = get_anon_name(platform, bbb)
# print(f"anon_reply:{anon_reply}")
return get_anon_name(platform, user_id)
except Exception:
anon_reply = "?"
content = re.sub(reply_pattern, f"回复 {anon_reply}", content, count=1)
return "?"
# 处理 @<aaa:bbb>无嵌套def
at_pattern = r"@<([^:<>]+):([^:<>]+)>"
at_matches = list(re.finditer(at_pattern, content))
if at_matches:
# print(f"发现@match:{at_matches}")
new_content = ""
last_end = 0
for m in at_matches:
new_content += content[last_end : m.start()]
bbb = m.group(2)
try:
anon_at = get_anon_name(platform, bbb)
# print(f"anon_at:{anon_at}")
except Exception:
anon_at = "?"
new_content += f"@{anon_at}"
last_end = m.end()
new_content += content[last_end:]
content = new_content
content = replace_user_references_sync(content, platform, anon_name_resolver, replace_bot_name=False)
header = f"{anon_name}"
output_lines.append(header)

View File

@@ -37,7 +37,7 @@ class ImageManager:
self._ensure_image_dir()
self._initialized = True
self._llm = LLMRequest(model=global_config.model.vlm, temperature=0.4, max_tokens=300, request_type="image")
self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.4, max_tokens=300, request_type="image")
try:
db.connect(reuse_if_open=True)
@@ -94,7 +94,7 @@ class ImageManager:
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
async def get_emoji_description(self, image_base64: str) -> str:
"""获取表情包描述,使用二步走识别并带缓存优化"""
"""获取表情包描述,优先使用Emoji表中的缓存数据"""
try:
# 计算图片哈希
# 确保base64字符串只包含ASCII字符
@@ -104,9 +104,21 @@ class ImageManager:
image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
# 查询缓存的描述
# 优先使用EmojiManager查询已注册表情包的描述
try:
from src.chat.emoji_system.emoji_manager import get_emoji_manager
emoji_manager = get_emoji_manager()
cached_emoji_description = await emoji_manager.get_emoji_description_by_hash(image_hash)
if cached_emoji_description:
logger.info(f"[缓存命中] 使用已注册表情包描述: {cached_emoji_description[:50]}...")
return cached_emoji_description
except Exception as e:
logger.debug(f"查询EmojiManager时出错: {e}")
# 查询ImageDescriptions表的缓存描述
cached_description = self._get_description_from_db(image_hash, "emoji")
if cached_description:
logger.info(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
return f"[表情包:{cached_description}]"
# === 二步走识别流程 ===
@@ -118,10 +130,10 @@ class ImageManager:
logger.warning("GIF转换失败无法获取描述")
return "[表情包(GIF处理失败)]"
vlm_prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
detailed_description, _ = await self._llm.generate_response_for_image(vlm_prompt, image_base64_processed, "jpg")
detailed_description, _ = await self.vlm.generate_response_for_image(vlm_prompt, image_base64_processed, "jpg")
else:
vlm_prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
detailed_description, _ = await self._llm.generate_response_for_image(vlm_prompt, image_base64, image_format)
detailed_description, _ = await self.vlm.generate_response_for_image(vlm_prompt, image_base64, image_format)
if detailed_description is None:
logger.warning("VLM未能生成表情包详细描述")
@@ -158,7 +170,7 @@ class ImageManager:
if len(emotions) > 1 and emotions[1] != emotions[0]:
final_emotion = f"{emotions[0]}{emotions[1]}"
logger.info(f"[二步走识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}")
logger.info(f"[emoji识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}")
# 再次检查缓存,防止并发写入时重复生成
cached_description = self._get_description_from_db(image_hash, "emoji")
@@ -201,13 +213,13 @@ class ImageManager:
self._save_description_to_db(image_hash, final_emotion, "emoji")
return f"[表情包:{final_emotion}]"
except Exception as e:
logger.error(f"获取表情包描述失败: {str(e)}")
return "[表情包]"
return "[表情包(处理失败)]"
async def get_image_description(self, image_base64: str) -> str:
"""获取普通图片描述,带查重和保存功能"""
"""获取普通图片描述,优先使用Images表中的缓存数据"""
try:
# 计算图片哈希
if isinstance(image_base64, str):
@@ -215,7 +227,7 @@ class ImageManager:
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
# 检查图片是否已存在
# 优先检查Images表中是否已有完整的描述
existing_image = Images.get_or_none(Images.emoji_hash == image_hash)
if existing_image:
# 更新计数
@@ -227,18 +239,20 @@ class ImageManager:
# 如果已有描述,直接返回
if existing_image.description:
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...")
return f"[图片:{existing_image.description}]"
# 查询缓存描述
# 查询ImageDescriptions表的缓存描述
cached_description = self._get_description_from_db(image_hash, "image")
if cached_description:
logger.debug(f"图片描述缓存中 {cached_description}")
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
return f"[图片:{cached_description}]"
# 调用AI获取描述
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
prompt = global_config.custom_prompt.image_prompt
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
if description is None:
logger.warning("AI未能生成图片描述")
@@ -266,6 +280,7 @@ class ImageManager:
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
existing_image.vlm_processed = True
existing_image.save()
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
else:
Images.create(
image_id=str(uuid.uuid4()),
@@ -277,16 +292,18 @@ class ImageManager:
vlm_processed=True,
count=1,
)
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
except Exception as e:
logger.error(f"保存图片文件或元数据失败: {str(e)}")
# 保存描述到ImageDescriptions表
# 保存描述到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, description, "image")
logger.info(f"[VLM完成] 图片描述生成: {description[:50]}...")
return f"[图片:{description}]"
except Exception as e:
logger.error(f"获取图片描述失败: {str(e)}")
return "[图片]"
return "[图片(处理失败)]"
@staticmethod
def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> Optional[str]:
@@ -502,12 +519,28 @@ class ImageManager:
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
# 先检查缓存的描述
# 获取当前图片记录
image = Images.get(Images.image_id == image_id)
# 优先检查是否已有其他相同哈希的图片记录包含描述
existing_with_description = Images.get_or_none(
(Images.emoji_hash == image_hash) &
(Images.description.is_null(False)) &
(Images.description != "")
)
if existing_with_description and existing_with_description.id != image.id:
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
image.description = existing_with_description.description
image.vlm_processed = True
image.save()
# 同时保存到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, existing_with_description.description, "image")
return
# 检查ImageDescriptions表的缓存描述
cached_description = self._get_description_from_db(image_hash, "image")
if cached_description:
logger.debug(f"VLM处理时发现缓存描述: {cached_description}")
# 更新数据库
image = Images.get(Images.image_id == image_id)
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
image.description = cached_description
image.vlm_processed = True
image.save()
@@ -520,7 +553,8 @@ class ImageManager:
prompt = global_config.custom_prompt.image_prompt
# 获取VLM描述
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)")
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
if description is None:
logger.warning("VLM未能生成图片描述")
@@ -533,14 +567,15 @@ class ImageManager:
description = cached_description
# 更新数据库
image = Images.get(Images.image_id == image_id)
image.description = description
image.vlm_processed = True
image.save()
# 保存描述到ImageDescriptions表
# 保存描述到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, description, "image")
logger.info(f"[VLM异步完成] 图片描述生成: {description[:50]}...")
except Exception as e:
logger.error(f"VLM处理图片失败: {str(e)}")

View File

@@ -28,7 +28,7 @@ class ClassicalWillingManager(BaseWillingManager):
# print(f"[{chat_id}] 回复意愿: {current_willing}")
interested_rate = willing_info.interested_rate * global_config.normal_chat.response_interested_rate_amplifier
interested_rate = willing_info.interested_rate
# print(f"[{chat_id}] 兴趣值: {interested_rate}")
@@ -36,20 +36,18 @@ class ClassicalWillingManager(BaseWillingManager):
current_willing += interested_rate - 0.2
if willing_info.is_mentioned_bot and global_config.chat.mentioned_bot_inevitable_reply and current_willing < 2:
current_willing += 1 if current_willing < 1.0 else 0.05
current_willing += 1 if current_willing < 1.0 else 0.2
self.chat_reply_willing[chat_id] = min(current_willing, 1.0)
reply_probability = min(max((current_willing - 0.5), 0.01) * 2, 1)
reply_probability = min(max((current_willing - 0.5), 0.01) * 2, 1.5)
# print(f"[{chat_id}] 回复概率: {reply_probability}")
return reply_probability
async def before_generate_reply_handle(self, message_id):
chat_id = self.ongoing_messages[message_id].chat_id
current_willing = self.chat_reply_willing.get(chat_id, 0)
self.chat_reply_willing[chat_id] = max(0.0, current_willing - 1.8)
pass
async def after_generate_reply_handle(self, message_id):
if message_id not in self.ongoing_messages:
@@ -58,7 +56,7 @@ class ClassicalWillingManager(BaseWillingManager):
chat_id = self.ongoing_messages[message_id].chat_id
current_willing = self.chat_reply_willing.get(chat_id, 0)
if current_willing < 1:
self.chat_reply_willing[chat_id] = min(1.0, current_willing + 0.4)
self.chat_reply_willing[chat_id] = min(1.0, current_willing + 0.3)
async def not_reply_handle(self, message_id):
return await super().not_reply_handle(message_id)

View File

@@ -390,7 +390,7 @@ MODULE_COLORS = {
"tts_action": "\033[38;5;58m", # 深黄色
"doubao_pic_plugin": "\033[38;5;64m", # 深绿色
# Action组件
"no_reply_action": "\033[38;5;196m", # 亮色,显眼
"no_reply_action": "\033[38;5;214m", # 亮色,显眼但不像警告
"reply_action": "\033[38;5;46m", # 亮绿色
"base_action": "\033[38;5;250m", # 浅灰色
# 数据库和消息

View File

@@ -69,6 +69,8 @@ class ChatConfig(ConfigBase):
max_context_size: int = 18
"""上下文长度"""
willing_amplifier: float = 1.0
replyer_random_probability: float = 0.5
"""
@@ -76,15 +78,12 @@ class ChatConfig(ConfigBase):
选择普通模型的概率为 1 - reasoning_normal_model_probability
"""
thinking_timeout: int = 30
thinking_timeout: int = 40
"""麦麦最长思考规划时间超过这个时间的思考会放弃往往是api反应太慢"""
talk_frequency: float = 1
"""回复频率阈值"""
use_s4u_prompt_mode: bool = False
"""是否使用 s4u 对话构建模式,该模式会分开处理当前对话对象和其他所有对话的内容进行 prompt 构建"""
mentioned_bot_inevitable_reply: bool = False
"""提及 bot 必然回复"""
@@ -274,12 +273,6 @@ class NormalChatConfig(ConfigBase):
willing_mode: str = "classical"
"""意愿模式"""
response_interested_rate_amplifier: float = 1.0
"""回复兴趣度放大系数"""
@dataclass
class ExpressionConfig(ConfigBase):
"""表达配置类"""
@@ -307,11 +300,8 @@ class ExpressionConfig(ConfigBase):
class ToolConfig(ConfigBase):
"""工具配置类"""
enable_in_normal_chat: bool = False
"""是否在普通聊天中启用工具"""
enable_in_focus_chat: bool = True
"""是否在专注聊天中启用工具"""
enable_tool: bool = False
"""是否在聊天中启用工具"""
@dataclass
class VoiceConfig(ConfigBase):

View File

@@ -273,15 +273,19 @@ class Individuality:
prompt=prompt,
)
if response.strip():
if response and response.strip():
personality_parts.append(response.strip())
logger.info(f"精简人格侧面: {response.strip()}")
else:
logger.error(f"使用LLM压缩人设时出错: {response}")
# 压缩失败时使用原始内容
if personality_side:
personality_parts.append(personality_side)
if personality_parts:
personality_result = "".join(personality_parts)
else:
personality_result = personality_core
personality_result = personality_core or "友好活泼"
else:
personality_result = personality_core
if personality_side:
@@ -308,13 +312,14 @@ class Individuality:
prompt=prompt,
)
if response.strip():
if response and response.strip():
identity_result = response.strip()
logger.info(f"精简身份: {identity_result}")
else:
logger.error(f"使用LLM压缩身份时出错: {response}")
identity_result = identity
else:
identity_result = "".join(identity)
identity_result = identity
return identity_result

View File

@@ -47,11 +47,35 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
logger.debug(f"记忆激活率: {interested_rate:.2f}")
text_len = len(message.processed_plain_text)
# 根据文本长度调整兴趣度,长度越大兴趣度越高但增长率递减最低0.01最高0.05
# 采用对数函数实现递减增长
base_interest = 0.01 + (0.05 - 0.01) * (math.log10(text_len + 1) / math.log10(1000 + 1))
base_interest = min(max(base_interest, 0.01), 0.05)
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
# 基于实际分布0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%)
if text_len == 0:
base_interest = 0.01 # 空消息最低兴趣度
elif text_len <= 5:
# 1-5字符线性增长 0.01 -> 0.03
base_interest = 0.01 + (text_len - 1) * (0.03 - 0.01) / 4
elif text_len <= 10:
# 6-10字符线性增长 0.03 -> 0.06
base_interest = 0.03 + (text_len - 5) * (0.06 - 0.03) / 5
elif text_len <= 20:
# 11-20字符线性增长 0.06 -> 0.12
base_interest = 0.06 + (text_len - 10) * (0.12 - 0.06) / 10
elif text_len <= 30:
# 21-30字符线性增长 0.12 -> 0.18
base_interest = 0.12 + (text_len - 20) * (0.18 - 0.12) / 10
elif text_len <= 50:
# 31-50字符线性增长 0.18 -> 0.22
base_interest = 0.18 + (text_len - 30) * (0.22 - 0.18) / 20
elif text_len <= 100:
# 51-100字符线性增长 0.22 -> 0.26
base_interest = 0.22 + (text_len - 50) * (0.26 - 0.22) / 50
else:
# 100+字符:对数增长 0.26 -> 0.3,增长率递减
base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901
# 确保在范围内
base_interest = min(max(base_interest, 0.01), 0.3)
interested_rate += base_interest

View File

@@ -78,7 +78,7 @@ class ChatMood:
if interested_rate <= 0:
interest_multiplier = 0
else:
interest_multiplier = 3 * math.pow(interested_rate, 0.25)
interest_multiplier = 2 * math.pow(interested_rate, 0.25)
logger.debug(
f"base_probability: {base_probability}, time_multiplier: {time_multiplier}, interest_multiplier: {interest_multiplier}"

View File

@@ -139,7 +139,7 @@ class RelationshipManager:
请用json格式输出引起了你的兴趣或者有什么需要你记忆的点。
并为每个点赋予1-10的权重权重越高表示越重要。
格式如下:
{{
[
{{
"point": "{person_name}想让我记住他的生日我回答确认了他的生日是11月23日",
"weight": 10
@@ -156,13 +156,10 @@ class RelationshipManager:
"point": "{person_name}喜欢吃辣具体来说没有辣的食物ta都不喜欢吃可能是因为ta是湖南人。",
"weight": 7
}}
}}
]
如果没有就输出none,或points为空
{{
"point": "none",
"weight": 0
}}
如果没有就输出none,或返回空数组
[]
"""
# 调用LLM生成印象
@@ -184,17 +181,25 @@ class RelationshipManager:
try:
points = repair_json(points)
points_data = json.loads(points)
if points_data == "none" or not points_data or points_data.get("point") == "none":
# 只处理正确的格式,错误格式直接跳过
if points_data == "none" or not points_data:
points_list = []
elif isinstance(points_data, str) and points_data.lower() == "none":
points_list = []
elif isinstance(points_data, list):
# 正确格式:数组格式 [{"point": "...", "weight": 10}, ...]
if not points_data: # 空数组
points_list = []
else:
points_list = [(item["point"], float(item["weight"]), current_time) for item in points_data]
else:
# logger.info(f"points_data: {points_data}")
if isinstance(points_data, dict) and "points" in points_data:
points_data = points_data["points"]
if not isinstance(points_data, list):
points_data = [points_data]
# 添加可读时间到每个point
points_list = [(item["point"], float(item["weight"]), current_time) for item in points_data]
# 错误格式,直接跳过不解析
logger.warning(f"LLM返回了错误的JSON格式跳过解析: {type(points_data)}, 内容: {points_data}")
points_list = []
# 权重过滤逻辑
if points_list:
original_points_list = list(points_list)
points_list.clear()
discarded_count = 0

View File

@@ -9,6 +9,7 @@ from .base import (
BasePlugin,
BaseAction,
BaseCommand,
BaseTool,
ConfigField,
ComponentType,
ActionActivationType,
@@ -34,6 +35,7 @@ from .utils import (
from .apis import (
chat_api,
tool_api,
component_manage_api,
config_api,
database_api,
@@ -44,17 +46,17 @@ from .apis import (
person_api,
plugin_manage_api,
send_api,
utils_api,
register_plugin,
get_logger,
)
__version__ = "1.0.0"
__version__ = "2.0.0"
__all__ = [
# API 模块
"chat_api",
"tool_api",
"component_manage_api",
"config_api",
"database_api",
@@ -65,13 +67,13 @@ __all__ = [
"person_api",
"plugin_manage_api",
"send_api",
"utils_api",
"register_plugin",
"get_logger",
# 基础类
"BasePlugin",
"BaseAction",
"BaseCommand",
"BaseTool",
"BaseEventHandler",
# 类型定义
"ComponentType",

View File

@@ -17,7 +17,7 @@ from src.plugin_system.apis import (
person_api,
plugin_manage_api,
send_api,
utils_api,
tool_api,
)
from .logging_api import get_logger
from .plugin_register_api import register_plugin
@@ -35,7 +35,7 @@ __all__ = [
"person_api",
"plugin_manage_api",
"send_api",
"utils_api",
"get_logger",
"register_plugin",
"tool_api",
]

View File

@@ -32,6 +32,7 @@ class ChatManager:
@staticmethod
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
# sourcery skip: for-append-to-extend
"""获取所有聊天流
Args:
@@ -57,6 +58,7 @@ class ChatManager:
@staticmethod
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
# sourcery skip: for-append-to-extend
"""获取所有群聊聊天流
Args:
@@ -79,6 +81,7 @@ class ChatManager:
@staticmethod
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
# sourcery skip: for-append-to-extend
"""获取所有私聊聊天流
Args:
@@ -105,7 +108,7 @@ class ChatManager:
@staticmethod
def get_group_stream_by_group_id(
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[ChatStream]:
) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast
"""根据群ID获取聊天流
Args:
@@ -142,7 +145,7 @@ class ChatManager:
@staticmethod
def get_private_stream_by_user_id(
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[ChatStream]:
) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast
"""根据用户ID获取私聊流
Args:
@@ -207,7 +210,7 @@ class ChatManager:
chat_stream: 聊天流对象
Returns:
Dict[str, Any]: 聊天流信息字典
Dict ({str: Any}): 聊天流信息字典
Raises:
TypeError: 如果 chat_stream 不是 ChatStream 类型
@@ -282,41 +285,41 @@ class ChatManager:
# =============================================================================
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq"):
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
"""获取所有聊天流的便捷函数"""
return ChatManager.get_all_streams(platform)
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq"):
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
"""获取群聊聊天流的便捷函数"""
return ChatManager.get_group_streams(platform)
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq"):
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
"""获取私聊聊天流的便捷函数"""
return ChatManager.get_private_streams(platform)
def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq"):
def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]:
"""根据群ID获取聊天流的便捷函数"""
return ChatManager.get_group_stream_by_group_id(group_id, platform)
def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq"):
def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]:
"""根据用户ID获取私聊流的便捷函数"""
return ChatManager.get_private_stream_by_user_id(user_id, platform)
def get_stream_type(chat_stream: ChatStream):
def get_stream_type(chat_stream: ChatStream) -> str:
"""获取聊天流类型的便捷函数"""
return ChatManager.get_stream_type(chat_stream)
def get_stream_info(chat_stream: ChatStream):
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
"""获取聊天流信息的便捷函数"""
return ChatManager.get_stream_info(chat_stream)
def get_streams_summary():
def get_streams_summary() -> Dict[str, int]:
"""获取聊天流统计摘要的便捷函数"""
return ChatManager.get_streams_summary()

View File

@@ -5,6 +5,7 @@ from src.plugin_system.base.component_types import (
EventHandlerInfo,
PluginInfo,
ComponentType,
ToolInfo,
)
@@ -119,6 +120,21 @@ def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
return component_registry.get_registered_command_info(command_name)
def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]:
"""
获取指定 Tool 的注册信息。
Args:
tool_name (str): Tool 名称。
Returns:
ToolInfo: Tool 信息对象,如果 Tool 不存在则返回 None。
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_registered_tool_info(tool_name)
# === EventHandler 特定查询方法 ===
def get_registered_event_handler_info(
event_handler_name: str,
@@ -191,6 +207,8 @@ def locally_enable_component(component_name: str, component_type: ComponentType,
return global_announcement_manager.enable_specific_chat_action(stream_id, component_name)
case ComponentType.COMMAND:
return global_announcement_manager.enable_specific_chat_command(stream_id, component_name)
case ComponentType.TOOL:
return global_announcement_manager.enable_specific_chat_tool(stream_id, component_name)
case ComponentType.EVENT_HANDLER:
return global_announcement_manager.enable_specific_chat_event_handler(stream_id, component_name)
case _:
@@ -216,11 +234,14 @@ def locally_disable_component(component_name: str, component_type: ComponentType
return global_announcement_manager.disable_specific_chat_action(stream_id, component_name)
case ComponentType.COMMAND:
return global_announcement_manager.disable_specific_chat_command(stream_id, component_name)
case ComponentType.TOOL:
return global_announcement_manager.disable_specific_chat_tool(stream_id, component_name)
case ComponentType.EVENT_HANDLER:
return global_announcement_manager.disable_specific_chat_event_handler(stream_id, component_name)
case _:
raise ValueError(f"未知 component type: {component_type}")
def get_locally_disabled_components(stream_id: str, component_type: ComponentType) -> list[str]:
"""
获取指定消息流中禁用的组件列表。
@@ -239,7 +260,9 @@ def get_locally_disabled_components(stream_id: str, component_type: ComponentTyp
return global_announcement_manager.get_disabled_chat_actions(stream_id)
case ComponentType.COMMAND:
return global_announcement_manager.get_disabled_chat_commands(stream_id)
case ComponentType.TOOL:
return global_announcement_manager.get_disabled_chat_tools(stream_id)
case ComponentType.EVENT_HANDLER:
return global_announcement_manager.get_disabled_chat_event_handlers(stream_id)
case _:
raise ValueError(f"未知 component type: {component_type}")
raise ValueError(f"未知 component type: {component_type}")

View File

@@ -10,7 +10,6 @@
from typing import Any
from src.common.logger import get_logger
from src.config.config import global_config
from src.person_info.person_info import get_person_info_manager
logger = get_logger("config_api")
@@ -26,7 +25,7 @@ def get_global_config(key: str, default: Any = None) -> Any:
插件应使用此方法读取全局配置,以保证只读和隔离性。
Args:
key: 命名空间式配置键名,支持嵌套访问,如 "section.subsection.key",大小写敏感
key: 命名空间式配置键名,使用嵌套访问,如 "section.subsection.key",大小写敏感
default: 如果配置不存在时返回的默认值
Returns:
@@ -76,50 +75,3 @@ def get_plugin_config(plugin_config: dict, key: str, default: Any = None) -> Any
except Exception as e:
logger.warning(f"[ConfigAPI] 获取插件配置 {key} 失败: {e}")
return default
# =============================================================================
# 用户信息API函数
# =============================================================================
async def get_user_id_by_person_name(person_name: str) -> tuple[str, str]:
"""根据内部用户名获取用户ID
Args:
person_name: 用户名
Returns:
tuple[str, str]: (平台, 用户ID)
"""
try:
person_info_manager = get_person_info_manager()
person_id = person_info_manager.get_person_id_by_person_name(person_name)
user_id: str = await person_info_manager.get_value(person_id, "user_id") # type: ignore
platform: str = await person_info_manager.get_value(person_id, "platform") # type: ignore
return platform, user_id
except Exception as e:
logger.error(f"[ConfigAPI] 根据用户名获取用户ID失败: {e}")
return "", ""
async def get_person_info(person_id: str, key: str, default: Any = None) -> Any:
"""获取用户信息
Args:
person_id: 用户ID
key: 信息键名
default: 默认值
Returns:
Any: 用户信息值或默认值
"""
try:
person_info_manager = get_person_info_manager()
response = await person_info_manager.get_value(person_id, key)
if not response:
raise ValueError(f"[ConfigAPI] 获取用户 {person_id} 的信息 '{key}' 失败,返回默认值")
return response
except Exception as e:
logger.error(f"[ConfigAPI] 获取用户信息失败: {e}")
return default

View File

@@ -152,10 +152,7 @@ async def db_query(
except DoesNotExist:
# 记录不存在
if query_type == "get" and single_result:
return None
return []
return None if query_type == "get" and single_result else []
except Exception as e:
logger.error(f"[DatabaseAPI] 数据库操作出错: {e}")
traceback.print_exc()
@@ -170,7 +167,8 @@ async def db_query(
async def db_save(
model_class: Type[Model], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None
) -> Union[Dict[str, Any], None]:
) -> Optional[Dict[str, Any]]:
# sourcery skip: inline-immediately-returned-variable
"""保存数据到数据库(创建或更新)
如果提供了key_field和key_value会先尝试查找匹配的记录进行更新
@@ -203,10 +201,9 @@ async def db_save(
try:
# 如果提供了key_field和key_value尝试更新现有记录
if key_field and key_value is not None:
# 查找现有记录
existing_records = list(model_class.select().where(getattr(model_class, key_field) == key_value).limit(1))
if existing_records:
if existing_records := list(
model_class.select().where(getattr(model_class, key_field) == key_value).limit(1)
):
# 更新现有记录
existing_record = existing_records[0]
for field, value in data.items():
@@ -244,8 +241,8 @@ async def db_get(
Args:
model_class: Peewee模型类
filters: 过滤条件,字段名和值的字典
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段即time字段降序
limit: 结果数量限制
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段即time字段降序
single_result: 是否只返回单个结果如果为True则返回单个记录字典或None否则返回记录字典列表或空列表
Returns:
@@ -310,7 +307,7 @@ async def store_action_info(
thinking_id: str = "",
action_data: Optional[dict] = None,
action_name: str = "",
) -> Union[Dict[str, Any], None]:
) -> Optional[Dict[str, Any]]:
"""存储动作信息到数据库
将Action执行的相关信息保存到ActionRecords表中用于后续的记忆和上下文构建。

View File

@@ -65,14 +65,14 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]
return None
async def get_random(count: Optional[int] = 1) -> Optional[List[Tuple[str, str, str]]]:
async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
"""随机获取指定数量的表情包
Args:
count: 要获取的表情包数量默认为1
Returns:
Optional[List[Tuple[str, str, str]]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表,如果失败则为None
List[Tuple[str, str, str]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表,失败则返回空列表
Raises:
TypeError: 如果count不是整数类型
@@ -94,13 +94,13 @@ async def get_random(count: Optional[int] = 1) -> Optional[List[Tuple[str, str,
if not all_emojis:
logger.warning("[EmojiAPI] 没有可用的表情包")
return None
return []
# 过滤有效表情包
valid_emojis = [emoji for emoji in all_emojis if not emoji.is_deleted]
if not valid_emojis:
logger.warning("[EmojiAPI] 没有有效的表情包")
return None
return []
if len(valid_emojis) < count:
logger.warning(
@@ -127,14 +127,14 @@ async def get_random(count: Optional[int] = 1) -> Optional[List[Tuple[str, str,
if not results and count > 0:
logger.warning("[EmojiAPI] 随机获取表情包失败,没有一个可以成功处理")
return None
return []
logger.info(f"[EmojiAPI] 成功获取 {len(results)} 个随机表情包")
return results
except Exception as e:
logger.error(f"[EmojiAPI] 获取随机表情包失败: {e}")
return None
return []
async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
@@ -162,10 +162,11 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
# 筛选匹配情感的表情包
matching_emojis = []
for emoji_obj in all_emojis:
if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion]:
matching_emojis.append(emoji_obj)
matching_emojis.extend(
emoji_obj
for emoji_obj in all_emojis
if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion]
)
if not matching_emojis:
logger.warning(f"[EmojiAPI] 未找到匹配情感 '{emotion}' 的表情包")
return None
@@ -256,10 +257,11 @@ def get_descriptions() -> List[str]:
emoji_manager = get_emoji_manager()
descriptions = []
for emoji_obj in emoji_manager.emoji_objects:
if not emoji_obj.is_deleted and emoji_obj.description:
descriptions.append(emoji_obj.description)
descriptions.extend(
emoji_obj.description
for emoji_obj in emoji_manager.emoji_objects
if not emoji_obj.is_deleted and emoji_obj.description
)
return descriptions
except Exception as e:
logger.error(f"[EmojiAPI] 获取表情包描述失败: {e}")

View File

@@ -84,18 +84,23 @@ async def generate_reply(
enable_chinese_typo: bool = True,
return_prompt: bool = False,
model_configs: Optional[List[Dict[str, Any]]] = None,
request_type: str = "",
enable_timeout: bool = False,
request_type: str = "generator_api",
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
"""生成回复
Args:
chat_stream: 聊天流对象(优先)
chat_id: 聊天ID备用
action_data: 动作数据
action_data: 动作数据向下兼容包含reply_to和extra_info
reply_to: 回复对象,格式为 "发送者:消息内容"
extra_info: 额外信息,用于补充上下文
available_actions: 可用动作
enable_tool: 是否启用工具调用
enable_splitter: 是否启用消息分割器
enable_chinese_typo: 是否启用错字生成器
return_prompt: 是否返回提示词
model_configs: 模型配置列表
request_type: 请求类型可选记录LLM使用
Returns:
Tuple[bool, List[Tuple[str, Any]], Optional[str]]: (是否成功, 回复集合, 提示词)
"""
@@ -108,13 +113,16 @@ async def generate_reply(
logger.debug("[GeneratorAPI] 开始生成回复")
if not reply_to and action_data:
reply_to = action_data.get("reply_to", "")
if not extra_info and action_data:
extra_info = action_data.get("extra_info", "")
# 调用回复器生成回复
success, content, prompt = await replyer.generate_reply_with_context(
reply_data=action_data or {},
reply_to=reply_to,
extra_info=extra_info,
available_actions=available_actions,
enable_timeout=enable_timeout,
enable_tool=enable_tool,
)
reply_set = []
@@ -136,6 +144,7 @@ async def generate_reply(
except Exception as e:
logger.error(f"[GeneratorAPI] 生成回复时出错: {e}")
logger.error(traceback.format_exc())
return False, [], None
@@ -146,15 +155,24 @@ async def rewrite_reply(
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
model_configs: Optional[List[Dict[str, Any]]] = None,
) -> Tuple[bool, List[Tuple[str, Any]]]:
raw_reply: str = "",
reason: str = "",
reply_to: str = "",
return_prompt: bool = False,
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
"""重写回复
Args:
chat_stream: 聊天流对象(优先)
reply_data: 回复数据
reply_data: 回复数据字典(向下兼容备用,当其他参数缺失时从此获取)
chat_id: 聊天ID备用
enable_splitter: 是否启用消息分割器
enable_chinese_typo: 是否启用错字生成器
model_configs: 模型配置列表
raw_reply: 原始回复内容
reason: 回复原因
reply_to: 回复对象
return_prompt: 是否返回提示词
Returns:
Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合)
@@ -164,12 +182,23 @@ async def rewrite_reply(
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return False, []
return False, [], None
logger.info("[GeneratorAPI] 开始重写回复")
# 如果参数缺失从reply_data中获取
if reply_data:
raw_reply = raw_reply or reply_data.get("raw_reply", "")
reason = reason or reply_data.get("reason", "")
reply_to = reply_to or reply_data.get("reply_to", "")
# 调用回复器重写回复
success, content = await replyer.rewrite_reply_with_context(reply_data=reply_data or {})
success, content, prompt = await replyer.rewrite_reply_with_context(
raw_reply=raw_reply,
reason=reason,
reply_to=reply_to,
return_prompt=return_prompt,
)
reply_set = []
if content:
reply_set = await process_human_text(content, enable_splitter, enable_chinese_typo)
@@ -179,14 +208,14 @@ async def rewrite_reply(
else:
logger.warning("[GeneratorAPI] 重写回复失败")
return success, reply_set
return success, reply_set, prompt if return_prompt else None
except ValueError as ve:
raise ve
except Exception as e:
logger.error(f"[GeneratorAPI] 重写回复时出错: {e}")
return False, []
return False, [], None
async def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> List[Tuple[str, Any]]:
@@ -212,3 +241,27 @@ async def process_human_text(content: str, enable_splitter: bool, enable_chinese
except Exception as e:
logger.error(f"[GeneratorAPI] 处理人形文本时出错: {e}")
return []
async def generate_response_custom(
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
model_configs: Optional[List[Dict[str, Any]]] = None,
prompt: str = "",
) -> Optional[str]:
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return None
try:
logger.debug("[GeneratorAPI] 开始生成自定义回复")
response = await replyer.llm_generate_content(prompt)
if response:
logger.debug("[GeneratorAPI] 自定义回复生成成功")
return response
else:
logger.warning("[GeneratorAPI] 自定义回复生成失败")
return None
except Exception as e:
logger.error(f"[GeneratorAPI] 生成自定义回复时出错: {e}")
return None

View File

@@ -54,7 +54,7 @@ def get_available_models() -> Dict[str, Any]:
async def generate_with_model(
prompt: str, model_config: Dict[str, Any], request_type: str = "plugin.generate", **kwargs
) -> Tuple[bool, str, str, str]:
) -> Tuple[bool, str]:
"""使用指定模型生成内容
Args:
@@ -73,10 +73,11 @@ async def generate_with_model(
llm_request = LLMRequest(model=model_config, request_type=request_type, **kwargs)
response, (reasoning, model_name) = await llm_request.generate_response_async(prompt)
return True, response, reasoning, model_name
# TODO: 复活这个_
response, _ = await llm_request.generate_response_async(prompt)
return True, response
except Exception as e:
error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"[LLMAPI] {error_msg}")
return False, error_msg, "", ""
return False, error_msg

View File

@@ -207,7 +207,7 @@ def get_random_chat_messages(
def get_messages_by_time_for_users(
start_time: float, end_time: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""
获取指定用户在所有聊天中指定时间范围内的消息
@@ -287,7 +287,7 @@ def get_messages_before_time_in_chat(
return get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit)
def get_messages_before_time_for_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]:
def get_messages_before_time_for_users(timestamp: float, person_ids: List[str], limit: int = 0) -> List[Dict[str, Any]]:
"""
获取指定用户在指定时间戳之前的消息
@@ -372,7 +372,7 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional
return num_new_messages_since(chat_id, start_time, end_time)
def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: list) -> int:
def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int:
"""
计算指定聊天中指定用户从开始时间到结束时间的新消息数量

View File

@@ -1,10 +1,12 @@
from typing import Tuple, List
def list_loaded_plugins() -> List[str]:
"""
列出所有当前加载的插件。
Returns:
list: 当前加载的插件名称列表。
List[str]: 当前加载的插件名称列表。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
@@ -16,17 +18,38 @@ def list_registered_plugins() -> List[str]:
列出所有已注册的插件。
Returns:
list: 已注册的插件名称列表。
List[str]: 已注册的插件名称列表。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
return plugin_manager.list_registered_plugins()
def get_plugin_path(plugin_name: str) -> str:
"""
获取指定插件的路径。
Args:
plugin_name (str): 插件名称。
Returns:
str: 插件目录的绝对路径。
Raises:
ValueError: 如果插件不存在。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
if plugin_path := plugin_manager.get_plugin_path(plugin_name):
return plugin_path
else:
raise ValueError(f"插件 '{plugin_name}' 不存在。")
async def remove_plugin(plugin_name: str) -> bool:
"""
卸载指定的插件。
**此函数是异步的,确保在异步环境中调用。**
Args:
@@ -43,7 +66,7 @@ async def remove_plugin(plugin_name: str) -> bool:
async def reload_plugin(plugin_name: str) -> bool:
"""
重新加载指定的插件。
**此函数是异步的,确保在异步环境中调用。**
Args:
@@ -71,6 +94,7 @@ def load_plugin(plugin_name: str) -> Tuple[bool, int]:
return plugin_manager.load_registered_plugin_classes(plugin_name)
def add_plugin_directory(plugin_directory: str) -> bool:
"""
添加插件目录。
@@ -84,6 +108,7 @@ def add_plugin_directory(plugin_directory: str) -> bool:
return plugin_manager.add_plugin_directory(plugin_directory)
def rescan_plugin_directory() -> Tuple[int, int]:
"""
重新扫描插件目录,加载新插件。
@@ -92,4 +117,4 @@ def rescan_plugin_directory() -> Tuple[int, int]:
"""
from src.plugin_system.core.plugin_manager import plugin_manager
return plugin_manager.rescan_plugin_directory()
return plugin_manager.rescan_plugin_directory()

View File

@@ -22,7 +22,6 @@
import traceback
import time
import difflib
import re
from typing import Optional, Union
from src.common.logger import get_logger
@@ -30,7 +29,7 @@ from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.uni_message_sender import HeartFCSender
from src.chat.message_receive.message import MessageSending, MessageRecv
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, replace_user_references_async
from src.person_info.person_info import get_person_info_manager
from maim_message import Seg, UserInfo
from src.config.config import global_config
@@ -50,7 +49,7 @@ async def _send_to_target(
display_message: str = "",
typing: bool = False,
reply_to: str = "",
reply_to_platform_id: str = "",
reply_to_platform_id: Optional[str] = None,
storage_message: bool = True,
show_log: bool = True,
) -> bool:
@@ -61,8 +60,11 @@ async def _send_to_target(
content: 消息内容
stream_id: 目标流ID
display_message: 显示消息
typing: 是否显示正在输入
reply_to: 回复消息格式,如"发送者:消息内容"
typing: 是否模拟打字等待。
reply_to: 回复消息格式"发送者:消息内容"
reply_to_platform_id: 回复消息,格式为"平台:用户ID",如果不提供则自动查找(插件开发者禁用!)
storage_message: 是否存储消息到数据库
show_log: 发送是否显示日志
Returns:
bool: 是否发送成功
@@ -98,6 +100,10 @@ async def _send_to_target(
anchor_message = None
if reply_to:
anchor_message = await _find_reply_message(target_stream, reply_to)
if anchor_message and anchor_message.message_info.user_info and not reply_to_platform_id:
reply_to_platform_id = (
f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}"
)
# 构建发送消息对象
bot_message = MessageSending(
@@ -183,32 +189,8 @@ async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageR
if person_name == sender:
translate_text = message["processed_plain_text"]
# 检查是否有 回复<aaa:bbb> 字段
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
if match := re.search(reply_pattern, translate_text):
aaa = match.group(1)
bbb = match.group(2)
reply_person_id = get_person_info_manager().get_person_id(platform, bbb)
reply_person_name = await get_person_info_manager().get_value(reply_person_id, "person_name") or aaa
# 在内容前加上回复信息
translate_text = re.sub(reply_pattern, f"回复 {reply_person_name}", translate_text, count=1)
# 检查是否有 @<aaa:bbb> 字段
at_pattern = r"@<([^:<>]+):([^:<>]+)>"
at_matches = list(re.finditer(at_pattern, translate_text))
if at_matches:
new_content = ""
last_end = 0
for m in at_matches:
new_content += translate_text[last_end : m.start()]
aaa = m.group(1)
bbb = m.group(2)
at_person_id = get_person_info_manager().get_person_id(platform, bbb)
at_person_name = await get_person_info_manager().get_value(at_person_id, "person_name") or aaa
new_content += f"@{at_person_name}"
last_end = m.end()
new_content += translate_text[last_end:]
translate_text = new_content
# 使用独立函数处理用户引用格式
translate_text = await replace_user_references_async(translate_text, platform)
similarity = difflib.SequenceMatcher(None, text, translate_text).ratio()
if similarity >= 0.9:
@@ -287,12 +269,22 @@ async def text_to_stream(
stream_id: 聊天流ID
typing: 是否显示正在输入
reply_to: 回复消息,格式为"发送者:消息内容"
reply_to_platform_id: 回复消息,格式为"平台:用户ID",如果不提供则自动查找(插件开发者禁用!)
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
return await _send_to_target("text", text, stream_id, "", typing, reply_to, reply_to_platform_id, storage_message)
return await _send_to_target(
"text",
text,
stream_id,
"",
typing,
reply_to,
reply_to_platform_id=reply_to_platform_id,
storage_message=storage_message,
)
async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bool = True) -> bool:
@@ -375,249 +367,3 @@ async def custom_to_stream(
storage_message=storage_message,
show_log=show_log,
)
async def text_to_group(
text: str,
group_id: str,
platform: str = "qq",
typing: bool = False,
reply_to: str = "",
storage_message: bool = True,
) -> bool:
"""向群聊发送文本消息
Args:
text: 要发送的文本内容
group_id: 群聊ID
platform: 平台,默认为"qq"
typing: 是否显示正在输入
reply_to: 回复消息,格式为"发送者:消息内容"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message)
async def text_to_user(
text: str,
user_id: str,
platform: str = "qq",
typing: bool = False,
reply_to: str = "",
storage_message: bool = True,
) -> bool:
"""向用户发送私聊文本消息
Args:
text: 要发送的文本内容
user_id: 用户ID
platform: 平台,默认为"qq"
typing: 是否显示正在输入
reply_to: 回复消息,格式为"发送者:消息内容"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message)
async def emoji_to_group(emoji_base64: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
"""向群聊发送表情包
Args:
emoji_base64: 表情包的base64编码
group_id: 群聊ID
platform: 平台,默认为"qq"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message)
async def emoji_to_user(emoji_base64: str, user_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
"""向用户发送表情包
Args:
emoji_base64: 表情包的base64编码
user_id: 用户ID
platform: 平台,默认为"qq"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message)
async def image_to_group(image_base64: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
"""向群聊发送图片
Args:
image_base64: 图片的base64编码
group_id: 群聊ID
platform: 平台,默认为"qq"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
return await _send_to_target("image", image_base64, stream_id, "", typing=False, storage_message=storage_message)
async def image_to_user(image_base64: str, user_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
"""向用户发送图片
Args:
image_base64: 图片的base64编码
user_id: 用户ID
platform: 平台,默认为"qq"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
return await _send_to_target("image", image_base64, stream_id, "", typing=False)
async def command_to_group(command: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
"""向群聊发送命令
Args:
command: 命令
group_id: 群聊ID
platform: 平台,默认为"qq"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
return await _send_to_target("command", command, stream_id, "", typing=False, storage_message=storage_message)
async def command_to_user(command: str, user_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
"""向用户发送命令
Args:
command: 命令
user_id: 用户ID
platform: 平台,默认为"qq"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
return await _send_to_target("command", command, stream_id, "", typing=False, storage_message=storage_message)
# =============================================================================
# 通用发送函数 - 支持任意消息类型
# =============================================================================
async def custom_to_group(
message_type: str,
content: str,
group_id: str,
platform: str = "qq",
display_message: str = "",
typing: bool = False,
reply_to: str = "",
storage_message: bool = True,
) -> bool:
"""向群聊发送自定义类型消息
Args:
message_type: 消息类型,如"text""image""emoji""video""file"
content: 消息内容通常是base64编码或文本
group_id: 群聊ID
platform: 平台,默认为"qq"
display_message: 显示消息
typing: 是否显示正在输入
reply_to: 回复消息,格式为"发送者:消息内容"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
return await _send_to_target(
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
)
async def custom_to_user(
message_type: str,
content: str,
user_id: str,
platform: str = "qq",
display_message: str = "",
typing: bool = False,
reply_to: str = "",
storage_message: bool = True,
) -> bool:
"""向用户发送自定义类型消息
Args:
message_type: 消息类型,如"text""image""emoji""video""file"
content: 消息内容通常是base64编码或文本
user_id: 用户ID
platform: 平台,默认为"qq"
display_message: 显示消息
typing: 是否显示正在输入
reply_to: 回复消息,格式为"发送者:消息内容"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
return await _send_to_target(
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
)
async def custom_message(
message_type: str,
content: str,
target_id: str,
is_group: bool = True,
platform: str = "qq",
display_message: str = "",
typing: bool = False,
reply_to: str = "",
storage_message: bool = True,
) -> bool:
"""发送自定义消息的通用接口
Args:
message_type: 消息类型,如"text""image""emoji""video""file""audio"
content: 消息内容
target_id: 目标ID群ID或用户ID
is_group: 是否为群聊True为群聊False为私聊
platform: 平台,默认为"qq"
display_message: 显示消息
typing: 是否显示正在输入
reply_to: 回复消息,格式为"发送者:消息内容"
Returns:
bool: 是否发送成功
示例:
# 发送视频到群聊
await send_api.custom_message("video", video_base64, "123456", True)
# 发送文件到用户
await send_api.custom_message("file", file_base64, "987654", False)
# 发送音频到群聊并回复特定消息
await send_api.custom_message("audio", audio_base64, "123456", True, reply_to="张三:你好")
"""
stream_id = get_chat_manager().get_stream_id(platform, target_id, is_group)
return await _send_to_target(
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
)

View File

@@ -0,0 +1,27 @@
from typing import Optional, Type
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.component_types import ComponentType
from src.common.logger import get_logger
logger = get_logger("tool_api")
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
"""获取公开工具实例"""
from src.plugin_system.core import component_registry
tool_class: Type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
return tool_class() if tool_class else None
def get_llm_available_tool_definitions():
"""获取LLM可用的工具定义列表
Returns:
List[Tuple[str, Dict[str, Any]]]: 工具定义列表,为[("tool_name", 定义)]
"""
from src.plugin_system.core import component_registry
llm_available_tools = component_registry.get_llm_available_tools()
return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()]

View File

@@ -1,168 +0,0 @@
"""工具类API模块
提供了各种辅助功能
使用方式:
from src.plugin_system.apis import utils_api
plugin_path = utils_api.get_plugin_path()
data = utils_api.read_json_file("data.json")
timestamp = utils_api.get_timestamp()
"""
import os
import json
import time
import inspect
import datetime
import uuid
from typing import Any, Optional
from src.common.logger import get_logger
logger = get_logger("utils_api")
# =============================================================================
# 文件操作API函数
# =============================================================================
def get_plugin_path(caller_frame=None) -> str:
"""获取调用者插件的路径
Args:
caller_frame: 调用者的栈帧默认为None自动获取
Returns:
str: 插件目录的绝对路径
"""
try:
if caller_frame is None:
caller_frame = inspect.currentframe().f_back # type: ignore
plugin_module_path = inspect.getfile(caller_frame) # type: ignore
plugin_dir = os.path.dirname(plugin_module_path)
return plugin_dir
except Exception as e:
logger.error(f"[UtilsAPI] 获取插件路径失败: {e}")
return ""
def read_json_file(file_path: str, default: Any = None) -> Any:
"""读取JSON文件
Args:
file_path: 文件路径,可以是相对于插件目录的路径
default: 如果文件不存在或读取失败时返回的默认值
Returns:
Any: JSON数据或默认值
"""
try:
# 如果是相对路径,则相对于调用者的插件目录
if not os.path.isabs(file_path):
caller_frame = inspect.currentframe().f_back # type: ignore
plugin_dir = get_plugin_path(caller_frame)
file_path = os.path.join(plugin_dir, file_path)
if not os.path.exists(file_path):
logger.warning(f"[UtilsAPI] 文件不存在: {file_path}")
return default
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
logger.error(f"[UtilsAPI] 读取JSON文件出错: {e}")
return default
def write_json_file(file_path: str, data: Any, indent: int = 2) -> bool:
"""写入JSON文件
Args:
file_path: 文件路径,可以是相对于插件目录的路径
data: 要写入的数据
indent: JSON缩进
Returns:
bool: 是否写入成功
"""
try:
# 如果是相对路径,则相对于调用者的插件目录
if not os.path.isabs(file_path):
caller_frame = inspect.currentframe().f_back # type: ignore
plugin_dir = get_plugin_path(caller_frame)
file_path = os.path.join(plugin_dir, file_path)
# 确保目录存在
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=indent)
return True
except Exception as e:
logger.error(f"[UtilsAPI] 写入JSON文件出错: {e}")
return False
# =============================================================================
# 时间相关API函数
# =============================================================================
def get_timestamp() -> int:
"""获取当前时间戳
Returns:
int: 当前时间戳(秒)
"""
return int(time.time())
def format_time(timestamp: Optional[int | float] = None, format_str: str = "%Y-%m-%d %H:%M:%S") -> str:
"""格式化时间
Args:
timestamp: 时间戳如果为None则使用当前时间
format_str: 时间格式字符串
Returns:
str: 格式化后的时间字符串
"""
try:
if timestamp is None:
timestamp = time.time()
return datetime.datetime.fromtimestamp(timestamp).strftime(format_str)
except Exception as e:
logger.error(f"[UtilsAPI] 格式化时间失败: {e}")
return ""
def parse_time(time_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> int:
"""解析时间字符串为时间戳
Args:
time_str: 时间字符串
format_str: 时间格式字符串
Returns:
int: 时间戳(秒)
"""
try:
dt = datetime.datetime.strptime(time_str, format_str)
return int(dt.timestamp())
except Exception as e:
logger.error(f"[UtilsAPI] 解析时间失败: {e}")
return 0
# =============================================================================
# 其他工具函数
# =============================================================================
def generate_unique_id() -> str:
"""生成唯一ID
Returns:
str: 唯一ID
"""
return str(uuid.uuid4())

View File

@@ -6,6 +6,7 @@
from .base_plugin import BasePlugin
from .base_action import BaseAction
from .base_tool import BaseTool
from .base_command import BaseCommand
from .base_events_handler import BaseEventHandler
from .component_types import (
@@ -15,6 +16,7 @@ from .component_types import (
ComponentInfo,
ActionInfo,
CommandInfo,
ToolInfo,
PluginInfo,
PythonDependency,
EventHandlerInfo,
@@ -27,12 +29,14 @@ __all__ = [
"BasePlugin",
"BaseAction",
"BaseCommand",
"BaseTool",
"ComponentType",
"ActionActivationType",
"ChatMode",
"ComponentInfo",
"ActionInfo",
"CommandInfo",
"ToolInfo",
"PluginInfo",
"PythonDependency",
"ConfigField",

View File

@@ -208,7 +208,7 @@ class BaseAction(ABC):
return False, f"等待新消息失败: {str(e)}"
async def send_text(
self, content: str, reply_to: str = "", reply_to_platform_id: str = "", typing: bool = False
self, content: str, reply_to: str = "", typing: bool = False
) -> bool:
"""发送文本消息
@@ -227,7 +227,6 @@ class BaseAction(ABC):
text=content,
stream_id=self.chat_id,
reply_to=reply_to,
reply_to_platform_id=reply_to_platform_id,
typing=typing,
)
@@ -384,7 +383,7 @@ class BaseAction(ABC):
keyword_case_sensitive=getattr(cls, "keyword_case_sensitive", False),
mode_enable=getattr(cls, "mode_enable", ChatMode.ALL),
parallel_action=getattr(cls, "parallel_action", True),
random_activation_probability=getattr(cls, "random_activation_probability", 0.3),
random_activation_probability=getattr(cls, "random_activation_probability", 0.0),
llm_judge_prompt=getattr(cls, "llm_judge_prompt", ""),
# 使用正确的字段名
action_parameters=getattr(cls, "action_parameters", {}).copy(),

View File

@@ -60,10 +60,10 @@ class BaseCommand(ABC):
pass
def get_config(self, key: str, default=None):
"""获取插件配置值,支持嵌套键访问
"""获取插件配置值,使用嵌套键访问
Args:
key: 配置键名,支持嵌套访问如 "section.subsection.key"
key: 配置键名,使用嵌套访问如 "section.subsection.key"
default: 默认值
Returns:

View File

@@ -0,0 +1,85 @@
from abc import ABC, abstractmethod
from typing import Any, Dict
from rich.traceback import install
from src.common.logger import get_logger
from src.plugin_system.base.component_types import ComponentType, ToolInfo
install(extra_lines=3)
logger = get_logger("base_tool")
class BaseTool(ABC):
"""所有工具的基类"""
name: str = ""
"""工具的名称"""
description: str = ""
"""工具的描述"""
parameters: Dict[str, Any] = {}
"""工具的参数定义"""
available_for_llm: bool = False
"""是否可供LLM使用"""
@classmethod
def get_tool_definition(cls) -> dict[str, Any]:
"""获取工具定义用于LLM工具调用
Returns:
dict: 工具定义字典
"""
if not cls.name or not cls.description or not cls.parameters:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
return {
"type": "function",
"function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
}
@classmethod
def get_tool_info(cls) -> ToolInfo:
"""获取工具信息"""
if not cls.name or not cls.description or not cls.parameters:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
return ToolInfo(
name=cls.name,
tool_description=cls.description,
enabled=cls.available_for_llm,
tool_parameters=cls.parameters,
component_type=ComponentType.TOOL,
)
@abstractmethod
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行工具函数(供llm调用)
通过该方法maicore会通过llm的tool call来调用工具
传入的是json格式的参数符合parameters定义的格式
Args:
function_args: 工具调用参数
Returns:
dict: 工具执行结果
"""
raise NotImplementedError("子类必须实现execute方法")
async def direct_execute(self, **function_args: dict[str, Any]) -> dict[str, Any]:
"""直接执行工具函数(供插件调用)
通过该方法,插件可以直接调用工具,而不需要传入字典格式的参数
插件可以直接调用此方法,用更加明了的方式传入参数
示例: result = await tool.direct_execute(arg1="参数",arg2="参数2")
工具开发者可以重写此方法以实现与llm调用差异化的执行逻辑
Args:
**function_args: 工具调用参数
Returns:
dict: 工具执行结果
"""
if self.parameters and (missing := [p for p in self.parameters.get("required", []) if p not in function_args]):
raise ValueError(f"工具类 {self.__class__.__name__} 缺少必要参数: {', '.join(missing)}")
return await self.execute(function_args)

View File

@@ -10,6 +10,7 @@ class ComponentType(Enum):
ACTION = "action" # 动作组件
COMMAND = "command" # 命令组件
TOOL = "tool" # 服务组件(预留)
SCHEDULER = "scheduler" # 定时任务组件(预留)
EVENT_HANDLER = "event_handler" # 事件处理组件(预留)
@@ -144,7 +145,17 @@ class CommandInfo(ComponentInfo):
def __post_init__(self):
super().__post_init__()
self.component_type = ComponentType.COMMAND
@dataclass
class ToolInfo(ComponentInfo):
"""工具组件信息"""
tool_parameters: Dict[str, Any] = field(default_factory=dict) # 工具参数定义
tool_description: str = "" # 工具描述
def __post_init__(self):
super().__post_init__()
self.component_type = ComponentType.TOOL
@dataclass
class EventHandlerInfo(ComponentInfo):

View File

@@ -6,14 +6,12 @@
from src.plugin_system.core.plugin_manager import plugin_manager
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.core.dependency_manager import dependency_manager
from src.plugin_system.core.events_manager import events_manager
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
__all__ = [
"plugin_manager",
"component_registry",
"dependency_manager",
"events_manager",
"global_announcement_manager",
]

View File

@@ -6,6 +6,7 @@ from src.common.logger import get_logger
from src.plugin_system.base.component_types import (
ComponentInfo,
ActionInfo,
ToolInfo,
CommandInfo,
EventHandlerInfo,
PluginInfo,
@@ -13,6 +14,7 @@ from src.plugin_system.base.component_types import (
)
from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.base_action import BaseAction
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.base_events_handler import BaseEventHandler
logger = get_logger("component_registry")
@@ -30,7 +32,7 @@ class ComponentRegistry:
"""组件注册表 命名空间式组件名 -> 组件信息"""
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
"""类型 -> 组件原名称 -> 组件信息"""
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseEventHandler]]] = {}
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler]]] = {}
"""命名空间式组件名 -> 组件类"""
# 插件注册表
@@ -49,6 +51,10 @@ class ComponentRegistry:
self._command_patterns: Dict[Pattern, str] = {}
"""编译后的正则 -> command名"""
# 工具特定注册表
self._tool_registry: Dict[str, Type[BaseTool]] = {} # 工具名 -> 工具类
self._llm_available_tools: Dict[str, Type[BaseTool]] = {} # llm可用的工具名 -> 工具类
# EventHandler特定注册表
self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {}
"""event_handler名 -> event_handler类"""
@@ -79,7 +85,9 @@ class ComponentRegistry:
return True
def register_component(
self, component_info: ComponentInfo, component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler]]
self,
component_info: ComponentInfo,
component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler, BaseTool]],
) -> bool:
"""注册组件
@@ -125,6 +133,10 @@ class ComponentRegistry:
assert isinstance(component_info, CommandInfo)
assert issubclass(component_class, BaseCommand)
ret = self._register_command_component(component_info, component_class)
case ComponentType.TOOL:
assert isinstance(component_info, ToolInfo)
assert issubclass(component_class, BaseTool)
ret = self._register_tool_component(component_info, component_class)
case ComponentType.EVENT_HANDLER:
assert isinstance(component_info, EventHandlerInfo)
assert issubclass(component_class, BaseEventHandler)
@@ -180,6 +192,18 @@ class ComponentRegistry:
return True
def _register_tool_component(self, tool_info: ToolInfo, tool_class: Type[BaseTool]) -> bool:
"""注册Tool组件到Tool特定注册表"""
tool_name = tool_info.name
self._tool_registry[tool_name] = tool_class
# 如果是llm可用的且启用的工具,添加到 llm可用工具列表
if tool_info.enabled:
self._llm_available_tools[tool_name] = tool_class
return True
def _register_event_handler_component(
self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]
) -> bool:
@@ -222,6 +246,9 @@ class ComponentRegistry:
keys_to_remove = [k for k, v in self._command_patterns.items() if v == component_name]
for key in keys_to_remove:
self._command_patterns.pop(key)
case ComponentType.TOOL:
self._tool_registry.pop(component_name)
self._llm_available_tools.pop(component_name)
case ComponentType.EVENT_HANDLER:
from .events_manager import events_manager # 延迟导入防止循环导入问题
@@ -234,13 +261,13 @@ class ComponentRegistry:
self._components_classes.pop(namespaced_name)
logger.info(f"组件 {component_name} 已移除")
return True
except KeyError:
logger.warning(f"移除组件时未找到组件: {component_name}")
except KeyError as e:
logger.warning(f"移除组件时未找到组件: {component_name}, 发生错误: {e}")
return False
except Exception as e:
logger.error(f"移除组件 {component_name} 时发生错误: {e}")
return False
def remove_plugin_registry(self, plugin_name: str) -> bool:
"""移除插件注册信息
@@ -281,6 +308,10 @@ class ComponentRegistry:
assert isinstance(target_component_info, CommandInfo)
pattern = target_component_info.command_pattern
self._command_patterns[re.compile(pattern)] = component_name
case ComponentType.TOOL:
assert isinstance(target_component_info, ToolInfo)
assert issubclass(target_component_class, BaseTool)
self._llm_available_tools[component_name] = target_component_class
case ComponentType.EVENT_HANDLER:
assert isinstance(target_component_info, EventHandlerInfo)
assert issubclass(target_component_class, BaseEventHandler)
@@ -308,20 +339,29 @@ class ComponentRegistry:
logger.warning(f"组件 {component_name} 未注册,无法禁用")
return False
target_component_info.enabled = False
match component_type:
case ComponentType.ACTION:
self._default_actions.pop(component_name, None)
case ComponentType.COMMAND:
self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name}
case ComponentType.EVENT_HANDLER:
self._enabled_event_handlers.pop(component_name, None)
from .events_manager import events_manager # 延迟导入防止循环导入问题
try:
match component_type:
case ComponentType.ACTION:
self._default_actions.pop(component_name)
case ComponentType.COMMAND:
self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name}
case ComponentType.TOOL:
self._llm_available_tools.pop(component_name)
case ComponentType.EVENT_HANDLER:
self._enabled_event_handlers.pop(component_name)
from .events_manager import events_manager # 延迟导入防止循环导入问题
await events_manager.unregister_event_subscriber(component_name)
self._components[component_name].enabled = False
self._components_by_type[component_type][component_name].enabled = False
logger.info(f"组件 {component_name} 已禁用")
return True
await events_manager.unregister_event_subscriber(component_name)
self._components[component_name].enabled = False
self._components_by_type[component_type][component_name].enabled = False
logger.info(f"组件 {component_name} 已禁用")
return True
except KeyError as e:
logger.warning(f"禁用组件时未找到组件或已禁用: {component_name}, 发生错误: {e}")
return False
except Exception as e:
logger.error(f"禁用组件 {component_name} 时发生错误: {e}")
return False
# === 组件查询方法 ===
def get_component_info(
@@ -371,7 +411,7 @@ class ComponentRegistry:
self,
component_name: str,
component_type: Optional[ComponentType] = None,
) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler]]]:
) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler], Type[BaseTool]]]:
"""获取组件类,支持自动命名空间解析
Args:
@@ -476,6 +516,27 @@ class ComponentRegistry:
command_info,
)
# === Tool 特定查询方法 ===
def get_tool_registry(self) -> Dict[str, Type[BaseTool]]:
"""获取Tool注册表"""
return self._tool_registry.copy()
def get_llm_available_tools(self) -> Dict[str, Type[BaseTool]]:
"""获取LLM可用的Tool列表"""
return self._llm_available_tools.copy()
def get_registered_tool_info(self, tool_name: str) -> Optional[ToolInfo]:
"""获取Tool信息
Args:
tool_name: 工具名称
Returns:
ToolInfo: 工具信息对象,如果工具不存在则返回 None
"""
info = self.get_component_info(tool_name, ComponentType.TOOL)
return info if isinstance(info, ToolInfo) else None
# === EventHandler 特定查询方法 ===
def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]:
@@ -529,17 +590,21 @@ class ComponentRegistry:
"""获取注册中心统计信息"""
action_components: int = 0
command_components: int = 0
tool_components: int = 0
events_handlers: int = 0
for component in self._components.values():
if component.component_type == ComponentType.ACTION:
action_components += 1
elif component.component_type == ComponentType.COMMAND:
command_components += 1
elif component.component_type == ComponentType.TOOL:
tool_components += 1
elif component.component_type == ComponentType.EVENT_HANDLER:
events_handlers += 1
return {
"action_components": action_components,
"command_components": command_components,
"tool_components": tool_components,
"event_handlers": events_handlers,
"total_components": len(self._components),
"total_plugins": len(self._plugins),

View File

@@ -1,190 +0,0 @@
"""
插件依赖管理器
负责检查和安装插件的Python包依赖
"""
import subprocess
import sys
import importlib
from typing import List, Dict, Tuple, Any
from src.common.logger import get_logger
from src.plugin_system.base.component_types import PythonDependency
logger = get_logger("dependency_manager")
class DependencyManager:
"""依赖管理器"""
def __init__(self):
self.install_log: List[str] = []
self.failed_installs: Dict[str, str] = {}
def check_dependencies(
self, dependencies: List[PythonDependency]
) -> Tuple[List[PythonDependency], List[PythonDependency]]:
"""检查依赖包状态
Args:
dependencies: 依赖包列表
Returns:
Tuple[List[PythonDependency], List[PythonDependency]]: (缺失的依赖, 可选缺失的依赖)
"""
missing_required = []
missing_optional = []
for dep in dependencies:
if self._is_package_available(dep.package_name):
logger.debug(f"依赖包已存在: {dep.package_name}")
elif dep.optional:
missing_optional.append(dep)
logger.warning(f"可选依赖包缺失: {dep.package_name} - {dep.description}")
else:
missing_required.append(dep)
logger.error(f"必需依赖包缺失: {dep.package_name} - {dep.description}")
return missing_required, missing_optional
def _is_package_available(self, package_name: str) -> bool:
"""检查包是否可用"""
try:
importlib.import_module(package_name)
return True
except ImportError:
return False
def install_dependencies(self, dependencies: List[PythonDependency], auto_install: bool = False) -> bool:
"""安装依赖包
Args:
dependencies: 需要安装的依赖包列表
auto_install: 是否自动安装True时不询问用户
Returns:
bool: 安装是否成功
"""
if not dependencies:
return True
logger.info(f"需要安装 {len(dependencies)} 个依赖包")
# 显示将要安装的包
for dep in dependencies:
install_cmd = dep.get_pip_requirement()
logger.info(f" - {install_cmd} {'(可选)' if dep.optional else '(必需)'}")
if dep.description:
logger.info(f" 说明: {dep.description}")
if not auto_install:
# 这里可以添加用户确认逻辑
logger.warning("手动安装模式:请手动运行 pip install 命令安装依赖包")
return False
# 执行安装
success_count = 0
for dep in dependencies:
if self._install_single_package(dep):
success_count += 1
else:
self.failed_installs[dep.package_name] = f"安装失败: {dep.get_pip_requirement()}"
logger.info(f"依赖安装完成: {success_count}/{len(dependencies)} 个成功")
return success_count == len(dependencies)
def _install_single_package(self, dependency: PythonDependency) -> bool:
"""安装单个包"""
pip_requirement = dependency.get_pip_requirement()
try:
logger.info(f"正在安装: {pip_requirement}")
# 使用subprocess安装包
cmd = [sys.executable, "-m", "pip", "install", pip_requirement]
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=300, # 5分钟超时
)
if result.returncode == 0:
logger.info(f"✅ 成功安装: {pip_requirement}")
self.install_log.append(f"成功安装: {pip_requirement}")
return True
else:
logger.error(f"❌ 安装失败: {pip_requirement}")
logger.error(f"错误输出: {result.stderr}")
self.install_log.append(f"安装失败: {pip_requirement} - {result.stderr}")
return False
except subprocess.TimeoutExpired:
logger.error(f"❌ 安装超时: {pip_requirement}")
return False
except Exception as e:
logger.error(f"❌ 安装异常: {pip_requirement} - {str(e)}")
return False
def generate_requirements_file(
self, plugins_dependencies: List[List[PythonDependency]], output_path: str = "plugin_requirements.txt"
) -> bool:
"""生成插件依赖的requirements文件
Args:
plugins_dependencies: 所有插件的依赖列表
output_path: 输出文件路径
Returns:
bool: 生成是否成功
"""
try:
all_deps = {}
# 合并所有插件的依赖
for plugin_deps in plugins_dependencies:
for dep in plugin_deps:
key = dep.install_name
if key in all_deps:
# 如果已存在,可以添加版本兼容性检查逻辑
existing = all_deps[key]
if dep.version and existing.version != dep.version:
logger.warning(f"依赖版本冲突: {key} ({existing.version} vs {dep.version})")
else:
all_deps[key] = dep
# 写入requirements文件
with open(output_path, "w", encoding="utf-8") as f:
f.write("# 插件依赖包自动生成\n")
f.write("# Auto-generated plugin dependencies\n\n")
# 按包名排序
sorted_deps = sorted(all_deps.values(), key=lambda x: x.install_name)
for dep in sorted_deps:
requirement = dep.get_pip_requirement()
if dep.description:
f.write(f"# {dep.description}\n")
if dep.optional:
f.write("# Optional dependency\n")
f.write(f"{requirement}\n\n")
logger.info(f"已生成插件依赖文件: {output_path} ({len(all_deps)} 个包)")
return True
except Exception as e:
logger.error(f"生成requirements文件失败: {str(e)}")
return False
def get_install_summary(self) -> Dict[str, Any]:
"""获取安装摘要"""
return {
"install_log": self.install_log.copy(),
"failed_installs": self.failed_installs.copy(),
"total_attempts": len(self.install_log),
"failed_count": len(self.failed_installs),
}
# 全局依赖管理器实例
dependency_manager = DependencyManager()

View File

@@ -13,6 +13,8 @@ class GlobalAnnouncementManager:
self._user_disabled_commands: Dict[str, List[str]] = {}
# 用户禁用的事件处理器chat_id -> [handler_name]
self._user_disabled_event_handlers: Dict[str, List[str]] = {}
# 用户禁用的工具chat_id -> [tool_name]
self._user_disabled_tools: Dict[str, List[str]] = {}
def disable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
"""禁用特定聊天的某个动作"""
@@ -77,6 +79,27 @@ class GlobalAnnouncementManager:
return False
return False
def disable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool:
"""禁用特定聊天的某个工具"""
if chat_id not in self._user_disabled_tools:
self._user_disabled_tools[chat_id] = []
if tool_name in self._user_disabled_tools[chat_id]:
logger.warning(f"工具 {tool_name} 已经被禁用")
return False
self._user_disabled_tools[chat_id].append(tool_name)
return True
def enable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool:
"""启用特定聊天的某个工具"""
if chat_id in self._user_disabled_tools:
try:
self._user_disabled_tools[chat_id].remove(tool_name)
return True
except ValueError:
logger.warning(f"工具 {tool_name} 不在禁用列表中")
return False
return False
def get_disabled_chat_actions(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有动作"""
return self._user_disabled_actions.get(chat_id, []).copy()
@@ -88,6 +111,10 @@ class GlobalAnnouncementManager:
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有事件处理器"""
return self._user_disabled_event_handlers.get(chat_id, []).copy()
def get_disabled_chat_tools(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有工具"""
return self._user_disabled_tools.get(chat_id, []).copy()
global_announcement_manager = GlobalAnnouncementManager()

View File

@@ -8,10 +8,9 @@ from pathlib import Path
from src.common.logger import get_logger
from src.plugin_system.base.plugin_base import PluginBase
from src.plugin_system.base.component_types import ComponentType, PythonDependency
from src.plugin_system.base.component_types import ComponentType
from src.plugin_system.utils.manifest_utils import VersionComparator
from .component_registry import component_registry
from .dependency_manager import dependency_manager
logger = get_logger("plugin_manager")
@@ -207,104 +206,6 @@ class PluginManager:
"""
return self.loaded_plugins.get(plugin_name)
def check_all_dependencies(self, auto_install: bool = False) -> Dict[str, Any]:
"""检查所有插件的Python依赖包
Args:
auto_install: 是否自动安装缺失的依赖包
Returns:
Dict[str, any]: 检查结果摘要
"""
logger.info("开始检查所有插件的Python依赖包...")
all_required_missing: List[PythonDependency] = []
all_optional_missing: List[PythonDependency] = []
plugin_status = {}
for plugin_name in self.loaded_plugins:
plugin_info = component_registry.get_plugin_info(plugin_name)
if not plugin_info or not plugin_info.python_dependencies:
plugin_status[plugin_name] = {"status": "no_dependencies", "missing": []}
continue
logger.info(f"检查插件 {plugin_name} 的依赖...")
missing_required, missing_optional = dependency_manager.check_dependencies(plugin_info.python_dependencies)
if missing_required:
all_required_missing.extend(missing_required)
plugin_status[plugin_name] = {
"status": "missing_required",
"missing": [dep.package_name for dep in missing_required],
"optional_missing": [dep.package_name for dep in missing_optional],
}
logger.error(f"插件 {plugin_name} 缺少必需依赖: {[dep.package_name for dep in missing_required]}")
elif missing_optional:
all_optional_missing.extend(missing_optional)
plugin_status[plugin_name] = {
"status": "missing_optional",
"missing": [],
"optional_missing": [dep.package_name for dep in missing_optional],
}
logger.warning(f"插件 {plugin_name} 缺少可选依赖: {[dep.package_name for dep in missing_optional]}")
else:
plugin_status[plugin_name] = {"status": "ok", "missing": []}
logger.info(f"插件 {plugin_name} 依赖检查通过")
# 汇总结果
total_missing = len({dep.package_name for dep in all_required_missing})
total_optional_missing = len({dep.package_name for dep in all_optional_missing})
logger.info(f"依赖检查完成 - 缺少必需包: {total_missing}个, 缺少可选包: {total_optional_missing}")
# 如果需要自动安装
install_success = True
if auto_install and all_required_missing:
unique_required = {dep.package_name: dep for dep in all_required_missing}
logger.info(f"开始自动安装 {len(unique_required)} 个必需依赖包...")
install_success = dependency_manager.install_dependencies(list(unique_required.values()), auto_install=True)
return {
"total_plugins_checked": len(plugin_status),
"plugins_with_missing_required": len(
[p for p in plugin_status.values() if p["status"] == "missing_required"]
),
"plugins_with_missing_optional": len(
[p for p in plugin_status.values() if p["status"] == "missing_optional"]
),
"total_missing_required": total_missing,
"total_missing_optional": total_optional_missing,
"plugin_status": plugin_status,
"auto_install_attempted": auto_install and bool(all_required_missing),
"auto_install_success": install_success,
"install_summary": dependency_manager.get_install_summary(),
}
def generate_plugin_requirements(self, output_path: str = "plugin_requirements.txt") -> bool:
"""生成所有插件依赖的requirements文件
Args:
output_path: 输出文件路径
Returns:
bool: 生成是否成功
"""
logger.info("开始生成插件依赖requirements文件...")
all_dependencies = []
for plugin_name in self.loaded_plugins:
plugin_info = component_registry.get_plugin_info(plugin_name)
if plugin_info and plugin_info.python_dependencies:
all_dependencies.append(plugin_info.python_dependencies)
if not all_dependencies:
logger.info("没有找到任何插件依赖")
return False
return dependency_manager.generate_requirements_file(all_dependencies, output_path)
# === 查询方法 ===
def list_loaded_plugins(self) -> List[str]:
"""
@@ -323,6 +224,18 @@ class PluginManager:
list: 已注册的插件类名称列表。
"""
return list(self.plugin_classes.keys())
def get_plugin_path(self, plugin_name: str) -> Optional[str]:
"""
获取指定插件的路径。
Args:
plugin_name: 插件名称
Returns:
Optional[str]: 插件目录的绝对路径如果插件不存在则返回None。
"""
return self.plugin_paths.get(plugin_name)
# === 私有方法 ===
# == 目录管理 ==
@@ -388,6 +301,7 @@ class PluginManager:
return False
module = module_from_spec(spec)
module.__package__ = module_name # 设置模块包名
spec.loader.exec_module(module)
logger.debug(f"插件模块加载成功: {plugin_file}")
@@ -444,6 +358,7 @@ class PluginManager:
stats = component_registry.get_registry_stats()
action_count = stats.get("action_components", 0)
command_count = stats.get("command_components", 0)
tool_count = stats.get("tool_components", 0)
event_handler_count = stats.get("event_handlers", 0)
total_components = stats.get("total_components", 0)
@@ -451,7 +366,7 @@ class PluginManager:
if total_registered > 0:
logger.info("🎉 插件系统加载完成!")
logger.info(
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, EventHandler: {event_handler_count})"
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, EventHandler: {event_handler_count})"
)
# 显示详细的插件列表
@@ -486,6 +401,9 @@ class PluginManager:
command_components = [
c for c in plugin_info.components if c.component_type == ComponentType.COMMAND
]
tool_components = [
c for c in plugin_info.components if c.component_type == ComponentType.TOOL
]
event_handler_components = [
c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER
]
@@ -497,7 +415,9 @@ class PluginManager:
if command_components:
command_names = [c.name for c in command_components]
logger.info(f" ⚡ Command组件: {', '.join(command_names)}")
if tool_components:
tool_names = [c.name for c in tool_components]
logger.info(f" 🛠️ Tool组件: {', '.join(tool_names)}")
if event_handler_components:
event_handler_names = [c.name for c in event_handler_components]
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}")

View File

@@ -1,14 +1,16 @@
import json
import time
from typing import List, Dict, Tuple, Optional, Any
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
import time
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.tools.tool_use import ToolUser
from src.chat.utils.json_utils import process_llm_tool_calls
from typing import List, Dict, Tuple, Optional
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger
logger = get_logger("tool_executor")
logger = get_logger("tool_use")
def init_tool_executor_prompt():
@@ -28,6 +30,10 @@ If you need to use a tool, please directly call the corresponding tool function.
Prompt(tool_executor_prompt, "tool_executor_prompt")
# 初始化提示词
init_tool_executor_prompt()
class ToolExecutor:
"""独立的工具执行器组件
@@ -51,9 +57,6 @@ class ToolExecutor:
request_type="tool_executor",
)
# 初始化工具实例
self.tool_instance = ToolUser()
# 缓存配置
self.enable_cache = enable_cache
self.cache_ttl = cache_ttl
@@ -73,7 +76,7 @@ class ToolExecutor:
return_details: 是否返回详细信息(使用的工具列表和提示词)
Returns:
如果return_details为False: List[Dict] - 工具执行结果列表
如果return_details为False: Tuple[List[Dict], List[str], str] - (工具执行结果列表, , )
如果return_details为True: Tuple[List[Dict], List[str], str] - (结果列表, 使用的工具, 提示词)
"""
@@ -82,15 +85,15 @@ class ToolExecutor:
if cached_result := self._get_from_cache(cache_key):
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行")
if not return_details:
return cached_result, [], "使用缓存结果"
return cached_result, [], ""
# 从缓存结果中提取工具名称
used_tools = [result.get("tool_name", "unknown") for result in cached_result]
return cached_result, used_tools, "使用缓存结果"
return cached_result, used_tools, ""
# 缓存未命中,执行工具调用
# 获取可用工具
tools = self.tool_instance._define_tools()
tools = self._get_tool_definitions()
# 获取当前时间
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
@@ -112,6 +115,7 @@ class ToolExecutor:
# 调用LLM进行工具决策
response, other_info = await self.llm_model.generate_response_async(prompt=prompt, tools=tools)
# TODO: 在APIADA加入后完全修复这里
# 解析LLM响应
if len(other_info) == 3:
reasoning_content, model_name, tool_calls = other_info
@@ -133,6 +137,11 @@ class ToolExecutor:
return tool_results, used_tools, prompt
else:
return tool_results, [], ""
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
all_tools = get_llm_available_tool_definitions()
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
return [parameters for name, parameters in all_tools if name not in user_disabled_tools]
async def _execute_tool_calls(self, tool_calls) -> Tuple[List[Dict], List[str]]:
"""执行工具调用
@@ -172,7 +181,7 @@ class ToolExecutor:
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
# 执行工具
result = await self.tool_instance._execute_tool_call(tool_call)
result = await self._execute_tool_call(tool_call)
if result:
tool_info = {
@@ -205,6 +214,45 @@ class ToolExecutor:
return tool_results, used_tools
async def _execute_tool_call(self, tool_call: Dict[str, Any]) -> Optional[Dict]:
# sourcery skip: use-assigned-variable
"""执行单个工具调用
Args:
tool_call: 工具调用对象
Returns:
Optional[Dict]: 工具调用结果如果失败则返回None
"""
try:
function_name = tool_call["function"]["name"]
function_args = json.loads(tool_call["function"]["arguments"])
function_args["llm_called"] = True # 标记为LLM调用
# 获取对应工具实例
tool_instance = get_tool_instance(function_name)
if not tool_instance:
logger.warning(f"未知工具名称: {function_name}")
return None
# 执行工具
result = await tool_instance.execute(function_args)
if result:
# 直接使用 function_name 作为 tool_type
tool_type = function_name
return {
"tool_call_id": tool_call["id"],
"role": "tool",
"name": function_name,
"type": tool_type,
"content": result["content"],
}
return None
except Exception as e:
logger.error(f"执行工具调用时发生错误: {str(e)}")
return None
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
"""生成缓存键
@@ -272,15 +320,6 @@ class ToolExecutor:
if expired_keys:
logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存")
def get_available_tools(self) -> List[str]:
"""获取可用工具列表
Returns:
List[str]: 可用工具名称列表
"""
tools = self.tool_instance._define_tools()
return [tool.get("function", {}).get("name", "unknown") for tool in tools]
async def execute_specific_tool(
self, tool_name: str, tool_args: Dict, validate_args: bool = True
) -> Optional[Dict]:
@@ -299,7 +338,7 @@ class ToolExecutor:
logger.info(f"{self.log_prefix}直接执行工具: {tool_name}")
result = await self.tool_instance._execute_tool_call(tool_call)
result = await self._execute_tool_call(tool_call)
if result:
tool_info = {
@@ -366,12 +405,8 @@ class ToolExecutor:
logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}")
# 初始化提示词
init_tool_executor_prompt()
"""
使用示例
ToolExecutor使用示例
# 1. 基础使用 - 从聊天消息执行工具启用缓存默认TTL=3
executor = ToolExecutor(executor_id="my_executor")
@@ -400,7 +435,6 @@ result = await executor.execute_specific_tool(
)
# 6. 缓存管理
available_tools = executor.get_available_tools()
cache_status = executor.get_cache_status() # 查看缓存状态
executor.clear_cache() # 清空缓存
executor.set_cache_config(cache_ttl=5) # 动态修改缓存配置

View File

@@ -163,12 +163,11 @@ class VersionComparator:
version_normalized, max_normalized
)
if is_compatible:
logger.info(f"版本兼容性检查:{compat_msg}")
return True, compat_msg
else:
if not is_compatible:
return False, f"版本 {version_normalized} 高于最大支持版本 {max_normalized},且无兼容性映射"
logger.info(f"版本兼容性检查:{compat_msg}")
return True, compat_msg
return True, ""
@staticmethod
@@ -358,14 +357,10 @@ class ManifestValidator:
if self.validation_errors:
report.append("❌ 验证错误:")
for error in self.validation_errors:
report.append(f" - {error}")
report.extend(f" - {error}" for error in self.validation_errors)
if self.validation_warnings:
report.append("⚠️ 验证警告:")
for warning in self.validation_warnings:
report.append(f" - {warning}")
report.extend(f" - {warning}" for warning in self.validation_warnings)
if not self.validation_errors and not self.validation_warnings:
report.append("✅ Manifest文件验证通过")

View File

@@ -24,11 +24,6 @@
"is_built_in": true,
"plugin_type": "action_provider",
"components": [
{
"type": "action",
"name": "reply",
"description": "参与聊天回复,发送文本进行表达"
},
{
"type": "action",
"name": "no_reply",

View File

@@ -9,7 +9,8 @@ from src.common.logger import get_logger
# 导入API模块 - 标准Python包方式
from src.plugin_system.apis import emoji_api, llm_api, message_api
from src.plugins.built_in.core_actions.no_reply import NoReplyAction
# 注释不再需要导入NoReplyAction因为计数器管理已移至heartFC_chat.py
# from src.plugins.built_in.core_actions.no_reply import NoReplyAction
from src.config.config import global_config
@@ -20,10 +21,14 @@ class EmojiAction(BaseAction):
"""表情动作 - 发送表情包"""
# 激活设置
activation_type = ActionActivationType.RANDOM
if global_config.emoji.emoji_activate_type == "llm":
activation_type = ActionActivationType.LLM_JUDGE
random_activation_probability = 0
else:
activation_type = ActionActivationType.RANDOM
random_activation_probability = global_config.emoji.emoji_chance
mode_enable = ChatMode.ALL
parallel_action = True
random_activation_probability = 0.2 # 默认值,可通过配置覆盖
# 动作基本信息
action_name = "emoji"
@@ -115,7 +120,7 @@ class EmojiAction(BaseAction):
logger.error(f"{self.log_prefix} 未找到'utils_small'模型配置无法调用LLM")
return False, "未找到'utils_small'模型配置"
success, chosen_emotion, _, _ = await llm_api.generate_with_model(
success, chosen_emotion = await llm_api.generate_with_model(
prompt, model_config=chat_model_config, request_type="emoji"
)
@@ -143,8 +148,8 @@ class EmojiAction(BaseAction):
logger.error(f"{self.log_prefix} 表情包发送失败")
return False, "表情包发送失败"
# 重置NoReplyAction的连续计数器
NoReplyAction.reset_consecutive_count()
# 注释:重置NoReplyAction的连续计数器现在由heartFC_chat.py统一管理
# NoReplyAction.reset_consecutive_count()
return True, f"发送表情包: {emoji_description}"

View File

@@ -1,6 +1,7 @@
import random
import time
from typing import Tuple
from typing import Tuple, List
from collections import deque
# 导入新插件系统
from src.plugin_system import BaseAction, ActionActivationType, ChatMode
@@ -17,11 +18,15 @@ logger = get_logger("no_reply_action")
class NoReplyAction(BaseAction):
"""不回复动作,根据新消息的兴趣值或数量决定何时结束等待.
"""不回复动作,支持waiting和breaking两种形式.
新的等待逻辑:
1. 新消息累计兴趣值超过阈值 (默认10) 则结束等待
2. 累计新消息数量达到随机阈值 (默认5-10条) 则结束等待
waiting形式:
- 只要有新消息就结束动作
- 记录新消息的兴趣度到列表(最多保留最近三项)
- 如果最近三次动作都是no_reply且最近新消息列表兴趣度之和小于阈值就进入breaking形式
breaking形式:
- 和原有逻辑一致,需要消息满足一定数量或累计一定兴趣值才结束动作
"""
focus_activation_type = ActionActivationType.NEVER
@@ -35,112 +40,45 @@ class NoReplyAction(BaseAction):
# 连续no_reply计数器
_consecutive_count = 0
# 最近三次no_reply的新消息兴趣度记录
_recent_interest_records: deque = deque(maxlen=3)
# 新增:兴趣值退出阈值
# 兴趣值退出阈值
_interest_exit_threshold = 3.0
# 新增:消息数量退出阈值
_min_exit_message_count = 5
_max_exit_message_count = 10
# 消息数量退出阈值
_min_exit_message_count = 3
_max_exit_message_count = 6
# 动作参数定义
action_parameters = {}
# 动作使用场景
action_require = ["你发送了消息,目前无人回复"]
action_require = [""]
# 关联类型
associated_types = []
async def execute(self) -> Tuple[bool, str]:
"""执行不回复动作"""
import asyncio
try:
# 增加连续计数
NoReplyAction._consecutive_count += 1
count = NoReplyAction._consecutive_count
reason = self.action_data.get("reason", "")
start_time = self.action_data.get("loop_start_time", time.time())
check_interval = 0.6 # 每秒检查一次
check_interval = 0.6
# 随机生成本次等待需要的新消息数量阈值
exit_message_count_threshold = random.randint(self._min_exit_message_count, self._max_exit_message_count)
logger.info(
f"{self.log_prefix} 本次no_reply需要 {exit_message_count_threshold} 条新消息或累计兴趣值超过 {self._interest_exit_threshold} 才能打断"
)
# 判断使用哪种形式
form_type = self._determine_form_type()
logger.info(f"{self.log_prefix} 选择不回复(第{NoReplyAction._consecutive_count + 1}次),使用{form_type}形式,原因: {reason}")
logger.info(f"{self.log_prefix} 选择不回复(第{count}次),开始摸鱼,原因: {reason}")
# 增加连续计数在确定要执行no_reply时才增加
NoReplyAction._consecutive_count += 1
# 进入等待状态
while True:
current_time = time.time()
elapsed_time = current_time - start_time
# 1. 检查新消息
recent_messages_dict = message_api.get_messages_by_time_in_chat(
chat_id=self.chat_id,
start_time=start_time,
end_time=current_time,
filter_mai=True,
filter_command=True,
)
new_message_count = len(recent_messages_dict)
# 2. 检查消息数量是否达到阈值
talk_frequency = global_config.chat.get_current_talk_frequency(self.chat_id)
if new_message_count >= exit_message_count_threshold / talk_frequency:
logger.info(
f"{self.log_prefix} 累计消息数量达到{new_message_count}条(>{exit_message_count_threshold / talk_frequency}),结束等待"
)
exit_reason = f"{global_config.bot.nickname}(你)看到了{new_message_count}条新消息,可以考虑一下是否要进行回复"
await self.store_action_info(
action_build_into_prompt=False,
action_prompt_display=exit_reason,
action_done=True,
)
return True, f"累计消息数量达到{new_message_count}条,结束等待 (等待时间: {elapsed_time:.1f}秒)"
# 3. 检查累计兴趣值
if new_message_count > 0:
accumulated_interest = 0.0
for msg_dict in recent_messages_dict:
text = msg_dict.get("processed_plain_text", "")
interest_value = msg_dict.get("interest_value", 0.0)
if text:
accumulated_interest += interest_value
talk_frequency = global_config.chat.get_current_talk_frequency(self.chat_id)
# 只在兴趣值变化时输出log
if not hasattr(self, "_last_accumulated_interest") or accumulated_interest != self._last_accumulated_interest:
logger.info(f"{self.log_prefix} 当前累计兴趣值: {accumulated_interest:.2f}, 当前聊天频率: {talk_frequency:.2f}")
self._last_accumulated_interest = accumulated_interest
if accumulated_interest >= self._interest_exit_threshold / talk_frequency:
logger.info(
f"{self.log_prefix} 累计兴趣值达到{accumulated_interest:.2f}(>{self._interest_exit_threshold / talk_frequency}),结束等待"
)
exit_reason = f"{global_config.bot.nickname}(你)感觉到了大家浓厚的兴趣(兴趣值{accumulated_interest:.1f}),决定重新加入讨论"
await self.store_action_info(
action_build_into_prompt=False,
action_prompt_display=exit_reason,
action_done=True,
)
return (
True,
f"累计兴趣值达到{accumulated_interest:.2f},结束等待 (等待时间: {elapsed_time:.1f}秒)",
)
# 每10秒输出一次等待状态
if int(elapsed_time) > 0 and int(elapsed_time) % 10 == 0:
logger.debug(
f"{self.log_prefix} 已等待{elapsed_time:.0f}秒,累计{new_message_count}条消息,继续等待..."
)
# 使用 asyncio.sleep(1) 来避免在同一秒内重复打印日志
await asyncio.sleep(1)
# 短暂等待后继续检查
await asyncio.sleep(check_interval)
if form_type == "waiting":
return await self._execute_waiting_form(start_time, check_interval)
else:
return await self._execute_breaking_form(start_time, check_interval)
except Exception as e:
logger.error(f"{self.log_prefix} 不回复动作执行失败: {e}")
@@ -153,8 +91,191 @@ class NoReplyAction(BaseAction):
)
return False, f"不回复动作执行失败: {e}"
def _determine_form_type(self) -> str:
"""判断使用哪种形式的no_reply"""
# 如果连续no_reply次数少于3次使用waiting形式
if NoReplyAction._consecutive_count < 3:
return "waiting"
# 如果最近三次记录不足使用waiting形式
if len(NoReplyAction._recent_interest_records) < 3:
return "waiting"
# 计算最近三次记录的兴趣度总和
total_recent_interest = sum(NoReplyAction._recent_interest_records)
# 获取当前聊天频率和意愿系数
talk_frequency = global_config.chat.get_current_talk_frequency(self.chat_id)
willing_amplifier = global_config.chat.willing_amplifier
# 计算调整后的阈值
adjusted_threshold = self._interest_exit_threshold / talk_frequency / willing_amplifier
logger.info(f"{self.log_prefix} 最近三次兴趣度总和: {total_recent_interest:.2f}, 调整后阈值: {adjusted_threshold:.2f}")
# 如果兴趣度总和小于阈值进入breaking形式
if total_recent_interest < adjusted_threshold:
logger.info(f"{self.log_prefix} 兴趣度不足进入breaking形式")
return "breaking"
else:
logger.info(f"{self.log_prefix} 兴趣度充足继续使用waiting形式")
return "waiting"
async def _execute_waiting_form(self, start_time: float, check_interval: float) -> Tuple[bool, str]:
"""执行waiting形式的no_reply"""
import asyncio
logger.info(f"{self.log_prefix} 进入waiting形式等待任何新消息")
while True:
current_time = time.time()
elapsed_time = current_time - start_time
# 检查新消息
recent_messages_dict = message_api.get_messages_by_time_in_chat(
chat_id=self.chat_id,
start_time=start_time,
end_time=current_time,
filter_mai=True,
filter_command=True,
)
new_message_count = len(recent_messages_dict)
# waiting形式只要有新消息就结束
if new_message_count > 0:
# 计算新消息的总兴趣度
total_interest = 0.0
for msg_dict in recent_messages_dict:
interest_value = msg_dict.get("interest_value", 0.0)
if msg_dict.get("processed_plain_text", ""):
total_interest += interest_value * global_config.chat.willing_amplifier
# 记录到最近兴趣度列表
NoReplyAction._recent_interest_records.append(total_interest)
logger.info(
f"{self.log_prefix} waiting形式检测到{new_message_count}条新消息,总兴趣度: {total_interest:.2f},结束等待"
)
exit_reason = f"{global_config.bot.nickname}(你)看到了{new_message_count}条新消息,可以考虑一下是否要进行回复"
await self.store_action_info(
action_build_into_prompt=False,
action_prompt_display=exit_reason,
action_done=True,
)
return True, f"waiting形式检测到{new_message_count}条新消息,结束等待 (等待时间: {elapsed_time:.1f}秒)"
# 每10秒输出一次等待状态
if int(elapsed_time) > 0 and int(elapsed_time) % 10 == 0:
logger.debug(f"{self.log_prefix} waiting形式已等待{elapsed_time:.0f}秒,继续等待新消息...")
await asyncio.sleep(1)
# 短暂等待后继续检查
await asyncio.sleep(check_interval)
async def _execute_breaking_form(self, start_time: float, check_interval: float) -> Tuple[bool, str]:
"""执行breaking形式的no_reply原有逻辑"""
import asyncio
# 随机生成本次等待需要的新消息数量阈值
exit_message_count_threshold = random.randint(self._min_exit_message_count, self._max_exit_message_count)
logger.info(f"{self.log_prefix} 进入breaking形式需要{exit_message_count_threshold}条消息或足够兴趣度")
while True:
current_time = time.time()
elapsed_time = current_time - start_time
# 检查新消息
recent_messages_dict = message_api.get_messages_by_time_in_chat(
chat_id=self.chat_id,
start_time=start_time,
end_time=current_time,
filter_mai=True,
filter_command=True,
)
new_message_count = len(recent_messages_dict)
# 检查消息数量是否达到阈值
talk_frequency = global_config.chat.get_current_talk_frequency(self.chat_id)
modified_exit_count_threshold = (exit_message_count_threshold / talk_frequency) / global_config.chat.willing_amplifier
if new_message_count >= modified_exit_count_threshold:
# 记录兴趣度到列表
total_interest = 0.0
for msg_dict in recent_messages_dict:
interest_value = msg_dict.get("interest_value", 0.0)
if msg_dict.get("processed_plain_text", ""):
total_interest += interest_value * global_config.chat.willing_amplifier
NoReplyAction._recent_interest_records.append(total_interest)
logger.info(
f"{self.log_prefix} breaking形式累计消息数量达到{new_message_count}条(>{modified_exit_count_threshold}),结束等待"
)
exit_reason = f"{global_config.bot.nickname}(你)看到了{new_message_count}条新消息,可以考虑一下是否要进行回复"
await self.store_action_info(
action_build_into_prompt=False,
action_prompt_display=exit_reason,
action_done=True,
)
return True, f"breaking形式累计消息数量达到{new_message_count}条,结束等待 (等待时间: {elapsed_time:.1f}秒)"
# 检查累计兴趣值
if new_message_count > 0:
accumulated_interest = 0.0
for msg_dict in recent_messages_dict:
text = msg_dict.get("processed_plain_text", "")
interest_value = msg_dict.get("interest_value", 0.0)
if text:
accumulated_interest += interest_value * global_config.chat.willing_amplifier
# 只在兴趣值变化时输出log
if not hasattr(self, "_last_accumulated_interest") or accumulated_interest != self._last_accumulated_interest:
logger.info(f"{self.log_prefix} breaking形式当前累计兴趣值: {accumulated_interest:.2f}, 当前聊天频率: {talk_frequency:.2f}")
self._last_accumulated_interest = accumulated_interest
if accumulated_interest >= self._interest_exit_threshold / talk_frequency:
# 记录兴趣度到列表
NoReplyAction._recent_interest_records.append(accumulated_interest)
logger.info(
f"{self.log_prefix} breaking形式累计兴趣值达到{accumulated_interest:.2f}(>{self._interest_exit_threshold / talk_frequency}),结束等待"
)
exit_reason = f"{global_config.bot.nickname}(你)感觉到了大家浓厚的兴趣(兴趣值{accumulated_interest:.1f}),决定重新加入讨论"
await self.store_action_info(
action_build_into_prompt=False,
action_prompt_display=exit_reason,
action_done=True,
)
return (
True,
f"breaking形式累计兴趣值达到{accumulated_interest:.2f},结束等待 (等待时间: {elapsed_time:.1f}秒)",
)
# 每10秒输出一次等待状态
if int(elapsed_time) > 0 and int(elapsed_time) % 10 == 0:
logger.debug(
f"{self.log_prefix} breaking形式已等待{elapsed_time:.0f}秒,累计{new_message_count}条消息,继续等待..."
)
await asyncio.sleep(1)
# 短暂等待后继续检查
await asyncio.sleep(check_interval)
@classmethod
def reset_consecutive_count(cls):
"""重置连续计数器"""
"""重置连续计数器和兴趣度记录"""
cls._consecutive_count = 0
logger.debug("NoReplyAction连续计数器已重置")
cls._recent_interest_records.clear()
logger.debug("NoReplyAction连续计数器和兴趣度记录已重置")
@classmethod
def get_recent_interest_records(cls) -> List[float]:
"""获取最近的兴趣度记录"""
return list(cls._recent_interest_records)
@classmethod
def get_consecutive_count(cls) -> int:
"""获取连续计数"""
return cls._consecutive_count

View File

@@ -8,9 +8,8 @@
from typing import List, Tuple, Type
# 导入新插件系统
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo, ActionActivationType
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo
from src.plugin_system.base.config_types import ConfigField
from src.config.config import global_config
# 导入依赖的系统组件
from src.common.logger import get_logger
@@ -18,7 +17,6 @@ from src.common.logger import get_logger
# 导入API模块 - 标准Python包方式
from src.plugins.built_in.core_actions.no_reply import NoReplyAction
from src.plugins.built_in.core_actions.emoji import EmojiAction
from src.plugins.built_in.core_actions.reply import ReplyAction
logger = get_logger("core_actions")
@@ -52,10 +50,9 @@ class CoreActionsPlugin(BasePlugin):
config_schema: dict = {
"plugin": {
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
"config_version": ConfigField(type=str, default="0.4.0", description="配置文件版本"),
"config_version": ConfigField(type=str, default="0.5.0", description="配置文件版本"),
},
"components": {
"enable_reply": ConfigField(type=bool, default=True, description="是否启用回复动作"),
"enable_no_reply": ConfigField(type=bool, default=True, description="是否启用不回复动作"),
"enable_emoji": ConfigField(type=bool, default=True, description="是否启用发送表情/图片动作"),
},
@@ -64,24 +61,12 @@ class CoreActionsPlugin(BasePlugin):
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
"""返回插件包含的组件列表"""
if global_config.emoji.emoji_activate_type == "llm":
EmojiAction.random_activation_probability = 0.0
EmojiAction.focus_activation_type = ActionActivationType.LLM_JUDGE
EmojiAction.normal_activation_type = ActionActivationType.LLM_JUDGE
elif global_config.emoji.emoji_activate_type == "random":
EmojiAction.random_activation_probability = global_config.emoji.emoji_chance
EmojiAction.focus_activation_type = ActionActivationType.RANDOM
EmojiAction.normal_activation_type = ActionActivationType.RANDOM
# --- 根据配置注册组件 ---
components = []
if self.get_config("components.enable_reply", True):
components.append((ReplyAction.get_action_info(), ReplyAction))
if self.get_config("components.enable_no_reply", True):
components.append((NoReplyAction.get_action_info(), NoReplyAction))
if self.get_config("components.enable_emoji", True):
components.append((EmojiAction.get_action_info(), EmojiAction))
# components.append((DeepReplyAction.get_action_info(), DeepReplyAction))
return components

View File

@@ -1,149 +0,0 @@
# 导入新插件系统
from src.plugin_system import BaseAction, ActionActivationType, ChatMode
from src.config.config import global_config
import random
import time
from typing import Tuple
import asyncio
import re
import traceback
# 导入依赖的系统组件
from src.common.logger import get_logger
# 导入API模块 - 标准Python包方式
from src.plugin_system.apis import generator_api, message_api
from src.plugins.built_in.core_actions.no_reply import NoReplyAction
from src.person_info.person_info import get_person_info_manager
from src.mais4u.mai_think import mai_thinking_manager
from src.mais4u.constant_s4u import ENABLE_S4U
logger = get_logger("reply_action")
class ReplyAction(BaseAction):
"""回复动作 - 参与聊天回复"""
# 激活设置
focus_activation_type = ActionActivationType.NEVER
normal_activation_type = ActionActivationType.NEVER
mode_enable = ChatMode.FOCUS
parallel_action = False
# 动作基本信息
action_name = "reply"
action_description = ""
# 动作参数定义
action_parameters = {}
# 动作使用场景
action_require = [""]
# 关联类型
associated_types = ["text"]
def _parse_reply_target(self, target_message: str) -> tuple:
sender = ""
target = ""
# 添加None检查防止NoneType错误
if target_message is None:
return sender, target
if ":" in target_message or "" in target_message:
# 使用正则表达式匹配中文或英文冒号
parts = re.split(pattern=r"[:]", string=target_message, maxsplit=1)
if len(parts) == 2:
sender = parts[0].strip()
target = parts[1].strip()
return sender, target
async def execute(self) -> Tuple[bool, str]:
"""执行回复动作"""
logger.debug(f"{self.log_prefix} 决定进行回复")
start_time = self.action_data.get("loop_start_time", time.time())
user_id = self.user_id
platform = self.platform
# logger.info(f"{self.log_prefix} 用户ID: {user_id}, 平台: {platform}")
person_id = get_person_info_manager().get_person_id(platform, user_id) # type: ignore
# logger.info(f"{self.log_prefix} 人物ID: {person_id}")
person_name = get_person_info_manager().get_value_sync(person_id, "person_name")
reply_to = f"{person_name}:{self.action_message.get('processed_plain_text', '')}" # type: ignore
logger.info(f"{self.log_prefix} 决定进行回复,目标: {reply_to}")
try:
if prepared_reply := self.action_data.get("prepared_reply", ""):
reply_text = prepared_reply
else:
try:
success, reply_set, _ = await asyncio.wait_for(
generator_api.generate_reply(
extra_info="",
reply_to=reply_to,
chat_id=self.chat_id,
request_type="chat.replyer.focus",
enable_tool=global_config.tool.enable_in_focus_chat,
),
timeout=global_config.chat.thinking_timeout,
)
except asyncio.TimeoutError:
logger.warning(f"{self.log_prefix} 回复生成超时 ({global_config.chat.thinking_timeout}s)")
return False, "timeout"
# 检查从start_time以来的新消息数量
# 获取动作触发时间或使用默认值
current_time = time.time()
new_message_count = message_api.count_new_messages(
chat_id=self.chat_id, start_time=start_time, end_time=current_time
)
# 根据新消息数量决定是否使用reply_to
need_reply = new_message_count >= random.randint(2, 4)
if need_reply:
logger.info(
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复"
)
else:
logger.debug(
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,不使用引用回复"
)
# 构建回复文本
reply_text = ""
first_replied = False
reply_to_platform_id = f"{platform}:{user_id}"
for reply_seg in reply_set:
data = reply_seg[1]
if not first_replied:
if need_reply:
await self.send_text(
content=data, reply_to=reply_to, reply_to_platform_id=reply_to_platform_id, typing=False
)
else:
await self.send_text(content=data, reply_to_platform_id=reply_to_platform_id, typing=False)
first_replied = True
else:
await self.send_text(content=data, reply_to_platform_id=reply_to_platform_id, typing=True)
reply_text += data
# 存储动作记录
reply_text = f"你对{person_name}进行了回复:{reply_text}"
if ENABLE_S4U:
await mai_thinking_manager.get_mai_think(self.chat_id).do_think_after_response(reply_text)
await self.store_action_info(
action_build_into_prompt=False,
action_prompt_display=reply_text,
action_done=True,
)
# 重置NoReplyAction的连续计数器
NoReplyAction.reset_consecutive_count()
return success, reply_text
except Exception as e:
logger.error(f"{self.log_prefix} 回复动作执行失败: {e}")
traceback.print_exc()
return False, f"回复失败: {str(e)}"

View File

@@ -1,4 +1,4 @@
from src.tools.tool_can_use.base_tool import BaseTool
from src.plugin_system.base.base_tool import BaseTool
from src.chat.utils.utils import get_embedding
from src.common.database.database_model import Knowledges # Updated import
from src.common.logger import get_logger
@@ -77,7 +77,7 @@ class SearchKnowledgeTool(BaseTool):
Union[str, list]: 格式化的信息字符串或原始结果列表
"""
if not query_embedding:
return "" if not return_raw else []
return [] if return_raw else ""
similar_items = []
try:
@@ -115,10 +115,10 @@ class SearchKnowledgeTool(BaseTool):
except Exception as e:
logger.error(f"从 Peewee 数据库获取知识信息失败: {str(e)}")
return "" if not return_raw else []
return [] if return_raw else ""
if not results:
return "" if not return_raw else []
return [] if return_raw else ""
if return_raw:
# Peewee 模型实例不能直接序列化为 JSON如果需要原始模型调用者需要处理

View File

@@ -1,4 +1,4 @@
from src.tools.tool_can_use.base_tool import BaseTool
from src.plugin_system.base.base_tool import BaseTool
# from src.common.database import db
from src.common.logger import get_logger

View File

@@ -9,7 +9,7 @@
},
"license": "GPL-v3.0-or-later",
"host_application": {
"min_version": "0.9.0"
"min_version": "0.9.1"
},
"homepage_url": "https://github.com/MaiM-with-u/maibot",
"repository_url": "https://github.com/MaiM-with-u/maibot",

View File

@@ -11,6 +11,7 @@ from src.plugin_system import (
component_manage_api,
ComponentInfo,
ComponentType,
send_api,
)
@@ -27,8 +28,15 @@ class ManagementCommand(BaseCommand):
or not self.message.message_info.user_info
or str(self.message.message_info.user_info.user_id) not in self.get_config("plugin.permission", []) # type: ignore
):
await self.send_text("你没有权限使用插件管理命令")
await self._send_message("你没有权限使用插件管理命令")
return False, "没有权限", True
if not self.message.chat_stream:
await self._send_message("无法获取聊天流信息")
return False, "无法获取聊天流信息", True
self.stream_id = self.message.chat_stream.stream_id
if not self.stream_id:
await self._send_message("无法获取聊天流信息")
return False, "无法获取聊天流信息", True
command_list = self.matched_groups["manage_command"].strip().split(" ")
if len(command_list) == 1:
await self.show_help("all")
@@ -42,7 +50,7 @@ class ManagementCommand(BaseCommand):
case "help":
await self.show_help("all")
case _:
await self.send_text("插件管理命令不合法")
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if len(command_list) == 3:
if command_list[1] == "plugin":
@@ -56,7 +64,7 @@ class ManagementCommand(BaseCommand):
case "rescan":
await self._rescan_plugin_dirs()
case _:
await self.send_text("插件管理命令不合法")
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
elif command_list[1] == "component":
if command_list[2] == "list":
@@ -64,10 +72,10 @@ class ManagementCommand(BaseCommand):
elif command_list[2] == "help":
await self.show_help("component")
else:
await self.send_text("插件管理命令不合法")
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
else:
await self.send_text("插件管理命令不合法")
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if len(command_list) == 4:
if command_list[1] == "plugin":
@@ -81,28 +89,28 @@ class ManagementCommand(BaseCommand):
case "add_dir":
await self._add_dir(command_list[3])
case _:
await self.send_text("插件管理命令不合法")
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
elif command_list[1] == "component":
if command_list[2] != "list":
await self.send_text("插件管理命令不合法")
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if command_list[3] == "enabled":
await self._list_enabled_components()
elif command_list[3] == "disabled":
await self._list_disabled_components()
else:
await self.send_text("插件管理命令不合法")
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
else:
await self.send_text("插件管理命令不合法")
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if len(command_list) == 5:
if command_list[1] != "component":
await self.send_text("插件管理命令不合法")
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if command_list[2] != "list":
await self.send_text("插件管理命令不合法")
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if command_list[3] == "enabled":
await self._list_enabled_components(target_type=command_list[4])
@@ -111,11 +119,11 @@ class ManagementCommand(BaseCommand):
elif command_list[3] == "type":
await self._list_registered_components_by_type(command_list[4])
else:
await self.send_text("插件管理命令不合法")
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if len(command_list) == 6:
if command_list[1] != "component":
await self.send_text("插件管理命令不合法")
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if command_list[2] == "enable":
if command_list[3] == "global":
@@ -123,7 +131,7 @@ class ManagementCommand(BaseCommand):
elif command_list[3] == "local":
await self._locally_enable_component(command_list[4], command_list[5])
else:
await self.send_text("插件管理命令不合法")
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
elif command_list[2] == "disable":
if command_list[3] == "global":
@@ -131,10 +139,10 @@ class ManagementCommand(BaseCommand):
elif command_list[3] == "local":
await self._locally_disable_component(command_list[4], command_list[5])
else:
await self.send_text("插件管理命令不合法")
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
else:
await self.send_text("插件管理命令不合法")
await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
return True, "命令执行完成", True
@@ -180,51 +188,51 @@ class ManagementCommand(BaseCommand):
)
case _:
return
await self.send_text(help_msg)
await self._send_message(help_msg)
async def _list_loaded_plugins(self):
plugins = plugin_manage_api.list_loaded_plugins()
await self.send_text(f"已加载的插件: {', '.join(plugins)}")
await self._send_message(f"已加载的插件: {', '.join(plugins)}")
async def _list_registered_plugins(self):
plugins = plugin_manage_api.list_registered_plugins()
await self.send_text(f"已注册的插件: {', '.join(plugins)}")
await self._send_message(f"已注册的插件: {', '.join(plugins)}")
async def _rescan_plugin_dirs(self):
plugin_manage_api.rescan_plugin_directory()
await self.send_text("插件目录重新扫描执行中")
await self._send_message("插件目录重新扫描执行中")
async def _load_plugin(self, plugin_name: str):
success, count = plugin_manage_api.load_plugin(plugin_name)
if success:
await self.send_text(f"插件加载成功: {plugin_name}")
await self._send_message(f"插件加载成功: {plugin_name}")
else:
if count == 0:
await self.send_text(f"插件{plugin_name}为禁用状态")
await self.send_text(f"插件加载失败: {plugin_name}")
await self._send_message(f"插件{plugin_name}为禁用状态")
await self._send_message(f"插件加载失败: {plugin_name}")
async def _unload_plugin(self, plugin_name: str):
success = await plugin_manage_api.remove_plugin(plugin_name)
if success:
await self.send_text(f"插件卸载成功: {plugin_name}")
await self._send_message(f"插件卸载成功: {plugin_name}")
else:
await self.send_text(f"插件卸载失败: {plugin_name}")
await self._send_message(f"插件卸载失败: {plugin_name}")
async def _reload_plugin(self, plugin_name: str):
success = await plugin_manage_api.reload_plugin(plugin_name)
if success:
await self.send_text(f"插件重新加载成功: {plugin_name}")
await self._send_message(f"插件重新加载成功: {plugin_name}")
else:
await self.send_text(f"插件重新加载失败: {plugin_name}")
await self._send_message(f"插件重新加载失败: {plugin_name}")
async def _add_dir(self, dir_path: str):
await self.send_text(f"正在添加插件目录: {dir_path}")
await self._send_message(f"正在添加插件目录: {dir_path}")
success = plugin_manage_api.add_plugin_directory(dir_path)
await asyncio.sleep(0.5) # 防止乱序发送
if success:
await self.send_text(f"插件目录添加成功: {dir_path}")
await self._send_message(f"插件目录添加成功: {dir_path}")
else:
await self.send_text(f"插件目录添加失败: {dir_path}")
await self._send_message(f"插件目录添加失败: {dir_path}")
def _fetch_all_registered_components(self) -> List[ComponentInfo]:
all_plugin_info = component_manage_api.get_all_plugin_info()
@@ -255,29 +263,29 @@ class ManagementCommand(BaseCommand):
async def _list_all_registered_components(self):
components_info = self._fetch_all_registered_components()
if not components_info:
await self.send_text("没有注册的组件")
await self._send_message("没有注册的组件")
return
all_components_str = ", ".join(
f"{component.name} ({component.component_type})" for component in components_info
)
await self.send_text(f"已注册的组件: {all_components_str}")
await self._send_message(f"已注册的组件: {all_components_str}")
async def _list_enabled_components(self, target_type: str = "global"):
components_info = self._fetch_all_registered_components()
if not components_info:
await self.send_text("没有注册的组件")
await self._send_message("没有注册的组件")
return
if target_type == "global":
enabled_components = [component for component in components_info if component.enabled]
if not enabled_components:
await self.send_text("没有满足条件的已启用全局组件")
await self._send_message("没有满足条件的已启用全局组件")
return
enabled_components_str = ", ".join(
f"{component.name} ({component.component_type})" for component in enabled_components
)
await self.send_text(f"满足条件的已启用全局组件: {enabled_components_str}")
await self._send_message(f"满足条件的已启用全局组件: {enabled_components_str}")
elif target_type == "local":
locally_disabled_components = self._fetch_locally_disabled_components()
enabled_components = [
@@ -286,28 +294,28 @@ class ManagementCommand(BaseCommand):
if (component.name not in locally_disabled_components and component.enabled)
]
if not enabled_components:
await self.send_text("本聊天没有满足条件的已启用组件")
await self._send_message("本聊天没有满足条件的已启用组件")
return
enabled_components_str = ", ".join(
f"{component.name} ({component.component_type})" for component in enabled_components
)
await self.send_text(f"本聊天满足条件的已启用组件: {enabled_components_str}")
await self._send_message(f"本聊天满足条件的已启用组件: {enabled_components_str}")
async def _list_disabled_components(self, target_type: str = "global"):
components_info = self._fetch_all_registered_components()
if not components_info:
await self.send_text("没有注册的组件")
await self._send_message("没有注册的组件")
return
if target_type == "global":
disabled_components = [component for component in components_info if not component.enabled]
if not disabled_components:
await self.send_text("没有满足条件的已禁用全局组件")
await self._send_message("没有满足条件的已禁用全局组件")
return
disabled_components_str = ", ".join(
f"{component.name} ({component.component_type})" for component in disabled_components
)
await self.send_text(f"满足条件的已禁用全局组件: {disabled_components_str}")
await self._send_message(f"满足条件的已禁用全局组件: {disabled_components_str}")
elif target_type == "local":
locally_disabled_components = self._fetch_locally_disabled_components()
disabled_components = [
@@ -316,12 +324,12 @@ class ManagementCommand(BaseCommand):
if (component.name in locally_disabled_components or not component.enabled)
]
if not disabled_components:
await self.send_text("本聊天没有满足条件的已禁用组件")
await self._send_message("本聊天没有满足条件的已禁用组件")
return
disabled_components_str = ", ".join(
f"{component.name} ({component.component_type})" for component in disabled_components
)
await self.send_text(f"本聊天满足条件的已禁用组件: {disabled_components_str}")
await self._send_message(f"本聊天满足条件的已禁用组件: {disabled_components_str}")
async def _list_registered_components_by_type(self, target_type: str):
match target_type:
@@ -332,18 +340,18 @@ class ManagementCommand(BaseCommand):
case "event_handler":
component_type = ComponentType.EVENT_HANDLER
case _:
await self.send_text(f"未知组件类型: {target_type}")
await self._send_message(f"未知组件类型: {target_type}")
return
components_info = component_manage_api.get_components_info_by_type(component_type)
if not components_info:
await self.send_text(f"没有注册的 {target_type} 组件")
await self._send_message(f"没有注册的 {target_type} 组件")
return
components_str = ", ".join(
f"{name} ({component.component_type})" for name, component in components_info.items()
)
await self.send_text(f"注册的 {target_type} 组件: {components_str}")
await self._send_message(f"注册的 {target_type} 组件: {components_str}")
async def _globally_enable_component(self, component_name: str, component_type: str):
match component_type:
@@ -354,12 +362,12 @@ class ManagementCommand(BaseCommand):
case "event_handler":
target_component_type = ComponentType.EVENT_HANDLER
case _:
await self.send_text(f"未知组件类型: {component_type}")
await self._send_message(f"未知组件类型: {component_type}")
return
if component_manage_api.globally_enable_component(component_name, target_component_type):
await self.send_text(f"全局启用组件成功: {component_name}")
await self._send_message(f"全局启用组件成功: {component_name}")
else:
await self.send_text(f"全局启用组件失败: {component_name}")
await self._send_message(f"全局启用组件失败: {component_name}")
async def _globally_disable_component(self, component_name: str, component_type: str):
match component_type:
@@ -370,13 +378,13 @@ class ManagementCommand(BaseCommand):
case "event_handler":
target_component_type = ComponentType.EVENT_HANDLER
case _:
await self.send_text(f"未知组件类型: {component_type}")
await self._send_message(f"未知组件类型: {component_type}")
return
success = await component_manage_api.globally_disable_component(component_name, target_component_type)
if success:
await self.send_text(f"全局禁用组件成功: {component_name}")
await self._send_message(f"全局禁用组件成功: {component_name}")
else:
await self.send_text(f"全局禁用组件失败: {component_name}")
await self._send_message(f"全局禁用组件失败: {component_name}")
async def _locally_enable_component(self, component_name: str, component_type: str):
match component_type:
@@ -387,16 +395,16 @@ class ManagementCommand(BaseCommand):
case "event_handler":
target_component_type = ComponentType.EVENT_HANDLER
case _:
await self.send_text(f"未知组件类型: {component_type}")
await self._send_message(f"未知组件类型: {component_type}")
return
if component_manage_api.locally_enable_component(
component_name,
target_component_type,
self.message.chat_stream.stream_id,
):
await self.send_text(f"本地启用组件成功: {component_name}")
await self._send_message(f"本地启用组件成功: {component_name}")
else:
await self.send_text(f"本地启用组件失败: {component_name}")
await self._send_message(f"本地启用组件失败: {component_name}")
async def _locally_disable_component(self, component_name: str, component_type: str):
match component_type:
@@ -407,34 +415,40 @@ class ManagementCommand(BaseCommand):
case "event_handler":
target_component_type = ComponentType.EVENT_HANDLER
case _:
await self.send_text(f"未知组件类型: {component_type}")
await self._send_message(f"未知组件类型: {component_type}")
return
if component_manage_api.locally_disable_component(
component_name,
target_component_type,
self.message.chat_stream.stream_id,
):
await self.send_text(f"本地禁用组件成功: {component_name}")
await self._send_message(f"本地禁用组件成功: {component_name}")
else:
await self.send_text(f"本地禁用组件失败: {component_name}")
await self._send_message(f"本地禁用组件失败: {component_name}")
async def _send_message(self, message: str):
await send_api.text_to_stream(message, self.stream_id, typing=False, storage_message=False)
@register_plugin
class PluginManagementPlugin(BasePlugin):
plugin_name: str = "plugin_management_plugin"
enable_plugin: bool = True
enable_plugin: bool = False
dependencies: list[str] = []
python_dependencies: list[str] = []
config_file_name: str = "config.toml"
config_schema: dict = {
"plugin": {
"enable": ConfigField(bool, default=True, description="是否启用插件"),
"permission": ConfigField(list, default=[], description="有权限使用插件管理命令的用户列表"),
"enabled": ConfigField(bool, default=False, description="是否启用插件"),
"config_version": ConfigField(type=str, default="1.1.0", description="配置文件版本"),
"permission": ConfigField(
list, default=[], description="有权限使用插件管理命令的用户列表请填写字符串形式的用户ID"
),
},
}
def get_plugin_components(self) -> List[Tuple[CommandInfo, Type[BaseCommand]]]:
components = []
if self.get_config("plugin.enable", True):
if self.get_config("plugin.enabled", True):
components.append((ManagementCommand.get_command_info(), ManagementCommand))
return components

View File

@@ -1,20 +0,0 @@
from src.tools.tool_can_use.base_tool import (
BaseTool,
register_tool,
discover_tools,
get_all_tool_definitions,
get_tool_instance,
TOOL_REGISTRY,
)
__all__ = [
"BaseTool",
"register_tool",
"discover_tools",
"get_all_tool_definitions",
"get_tool_instance",
"TOOL_REGISTRY",
]
# 自动发现并注册工具
discover_tools()

View File

@@ -1,115 +0,0 @@
from typing import List, Any, Optional, Type
import inspect
import importlib
import pkgutil
import os
from src.common.logger import get_logger
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("base_tool")
# 工具注册表
TOOL_REGISTRY = {}
class BaseTool:
"""所有工具的基类"""
# 工具名称,子类必须重写
name = None
# 工具描述,子类必须重写
description = None
# 工具参数定义,子类必须重写
parameters = None
@classmethod
def get_tool_definition(cls) -> dict[str, Any]:
"""获取工具定义用于LLM工具调用
Returns:
dict: 工具定义字典
"""
if not cls.name or not cls.description or not cls.parameters:
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
return {
"type": "function",
"function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
}
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行工具函数
Args:
function_args: 工具调用参数
Returns:
dict: 工具执行结果
"""
raise NotImplementedError("子类必须实现execute方法")
def register_tool(tool_class: Type[BaseTool]):
"""注册工具到全局注册表
Args:
tool_class: 工具类
"""
if not issubclass(tool_class, BaseTool):
raise TypeError(f"{tool_class.__name__} 不是 BaseTool 的子类")
tool_name = tool_class.name
if not tool_name:
raise ValueError(f"工具类 {tool_class.__name__} 没有定义 name 属性")
TOOL_REGISTRY[tool_name] = tool_class
logger.info(f"已注册: {tool_name}")
def discover_tools():
"""自动发现并注册tool_can_use目录下的所有工具"""
# 获取当前目录路径
current_dir = os.path.dirname(os.path.abspath(__file__))
package_name = os.path.basename(current_dir)
# 遍历包中的所有模块
for _, module_name, _ in pkgutil.iter_modules([current_dir]):
# 跳过当前模块和__pycache__
if module_name == "base_tool" or module_name.startswith("__"):
continue
# 导入模块
module = importlib.import_module(f"src.tools.{package_name}.{module_name}")
# 查找模块中的工具类
for _, obj in inspect.getmembers(module):
if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
register_tool(obj)
logger.info(f"工具发现完成,共注册 {len(TOOL_REGISTRY)} 个工具")
def get_all_tool_definitions() -> List[dict[str, Any]]:
"""获取所有已注册工具的定义
Returns:
List[dict]: 工具定义列表
"""
return [tool_class().get_tool_definition() for tool_class in TOOL_REGISTRY.values()]
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
"""获取指定名称的工具实例
Args:
tool_name: 工具名称
Returns:
Optional[BaseTool]: 工具实例如果找不到则返回None
"""
tool_class = TOOL_REGISTRY.get(tool_name)
if not tool_class:
return None
return tool_class()

View File

@@ -1,49 +0,0 @@
from src.tools.tool_can_use.base_tool import BaseTool
from src.common.logger import get_logger
from typing import Any
logger = get_logger("compare_numbers_tool")
class CompareNumbersTool(BaseTool):
"""比较两个数大小的工具"""
name = "compare_numbers"
description = "使用工具 比较两个数的大小,返回较大的数"
parameters = {
"type": "object",
"properties": {
"num1": {"type": "number", "description": "第一个数字"},
"num2": {"type": "number", "description": "第二个数字"},
},
"required": ["num1", "num2"],
}
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行比较两个数的大小
Args:
function_args: 工具参数
Returns:
dict: 工具执行结果
"""
num1: int | float = function_args.get("num1") # type: ignore
num2: int | float = function_args.get("num2") # type: ignore
try:
if num1 > num2:
result = f"{num1} 大于 {num2}"
elif num1 < num2:
result = f"{num1} 小于 {num2}"
else:
result = f"{num1} 等于 {num2}"
return {"type": "comparison_result", "id": f"{num1}_vs_{num2}", "content": result}
except Exception as e:
logger.error(f"比较数字失败: {str(e)}")
return {"type": "info", "id": f"{num1}_vs_{num2}", "content": f"比较数字失败,炸了: {str(e)}"}
# 注册工具
# register_tool(CompareNumbersTool)

View File

@@ -1,104 +0,0 @@
from src.tools.tool_can_use.base_tool import BaseTool
from src.person_info.person_info import get_person_info_manager
from src.common.logger import get_logger
import time
logger = get_logger("rename_person_tool")
class RenamePersonTool(BaseTool):
name = "rename_person"
description = (
"这个工具可以改变用户的昵称。你可以选择改变对他人的称呼。你想给人改名,叫别人别的称呼,需要调用这个工具。"
)
parameters = {
"type": "object",
"properties": {
"person_name": {"type": "string", "description": "需要重新取名的用户的当前昵称"},
"message_content": {
"type": "string",
"description": "当前的聊天内容或特定要求,用于提供取名建议的上下文,尽可能详细。",
},
},
"required": ["person_name"],
}
async def execute(self, function_args: dict, message_txt=""):
"""
执行取名工具逻辑
Args:
function_args (dict): 包含 'person_name' 和可选 'message_content' 的字典
message_txt (str): 原始消息文本 (这里未使用,因为 message_content 更明确)
Returns:
dict: 包含执行结果的字典
"""
person_name_to_find = function_args.get("person_name")
request_context = function_args.get("message_content", "") # 如果没有提供,则为空字符串
if not person_name_to_find:
return {"name": self.name, "content": "错误:必须提供需要重命名的用户昵称 (person_name)。"}
person_info_manager = get_person_info_manager()
try:
# 1. 根据昵称查找用户信息
logger.debug(f"尝试根据昵称 '{person_name_to_find}' 查找用户...")
person_info = await person_info_manager.get_person_info_by_name(person_name_to_find)
if not person_info:
logger.info(f"未找到昵称为 '{person_name_to_find}' 的用户。")
return {
"name": self.name,
"content": f"找不到昵称为 '{person_name_to_find}' 的用户。请确保输入的是我之前为该用户取的昵称。",
}
person_id = person_info.get("person_id")
user_nickname = person_info.get("nickname") # 这是用户原始昵称
user_cardname = person_info.get("user_cardname")
user_avatar = person_info.get("user_avatar")
if not person_id:
logger.error(f"找到了用户 '{person_name_to_find}' 但无法获取 person_id")
return {"name": self.name, "content": f"找到了用户 '{person_name_to_find}' 但获取内部ID时出错。"}
# 2. 调用 qv_person_name 进行取名
logger.debug(
f"为用户 {person_id} (原昵称: {person_name_to_find}) 调用 qv_person_name请求上下文: '{request_context}'"
)
result = await person_info_manager.qv_person_name(
person_id=person_id,
user_nickname=user_nickname, # type: ignore
user_cardname=user_cardname, # type: ignore
user_avatar=user_avatar, # type: ignore
request=request_context,
)
# 3. 处理结果
if result and result.get("nickname"):
new_name = result["nickname"]
# reason = result.get("reason", "未提供理由")
logger.info(f"成功为用户 {person_id} 取了新昵称: {new_name}")
content = f"已成功将用户 {person_name_to_find} 的备注名更新为 {new_name}"
logger.info(content)
return {"type": "info", "id": f"rename_success_{time.time()}", "content": content}
else:
logger.warning(f"为用户 {person_id} 调用 qv_person_name 后未能成功获取新昵称。")
# 尝试从内存中获取可能已经更新的名字
current_name = await person_info_manager.get_value(person_id, "person_name")
if current_name and current_name != person_name_to_find:
return {
"name": self.name,
"content": f"尝试取新昵称时遇到一点小问题,但我已经将 '{person_name_to_find}' 的昵称更新为 '{current_name}' 了。",
}
else:
return {
"name": self.name,
"content": f"尝试为 '{person_name_to_find}' 取新昵称时遇到了问题,未能成功生成。可能需要稍后再试。",
}
except Exception as e:
error_msg = f"重命名失败: {str(e)}"
logger.error(error_msg, exc_info=True)
return {"type": "info_error", "id": f"rename_error_{time.time()}", "content": error_msg}

View File

@@ -1,55 +0,0 @@
import json
from src.common.logger import get_logger
from src.tools.tool_can_use import get_all_tool_definitions, get_tool_instance
logger = get_logger("tool_use")
class ToolUser:
@staticmethod
def _define_tools():
"""获取所有已注册工具的定义
Returns:
list: 工具定义列表
"""
return get_all_tool_definitions()
@staticmethod
async def _execute_tool_call(tool_call):
"""执行特定的工具调用
Args:
tool_call: 工具调用对象
message_txt: 原始消息文本
Returns:
dict: 工具调用结果
"""
try:
function_name = tool_call["function"]["name"]
function_args = json.loads(tool_call["function"]["arguments"])
# 获取对应工具实例
tool_instance = get_tool_instance(function_name)
if not tool_instance:
logger.warning(f"未知工具名称: {function_name}")
return None
# 执行工具
result = await tool_instance.execute(function_args)
if result:
# 直接使用 function_name 作为 tool_type
tool_type = function_name
return {
"tool_call_id": tool_call["id"],
"role": "tool",
"name": function_name,
"type": tool_type,
"content": result["content"],
}
return None
except Exception as e:
logger.error(f"执行工具调用时发生错误: {str(e)}")
return None