@@ -20,10 +20,9 @@ from src.person_info.person_info import get_person_info_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode
|
||||
from src.plugin_system.apis import generator_api, send_api, message_api
|
||||
from src.chat.willing.willing_manager import get_willing_manager
|
||||
from src.chat.mai_thinking.mai_think import mai_thinking_manager
|
||||
from maim_message.message_base import GroupInfo,UserInfo
|
||||
|
||||
ENABLE_THINKING = False
|
||||
from src.mais4u.mai_think import mai_thinking_manager
|
||||
from maim_message.message_base import GroupInfo
|
||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||
|
||||
ERROR_LOOP_INFO = {
|
||||
"loop_plan_info": {
|
||||
@@ -237,12 +236,12 @@ class HeartFChatting:
|
||||
if if_think:
|
||||
factor = max(global_config.chat.focus_value, 0.1)
|
||||
self.energy_value *= 1.1 / factor
|
||||
logger.info(f"{self.log_prefix} 麦麦进行了思考,能量值按倍数增加,当前能量值:{self.energy_value}")
|
||||
logger.info(f"{self.log_prefix} 进行了思考,能量值按倍数增加,当前能量值:{self.energy_value:.1f}")
|
||||
else:
|
||||
self.energy_value += 0.1 / global_config.chat.focus_value
|
||||
logger.info(f"{self.log_prefix} 麦麦没有进行思考,能量值线性增加,当前能量值:{self.energy_value}")
|
||||
logger.debug(f"{self.log_prefix} 没有进行思考,能量值线性增加,当前能量值:{self.energy_value:.1f}")
|
||||
|
||||
logger.debug(f"{self.log_prefix} 当前能量值:{self.energy_value}")
|
||||
logger.debug(f"{self.log_prefix} 当前能量值:{self.energy_value:.1f}")
|
||||
return True
|
||||
|
||||
await asyncio.sleep(1)
|
||||
@@ -257,31 +256,29 @@ class HeartFChatting:
|
||||
)
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
return f"{person_name}:{message_data.get('processed_plain_text')}"
|
||||
|
||||
|
||||
async def send_typing(self):
|
||||
group_info = GroupInfo(platform = "amaidesu_default",group_id = 114514,group_name = "内心")
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform = "amaidesu_default",
|
||||
user_info = None,
|
||||
group_info = group_info
|
||||
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform="amaidesu_default",
|
||||
user_info=None,
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
|
||||
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
|
||||
)
|
||||
|
||||
|
||||
async def stop_typing(self):
|
||||
group_info = GroupInfo(platform = "amaidesu_default",group_id = 114514,group_name = "内心")
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform = "amaidesu_default",
|
||||
user_info = None,
|
||||
group_info = group_info
|
||||
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform="amaidesu_default",
|
||||
user_info=None,
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
|
||||
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False
|
||||
)
|
||||
@@ -296,7 +293,8 @@ class HeartFChatting:
|
||||
|
||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考[模式:{self.loop_mode}]")
|
||||
|
||||
await self.send_typing()
|
||||
if ENABLE_S4U:
|
||||
await self.send_typing()
|
||||
|
||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||
loop_start_time = time.time()
|
||||
@@ -366,13 +364,13 @@ class HeartFChatting:
|
||||
# 发送回复 (不再需要传入 chat)
|
||||
reply_text = await self._send_response(response_set, reply_to_str, loop_start_time,message_data)
|
||||
|
||||
await self.stop_typing()
|
||||
|
||||
|
||||
|
||||
if ENABLE_THINKING:
|
||||
|
||||
if ENABLE_S4U:
|
||||
await self.stop_typing()
|
||||
await mai_thinking_manager.get_mai_think(self.stream_id).do_think_after_response(reply_text)
|
||||
|
||||
|
||||
return True
|
||||
|
||||
@@ -504,10 +502,9 @@ class HeartFChatting:
|
||||
"""
|
||||
|
||||
interested_rate = (message_data.get("interest_value") or 0.0) * self.willing_amplifier
|
||||
|
||||
|
||||
self.willing_manager.setup(message_data, self.chat_stream)
|
||||
|
||||
|
||||
|
||||
reply_probability = await self.willing_manager.get_reply_probability(message_data.get("message_id", ""))
|
||||
|
||||
talk_frequency = -1.00
|
||||
@@ -517,7 +514,7 @@ class HeartFChatting:
|
||||
if additional_config and "maimcore_reply_probability_gain" in additional_config:
|
||||
reply_probability += additional_config["maimcore_reply_probability_gain"]
|
||||
reply_probability = min(max(reply_probability, 0), 1) # 确保概率在 0-1 之间
|
||||
|
||||
|
||||
talk_frequency = global_config.chat.get_current_talk_frequency(self.stream_id)
|
||||
reply_probability = talk_frequency * reply_probability
|
||||
|
||||
@@ -527,9 +524,9 @@ class HeartFChatting:
|
||||
|
||||
# 打印消息信息
|
||||
mes_name = self.chat_stream.group_info.group_name if self.chat_stream.group_info else "私聊"
|
||||
|
||||
|
||||
# logger.info(f"[{mes_name}] 当前聊天频率: {talk_frequency:.2f},兴趣值: {interested_rate:.2f},回复概率: {reply_probability * 100:.1f}%")
|
||||
|
||||
|
||||
if reply_probability > 0.05:
|
||||
logger.info(
|
||||
f"[{mes_name}]"
|
||||
@@ -545,7 +542,6 @@ class HeartFChatting:
|
||||
# 意愿管理器:注销当前message信息 (无论是否回复,只要处理过就删除)
|
||||
self.willing_manager.delete(message_data.get("message_id", ""))
|
||||
return False
|
||||
|
||||
|
||||
async def _generate_response(
|
||||
self, message_data: dict, available_actions: Optional[Dict[str, ActionInfo]], reply_to: str
|
||||
@@ -570,7 +566,7 @@ class HeartFChatting:
|
||||
logger.error(f"[{self.log_prefix}] 回复生成出现错误:{str(e)} {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
async def _send_response(self, reply_set, reply_to, thinking_start_time,message_data):
|
||||
async def _send_response(self, reply_set, reply_to, thinking_start_time, message_data):
|
||||
current_time = time.time()
|
||||
new_message_count = message_api.count_new_messages(
|
||||
chat_id=self.chat_stream.stream_id, start_time=thinking_start_time, end_time=current_time
|
||||
@@ -581,9 +577,14 @@ class HeartFChatting:
|
||||
|
||||
need_reply = new_message_count >= random.randint(2, 4)
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,{'使用' if need_reply else '不使用'}引用回复"
|
||||
)
|
||||
if need_reply:
|
||||
logger.info(
|
||||
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,不使用引用回复"
|
||||
)
|
||||
|
||||
reply_text = ""
|
||||
first_replied = False
|
||||
@@ -592,13 +593,27 @@ class HeartFChatting:
|
||||
if not first_replied:
|
||||
if need_reply:
|
||||
await send_api.text_to_stream(
|
||||
text=data, stream_id=self.chat_stream.stream_id, reply_to=reply_to, reply_to_platform_id=reply_to_platform_id, typing=False
|
||||
text=data,
|
||||
stream_id=self.chat_stream.stream_id,
|
||||
reply_to=reply_to,
|
||||
reply_to_platform_id=reply_to_platform_id,
|
||||
typing=False,
|
||||
)
|
||||
else:
|
||||
await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, reply_to_platform_id=reply_to_platform_id, typing=False)
|
||||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.chat_stream.stream_id,
|
||||
reply_to_platform_id=reply_to_platform_id,
|
||||
typing=False,
|
||||
)
|
||||
first_replied = True
|
||||
else:
|
||||
await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, reply_to_platform_id=reply_to_platform_id, typing=True)
|
||||
await send_api.text_to_stream(
|
||||
text=data,
|
||||
stream_id=self.chat_stream.stream_id,
|
||||
reply_to_platform_id=reply_to_platform_id,
|
||||
typing=True,
|
||||
)
|
||||
reply_text += data
|
||||
|
||||
return reply_text
|
||||
|
||||
@@ -836,7 +836,7 @@ class EmojiManager:
|
||||
return False
|
||||
|
||||
async def build_emoji_description(self, image_base64: str) -> Tuple[str, List[str]]:
|
||||
"""获取表情包描述和情感列表
|
||||
"""获取表情包描述和情感列表,优化复用已有描述
|
||||
|
||||
Args:
|
||||
image_base64: 图片的base64编码
|
||||
@@ -850,18 +850,35 @@ class EmojiManager:
|
||||
if isinstance(image_base64, str):
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
|
||||
# 调用AI获取描述
|
||||
if image_format == "gif" or image_format == "GIF":
|
||||
image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore
|
||||
if not image_base64:
|
||||
raise RuntimeError("GIF表情包转换失败")
|
||||
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg")
|
||||
# 尝试从Images表获取已有的详细描述(可能在收到表情包时已生成)
|
||||
existing_description = None
|
||||
try:
|
||||
from src.common.database.database_model import Images
|
||||
existing_image = Images.get_or_none((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
|
||||
if existing_image and existing_image.description:
|
||||
existing_description = existing_image.description
|
||||
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")
|
||||
except Exception as e:
|
||||
logger.debug(f"查询已有描述时出错: {e}")
|
||||
|
||||
# 第一步:VLM视觉分析(如果没有已有描述才调用)
|
||||
if existing_description:
|
||||
description = existing_description
|
||||
logger.info("[优化] 复用已有的详细描述,跳过VLM调用")
|
||||
else:
|
||||
prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
logger.info("[VLM分析] 生成新的详细描述")
|
||||
if image_format == "gif" or image_format == "GIF":
|
||||
image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore
|
||||
if not image_base64:
|
||||
raise RuntimeError("GIF表情包转换失败")
|
||||
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg")
|
||||
else:
|
||||
prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
|
||||
# 审核表情包
|
||||
if global_config.emoji.content_filtration:
|
||||
@@ -877,7 +894,7 @@ class EmojiManager:
|
||||
if content == "否":
|
||||
return "", []
|
||||
|
||||
# 分析情感含义
|
||||
# 第二步:LLM情感分析 - 基于详细描述生成情感标签列表
|
||||
emotion_prompt = f"""
|
||||
请你识别这个表情包的含义和适用场景,给我简短的描述,每个描述不要超过15个字
|
||||
这是一个基于这个表情包的描述:'{description}'
|
||||
@@ -889,12 +906,14 @@ class EmojiManager:
|
||||
# 处理情感列表
|
||||
emotions = [e.strip() for e in emotions_text.split(",") if e.strip()]
|
||||
|
||||
# 根据情感标签数量随机选择喵~超过5个选3个,超过2个选2个
|
||||
# 根据情感标签数量随机选择 - 超过5个选3个,超过2个选2个
|
||||
if len(emotions) > 5:
|
||||
emotions = random.sample(emotions, 3)
|
||||
elif len(emotions) > 2:
|
||||
emotions = random.sample(emotions, 2)
|
||||
|
||||
logger.info(f"[注册分析] 详细描述: {description[:50]}... -> 情感标签: {emotions}")
|
||||
|
||||
return f"[表情包:{description}]", emotions
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -2,6 +2,7 @@ import time
|
||||
import random
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
|
||||
@@ -21,6 +22,16 @@ DECAY_MIN = 0.01 # 最小衰减值
|
||||
logger = get_logger("expressor")
|
||||
|
||||
|
||||
def format_create_date(timestamp: float) -> str:
|
||||
"""
|
||||
将时间戳格式化为可读的日期字符串
|
||||
"""
|
||||
try:
|
||||
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
return "未知时间"
|
||||
|
||||
|
||||
def init_prompt() -> None:
|
||||
learn_style_prompt = """
|
||||
{chat_str}
|
||||
@@ -76,35 +87,90 @@ class ExpressionLearner:
|
||||
request_type="expressor.learner",
|
||||
)
|
||||
self.llm_model = None
|
||||
self._ensure_expression_directories()
|
||||
self._auto_migrate_json_to_db()
|
||||
self._migrate_old_data_create_date()
|
||||
|
||||
def _ensure_expression_directories(self):
|
||||
"""
|
||||
确保表达方式相关的目录结构存在
|
||||
"""
|
||||
base_dir = os.path.join("data", "expression")
|
||||
directories_to_create = [
|
||||
base_dir,
|
||||
os.path.join(base_dir, "learnt_style"),
|
||||
os.path.join(base_dir, "learnt_grammar"),
|
||||
]
|
||||
|
||||
for directory in directories_to_create:
|
||||
try:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
logger.debug(f"确保目录存在: {directory}")
|
||||
except Exception as e:
|
||||
logger.error(f"创建目录失败 {directory}: {e}")
|
||||
|
||||
def _auto_migrate_json_to_db(self):
|
||||
"""
|
||||
自动将/data/expression/learnt_style 和 learnt_grammar 下所有expressions.json迁移到数据库。
|
||||
迁移完成后在/data/expression/done.done写入标记文件,存在则跳过。
|
||||
"""
|
||||
done_flag = os.path.join("data", "expression", "done.done")
|
||||
base_dir = os.path.join("data", "expression")
|
||||
done_flag = os.path.join(base_dir, "done.done")
|
||||
|
||||
# 确保基础目录存在
|
||||
try:
|
||||
os.makedirs(base_dir, exist_ok=True)
|
||||
logger.debug(f"确保目录存在: {base_dir}")
|
||||
except Exception as e:
|
||||
logger.error(f"创建表达方式目录失败: {e}")
|
||||
return
|
||||
|
||||
if os.path.exists(done_flag):
|
||||
logger.info("表达方式JSON已迁移,无需重复迁移。")
|
||||
return
|
||||
base_dir = os.path.join("data", "expression")
|
||||
|
||||
logger.info("开始迁移表达方式JSON到数据库...")
|
||||
migrated_count = 0
|
||||
|
||||
for type in ["learnt_style", "learnt_grammar"]:
|
||||
type_str = "style" if type == "learnt_style" else "grammar"
|
||||
type_dir = os.path.join(base_dir, type)
|
||||
if not os.path.exists(type_dir):
|
||||
logger.debug(f"目录不存在,跳过: {type_dir}")
|
||||
continue
|
||||
for chat_id in os.listdir(type_dir):
|
||||
|
||||
try:
|
||||
chat_ids = os.listdir(type_dir)
|
||||
logger.debug(f"在 {type_dir} 中找到 {len(chat_ids)} 个聊天ID目录")
|
||||
except Exception as e:
|
||||
logger.error(f"读取目录失败 {type_dir}: {e}")
|
||||
continue
|
||||
|
||||
for chat_id in chat_ids:
|
||||
expr_file = os.path.join(type_dir, chat_id, "expressions.json")
|
||||
if not os.path.exists(expr_file):
|
||||
continue
|
||||
try:
|
||||
with open(expr_file, "r", encoding="utf-8") as f:
|
||||
expressions = json.load(f)
|
||||
|
||||
if not isinstance(expressions, list):
|
||||
logger.warning(f"表达方式文件格式错误,跳过: {expr_file}")
|
||||
continue
|
||||
|
||||
for expr in expressions:
|
||||
if not isinstance(expr, dict):
|
||||
continue
|
||||
|
||||
situation = expr.get("situation")
|
||||
style_val = expr.get("style")
|
||||
count = expr.get("count", 1)
|
||||
last_active_time = expr.get("last_active_time", time.time())
|
||||
|
||||
if not situation or not style_val:
|
||||
logger.warning(f"表达方式缺少必要字段,跳过: {expr}")
|
||||
continue
|
||||
|
||||
# 查重:同chat_id+type+situation+style
|
||||
from src.common.database.database_model import Expression
|
||||
|
||||
@@ -127,18 +193,54 @@ class ExpressionLearner:
|
||||
last_active_time=last_active_time,
|
||||
chat_id=chat_id,
|
||||
type=type_str,
|
||||
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
|
||||
)
|
||||
logger.info(f"已迁移 {expr_file} 到数据库")
|
||||
migrated_count += 1
|
||||
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析失败 {expr_file}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
|
||||
|
||||
# 标记迁移完成
|
||||
try:
|
||||
# 确保done.done文件的父目录存在
|
||||
done_parent_dir = os.path.dirname(done_flag)
|
||||
if not os.path.exists(done_parent_dir):
|
||||
os.makedirs(done_parent_dir, exist_ok=True)
|
||||
logger.debug(f"为done.done创建父目录: {done_parent_dir}")
|
||||
|
||||
with open(done_flag, "w", encoding="utf-8") as f:
|
||||
f.write("done\n")
|
||||
logger.info("表达方式JSON迁移已完成,已写入done.done标记文件")
|
||||
logger.info(f"表达方式JSON迁移已完成,共迁移 {migrated_count} 个表达方式,已写入done.done标记文件")
|
||||
except PermissionError as e:
|
||||
logger.error(f"权限不足,无法写入done.done标记文件: {e}")
|
||||
except OSError as e:
|
||||
logger.error(f"文件系统错误,无法写入done.done标记文件: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"写入done.done标记文件失败: {e}")
|
||||
|
||||
def _migrate_old_data_create_date(self):
|
||||
"""
|
||||
为没有create_date的老数据设置创建日期
|
||||
使用last_active_time作为create_date的默认值
|
||||
"""
|
||||
try:
|
||||
# 查找所有create_date为空的表达方式
|
||||
old_expressions = Expression.select().where(Expression.create_date.is_null())
|
||||
updated_count = 0
|
||||
|
||||
for expr in old_expressions:
|
||||
# 使用last_active_time作为create_date
|
||||
expr.create_date = expr.last_active_time
|
||||
expr.save()
|
||||
updated_count += 1
|
||||
|
||||
if updated_count > 0:
|
||||
logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移老数据创建日期失败: {e}")
|
||||
|
||||
def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
|
||||
"""
|
||||
获取指定chat_id的style和grammar表达方式
|
||||
@@ -150,6 +252,8 @@ class ExpressionLearner:
|
||||
# 直接从数据库查询
|
||||
style_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "style"))
|
||||
for expr in style_query:
|
||||
# 确保create_date存在,如果不存在则使用last_active_time
|
||||
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
||||
learnt_style_expressions.append(
|
||||
{
|
||||
"situation": expr.situation,
|
||||
@@ -158,10 +262,13 @@ class ExpressionLearner:
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": chat_id,
|
||||
"type": "style",
|
||||
"create_date": create_date,
|
||||
}
|
||||
)
|
||||
grammar_query = Expression.select().where((Expression.chat_id == chat_id) & (Expression.type == "grammar"))
|
||||
for expr in grammar_query:
|
||||
# 确保create_date存在,如果不存在则使用last_active_time
|
||||
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
||||
learnt_grammar_expressions.append(
|
||||
{
|
||||
"situation": expr.situation,
|
||||
@@ -170,10 +277,40 @@ class ExpressionLearner:
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": chat_id,
|
||||
"type": "grammar",
|
||||
"create_date": create_date,
|
||||
}
|
||||
)
|
||||
return learnt_style_expressions, learnt_grammar_expressions
|
||||
|
||||
def get_expression_create_info(self, chat_id: str, limit: int = 10) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定chat_id的表达方式创建信息,按创建日期排序
|
||||
"""
|
||||
try:
|
||||
expressions = (Expression.select()
|
||||
.where(Expression.chat_id == chat_id)
|
||||
.order_by(Expression.create_date.desc())
|
||||
.limit(limit))
|
||||
|
||||
result = []
|
||||
for expr in expressions:
|
||||
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
||||
result.append({
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"type": expr.type,
|
||||
"count": expr.count,
|
||||
"create_date": create_date,
|
||||
"create_date_formatted": format_create_date(create_date),
|
||||
"last_active_time": expr.last_active_time,
|
||||
"last_active_formatted": format_create_date(expr.last_active_time),
|
||||
})
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"获取表达方式创建信息失败: {e}")
|
||||
return []
|
||||
|
||||
def is_similar(self, s1: str, s2: str) -> bool:
|
||||
"""
|
||||
判断两个字符串是否相似(只考虑长度大于5且有80%以上重合,不考虑子串)
|
||||
@@ -197,9 +334,17 @@ class ExpressionLearner:
|
||||
for type in ["style", "grammar"]:
|
||||
base_dir = os.path.join("data", "expression", f"learnt_{type}")
|
||||
if not os.path.exists(base_dir):
|
||||
logger.debug(f"目录不存在,跳过衰减: {base_dir}")
|
||||
continue
|
||||
|
||||
for chat_id in os.listdir(base_dir):
|
||||
try:
|
||||
chat_ids = os.listdir(base_dir)
|
||||
logger.debug(f"在 {base_dir} 中找到 {len(chat_ids)} 个聊天ID目录进行衰减")
|
||||
except Exception as e:
|
||||
logger.error(f"读取目录失败 {base_dir}: {e}")
|
||||
continue
|
||||
|
||||
for chat_id in chat_ids:
|
||||
file_path = os.path.join(base_dir, chat_id, "expressions.json")
|
||||
if not os.path.exists(file_path):
|
||||
continue
|
||||
@@ -208,14 +353,24 @@ class ExpressionLearner:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
expressions = json.load(f)
|
||||
|
||||
if not isinstance(expressions, list):
|
||||
logger.warning(f"表达方式文件格式错误,跳过衰减: {file_path}")
|
||||
continue
|
||||
|
||||
# 应用全局衰减
|
||||
decayed_expressions = self.apply_decay_to_expressions(expressions, current_time)
|
||||
|
||||
# 保存衰减后的结果
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(decayed_expressions, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.debug(f"已对 {file_path} 应用衰减,剩余 {len(decayed_expressions)} 个表达方式")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析失败,跳过衰减 {file_path}: {e}")
|
||||
except PermissionError as e:
|
||||
logger.error(f"权限不足,无法更新 {file_path}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"全局衰减{type}表达方式失败: {e}")
|
||||
logger.error(f"全局衰减{type}表达方式失败 {file_path}: {e}")
|
||||
continue
|
||||
|
||||
learnt_style: Optional[List[Tuple[str, str, str]]] = []
|
||||
@@ -350,6 +505,7 @@ class ExpressionLearner:
|
||||
last_active_time=current_time,
|
||||
chat_id=chat_id,
|
||||
type=type,
|
||||
create_date=current_time, # 手动设置创建日期
|
||||
)
|
||||
# 限制最大数量
|
||||
exprs = list(
|
||||
|
||||
@@ -132,7 +132,8 @@ class ExpressionSelector:
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": cid,
|
||||
"type": "style"
|
||||
"type": "style",
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
} for expr in style_query
|
||||
])
|
||||
grammar_exprs.extend([
|
||||
@@ -142,7 +143,8 @@ class ExpressionSelector:
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": cid,
|
||||
"type": "grammar"
|
||||
"type": "grammar",
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
} for expr in grammar_query
|
||||
])
|
||||
style_num = int(total_num * style_percentage)
|
||||
|
||||
@@ -111,9 +111,9 @@ class HeartFCMessageReceiver:
|
||||
subheartflow: SubHeartflow = await heartflow.get_or_create_subheartflow(chat.stream_id) # type: ignore
|
||||
|
||||
# subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned)
|
||||
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id)
|
||||
asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate))
|
||||
if global_config.mood.enable_mood:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(subheartflow.chat_id)
|
||||
asyncio.create_task(chat_mood.update_mood_by_message(message, interested_rate))
|
||||
|
||||
# 3. 日志记录
|
||||
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
|
||||
@@ -13,10 +13,9 @@ from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.plugin_system.core import component_registry, events_manager # 导入新插件系统
|
||||
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
|
||||
from src.plugin_system.base import BaseCommand, EventType
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
# 定义日志配置
|
||||
|
||||
@@ -92,8 +91,19 @@ class ChatBot:
|
||||
# 使用新的组件注册中心查找命令
|
||||
command_result = component_registry.find_command_by_text(text)
|
||||
if command_result:
|
||||
command_class, matched_groups, command_info = command_result
|
||||
plugin_name = command_info.plugin_name
|
||||
command_name = command_info.name
|
||||
if (
|
||||
message.chat_stream
|
||||
and message.chat_stream.stream_id
|
||||
and command_name
|
||||
in global_announcement_manager.get_disabled_chat_commands(message.chat_stream.stream_id)
|
||||
):
|
||||
logger.info("用户禁用的命令,跳过处理")
|
||||
return False, None, True
|
||||
|
||||
message.is_command = True
|
||||
command_class, matched_groups, intercept_message, plugin_name = command_result
|
||||
|
||||
# 获取插件配置
|
||||
plugin_config = component_registry.get_plugin_config(plugin_name)
|
||||
@@ -104,7 +114,7 @@ class ChatBot:
|
||||
|
||||
try:
|
||||
# 执行命令
|
||||
success, response = await command_instance.execute()
|
||||
success, response, intercept_message = await command_instance.execute()
|
||||
|
||||
# 记录命令执行结果
|
||||
if success:
|
||||
@@ -117,8 +127,6 @@ class ChatBot:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令时出错: {command_class.__name__} - {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
try:
|
||||
@@ -127,7 +135,7 @@ class ChatBot:
|
||||
logger.error(f"发送错误消息失败: {send_error}")
|
||||
|
||||
# 命令出错时,根据命令的拦截设置决定是否继续处理消息
|
||||
return True, str(e), not intercept_message
|
||||
return True, str(e), False # 出错时继续处理消息
|
||||
|
||||
# 没有找到命令,继续处理消息
|
||||
return False, None, True
|
||||
@@ -135,13 +143,12 @@ class ChatBot:
|
||||
except Exception as e:
|
||||
logger.error(f"处理命令时出错: {e}")
|
||||
return False, None, True # 出错时继续处理消息
|
||||
|
||||
|
||||
async def hanle_notice_message(self, message: MessageRecv):
|
||||
if message.message_info.message_id == "notice":
|
||||
logger.info("收到notice消息,暂时不支持处理")
|
||||
return True
|
||||
|
||||
|
||||
|
||||
async def do_s4u(self, message_data: Dict[str, Any]):
|
||||
message = MessageRecvS4U(message_data)
|
||||
group_info = message.message_info.group_info
|
||||
@@ -163,7 +170,6 @@ class ChatBot:
|
||||
|
||||
return
|
||||
|
||||
|
||||
async def message_process(self, message_data: Dict[str, Any]) -> None:
|
||||
"""处理转化后的统一格式消息
|
||||
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
|
||||
@@ -179,8 +185,6 @@ class ChatBot:
|
||||
- 性能计时
|
||||
"""
|
||||
try:
|
||||
|
||||
|
||||
# 确保所有任务已启动
|
||||
await self._ensure_started()
|
||||
|
||||
@@ -201,11 +205,10 @@ class ChatBot:
|
||||
# print(message_data)
|
||||
# logger.debug(str(message_data))
|
||||
message = MessageRecv(message_data)
|
||||
|
||||
|
||||
if await self.hanle_notice_message(message):
|
||||
return
|
||||
|
||||
|
||||
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
if message.message_info.additional_config:
|
||||
@@ -214,9 +217,6 @@ class ChatBot:
|
||||
await MessageStorage.update_message(message)
|
||||
return
|
||||
|
||||
if not await events_manager.handle_mai_events(EventType.ON_MESSAGE, message):
|
||||
return
|
||||
|
||||
get_chat_manager().register_message(message)
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
@@ -229,11 +229,10 @@ class ChatBot:
|
||||
|
||||
# 处理消息内容,生成纯文本
|
||||
await message.process()
|
||||
|
||||
|
||||
# if await self.check_ban_content(message):
|
||||
# logger.warning(f"检测到消息中含有违法,色情,暴力,反动,敏感内容,消息内容:{message.processed_plain_text},发送者:{message.message_info.user_info.user_nickname}")
|
||||
# return
|
||||
|
||||
|
||||
# 过滤检查
|
||||
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
||||
@@ -252,6 +251,9 @@ class ChatBot:
|
||||
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
|
||||
return
|
||||
|
||||
if not await events_manager.handle_mai_events(EventType.ON_MESSAGE, message):
|
||||
return
|
||||
|
||||
# 确认从接口发来的message是否有自定义的prompt模板信息
|
||||
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
||||
template_group_name: Optional[str] = message.message_info.template_info.template_name # type: ignore
|
||||
|
||||
@@ -163,20 +163,25 @@ class ChatManager:
|
||||
"""注册消息到聊天流"""
|
||||
stream_id = self._generate_stream_id(
|
||||
message.message_info.platform, # type: ignore
|
||||
message.message_info.user_info, # type: ignore
|
||||
message.message_info.user_info,
|
||||
message.message_info.group_info,
|
||||
)
|
||||
self.last_messages[stream_id] = message
|
||||
# logger.debug(f"注册消息到聊天流: {stream_id}")
|
||||
|
||||
@staticmethod
|
||||
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
|
||||
def _generate_stream_id(
|
||||
platform: str, user_info: Optional[UserInfo], group_info: Optional[GroupInfo] = None
|
||||
) -> str:
|
||||
"""生成聊天流唯一ID"""
|
||||
if not user_info and not group_info:
|
||||
raise ValueError("用户信息或群组信息必须提供")
|
||||
|
||||
if group_info:
|
||||
# 组合关键信息
|
||||
components = [platform, str(group_info.group_id)]
|
||||
else:
|
||||
components = [platform, str(user_info.user_id), "private"]
|
||||
components = [platform, str(user_info.user_id), "private"] # type: ignore
|
||||
|
||||
# 使用MD5生成唯一ID
|
||||
key = "_".join(components)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, List, Optional, Type
|
||||
from typing import Dict, Optional, Type
|
||||
from src.plugin_system.base.base_action import BaseAction
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.common.logger import get_logger
|
||||
@@ -22,53 +22,14 @@ class ActionManager:
|
||||
|
||||
def __init__(self):
|
||||
"""初始化动作管理器"""
|
||||
# 所有注册的动作集合
|
||||
self._registered_actions: Dict[str, ActionInfo] = {}
|
||||
|
||||
# 当前正在使用的动作集合,默认加载默认动作
|
||||
self._using_actions: Dict[str, ActionInfo] = {}
|
||||
|
||||
# 加载插件动作
|
||||
self._load_plugin_actions()
|
||||
|
||||
# 初始化时将默认动作加载到使用中的动作
|
||||
self._using_actions = component_registry.get_default_actions()
|
||||
|
||||
def _load_plugin_actions(self) -> None:
|
||||
"""
|
||||
加载所有插件系统中的动作
|
||||
"""
|
||||
try:
|
||||
# 从新插件系统获取Action组件
|
||||
self._load_plugin_system_actions()
|
||||
logger.debug("从插件系统加载Action组件成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载插件动作失败: {e}")
|
||||
|
||||
def _load_plugin_system_actions(self) -> None:
|
||||
"""从插件系统的component_registry加载Action组件"""
|
||||
try:
|
||||
# 获取所有Action组件
|
||||
action_components: Dict[str, ActionInfo] = component_registry.get_components_by_type(ComponentType.ACTION) # type: ignore
|
||||
|
||||
for action_name, action_info in action_components.items():
|
||||
if action_name in self._registered_actions:
|
||||
logger.debug(f"Action组件 {action_name} 已存在,跳过")
|
||||
continue
|
||||
|
||||
self._registered_actions[action_name] = action_info
|
||||
|
||||
logger.debug(
|
||||
f"从插件系统加载Action组件: {action_name} (插件: {getattr(action_info, 'plugin_name', 'unknown')})"
|
||||
)
|
||||
|
||||
logger.info(f"加载了 {len(action_components)} 个Action动作")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从插件系统加载Action组件失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
# === 执行Action方法 ===
|
||||
|
||||
def create_action(
|
||||
self,
|
||||
@@ -139,36 +100,11 @@ class ActionManager:
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
def get_registered_actions(self) -> Dict[str, ActionInfo]:
|
||||
"""获取所有已注册的动作集"""
|
||||
return self._registered_actions.copy()
|
||||
|
||||
def get_using_actions(self) -> Dict[str, ActionInfo]:
|
||||
"""获取当前正在使用的动作集合"""
|
||||
return self._using_actions.copy()
|
||||
|
||||
def add_action_to_using(self, action_name: str) -> bool:
|
||||
"""
|
||||
添加已注册的动作到当前使用的动作集
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
|
||||
Returns:
|
||||
bool: 添加是否成功
|
||||
"""
|
||||
if action_name not in self._registered_actions:
|
||||
logger.warning(f"添加失败: 动作 {action_name} 未注册")
|
||||
return False
|
||||
|
||||
if action_name in self._using_actions:
|
||||
logger.info(f"动作 {action_name} 已经在使用中")
|
||||
return True
|
||||
|
||||
self._using_actions[action_name] = self._registered_actions[action_name]
|
||||
logger.info(f"添加动作 {action_name} 到使用集")
|
||||
return True
|
||||
|
||||
# === Modify相关方法 ===
|
||||
def remove_action_from_using(self, action_name: str) -> bool:
|
||||
"""
|
||||
从当前使用的动作集中移除指定动作
|
||||
@@ -187,79 +123,8 @@ class ActionManager:
|
||||
logger.debug(f"已从使用集中移除动作 {action_name}")
|
||||
return True
|
||||
|
||||
# def add_action(self, action_name: str, description: str, parameters: Dict = None, require: List = None) -> bool:
|
||||
# """
|
||||
# 添加新的动作到注册集
|
||||
|
||||
# Args:
|
||||
# action_name: 动作名称
|
||||
# description: 动作描述
|
||||
# parameters: 动作参数定义,默认为空字典
|
||||
# require: 动作依赖项,默认为空列表
|
||||
|
||||
# Returns:
|
||||
# bool: 添加是否成功
|
||||
# """
|
||||
# if action_name in self._registered_actions:
|
||||
# return False
|
||||
|
||||
# if parameters is None:
|
||||
# parameters = {}
|
||||
# if require is None:
|
||||
# require = []
|
||||
|
||||
# action_info = {"description": description, "parameters": parameters, "require": require}
|
||||
|
||||
# self._registered_actions[action_name] = action_info
|
||||
# return True
|
||||
|
||||
def remove_action(self, action_name: str) -> bool:
|
||||
"""从注册集移除指定动作"""
|
||||
if action_name not in self._registered_actions:
|
||||
return False
|
||||
del self._registered_actions[action_name]
|
||||
# 如果在使用集中也存在,一并移除
|
||||
if action_name in self._using_actions:
|
||||
del self._using_actions[action_name]
|
||||
return True
|
||||
|
||||
def temporarily_remove_actions(self, actions_to_remove: List[str]) -> None:
|
||||
"""临时移除使用集中的指定动作"""
|
||||
for name in actions_to_remove:
|
||||
self._using_actions.pop(name, None)
|
||||
|
||||
def restore_actions(self) -> None:
|
||||
"""恢复到默认动作集"""
|
||||
actions_to_restore = list(self._using_actions.keys())
|
||||
self._using_actions = component_registry.get_default_actions()
|
||||
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
|
||||
|
||||
def add_system_action_if_needed(self, action_name: str) -> bool:
|
||||
"""
|
||||
根据需要添加系统动作到使用集
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
|
||||
Returns:
|
||||
bool: 是否成功添加
|
||||
"""
|
||||
if action_name in self._registered_actions and action_name not in self._using_actions:
|
||||
self._using_actions[action_name] = self._registered_actions[action_name]
|
||||
logger.info(f"临时添加系统动作到使用集: {action_name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_action(self, action_name: str) -> Optional[Type[BaseAction]]:
|
||||
"""
|
||||
获取指定动作的处理器类
|
||||
|
||||
Args:
|
||||
action_name: 动作名称
|
||||
|
||||
Returns:
|
||||
Optional[Type[BaseAction]]: 动作处理器类,如果不存在则返回None
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_component_class(action_name, ComponentType.ACTION) # type: ignore
|
||||
|
||||
@@ -2,7 +2,7 @@ import random
|
||||
import asyncio
|
||||
import hashlib
|
||||
import time
|
||||
from typing import List, Any, Dict, TYPE_CHECKING
|
||||
from typing import List, Any, Dict, TYPE_CHECKING, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
@@ -11,6 +11,7 @@ from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageCo
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
||||
from src.plugin_system.base.component_types import ActionInfo, ActionActivationType
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
@@ -47,7 +48,6 @@ class ActionModifier:
|
||||
|
||||
async def modify_actions(
|
||||
self,
|
||||
history_loop=None,
|
||||
message_content: str = "",
|
||||
): # sourcery skip: use-named-expression
|
||||
"""
|
||||
@@ -61,8 +61,9 @@ class ActionModifier:
|
||||
"""
|
||||
logger.debug(f"{self.log_prefix}开始完整动作修改流程")
|
||||
|
||||
removals_s1 = []
|
||||
removals_s2 = []
|
||||
removals_s1: List[Tuple[str, str]] = []
|
||||
removals_s2: List[Tuple[str, str]] = []
|
||||
removals_s3: List[Tuple[str, str]] = []
|
||||
|
||||
self.action_manager.restore_actions()
|
||||
all_actions = self.action_manager.get_using_actions()
|
||||
@@ -84,25 +85,28 @@ class ActionModifier:
|
||||
if message_content:
|
||||
chat_content = chat_content + "\n" + f"现在,最新的消息是:{message_content}"
|
||||
|
||||
# === 第一阶段:传统观察处理 ===
|
||||
# if history_loop:
|
||||
# removals_from_loop = await self.analyze_loop_actions(history_loop)
|
||||
# if removals_from_loop:
|
||||
# removals_s1.extend(removals_from_loop)
|
||||
# === 第一阶段:去除用户自行禁用的 ===
|
||||
disabled_actions = global_announcement_manager.get_disabled_chat_actions(self.chat_id)
|
||||
if disabled_actions:
|
||||
for disabled_action_name in disabled_actions:
|
||||
if disabled_action_name in all_actions:
|
||||
removals_s1.append((disabled_action_name, "用户自行禁用"))
|
||||
self.action_manager.remove_action_from_using(disabled_action_name)
|
||||
logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用")
|
||||
|
||||
# 检查动作的关联类型
|
||||
# === 第二阶段:检查动作的关联类型 ===
|
||||
chat_context = self.chat_stream.context
|
||||
type_mismatched_actions = self._check_action_associated_types(all_actions, chat_context)
|
||||
|
||||
if type_mismatched_actions:
|
||||
removals_s1.extend(type_mismatched_actions)
|
||||
removals_s2.extend(type_mismatched_actions)
|
||||
|
||||
# 应用第一阶段的移除
|
||||
for action_name, reason in removals_s1:
|
||||
# 应用第二阶段的移除
|
||||
for action_name, reason in removals_s2:
|
||||
self.action_manager.remove_action_from_using(action_name)
|
||||
logger.debug(f"{self.log_prefix}阶段一移除动作: {action_name},原因: {reason}")
|
||||
logger.debug(f"{self.log_prefix}阶段二移除动作: {action_name},原因: {reason}")
|
||||
|
||||
# === 第二阶段:激活类型判定 ===
|
||||
# === 第三阶段:激活类型判定 ===
|
||||
if chat_content is not None:
|
||||
logger.debug(f"{self.log_prefix}开始激活类型判定阶段")
|
||||
|
||||
@@ -110,18 +114,18 @@ class ActionModifier:
|
||||
current_using_actions = self.action_manager.get_using_actions()
|
||||
|
||||
# 获取因激活类型判定而需要移除的动作
|
||||
removals_s2 = await self._get_deactivated_actions_by_type(
|
||||
removals_s3 = await self._get_deactivated_actions_by_type(
|
||||
current_using_actions,
|
||||
chat_content,
|
||||
)
|
||||
|
||||
# 应用第二阶段的移除
|
||||
for action_name, reason in removals_s2:
|
||||
# 应用第三阶段的移除
|
||||
for action_name, reason in removals_s3:
|
||||
self.action_manager.remove_action_from_using(action_name)
|
||||
logger.debug(f"{self.log_prefix}阶段二移除动作: {action_name},原因: {reason}")
|
||||
logger.debug(f"{self.log_prefix}阶段三移除动作: {action_name},原因: {reason}")
|
||||
|
||||
# === 统一日志记录 ===
|
||||
all_removals = removals_s1 + removals_s2
|
||||
all_removals = removals_s1 + removals_s2 + removals_s3
|
||||
removals_summary: str = ""
|
||||
if all_removals:
|
||||
removals_summary = " | ".join([f"{name}({reason})" for name, reason in all_removals])
|
||||
@@ -131,7 +135,7 @@ class ActionModifier:
|
||||
)
|
||||
|
||||
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
|
||||
type_mismatched_actions = []
|
||||
type_mismatched_actions: List[Tuple[str, str]] = []
|
||||
for action_name, action_info in all_actions.items():
|
||||
if action_info.associated_types and not chat_context.check_types(action_info.associated_types):
|
||||
associated_types_str = ", ".join(action_info.associated_types)
|
||||
@@ -318,7 +322,7 @@ class ActionModifier:
|
||||
action_name: str,
|
||||
action_info: ActionInfo,
|
||||
chat_content: str = "",
|
||||
) -> bool:
|
||||
) -> bool: # sourcery skip: move-assign-in-block, use-named-expression
|
||||
"""
|
||||
使用LLM判定是否应该激活某个action
|
||||
|
||||
|
||||
@@ -19,8 +19,8 @@ from src.chat.utils.chat_message_builder import (
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode
|
||||
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ComponentType
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
logger = get_logger("planner")
|
||||
|
||||
@@ -99,7 +99,7 @@ class ActionPlanner:
|
||||
|
||||
async def plan(
|
||||
self, mode: ChatMode = ChatMode.FOCUS
|
||||
) -> Tuple[Dict[str, Dict[str, Any] | str], Optional[Dict[str, Any]]]: # sourcery skip: dict-comprehension
|
||||
) -> Tuple[Dict[str, Dict[str, Any] | str], Optional[Dict[str, Any]]]:
|
||||
"""
|
||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
||||
"""
|
||||
@@ -113,16 +113,17 @@ class ActionPlanner:
|
||||
|
||||
try:
|
||||
is_group_chat = True
|
||||
|
||||
is_group_chat, chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||
logger.debug(f"{self.log_prefix}获取到聊天信息 - 群聊: {is_group_chat}, 目标信息: {chat_target_info}")
|
||||
|
||||
current_available_actions_dict = self.action_manager.get_using_actions()
|
||||
|
||||
# 获取完整的动作信息
|
||||
all_registered_actions = self.action_manager.get_registered_actions()
|
||||
|
||||
for action_name in current_available_actions_dict.keys():
|
||||
all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
|
||||
ComponentType.ACTION
|
||||
)
|
||||
current_available_actions = {}
|
||||
for action_name in current_available_actions_dict:
|
||||
if action_name in all_registered_actions:
|
||||
current_available_actions[action_name] = all_registered_actions[action_name]
|
||||
else:
|
||||
@@ -234,10 +235,13 @@ class ActionPlanner:
|
||||
"is_parallel": is_parallel,
|
||||
}
|
||||
|
||||
return {
|
||||
"action_result": action_result,
|
||||
"action_prompt": prompt,
|
||||
}, target_message
|
||||
return (
|
||||
{
|
||||
"action_result": action_result,
|
||||
"action_prompt": prompt,
|
||||
},
|
||||
target_message,
|
||||
)
|
||||
|
||||
async def build_planner_prompt(
|
||||
self,
|
||||
@@ -275,23 +279,29 @@ class ActionPlanner:
|
||||
self.last_obs_time_mark = time.time()
|
||||
|
||||
if mode == ChatMode.FOCUS:
|
||||
mentioned_bonus = ""
|
||||
if global_config.chat.mentioned_bot_inevitable_reply:
|
||||
mentioned_bonus = "\n- 有人提到你"
|
||||
if global_config.chat.at_bot_inevitable_reply:
|
||||
mentioned_bonus = "\n- 有人提到你,或者at你"
|
||||
|
||||
|
||||
by_what = "聊天内容"
|
||||
target_prompt = '\n "target_message_id":"触发action的消息id"'
|
||||
no_action_block = """重要说明1:
|
||||
no_action_block = f"""重要说明1:
|
||||
- 'no_reply' 表示只进行不进行回复,等待合适的回复时机
|
||||
- 当你刚刚发送了消息,没有人回复时,选择no_reply
|
||||
- 当你一次发送了太多消息,为了避免打扰聊天节奏,选择no_reply
|
||||
|
||||
动作:reply
|
||||
动作描述:参与聊天回复,发送文本进行表达
|
||||
- 你想要闲聊或者随便附和
|
||||
- 有人提到你
|
||||
- 你想要闲聊或者随便附和{mentioned_bonus}
|
||||
- 如果你刚刚进行了回复,不要对同一个话题重复回应
|
||||
{
|
||||
{{
|
||||
"action": "reply",
|
||||
"target_message_id":"触发action的消息id",
|
||||
"reason":"回复的原因"
|
||||
}
|
||||
}}
|
||||
|
||||
"""
|
||||
else:
|
||||
|
||||
@@ -6,7 +6,7 @@ import re
|
||||
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from datetime import datetime
|
||||
from src.chat.mai_thinking.mai_think import mai_thinking_manager
|
||||
from src.mais4u.mai_think import mai_thinking_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.individuality.individuality import get_individuality
|
||||
@@ -30,9 +30,6 @@ from src.plugin_system.base.component_types import ActionInfo
|
||||
|
||||
logger = get_logger("replyer")
|
||||
|
||||
ENABLE_S2S_MODE = True
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt("你正在qq群里聊天,下面是群里在聊的内容:", "chat_target_group1")
|
||||
Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1")
|
||||
@@ -60,7 +57,6 @@ def init_prompt():
|
||||
现在请你读读之前的聊天记录,并给出回复
|
||||
{config_expression_style}。注意不要复读你说过的话
|
||||
{keywords_reaction_prompt}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,at或 @等 )。只输出回复内容。
|
||||
{moderation_prompt}
|
||||
不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容""",
|
||||
"default_generator_prompt",
|
||||
@@ -78,6 +74,7 @@ def init_prompt():
|
||||
|
||||
你正在{chat_target_2},{reply_target_block}
|
||||
对这句话,你想表达,原句:{raw_reply},原因是:{reason}。你现在要思考怎么组织回复
|
||||
你现在的心情是:{mood_state}
|
||||
你需要使用合适的语法和句法,参考聊天内容,组织一条日常且口语化的回复。请你修改你想表达的原句,符合你的表达风格和语言习惯
|
||||
{config_expression_style},你可以完全重组回复,保留最基本的表达含义就好,但重组后保持语意通顺。
|
||||
{keywords_reaction_prompt}
|
||||
@@ -98,29 +95,29 @@ def init_prompt():
|
||||
{relation_info_block}
|
||||
{extra_info_block}
|
||||
|
||||
你是一个AI虚拟主播,正在直播QQ聊天,同时也在直播间回复弹幕,不过回复的时候不用过多提及这点
|
||||
|
||||
{identity}
|
||||
|
||||
{action_descriptions}
|
||||
你现在的主要任务是和 {sender_name} 聊天。同时,也有其他用户会参与你们的聊天,你可以参考他们的回复内容,但是你主要还是关注你和{sender_name}的聊天内容。你现在的心情是:{mood_state}
|
||||
你现在的主要任务是和 {sender_name} 聊天。同时,也有其他用户会参与你们的聊天,你可以参考他们的回复内容,但是你主要还是关注你和{sender_name}的聊天内容。
|
||||
|
||||
{background_dialogue_prompt}
|
||||
--------------------------------
|
||||
{time_block}
|
||||
这是你和{sender_name}的对话,你们正在交流中:
|
||||
|
||||
{core_dialogue_prompt}
|
||||
|
||||
{reply_target_block}
|
||||
对方最新发送的内容:{message_txt}
|
||||
回复可以简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。
|
||||
{config_expression_style}。注意不要复读你说过的话
|
||||
你现在的心情是:{mood_state}
|
||||
{config_expression_style}
|
||||
注意不要复读你说过的话
|
||||
{keywords_reaction_prompt}
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,at或 @等 )。只输出回复内容。
|
||||
{moderation_prompt}
|
||||
不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容,现在{sender_name}正在等待你的回复。
|
||||
你的回复风格不要浮夸,有逻辑和条理,请你继续回复{sender_name}。
|
||||
你的发言:
|
||||
不要浮夸,不要夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出一条回复内容就好
|
||||
现在,你说:
|
||||
""",
|
||||
"s4u_style_prompt",
|
||||
)
|
||||
@@ -133,7 +130,6 @@ class DefaultReplyer:
|
||||
model_configs: Optional[List[Dict[str, Any]]] = None,
|
||||
request_type: str = "focus.replyer",
|
||||
):
|
||||
self.log_prefix = "replyer"
|
||||
self.request_type = request_type
|
||||
|
||||
if model_configs:
|
||||
@@ -197,7 +193,7 @@ class DefaultReplyer:
|
||||
}
|
||||
for key, value in reply_data.items():
|
||||
if not value:
|
||||
logger.debug(f"{self.log_prefix} 回复数据跳过{key},生成回复时将忽略。")
|
||||
logger.debug(f"回复数据跳过{key},生成回复时将忽略。")
|
||||
|
||||
# 3. 构建 Prompt
|
||||
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
||||
@@ -218,7 +214,7 @@ class DefaultReplyer:
|
||||
# 加权随机选择一个模型配置
|
||||
selected_model_config = self._select_weighted_model_config()
|
||||
logger.info(
|
||||
f"{self.log_prefix} 使用模型配置: {selected_model_config.get('name', 'N/A')} (权重: {selected_model_config.get('weight', 1.0)})"
|
||||
f"使用模型生成回复: {selected_model_config.get('name', 'N/A')} (选中概率: {selected_model_config.get('weight', 1.0)})"
|
||||
)
|
||||
|
||||
express_model = LLMRequest(
|
||||
@@ -227,9 +223,9 @@ class DefaultReplyer:
|
||||
)
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"{self.log_prefix}\n{prompt}\n")
|
||||
logger.info(f"\n{prompt}\n")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix}\n{prompt}\n")
|
||||
logger.debug(f"\n{prompt}\n")
|
||||
|
||||
content, (reasoning_content, model_name) = await express_model.generate_response_async(prompt)
|
||||
|
||||
@@ -237,13 +233,13 @@ class DefaultReplyer:
|
||||
|
||||
except Exception as llm_e:
|
||||
# 精简报错信息
|
||||
logger.error(f"{self.log_prefix}LLM 生成失败: {llm_e}")
|
||||
logger.error(f"LLM 生成失败: {llm_e}")
|
||||
return False, None, prompt # LLM 调用失败则无法生成回复
|
||||
|
||||
return True, content, prompt
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}回复生成意外失败: {e}")
|
||||
logger.error(f"回复生成意外失败: {e}")
|
||||
traceback.print_exc()
|
||||
return False, None, prompt
|
||||
|
||||
@@ -274,7 +270,7 @@ class DefaultReplyer:
|
||||
reasoning_content = None
|
||||
model_name = "unknown_model"
|
||||
if not prompt:
|
||||
logger.error(f"{self.log_prefix}Prompt 构建失败,无法生成回复。")
|
||||
logger.error("Prompt 构建失败,无法生成回复。")
|
||||
return False, None
|
||||
|
||||
try:
|
||||
@@ -282,7 +278,7 @@ class DefaultReplyer:
|
||||
# 加权随机选择一个模型配置
|
||||
selected_model_config = self._select_weighted_model_config()
|
||||
logger.info(
|
||||
f"{self.log_prefix} 使用模型配置进行重写: {selected_model_config.get('name', 'N/A')} (权重: {selected_model_config.get('weight', 1.0)})"
|
||||
f"使用模型重写回复: {selected_model_config.get('name', 'N/A')} (选中概率: {selected_model_config.get('weight', 1.0)})"
|
||||
)
|
||||
|
||||
express_model = LLMRequest(
|
||||
@@ -296,13 +292,13 @@ class DefaultReplyer:
|
||||
|
||||
except Exception as llm_e:
|
||||
# 精简报错信息
|
||||
logger.error(f"{self.log_prefix}LLM 生成失败: {llm_e}")
|
||||
logger.error(f"LLM 生成失败: {llm_e}")
|
||||
return False, None # LLM 调用失败则无法生成回复
|
||||
|
||||
return True, content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}回复生成意外失败: {e}")
|
||||
logger.error(f"回复生成意外失败: {e}")
|
||||
traceback.print_exc()
|
||||
return False, None
|
||||
|
||||
@@ -322,7 +318,7 @@ class DefaultReplyer:
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
if not person_id:
|
||||
logger.warning(f"{self.log_prefix} 未找到用户 {sender} 的ID,跳过信息提取")
|
||||
logger.warning(f"未找到用户 {sender} 的ID,跳过信息提取")
|
||||
return f"你完全不认识{sender},不理解ta的相关信息。"
|
||||
|
||||
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
@@ -341,7 +337,7 @@ class DefaultReplyer:
|
||||
)
|
||||
|
||||
if selected_expressions:
|
||||
logger.debug(f"{self.log_prefix} 使用处理器选中的{len(selected_expressions)}个表达方式")
|
||||
logger.debug(f"使用处理器选中的{len(selected_expressions)}个表达方式")
|
||||
for expr in selected_expressions:
|
||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
||||
expr_type = expr.get("type", "style")
|
||||
@@ -350,7 +346,7 @@ class DefaultReplyer:
|
||||
else:
|
||||
style_habits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix} 没有从处理器获得表达方式,将使用空的表达方式")
|
||||
logger.debug("没有从处理器获得表达方式,将使用空的表达方式")
|
||||
# 不再在replyer中进行随机选择,全部交给处理器处理
|
||||
|
||||
style_habits_str = "\n".join(style_habits)
|
||||
@@ -358,10 +354,19 @@ class DefaultReplyer:
|
||||
|
||||
# 动态构建expression habits块
|
||||
expression_habits_block = ""
|
||||
expression_habits_title = ""
|
||||
if style_habits_str.strip():
|
||||
expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habits_str}\n\n"
|
||||
expression_habits_title = "你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:"
|
||||
expression_habits_block += f"{style_habits_str}\n"
|
||||
if grammar_habits_str.strip():
|
||||
expression_habits_block += f"请你根据情景使用以下句法:\n{grammar_habits_str}\n"
|
||||
expression_habits_title = "你可以选择下面的句法进行回复,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式使用:"
|
||||
expression_habits_block += f"{grammar_habits_str}\n"
|
||||
|
||||
if style_habits_str.strip() and grammar_habits_str.strip():
|
||||
expression_habits_title = "你可以参考以下的语言习惯和句法,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式结合到你的回复中:"
|
||||
|
||||
expression_habits_block = f"{expression_habits_title}\n{expression_habits_block}"
|
||||
|
||||
|
||||
return expression_habits_block
|
||||
|
||||
@@ -432,19 +437,23 @@ class DefaultReplyer:
|
||||
tool_info_str += f"- 【{tool_name}】{result_type}: {content}\n"
|
||||
|
||||
tool_info_str += "以上是你获取到的实时信息,请在回复时参考这些信息。"
|
||||
logger.info(f"{self.log_prefix} 获取到 {len(tool_results)} 个工具结果")
|
||||
logger.info(f"获取到 {len(tool_results)} 个工具结果")
|
||||
|
||||
return tool_info_str
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix} 未获取到任何工具结果")
|
||||
logger.debug("未获取到任何工具结果")
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 工具信息获取失败: {e}")
|
||||
logger.error(f"工具信息获取失败: {e}")
|
||||
return ""
|
||||
|
||||
def _parse_reply_target(self, target_message: str) -> tuple:
|
||||
sender = ""
|
||||
target = ""
|
||||
# 添加None检查,防止NoneType错误
|
||||
if target_message is None:
|
||||
return sender, target
|
||||
if ":" in target_message or ":" in target_message:
|
||||
# 使用正则表达式匹配中文或英文冒号
|
||||
parts = re.split(pattern=r"[::]", string=target_message, maxsplit=1)
|
||||
@@ -457,6 +466,10 @@ class DefaultReplyer:
|
||||
# 关键词检测与反应
|
||||
keywords_reaction_prompt = ""
|
||||
try:
|
||||
# 添加None检查,防止NoneType错误
|
||||
if target is None:
|
||||
return keywords_reaction_prompt
|
||||
|
||||
# 处理关键词规则
|
||||
for rule in global_config.keyword_reaction.keyword_rules:
|
||||
if any(keyword in target for keyword in rule.keywords):
|
||||
@@ -510,19 +523,21 @@ class DefaultReplyer:
|
||||
for msg_dict in message_list_before_now:
|
||||
try:
|
||||
msg_user_id = str(msg_dict.get("user_id"))
|
||||
if msg_user_id == bot_id or msg_user_id == target_user_id:
|
||||
reply_to = msg_dict.get("reply_to", "")
|
||||
_platform, reply_to_user_id = self._parse_reply_target(reply_to)
|
||||
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
|
||||
# bot 和目标用户的对话
|
||||
core_dialogue_list.append(msg_dict)
|
||||
else:
|
||||
# 其他用户的对话
|
||||
background_dialogue_list.append(msg_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"无法处理历史消息记录: {msg_dict}, 错误: {e}")
|
||||
logger.error(f"记录: {msg_dict}, 错误: {e}")
|
||||
|
||||
# 构建背景对话 prompt
|
||||
background_dialogue_prompt = ""
|
||||
if background_dialogue_list:
|
||||
latest_25_msgs = background_dialogue_list[-int(global_config.chat.max_context_size * 0.6) :]
|
||||
latest_25_msgs = background_dialogue_list[-int(global_config.chat.max_context_size * 0.5) :]
|
||||
background_dialogue_prompt_str = build_readable_messages(
|
||||
latest_25_msgs,
|
||||
replace_bot_name=True,
|
||||
@@ -549,6 +564,34 @@ class DefaultReplyer:
|
||||
|
||||
return core_dialogue_prompt, background_dialogue_prompt
|
||||
|
||||
def build_mai_think_context(
|
||||
self,
|
||||
chat_id: str,
|
||||
memory_block: str,
|
||||
relation_info: str,
|
||||
time_block: str,
|
||||
chat_target_1: str,
|
||||
chat_target_2: str,
|
||||
mood_prompt: str,
|
||||
identity_block: str,
|
||||
sender: str,
|
||||
target: str,
|
||||
chat_info: str,
|
||||
):
|
||||
"""构建 mai_think 上下文信息"""
|
||||
mai_think = mai_thinking_manager.get_mai_think(chat_id)
|
||||
mai_think.memory_block = memory_block
|
||||
mai_think.relation_info_block = relation_info
|
||||
mai_think.time_block = time_block
|
||||
mai_think.chat_target = chat_target_1
|
||||
mai_think.chat_target_2 = chat_target_2
|
||||
mai_think.chat_info = chat_info
|
||||
mai_think.mood_state = mood_prompt
|
||||
mai_think.identity = identity_block
|
||||
mai_think.sender = sender
|
||||
mai_think.target = target
|
||||
return mai_think
|
||||
|
||||
async def build_prompt_reply_context(
|
||||
self,
|
||||
reply_data: Dict[str, Any],
|
||||
@@ -578,9 +621,12 @@ class DefaultReplyer:
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
reply_to = reply_data.get("reply_to", "none")
|
||||
extra_info_block = reply_data.get("extra_info", "") or reply_data.get("extra_info_block", "")
|
||||
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
|
||||
mood_prompt = chat_mood.mood_state
|
||||
|
||||
if global_config.mood.enable_mood:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
|
||||
mood_prompt = chat_mood.mood_state
|
||||
else:
|
||||
mood_prompt = ""
|
||||
|
||||
sender, target = self._parse_reply_target(reply_to)
|
||||
|
||||
@@ -628,44 +674,51 @@ class DefaultReplyer:
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
# 并行执行四个构建任务
|
||||
# 并行执行五个构建任务
|
||||
task_results = await asyncio.gather(
|
||||
self._time_and_run_task(
|
||||
self.build_expression_habits(chat_talking_prompt_short, target), "build_expression_habits"
|
||||
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
|
||||
),
|
||||
self._time_and_run_task(
|
||||
self.build_relation_info(reply_data), "build_relation_info"
|
||||
self.build_relation_info(reply_data), "relation_info"
|
||||
),
|
||||
self._time_and_run_task(self.build_memory_block(chat_talking_prompt_short, target), "build_memory_block"),
|
||||
self._time_and_run_task(self.build_memory_block(chat_talking_prompt_short, target), "memory_block"),
|
||||
self._time_and_run_task(
|
||||
self.build_tool_info(chat_talking_prompt_short, reply_data, enable_tool=enable_tool), "build_tool_info"
|
||||
self.build_tool_info(chat_talking_prompt_short, reply_data, enable_tool=enable_tool), "tool_info"
|
||||
),
|
||||
self._time_and_run_task(
|
||||
get_prompt_info(target, threshold=0.38), "prompt_info"
|
||||
),
|
||||
)
|
||||
|
||||
# 任务名称中英文映射
|
||||
task_name_mapping = {
|
||||
"expression_habits": "选取表达方式",
|
||||
"relation_info": "感受关系",
|
||||
"memory_block": "回忆",
|
||||
"tool_info": "使用工具",
|
||||
"prompt_info": "获取知识"
|
||||
}
|
||||
|
||||
# 处理结果
|
||||
timing_logs = []
|
||||
results_dict = {}
|
||||
for name, result, duration in task_results:
|
||||
results_dict[name] = result
|
||||
timing_logs.append(f"{name}: {duration:.4f}s")
|
||||
chinese_name = task_name_mapping.get(name, name)
|
||||
timing_logs.append(f"{chinese_name}: {duration:.1f}s")
|
||||
if duration > 8:
|
||||
logger.warning(f"回复生成前信息获取耗时过长: {name} 耗时: {duration:.4f}s,请使用更快的模型")
|
||||
logger.info(f"回复生成前信息获取耗时: {'; '.join(timing_logs)}")
|
||||
logger.warning(f"回复生成前信息获取耗时过长: {chinese_name} 耗时: {duration:.1f}s,请使用更快的模型")
|
||||
logger.info(f"在回复前的步骤耗时: {'; '.join(timing_logs)}")
|
||||
|
||||
expression_habits_block = results_dict["build_expression_habits"]
|
||||
relation_info = results_dict["build_relation_info"]
|
||||
memory_block = results_dict["build_memory_block"]
|
||||
tool_info = results_dict["build_tool_info"]
|
||||
expression_habits_block = results_dict["expression_habits"]
|
||||
relation_info = results_dict["relation_info"]
|
||||
memory_block = results_dict["memory_block"]
|
||||
tool_info = results_dict["tool_info"]
|
||||
prompt_info = results_dict["prompt_info"] # 直接使用格式化后的结果
|
||||
|
||||
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
|
||||
|
||||
if tool_info:
|
||||
tool_info_block = (
|
||||
f"以下是你了解的额外信息信息,现在请你阅读以下内容,进行决策\n{tool_info}\n以上是一些额外的信息。"
|
||||
)
|
||||
else:
|
||||
tool_info_block = ""
|
||||
|
||||
if extra_info_block:
|
||||
extra_info_block = f"以下是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策\n{extra_info_block}\n以上是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策"
|
||||
else:
|
||||
@@ -699,10 +752,6 @@ class DefaultReplyer:
|
||||
else:
|
||||
reply_target_block = ""
|
||||
|
||||
prompt_info = await get_prompt_info(target, threshold=0.38)
|
||||
if prompt_info:
|
||||
prompt_info = await global_prompt_manager.format_prompt("knowledge_prompt", prompt_info=prompt_info)
|
||||
|
||||
template_name = "default_generator_prompt"
|
||||
if is_group_chat:
|
||||
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
|
||||
@@ -742,24 +791,24 @@ class DefaultReplyer:
|
||||
message_list_before_now_long, target_user_id
|
||||
)
|
||||
|
||||
mai_think = mai_thinking_manager.get_mai_think(chat_id)
|
||||
mai_think.memory_block = memory_block
|
||||
mai_think.relation_info_block = relation_info
|
||||
mai_think.time_block = time_block
|
||||
mai_think.chat_target = chat_target_1
|
||||
mai_think.chat_target_2 = chat_target_2
|
||||
# mai_think.chat_info = chat_talking_prompt
|
||||
mai_think.mood_state = mood_prompt
|
||||
mai_think.identity = identity_block
|
||||
mai_think.sender = sender
|
||||
mai_think.target = target
|
||||
|
||||
mai_think.chat_info = f"""
|
||||
self.build_mai_think_context(
|
||||
chat_id=chat_id,
|
||||
memory_block=memory_block,
|
||||
relation_info=relation_info,
|
||||
time_block=time_block,
|
||||
chat_target_1=chat_target_1,
|
||||
chat_target_2=chat_target_2,
|
||||
mood_prompt=mood_prompt,
|
||||
identity_block=identity_block,
|
||||
sender=sender,
|
||||
target=target,
|
||||
chat_info=f"""
|
||||
{background_dialogue_prompt}
|
||||
--------------------------------
|
||||
{time_block}
|
||||
这是你和{sender}的对话,你们正在交流中:
|
||||
{core_dialogue_prompt}"""
|
||||
)
|
||||
|
||||
|
||||
# 使用 s4u 风格的模板
|
||||
@@ -768,7 +817,7 @@ class DefaultReplyer:
|
||||
return await global_prompt_manager.format_prompt(
|
||||
template_name,
|
||||
expression_habits_block=expression_habits_block,
|
||||
tool_info_block=tool_info_block,
|
||||
tool_info_block=tool_info,
|
||||
knowledge_prompt=prompt_info,
|
||||
memory_block=memory_block,
|
||||
relation_info_block=relation_info,
|
||||
@@ -787,17 +836,19 @@ class DefaultReplyer:
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
)
|
||||
else:
|
||||
mai_think = mai_thinking_manager.get_mai_think(chat_id)
|
||||
mai_think.memory_block = memory_block
|
||||
mai_think.relation_info_block = relation_info
|
||||
mai_think.time_block = time_block
|
||||
mai_think.chat_target = chat_target_1
|
||||
mai_think.chat_target_2 = chat_target_2
|
||||
mai_think.chat_info = chat_talking_prompt
|
||||
mai_think.mood_state = mood_prompt
|
||||
mai_think.identity = identity_block
|
||||
mai_think.sender = sender
|
||||
mai_think.target = target
|
||||
self.build_mai_think_context(
|
||||
chat_id=chat_id,
|
||||
memory_block=memory_block,
|
||||
relation_info=relation_info,
|
||||
time_block=time_block,
|
||||
chat_target_1=chat_target_1,
|
||||
chat_target_2=chat_target_2,
|
||||
mood_prompt=mood_prompt,
|
||||
identity_block=identity_block,
|
||||
sender=sender,
|
||||
target=target,
|
||||
chat_info=chat_talking_prompt
|
||||
)
|
||||
|
||||
# 使用原有的模式
|
||||
return await global_prompt_manager.format_prompt(
|
||||
@@ -806,7 +857,7 @@ class DefaultReplyer:
|
||||
chat_target=chat_target_1,
|
||||
chat_info=chat_talking_prompt,
|
||||
memory_block=memory_block,
|
||||
tool_info_block=tool_info_block,
|
||||
tool_info_block=tool_info,
|
||||
knowledge_prompt=prompt_info,
|
||||
extra_info_block=extra_info_block,
|
||||
relation_info_block=relation_info,
|
||||
@@ -836,6 +887,13 @@ class DefaultReplyer:
|
||||
reason = reply_data.get("reason", "")
|
||||
sender, target = self._parse_reply_target(reply_to)
|
||||
|
||||
# 添加情绪状态获取
|
||||
if global_config.mood.enable_mood:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat_id)
|
||||
mood_prompt = chat_mood.mood_state
|
||||
else:
|
||||
mood_prompt = ""
|
||||
|
||||
message_list_before_now_half = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=time.time(),
|
||||
@@ -916,6 +974,7 @@ class DefaultReplyer:
|
||||
reply_target_block=reply_target_block,
|
||||
raw_reply=raw_reply,
|
||||
reason=reason,
|
||||
mood_state=mood_prompt, # 添加情绪状态参数
|
||||
config_expression_style=global_config.expression.expression_style,
|
||||
keywords_reaction_prompt=keywords_reaction_prompt,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
@@ -1012,7 +1071,10 @@ async def get_prompt_info(message: str, threshold: float):
|
||||
related_info += found_knowledge_from_lpmm
|
||||
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒")
|
||||
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
|
||||
return related_info
|
||||
|
||||
# 格式化知识信息
|
||||
formatted_prompt_info = await global_prompt_manager.format_prompt("knowledge_prompt", prompt_info=related_info)
|
||||
return formatted_prompt_info
|
||||
else:
|
||||
logger.debug("从LPMM知识库获取知识失败,可能是从未导入过知识,返回空知识...")
|
||||
return ""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -78,7 +78,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
||||
# print(f"is_mentioned: {is_mentioned}")
|
||||
# print(f"is_at: {is_at}")
|
||||
|
||||
if is_at and global_config.normal_chat.at_bot_inevitable_reply:
|
||||
if is_at and global_config.chat.at_bot_inevitable_reply:
|
||||
reply_probability = 1.0
|
||||
logger.debug("被@,回复概率设置为100%")
|
||||
else:
|
||||
@@ -103,7 +103,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
||||
for nickname in nicknames:
|
||||
if nickname in message_content:
|
||||
is_mentioned = True
|
||||
if is_mentioned and global_config.normal_chat.mentioned_bot_inevitable_reply:
|
||||
if is_mentioned and global_config.chat.mentioned_bot_inevitable_reply:
|
||||
reply_probability = 1.0
|
||||
logger.debug("被提及,回复概率设置为100%")
|
||||
return is_mentioned, reply_probability
|
||||
@@ -619,9 +619,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
||||
chat_target_info = None
|
||||
|
||||
try:
|
||||
chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
|
||||
if chat_stream:
|
||||
if chat_stream := get_chat_manager().get_stream(chat_id):
|
||||
if chat_stream.group_info:
|
||||
is_group_chat = True
|
||||
chat_target_info = None # Explicitly None for group chat
|
||||
@@ -660,8 +658,6 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
||||
chat_target_info = target_info
|
||||
else:
|
||||
logger.warning(f"无法获取 chat_stream for {chat_id} in utils")
|
||||
# Keep defaults: is_group_chat=False, chat_target_info=None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取聊天类型和目标信息时出错 for {chat_id}: {e}", exc_info=True)
|
||||
# Keep defaults on error
|
||||
|
||||
@@ -94,7 +94,7 @@ class ImageManager:
|
||||
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
|
||||
|
||||
async def get_emoji_description(self, image_base64: str) -> str:
|
||||
"""获取表情包描述,带查重和保存功能"""
|
||||
"""获取表情包描述,使用二步走识别并带缓存优化"""
|
||||
try:
|
||||
# 计算图片哈希
|
||||
# 确保base64字符串只包含ASCII字符
|
||||
@@ -107,33 +107,66 @@ class ImageManager:
|
||||
# 查询缓存的描述
|
||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
||||
if cached_description:
|
||||
return f"[表情包,含义看起来是:{cached_description}]"
|
||||
return f"[表情包:{cached_description}]"
|
||||
|
||||
# 调用AI获取描述
|
||||
# === 二步走识别流程 ===
|
||||
|
||||
# 第一步:VLM视觉分析 - 生成详细描述
|
||||
if image_format in ["gif", "GIF"]:
|
||||
image_base64_processed = self.transform_gif(image_base64)
|
||||
if image_base64_processed is None:
|
||||
logger.warning("GIF转换失败,无法获取描述")
|
||||
return "[表情包(GIF处理失败)]"
|
||||
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,使用1-2个词描述一下表情包表达的情感和内容,简短一些,输出一段平文本,只输出1-2个词就好,不要输出其他内容"
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64_processed, "jpg")
|
||||
vlm_prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
detailed_description, _ = await self._llm.generate_response_for_image(vlm_prompt, image_base64_processed, "jpg")
|
||||
else:
|
||||
prompt = "图片是一个表情包,请用使用1-2个词描述一下表情包所表达的情感和内容,简短一些,输出一段平文本,只输出1-2个词就好,不要输出其他内容"
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
||||
vlm_prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
|
||||
detailed_description, _ = await self._llm.generate_response_for_image(vlm_prompt, image_base64, image_format)
|
||||
|
||||
if description is None:
|
||||
logger.warning("AI未能生成表情包描述")
|
||||
return "[表情包(描述生成失败)]"
|
||||
if detailed_description is None:
|
||||
logger.warning("VLM未能生成表情包详细描述")
|
||||
return "[表情包(VLM描述生成失败)]"
|
||||
|
||||
# 第二步:LLM情感分析 - 基于详细描述生成简短的情感标签
|
||||
emotion_prompt = f"""
|
||||
请你基于这个表情包的详细描述,提取出最核心的情感含义,用1-2个词概括。
|
||||
详细描述:'{detailed_description}'
|
||||
|
||||
要求:
|
||||
1. 只输出1-2个最核心的情感词汇
|
||||
2. 从互联网梗、meme的角度理解
|
||||
3. 输出简短精准,不要解释
|
||||
4. 如果有多个词用逗号分隔
|
||||
"""
|
||||
|
||||
# 使用较低温度确保输出稳定
|
||||
emotion_llm = LLMRequest(model=global_config.model.utils, temperature=0.3, max_tokens=50, request_type="emoji")
|
||||
emotion_result, _ = await emotion_llm.generate_response_async(emotion_prompt)
|
||||
|
||||
if emotion_result is None:
|
||||
logger.warning("LLM未能生成情感标签,使用详细描述的前几个词")
|
||||
# 降级处理:从详细描述中提取关键词
|
||||
import jieba
|
||||
words = list(jieba.cut(detailed_description))
|
||||
emotion_result = ",".join(words[:2]) if len(words) >= 2 else (words[0] if words else "表情")
|
||||
|
||||
# 处理情感结果,取前1-2个最重要的标签
|
||||
emotions = [e.strip() for e in emotion_result.replace(",", ",").split(",") if e.strip()]
|
||||
final_emotion = emotions[0] if emotions else "表情"
|
||||
|
||||
# 如果有第二个情感且不重复,也包含进来
|
||||
if len(emotions) > 1 and emotions[1] != emotions[0]:
|
||||
final_emotion = f"{emotions[0]},{emotions[1]}"
|
||||
|
||||
logger.info(f"[二步走识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}")
|
||||
|
||||
# 再次检查缓存,防止并发写入时重复生成
|
||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
||||
if cached_description:
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
|
||||
return f"[表情包,含义看起来是:{cached_description}]"
|
||||
return f"[表情包:{cached_description}]"
|
||||
|
||||
# 根据配置决定是否保存图片
|
||||
# if global_config.emoji.save_emoji:
|
||||
# 生成文件名和路径
|
||||
# 保存表情包文件和元数据(用于可能的后续分析)
|
||||
logger.debug(f"保存表情包: {image_hash}")
|
||||
current_timestamp = time.time()
|
||||
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
|
||||
@@ -146,11 +179,11 @@ class ImageManager:
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
|
||||
# 保存到数据库 (Images表)
|
||||
# 保存到数据库 (Images表) - 包含详细描述用于可能的注册流程
|
||||
try:
|
||||
img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
|
||||
img_obj.path = file_path
|
||||
img_obj.description = description
|
||||
img_obj.description = detailed_description # 保存详细描述
|
||||
img_obj.timestamp = current_timestamp
|
||||
img_obj.save()
|
||||
except Images.DoesNotExist: # type: ignore
|
||||
@@ -158,17 +191,17 @@ class ImageManager:
|
||||
emoji_hash=image_hash,
|
||||
path=file_path,
|
||||
type="emoji",
|
||||
description=description,
|
||||
description=detailed_description, # 保存详细描述
|
||||
timestamp=current_timestamp,
|
||||
)
|
||||
# logger.debug(f"保存表情包元数据: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存表情包文件或元数据失败: {str(e)}")
|
||||
|
||||
# 保存描述到数据库 (ImageDescriptions表)
|
||||
self._save_description_to_db(image_hash, description, "emoji")
|
||||
# 保存最终的情感标签到缓存 (ImageDescriptions表)
|
||||
self._save_description_to_db(image_hash, final_emotion, "emoji")
|
||||
|
||||
return f"[表情包:{description}]"
|
||||
return f"[表情包:{final_emotion}]"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取表情包描述失败: {str(e)}")
|
||||
return "[表情包]"
|
||||
|
||||
@@ -11,7 +11,7 @@ logger = get_logger("chat_voice")
|
||||
|
||||
async def get_voice_text(voice_base64: str) -> str:
|
||||
"""获取音频文件描述"""
|
||||
if not global_config.chat.enable_asr:
|
||||
if not global_config.voice.enable_asr:
|
||||
logger.warning("语音识别未启用,无法处理语音消息")
|
||||
return "[语音]"
|
||||
try:
|
||||
|
||||
@@ -35,7 +35,7 @@ class ClassicalWillingManager(BaseWillingManager):
|
||||
if interested_rate > 0.2:
|
||||
current_willing += interested_rate - 0.2
|
||||
|
||||
if willing_info.is_mentioned_bot and global_config.normal_chat.mentioned_bot_inevitable_reply and current_willing < 2:
|
||||
if willing_info.is_mentioned_bot and global_config.chat.mentioned_bot_inevitable_reply and current_willing < 2:
|
||||
current_willing += 1 if current_willing < 1.0 else 0.05
|
||||
|
||||
self.chat_reply_willing[chat_id] = min(current_willing, 1.0)
|
||||
|
||||
@@ -306,6 +306,7 @@ class Expression(BaseModel):
|
||||
last_active_time = FloatField()
|
||||
chat_id = TextField(index=True)
|
||||
type = TextField()
|
||||
create_date = FloatField(null=True) # 创建日期,允许为空以兼容老数据
|
||||
|
||||
class Meta:
|
||||
table_name = "expression"
|
||||
@@ -449,9 +450,12 @@ def initialize_database():
|
||||
alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}"
|
||||
alter_sql += " NULL" if field_obj.null else " NOT NULL"
|
||||
if hasattr(field_obj, "default") and field_obj.default is not None:
|
||||
# 正确处理不同类型的默认值
|
||||
# 正确处理不同类型的默认值,跳过lambda函数
|
||||
default_value = field_obj.default
|
||||
if isinstance(default_value, str):
|
||||
if callable(default_value):
|
||||
# 跳过lambda函数或其他可调用对象,这些无法在SQL中表示
|
||||
pass
|
||||
elif isinstance(default_value, str):
|
||||
alter_sql += f" DEFAULT '{default_value}'"
|
||||
elif isinstance(default_value, bool):
|
||||
alter_sql += f" DEFAULT {int(default_value)}"
|
||||
|
||||
@@ -321,7 +321,7 @@ MODULE_COLORS = {
|
||||
# 核心模块
|
||||
"main": "\033[1;97m", # 亮白色+粗体 (主程序)
|
||||
"api": "\033[92m", # 亮绿色
|
||||
"emoji": "\033[33m", # 亮绿色
|
||||
"emoji": "\033[38;5;214m", # 橙黄色,偏向橙色但与replyer和action_manager不同
|
||||
"chat": "\033[92m", # 亮蓝色
|
||||
"config": "\033[93m", # 亮黄色
|
||||
"common": "\033[95m", # 亮紫色
|
||||
@@ -329,35 +329,33 @@ MODULE_COLORS = {
|
||||
"lpmm": "\033[96m",
|
||||
"plugin_system": "\033[91m", # 亮红色
|
||||
"person_info": "\033[32m", # 绿色
|
||||
"individuality": "\033[34m", # 蓝色
|
||||
"individuality": "\033[94m", # 显眼的亮蓝色
|
||||
"manager": "\033[35m", # 紫色
|
||||
"llm_models": "\033[36m", # 青色
|
||||
"plugins": "\033[31m", # 红色
|
||||
"plugin_api": "\033[33m", # 黄色
|
||||
"remote": "\033[38;5;93m", # 紫蓝色
|
||||
"remote": "\033[38;5;242m", # 深灰色,更不显眼
|
||||
"planner": "\033[36m",
|
||||
"memory": "\033[34m",
|
||||
"hfc": "\033[96m",
|
||||
"action_manager": "\033[38;5;166m",
|
||||
"hfc": "\033[38;5;81m", # 稍微暗一些的青色,保持可读
|
||||
"action_manager": "\033[38;5;208m", # 橙色,不与replyer重复
|
||||
# 关系系统
|
||||
"relation": "\033[38;5;201m", # 深粉色
|
||||
"relation": "\033[38;5;139m", # 柔和的紫色,不刺眼
|
||||
# 聊天相关模块
|
||||
"normal_chat": "\033[38;5;81m", # 亮蓝绿色
|
||||
"normal_chat_response": "\033[38;5;123m", # 青绿色
|
||||
"heartflow": "\033[38;5;213m", # 粉色
|
||||
"heartflow": "\033[38;5;175m", # 柔和的粉色,不显眼但保持粉色系
|
||||
"sub_heartflow": "\033[38;5;207m", # 粉紫色
|
||||
"subheartflow_manager": "\033[38;5;201m", # 深粉色
|
||||
"background_tasks": "\033[38;5;240m", # 灰色
|
||||
"chat_message": "\033[38;5;45m", # 青色
|
||||
"chat_stream": "\033[38;5;51m", # 亮青色
|
||||
"sender": "\033[38;5;39m", # 蓝色
|
||||
"sender": "\033[38;5;67m", # 稍微暗一些的蓝色,不显眼
|
||||
"message_storage": "\033[38;5;33m", # 深蓝色
|
||||
"expressor": "\033[38;5;166m", # 橙色
|
||||
# 专注聊天模块
|
||||
"replyer": "\033[38;5;166m", # 橙色
|
||||
"base_processor": "\033[38;5;190m", # 绿黄色
|
||||
"working_memory": "\033[38;5;22m", # 深绿色
|
||||
"memory_activator": "\033[34m", # 绿色
|
||||
# 插件系统
|
||||
"plugins": "\033[31m", # 红色
|
||||
"plugin_api": "\033[33m", # 黄色
|
||||
"plugin_manager": "\033[38;5;208m", # 红色
|
||||
"base_plugin": "\033[38;5;202m", # 橙红色
|
||||
"send_api": "\033[38;5;208m", # 橙色
|
||||
@@ -378,9 +376,9 @@ MODULE_COLORS = {
|
||||
"local_storage": "\033[38;5;141m", # 紫色
|
||||
"willing": "\033[38;5;147m", # 浅紫色
|
||||
# 工具模块
|
||||
"tool_use": "\033[38;5;64m", # 深绿色
|
||||
"tool_executor": "\033[38;5;64m", # 深绿色
|
||||
"base_tool": "\033[38;5;70m", # 绿色
|
||||
"tool_use": "\033[38;5;172m", # 橙褐色
|
||||
"tool_executor": "\033[38;5;172m", # 橙褐色
|
||||
"base_tool": "\033[38;5;178m", # 金黄色
|
||||
# 工具和实用模块
|
||||
"prompt_build": "\033[38;5;105m", # 紫色
|
||||
"chat_utils": "\033[38;5;111m", # 蓝色
|
||||
@@ -388,14 +386,16 @@ MODULE_COLORS = {
|
||||
"maibot_statistic": "\033[38;5;129m", # 紫色
|
||||
# 特殊功能插件
|
||||
"mute_plugin": "\033[38;5;240m", # 灰色
|
||||
"example_comprehensive": "\033[38;5;246m", # 浅灰色
|
||||
"core_actions": "\033[38;5;117m", # 深红色
|
||||
"tts_action": "\033[38;5;58m", # 深黄色
|
||||
"doubao_pic_plugin": "\033[38;5;64m", # 深绿色
|
||||
"vtb_action": "\033[38;5;70m", # 绿色
|
||||
# Action组件
|
||||
"no_reply_action": "\033[38;5;196m", # 亮红色,更显眼
|
||||
"reply_action": "\033[38;5;46m", # 亮绿色
|
||||
"base_action": "\033[38;5;250m", # 浅灰色
|
||||
# 数据库和消息
|
||||
"database_model": "\033[38;5;94m", # 橙褐色
|
||||
"maim_message": "\033[38;5;100m", # 绿褐色
|
||||
"maim_message": "\033[38;5;140m", # 紫褐色
|
||||
# 日志系统
|
||||
"logger": "\033[38;5;8m", # 深灰色
|
||||
"confirm": "\033[1;93m", # 黄色+粗体
|
||||
@@ -409,6 +409,34 @@ MODULE_COLORS = {
|
||||
"S4U_chat": "\033[92m", # 深灰色
|
||||
}
|
||||
|
||||
# 定义模块别名映射 - 将真实的logger名称映射到显示的别名
|
||||
MODULE_ALIASES = {
|
||||
# 示例映射
|
||||
"individuality": "人格特质",
|
||||
"emoji": "表情包",
|
||||
"no_reply_action": "摸鱼",
|
||||
"reply_action": "回复",
|
||||
"action_manager": "动作",
|
||||
"memory_activator": "记忆",
|
||||
"tool_use": "工具",
|
||||
"expressor": "表达方式",
|
||||
"database_model": "数据库",
|
||||
"mood": "情绪",
|
||||
"memory": "记忆",
|
||||
"tool_executor": "工具",
|
||||
"hfc": "聊天节奏",
|
||||
"chat": "所见",
|
||||
"plugin_manager": "插件",
|
||||
"relationship_builder": "关系",
|
||||
"llm_models": "模型",
|
||||
"person_info": "人物",
|
||||
"chat_stream": "聊天流",
|
||||
"planner": "规划器",
|
||||
"replyer": "言语",
|
||||
"config": "配置",
|
||||
"main": "主程序",
|
||||
}
|
||||
|
||||
RESET_COLOR = "\033[0m"
|
||||
|
||||
|
||||
@@ -497,15 +525,18 @@ class ModuleColoredConsoleRenderer:
|
||||
if self._colors and self._enable_module_colors and logger_name:
|
||||
module_color = MODULE_COLORS.get(logger_name, "")
|
||||
|
||||
# 模块名称(带颜色)
|
||||
# 模块名称(带颜色和别名支持)
|
||||
if logger_name:
|
||||
# 获取别名,如果没有别名则使用原名称
|
||||
display_name = MODULE_ALIASES.get(logger_name, logger_name)
|
||||
|
||||
if self._colors and self._enable_module_colors:
|
||||
if module_color:
|
||||
module_part = f"{module_color}[{logger_name}]{RESET_COLOR}"
|
||||
module_part = f"{module_color}[{display_name}]{RESET_COLOR}"
|
||||
else:
|
||||
module_part = f"[{logger_name}]"
|
||||
module_part = f"[{display_name}]"
|
||||
else:
|
||||
module_part = f"[{logger_name}]"
|
||||
module_part = f"[{display_name}]"
|
||||
parts.append(module_part)
|
||||
|
||||
# 消息内容(确保转换为字符串)
|
||||
@@ -715,19 +746,7 @@ def configure_logging(
|
||||
root_logger.setLevel(getattr(logging, level.upper()))
|
||||
|
||||
|
||||
def set_module_color(module_name: str, color_code: str):
|
||||
"""为指定模块设置颜色
|
||||
|
||||
Args:
|
||||
module_name: 模块名称
|
||||
color_code: ANSI颜色代码,例如 '\033[92m' 表示亮绿色
|
||||
"""
|
||||
MODULE_COLORS[module_name] = color_code
|
||||
|
||||
|
||||
def get_module_colors():
|
||||
"""获取当前模块颜色配置"""
|
||||
return MODULE_COLORS.copy()
|
||||
|
||||
|
||||
def reload_log_config():
|
||||
@@ -918,9 +937,20 @@ def show_module_colors():
|
||||
for module_name, _color_code in MODULE_COLORS.items():
|
||||
# 临时创建一个该模块的logger来展示颜色
|
||||
demo_logger = structlog.get_logger(module_name).bind(logger_name=module_name)
|
||||
demo_logger.info(f"这是 {module_name} 模块的颜色效果")
|
||||
alias = MODULE_ALIASES.get(module_name, module_name)
|
||||
if alias != module_name:
|
||||
demo_logger.info(f"这是 {module_name} 模块的颜色效果 (显示为: {alias})")
|
||||
else:
|
||||
demo_logger.info(f"这是 {module_name} 模块的颜色效果")
|
||||
|
||||
print("=== 颜色展示结束 ===\n")
|
||||
|
||||
# 显示别名映射表
|
||||
if MODULE_ALIASES:
|
||||
print("=== 当前别名映射 ===")
|
||||
for module_name, alias in MODULE_ALIASES.items():
|
||||
print(f" {module_name} -> {alias}")
|
||||
print("=== 别名映射结束 ===\n")
|
||||
|
||||
|
||||
def format_json_for_logging(data, indent=2, ensure_ascii=False):
|
||||
|
||||
@@ -36,6 +36,7 @@ from src.config.official_configs import (
|
||||
LPMMKnowledgeConfig,
|
||||
RelationshipConfig,
|
||||
ToolConfig,
|
||||
VoiceConfig,
|
||||
DebugConfig,
|
||||
CustomPromptConfig,
|
||||
)
|
||||
@@ -64,7 +65,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
|
||||
|
||||
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
||||
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
||||
MMC_VERSION = "0.9.0-snapshot.2"
|
||||
MMC_VERSION = "0.9.1"
|
||||
|
||||
|
||||
|
||||
@@ -616,7 +617,7 @@ class Config(ConfigBase):
|
||||
tool: ToolConfig
|
||||
debug: DebugConfig
|
||||
custom_prompt: CustomPromptConfig
|
||||
|
||||
voice: VoiceConfig
|
||||
|
||||
def load_config(config_path: str) -> Config:
|
||||
"""
|
||||
|
||||
@@ -18,6 +18,9 @@ from packaging.version import Version
|
||||
@dataclass
|
||||
class BotConfig(ConfigBase):
|
||||
"""QQ机器人配置类"""
|
||||
|
||||
platform: str
|
||||
"""平台"""
|
||||
|
||||
qq_account: str
|
||||
"""QQ账号"""
|
||||
@@ -82,6 +85,12 @@ class ChatConfig(ConfigBase):
|
||||
use_s4u_prompt_mode: bool = False
|
||||
"""是否使用 s4u 对话构建模式,该模式会分开处理当前对话对象和其他所有对话的内容进行 prompt 构建"""
|
||||
|
||||
mentioned_bot_inevitable_reply: bool = False
|
||||
"""提及 bot 必然回复"""
|
||||
|
||||
at_bot_inevitable_reply: bool = False
|
||||
"""@bot 必然回复"""
|
||||
|
||||
# 修改:基于时段的回复频率配置,改为数组格式
|
||||
time_based_talk_frequency: list[str] = field(default_factory=lambda: [])
|
||||
"""
|
||||
@@ -107,9 +116,6 @@ class ChatConfig(ConfigBase):
|
||||
focus_value: float = 1.0
|
||||
"""麦麦的专注思考能力,越低越容易专注,消耗token也越多"""
|
||||
|
||||
enable_asr: bool = False
|
||||
"""是否启用语音识别"""
|
||||
|
||||
def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float:
|
||||
"""
|
||||
根据当前时间和聊天流获取对应的 talk_frequency
|
||||
@@ -271,11 +277,7 @@ class NormalChatConfig(ConfigBase):
|
||||
response_interested_rate_amplifier: float = 1.0
|
||||
"""回复兴趣度放大系数"""
|
||||
|
||||
mentioned_bot_inevitable_reply: bool = False
|
||||
"""提及 bot 必然回复"""
|
||||
|
||||
at_bot_inevitable_reply: bool = False
|
||||
"""@bot 必然回复"""
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -310,6 +312,13 @@ class ToolConfig(ConfigBase):
|
||||
|
||||
enable_in_focus_chat: bool = True
|
||||
"""是否在专注聊天中启用工具"""
|
||||
|
||||
@dataclass
|
||||
class VoiceConfig(ConfigBase):
|
||||
"""语音识别配置类"""
|
||||
|
||||
enable_asr: bool = False
|
||||
"""是否启用语音识别"""
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -400,15 +409,9 @@ class MoodConfig(ConfigBase):
|
||||
|
||||
enable_mood: bool = False
|
||||
"""是否启用情绪系统"""
|
||||
|
||||
mood_update_interval: int = 1
|
||||
"""情绪更新间隔(秒)"""
|
||||
|
||||
mood_decay_rate: float = 0.95
|
||||
"""情绪衰减率"""
|
||||
|
||||
mood_intensity_factor: float = 0.7
|
||||
"""情绪强度因子"""
|
||||
|
||||
mood_update_threshold: float = 1.0
|
||||
"""情绪更新阈值,越高,更新越慢"""
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ast
|
||||
import json
|
||||
import os
|
||||
import hashlib
|
||||
import time
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
@@ -9,8 +9,6 @@ from src.llm_models.utils_model import LLMRequest
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from rich.traceback import install
|
||||
|
||||
from .personality import Personality
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("individuality")
|
||||
@@ -20,12 +18,10 @@ class Individuality:
|
||||
"""个体特征管理类"""
|
||||
|
||||
def __init__(self):
|
||||
# 正常初始化实例属性
|
||||
self.personality: Personality = None # type: ignore
|
||||
|
||||
self.name = ""
|
||||
self.bot_person_id = ""
|
||||
self.meta_info_file_path = "data/personality/meta.json"
|
||||
self.personality_data_file_path = "data/personality/personality_data.json"
|
||||
|
||||
self.model = LLMRequest(
|
||||
model=global_config.model.utils,
|
||||
@@ -33,20 +29,13 @@ class Individuality:
|
||||
)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""初始化个体特征
|
||||
|
||||
Args:
|
||||
bot_nickname: 机器人昵称
|
||||
personality_core: 人格核心特点
|
||||
personality_side: 人格侧面描述
|
||||
identity: 身份细节描述
|
||||
"""
|
||||
"""初始化个体特征"""
|
||||
bot_nickname = global_config.bot.nickname
|
||||
personality_core = global_config.personality.personality_core
|
||||
personality_side = global_config.personality.personality_side
|
||||
identity = global_config.personality.identity
|
||||
|
||||
logger.info("正在初始化个体特征")
|
||||
|
||||
person_info_manager = get_person_info_manager()
|
||||
self.bot_person_id = person_info_manager.get_person_id("system", "bot_id")
|
||||
self.name = bot_nickname
|
||||
@@ -56,129 +45,61 @@ class Individuality:
|
||||
bot_nickname, personality_core, personality_side, identity
|
||||
)
|
||||
|
||||
# 初始化人格(现在包含身份)
|
||||
self.personality = Personality.initialize(
|
||||
bot_nickname=bot_nickname,
|
||||
personality_core=personality_core,
|
||||
personality_side=personality_side,
|
||||
identity=identity,
|
||||
compress_personality=global_config.personality.compress_personality,
|
||||
compress_identity=global_config.personality.compress_identity,
|
||||
)
|
||||
logger.info("正在构建人设信息")
|
||||
|
||||
logger.info("正在将所有人设写入impression")
|
||||
# 将所有人设写入impression
|
||||
impression_parts = []
|
||||
if personality_core:
|
||||
impression_parts.append(f"核心人格: {personality_core}")
|
||||
if personality_side:
|
||||
impression_parts.append(f"人格侧面: {personality_side}")
|
||||
if identity:
|
||||
impression_parts.append(f"身份: {identity}")
|
||||
logger.info(f"impression_parts: {impression_parts}")
|
||||
# 如果配置有变化,重新生成压缩版本
|
||||
if personality_changed or identity_changed:
|
||||
logger.info("检测到配置变化,重新生成压缩版本")
|
||||
personality_result = await self._create_personality(personality_core, personality_side)
|
||||
identity_result = await self._create_identity(identity)
|
||||
else:
|
||||
logger.info("配置未变化,使用缓存版本")
|
||||
# 从文件中获取已有的结果
|
||||
personality_result, identity_result = self._get_personality_from_file()
|
||||
if not personality_result or not identity_result:
|
||||
logger.info("未找到有效缓存,重新生成")
|
||||
personality_result = await self._create_personality(personality_core, personality_side)
|
||||
identity_result = await self._create_identity(identity)
|
||||
|
||||
impression_text = "。".join(impression_parts)
|
||||
if impression_text:
|
||||
impression_text += "。"
|
||||
# 保存到文件
|
||||
if personality_result and identity_result:
|
||||
self._save_personality_to_file(personality_result, identity_result)
|
||||
logger.info("已将人设构建并保存到文件")
|
||||
else:
|
||||
logger.error("人设构建失败")
|
||||
|
||||
if impression_text:
|
||||
# 如果任何一个发生变化,都需要清空数据库中的info_list(因为这影响整体人设)
|
||||
if personality_changed or identity_changed:
|
||||
logger.info("将清空数据库中原有的关键词缓存")
|
||||
update_data = {
|
||||
"platform": "system",
|
||||
"user_id": "bot_id",
|
||||
"person_name": self.name,
|
||||
"nickname": self.name,
|
||||
}
|
||||
|
||||
await person_info_manager.update_one_field(
|
||||
self.bot_person_id, "impression", impression_text, data=update_data
|
||||
)
|
||||
logger.debug("已将完整人设更新到bot的impression中")
|
||||
|
||||
# 根据变化情况决定是否重新创建
|
||||
personality_result = None
|
||||
identity_result = None
|
||||
|
||||
if personality_changed:
|
||||
logger.info("检测到人格配置变化,重新生成压缩版本")
|
||||
personality_result = await self._create_personality(personality_core, personality_side)
|
||||
else:
|
||||
logger.info("人格配置未变化,使用缓存版本")
|
||||
# 从缓存中获取已有的personality结果
|
||||
existing_short_impression = await person_info_manager.get_value(self.bot_person_id, "short_impression")
|
||||
if existing_short_impression:
|
||||
try:
|
||||
existing_data = ast.literal_eval(existing_short_impression) # type: ignore
|
||||
if isinstance(existing_data, list) and len(existing_data) >= 1:
|
||||
personality_result = existing_data[0]
|
||||
except (json.JSONDecodeError, TypeError, IndexError):
|
||||
logger.warning("无法解析现有的short_impression,将重新生成人格部分")
|
||||
personality_result = await self._create_personality(personality_core, personality_side)
|
||||
else:
|
||||
logger.info("未找到现有的人格缓存,重新生成")
|
||||
personality_result = await self._create_personality(personality_core, personality_side)
|
||||
|
||||
if identity_changed:
|
||||
logger.info("检测到身份配置变化,重新生成压缩版本")
|
||||
identity_result = await self._create_identity(identity)
|
||||
else:
|
||||
logger.info("身份配置未变化,使用缓存版本")
|
||||
# 从缓存中获取已有的identity结果
|
||||
existing_short_impression = await person_info_manager.get_value(self.bot_person_id, "short_impression")
|
||||
if existing_short_impression:
|
||||
try:
|
||||
existing_data = ast.literal_eval(existing_short_impression) # type: ignore
|
||||
if isinstance(existing_data, list) and len(existing_data) >= 2:
|
||||
identity_result = existing_data[1]
|
||||
except (json.JSONDecodeError, TypeError, IndexError):
|
||||
logger.warning("无法解析现有的short_impression,将重新生成身份部分")
|
||||
identity_result = await self._create_identity(identity)
|
||||
else:
|
||||
logger.info("未找到现有的身份缓存,重新生成")
|
||||
identity_result = await self._create_identity(identity)
|
||||
|
||||
result = [personality_result, identity_result]
|
||||
|
||||
# 更新short_impression字段
|
||||
if personality_result and identity_result:
|
||||
person_info_manager = get_person_info_manager()
|
||||
await person_info_manager.update_one_field(self.bot_person_id, "short_impression", result)
|
||||
logger.info("已将人设构建")
|
||||
else:
|
||||
logger.error("人设构建失败")
|
||||
await person_info_manager.update_one_field(self.bot_person_id, "info_list", [], data=update_data)
|
||||
|
||||
async def get_personality_block(self) -> str:
|
||||
person_info_manager = get_person_info_manager()
|
||||
bot_person_id = person_info_manager.get_person_id("system", "bot_id")
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||
else:
|
||||
bot_nickname = ""
|
||||
short_impression = await person_info_manager.get_value(bot_person_id, "short_impression")
|
||||
# 解析字符串形式的Python列表
|
||||
try:
|
||||
if isinstance(short_impression, str) and short_impression.strip():
|
||||
short_impression = ast.literal_eval(short_impression)
|
||||
elif not short_impression:
|
||||
logger.warning("short_impression为空,使用默认值")
|
||||
short_impression = ["友好活泼", "人类"]
|
||||
except (ValueError, SyntaxError) as e:
|
||||
logger.error(f"解析short_impression失败: {e}, 原始值: {short_impression}")
|
||||
short_impression = ["友好活泼", "人类"]
|
||||
|
||||
# 从文件获取 short_impression
|
||||
personality, identity = self._get_personality_from_file()
|
||||
|
||||
# 确保short_impression是列表格式且有足够的元素
|
||||
if not isinstance(short_impression, list) or len(short_impression) < 2:
|
||||
logger.warning(f"short_impression格式不正确: {short_impression}, 使用默认值")
|
||||
short_impression = ["友好活泼", "人类"]
|
||||
personality = short_impression[0]
|
||||
identity = short_impression[1]
|
||||
prompt_personality = f"{personality},{identity}"
|
||||
identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||
|
||||
return identity_block
|
||||
if not personality or not identity:
|
||||
logger.warning(f"personality或identity为空: {personality}, {identity}, 使用默认值")
|
||||
personality = "友好活泼"
|
||||
identity = "人类"
|
||||
|
||||
prompt_personality = f"{personality}\n{identity}"
|
||||
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
|
||||
|
||||
def _get_config_hash(
|
||||
self, bot_nickname: str, personality_core: str, personality_side: str, identity: list
|
||||
self, bot_nickname: str, personality_core: str, personality_side: str, identity: str
|
||||
) -> tuple[str, str]:
|
||||
"""获取personality和identity配置的哈希值
|
||||
|
||||
@@ -190,15 +111,15 @@ class Individuality:
|
||||
"nickname": bot_nickname,
|
||||
"personality_core": personality_core,
|
||||
"personality_side": personality_side,
|
||||
"compress_personality": self.personality.compress_personality if self.personality else True,
|
||||
"compress_personality": global_config.personality.compress_personality,
|
||||
}
|
||||
personality_str = json.dumps(personality_config, sort_keys=True)
|
||||
personality_hash = hashlib.md5(personality_str.encode("utf-8")).hexdigest()
|
||||
|
||||
# 身份配置哈希
|
||||
identity_config = {
|
||||
"identity": sorted(identity),
|
||||
"compress_identity": self.personality.compress_identity if self.personality else True,
|
||||
"identity": identity,
|
||||
"compress_identity": global_config.personality.compress_identity,
|
||||
}
|
||||
identity_str = json.dumps(identity_config, sort_keys=True)
|
||||
identity_hash = hashlib.md5(identity_str.encode("utf-8")).hexdigest()
|
||||
@@ -206,7 +127,7 @@ class Individuality:
|
||||
return personality_hash, identity_hash
|
||||
|
||||
async def _check_config_and_clear_if_changed(
|
||||
self, bot_nickname: str, personality_core: str, personality_side: str, identity: list
|
||||
self, bot_nickname: str, personality_core: str, personality_side: str, identity: str
|
||||
) -> tuple[bool, bool]:
|
||||
"""检查配置是否发生变化,如果变化则清空相应缓存
|
||||
|
||||
@@ -271,6 +192,53 @@ class Individuality:
|
||||
except IOError as e:
|
||||
logger.error(f"保存meta_info文件失败: {e}")
|
||||
|
||||
def _load_personality_data(self) -> dict:
|
||||
"""从JSON文件中加载personality数据"""
|
||||
if os.path.exists(self.personality_data_file_path):
|
||||
try:
|
||||
with open(self.personality_data_file_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
logger.error(f"读取personality_data文件失败: {e}, 将创建新文件。")
|
||||
return {}
|
||||
return {}
|
||||
|
||||
def _save_personality_data(self, personality_data: dict):
|
||||
"""将personality数据保存到JSON文件"""
|
||||
try:
|
||||
os.makedirs(os.path.dirname(self.personality_data_file_path), exist_ok=True)
|
||||
with open(self.personality_data_file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(personality_data, f, ensure_ascii=False, indent=2)
|
||||
logger.debug(f"已保存personality数据到文件: {self.personality_data_file_path}")
|
||||
except IOError as e:
|
||||
logger.error(f"保存personality_data文件失败: {e}")
|
||||
|
||||
def _get_personality_from_file(self) -> tuple[str, str]:
|
||||
"""从文件获取personality数据
|
||||
|
||||
Returns:
|
||||
tuple: (personality, identity)
|
||||
"""
|
||||
personality_data = self._load_personality_data()
|
||||
personality = personality_data.get("personality", "友好活泼")
|
||||
identity = personality_data.get("identity", "人类")
|
||||
return personality, identity
|
||||
|
||||
def _save_personality_to_file(self, personality: str, identity: str):
|
||||
"""保存personality数据到文件
|
||||
|
||||
Args:
|
||||
personality: 压缩后的人格描述
|
||||
identity: 压缩后的身份描述
|
||||
"""
|
||||
personality_data = {
|
||||
"personality": personality,
|
||||
"identity": identity,
|
||||
"bot_nickname": self.name,
|
||||
"last_updated": int(time.time())
|
||||
}
|
||||
self._save_personality_data(personality_data)
|
||||
|
||||
async def _create_personality(self, personality_core: str, personality_side: str) -> str:
|
||||
# sourcery skip: merge-list-append, move-assign
|
||||
"""使用LLM创建压缩版本的impression
|
||||
@@ -290,7 +258,7 @@ class Individuality:
|
||||
personality_parts.append(f"{personality_core}")
|
||||
|
||||
# 准备需要压缩的内容
|
||||
if self.personality.compress_personality:
|
||||
if global_config.personality.compress_personality:
|
||||
personality_to_compress = f"人格特质: {personality_side}"
|
||||
|
||||
prompt = f"""请将以下人格信息进行简洁压缩,保留主要内容,用简练的中文表达:
|
||||
@@ -321,11 +289,11 @@ class Individuality:
|
||||
|
||||
return personality_result
|
||||
|
||||
async def _create_identity(self, identity: list) -> str:
|
||||
async def _create_identity(self, identity: str) -> str:
|
||||
"""使用LLM创建压缩版本的impression"""
|
||||
logger.info("正在构建身份.........")
|
||||
|
||||
if self.personality.compress_identity:
|
||||
if global_config.personality.compress_identity:
|
||||
identity_to_compress = f"身份背景: {identity}"
|
||||
|
||||
prompt = f"""请将以下身份信息进行简洁压缩,保留主要内容,用简练的中文表达:
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
@dataclass
|
||||
class Personality:
|
||||
"""人格特质类"""
|
||||
|
||||
bot_nickname: str # 机器人昵称
|
||||
personality_core: str # 人格核心特点
|
||||
personality_side: str # 人格侧面描述
|
||||
identity: List[str] # 身份细节描述
|
||||
compress_personality: bool # 是否压缩人格
|
||||
compress_identity: bool # 是否压缩身份
|
||||
|
||||
_instance = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, personality_core: str = "", personality_side: str = "", identity: List[str] = None):
|
||||
self.personality_core = personality_core
|
||||
self.personality_side = personality_side
|
||||
self.identity = identity
|
||||
self.compress_personality = True
|
||||
self.compress_identity = True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "Personality":
|
||||
"""获取Personality单例实例
|
||||
|
||||
Returns:
|
||||
Personality: 单例实例
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def initialize(
|
||||
cls,
|
||||
bot_nickname: str,
|
||||
personality_core: str,
|
||||
personality_side: str,
|
||||
identity: List[str] = None,
|
||||
compress_personality: bool = True,
|
||||
compress_identity: bool = True,
|
||||
) -> "Personality":
|
||||
"""初始化人格特质
|
||||
|
||||
Args:
|
||||
bot_nickname: 机器人昵称
|
||||
personality_core: 人格核心特点
|
||||
personality_side: 人格侧面描述
|
||||
identity: 身份细节描述
|
||||
compress_personality: 是否压缩人格
|
||||
compress_identity: 是否压缩身份
|
||||
|
||||
Returns:
|
||||
Personality: 初始化后的人格特质实例
|
||||
"""
|
||||
instance = cls.get_instance()
|
||||
instance.bot_nickname = bot_nickname
|
||||
instance.personality_core = personality_core
|
||||
instance.personality_side = personality_side
|
||||
instance.identity = identity
|
||||
instance.compress_personality = compress_personality
|
||||
instance.compress_identity = compress_identity
|
||||
return instance
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""将人格特质转换为字典格式"""
|
||||
return {
|
||||
"bot_nickname": self.bot_nickname,
|
||||
"personality_core": self.personality_core,
|
||||
"personality_side": self.personality_side,
|
||||
"identity": self.identity,
|
||||
"compress_personality": self.compress_personality,
|
||||
"compress_identity": self.compress_identity,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "Personality":
|
||||
"""从字典创建人格特质实例"""
|
||||
instance = cls.get_instance()
|
||||
for key, value in data.items():
|
||||
setattr(instance, key, value)
|
||||
return instance
|
||||
@@ -10,6 +10,7 @@ import base64
|
||||
from PIL import Image
|
||||
import io
|
||||
import os
|
||||
import copy # 添加copy模块用于深拷贝
|
||||
from src.common.database.database import db # 确保 db 被导入用于 create_tables
|
||||
from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
|
||||
from src.config.config import global_config
|
||||
@@ -69,23 +70,28 @@ error_code_mapping = {
|
||||
|
||||
|
||||
async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any]):
|
||||
"""安全地记录请求体,用于调试日志,不会修改原始payload对象"""
|
||||
# 创建payload的深拷贝,避免修改原始对象
|
||||
safe_payload = copy.deepcopy(payload)
|
||||
|
||||
image_base64: str = request_content.get("image_base64")
|
||||
image_format: str = request_content.get("image_format")
|
||||
if (
|
||||
image_base64
|
||||
and payload
|
||||
and isinstance(payload, dict)
|
||||
and "messages" in payload
|
||||
and len(payload["messages"]) > 0
|
||||
and safe_payload
|
||||
and isinstance(safe_payload, dict)
|
||||
and "messages" in safe_payload
|
||||
and len(safe_payload["messages"]) > 0
|
||||
):
|
||||
if isinstance(payload["messages"][0], dict) and "content" in payload["messages"][0]:
|
||||
content = payload["messages"][0]["content"]
|
||||
if isinstance(safe_payload["messages"][0], dict) and "content" in safe_payload["messages"][0]:
|
||||
content = safe_payload["messages"][0]["content"]
|
||||
if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
|
||||
payload["messages"][0]["content"][1]["image_url"]["url"] = (
|
||||
# 只修改拷贝的对象,用于安全的日志记录
|
||||
safe_payload["messages"][0]["content"][1]["image_url"]["url"] = (
|
||||
f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
|
||||
f"{image_base64[:10]}...{image_base64[-10:]}"
|
||||
)
|
||||
return payload
|
||||
return safe_payload
|
||||
|
||||
|
||||
class LLMRequest:
|
||||
@@ -109,10 +115,15 @@ class LLMRequest:
|
||||
|
||||
def __init__(self, model: dict, **kwargs):
|
||||
# 将大写的配置键转换为小写并从config中获取实际值
|
||||
logger.debug(f"🔍 [模型初始化] 开始初始化模型: {model.get('name', 'Unknown')}")
|
||||
logger.debug(f"🔍 [模型初始化] 模型配置: {model}")
|
||||
logger.debug(f"🔍 [模型初始化] 额外参数: {kwargs}")
|
||||
|
||||
try:
|
||||
# print(f"model['provider']: {model['provider']}")
|
||||
self.api_key = os.environ[f"{model['provider']}_KEY"]
|
||||
self.base_url = os.environ[f"{model['provider']}_BASE_URL"]
|
||||
logger.debug(f"🔍 [模型初始化] 成功获取环境变量: {model['provider']}_KEY 和 {model['provider']}_BASE_URL")
|
||||
except AttributeError as e:
|
||||
logger.error(f"原始 model dict 信息:{model}")
|
||||
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
|
||||
@@ -124,6 +135,10 @@ class LLMRequest:
|
||||
self.model_name: str = model["name"]
|
||||
self.params = kwargs
|
||||
|
||||
# 记录配置文件中声明了哪些参数(不管值是什么)
|
||||
self.has_enable_thinking = "enable_thinking" in model
|
||||
self.has_thinking_budget = "thinking_budget" in model
|
||||
|
||||
self.enable_thinking = model.get("enable_thinking", False)
|
||||
self.temp = model.get("temp", 0.7)
|
||||
self.thinking_budget = model.get("thinking_budget", 4096)
|
||||
@@ -132,12 +147,24 @@ class LLMRequest:
|
||||
self.pri_out = model.get("pri_out", 0)
|
||||
self.max_tokens = model.get("max_tokens", global_config.model.model_max_output_length)
|
||||
# print(f"max_tokens: {self.max_tokens}")
|
||||
|
||||
logger.debug(f"🔍 [模型初始化] 模型参数设置完成:")
|
||||
logger.debug(f" - model_name: {self.model_name}")
|
||||
logger.debug(f" - has_enable_thinking: {self.has_enable_thinking}")
|
||||
logger.debug(f" - enable_thinking: {self.enable_thinking}")
|
||||
logger.debug(f" - has_thinking_budget: {self.has_thinking_budget}")
|
||||
logger.debug(f" - thinking_budget: {self.thinking_budget}")
|
||||
logger.debug(f" - temp: {self.temp}")
|
||||
logger.debug(f" - stream: {self.stream}")
|
||||
logger.debug(f" - max_tokens: {self.max_tokens}")
|
||||
logger.debug(f" - base_url: {self.base_url}")
|
||||
|
||||
# 获取数据库实例
|
||||
self._init_database()
|
||||
|
||||
# 从 kwargs 中提取 request_type,如果没有提供则默认为 "default"
|
||||
self.request_type = kwargs.pop("request_type", "default")
|
||||
logger.debug(f"🔍 [模型初始化] 初始化完成,request_type: {self.request_type}")
|
||||
|
||||
@staticmethod
|
||||
def _init_database():
|
||||
@@ -262,11 +289,12 @@ class LLMRequest:
|
||||
if self.temp != 0.7:
|
||||
payload["temperature"] = self.temp
|
||||
|
||||
# 添加enable_thinking参数(如果不是默认值False)
|
||||
if not self.enable_thinking:
|
||||
payload["enable_thinking"] = False
|
||||
# 添加enable_thinking参数(只有配置文件中声明了才添加,不管值是true还是false)
|
||||
if self.has_enable_thinking:
|
||||
payload["enable_thinking"] = self.enable_thinking
|
||||
|
||||
if self.thinking_budget != 4096:
|
||||
# 添加thinking_budget参数(只有配置文件中声明了才添加)
|
||||
if self.has_thinking_budget:
|
||||
payload["thinking_budget"] = self.thinking_budget
|
||||
|
||||
if self.max_tokens:
|
||||
@@ -334,6 +362,19 @@ class LLMRequest:
|
||||
# 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
|
||||
if request_content["stream_mode"]:
|
||||
headers["Accept"] = "text/event-stream"
|
||||
|
||||
# 添加请求发送前的调试信息
|
||||
logger.debug(f"🔍 [请求调试] 模型 {self.model_name} 准备发送请求")
|
||||
logger.debug(f"🔍 [请求调试] API URL: {request_content['api_url']}")
|
||||
logger.debug(f"🔍 [请求调试] 请求头: {await self._build_headers(no_key=True, is_formdata=file_bytes is not None)}")
|
||||
|
||||
if not file_bytes:
|
||||
# 安全地记录请求体(隐藏敏感信息)
|
||||
safe_payload = await _safely_record(request_content, request_content["payload"])
|
||||
logger.debug(f"🔍 [请求调试] 请求体: {json.dumps(safe_payload, indent=2, ensure_ascii=False)}")
|
||||
else:
|
||||
logger.debug(f"🔍 [请求调试] 文件上传请求,文件格式: {request_content['file_format']}")
|
||||
|
||||
async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session:
|
||||
post_kwargs = {"headers": headers}
|
||||
# form-data数据上传方式不同
|
||||
@@ -491,7 +532,36 @@ class LLMRequest:
|
||||
logger.warning(f"模型 {self.model_name} 请求限制(429),等待{wait_time}秒后重试...")
|
||||
raise RuntimeError("请求限制(429)")
|
||||
elif response.status in policy["abort_codes"]:
|
||||
if response.status != 403:
|
||||
# 特别处理400错误,添加详细调试信息
|
||||
if response.status == 400:
|
||||
logger.error(f"🔍 [调试信息] 模型 {self.model_name} 参数错误 (400) - 开始详细诊断")
|
||||
logger.error(f"🔍 [调试信息] 模型名称: {self.model_name}")
|
||||
logger.error(f"🔍 [调试信息] API地址: {self.base_url}")
|
||||
logger.error(f"🔍 [调试信息] 模型配置参数:")
|
||||
logger.error(f" - enable_thinking: {self.enable_thinking}")
|
||||
logger.error(f" - temp: {self.temp}")
|
||||
logger.error(f" - thinking_budget: {self.thinking_budget}")
|
||||
logger.error(f" - stream: {self.stream}")
|
||||
logger.error(f" - max_tokens: {self.max_tokens}")
|
||||
logger.error(f" - pri_in: {self.pri_in}")
|
||||
logger.error(f" - pri_out: {self.pri_out}")
|
||||
logger.error(f"🔍 [调试信息] 原始params: {self.params}")
|
||||
|
||||
# 尝试获取服务器返回的详细错误信息
|
||||
try:
|
||||
error_text = await response.text()
|
||||
logger.error(f"🔍 [调试信息] 服务器返回的原始错误内容: {error_text}")
|
||||
|
||||
try:
|
||||
error_json = json.loads(error_text)
|
||||
logger.error(f"🔍 [调试信息] 解析后的错误JSON: {json.dumps(error_json, indent=2, ensure_ascii=False)}")
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"🔍 [调试信息] 错误响应不是有效的JSON格式")
|
||||
except Exception as e:
|
||||
logger.error(f"🔍 [调试信息] 无法读取错误响应内容: {str(e)}")
|
||||
|
||||
raise RequestAbortException("参数错误,请检查调试信息", response)
|
||||
elif response.status != 403:
|
||||
raise RequestAbortException("请求出现错误,中断处理", response)
|
||||
else:
|
||||
raise PermissionDeniedException("模型禁止访问")
|
||||
@@ -510,6 +580,19 @@ class LLMRequest:
|
||||
logger.error(
|
||||
f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
|
||||
)
|
||||
|
||||
# 如果是400错误,额外输出请求体信息用于调试
|
||||
if response.status == 400:
|
||||
logger.error(f"🔍 [异常调试] 400错误 - 请求体调试信息:")
|
||||
try:
|
||||
safe_payload = await _safely_record(request_content, payload)
|
||||
logger.error(f"🔍 [异常调试] 发送的请求体: {json.dumps(safe_payload, indent=2, ensure_ascii=False)}")
|
||||
except Exception as debug_error:
|
||||
logger.error(f"🔍 [异常调试] 无法安全记录请求体: {str(debug_error)}")
|
||||
logger.error(f"🔍 [异常调试] 原始payload类型: {type(payload)}")
|
||||
if isinstance(payload, dict):
|
||||
logger.error(f"🔍 [异常调试] 原始payload键: {list(payload.keys())}")
|
||||
|
||||
# print(request_content)
|
||||
# print(response)
|
||||
# 尝试获取并记录服务器返回的详细错误信息
|
||||
@@ -654,14 +737,27 @@ class LLMRequest:
|
||||
"""
|
||||
# 复制一份参数,避免直接修改原始数据
|
||||
new_params = dict(params)
|
||||
|
||||
logger.debug(f"🔍 [参数转换] 模型 {self.model_name} 开始参数转换")
|
||||
logger.debug(f"🔍 [参数转换] 是否为CoT模型: {self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION}")
|
||||
logger.debug(f"🔍 [参数转换] CoT模型列表: {self.MODELS_NEEDING_TRANSFORMATION}")
|
||||
|
||||
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION:
|
||||
logger.debug(f"🔍 [参数转换] 检测到CoT模型,开始参数转换")
|
||||
# 删除 'temperature' 参数(如果存在),但避免删除我们在_build_payload中添加的自定义温度
|
||||
if "temperature" in new_params and new_params["temperature"] == 0.7:
|
||||
new_params.pop("temperature")
|
||||
removed_temp = new_params.pop("temperature")
|
||||
logger.debug(f"🔍 [参数转换] 移除默认temperature参数: {removed_temp}")
|
||||
# 如果存在 'max_tokens',则重命名为 'max_completion_tokens'
|
||||
if "max_tokens" in new_params:
|
||||
old_value = new_params["max_tokens"]
|
||||
new_params["max_completion_tokens"] = new_params.pop("max_tokens")
|
||||
logger.debug(f"🔍 [参数转换] 参数重命名: max_tokens({old_value}) -> max_completion_tokens({new_params['max_completion_tokens']})")
|
||||
else:
|
||||
logger.debug(f"🔍 [参数转换] 非CoT模型,无需参数转换")
|
||||
|
||||
logger.debug(f"🔍 [参数转换] 转换前参数: {params}")
|
||||
logger.debug(f"🔍 [参数转换] 转换后参数: {new_params}")
|
||||
return new_params
|
||||
|
||||
async def _build_formdata_payload(self, file_bytes: bytes, file_format: str) -> aiohttp.FormData:
|
||||
@@ -693,7 +789,12 @@ class LLMRequest:
|
||||
async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict:
|
||||
"""构建请求体"""
|
||||
# 复制一份参数,避免直接修改 self.params
|
||||
logger.debug(f"🔍 [参数构建] 模型 {self.model_name} 开始构建请求体")
|
||||
logger.debug(f"🔍 [参数构建] 原始self.params: {self.params}")
|
||||
|
||||
params_copy = await self._transform_parameters(self.params)
|
||||
logger.debug(f"🔍 [参数构建] 转换后的params_copy: {params_copy}")
|
||||
|
||||
if image_base64:
|
||||
messages = [
|
||||
{
|
||||
@@ -715,26 +816,37 @@ class LLMRequest:
|
||||
"messages": messages,
|
||||
**params_copy,
|
||||
}
|
||||
|
||||
logger.debug(f"🔍 [参数构建] 基础payload构建完成: {list(payload.keys())}")
|
||||
|
||||
# 添加temp参数(如果不是默认值0.7)
|
||||
if self.temp != 0.7:
|
||||
payload["temperature"] = self.temp
|
||||
logger.debug(f"🔍 [参数构建] 添加temperature参数: {self.temp}")
|
||||
|
||||
# 添加enable_thinking参数(如果不是默认值False)
|
||||
if not self.enable_thinking:
|
||||
payload["enable_thinking"] = False
|
||||
# 添加enable_thinking参数(只有配置文件中声明了才添加,不管值是true还是false)
|
||||
if self.has_enable_thinking:
|
||||
payload["enable_thinking"] = self.enable_thinking
|
||||
logger.debug(f"🔍 [参数构建] 添加enable_thinking参数: {self.enable_thinking}")
|
||||
|
||||
if self.thinking_budget != 4096:
|
||||
# 添加thinking_budget参数(只有配置文件中声明了才添加)
|
||||
if self.has_thinking_budget:
|
||||
payload["thinking_budget"] = self.thinking_budget
|
||||
logger.debug(f"🔍 [参数构建] 添加thinking_budget参数: {self.thinking_budget}")
|
||||
|
||||
if self.max_tokens:
|
||||
payload["max_tokens"] = self.max_tokens
|
||||
logger.debug(f"🔍 [参数构建] 添加max_tokens参数: {self.max_tokens}")
|
||||
|
||||
# if "max_tokens" not in payload and "max_completion_tokens" not in payload:
|
||||
# payload["max_tokens"] = global_config.model.model_max_output_length
|
||||
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
|
||||
if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
|
||||
old_value = payload["max_tokens"]
|
||||
payload["max_completion_tokens"] = payload.pop("max_tokens")
|
||||
logger.debug(f"🔍 [参数构建] CoT模型参数转换: max_tokens({old_value}) -> max_completion_tokens({payload['max_completion_tokens']})")
|
||||
|
||||
logger.debug(f"🔍 [参数构建] 最终payload键列表: {list(payload.keys())}")
|
||||
return payload
|
||||
|
||||
def _default_response_handler(
|
||||
|
||||
@@ -115,7 +115,6 @@ class MainSystem:
|
||||
|
||||
# 初始化个体特征
|
||||
await self.individuality.initialize()
|
||||
logger.info("个体特征初始化成功")
|
||||
|
||||
try:
|
||||
init_time = int(1000 * (time.time() - init_start_time))
|
||||
|
||||
1
src/mais4u/constant_s4u.py
Normal file
1
src/mais4u/constant_s4u.py
Normal file
@@ -0,0 +1 @@
|
||||
ENABLE_S4U = False
|
||||
@@ -3,7 +3,7 @@ import time
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.message import MessageSending, MessageRecv, MessageRecvS4U
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
from src.mais4u.mais4u_chat.internal_manager import internal_manager
|
||||
from src.common.logger import get_logger
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -19,6 +19,7 @@ from src.mais4u.s4u_config import s4u_config
|
||||
from src.person_info.person_info import PersonInfoManager
|
||||
from .super_chat_manager import get_super_chat_manager
|
||||
from .yes_or_no import yes_or_no_head
|
||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||
|
||||
logger = get_logger("S4U_chat")
|
||||
|
||||
@@ -165,7 +166,10 @@ class S4UChatManager:
|
||||
return self.s4u_chats[chat_stream.stream_id]
|
||||
|
||||
|
||||
s4u_chat_manager = S4UChatManager()
|
||||
if not ENABLE_S4U:
|
||||
s4u_chat_manager = None
|
||||
else:
|
||||
s4u_chat_manager = S4UChatManager()
|
||||
|
||||
|
||||
def get_s4u_chat_manager() -> S4UChatManager:
|
||||
@@ -486,7 +490,7 @@ class S4UChat:
|
||||
logger.info(f"[S4U] 开始为消息生成文本和音频流: '{message.processed_plain_text[:30]}...'")
|
||||
|
||||
if s4u_config.enable_streaming_output:
|
||||
logger.info(f"[S4U] 开始流式输出")
|
||||
logger.info("[S4U] 开始流式输出")
|
||||
# 流式输出,边生成边发送
|
||||
gen = self.gpt.generate_response(message, "")
|
||||
async for chunk in gen:
|
||||
@@ -494,7 +498,7 @@ class S4UChat:
|
||||
await sender_container.add_message(chunk)
|
||||
total_chars_sent += len(chunk)
|
||||
else:
|
||||
logger.info(f"[S4U] 开始一次性输出")
|
||||
logger.info("[S4U] 开始一次性输出")
|
||||
# 一次性输出,先收集所有chunk
|
||||
all_chunks = []
|
||||
gen = self.gpt.generate_response(message, "")
|
||||
|
||||
@@ -10,6 +10,7 @@ from src.config.config import global_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.plugin_system.apis import send_api
|
||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||
|
||||
"""
|
||||
情绪管理系统使用说明:
|
||||
@@ -446,9 +447,10 @@ class MoodManager:
|
||||
# 发送初始情绪状态到ws端
|
||||
asyncio.create_task(new_mood.send_emotion_update(new_mood.mood_values))
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
mood_manager = MoodManager()
|
||||
if ENABLE_S4U:
|
||||
init_prompt()
|
||||
mood_manager = MoodManager()
|
||||
else:
|
||||
mood_manager = None
|
||||
|
||||
"""全局情绪管理器"""
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Tuple
|
||||
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||
from maim_message.message_base import GroupInfo,UserInfo
|
||||
from maim_message.message_base import GroupInfo
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
|
||||
@@ -10,13 +10,13 @@ from datetime import datetime
|
||||
import asyncio
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.person_info.relationship_manager import get_relationship_manager
|
||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.mais4u.mais4u_chat.super_chat_manager import get_super_chat_manager
|
||||
from src.mais4u.mais4u_chat.screen_manager import screen_manager
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
from .s4u_mood_manager import mood_manager
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
from src.mais4u.mais4u_chat.internal_manager import internal_manager
|
||||
logger = get_logger("prompt")
|
||||
|
||||
@@ -149,9 +149,17 @@ class PromptBuilder:
|
||||
|
||||
relation_prompt = ""
|
||||
if global_config.relationship.enable_relationship and who_chat_in_group:
|
||||
relationship_manager = get_relationship_manager()
|
||||
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_stream.stream_id)
|
||||
|
||||
# 将 (platform, user_id, nickname) 转换为 person_id
|
||||
person_ids = []
|
||||
for person in who_chat_in_group:
|
||||
person_id = PersonInfoManager.get_person_id(person[0], person[1])
|
||||
person_ids.append(person_id)
|
||||
|
||||
# 使用 RelationshipFetcher 的 build_relation_info 方法,设置 points_num=3 保持与原来相同的行为
|
||||
relation_info_list = await asyncio.gather(
|
||||
*[relationship_manager.build_relationship_info(person) for person in who_chat_in_group]
|
||||
*[relationship_fetcher.build_relation_info(person_id, points_num=3) for person_id in person_ids]
|
||||
)
|
||||
relation_info = "".join(relation_info_list)
|
||||
if relation_info:
|
||||
|
||||
@@ -5,7 +5,6 @@ from src.config.config import global_config
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder
|
||||
from src.common.logger import get_logger
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
@@ -49,19 +48,19 @@ class S4UStreamGenerator:
|
||||
self.chat_stream =None
|
||||
|
||||
async def build_last_internal_message(self,message:MessageRecvS4U,previous_reply_context:str = ""):
|
||||
person_id = PersonInfoManager.get_person_id(
|
||||
message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
|
||||
)
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
# person_id = PersonInfoManager.get_person_id(
|
||||
# message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
|
||||
# )
|
||||
# person_info_manager = get_person_info_manager()
|
||||
# person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
|
||||
if message.chat_stream.user_info.user_nickname:
|
||||
if person_name:
|
||||
sender_name = f"[{message.chat_stream.user_info.user_nickname}](你叫ta{person_name})"
|
||||
else:
|
||||
sender_name = f"[{message.chat_stream.user_info.user_nickname}]"
|
||||
else:
|
||||
sender_name = f"用户({message.chat_stream.user_info.user_id})"
|
||||
# if message.chat_stream.user_info.user_nickname:
|
||||
# if person_name:
|
||||
# sender_name = f"[{message.chat_stream.user_info.user_nickname}](你叫ta{person_name})"
|
||||
# else:
|
||||
# sender_name = f"[{message.chat_stream.user_info.user_nickname}]"
|
||||
# else:
|
||||
# sender_name = f"用户({message.chat_stream.user_info.user_id})"
|
||||
|
||||
# 构建prompt
|
||||
if previous_reply_context:
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
import asyncio
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
@@ -4,6 +4,8 @@ from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
# 全局SuperChat管理器实例
|
||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||
|
||||
logger = get_logger("super_chat_manager")
|
||||
|
||||
@@ -296,10 +298,14 @@ class SuperChatManager:
|
||||
logger.info("SuperChat管理器已关闭")
|
||||
|
||||
|
||||
# 全局SuperChat管理器实例
|
||||
super_chat_manager = SuperChatManager()
|
||||
|
||||
|
||||
if ENABLE_S4U:
|
||||
super_chat_manager = SuperChatManager()
|
||||
else:
|
||||
super_chat_manager = None
|
||||
|
||||
def get_super_chat_manager() -> SuperChatManager:
|
||||
"""获取全局SuperChat管理器实例"""
|
||||
return super_chat_manager
|
||||
|
||||
return super_chat_manager
|
||||
@@ -1,16 +1,6 @@
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.plugin_system.apis import send_api
|
||||
from json_repair import repair_json
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
from src.plugin_system.apis import send_api
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from tomlkit import TOMLDocument
|
||||
from tomlkit.items import Table
|
||||
from dataclasses import dataclass, fields, MISSING, field
|
||||
from typing import TypeVar, Type, Any, get_origin, get_args, Literal
|
||||
|
||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("s4u_config")
|
||||
@@ -353,12 +353,16 @@ def load_s4u_config(config_path: str) -> S4UGlobalConfig:
|
||||
raise e
|
||||
|
||||
|
||||
# 初始化S4U配置
|
||||
logger.info(f"S4U当前版本: {S4U_VERSION}")
|
||||
update_s4u_config()
|
||||
if not ENABLE_S4U:
|
||||
s4u_config = None
|
||||
s4u_config_main = None
|
||||
else:
|
||||
# 初始化S4U配置
|
||||
logger.info(f"S4U当前版本: {S4U_VERSION}")
|
||||
update_s4u_config()
|
||||
|
||||
logger.info("正在加载S4U配置文件...")
|
||||
s4u_config_main = load_s4u_config(config_path=CONFIG_PATH)
|
||||
logger.info("S4U配置文件加载完成!")
|
||||
logger.info("正在加载S4U配置文件...")
|
||||
s4u_config_main = load_s4u_config(config_path=CONFIG_PATH)
|
||||
logger.info("S4U配置文件加载完成!")
|
||||
|
||||
s4u_config: S4UConfig = s4u_config_main.s4u
|
||||
s4u_config: S4UConfig = s4u_config_main.s4u
|
||||
@@ -83,12 +83,12 @@ class ChatMood:
|
||||
logger.debug(
|
||||
f"base_probability: {base_probability}, time_multiplier: {time_multiplier}, interest_multiplier: {interest_multiplier}"
|
||||
)
|
||||
update_probability = min(1.0, base_probability * time_multiplier * interest_multiplier)
|
||||
update_probability = global_config.mood.mood_update_threshold * min(1.0, base_probability * time_multiplier * interest_multiplier)
|
||||
|
||||
if random.random() > update_probability:
|
||||
return
|
||||
|
||||
logger.info(f"{self.log_prefix} 更新情绪状态,感兴趣度: {interested_rate}, 更新概率: {update_probability}")
|
||||
logger.debug(f"{self.log_prefix} 更新情绪状态,感兴趣度: {interested_rate:.2f}, 更新概率: {update_probability:.2f}")
|
||||
|
||||
message_time: float = message.message_info.time # type: ignore
|
||||
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
@@ -201,7 +201,7 @@ class MoodRegressionTask(AsyncTask):
|
||||
if mood.regression_count >= 3:
|
||||
continue
|
||||
|
||||
logger.info(f"chat {mood.chat_id} 开始情绪回归, 这是第 {mood.regression_count + 1} 次")
|
||||
logger.info(f"{mood.log_prefix} 开始情绪回归, 这是第 {mood.regression_count + 1} 次")
|
||||
await mood.regress_mood()
|
||||
|
||||
|
||||
|
||||
@@ -41,8 +41,6 @@ person_info_default = {
|
||||
"know_times": 0,
|
||||
"know_since": None,
|
||||
"last_know": None,
|
||||
# "user_cardname": None, # This field is not in Peewee model PersonInfo
|
||||
# "user_avatar": None, # This field is not in Peewee model PersonInfo
|
||||
"impression": None, # Corrected from person_impression
|
||||
"short_impression": None,
|
||||
"info_list": None,
|
||||
|
||||
@@ -112,15 +112,6 @@ class RelationshipFetcher:
|
||||
|
||||
current_points = await person_info_manager.get_value(person_id, "points") or []
|
||||
|
||||
if isinstance(current_points, str):
|
||||
try:
|
||||
current_points = json.loads(current_points)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"解析points JSON失败: {current_points}")
|
||||
current_points = []
|
||||
elif not isinstance(current_points, list):
|
||||
current_points = []
|
||||
|
||||
# 按时间排序forgotten_points
|
||||
current_points.sort(key=lambda x: x[2])
|
||||
# 按权重加权随机抽取最多3个不重复的points,point[1]的值在1-10之间,权重越高被抽到概率越大
|
||||
@@ -370,60 +361,6 @@ class RelationshipFetcher:
|
||||
logger.error(f"{self.log_prefix} 执行信息提取时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
def _organize_known_info(self) -> str:
|
||||
"""组织已知的用户信息为字符串
|
||||
|
||||
Returns:
|
||||
str: 格式化的用户信息字符串
|
||||
"""
|
||||
persons_infos_str = ""
|
||||
|
||||
if self.info_fetched_cache:
|
||||
persons_with_known_info = [] # 有已知信息的人员
|
||||
persons_with_unknown_info = [] # 有未知信息的人员
|
||||
|
||||
for person_id in self.info_fetched_cache:
|
||||
person_known_infos = []
|
||||
person_unknown_infos = []
|
||||
person_name = ""
|
||||
|
||||
for info_type in self.info_fetched_cache[person_id]:
|
||||
person_name = self.info_fetched_cache[person_id][info_type]["person_name"]
|
||||
if not self.info_fetched_cache[person_id][info_type]["unknown"]:
|
||||
info_content = self.info_fetched_cache[person_id][info_type]["info"]
|
||||
person_known_infos.append(f"[{info_type}]:{info_content}")
|
||||
else:
|
||||
person_unknown_infos.append(info_type)
|
||||
|
||||
# 如果有已知信息,添加到已知信息列表
|
||||
if person_known_infos:
|
||||
known_info_str = ";".join(person_known_infos) + ";"
|
||||
persons_with_known_info.append((person_name, known_info_str))
|
||||
|
||||
# 如果有未知信息,添加到未知信息列表
|
||||
if person_unknown_infos:
|
||||
persons_with_unknown_info.append((person_name, person_unknown_infos))
|
||||
|
||||
# 先输出有已知信息的人员
|
||||
for person_name, known_info_str in persons_with_known_info:
|
||||
persons_infos_str += f"你对 {person_name} 的了解:{known_info_str}\n"
|
||||
|
||||
# 统一处理未知信息,避免重复的警告文本
|
||||
if persons_with_unknown_info:
|
||||
unknown_persons_details = []
|
||||
for person_name, unknown_types in persons_with_unknown_info:
|
||||
unknown_types_str = "、".join(unknown_types)
|
||||
unknown_persons_details.append(f"{person_name}的[{unknown_types_str}]")
|
||||
|
||||
if len(unknown_persons_details) == 1:
|
||||
persons_infos_str += (
|
||||
f"你不了解{unknown_persons_details[0]}信息,不要胡乱回答,可以直接说不知道或忘记了;\n"
|
||||
)
|
||||
else:
|
||||
unknown_all_str = "、".join(unknown_persons_details)
|
||||
persons_infos_str += f"你不了解{unknown_all_str}等信息,不要胡乱回答,可以直接说不知道或忘记了;\n"
|
||||
|
||||
return persons_infos_str
|
||||
|
||||
async def _save_info_to_cache(self, person_id: str, info_type: str, info_content: str):
|
||||
# sourcery skip: use-next
|
||||
|
||||
@@ -55,60 +55,6 @@ class RelationshipManager:
|
||||
# person_id=person_id, user_nickname=user_nickname, user_cardname=user_cardname, user_avatar=user_avatar
|
||||
# )
|
||||
|
||||
async def build_relationship_info(self, person, is_id: bool = False) -> str:
|
||||
if is_id:
|
||||
person_id = person
|
||||
else:
|
||||
person_id = PersonInfoManager.get_person_id(person[0], person[1])
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
if not person_name or person_name == "none":
|
||||
return ""
|
||||
short_impression = await person_info_manager.get_value(person_id, "short_impression")
|
||||
|
||||
current_points = await person_info_manager.get_value(person_id, "points") or []
|
||||
# print(f"current_points: {current_points}")
|
||||
if isinstance(current_points, str):
|
||||
try:
|
||||
current_points = json.loads(current_points)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"解析points JSON失败: {current_points}")
|
||||
current_points = []
|
||||
elif not isinstance(current_points, list):
|
||||
current_points = []
|
||||
|
||||
# 按时间排序forgotten_points
|
||||
current_points.sort(key=lambda x: x[2])
|
||||
# 按权重加权随机抽取3个points,point[1]的值在1-10之间,权重越高被抽到概率越大
|
||||
if len(current_points) > 3:
|
||||
# point[1] 取值范围1-10,直接作为权重
|
||||
weights = [max(1, min(10, int(point[1]))) for point in current_points]
|
||||
points = random.choices(current_points, weights=weights, k=3)
|
||||
else:
|
||||
points = current_points
|
||||
|
||||
# 构建points文本
|
||||
points_text = "\n".join([f"{point[2]}:{point[0]}" for point in points])
|
||||
|
||||
nickname_str = await person_info_manager.get_value(person_id, "nickname")
|
||||
platform = await person_info_manager.get_value(person_id, "platform")
|
||||
|
||||
if person_name == nickname_str and not short_impression:
|
||||
return ""
|
||||
|
||||
if person_name == nickname_str:
|
||||
relation_prompt = f"'{person_name}' :"
|
||||
else:
|
||||
relation_prompt = f"'{person_name}' ,ta在{platform}上的昵称是{nickname_str}。"
|
||||
|
||||
if short_impression:
|
||||
relation_prompt += f"你对ta的印象是:{short_impression}。\n"
|
||||
|
||||
if points_text:
|
||||
relation_prompt += f"你记得ta最近做的事:{points_text}"
|
||||
|
||||
return relation_prompt
|
||||
|
||||
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[Dict[str, Any]]):
|
||||
"""更新用户印象
|
||||
|
||||
|
||||
@@ -23,12 +23,6 @@ from .base import (
|
||||
EventType,
|
||||
MaiMessages,
|
||||
)
|
||||
from .core import (
|
||||
plugin_manager,
|
||||
component_registry,
|
||||
dependency_manager,
|
||||
events_manager,
|
||||
)
|
||||
|
||||
# 导入工具模块
|
||||
from .utils import (
|
||||
@@ -38,12 +32,42 @@ from .utils import (
|
||||
# generate_plugin_manifest,
|
||||
)
|
||||
|
||||
from .apis import register_plugin, get_logger
|
||||
from .apis import (
|
||||
chat_api,
|
||||
component_manage_api,
|
||||
config_api,
|
||||
database_api,
|
||||
emoji_api,
|
||||
generator_api,
|
||||
llm_api,
|
||||
message_api,
|
||||
person_api,
|
||||
plugin_manage_api,
|
||||
send_api,
|
||||
utils_api,
|
||||
register_plugin,
|
||||
get_logger,
|
||||
)
|
||||
|
||||
|
||||
__version__ = "1.0.0"
|
||||
|
||||
__all__ = [
|
||||
# API 模块
|
||||
"chat_api",
|
||||
"component_manage_api",
|
||||
"config_api",
|
||||
"database_api",
|
||||
"emoji_api",
|
||||
"generator_api",
|
||||
"llm_api",
|
||||
"message_api",
|
||||
"person_api",
|
||||
"plugin_manage_api",
|
||||
"send_api",
|
||||
"utils_api",
|
||||
"register_plugin",
|
||||
"get_logger",
|
||||
# 基础类
|
||||
"BasePlugin",
|
||||
"BaseAction",
|
||||
@@ -62,11 +86,6 @@ __all__ = [
|
||||
"EventType",
|
||||
# 消息
|
||||
"MaiMessages",
|
||||
# 管理器
|
||||
"plugin_manager",
|
||||
"component_registry",
|
||||
"dependency_manager",
|
||||
"events_manager",
|
||||
# 装饰器
|
||||
"register_plugin",
|
||||
"ConfigField",
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
# 导入所有API模块
|
||||
from src.plugin_system.apis import (
|
||||
chat_api,
|
||||
component_manage_api,
|
||||
config_api,
|
||||
database_api,
|
||||
emoji_api,
|
||||
@@ -14,15 +15,17 @@ from src.plugin_system.apis import (
|
||||
llm_api,
|
||||
message_api,
|
||||
person_api,
|
||||
plugin_manage_api,
|
||||
send_api,
|
||||
utils_api,
|
||||
plugin_register_api,
|
||||
)
|
||||
from .logging_api import get_logger
|
||||
from .plugin_register_api import register_plugin
|
||||
|
||||
# 导出所有API模块,使它们可以通过 apis.xxx 方式访问
|
||||
__all__ = [
|
||||
"chat_api",
|
||||
"component_manage_api",
|
||||
"config_api",
|
||||
"database_api",
|
||||
"emoji_api",
|
||||
@@ -30,9 +33,9 @@ __all__ = [
|
||||
"llm_api",
|
||||
"message_api",
|
||||
"person_api",
|
||||
"plugin_manage_api",
|
||||
"send_api",
|
||||
"utils_api",
|
||||
"plugin_register_api",
|
||||
"get_logger",
|
||||
"register_plugin",
|
||||
]
|
||||
|
||||
245
src/plugin_system/apis/component_manage_api.py
Normal file
245
src/plugin_system/apis/component_manage_api.py
Normal file
@@ -0,0 +1,245 @@
|
||||
from typing import Optional, Union, Dict
|
||||
from src.plugin_system.base.component_types import (
|
||||
CommandInfo,
|
||||
ActionInfo,
|
||||
EventHandlerInfo,
|
||||
PluginInfo,
|
||||
ComponentType,
|
||||
)
|
||||
|
||||
|
||||
# === 插件信息查询 ===
|
||||
def get_all_plugin_info() -> Dict[str, PluginInfo]:
|
||||
"""
|
||||
获取所有插件的信息。
|
||||
|
||||
Returns:
|
||||
dict: 包含所有插件信息的字典,键为插件名称,值为 PluginInfo 对象。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_all_plugins()
|
||||
|
||||
|
||||
def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]:
|
||||
"""
|
||||
获取指定插件的信息。
|
||||
|
||||
Args:
|
||||
plugin_name (str): 插件名称。
|
||||
|
||||
Returns:
|
||||
PluginInfo: 插件信息对象,如果插件不存在则返回 None。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_plugin_info(plugin_name)
|
||||
|
||||
|
||||
# === 组件查询方法 ===
|
||||
def get_component_info(
|
||||
component_name: str, component_type: ComponentType
|
||||
) -> Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
|
||||
"""
|
||||
获取指定组件的信息。
|
||||
|
||||
Args:
|
||||
component_name (str): 组件名称。
|
||||
component_type (ComponentType): 组件类型。
|
||||
Returns:
|
||||
Union[CommandInfo, ActionInfo, EventHandlerInfo]: 组件信息对象,如果组件不存在则返回 None。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_component_info(component_name, component_type) # type: ignore
|
||||
|
||||
|
||||
def get_components_info_by_type(
|
||||
component_type: ComponentType,
|
||||
) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
|
||||
"""
|
||||
获取指定类型的所有组件信息。
|
||||
|
||||
Args:
|
||||
component_type (ComponentType): 组件类型。
|
||||
|
||||
Returns:
|
||||
dict: 包含指定类型组件信息的字典,键为组件名称,值为对应的组件信息对象。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_components_by_type(component_type) # type: ignore
|
||||
|
||||
|
||||
def get_enabled_components_info_by_type(
|
||||
component_type: ComponentType,
|
||||
) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
|
||||
"""
|
||||
获取指定类型的所有启用的组件信息。
|
||||
|
||||
Args:
|
||||
component_type (ComponentType): 组件类型。
|
||||
|
||||
Returns:
|
||||
dict: 包含指定类型启用组件信息的字典,键为组件名称,值为对应的组件信息对象。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_enabled_components_by_type(component_type) # type: ignore
|
||||
|
||||
|
||||
# === Action 查询方法 ===
|
||||
def get_registered_action_info(action_name: str) -> Optional[ActionInfo]:
|
||||
"""
|
||||
获取指定 Action 的注册信息。
|
||||
|
||||
Args:
|
||||
action_name (str): Action 名称。
|
||||
|
||||
Returns:
|
||||
ActionInfo: Action 信息对象,如果 Action 不存在则返回 None。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_registered_action_info(action_name)
|
||||
|
||||
|
||||
def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
|
||||
"""
|
||||
获取指定 Command 的注册信息。
|
||||
|
||||
Args:
|
||||
command_name (str): Command 名称。
|
||||
|
||||
Returns:
|
||||
CommandInfo: Command 信息对象,如果 Command 不存在则返回 None。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_registered_command_info(command_name)
|
||||
|
||||
|
||||
# === EventHandler 特定查询方法 ===
|
||||
def get_registered_event_handler_info(
|
||||
event_handler_name: str,
|
||||
) -> Optional[EventHandlerInfo]:
|
||||
"""
|
||||
获取指定 EventHandler 的注册信息。
|
||||
|
||||
Args:
|
||||
event_handler_name (str): EventHandler 名称。
|
||||
|
||||
Returns:
|
||||
EventHandlerInfo: EventHandler 信息对象,如果 EventHandler 不存在则返回 None。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_registered_event_handler_info(event_handler_name)
|
||||
|
||||
|
||||
# === 组件管理方法 ===
|
||||
def globally_enable_component(component_name: str, component_type: ComponentType) -> bool:
|
||||
"""
|
||||
全局启用指定组件。
|
||||
|
||||
Args:
|
||||
component_name (str): 组件名称。
|
||||
component_type (ComponentType): 组件类型。
|
||||
|
||||
Returns:
|
||||
bool: 启用成功返回 True,否则返回 False。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.enable_component(component_name, component_type)
|
||||
|
||||
|
||||
async def globally_disable_component(component_name: str, component_type: ComponentType) -> bool:
|
||||
"""
|
||||
全局禁用指定组件。
|
||||
|
||||
**此函数是异步的,确保在异步环境中调用。**
|
||||
|
||||
Args:
|
||||
component_name (str): 组件名称。
|
||||
component_type (ComponentType): 组件类型。
|
||||
|
||||
Returns:
|
||||
bool: 禁用成功返回 True,否则返回 False。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return await component_registry.disable_component(component_name, component_type)
|
||||
|
||||
|
||||
def locally_enable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool:
|
||||
"""
|
||||
局部启用指定组件。
|
||||
|
||||
Args:
|
||||
component_name (str): 组件名称。
|
||||
component_type (ComponentType): 组件类型。
|
||||
stream_id (str): 消息流 ID。
|
||||
|
||||
Returns:
|
||||
bool: 启用成功返回 True,否则返回 False。
|
||||
"""
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
return global_announcement_manager.enable_specific_chat_action(stream_id, component_name)
|
||||
case ComponentType.COMMAND:
|
||||
return global_announcement_manager.enable_specific_chat_command(stream_id, component_name)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
return global_announcement_manager.enable_specific_chat_event_handler(stream_id, component_name)
|
||||
case _:
|
||||
raise ValueError(f"未知 component type: {component_type}")
|
||||
|
||||
|
||||
def locally_disable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool:
|
||||
"""
|
||||
局部禁用指定组件。
|
||||
|
||||
Args:
|
||||
component_name (str): 组件名称。
|
||||
component_type (ComponentType): 组件类型。
|
||||
stream_id (str): 消息流 ID。
|
||||
|
||||
Returns:
|
||||
bool: 禁用成功返回 True,否则返回 False。
|
||||
"""
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
return global_announcement_manager.disable_specific_chat_action(stream_id, component_name)
|
||||
case ComponentType.COMMAND:
|
||||
return global_announcement_manager.disable_specific_chat_command(stream_id, component_name)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
return global_announcement_manager.disable_specific_chat_event_handler(stream_id, component_name)
|
||||
case _:
|
||||
raise ValueError(f"未知 component type: {component_type}")
|
||||
|
||||
def get_locally_disabled_components(stream_id: str, component_type: ComponentType) -> list[str]:
|
||||
"""
|
||||
获取指定消息流中禁用的组件列表。
|
||||
|
||||
Args:
|
||||
stream_id (str): 消息流 ID。
|
||||
component_type (ComponentType): 组件类型。
|
||||
|
||||
Returns:
|
||||
list[str]: 禁用的组件名称列表。
|
||||
"""
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
return global_announcement_manager.get_disabled_chat_actions(stream_id)
|
||||
case ComponentType.COMMAND:
|
||||
return global_announcement_manager.get_disabled_chat_commands(stream_id)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
return global_announcement_manager.get_disabled_chat_event_handlers(stream_id)
|
||||
case _:
|
||||
raise ValueError(f"未知 component type: {component_type}")
|
||||
95
src/plugin_system/apis/plugin_manage_api.py
Normal file
95
src/plugin_system/apis/plugin_manage_api.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from typing import Tuple, List
|
||||
def list_loaded_plugins() -> List[str]:
|
||||
"""
|
||||
列出所有当前加载的插件。
|
||||
|
||||
Returns:
|
||||
list: 当前加载的插件名称列表。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
return plugin_manager.list_loaded_plugins()
|
||||
|
||||
|
||||
def list_registered_plugins() -> List[str]:
|
||||
"""
|
||||
列出所有已注册的插件。
|
||||
|
||||
Returns:
|
||||
list: 已注册的插件名称列表。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
return plugin_manager.list_registered_plugins()
|
||||
|
||||
|
||||
async def remove_plugin(plugin_name: str) -> bool:
|
||||
"""
|
||||
卸载指定的插件。
|
||||
|
||||
**此函数是异步的,确保在异步环境中调用。**
|
||||
|
||||
Args:
|
||||
plugin_name (str): 要卸载的插件名称。
|
||||
|
||||
Returns:
|
||||
bool: 卸载是否成功。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
return await plugin_manager.remove_registered_plugin(plugin_name)
|
||||
|
||||
|
||||
async def reload_plugin(plugin_name: str) -> bool:
|
||||
"""
|
||||
重新加载指定的插件。
|
||||
|
||||
**此函数是异步的,确保在异步环境中调用。**
|
||||
|
||||
Args:
|
||||
plugin_name (str): 要重新加载的插件名称。
|
||||
|
||||
Returns:
|
||||
bool: 重新加载是否成功。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
return await plugin_manager.reload_registered_plugin(plugin_name)
|
||||
|
||||
|
||||
def load_plugin(plugin_name: str) -> Tuple[bool, int]:
|
||||
"""
|
||||
加载指定的插件。
|
||||
|
||||
Args:
|
||||
plugin_name (str): 要加载的插件名称。
|
||||
|
||||
Returns:
|
||||
Tuple[bool, int]: 加载是否成功,成功或失败个数。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
return plugin_manager.load_registered_plugin_classes(plugin_name)
|
||||
|
||||
def add_plugin_directory(plugin_directory: str) -> bool:
|
||||
"""
|
||||
添加插件目录。
|
||||
|
||||
Args:
|
||||
plugin_directory (str): 要添加的插件目录路径。
|
||||
Returns:
|
||||
bool: 添加是否成功。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
return plugin_manager.add_plugin_directory(plugin_directory)
|
||||
|
||||
def rescan_plugin_directory() -> Tuple[int, int]:
|
||||
"""
|
||||
重新扫描插件目录,加载新插件。
|
||||
Returns:
|
||||
Tuple[int, int]: 成功加载的插件数量和失败的插件数量。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
return plugin_manager.rescan_plugin_directory()
|
||||
@@ -28,7 +28,6 @@ def register_plugin(cls):
|
||||
if "." in plugin_name:
|
||||
logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
raise ValueError(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
plugin_manager.plugin_classes[plugin_name] = cls
|
||||
splitted_name = cls.__module__.split(".")
|
||||
root_path = Path(__file__)
|
||||
|
||||
@@ -40,6 +39,7 @@ def register_plugin(cls):
|
||||
logger.error(f"注册 {plugin_name} 无法找到项目根目录")
|
||||
return cls
|
||||
|
||||
plugin_manager.plugin_classes[plugin_name] = cls
|
||||
plugin_manager.plugin_paths[plugin_name] = str(Path(root_path, *splitted_name).resolve())
|
||||
logger.debug(f"插件类已注册: {plugin_name}, 路径: {plugin_manager.plugin_paths[plugin_name]}")
|
||||
|
||||
|
||||
@@ -49,12 +49,10 @@ class BaseAction(ABC):
|
||||
reasoning: 执行该动作的理由
|
||||
cycle_timers: 计时器字典
|
||||
thinking_id: 思考ID
|
||||
expressor: 表达器对象
|
||||
replyer: 回复器对象
|
||||
chat_stream: 聊天流对象
|
||||
log_prefix: 日志前缀
|
||||
shutting_down: 是否正在关闭
|
||||
plugin_config: 插件配置字典
|
||||
action_message: 消息数据
|
||||
**kwargs: 其他参数
|
||||
"""
|
||||
if plugin_config is None:
|
||||
@@ -65,21 +63,30 @@ class BaseAction(ABC):
|
||||
self.thinking_id = thinking_id
|
||||
self.log_prefix = log_prefix
|
||||
|
||||
# 保存插件配置
|
||||
self.plugin_config = plugin_config or {}
|
||||
"""对应的插件配置"""
|
||||
|
||||
# 设置动作基本信息实例属性
|
||||
self.action_name: str = getattr(self, "action_name", self.__class__.__name__.lower().replace("action", ""))
|
||||
"""Action的名字"""
|
||||
self.action_description: str = getattr(self, "action_description", self.__doc__ or "Action组件")
|
||||
"""Action的描述"""
|
||||
self.action_parameters: dict = getattr(self.__class__, "action_parameters", {}).copy()
|
||||
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
|
||||
|
||||
# 设置激活类型实例属性(从类属性复制,提供默认值)
|
||||
self.focus_activation_type = getattr(self.__class__, "focus_activation_type", ActionActivationType.ALWAYS)
|
||||
"""FOCUS模式下的激活类型"""
|
||||
self.normal_activation_type = getattr(self.__class__, "normal_activation_type", ActionActivationType.ALWAYS)
|
||||
"""NORMAL模式下的激活类型"""
|
||||
self.activation_type = getattr(self.__class__, "activation_type", self.focus_activation_type)
|
||||
"""激活类型"""
|
||||
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
|
||||
"""当激活类型为RANDOM时的概率"""
|
||||
self.llm_judge_prompt: str = getattr(self.__class__, "llm_judge_prompt", "")
|
||||
"""协助LLM进行判断的Prompt"""
|
||||
self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy()
|
||||
"""激活类型为KEYWORD时的KEYWORDS列表"""
|
||||
self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False)
|
||||
self.mode_enable: ChatMode = getattr(self.__class__, "mode_enable", ChatMode.ALL)
|
||||
self.parallel_action: bool = getattr(self.__class__, "parallel_action", True)
|
||||
@@ -136,7 +143,7 @@ class BaseAction(ABC):
|
||||
self.target_id = self.user_id
|
||||
|
||||
logger.debug(f"{self.log_prefix} Action组件初始化完成")
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
|
||||
)
|
||||
|
||||
@@ -405,23 +412,11 @@ class BaseAction(ABC):
|
||||
"""
|
||||
return await self.execute()
|
||||
|
||||
# def get_action_context(self, key: str, default=None):
|
||||
# """获取action上下文信息
|
||||
|
||||
# Args:
|
||||
# key: 上下文键名
|
||||
# default: 默认值
|
||||
|
||||
# Returns:
|
||||
# Any: 上下文值或默认值
|
||||
# """
|
||||
# return self.api.get_action_context(key, default)
|
||||
|
||||
def get_config(self, key: str, default=None):
|
||||
"""获取插件配置值,支持嵌套键访问
|
||||
"""获取插件配置值,使用嵌套键访问
|
||||
|
||||
Args:
|
||||
key: 配置键名,支持嵌套访问如 "section.subsection.key"
|
||||
key: 配置键名,使用嵌套访问如 "section.subsection.key"
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -17,17 +17,18 @@ class BaseCommand(ABC):
|
||||
- command_pattern: 命令匹配的正则表达式
|
||||
- command_help: 命令帮助信息
|
||||
- command_examples: 命令使用示例列表
|
||||
- intercept_message: 是否拦截消息处理(默认True拦截,False继续传递)
|
||||
"""
|
||||
|
||||
command_name: str = ""
|
||||
"""Command组件的名称"""
|
||||
command_description: str = ""
|
||||
|
||||
# 默认命令设置(子类可以覆盖)
|
||||
command_pattern: str = ""
|
||||
"""Command组件的描述"""
|
||||
# 默认命令设置
|
||||
command_pattern: str = r""
|
||||
"""命令匹配的正则表达式"""
|
||||
command_help: str = ""
|
||||
"""命令帮助信息"""
|
||||
command_examples: List[str] = []
|
||||
intercept_message: bool = True # 默认拦截消息,不继续处理
|
||||
|
||||
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
|
||||
"""初始化Command组件
|
||||
@@ -53,11 +54,11 @@ class BaseCommand(ABC):
|
||||
self.matched_groups = groups
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self) -> Tuple[bool, Optional[str]]:
|
||||
async def execute(self) -> Tuple[bool, Optional[str], bool]:
|
||||
"""执行Command的抽象方法,子类必须实现
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str]]: (是否执行成功, 可选的回复消息)
|
||||
Tuple[bool, Optional[str], bool]: (是否执行成功, 可选的回复消息, 是否拦截消息 不进行 后续处理)
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -229,5 +230,4 @@ class BaseCommand(ABC):
|
||||
command_pattern=cls.command_pattern,
|
||||
command_help=cls.command_help,
|
||||
command_examples=cls.command_examples.copy() if cls.command_examples else [],
|
||||
intercept_message=cls.intercept_message,
|
||||
)
|
||||
|
||||
@@ -13,16 +13,23 @@ class BaseEventHandler(ABC):
|
||||
所有事件处理器都应该继承这个基类,提供事件处理的基本接口
|
||||
"""
|
||||
|
||||
event_type: EventType = EventType.UNKNOWN # 事件类型,默认为未知
|
||||
handler_name: str = "" # 处理器名称
|
||||
event_type: EventType = EventType.UNKNOWN
|
||||
"""事件类型,默认为未知"""
|
||||
handler_name: str = ""
|
||||
"""处理器名称"""
|
||||
handler_description: str = ""
|
||||
weight: int = 0 # 权重,数值越大优先级越高
|
||||
intercept_message: bool = False # 是否拦截消息,默认为否
|
||||
"""处理器描述"""
|
||||
weight: int = 0
|
||||
"""处理器权重,越大权重越高"""
|
||||
intercept_message: bool = False
|
||||
"""是否拦截消息,默认为否"""
|
||||
|
||||
def __init__(self):
|
||||
self.log_prefix = "[EventHandler]"
|
||||
self.plugin_name = "" # 对应插件名
|
||||
self.plugin_config: Optional[Dict] = None # 插件配置字典
|
||||
self.plugin_name = ""
|
||||
"""对应插件名"""
|
||||
self.plugin_config: Optional[Dict] = None
|
||||
"""插件配置字典"""
|
||||
if self.event_type == EventType.UNKNOWN:
|
||||
raise NotImplementedError("事件处理器必须指定 event_type")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import List, Type, Tuple, Union
|
||||
from .plugin_base import PluginBase
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ComponentInfo, ActionInfo, CommandInfo, EventHandlerInfo
|
||||
from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo
|
||||
from .base_action import BaseAction
|
||||
from .base_command import BaseCommand
|
||||
from .base_events_handler import BaseEventHandler
|
||||
|
||||
@@ -142,7 +142,6 @@ class CommandInfo(ComponentInfo):
|
||||
command_pattern: str = "" # 命令匹配模式(正则表达式)
|
||||
command_help: str = "" # 命令帮助信息
|
||||
command_examples: List[str] = field(default_factory=list) # 命令使用示例
|
||||
intercept_message: bool = True # 是否拦截消息处理(默认拦截)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
@@ -8,10 +8,12 @@ from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.core.dependency_manager import dependency_manager
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
|
||||
__all__ = [
|
||||
"plugin_manager",
|
||||
"component_registry",
|
||||
"dependency_manager",
|
||||
"events_manager",
|
||||
"global_announcement_manager",
|
||||
]
|
||||
|
||||
@@ -25,27 +25,35 @@ class ComponentRegistry:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 组件注册表
|
||||
self._components: Dict[str, ComponentInfo] = {} # 命名空间式组件名 -> 组件信息
|
||||
# 类型 -> 命名空间式名称 -> 组件信息
|
||||
# 命名空间式组件名构成法 f"{component_type}.{component_name}"
|
||||
self._components: Dict[str, ComponentInfo] = {}
|
||||
"""组件注册表 命名空间式组件名 -> 组件信息"""
|
||||
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
|
||||
# 命名空间式组件名 -> 组件类
|
||||
"""类型 -> 组件原名称 -> 组件信息"""
|
||||
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseEventHandler]]] = {}
|
||||
"""命名空间式组件名 -> 组件类"""
|
||||
|
||||
# 插件注册表
|
||||
self._plugins: Dict[str, PluginInfo] = {} # 插件名 -> 插件信息
|
||||
self._plugins: Dict[str, PluginInfo] = {}
|
||||
"""插件名 -> 插件信息"""
|
||||
|
||||
# Action特定注册表
|
||||
self._action_registry: Dict[str, Type[BaseAction]] = {} # action名 -> action类
|
||||
self._default_actions: Dict[str, ActionInfo] = {} # 默认动作集,即启用的Action集,用于重置ActionManager状态
|
||||
self._action_registry: Dict[str, Type[BaseAction]] = {}
|
||||
"""Action注册表 action名 -> action类"""
|
||||
self._default_actions: Dict[str, ActionInfo] = {}
|
||||
"""默认动作集,即启用的Action集,用于重置ActionManager状态"""
|
||||
|
||||
# Command特定注册表
|
||||
self._command_registry: Dict[str, Type[BaseCommand]] = {} # command名 -> command类
|
||||
self._command_patterns: Dict[Pattern, str] = {} # 编译后的正则 -> command名
|
||||
self._command_registry: Dict[str, Type[BaseCommand]] = {}
|
||||
"""Command类注册表 command名 -> command类"""
|
||||
self._command_patterns: Dict[Pattern, str] = {}
|
||||
"""编译后的正则 -> command名"""
|
||||
|
||||
# EventHandler特定注册表
|
||||
self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {} # event_handler名 -> event_handler类
|
||||
self._enabled_event_handlers: Dict[str, Type[BaseEventHandler]] = {} # 启用的事件处理器
|
||||
self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {}
|
||||
"""event_handler名 -> event_handler类"""
|
||||
self._enabled_event_handlers: Dict[str, Type[BaseEventHandler]] = {}
|
||||
"""启用的事件处理器 event_handler名 -> event_handler类"""
|
||||
|
||||
logger.info("组件注册中心初始化完成")
|
||||
|
||||
@@ -110,11 +118,17 @@ class ComponentRegistry:
|
||||
# 根据组件类型进行特定注册(使用原始名称)
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
ret = self._register_action_component(component_info, component_class) # type: ignore
|
||||
assert isinstance(component_info, ActionInfo)
|
||||
assert issubclass(component_class, BaseAction)
|
||||
ret = self._register_action_component(component_info, component_class)
|
||||
case ComponentType.COMMAND:
|
||||
ret = self._register_command_component(component_info, component_class) # type: ignore
|
||||
assert isinstance(component_info, CommandInfo)
|
||||
assert issubclass(component_class, BaseCommand)
|
||||
ret = self._register_command_component(component_info, component_class)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
ret = self._register_event_handler_component(component_info, component_class) # type: ignore
|
||||
assert isinstance(component_info, EventHandlerInfo)
|
||||
assert issubclass(component_class, BaseEventHandler)
|
||||
ret = self._register_event_handler_component(component_info, component_class)
|
||||
case _:
|
||||
logger.warning(f"未知组件类型: {component_type}")
|
||||
|
||||
@@ -160,7 +174,9 @@ class ComponentRegistry:
|
||||
if pattern not in self._command_patterns:
|
||||
self._command_patterns[pattern] = command_name
|
||||
else:
|
||||
logger.warning(f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令")
|
||||
logger.warning(
|
||||
f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@@ -176,6 +192,10 @@ class ComponentRegistry:
|
||||
|
||||
self._event_handler_registry[handler_name] = handler_class
|
||||
|
||||
if not handler_info.enabled:
|
||||
logger.warning(f"EventHandler组件 {handler_name} 未启用")
|
||||
return True # 未启用,但是也是注册成功
|
||||
|
||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||
|
||||
if events_manager.register_event_subscriber(handler_info, handler_class):
|
||||
@@ -185,6 +205,124 @@ class ComponentRegistry:
|
||||
logger.error(f"注册事件处理器 {handler_name} 失败")
|
||||
return False
|
||||
|
||||
# === 组件移除相关 ===
|
||||
|
||||
async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool:
|
||||
target_component_class = self.get_component_class(component_name, component_type)
|
||||
if not target_component_class:
|
||||
logger.warning(f"组件 {component_name} 未注册,无法移除")
|
||||
return False
|
||||
try:
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
self._action_registry.pop(component_name)
|
||||
self._default_actions.pop(component_name)
|
||||
case ComponentType.COMMAND:
|
||||
self._command_registry.pop(component_name)
|
||||
keys_to_remove = [k for k, v in self._command_patterns.items() if v == component_name]
|
||||
for key in keys_to_remove:
|
||||
self._command_patterns.pop(key)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||
|
||||
self._event_handler_registry.pop(component_name)
|
||||
self._enabled_event_handlers.pop(component_name)
|
||||
await events_manager.unregister_event_subscriber(component_name)
|
||||
namespaced_name = f"{component_type}.{component_name}"
|
||||
self._components.pop(namespaced_name)
|
||||
self._components_by_type[component_type].pop(component_name)
|
||||
self._components_classes.pop(namespaced_name)
|
||||
logger.info(f"组件 {component_name} 已移除")
|
||||
return True
|
||||
except KeyError:
|
||||
logger.warning(f"移除组件时未找到组件: {component_name}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"移除组件 {component_name} 时发生错误: {e}")
|
||||
return False
|
||||
|
||||
def remove_plugin_registry(self, plugin_name: str) -> bool:
|
||||
"""移除插件注册信息
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
bool: 是否成功移除
|
||||
"""
|
||||
if plugin_name not in self._plugins:
|
||||
logger.warning(f"插件 {plugin_name} 未注册,无法移除")
|
||||
return False
|
||||
del self._plugins[plugin_name]
|
||||
logger.info(f"插件 {plugin_name} 已移除")
|
||||
return True
|
||||
|
||||
# === 组件全局启用/禁用方法 ===
|
||||
|
||||
def enable_component(self, component_name: str, component_type: ComponentType) -> bool:
|
||||
"""全局的启用某个组件
|
||||
Parameters:
|
||||
component_name: 组件名称
|
||||
component_type: 组件类型
|
||||
Returns:
|
||||
bool: 启用成功返回True,失败返回False
|
||||
"""
|
||||
target_component_class = self.get_component_class(component_name, component_type)
|
||||
target_component_info = self.get_component_info(component_name, component_type)
|
||||
if not target_component_class or not target_component_info:
|
||||
logger.warning(f"组件 {component_name} 未注册,无法启用")
|
||||
return False
|
||||
target_component_info.enabled = True
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
assert isinstance(target_component_info, ActionInfo)
|
||||
self._default_actions[component_name] = target_component_info
|
||||
case ComponentType.COMMAND:
|
||||
assert isinstance(target_component_info, CommandInfo)
|
||||
pattern = target_component_info.command_pattern
|
||||
self._command_patterns[re.compile(pattern)] = component_name
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
assert isinstance(target_component_info, EventHandlerInfo)
|
||||
assert issubclass(target_component_class, BaseEventHandler)
|
||||
self._enabled_event_handlers[component_name] = target_component_class
|
||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||
|
||||
events_manager.register_event_subscriber(target_component_info, target_component_class)
|
||||
namespaced_name = f"{component_type}.{component_name}"
|
||||
self._components[namespaced_name].enabled = True
|
||||
self._components_by_type[component_type][component_name].enabled = True
|
||||
logger.info(f"组件 {component_name} 已启用")
|
||||
return True
|
||||
|
||||
async def disable_component(self, component_name: str, component_type: ComponentType) -> bool:
|
||||
"""全局的禁用某个组件
|
||||
Parameters:
|
||||
component_name: 组件名称
|
||||
component_type: 组件类型
|
||||
Returns:
|
||||
bool: 禁用成功返回True,失败返回False
|
||||
"""
|
||||
target_component_class = self.get_component_class(component_name, component_type)
|
||||
target_component_info = self.get_component_info(component_name, component_type)
|
||||
if not target_component_class or not target_component_info:
|
||||
logger.warning(f"组件 {component_name} 未注册,无法禁用")
|
||||
return False
|
||||
target_component_info.enabled = False
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
self._default_actions.pop(component_name, None)
|
||||
case ComponentType.COMMAND:
|
||||
self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name}
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
self._enabled_event_handlers.pop(component_name, None)
|
||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||
|
||||
await events_manager.unregister_event_subscriber(component_name)
|
||||
self._components[component_name].enabled = False
|
||||
self._components_by_type[component_type][component_name].enabled = False
|
||||
logger.info(f"组件 {component_name} 已禁用")
|
||||
return True
|
||||
|
||||
# === 组件查询方法 ===
|
||||
def get_component_info(
|
||||
self, component_name: str, component_type: Optional[ComponentType] = None
|
||||
@@ -287,7 +425,7 @@ class ComponentRegistry:
|
||||
# === Action特定查询方法 ===
|
||||
|
||||
def get_action_registry(self) -> Dict[str, Type[BaseAction]]:
|
||||
"""获取Action注册表(用于兼容现有系统)"""
|
||||
"""获取Action注册表"""
|
||||
return self._action_registry.copy()
|
||||
|
||||
def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]:
|
||||
@@ -314,7 +452,7 @@ class ComponentRegistry:
|
||||
"""获取Command模式注册表"""
|
||||
return self._command_patterns.copy()
|
||||
|
||||
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, bool, str]]:
|
||||
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, CommandInfo]]:
|
||||
# sourcery skip: use-named-expression, use-next
|
||||
"""根据文本查找匹配的命令
|
||||
|
||||
@@ -335,11 +473,10 @@ class ComponentRegistry:
|
||||
return (
|
||||
self._command_registry[command_name],
|
||||
candidates[0].match(text).groupdict(), # type: ignore
|
||||
command_info.intercept_message,
|
||||
command_info.plugin_name,
|
||||
command_info,
|
||||
)
|
||||
|
||||
# === 事件处理器特定查询方法 ===
|
||||
# === EventHandler 特定查询方法 ===
|
||||
|
||||
def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]:
|
||||
"""获取事件处理器注册表"""
|
||||
@@ -364,9 +501,9 @@ class ComponentRegistry:
|
||||
"""获取所有插件"""
|
||||
return self._plugins.copy()
|
||||
|
||||
def get_enabled_plugins(self) -> Dict[str, PluginInfo]:
|
||||
"""获取所有启用的插件"""
|
||||
return {name: info for name, info in self._plugins.items() if info.enabled}
|
||||
# def get_enabled_plugins(self) -> Dict[str, PluginInfo]:
|
||||
# """获取所有启用的插件"""
|
||||
# return {name: info for name, info in self._plugins.items() if info.enabled}
|
||||
|
||||
def get_plugin_components(self, plugin_name: str) -> List[ComponentInfo]:
|
||||
"""获取插件的所有组件"""
|
||||
|
||||
@@ -6,6 +6,7 @@ from src.chat.message_receive.message import MessageRecv
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
from .global_announcement_manager import global_announcement_manager
|
||||
|
||||
logger = get_logger("events_manager")
|
||||
|
||||
@@ -28,18 +29,16 @@ class EventsManager:
|
||||
bool: 是否注册成功
|
||||
"""
|
||||
handler_name = handler_info.name
|
||||
plugin_name = getattr(handler_info, "plugin_name", "unknown")
|
||||
|
||||
namespace_name = f"{plugin_name}.{handler_name}"
|
||||
if namespace_name in self._handler_mapping:
|
||||
logger.warning(f"事件处理器 {namespace_name} 已存在,跳过注册")
|
||||
if handler_name in self._handler_mapping:
|
||||
logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
|
||||
return False
|
||||
|
||||
if not issubclass(handler_class, BaseEventHandler):
|
||||
logger.error(f"类 {handler_class.__name__} 不是 BaseEventHandler 的子类")
|
||||
return False
|
||||
|
||||
self._handler_mapping[namespace_name] = handler_class
|
||||
self._handler_mapping[handler_name] = handler_class
|
||||
return self._insert_event_handler(handler_class, handler_info)
|
||||
|
||||
async def handle_mai_events(
|
||||
@@ -55,6 +54,10 @@ class EventsManager:
|
||||
continue_flag = True
|
||||
transformed_message = self._transform_event_message(message, llm_prompt, llm_response)
|
||||
for handler in self._events_subscribers.get(event_type, []):
|
||||
if message.chat_stream and message.chat_stream.stream_id:
|
||||
stream_id = message.chat_stream.stream_id
|
||||
if handler.handler_name in global_announcement_manager.get_disabled_chat_event_handlers(stream_id):
|
||||
continue
|
||||
handler.set_plugin_config(component_registry.get_plugin_config(handler.plugin_name) or {})
|
||||
if handler.intercept_message:
|
||||
try:
|
||||
@@ -71,7 +74,9 @@ class EventsManager:
|
||||
try:
|
||||
handler_task = asyncio.create_task(handler.execute(transformed_message))
|
||||
handler_task.add_done_callback(self._task_done_callback)
|
||||
handler_task.set_name(f"EventHandler-{handler.handler_name}-{event_type.name}")
|
||||
handler_task.set_name(f"{handler.plugin_name}-{handler.handler_name}")
|
||||
if handler.handler_name not in self._handler_tasks:
|
||||
self._handler_tasks[handler.handler_name] = []
|
||||
self._handler_tasks[handler.handler_name].append(handler_task)
|
||||
except Exception as e:
|
||||
logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}")
|
||||
@@ -91,7 +96,7 @@ class EventsManager:
|
||||
|
||||
return True
|
||||
|
||||
def _remove_event_handler(self, handler_class: Type[BaseEventHandler]) -> bool:
|
||||
def _remove_event_handler_instance(self, handler_class: Type[BaseEventHandler]) -> bool:
|
||||
"""从事件类型列表中移除事件处理器"""
|
||||
display_handler_name = handler_class.handler_name or handler_class.__name__
|
||||
if handler_class.event_type == EventType.UNKNOWN:
|
||||
@@ -190,5 +195,20 @@ class EventsManager:
|
||||
finally:
|
||||
del self._handler_tasks[handler_name]
|
||||
|
||||
async def unregister_event_subscriber(self, handler_name: str) -> bool:
|
||||
"""取消注册事件处理器"""
|
||||
if handler_name not in self._handler_mapping:
|
||||
logger.warning(f"事件处理器 {handler_name} 不存在,无法取消注册")
|
||||
return False
|
||||
|
||||
await self.cancel_handler_tasks(handler_name)
|
||||
|
||||
handler_class = self._handler_mapping.pop(handler_name)
|
||||
if not self._remove_event_handler_instance(handler_class):
|
||||
return False
|
||||
|
||||
logger.info(f"事件处理器 {handler_name} 已成功取消注册")
|
||||
return True
|
||||
|
||||
|
||||
events_manager = EventsManager()
|
||||
|
||||
93
src/plugin_system/core/global_announcement_manager.py
Normal file
93
src/plugin_system/core/global_announcement_manager.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from typing import List, Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("global_announcement_manager")
|
||||
|
||||
|
||||
class GlobalAnnouncementManager:
|
||||
def __init__(self) -> None:
|
||||
# 用户禁用的动作,chat_id -> [action_name]
|
||||
self._user_disabled_actions: Dict[str, List[str]] = {}
|
||||
# 用户禁用的命令,chat_id -> [command_name]
|
||||
self._user_disabled_commands: Dict[str, List[str]] = {}
|
||||
# 用户禁用的事件处理器,chat_id -> [handler_name]
|
||||
self._user_disabled_event_handlers: Dict[str, List[str]] = {}
|
||||
|
||||
def disable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
|
||||
"""禁用特定聊天的某个动作"""
|
||||
if chat_id not in self._user_disabled_actions:
|
||||
self._user_disabled_actions[chat_id] = []
|
||||
if action_name in self._user_disabled_actions[chat_id]:
|
||||
logger.warning(f"动作 {action_name} 已经被禁用")
|
||||
return False
|
||||
self._user_disabled_actions[chat_id].append(action_name)
|
||||
return True
|
||||
|
||||
def enable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
|
||||
"""启用特定聊天的某个动作"""
|
||||
if chat_id in self._user_disabled_actions:
|
||||
try:
|
||||
self._user_disabled_actions[chat_id].remove(action_name)
|
||||
return True
|
||||
except ValueError:
|
||||
logger.warning(f"动作 {action_name} 不在禁用列表中")
|
||||
return False
|
||||
return False
|
||||
|
||||
def disable_specific_chat_command(self, chat_id: str, command_name: str) -> bool:
|
||||
"""禁用特定聊天的某个命令"""
|
||||
if chat_id not in self._user_disabled_commands:
|
||||
self._user_disabled_commands[chat_id] = []
|
||||
if command_name in self._user_disabled_commands[chat_id]:
|
||||
logger.warning(f"命令 {command_name} 已经被禁用")
|
||||
return False
|
||||
self._user_disabled_commands[chat_id].append(command_name)
|
||||
return True
|
||||
|
||||
def enable_specific_chat_command(self, chat_id: str, command_name: str) -> bool:
|
||||
"""启用特定聊天的某个命令"""
|
||||
if chat_id in self._user_disabled_commands:
|
||||
try:
|
||||
self._user_disabled_commands[chat_id].remove(command_name)
|
||||
return True
|
||||
except ValueError:
|
||||
logger.warning(f"命令 {command_name} 不在禁用列表中")
|
||||
return False
|
||||
return False
|
||||
|
||||
def disable_specific_chat_event_handler(self, chat_id: str, handler_name: str) -> bool:
|
||||
"""禁用特定聊天的某个事件处理器"""
|
||||
if chat_id not in self._user_disabled_event_handlers:
|
||||
self._user_disabled_event_handlers[chat_id] = []
|
||||
if handler_name in self._user_disabled_event_handlers[chat_id]:
|
||||
logger.warning(f"事件处理器 {handler_name} 已经被禁用")
|
||||
return False
|
||||
self._user_disabled_event_handlers[chat_id].append(handler_name)
|
||||
return True
|
||||
|
||||
def enable_specific_chat_event_handler(self, chat_id: str, handler_name: str) -> bool:
|
||||
"""启用特定聊天的某个事件处理器"""
|
||||
if chat_id in self._user_disabled_event_handlers:
|
||||
try:
|
||||
self._user_disabled_event_handlers[chat_id].remove(handler_name)
|
||||
return True
|
||||
except ValueError:
|
||||
logger.warning(f"事件处理器 {handler_name} 不在禁用列表中")
|
||||
return False
|
||||
return False
|
||||
|
||||
def get_disabled_chat_actions(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有动作"""
|
||||
return self._user_disabled_actions.get(chat_id, []).copy()
|
||||
|
||||
def get_disabled_chat_commands(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有命令"""
|
||||
return self._user_disabled_commands.get(chat_id, []).copy()
|
||||
|
||||
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有事件处理器"""
|
||||
return self._user_disabled_event_handlers.get(chat_id, []).copy()
|
||||
|
||||
|
||||
global_announcement_manager = GlobalAnnouncementManager()
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
import inspect
|
||||
import traceback
|
||||
|
||||
from typing import Dict, List, Optional, Tuple, Type, Any
|
||||
@@ -8,11 +7,11 @@ from pathlib import Path
|
||||
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.core.dependency_manager import dependency_manager
|
||||
from src.plugin_system.base.plugin_base import PluginBase
|
||||
from src.plugin_system.base.component_types import ComponentType, PluginInfo, PythonDependency
|
||||
from src.plugin_system.base.component_types import ComponentType, PythonDependency
|
||||
from src.plugin_system.utils.manifest_utils import VersionComparator
|
||||
from .component_registry import component_registry
|
||||
from .dependency_manager import dependency_manager
|
||||
|
||||
logger = get_logger("plugin_manager")
|
||||
|
||||
@@ -36,19 +35,7 @@ class PluginManager:
|
||||
self._ensure_plugin_directories()
|
||||
logger.info("插件管理器初始化完成")
|
||||
|
||||
def _ensure_plugin_directories(self) -> None:
|
||||
"""确保所有插件根目录存在,如果不存在则创建"""
|
||||
default_directories = ["src/plugins/built_in", "plugins"]
|
||||
|
||||
for directory in default_directories:
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
logger.info(f"创建插件根目录: {directory}")
|
||||
if directory not in self.plugin_directories:
|
||||
self.plugin_directories.append(directory)
|
||||
logger.debug(f"已添加插件根目录: {directory}")
|
||||
else:
|
||||
logger.warning(f"根目录不可重复加载: {directory}")
|
||||
# === 插件目录管理 ===
|
||||
|
||||
def add_plugin_directory(self, directory: str) -> bool:
|
||||
"""添加插件目录"""
|
||||
@@ -63,6 +50,8 @@ class PluginManager:
|
||||
logger.warning(f"插件目录不存在: {directory}")
|
||||
return False
|
||||
|
||||
# === 插件加载管理 ===
|
||||
|
||||
def load_all_plugins(self) -> Tuple[int, int]:
|
||||
"""加载所有插件
|
||||
|
||||
@@ -162,62 +151,50 @@ class PluginManager:
|
||||
logger.debug("详细错误信息: ", exc_info=True)
|
||||
return False, 1
|
||||
|
||||
def unload_registered_plugin_module(self, plugin_name: str) -> None:
|
||||
async def remove_registered_plugin(self, plugin_name: str) -> bool:
|
||||
"""
|
||||
卸载插件模块
|
||||
禁用插件模块
|
||||
"""
|
||||
pass
|
||||
if not plugin_name:
|
||||
raise ValueError("插件名称不能为空")
|
||||
if plugin_name not in self.loaded_plugins:
|
||||
logger.warning(f"插件 {plugin_name} 未加载")
|
||||
return False
|
||||
plugin_instance = self.loaded_plugins[plugin_name]
|
||||
plugin_info = plugin_instance.plugin_info
|
||||
success = True
|
||||
for component in plugin_info.components:
|
||||
success &= await component_registry.remove_component(component.name, component.component_type, plugin_name)
|
||||
success &= component_registry.remove_plugin_registry(plugin_name)
|
||||
del self.loaded_plugins[plugin_name]
|
||||
return success
|
||||
|
||||
def reload_registered_plugin_module(self, plugin_name: str) -> None:
|
||||
async def reload_registered_plugin(self, plugin_name: str) -> bool:
|
||||
"""
|
||||
重载插件模块
|
||||
"""
|
||||
self.unload_registered_plugin_module(plugin_name)
|
||||
self.load_registered_plugin_classes(plugin_name)
|
||||
if not await self.remove_registered_plugin(plugin_name):
|
||||
return False
|
||||
if not self.load_registered_plugin_classes(plugin_name)[0]:
|
||||
return False
|
||||
logger.debug(f"插件 {plugin_name} 重载成功")
|
||||
return True
|
||||
|
||||
def rescan_plugin_directory(self) -> None:
|
||||
def rescan_plugin_directory(self) -> Tuple[int, int]:
|
||||
"""
|
||||
重新扫描插件根目录
|
||||
"""
|
||||
# --------------------------------------- NEED REFACTORING ---------------------------------------
|
||||
total_success = 0
|
||||
total_fail = 0
|
||||
for directory in self.plugin_directories:
|
||||
if os.path.exists(directory):
|
||||
logger.debug(f"重新扫描插件根目录: {directory}")
|
||||
self._load_plugin_modules_from_directory(directory)
|
||||
success, fail = self._load_plugin_modules_from_directory(directory)
|
||||
total_success += success
|
||||
total_fail += fail
|
||||
else:
|
||||
logger.warning(f"插件根目录不存在: {directory}")
|
||||
|
||||
def get_loaded_plugins(self) -> List[PluginInfo]:
|
||||
"""获取所有已加载的插件信息"""
|
||||
return list(component_registry.get_all_plugins().values())
|
||||
|
||||
def get_enabled_plugins(self) -> List[PluginInfo]:
|
||||
"""获取所有启用的插件信息"""
|
||||
return list(component_registry.get_enabled_plugins().values())
|
||||
|
||||
# def enable_plugin(self, plugin_name: str) -> bool:
|
||||
# # -------------------------------- NEED REFACTORING --------------------------------
|
||||
# """启用插件"""
|
||||
# if plugin_info := component_registry.get_plugin_info(plugin_name):
|
||||
# plugin_info.enabled = True
|
||||
# # 启用插件的所有组件
|
||||
# for component in plugin_info.components:
|
||||
# component_registry.enable_component(component.name)
|
||||
# logger.debug(f"已启用插件: {plugin_name}")
|
||||
# return True
|
||||
# return False
|
||||
|
||||
# def disable_plugin(self, plugin_name: str) -> bool:
|
||||
# # -------------------------------- NEED REFACTORING --------------------------------
|
||||
# """禁用插件"""
|
||||
# if plugin_info := component_registry.get_plugin_info(plugin_name):
|
||||
# plugin_info.enabled = False
|
||||
# # 禁用插件的所有组件
|
||||
# for component in plugin_info.components:
|
||||
# component_registry.disable_component(component.name)
|
||||
# logger.debug(f"已禁用插件: {plugin_name}")
|
||||
# return True
|
||||
# return False
|
||||
return total_success, total_fail
|
||||
|
||||
def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]:
|
||||
"""获取插件实例
|
||||
@@ -230,25 +207,6 @@ class PluginManager:
|
||||
"""
|
||||
return self.loaded_plugins.get(plugin_name)
|
||||
|
||||
def get_plugin_stats(self) -> Dict[str, Any]:
|
||||
"""获取插件统计信息"""
|
||||
all_plugins = component_registry.get_all_plugins()
|
||||
enabled_plugins = component_registry.get_enabled_plugins()
|
||||
|
||||
action_components = component_registry.get_components_by_type(ComponentType.ACTION)
|
||||
command_components = component_registry.get_components_by_type(ComponentType.COMMAND)
|
||||
|
||||
return {
|
||||
"total_plugins": len(all_plugins),
|
||||
"enabled_plugins": len(enabled_plugins),
|
||||
"failed_plugins": len(self.failed_plugins),
|
||||
"total_components": len(action_components) + len(command_components),
|
||||
"action_components": len(action_components),
|
||||
"command_components": len(command_components),
|
||||
"loaded_plugin_files": len(self.loaded_plugins),
|
||||
"failed_plugin_details": self.failed_plugins.copy(),
|
||||
}
|
||||
|
||||
def check_all_dependencies(self, auto_install: bool = False) -> Dict[str, Any]:
|
||||
"""检查所有插件的Python依赖包
|
||||
|
||||
@@ -347,6 +305,43 @@ class PluginManager:
|
||||
|
||||
return dependency_manager.generate_requirements_file(all_dependencies, output_path)
|
||||
|
||||
# === 查询方法 ===
|
||||
def list_loaded_plugins(self) -> List[str]:
|
||||
"""
|
||||
列出所有当前加载的插件。
|
||||
|
||||
Returns:
|
||||
list: 当前加载的插件名称列表。
|
||||
"""
|
||||
return list(self.loaded_plugins.keys())
|
||||
|
||||
def list_registered_plugins(self) -> List[str]:
|
||||
"""
|
||||
列出所有已注册的插件类。
|
||||
|
||||
Returns:
|
||||
list: 已注册的插件类名称列表。
|
||||
"""
|
||||
return list(self.plugin_classes.keys())
|
||||
|
||||
# === 私有方法 ===
|
||||
# == 目录管理 ==
|
||||
def _ensure_plugin_directories(self) -> None:
|
||||
"""确保所有插件根目录存在,如果不存在则创建"""
|
||||
default_directories = ["src/plugins/built_in", "plugins"]
|
||||
|
||||
for directory in default_directories:
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
logger.info(f"创建插件根目录: {directory}")
|
||||
if directory not in self.plugin_directories:
|
||||
self.plugin_directories.append(directory)
|
||||
logger.debug(f"已添加插件根目录: {directory}")
|
||||
else:
|
||||
logger.warning(f"根目录不可重复加载: {directory}")
|
||||
|
||||
# == 插件加载 ==
|
||||
|
||||
def _load_plugin_modules_from_directory(self, directory: str) -> tuple[int, int]:
|
||||
"""从指定目录加载插件模块"""
|
||||
loaded_count = 0
|
||||
@@ -372,18 +367,6 @@ class PluginManager:
|
||||
|
||||
return loaded_count, failed_count
|
||||
|
||||
def _find_plugin_directory(self, plugin_class: Type[PluginBase]) -> Optional[str]:
|
||||
"""查找插件类对应的目录路径"""
|
||||
try:
|
||||
# module = getmodule(plugin_class)
|
||||
# if module and hasattr(module, "__file__") and module.__file__:
|
||||
# return os.path.dirname(module.__file__)
|
||||
file_path = inspect.getfile(plugin_class)
|
||||
return os.path.dirname(file_path)
|
||||
except Exception as e:
|
||||
logger.debug(f"通过inspect获取插件目录失败: {e}")
|
||||
return None
|
||||
|
||||
def _load_plugin_module_file(self, plugin_file: str) -> bool:
|
||||
# sourcery skip: extract-method
|
||||
"""加载单个插件模块文件
|
||||
@@ -416,6 +399,8 @@ class PluginManager:
|
||||
self.failed_plugins[module_name] = error_msg
|
||||
return False
|
||||
|
||||
# == 兼容性检查 ==
|
||||
|
||||
def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
"""检查插件版本兼容性
|
||||
|
||||
@@ -451,6 +436,8 @@ class PluginManager:
|
||||
logger.warning(f"插件 {plugin_name} 版本兼容性检查失败: {e}")
|
||||
return False, f"插件 {plugin_name} 版本兼容性检查失败: {e}" # 检查失败时默认不允许加载
|
||||
|
||||
# == 显示统计与插件信息 ==
|
||||
|
||||
def _show_stats(self, total_registered: int, total_failed_registration: int):
|
||||
# sourcery skip: low-code-quality
|
||||
# 获取组件统计信息
|
||||
@@ -493,9 +480,15 @@ class PluginManager:
|
||||
|
||||
# 组件列表
|
||||
if plugin_info.components:
|
||||
action_components = [c for c in plugin_info.components if c.component_type == ComponentType.ACTION]
|
||||
command_components = [c for c in plugin_info.components if c.component_type == ComponentType.COMMAND]
|
||||
event_handler_components = [c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER]
|
||||
action_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.ACTION
|
||||
]
|
||||
command_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.COMMAND
|
||||
]
|
||||
event_handler_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER
|
||||
]
|
||||
|
||||
if action_components:
|
||||
action_names = [c.name for c in action_components]
|
||||
@@ -504,7 +497,7 @@ class PluginManager:
|
||||
if command_components:
|
||||
command_names = [c.name for c in command_components]
|
||||
logger.info(f" ⚡ Command组件: {', '.join(command_names)}")
|
||||
|
||||
|
||||
if event_handler_components:
|
||||
event_handler_names = [c.name for c in event_handler_components]
|
||||
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}")
|
||||
|
||||
@@ -10,6 +10,7 @@ from src.common.logger import get_logger
|
||||
# 导入API模块 - 标准Python包方式
|
||||
from src.plugin_system.apis import emoji_api, llm_api, message_api
|
||||
from src.plugins.built_in.core_actions.no_reply import NoReplyAction
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
logger = get_logger("emoji")
|
||||
@@ -102,7 +103,11 @@ class EmojiAction(BaseAction):
|
||||
这里是可用的情感标签:{available_emotions}
|
||||
请直接返回最匹配的那个情感标签,不要进行任何解释或添加其他多余的文字。
|
||||
"""
|
||||
logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
|
||||
|
||||
# 5. 调用LLM
|
||||
models = llm_api.get_available_models()
|
||||
|
||||
@@ -13,7 +13,7 @@ from src.plugin_system.apis import message_api
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
logger = get_logger("core_actions")
|
||||
logger = get_logger("no_reply_action")
|
||||
|
||||
|
||||
class NoReplyAction(BaseAction):
|
||||
|
||||
@@ -5,15 +5,10 @@
|
||||
这是系统的内置插件,提供基础的聊天交互功能
|
||||
"""
|
||||
|
||||
import random
|
||||
import time
|
||||
from typing import List, Tuple, Type
|
||||
import asyncio
|
||||
import re
|
||||
import traceback
|
||||
|
||||
# 导入新插件系统
|
||||
from src.plugin_system import BasePlugin, register_plugin, BaseAction, ComponentInfo, ActionActivationType, ChatMode
|
||||
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo, ActionActivationType
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from src.config.config import global_config
|
||||
|
||||
@@ -21,139 +16,12 @@ from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# 导入API模块 - 标准Python包方式
|
||||
from src.plugin_system.apis import generator_api, message_api
|
||||
from src.plugins.built_in.core_actions.no_reply import NoReplyAction
|
||||
from src.plugins.built_in.core_actions.emoji import EmojiAction
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.chat.mai_thinking.mai_think import mai_thinking_manager
|
||||
from src.plugins.built_in.core_actions.reply import ReplyAction
|
||||
|
||||
logger = get_logger("core_actions")
|
||||
|
||||
# 常量定义
|
||||
WAITING_TIME_THRESHOLD = 1200 # 等待新消息时间阈值,单位秒
|
||||
|
||||
ENABLE_THINKING = False
|
||||
|
||||
class ReplyAction(BaseAction):
|
||||
"""回复动作 - 参与聊天回复"""
|
||||
|
||||
# 激活设置
|
||||
focus_activation_type = ActionActivationType.NEVER
|
||||
normal_activation_type = ActionActivationType.NEVER
|
||||
mode_enable = ChatMode.FOCUS
|
||||
parallel_action = False
|
||||
|
||||
# 动作基本信息
|
||||
action_name = "reply"
|
||||
action_description = "参与聊天回复,发送文本进行表达"
|
||||
|
||||
# 动作参数定义
|
||||
action_parameters = {}
|
||||
|
||||
# 动作使用场景
|
||||
action_require = ["你想要闲聊或者随便附和", "有人提到你", "如果你刚刚进行了回复,不要对同一个话题重复回应"]
|
||||
|
||||
# 关联类型
|
||||
associated_types = ["text"]
|
||||
|
||||
def _parse_reply_target(self, target_message: str) -> tuple:
|
||||
sender = ""
|
||||
target = ""
|
||||
if ":" in target_message or ":" in target_message:
|
||||
# 使用正则表达式匹配中文或英文冒号
|
||||
parts = re.split(pattern=r"[::]", string=target_message, maxsplit=1)
|
||||
if len(parts) == 2:
|
||||
sender = parts[0].strip()
|
||||
target = parts[1].strip()
|
||||
return sender, target
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""执行回复动作"""
|
||||
logger.info(f"{self.log_prefix} 决定进行回复")
|
||||
start_time = self.action_data.get("loop_start_time", time.time())
|
||||
|
||||
user_id = self.user_id
|
||||
platform = self.platform
|
||||
# logger.info(f"{self.log_prefix} 用户ID: {user_id}, 平台: {platform}")
|
||||
person_id = get_person_info_manager().get_person_id(platform, user_id)
|
||||
# logger.info(f"{self.log_prefix} 人物ID: {person_id}")
|
||||
person_name = get_person_info_manager().get_value_sync(person_id, "person_name")
|
||||
reply_to = f"{person_name}:{self.action_message.get('processed_plain_text', '')}"
|
||||
logger.info(f"{self.log_prefix} 回复目标: {reply_to}")
|
||||
|
||||
try:
|
||||
if prepared_reply := self.action_data.get("prepared_reply", ""):
|
||||
reply_text = prepared_reply
|
||||
else:
|
||||
try:
|
||||
success, reply_set, _ = await asyncio.wait_for(
|
||||
generator_api.generate_reply(
|
||||
extra_info="",
|
||||
reply_to=reply_to,
|
||||
chat_id=self.chat_id,
|
||||
request_type="chat.replyer.focus",
|
||||
enable_tool=global_config.tool.enable_in_focus_chat,
|
||||
),
|
||||
timeout=global_config.chat.thinking_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"{self.log_prefix} 回复生成超时 ({global_config.chat.thinking_timeout}s)")
|
||||
return False, "timeout"
|
||||
|
||||
# 检查从start_time以来的新消息数量
|
||||
# 获取动作触发时间或使用默认值
|
||||
current_time = time.time()
|
||||
new_message_count = message_api.count_new_messages(
|
||||
chat_id=self.chat_id, start_time=start_time, end_time=current_time
|
||||
)
|
||||
|
||||
# 根据新消息数量决定是否使用reply_to
|
||||
need_reply = new_message_count >= random.randint(2, 4)
|
||||
logger.info(
|
||||
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,{'使用' if need_reply else '不使用'}引用回复"
|
||||
)
|
||||
# 构建回复文本
|
||||
reply_text = ""
|
||||
first_replied = False
|
||||
reply_to_platform_id = f"{platform}:{user_id}"
|
||||
for reply_seg in reply_set:
|
||||
data = reply_seg[1]
|
||||
if not first_replied:
|
||||
if need_reply:
|
||||
await self.send_text(
|
||||
content=data, reply_to=reply_to, reply_to_platform_id=reply_to_platform_id, typing=False
|
||||
)
|
||||
else:
|
||||
await self.send_text(content=data, reply_to_platform_id=reply_to_platform_id, typing=False)
|
||||
first_replied = True
|
||||
else:
|
||||
await self.send_text(content=data, reply_to_platform_id=reply_to_platform_id, typing=True)
|
||||
reply_text += data
|
||||
|
||||
# 存储动作记录
|
||||
reply_text = f"你对{person_name}进行了回复:{reply_text}"
|
||||
|
||||
|
||||
if ENABLE_THINKING:
|
||||
await mai_thinking_manager.get_mai_think(self.chat_id).do_think_after_response(reply_text)
|
||||
|
||||
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reply_text,
|
||||
action_done=True,
|
||||
)
|
||||
|
||||
# 重置NoReplyAction的连续计数器
|
||||
NoReplyAction.reset_consecutive_count()
|
||||
|
||||
return success, reply_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 回复动作执行失败: {e}")
|
||||
traceback.print_exc()
|
||||
return False, f"回复失败: {str(e)}"
|
||||
|
||||
|
||||
@register_plugin
|
||||
class CoreActionsPlugin(BasePlugin):
|
||||
@@ -168,11 +36,11 @@ class CoreActionsPlugin(BasePlugin):
|
||||
"""
|
||||
|
||||
# 插件基本信息
|
||||
plugin_name = "core_actions" # 内部标识符
|
||||
enable_plugin = True
|
||||
dependencies = [] # 插件依赖列表
|
||||
python_dependencies = [] # Python包依赖列表
|
||||
config_file_name = "config.toml"
|
||||
plugin_name: str = "core_actions" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = [] # 插件依赖列表
|
||||
python_dependencies: list[str] = [] # Python包依赖列表
|
||||
config_file_name: str = "config.toml"
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {
|
||||
@@ -181,7 +49,7 @@ class CoreActionsPlugin(BasePlugin):
|
||||
}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema = {
|
||||
config_schema: dict = {
|
||||
"plugin": {
|
||||
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
||||
"config_version": ConfigField(type=str, default="0.4.0", description="配置文件版本"),
|
||||
|
||||
149
src/plugins/built_in/core_actions/reply.py
Normal file
149
src/plugins/built_in/core_actions/reply.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# 导入新插件系统
|
||||
from src.plugin_system import BaseAction, ActionActivationType, ChatMode
|
||||
from src.config.config import global_config
|
||||
import random
|
||||
import time
|
||||
from typing import Tuple
|
||||
import asyncio
|
||||
import re
|
||||
import traceback
|
||||
|
||||
# 导入依赖的系统组件
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# 导入API模块 - 标准Python包方式
|
||||
from src.plugin_system.apis import generator_api, message_api
|
||||
from src.plugins.built_in.core_actions.no_reply import NoReplyAction
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.mais4u.mai_think import mai_thinking_manager
|
||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||
|
||||
logger = get_logger("reply_action")
|
||||
|
||||
|
||||
class ReplyAction(BaseAction):
|
||||
"""回复动作 - 参与聊天回复"""
|
||||
|
||||
# 激活设置
|
||||
focus_activation_type = ActionActivationType.NEVER
|
||||
normal_activation_type = ActionActivationType.NEVER
|
||||
mode_enable = ChatMode.FOCUS
|
||||
parallel_action = False
|
||||
|
||||
# 动作基本信息
|
||||
action_name = "reply"
|
||||
action_description = ""
|
||||
|
||||
# 动作参数定义
|
||||
action_parameters = {}
|
||||
|
||||
# 动作使用场景
|
||||
action_require = [""]
|
||||
|
||||
# 关联类型
|
||||
associated_types = ["text"]
|
||||
|
||||
def _parse_reply_target(self, target_message: str) -> tuple:
|
||||
sender = ""
|
||||
target = ""
|
||||
# 添加None检查,防止NoneType错误
|
||||
if target_message is None:
|
||||
return sender, target
|
||||
if ":" in target_message or ":" in target_message:
|
||||
# 使用正则表达式匹配中文或英文冒号
|
||||
parts = re.split(pattern=r"[::]", string=target_message, maxsplit=1)
|
||||
if len(parts) == 2:
|
||||
sender = parts[0].strip()
|
||||
target = parts[1].strip()
|
||||
return sender, target
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""执行回复动作"""
|
||||
logger.debug(f"{self.log_prefix} 决定进行回复")
|
||||
start_time = self.action_data.get("loop_start_time", time.time())
|
||||
|
||||
user_id = self.user_id
|
||||
platform = self.platform
|
||||
# logger.info(f"{self.log_prefix} 用户ID: {user_id}, 平台: {platform}")
|
||||
person_id = get_person_info_manager().get_person_id(platform, user_id) # type: ignore
|
||||
# logger.info(f"{self.log_prefix} 人物ID: {person_id}")
|
||||
person_name = get_person_info_manager().get_value_sync(person_id, "person_name")
|
||||
reply_to = f"{person_name}:{self.action_message.get('processed_plain_text', '')}" # type: ignore
|
||||
logger.info(f"{self.log_prefix} 决定进行回复,目标: {reply_to}")
|
||||
|
||||
try:
|
||||
if prepared_reply := self.action_data.get("prepared_reply", ""):
|
||||
reply_text = prepared_reply
|
||||
else:
|
||||
try:
|
||||
success, reply_set, _ = await asyncio.wait_for(
|
||||
generator_api.generate_reply(
|
||||
extra_info="",
|
||||
reply_to=reply_to,
|
||||
chat_id=self.chat_id,
|
||||
request_type="chat.replyer.focus",
|
||||
enable_tool=global_config.tool.enable_in_focus_chat,
|
||||
),
|
||||
timeout=global_config.chat.thinking_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"{self.log_prefix} 回复生成超时 ({global_config.chat.thinking_timeout}s)")
|
||||
return False, "timeout"
|
||||
|
||||
# 检查从start_time以来的新消息数量
|
||||
# 获取动作触发时间或使用默认值
|
||||
current_time = time.time()
|
||||
new_message_count = message_api.count_new_messages(
|
||||
chat_id=self.chat_id, start_time=start_time, end_time=current_time
|
||||
)
|
||||
|
||||
# 根据新消息数量决定是否使用reply_to
|
||||
need_reply = new_message_count >= random.randint(2, 4)
|
||||
if need_reply:
|
||||
logger.info(
|
||||
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,不使用引用回复"
|
||||
)
|
||||
|
||||
# 构建回复文本
|
||||
reply_text = ""
|
||||
first_replied = False
|
||||
reply_to_platform_id = f"{platform}:{user_id}"
|
||||
for reply_seg in reply_set:
|
||||
data = reply_seg[1]
|
||||
if not first_replied:
|
||||
if need_reply:
|
||||
await self.send_text(
|
||||
content=data, reply_to=reply_to, reply_to_platform_id=reply_to_platform_id, typing=False
|
||||
)
|
||||
else:
|
||||
await self.send_text(content=data, reply_to_platform_id=reply_to_platform_id, typing=False)
|
||||
first_replied = True
|
||||
else:
|
||||
await self.send_text(content=data, reply_to_platform_id=reply_to_platform_id, typing=True)
|
||||
reply_text += data
|
||||
|
||||
# 存储动作记录
|
||||
reply_text = f"你对{person_name}进行了回复:{reply_text}"
|
||||
|
||||
if ENABLE_S4U:
|
||||
await mai_thinking_manager.get_mai_think(self.chat_id).do_think_after_response(reply_text)
|
||||
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reply_text,
|
||||
action_done=True,
|
||||
)
|
||||
|
||||
# 重置NoReplyAction的连续计数器
|
||||
NoReplyAction.reset_consecutive_count()
|
||||
|
||||
return success, reply_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 回复动作执行失败: {e}")
|
||||
traceback.print_exc()
|
||||
return False, f"回复失败: {str(e)}"
|
||||
39
src/plugins/built_in/plugin_management/_manifest.json
Normal file
39
src/plugins/built_in/plugin_management/_manifest.json
Normal file
@@ -0,0 +1,39 @@
|
||||
{
|
||||
"manifest_version": 1,
|
||||
"name": "插件和组件管理 (Plugin and Component Management)",
|
||||
"version": "1.0.0",
|
||||
"description": "通过系统API管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。",
|
||||
"author": {
|
||||
"name": "MaiBot团队",
|
||||
"url": "https://github.com/MaiM-with-u"
|
||||
},
|
||||
"license": "GPL-v3.0-or-later",
|
||||
"host_application": {
|
||||
"min_version": "0.9.0"
|
||||
},
|
||||
"homepage_url": "https://github.com/MaiM-with-u/maibot",
|
||||
"repository_url": "https://github.com/MaiM-with-u/maibot",
|
||||
"keywords": [
|
||||
"plugins",
|
||||
"components",
|
||||
"management",
|
||||
"built-in"
|
||||
],
|
||||
"categories": [
|
||||
"Core System",
|
||||
"Plugin Management"
|
||||
],
|
||||
"default_locale": "zh-CN",
|
||||
"locales_path": "_locales",
|
||||
"plugin_info": {
|
||||
"is_built_in": true,
|
||||
"plugin_type": "plugin_management",
|
||||
"components": [
|
||||
{
|
||||
"type": "command",
|
||||
"name": "plugin_management",
|
||||
"description": "管理插件和组件的生命周期,包括加载、卸载、启用和禁用等操作。"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
440
src/plugins/built_in/plugin_management/plugin.py
Normal file
440
src/plugins/built_in/plugin_management/plugin.py
Normal file
@@ -0,0 +1,440 @@
|
||||
import asyncio
|
||||
|
||||
from typing import List, Tuple, Type
|
||||
from src.plugin_system import (
|
||||
BasePlugin,
|
||||
BaseCommand,
|
||||
CommandInfo,
|
||||
ConfigField,
|
||||
register_plugin,
|
||||
plugin_manage_api,
|
||||
component_manage_api,
|
||||
ComponentInfo,
|
||||
ComponentType,
|
||||
)
|
||||
|
||||
|
||||
class ManagementCommand(BaseCommand):
|
||||
command_name: str = "management"
|
||||
description: str = "管理命令"
|
||||
command_pattern: str = r"(?P<manage_command>^/pm(\s[a-zA-Z0-9_]+)*\s*$)"
|
||||
|
||||
async def execute(self) -> Tuple[bool, str, bool]:
|
||||
# sourcery skip: merge-duplicate-blocks
|
||||
if (
|
||||
not self.message
|
||||
or not self.message.message_info
|
||||
or not self.message.message_info.user_info
|
||||
or str(self.message.message_info.user_info.user_id) not in self.get_config("plugin.permission", []) # type: ignore
|
||||
):
|
||||
await self.send_text("你没有权限使用插件管理命令")
|
||||
return False, "没有权限", True
|
||||
command_list = self.matched_groups["manage_command"].strip().split(" ")
|
||||
if len(command_list) == 1:
|
||||
await self.show_help("all")
|
||||
return True, "帮助已发送", True
|
||||
if len(command_list) == 2:
|
||||
match command_list[1]:
|
||||
case "plugin":
|
||||
await self.show_help("plugin")
|
||||
case "component":
|
||||
await self.show_help("component")
|
||||
case "help":
|
||||
await self.show_help("all")
|
||||
case _:
|
||||
await self.send_text("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
if len(command_list) == 3:
|
||||
if command_list[1] == "plugin":
|
||||
match command_list[2]:
|
||||
case "help":
|
||||
await self.show_help("plugin")
|
||||
case "list":
|
||||
await self._list_registered_plugins()
|
||||
case "list_enabled":
|
||||
await self._list_loaded_plugins()
|
||||
case "rescan":
|
||||
await self._rescan_plugin_dirs()
|
||||
case _:
|
||||
await self.send_text("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
elif command_list[1] == "component":
|
||||
if command_list[2] == "list":
|
||||
await self._list_all_registered_components()
|
||||
elif command_list[2] == "help":
|
||||
await self.show_help("component")
|
||||
else:
|
||||
await self.send_text("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
else:
|
||||
await self.send_text("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
if len(command_list) == 4:
|
||||
if command_list[1] == "plugin":
|
||||
match command_list[2]:
|
||||
case "load":
|
||||
await self._load_plugin(command_list[3])
|
||||
case "unload":
|
||||
await self._unload_plugin(command_list[3])
|
||||
case "reload":
|
||||
await self._reload_plugin(command_list[3])
|
||||
case "add_dir":
|
||||
await self._add_dir(command_list[3])
|
||||
case _:
|
||||
await self.send_text("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
elif command_list[1] == "component":
|
||||
if command_list[2] != "list":
|
||||
await self.send_text("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
if command_list[3] == "enabled":
|
||||
await self._list_enabled_components()
|
||||
elif command_list[3] == "disabled":
|
||||
await self._list_disabled_components()
|
||||
else:
|
||||
await self.send_text("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
else:
|
||||
await self.send_text("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
if len(command_list) == 5:
|
||||
if command_list[1] != "component":
|
||||
await self.send_text("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
if command_list[2] != "list":
|
||||
await self.send_text("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
if command_list[3] == "enabled":
|
||||
await self._list_enabled_components(target_type=command_list[4])
|
||||
elif command_list[3] == "disabled":
|
||||
await self._list_disabled_components(target_type=command_list[4])
|
||||
elif command_list[3] == "type":
|
||||
await self._list_registered_components_by_type(command_list[4])
|
||||
else:
|
||||
await self.send_text("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
if len(command_list) == 6:
|
||||
if command_list[1] != "component":
|
||||
await self.send_text("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
if command_list[2] == "enable":
|
||||
if command_list[3] == "global":
|
||||
await self._globally_enable_component(command_list[4], command_list[5])
|
||||
elif command_list[3] == "local":
|
||||
await self._locally_enable_component(command_list[4], command_list[5])
|
||||
else:
|
||||
await self.send_text("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
elif command_list[2] == "disable":
|
||||
if command_list[3] == "global":
|
||||
await self._globally_disable_component(command_list[4], command_list[5])
|
||||
elif command_list[3] == "local":
|
||||
await self._locally_disable_component(command_list[4], command_list[5])
|
||||
else:
|
||||
await self.send_text("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
else:
|
||||
await self.send_text("插件管理命令不合法")
|
||||
return False, "命令不合法", True
|
||||
|
||||
return True, "命令执行完成", True
|
||||
|
||||
async def show_help(self, target: str):
|
||||
help_msg = ""
|
||||
match target:
|
||||
case "all":
|
||||
help_msg = (
|
||||
"管理命令帮助\n"
|
||||
"/pm help 管理命令提示\n"
|
||||
"/pm plugin 插件管理命令\n"
|
||||
"/pm component 组件管理命令\n"
|
||||
"使用 /pm plugin help 或 /pm component help 获取具体帮助"
|
||||
)
|
||||
case "plugin":
|
||||
help_msg = (
|
||||
"插件管理命令帮助\n"
|
||||
"/pm plugin help 插件管理命令提示\n"
|
||||
"/pm plugin list 列出所有注册的插件\n"
|
||||
"/pm plugin list_enabled 列出所有加载(启用)的插件\n"
|
||||
"/pm plugin rescan 重新扫描所有目录\n"
|
||||
"/pm plugin load <plugin_name> 加载指定插件\n"
|
||||
"/pm plugin unload <plugin_name> 卸载指定插件\n"
|
||||
"/pm plugin reload <plugin_name> 重新加载指定插件\n"
|
||||
"/pm plugin add_dir <directory_path> 添加插件目录\n"
|
||||
)
|
||||
case "component":
|
||||
help_msg = (
|
||||
"组件管理命令帮助\n"
|
||||
"/pm component help 组件管理命令提示\n"
|
||||
"/pm component list 列出所有注册的组件\n"
|
||||
"/pm component list enabled <可选: type> 列出所有启用的组件\n"
|
||||
"/pm component list disabled <可选: type> 列出所有禁用的组件\n"
|
||||
" - <type> 可选项: local,代表当前聊天中的;global,代表全局的\n"
|
||||
" - <type> 不填时为 global\n"
|
||||
"/pm component list type <component_type> 列出已经注册的指定类型的组件\n"
|
||||
"/pm component enable global <component_name> <component_type> 全局启用组件\n"
|
||||
"/pm component enable local <component_name> <component_type> 本聊天启用组件\n"
|
||||
"/pm component disable global <component_name> <component_type> 全局禁用组件\n"
|
||||
"/pm component disable local <component_name> <component_type> 本聊天禁用组件\n"
|
||||
" - <component_type> 可选项: action, command, event_handler\n"
|
||||
)
|
||||
case _:
|
||||
return
|
||||
await self.send_text(help_msg)
|
||||
|
||||
async def _list_loaded_plugins(self):
|
||||
plugins = plugin_manage_api.list_loaded_plugins()
|
||||
await self.send_text(f"已加载的插件: {', '.join(plugins)}")
|
||||
|
||||
async def _list_registered_plugins(self):
|
||||
plugins = plugin_manage_api.list_registered_plugins()
|
||||
await self.send_text(f"已注册的插件: {', '.join(plugins)}")
|
||||
|
||||
async def _rescan_plugin_dirs(self):
|
||||
plugin_manage_api.rescan_plugin_directory()
|
||||
await self.send_text("插件目录重新扫描执行中")
|
||||
|
||||
async def _load_plugin(self, plugin_name: str):
|
||||
success, count = plugin_manage_api.load_plugin(plugin_name)
|
||||
if success:
|
||||
await self.send_text(f"插件加载成功: {plugin_name}")
|
||||
else:
|
||||
if count == 0:
|
||||
await self.send_text(f"插件{plugin_name}为禁用状态")
|
||||
await self.send_text(f"插件加载失败: {plugin_name}")
|
||||
|
||||
async def _unload_plugin(self, plugin_name: str):
|
||||
success = await plugin_manage_api.remove_plugin(plugin_name)
|
||||
if success:
|
||||
await self.send_text(f"插件卸载成功: {plugin_name}")
|
||||
else:
|
||||
await self.send_text(f"插件卸载失败: {plugin_name}")
|
||||
|
||||
async def _reload_plugin(self, plugin_name: str):
|
||||
success = await plugin_manage_api.reload_plugin(plugin_name)
|
||||
if success:
|
||||
await self.send_text(f"插件重新加载成功: {plugin_name}")
|
||||
else:
|
||||
await self.send_text(f"插件重新加载失败: {plugin_name}")
|
||||
|
||||
async def _add_dir(self, dir_path: str):
|
||||
await self.send_text(f"正在添加插件目录: {dir_path}")
|
||||
success = plugin_manage_api.add_plugin_directory(dir_path)
|
||||
await asyncio.sleep(0.5) # 防止乱序发送
|
||||
if success:
|
||||
await self.send_text(f"插件目录添加成功: {dir_path}")
|
||||
else:
|
||||
await self.send_text(f"插件目录添加失败: {dir_path}")
|
||||
|
||||
def _fetch_all_registered_components(self) -> List[ComponentInfo]:
|
||||
all_plugin_info = component_manage_api.get_all_plugin_info()
|
||||
if not all_plugin_info:
|
||||
return []
|
||||
|
||||
components_info: List[ComponentInfo] = []
|
||||
for plugin_info in all_plugin_info.values():
|
||||
components_info.extend(plugin_info.components)
|
||||
return components_info
|
||||
|
||||
def _fetch_locally_disabled_components(self) -> List[str]:
|
||||
locally_disabled_components_actions = component_manage_api.get_locally_disabled_components(
|
||||
self.message.chat_stream.stream_id, ComponentType.ACTION
|
||||
)
|
||||
locally_disabled_components_commands = component_manage_api.get_locally_disabled_components(
|
||||
self.message.chat_stream.stream_id, ComponentType.COMMAND
|
||||
)
|
||||
locally_disabled_components_event_handlers = component_manage_api.get_locally_disabled_components(
|
||||
self.message.chat_stream.stream_id, ComponentType.EVENT_HANDLER
|
||||
)
|
||||
return (
|
||||
locally_disabled_components_actions
|
||||
+ locally_disabled_components_commands
|
||||
+ locally_disabled_components_event_handlers
|
||||
)
|
||||
|
||||
async def _list_all_registered_components(self):
|
||||
components_info = self._fetch_all_registered_components()
|
||||
if not components_info:
|
||||
await self.send_text("没有注册的组件")
|
||||
return
|
||||
|
||||
all_components_str = ", ".join(
|
||||
f"{component.name} ({component.component_type})" for component in components_info
|
||||
)
|
||||
await self.send_text(f"已注册的组件: {all_components_str}")
|
||||
|
||||
async def _list_enabled_components(self, target_type: str = "global"):
|
||||
components_info = self._fetch_all_registered_components()
|
||||
if not components_info:
|
||||
await self.send_text("没有注册的组件")
|
||||
return
|
||||
|
||||
if target_type == "global":
|
||||
enabled_components = [component for component in components_info if component.enabled]
|
||||
if not enabled_components:
|
||||
await self.send_text("没有满足条件的已启用全局组件")
|
||||
return
|
||||
enabled_components_str = ", ".join(
|
||||
f"{component.name} ({component.component_type})" for component in enabled_components
|
||||
)
|
||||
await self.send_text(f"满足条件的已启用全局组件: {enabled_components_str}")
|
||||
elif target_type == "local":
|
||||
locally_disabled_components = self._fetch_locally_disabled_components()
|
||||
enabled_components = [
|
||||
component
|
||||
for component in components_info
|
||||
if (component.name not in locally_disabled_components and component.enabled)
|
||||
]
|
||||
if not enabled_components:
|
||||
await self.send_text("本聊天没有满足条件的已启用组件")
|
||||
return
|
||||
enabled_components_str = ", ".join(
|
||||
f"{component.name} ({component.component_type})" for component in enabled_components
|
||||
)
|
||||
await self.send_text(f"本聊天满足条件的已启用组件: {enabled_components_str}")
|
||||
|
||||
async def _list_disabled_components(self, target_type: str = "global"):
|
||||
components_info = self._fetch_all_registered_components()
|
||||
if not components_info:
|
||||
await self.send_text("没有注册的组件")
|
||||
return
|
||||
|
||||
if target_type == "global":
|
||||
disabled_components = [component for component in components_info if not component.enabled]
|
||||
if not disabled_components:
|
||||
await self.send_text("没有满足条件的已禁用全局组件")
|
||||
return
|
||||
disabled_components_str = ", ".join(
|
||||
f"{component.name} ({component.component_type})" for component in disabled_components
|
||||
)
|
||||
await self.send_text(f"满足条件的已禁用全局组件: {disabled_components_str}")
|
||||
elif target_type == "local":
|
||||
locally_disabled_components = self._fetch_locally_disabled_components()
|
||||
disabled_components = [
|
||||
component
|
||||
for component in components_info
|
||||
if (component.name in locally_disabled_components or not component.enabled)
|
||||
]
|
||||
if not disabled_components:
|
||||
await self.send_text("本聊天没有满足条件的已禁用组件")
|
||||
return
|
||||
disabled_components_str = ", ".join(
|
||||
f"{component.name} ({component.component_type})" for component in disabled_components
|
||||
)
|
||||
await self.send_text(f"本聊天满足条件的已禁用组件: {disabled_components_str}")
|
||||
|
||||
async def _list_registered_components_by_type(self, target_type: str):
|
||||
match target_type:
|
||||
case "action":
|
||||
component_type = ComponentType.ACTION
|
||||
case "command":
|
||||
component_type = ComponentType.COMMAND
|
||||
case "event_handler":
|
||||
component_type = ComponentType.EVENT_HANDLER
|
||||
case _:
|
||||
await self.send_text(f"未知组件类型: {target_type}")
|
||||
return
|
||||
|
||||
components_info = component_manage_api.get_components_info_by_type(component_type)
|
||||
if not components_info:
|
||||
await self.send_text(f"没有注册的 {target_type} 组件")
|
||||
return
|
||||
|
||||
components_str = ", ".join(
|
||||
f"{name} ({component.component_type})" for name, component in components_info.items()
|
||||
)
|
||||
await self.send_text(f"注册的 {target_type} 组件: {components_str}")
|
||||
|
||||
async def _globally_enable_component(self, component_name: str, component_type: str):
|
||||
match component_type:
|
||||
case "action":
|
||||
target_component_type = ComponentType.ACTION
|
||||
case "command":
|
||||
target_component_type = ComponentType.COMMAND
|
||||
case "event_handler":
|
||||
target_component_type = ComponentType.EVENT_HANDLER
|
||||
case _:
|
||||
await self.send_text(f"未知组件类型: {component_type}")
|
||||
return
|
||||
if component_manage_api.globally_enable_component(component_name, target_component_type):
|
||||
await self.send_text(f"全局启用组件成功: {component_name}")
|
||||
else:
|
||||
await self.send_text(f"全局启用组件失败: {component_name}")
|
||||
|
||||
async def _globally_disable_component(self, component_name: str, component_type: str):
|
||||
match component_type:
|
||||
case "action":
|
||||
target_component_type = ComponentType.ACTION
|
||||
case "command":
|
||||
target_component_type = ComponentType.COMMAND
|
||||
case "event_handler":
|
||||
target_component_type = ComponentType.EVENT_HANDLER
|
||||
case _:
|
||||
await self.send_text(f"未知组件类型: {component_type}")
|
||||
return
|
||||
success = await component_manage_api.globally_disable_component(component_name, target_component_type)
|
||||
if success:
|
||||
await self.send_text(f"全局禁用组件成功: {component_name}")
|
||||
else:
|
||||
await self.send_text(f"全局禁用组件失败: {component_name}")
|
||||
|
||||
async def _locally_enable_component(self, component_name: str, component_type: str):
|
||||
match component_type:
|
||||
case "action":
|
||||
target_component_type = ComponentType.ACTION
|
||||
case "command":
|
||||
target_component_type = ComponentType.COMMAND
|
||||
case "event_handler":
|
||||
target_component_type = ComponentType.EVENT_HANDLER
|
||||
case _:
|
||||
await self.send_text(f"未知组件类型: {component_type}")
|
||||
return
|
||||
if component_manage_api.locally_enable_component(
|
||||
component_name,
|
||||
target_component_type,
|
||||
self.message.chat_stream.stream_id,
|
||||
):
|
||||
await self.send_text(f"本地启用组件成功: {component_name}")
|
||||
else:
|
||||
await self.send_text(f"本地启用组件失败: {component_name}")
|
||||
|
||||
async def _locally_disable_component(self, component_name: str, component_type: str):
|
||||
match component_type:
|
||||
case "action":
|
||||
target_component_type = ComponentType.ACTION
|
||||
case "command":
|
||||
target_component_type = ComponentType.COMMAND
|
||||
case "event_handler":
|
||||
target_component_type = ComponentType.EVENT_HANDLER
|
||||
case _:
|
||||
await self.send_text(f"未知组件类型: {component_type}")
|
||||
return
|
||||
if component_manage_api.locally_disable_component(
|
||||
component_name,
|
||||
target_component_type,
|
||||
self.message.chat_stream.stream_id,
|
||||
):
|
||||
await self.send_text(f"本地禁用组件成功: {component_name}")
|
||||
else:
|
||||
await self.send_text(f"本地禁用组件失败: {component_name}")
|
||||
|
||||
|
||||
@register_plugin
|
||||
class PluginManagementPlugin(BasePlugin):
|
||||
plugin_name: str = "plugin_management_plugin"
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = []
|
||||
python_dependencies: list[str] = []
|
||||
config_file_name: str = "config.toml"
|
||||
config_schema: dict = {
|
||||
"plugin": {
|
||||
"enable": ConfigField(bool, default=True, description="是否启用插件"),
|
||||
"permission": ConfigField(list, default=[], description="有权限使用插件管理命令的用户列表"),
|
||||
},
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[CommandInfo, Type[BaseCommand]]]:
|
||||
components = []
|
||||
if self.get_config("plugin.enable", True):
|
||||
components.append((ManagementCommand.get_command_info(), ManagementCommand))
|
||||
return components
|
||||
@@ -92,7 +92,7 @@ class TTSAction(BaseAction):
|
||||
|
||||
# 确保句子结尾有合适的标点
|
||||
if not any(processed_text.endswith(end) for end in [".", "?", "!", "。", "!", "?"]):
|
||||
processed_text = processed_text + "。"
|
||||
processed_text = f"{processed_text}。"
|
||||
|
||||
return processed_text
|
||||
|
||||
@@ -107,11 +107,11 @@ class TTSPlugin(BasePlugin):
|
||||
"""
|
||||
|
||||
# 插件基本信息
|
||||
plugin_name = "tts_plugin" # 内部标识符
|
||||
enable_plugin = True
|
||||
dependencies = [] # 插件依赖列表
|
||||
python_dependencies = [] # Python包依赖列表
|
||||
config_file_name = "config.toml"
|
||||
plugin_name: str = "tts_plugin" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = [] # 插件依赖列表
|
||||
python_dependencies: list[str] = [] # Python包依赖列表
|
||||
config_file_name: str = "config.toml"
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {
|
||||
@@ -121,7 +121,7 @@ class TTSPlugin(BasePlugin):
|
||||
}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema = {
|
||||
config_schema: dict = {
|
||||
"plugin": {
|
||||
"name": ConfigField(type=str, default="tts_plugin", description="插件名称", required=True),
|
||||
"version": ConfigField(type=str, default="0.1.0", description="插件版本号"),
|
||||
|
||||
Reference in New Issue
Block a user