ruff,私聊视为提及了bot

This commit is contained in:
Windpicker-owo
2025-09-20 22:34:22 +08:00
parent 006f9130b9
commit 444f1ca315
76 changed files with 1066 additions and 882 deletions

11
bot.py
View File

@@ -34,6 +34,7 @@ script_dir = os.path.dirname(os.path.abspath(__file__))
os.chdir(script_dir) os.chdir(script_dir)
logger.info(f"已设置工作目录为: {script_dir}") logger.info(f"已设置工作目录为: {script_dir}")
# 检查并创建.env文件 # 检查并创建.env文件
def ensure_env_file(): def ensure_env_file():
"""确保.env文件存在如果不存在则从模板创建""" """确保.env文件存在如果不存在则从模板创建"""
@@ -44,6 +45,7 @@ def ensure_env_file():
if template_env.exists(): if template_env.exists():
logger.info("未找到.env文件正在从模板创建...") logger.info("未找到.env文件正在从模板创建...")
import shutil import shutil
shutil.copy(template_env, env_file) shutil.copy(template_env, env_file)
logger.info("已从template/template.env创建.env文件") logger.info("已从template/template.env创建.env文件")
logger.warning("请编辑.env文件将EULA_CONFIRMED设置为true并配置其他必要参数") logger.warning("请编辑.env文件将EULA_CONFIRMED设置为true并配置其他必要参数")
@@ -51,6 +53,7 @@ def ensure_env_file():
logger.error("未找到.env文件和template.env模板文件") logger.error("未找到.env文件和template.env模板文件")
sys.exit(1) sys.exit(1)
# 确保环境文件存在 # 确保环境文件存在
ensure_env_file() ensure_env_file()
@@ -130,9 +133,9 @@ async def graceful_shutdown():
def check_eula(): def check_eula():
"""检查EULA和隐私条款确认状态 - 环境变量版类似Minecraft""" """检查EULA和隐私条款确认状态 - 环境变量版类似Minecraft"""
# 检查环境变量中的EULA确认 # 检查环境变量中的EULA确认
eula_confirmed = os.getenv('EULA_CONFIRMED', '').lower() eula_confirmed = os.getenv("EULA_CONFIRMED", "").lower()
if eula_confirmed == 'true': if eula_confirmed == "true":
logger.info("EULA已通过环境变量确认") logger.info("EULA已通过环境变量确认")
return return
@@ -148,8 +151,8 @@ def check_eula():
try: try:
load_dotenv(override=True) # 重新加载.env文件 load_dotenv(override=True) # 重新加载.env文件
eula_confirmed = os.getenv('EULA_CONFIRMED', '').lower() eula_confirmed = os.getenv("EULA_CONFIRMED", "").lower()
if eula_confirmed == 'true': if eula_confirmed == "true":
confirm_logger.info("EULA确认成功感谢您的同意") confirm_logger.info("EULA确认成功感谢您的同意")
return return

View File

@@ -20,16 +20,17 @@ files_to_update = [
"src/mais4u/mais4u_chat/s4u_mood_manager.py", "src/mais4u/mais4u_chat/s4u_mood_manager.py",
"src/plugin_system/core/tool_use.py", "src/plugin_system/core/tool_use.py",
"src/chat/memory_system/memory_activator.py", "src/chat/memory_system/memory_activator.py",
"src/chat/utils/smart_prompt.py" "src/chat/utils/smart_prompt.py",
] ]
def update_prompt_imports(file_path): def update_prompt_imports(file_path):
"""更新文件中的Prompt导入""" """更新文件中的Prompt导入"""
if not os.path.exists(file_path): if not os.path.exists(file_path):
print(f"文件不存在: {file_path}") print(f"文件不存在: {file_path}")
return False return False
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, "r", encoding="utf-8") as f:
content = f.read() content = f.read()
# 替换导入语句 # 替换导入语句
@@ -38,7 +39,7 @@ def update_prompt_imports(file_path):
if old_import in content: if old_import in content:
new_content = content.replace(old_import, new_import) new_content = content.replace(old_import, new_import)
with open(file_path, 'w', encoding='utf-8') as f: with open(file_path, "w", encoding="utf-8") as f:
f.write(new_content) f.write(new_content)
print(f"已更新: {file_path}") print(f"已更新: {file_path}")
return True return True
@@ -46,6 +47,7 @@ def update_prompt_imports(file_path):
print(f"无需更新: {file_path}") print(f"无需更新: {file_path}")
return False return False
def main(): def main():
"""主函数""" """主函数"""
updated_count = 0 updated_count = 0
@@ -55,5 +57,6 @@ def main():
print(f"\n更新完成!共更新了 {updated_count} 个文件") print(f"\n更新完成!共更新了 {updated_count} 个文件")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -5,4 +5,4 @@
from src.chat.affinity_flow.afc_manager import afc_manager from src.chat.affinity_flow.afc_manager import afc_manager
__all__ = ['afc_manager', 'AFCManager', 'AffinityFlowChatter'] __all__ = ["afc_manager", "AFCManager", "AffinityFlowChatter"]

View File

@@ -2,6 +2,7 @@
亲和力聊天处理流管理器 亲和力聊天处理流管理器
管理不同聊天流的亲和力聊天处理流,统一获取新消息并分发到对应的亲和力聊天处理流 管理不同聊天流的亲和力聊天处理流,统一获取新消息并分发到对应的亲和力聊天处理流
""" """
import time import time
import traceback import traceback
from typing import Dict, Optional, List from typing import Dict, Optional, List
@@ -20,7 +21,7 @@ class AFCManager:
def __init__(self): def __init__(self):
self.affinity_flow_chatters: Dict[str, "AffinityFlowChatter"] = {} self.affinity_flow_chatters: Dict[str, "AffinityFlowChatter"] = {}
'''所有聊天流的亲和力聊天处理流stream_id -> affinity_flow_chatter''' """所有聊天流的亲和力聊天处理流stream_id -> affinity_flow_chatter"""
# 动作管理器 # 动作管理器
self.action_manager = ActionManager() self.action_manager = ActionManager()
@@ -40,11 +41,7 @@ class AFCManager:
# 创建增强版规划器 # 创建增强版规划器
planner = ActionPlanner(stream_id, self.action_manager) planner = ActionPlanner(stream_id, self.action_manager)
chatter = AffinityFlowChatter( chatter = AffinityFlowChatter(stream_id=stream_id, planner=planner, action_manager=self.action_manager)
stream_id=stream_id,
planner=planner,
action_manager=self.action_manager
)
self.affinity_flow_chatters[stream_id] = chatter self.affinity_flow_chatters[stream_id] = chatter
logger.info(f"创建新的亲和力聊天处理器: {stream_id}") logger.info(f"创建新的亲和力聊天处理器: {stream_id}")
@@ -74,7 +71,6 @@ class AFCManager:
"executed_count": 0, "executed_count": 0,
} }
def get_chatter_stats(self, stream_id: str) -> Optional[Dict[str, any]]: def get_chatter_stats(self, stream_id: str) -> Optional[Dict[str, any]]:
"""获取聊天处理器统计""" """获取聊天处理器统计"""
if stream_id in self.affinity_flow_chatters: if stream_id in self.affinity_flow_chatters:
@@ -131,4 +127,5 @@ class AFCManager:
self.affinity_flow_chatters[stream_id].update_interest_keywords(new_keywords) self.affinity_flow_chatters[stream_id].update_interest_keywords(new_keywords)
logger.info(f"已更新聊天流 {stream_id} 的兴趣关键词: {list(new_keywords.keys())}") logger.info(f"已更新聊天流 {stream_id} 的兴趣关键词: {list(new_keywords.keys())}")
afc_manager = AFCManager() afc_manager = AFCManager()

View File

@@ -2,6 +2,7 @@
亲和力聊天处理器 亲和力聊天处理器
单个聊天流的处理器,负责处理特定聊天流的完整交互流程 单个聊天流的处理器,负责处理特定聊天流的完整交互流程
""" """
import time import time
import traceback import traceback
from datetime import datetime from datetime import datetime
@@ -57,10 +58,7 @@ class AffinityFlowChatter:
unread_messages = context.get_unread_messages() unread_messages = context.get_unread_messages()
# 使用增强版规划器处理消息 # 使用增强版规划器处理消息
actions, target_message = await self.planner.plan( actions, target_message = await self.planner.plan(mode=ChatMode.FOCUS, context=context)
mode=ChatMode.FOCUS,
context=context
)
self.stats["plans_created"] += 1 self.stats["plans_created"] += 1
# 执行动作(如果规划器返回了动作) # 执行动作(如果规划器返回了动作)
@@ -84,7 +82,9 @@ class AffinityFlowChatter:
**execution_result, **execution_result,
} }
logger.info(f"聊天流 {self.stream_id} StreamContext处理成功: 动作数={result['actions_count']}, 未读消息={result['unread_messages_processed']}") logger.info(
f"聊天流 {self.stream_id} StreamContext处理成功: 动作数={result['actions_count']}, 未读消息={result['unread_messages_processed']}"
)
return result return result
@@ -197,7 +197,9 @@ class AffinityFlowChatter:
def __repr__(self) -> str: def __repr__(self) -> str:
"""详细字符串表示""" """详细字符串表示"""
return (f"AffinityFlowChatter(stream_id={self.stream_id}, " return (
f"messages_processed={self.stats['messages_processed']}, " f"AffinityFlowChatter(stream_id={self.stream_id}, "
f"plans_created={self.stats['plans_created']}, " f"messages_processed={self.stats['messages_processed']}, "
f"last_activity={datetime.fromtimestamp(self.last_activity_time)})") f"plans_created={self.stats['plans_created']}, "
f"last_activity={datetime.fromtimestamp(self.last_activity_time)})"
)

View File

@@ -38,7 +38,9 @@ class InterestScoringSystem:
# 连续不回复概率提升 # 连续不回复概率提升
self.no_reply_count = 0 self.no_reply_count = 0
self.max_no_reply_count = affinity_config.max_no_reply_count self.max_no_reply_count = affinity_config.max_no_reply_count
self.probability_boost_per_no_reply = affinity_config.no_reply_threshold_adjustment / affinity_config.max_no_reply_count # 每次不回复增加的概率 self.probability_boost_per_no_reply = (
affinity_config.no_reply_threshold_adjustment / affinity_config.max_no_reply_count
) # 每次不回复增加的概率
# 用户关系数据 # 用户关系数据
self.user_relationships: Dict[str, float] = {} # user_id -> relationship_score self.user_relationships: Dict[str, float] = {} # user_id -> relationship_score
@@ -153,7 +155,9 @@ class InterestScoringSystem:
# 返回匹配分数,考虑置信度和匹配标签数量 # 返回匹配分数,考虑置信度和匹配标签数量
affinity_config = global_config.affinity_flow affinity_config = global_config.affinity_flow
match_count_bonus = min(len(match_result.matched_tags) * affinity_config.match_count_bonus, affinity_config.max_match_bonus) match_count_bonus = min(
len(match_result.matched_tags) * affinity_config.match_count_bonus, affinity_config.max_match_bonus
)
final_score = match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus final_score = match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus
logger.debug( logger.debug(
f"⚖️ 最终分数计算: 总分({match_result.overall_score:.3f}) × 1.3 × 置信度({match_result.confidence:.3f}) + 标签数量奖励({match_count_bonus:.3f}) = {final_score:.3f}" f"⚖️ 最终分数计算: 总分({match_result.overall_score:.3f}) × 1.3 × 置信度({match_result.confidence:.3f}) + 标签数量奖励({match_count_bonus:.3f}) = {final_score:.3f}"
@@ -263,7 +267,17 @@ class InterestScoringSystem:
if not msg.processed_plain_text: if not msg.processed_plain_text:
return 0.0 return 0.0
if msg.is_mentioned or (bot_nickname and bot_nickname in msg.processed_plain_text): # 检查是否被提及
is_mentioned = msg.is_mentioned or (bot_nickname and bot_nickname in msg.processed_plain_text)
# 检查是否为私聊group_info为None表示私聊
is_private_chat = msg.group_info is None
# 如果被提及或是私聊都视为提及了bot
if is_mentioned or is_private_chat:
logger.debug(f"🔍 提及检测 - 被提及: {is_mentioned}, 私聊: {is_private_chat}")
if is_private_chat and not is_mentioned:
logger.debug("💬 私聊消息自动视为提及bot")
return global_config.affinity_flow.mention_bot_interest_score return global_config.affinity_flow.mention_bot_interest_score
return 0.0 return 0.0
@@ -282,7 +296,9 @@ class InterestScoringSystem:
logger.debug(f"📋 基础阈值: {base_threshold:.3f}") logger.debug(f"📋 基础阈值: {base_threshold:.3f}")
# 如果被提及,降低阈值 # 如果被提及,降低阈值
if score.mentioned_score >= global_config.affinity_flow.mention_bot_interest_score * 0.5: # 使用提及bot兴趣分的一半作为判断阈值 if (
score.mentioned_score >= global_config.affinity_flow.mention_bot_interest_score * 0.5
): # 使用提及bot兴趣分的一半作为判断阈值
base_threshold = self.mention_threshold base_threshold = self.mention_threshold
logger.debug(f"📣 消息提及了机器人,使用降低阈值: {base_threshold:.3f}") logger.debug(f"📣 消息提及了机器人,使用降低阈值: {base_threshold:.3f}")
@@ -325,7 +341,9 @@ class InterestScoringSystem:
def update_user_relationship(self, user_id: str, relationship_change: float): def update_user_relationship(self, user_id: str, relationship_change: float):
"""更新用户关系""" """更新用户关系"""
old_score = self.user_relationships.get(user_id, global_config.affinity_flow.base_relationship_score) # 默认新用户分数 old_score = self.user_relationships.get(
user_id, global_config.affinity_flow.base_relationship_score
) # 默认新用户分数
new_score = max(0.0, min(1.0, old_score + relationship_change)) new_score = max(0.0, min(1.0, old_score + relationship_change))
self.user_relationships[user_id] = new_score self.user_relationships[user_id] = new_score

View File

@@ -116,6 +116,7 @@ class UserRelationshipTracker:
try: try:
# 获取bot人设信息 # 获取bot人设信息
from src.individuality.individuality import Individuality from src.individuality.individuality import Individuality
individuality = Individuality() individuality = Individuality()
bot_personality = await individuality.get_personality_block() bot_personality = await individuality.get_personality_block()
@@ -168,7 +169,17 @@ class UserRelationshipTracker:
# 清理LLM响应移除可能的格式标记 # 清理LLM响应移除可能的格式标记
cleaned_response = self._clean_llm_json_response(llm_response) cleaned_response = self._clean_llm_json_response(llm_response)
response_data = json.loads(cleaned_response) response_data = json.loads(cleaned_response)
new_score = max(0.0, min(1.0, float(response_data.get("new_relationship_score", global_config.affinity_flow.base_relationship_score)))) new_score = max(
0.0,
min(
1.0,
float(
response_data.get(
"new_relationship_score", global_config.affinity_flow.base_relationship_score
)
),
),
)
if self.interest_scoring_system: if self.interest_scoring_system:
self.interest_scoring_system.update_user_relationship( self.interest_scoring_system.update_user_relationship(
@@ -295,7 +306,9 @@ class UserRelationshipTracker:
# 更新缓存 # 更新缓存
self.user_relationship_cache[user_id] = { self.user_relationship_cache[user_id] = {
"relationship_text": relationship_data.get("relationship_text", ""), "relationship_text": relationship_data.get("relationship_text", ""),
"relationship_score": relationship_data.get("relationship_score", global_config.affinity_flow.base_relationship_score), "relationship_score": relationship_data.get(
"relationship_score", global_config.affinity_flow.base_relationship_score
),
"last_tracked": time.time(), "last_tracked": time.time(),
} }
return relationship_data.get("relationship_score", global_config.affinity_flow.base_relationship_score) return relationship_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)
@@ -386,7 +399,11 @@ class UserRelationshipTracker:
# 获取当前关系数据 # 获取当前关系数据
current_relationship = self._get_user_relationship_from_db(user_id) current_relationship = self._get_user_relationship_from_db(user_id)
current_score = current_relationship.get("relationship_score", global_config.affinity_flow.base_relationship_score) if current_relationship else global_config.affinity_flow.base_relationship_score current_score = (
current_relationship.get("relationship_score", global_config.affinity_flow.base_relationship_score)
if current_relationship
else global_config.affinity_flow.base_relationship_score
)
current_text = current_relationship.get("relationship_text", "新用户") if current_relationship else "新用户" current_text = current_relationship.get("relationship_text", "新用户") if current_relationship else "新用户"
# 使用LLM分析并更新关系 # 使用LLM分析并更新关系
@@ -501,6 +518,7 @@ class UserRelationshipTracker:
# 获取bot人设信息 # 获取bot人设信息
from src.individuality.individuality import Individuality from src.individuality.individuality import Individuality
individuality = Individuality() individuality = Individuality()
bot_personality = await individuality.get_personality_block() bot_personality = await individuality.get_personality_block()

View File

@@ -2,6 +2,7 @@
""" """
表情包发送历史记录模块 表情包发送历史记录模块
""" """
import os import os
from typing import List, Dict from typing import List, Dict
from collections import deque from collections import deque

View File

@@ -477,7 +477,7 @@ class EmojiManager:
emoji_options_str = "" emoji_options_str = ""
for i, emoji in enumerate(candidate_emojis): for i, emoji in enumerate(candidate_emojis):
# 为每个表情包创建一个编号和它的详细描述 # 为每个表情包创建一个编号和它的详细描述
emoji_options_str += f"编号: {i+1}\n描述: {emoji.description}\n\n" emoji_options_str += f"编号: {i + 1}\n描述: {emoji.description}\n\n"
# 精心设计的prompt引导LLM做出选择 # 精心设计的prompt引导LLM做出选择
prompt = f""" prompt = f"""
@@ -524,9 +524,7 @@ class EmojiManager:
self.record_usage(selected_emoji.hash) self.record_usage(selected_emoji.hash)
_time_end = time.time() _time_end = time.time()
logger.info( logger.info(f"找到匹配描述的表情包: {selected_emoji.description}, 耗时: {(_time_end - _time_start):.2f}s")
f"找到匹配描述的表情包: {selected_emoji.description}, 耗时: {(_time_end - _time_start):.2f}s"
)
# 8. 返回选中的表情包信息 # 8. 返回选中的表情包信息
return selected_emoji.full_path, f"[表情包:{selected_emoji.description}]", text_emotion return selected_emoji.full_path, f"[表情包:{selected_emoji.description}]", text_emotion
@@ -627,8 +625,9 @@ class EmojiManager:
# 无论steal_emoji是否开启都检查emoji文件夹以支持手动注册 # 无论steal_emoji是否开启都检查emoji文件夹以支持手动注册
# 只有在需要腾出空间或填充表情库时,才真正执行注册 # 只有在需要腾出空间或填充表情库时,才真正执行注册
if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or \ if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or (
(self.emoji_num < self.emoji_num_max): self.emoji_num < self.emoji_num_max
):
try: try:
# 获取目录下所有图片文件 # 获取目录下所有图片文件
files_to_process = [ files_to_process = [
@@ -931,16 +930,21 @@ class EmojiManager:
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64) image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() if Image.open(io.BytesIO(image_bytes)).format else "jpeg" image_format = (
Image.open(io.BytesIO(image_bytes)).format.lower()
if Image.open(io.BytesIO(image_bytes)).format
else "jpeg"
)
# 2. 检查数据库中是否已存在该表情包的描述,实现复用 # 2. 检查数据库中是否已存在该表情包的描述,实现复用
existing_description = None existing_description = None
try: try:
with get_db_session() as session: with get_db_session() as session:
existing_image = session.query(Images).filter( existing_image = (
(Images.emoji_hash == image_hash) & (Images.type == "emoji") session.query(Images)
).one_or_none() .filter((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
.one_or_none()
)
if existing_image and existing_image.description: if existing_image and existing_image.description:
existing_description = existing_image.description existing_description = existing_image.description
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...") logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")

View File

@@ -14,6 +14,7 @@ Chat Frequency Analyzer
- MIN_CHATS_FOR_PEAK: 在一个窗口内需要多少次聊天才能被认为是高峰时段。 - MIN_CHATS_FOR_PEAK: 在一个窗口内需要多少次聊天才能被认为是高峰时段。
- MIN_GAP_BETWEEN_PEAKS_HOURS: 两个独立高峰时段之间的最小间隔(小时)。 - MIN_GAP_BETWEEN_PEAKS_HOURS: 两个独立高峰时段之间的最小间隔(小时)。
""" """
import time as time_module import time as time_module
from datetime import datetime, timedelta, time from datetime import datetime, timedelta, time
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
@@ -71,7 +72,9 @@ class ChatFrequencyAnalyzer:
current_window_end = datetimes[i] current_window_end = datetimes[i]
# 合并重叠或相邻的高峰时段 # 合并重叠或相邻的高峰时段
if peak_windows and current_window_start - peak_windows[-1][1] < timedelta(hours=MIN_GAP_BETWEEN_PEAKS_HOURS): if peak_windows and current_window_start - peak_windows[-1][1] < timedelta(
hours=MIN_GAP_BETWEEN_PEAKS_HOURS
):
# 扩展上一个窗口的结束时间 # 扩展上一个窗口的结束时间
peak_windows[-1] = (peak_windows[-1][0], current_window_end) peak_windows[-1] = (peak_windows[-1][0], current_window_end)
else: else:

View File

@@ -14,6 +14,7 @@ Frequency-Based Proactive Trigger
- TRIGGER_CHECK_INTERVAL_SECONDS: 触发器检查的周期(秒)。 - TRIGGER_CHECK_INTERVAL_SECONDS: 触发器检查的周期(秒)。
- COOLDOWN_HOURS: 在同一个高峰时段内触发一次后的冷却时间(小时)。 - COOLDOWN_HOURS: 在同一个高峰时段内触发一次后的冷却时间(小时)。
""" """
import asyncio import asyncio
import time import time
from datetime import datetime from datetime import datetime
@@ -21,6 +22,7 @@ from typing import Dict, Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.affinity_flow.afc_manager import afc_manager from src.chat.affinity_flow.afc_manager import afc_manager
# TODO: 需要重新实现主动思考和睡眠管理功能 # TODO: 需要重新实现主动思考和睡眠管理功能
from .analyzer import chat_frequency_analyzer from .analyzer import chat_frequency_analyzer
@@ -74,7 +76,6 @@ class FrequencyBasedTrigger:
# 4. 检查当前是否是该用户的高峰聊天时间 # 4. 检查当前是否是该用户的高峰聊天时间
if chat_frequency_analyzer.is_in_peak_time(chat_id, now): if chat_frequency_analyzer.is_in_peak_time(chat_id, now):
# 5. 检查用户当前是否已有活跃的处理任务 # 5. 检查用户当前是否已有活跃的处理任务
# 亲和力流系统不直接提供循环状态,通过检查最后活动时间来判断是否忙碌 # 亲和力流系统不直接提供循环状态,通过检查最后活动时间来判断是否忙碌
chatter = afc_manager.get_or_create_chatter(chat_id) chatter = afc_manager.get_or_create_chatter(chat_id)

View File

@@ -4,14 +4,12 @@
""" """
from .bot_interest_manager import BotInterestManager, bot_interest_manager from .bot_interest_manager import BotInterestManager, bot_interest_manager
from src.common.data_models.bot_interest_data_model import ( from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
BotInterestTag, BotPersonalityInterests, InterestMatchResult
)
__all__ = [ __all__ = [
"BotInterestManager", "BotInterestManager",
"bot_interest_manager", "bot_interest_manager",
"BotInterestTag", "BotInterestTag",
"BotPersonalityInterests", "BotPersonalityInterests",
"InterestMatchResult" "InterestMatchResult",
] ]

View File

@@ -2,6 +2,7 @@
机器人兴趣标签管理系统 机器人兴趣标签管理系统
基于人设生成兴趣标签并使用embedding计算匹配度 基于人设生成兴趣标签并使用embedding计算匹配度
""" """
import orjson import orjson
import traceback import traceback
from typing import List, Dict, Optional, Any from typing import List, Dict, Optional, Any
@@ -10,9 +11,7 @@ import numpy as np
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
from src.common.data_models.bot_interest_data_model import ( from src.common.data_models.bot_interest_data_model import BotPersonalityInterests, BotInterestTag, InterestMatchResult
BotPersonalityInterests, BotInterestTag, InterestMatchResult
)
logger = get_logger("bot_interest_manager") logger = get_logger("bot_interest_manager")
@@ -87,7 +86,7 @@ class BotInterestManager:
logger.debug("✅ 成功导入embedding相关模块") logger.debug("✅ 成功导入embedding相关模块")
# 检查embedding配置是否存在 # 检查embedding配置是否存在
if not hasattr(model_config.model_task_config, 'embedding'): if not hasattr(model_config.model_task_config, "embedding"):
raise RuntimeError("❌ 未找到embedding模型配置") raise RuntimeError("❌ 未找到embedding模型配置")
logger.info("📋 找到embedding模型配置") logger.info("📋 找到embedding模型配置")
@@ -101,7 +100,7 @@ class BotInterestManager:
logger.info(f"🔗 客户端类型: {type(self.embedding_request).__name__}") logger.info(f"🔗 客户端类型: {type(self.embedding_request).__name__}")
# 获取第一个embedding模型的ModelInfo # 获取第一个embedding模型的ModelInfo
if hasattr(self.embedding_config, 'model_list') and self.embedding_config.model_list: if hasattr(self.embedding_config, "model_list") and self.embedding_config.model_list:
first_model_name = self.embedding_config.model_list[0] first_model_name = self.embedding_config.model_list[0]
logger.info(f"🎯 使用embedding模型: {first_model_name}") logger.info(f"🎯 使用embedding模型: {first_model_name}")
else: else:
@@ -127,7 +126,9 @@ class BotInterestManager:
# 生成新的兴趣标签 # 生成新的兴趣标签
logger.info("🆕 数据库中未找到兴趣标签,开始生成新的...") logger.info("🆕 数据库中未找到兴趣标签,开始生成新的...")
logger.info("🤖 正在调用LLM生成个性化兴趣标签...") logger.info("🤖 正在调用LLM生成个性化兴趣标签...")
generated_interests = await self._generate_interests_from_personality(personality_description, personality_id) generated_interests = await self._generate_interests_from_personality(
personality_description, personality_id
)
if generated_interests: if generated_interests:
self.current_interests = generated_interests self.current_interests = generated_interests
@@ -140,14 +141,16 @@ class BotInterestManager:
else: else:
raise RuntimeError("❌ 兴趣标签生成失败") raise RuntimeError("❌ 兴趣标签生成失败")
async def _generate_interests_from_personality(self, personality_description: str, personality_id: str) -> Optional[BotPersonalityInterests]: async def _generate_interests_from_personality(
self, personality_description: str, personality_id: str
) -> Optional[BotPersonalityInterests]:
"""根据人设生成兴趣标签""" """根据人设生成兴趣标签"""
try: try:
logger.info("🎨 开始根据人设生成兴趣标签...") logger.info("🎨 开始根据人设生成兴趣标签...")
logger.info(f"📝 人设长度: {len(personality_description)} 字符") logger.info(f"📝 人设长度: {len(personality_description)} 字符")
# 检查embedding客户端是否可用 # 检查embedding客户端是否可用
if not hasattr(self, 'embedding_request'): if not hasattr(self, "embedding_request"):
raise RuntimeError("❌ Embedding客户端未初始化无法生成兴趣标签") raise RuntimeError("❌ Embedding客户端未初始化无法生成兴趣标签")
# 构建提示词 # 构建提示词
@@ -190,8 +193,7 @@ class BotInterestManager:
interests_data = orjson.loads(response) interests_data = orjson.loads(response)
bot_interests = BotPersonalityInterests( bot_interests = BotPersonalityInterests(
personality_id=personality_id, personality_id=personality_id, personality_description=personality_description
personality_description=personality_description
) )
# 解析生成的兴趣标签 # 解析生成的兴趣标签
@@ -202,10 +204,7 @@ class BotInterestManager:
tag_name = tag_data.get("name", f"标签_{i}") tag_name = tag_data.get("name", f"标签_{i}")
weight = tag_data.get("weight", 0.5) weight = tag_data.get("weight", 0.5)
tag = BotInterestTag( tag = BotInterestTag(tag_name=tag_name, weight=weight)
tag_name=tag_name,
weight=weight
)
bot_interests.interest_tags.append(tag) bot_interests.interest_tags.append(tag)
logger.debug(f" 🏷️ {tag_name} (权重: {weight:.2f})") logger.debug(f" 🏷️ {tag_name} (权重: {weight:.2f})")
@@ -225,7 +224,6 @@ class BotInterestManager:
traceback.print_exc() traceback.print_exc()
raise raise
async def _call_llm_for_interest_generation(self, prompt: str) -> Optional[str]: async def _call_llm_for_interest_generation(self, prompt: str) -> Optional[str]:
"""调用LLM生成兴趣标签""" """调用LLM生成兴趣标签"""
try: try:
@@ -252,12 +250,14 @@ class BotInterestManager:
model_config=replyer_config, model_config=replyer_config,
request_type="interest_generation", request_type="interest_generation",
temperature=0.7, temperature=0.7,
max_tokens=2000 max_tokens=2000,
) )
if success and response: if success and response:
logger.info(f"✅ LLM响应成功模型: {model_name}, 响应长度: {len(response)} 字符") logger.info(f"✅ LLM响应成功模型: {model_name}, 响应长度: {len(response)} 字符")
logger.debug(f"📄 LLM响应内容: {response[:200]}..." if len(response) > 200 else f"📄 LLM响应内容: {response}") logger.debug(
f"📄 LLM响应内容: {response[:200]}..." if len(response) > 200 else f"📄 LLM响应内容: {response}"
)
if reasoning_content: if reasoning_content:
logger.debug(f"🧠 推理内容: {reasoning_content[:100]}...") logger.debug(f"🧠 推理内容: {reasoning_content[:100]}...")
@@ -279,14 +279,14 @@ class BotInterestManager:
import re import re
# 移除 ```json 和 ``` 标记 # 移除 ```json 和 ``` 标记
cleaned = re.sub(r'```json\s*', '', response) cleaned = re.sub(r"```json\s*", "", response)
cleaned = re.sub(r'\s*```', '', cleaned) cleaned = re.sub(r"\s*```", "", cleaned)
# 移除可能的多余空格和换行 # 移除可能的多余空格和换行
cleaned = cleaned.strip() cleaned = cleaned.strip()
# 尝试提取JSON对象如果响应中有其他文本 # 尝试提取JSON对象如果响应中有其他文本
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL) json_match = re.search(r"\{.*\}", cleaned, re.DOTALL)
if json_match: if json_match:
cleaned = json_match.group(0) cleaned = json_match.group(0)
@@ -295,7 +295,7 @@ class BotInterestManager:
async def _generate_embeddings_for_tags(self, interests: BotPersonalityInterests): async def _generate_embeddings_for_tags(self, interests: BotPersonalityInterests):
"""为所有兴趣标签生成embedding""" """为所有兴趣标签生成embedding"""
if not hasattr(self, 'embedding_request'): if not hasattr(self, "embedding_request"):
raise RuntimeError("❌ Embedding客户端未初始化无法生成embedding") raise RuntimeError("❌ Embedding客户端未初始化无法生成embedding")
total_tags = len(interests.interest_tags) total_tags = len(interests.interest_tags)
@@ -342,7 +342,7 @@ class BotInterestManager:
async def _get_embedding(self, text: str) -> List[float]: async def _get_embedding(self, text: str) -> List[float]:
"""获取文本的embedding向量""" """获取文本的embedding向量"""
if not hasattr(self, 'embedding_request'): if not hasattr(self, "embedding_request"):
raise RuntimeError("❌ Embedding请求客户端未初始化") raise RuntimeError("❌ Embedding请求客户端未初始化")
# 检查缓存 # 检查缓存
@@ -376,7 +376,9 @@ class BotInterestManager:
logger.debug(f"✅ 消息embedding生成成功维度: {len(embedding)}") logger.debug(f"✅ 消息embedding生成成功维度: {len(embedding)}")
return embedding return embedding
async def _calculate_similarity_scores(self, result: InterestMatchResult, message_embedding: List[float], keywords: List[str]): async def _calculate_similarity_scores(
self, result: InterestMatchResult, message_embedding: List[float], keywords: List[str]
):
"""计算消息与兴趣标签的相似度分数""" """计算消息与兴趣标签的相似度分数"""
try: try:
if not self.current_interests: if not self.current_interests:
@@ -397,7 +399,9 @@ class BotInterestManager:
# 设置相似度阈值为0.3 # 设置相似度阈值为0.3
if similarity > 0.3: if similarity > 0.3:
result.add_match(tag.tag_name, weighted_score, keywords) result.add_match(tag.tag_name, weighted_score, keywords)
logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 加权分数={weighted_score:.3f}") logger.debug(
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 加权分数={weighted_score:.3f}"
)
except Exception as e: except Exception as e:
logger.error(f"❌ 计算相似度分数失败: {e}") logger.error(f"❌ 计算相似度分数失败: {e}")
@@ -455,7 +459,9 @@ class BotInterestManager:
match_count += 1 match_count += 1
high_similarity_count += 1 high_similarity_count += 1
result.add_match(tag.tag_name, enhanced_score, [tag.tag_name]) result.add_match(tag.tag_name, enhanced_score, [tag.tag_name])
logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [高匹配]") logger.debug(
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [高匹配]"
)
elif similarity > medium_threshold: elif similarity > medium_threshold:
# 中相似度:中等加成 # 中相似度:中等加成
@@ -463,7 +469,9 @@ class BotInterestManager:
match_count += 1 match_count += 1
medium_similarity_count += 1 medium_similarity_count += 1
result.add_match(tag.tag_name, enhanced_score, [tag.tag_name]) result.add_match(tag.tag_name, enhanced_score, [tag.tag_name])
logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [中匹配]") logger.debug(
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [中匹配]"
)
elif similarity > low_threshold: elif similarity > low_threshold:
# 低相似度:轻微加成 # 低相似度:轻微加成
@@ -471,7 +479,9 @@ class BotInterestManager:
match_count += 1 match_count += 1
low_similarity_count += 1 low_similarity_count += 1
result.add_match(tag.tag_name, enhanced_score, [tag.tag_name]) result.add_match(tag.tag_name, enhanced_score, [tag.tag_name])
logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [低匹配]") logger.debug(
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [低匹配]"
)
logger.info(f"📈 匹配统计: {match_count}/{len(active_tags)} 个标签超过阈值") logger.info(f"📈 匹配统计: {match_count}/{len(active_tags)} 个标签超过阈值")
logger.info(f"🔥 高相似度匹配(>{high_threshold}): {high_similarity_count}") logger.info(f"🔥 高相似度匹配(>{high_threshold}): {high_similarity_count}")
@@ -488,7 +498,9 @@ class BotInterestManager:
original_score = result.match_scores[tag_name] original_score = result.match_scores[tag_name]
bonus = keyword_bonus[tag_name] bonus = keyword_bonus[tag_name]
result.match_scores[tag_name] = original_score + bonus result.match_scores[tag_name] = original_score + bonus
logger.debug(f" 🏷️ '{tag_name}': 原始分数={original_score:.3f}, 奖励={bonus:.3f}, 最终分数={result.match_scores[tag_name]:.3f}") logger.debug(
f" 🏷️ '{tag_name}': 原始分数={original_score:.3f}, 奖励={bonus:.3f}, 最终分数={result.match_scores[tag_name]:.3f}"
)
# 计算总体分数 # 计算总体分数
result.calculate_overall_score() result.calculate_overall_score()
@@ -499,10 +511,11 @@ class BotInterestManager:
result.top_tag = top_tag_name result.top_tag = top_tag_name
logger.info(f"🏆 最佳匹配标签: '{top_tag_name}' (分数: {result.match_scores[top_tag_name]:.3f})") logger.info(f"🏆 最佳匹配标签: '{top_tag_name}' (分数: {result.match_scores[top_tag_name]:.3f})")
logger.info(f"📊 最终结果: 总分={result.overall_score:.3f}, 置信度={result.confidence:.3f}, 匹配标签数={len(result.matched_tags)}") logger.info(
f"📊 最终结果: 总分={result.overall_score:.3f}, 置信度={result.confidence:.3f}, 匹配标签数={len(result.matched_tags)}"
)
return result return result
def _calculate_keyword_match_bonus(self, keywords: List[str], matched_tags: List[str]) -> Dict[str, float]: def _calculate_keyword_match_bonus(self, keywords: List[str], matched_tags: List[str]) -> Dict[str, float]:
"""计算关键词直接匹配奖励""" """计算关键词直接匹配奖励"""
if not keywords or not matched_tags: if not keywords or not matched_tags:
@@ -522,17 +535,25 @@ class BotInterestManager:
# 完全匹配 # 完全匹配
if keyword_lower == tag_name_lower: if keyword_lower == tag_name_lower:
bonus += affinity_config.high_match_interest_threshold * 0.6 # 使用高匹配阈值的60%作为完全匹配奖励 bonus += affinity_config.high_match_interest_threshold * 0.6 # 使用高匹配阈值的60%作为完全匹配奖励
logger.debug(f" 🎯 关键词完全匹配: '{keyword}' == '{tag_name}' (+{affinity_config.high_match_interest_threshold * 0.6:.3f})") logger.debug(
f" 🎯 关键词完全匹配: '{keyword}' == '{tag_name}' (+{affinity_config.high_match_interest_threshold * 0.6:.3f})"
)
# 包含匹配 # 包含匹配
elif keyword_lower in tag_name_lower or tag_name_lower in keyword_lower: elif keyword_lower in tag_name_lower or tag_name_lower in keyword_lower:
bonus += affinity_config.medium_match_interest_threshold * 0.3 # 使用中匹配阈值的30%作为包含匹配奖励 bonus += (
logger.debug(f" 🎯 关键词包含匹配: '{keyword}''{tag_name}' (+{affinity_config.medium_match_interest_threshold * 0.3:.3f})") affinity_config.medium_match_interest_threshold * 0.3
) # 使用中匹配阈值的30%作为包含匹配奖励
logger.debug(
f" 🎯 关键词包含匹配: '{keyword}''{tag_name}' (+{affinity_config.medium_match_interest_threshold * 0.3:.3f})"
)
# 部分匹配(编辑距离) # 部分匹配(编辑距离)
elif self._calculate_partial_match(keyword_lower, tag_name_lower): elif self._calculate_partial_match(keyword_lower, tag_name_lower):
bonus += affinity_config.low_match_interest_threshold * 0.4 # 使用低匹配阈值的40%作为部分匹配奖励 bonus += affinity_config.low_match_interest_threshold * 0.4 # 使用低匹配阈值的40%作为部分匹配奖励
logger.debug(f" 🎯 关键词部分匹配: '{keyword}''{tag_name}' (+{affinity_config.low_match_interest_threshold * 0.4:.3f})") logger.debug(
f" 🎯 关键词部分匹配: '{keyword}''{tag_name}' (+{affinity_config.low_match_interest_threshold * 0.4:.3f})"
)
if bonus > 0: if bonus > 0:
bonus_dict[tag_name] = min(bonus, affinity_config.max_match_bonus) # 使用配置的最大奖励限制 bonus_dict[tag_name] = min(bonus, affinity_config.max_match_bonus) # 使用配置的最大奖励限制
@@ -608,12 +629,12 @@ class BotInterestManager:
with get_db_session() as session: with get_db_session() as session:
# 查询最新的兴趣标签配置 # 查询最新的兴趣标签配置
db_interests = session.query(DBBotPersonalityInterests).filter( db_interests = (
DBBotPersonalityInterests.personality_id == personality_id session.query(DBBotPersonalityInterests)
).order_by( .filter(DBBotPersonalityInterests.personality_id == personality_id)
DBBotPersonalityInterests.version.desc(), .order_by(DBBotPersonalityInterests.version.desc(), DBBotPersonalityInterests.last_updated.desc())
DBBotPersonalityInterests.last_updated.desc() .first()
).first() )
if db_interests: if db_interests:
logger.info(f"✅ 找到数据库中的兴趣标签配置,版本: {db_interests.version}") logger.info(f"✅ 找到数据库中的兴趣标签配置,版本: {db_interests.version}")
@@ -631,7 +652,7 @@ class BotInterestManager:
personality_description=db_interests.personality_description, personality_description=db_interests.personality_description,
embedding_model=db_interests.embedding_model, embedding_model=db_interests.embedding_model,
version=db_interests.version, version=db_interests.version,
last_updated=db_interests.last_updated last_updated=db_interests.last_updated,
) )
# 解析兴趣标签 # 解析兴趣标签
@@ -639,10 +660,14 @@ class BotInterestManager:
tag = BotInterestTag( tag = BotInterestTag(
tag_name=tag_data.get("tag_name", ""), tag_name=tag_data.get("tag_name", ""),
weight=tag_data.get("weight", 0.5), weight=tag_data.get("weight", 0.5),
created_at=datetime.fromisoformat(tag_data.get("created_at", datetime.now().isoformat())), created_at=datetime.fromisoformat(
updated_at=datetime.fromisoformat(tag_data.get("updated_at", datetime.now().isoformat())), tag_data.get("created_at", datetime.now().isoformat())
),
updated_at=datetime.fromisoformat(
tag_data.get("updated_at", datetime.now().isoformat())
),
is_active=tag_data.get("is_active", True), is_active=tag_data.get("is_active", True),
embedding=tag_data.get("embedding") embedding=tag_data.get("embedding"),
) )
interests.interest_tags.append(tag) interests.interest_tags.append(tag)
@@ -685,7 +710,7 @@ class BotInterestManager:
"created_at": tag.created_at.isoformat(), "created_at": tag.created_at.isoformat(),
"updated_at": tag.updated_at.isoformat(), "updated_at": tag.updated_at.isoformat(),
"is_active": tag.is_active, "is_active": tag.is_active,
"embedding": tag.embedding "embedding": tag.embedding,
} }
tags_data.append(tag_dict) tags_data.append(tag_dict)
@@ -694,9 +719,11 @@ class BotInterestManager:
with get_db_session() as session: with get_db_session() as session:
# 检查是否已存在相同personality_id的记录 # 检查是否已存在相同personality_id的记录
existing_record = session.query(DBBotPersonalityInterests).filter( existing_record = (
DBBotPersonalityInterests.personality_id == interests.personality_id session.query(DBBotPersonalityInterests)
).first() .filter(DBBotPersonalityInterests.personality_id == interests.personality_id)
.first()
)
if existing_record: if existing_record:
# 更新现有记录 # 更新现有记录
@@ -718,7 +745,7 @@ class BotInterestManager:
interest_tags=json_data, interest_tags=json_data,
embedding_model=interests.embedding_model, embedding_model=interests.embedding_model,
version=interests.version, version=interests.version,
last_updated=interests.last_updated last_updated=interests.last_updated,
) )
session.add(new_record) session.add(new_record)
session.commit() session.commit()
@@ -728,9 +755,11 @@ class BotInterestManager:
# 验证保存是否成功 # 验证保存是否成功
with get_db_session() as session: with get_db_session() as session:
saved_record = session.query(DBBotPersonalityInterests).filter( saved_record = (
DBBotPersonalityInterests.personality_id == interests.personality_id session.query(DBBotPersonalityInterests)
).first() .filter(DBBotPersonalityInterests.personality_id == interests.personality_id)
.first()
)
session.commit() session.commit()
if saved_record: if saved_record:
logger.info(f"✅ 验证成功数据库中存在personality_id为 {interests.personality_id} 的记录") logger.info(f"✅ 验证成功数据库中存在personality_id为 {interests.personality_id} 的记录")
@@ -760,7 +789,7 @@ class BotInterestManager:
"total_tags": len(active_tags), "total_tags": len(active_tags),
"embedding_model": self.current_interests.embedding_model, "embedding_model": self.current_interests.embedding_model,
"last_updated": self.current_interests.last_updated.isoformat(), "last_updated": self.current_interests.last_updated.isoformat(),
"cache_size": len(self.embedding_cache) "cache_size": len(self.embedding_cache),
} }
async def update_interest_tags(self, new_personality_description: str = None): async def update_interest_tags(self, new_personality_description: str = None):
@@ -775,8 +804,7 @@ class BotInterestManager:
# 重新生成兴趣标签 # 重新生成兴趣标签
new_interests = await self._generate_interests_from_personality( new_interests = await self._generate_interests_from_personality(
self.current_interests.personality_description, self.current_interests.personality_description, self.current_interests.personality_id
self.current_interests.personality_id
) )
if new_interests: if new_interests:

View File

@@ -4,13 +4,11 @@
""" """
from .message_manager import MessageManager, message_manager from .message_manager import MessageManager, message_manager
from src.common.data_models.message_manager_data_model import StreamContext, MessageStatus, MessageManagerStats, StreamStats from src.common.data_models.message_manager_data_model import (
StreamContext,
MessageStatus,
MessageManagerStats,
StreamStats,
)
__all__ = [ __all__ = ["MessageManager", "message_manager", "StreamContext", "MessageStatus", "MessageManagerStats", "StreamStats"]
"MessageManager",
"message_manager",
"StreamContext",
"MessageStatus",
"MessageManagerStats",
"StreamStats"
]

View File

@@ -2,6 +2,7 @@
消息管理模块 消息管理模块
管理每个聊天流的上下文信息,包含历史记录和未读消息,定期检查并处理新消息 管理每个聊天流的上下文信息,包含历史记录和未读消息,定期检查并处理新消息
""" """
import asyncio import asyncio
import time import time
import traceback import traceback
@@ -100,9 +101,7 @@ class MessageManager:
# 如果没有处理任务,创建一个 # 如果没有处理任务,创建一个
if not context.processing_task or context.processing_task.done(): if not context.processing_task or context.processing_task.done():
context.processing_task = asyncio.create_task( context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
self._process_stream_messages(stream_id)
)
# 更新统计 # 更新统计
self.stats.active_streams = active_streams self.stats.active_streams = active_streams
@@ -175,7 +174,7 @@ class MessageManager:
unread_count=len(context.get_unread_messages()), unread_count=len(context.get_unread_messages()),
history_count=len(context.history_messages), history_count=len(context.history_messages),
last_check_time=context.last_check_time, last_check_time=context.last_check_time,
has_active_task=context.processing_task and not context.processing_task.done() has_active_task=context.processing_task and not context.processing_task.done(),
) )
def get_manager_stats(self) -> Dict[str, Any]: def get_manager_stats(self) -> Dict[str, Any]:
@@ -186,7 +185,7 @@ class MessageManager:
"total_unread_messages": self.stats.total_unread_messages, "total_unread_messages": self.stats.total_unread_messages,
"total_processed_messages": self.stats.total_processed_messages, "total_processed_messages": self.stats.total_processed_messages,
"uptime": self.stats.uptime, "uptime": self.stats.uptime,
"start_time": self.stats.start_time "start_time": self.stats.start_time,
} }
def cleanup_inactive_streams(self, max_inactive_hours: int = 24): def cleanup_inactive_streams(self, max_inactive_hours: int = 24):
@@ -196,8 +195,7 @@ class MessageManager:
inactive_streams = [] inactive_streams = []
for stream_id, context in self.stream_contexts.items(): for stream_id, context in self.stream_contexts.items():
if (current_time - context.last_check_time > max_inactive_seconds and if current_time - context.last_check_time > max_inactive_seconds and not context.get_unread_messages():
not context.get_unread_messages()):
inactive_streams.append(stream_id) inactive_streams.append(stream_id)
for stream_id in inactive_streams: for stream_id in inactive_streams:

View File

@@ -17,6 +17,7 @@ from src.plugin_system.core import component_registry, event_manager, global_ann
from src.plugin_system.base import BaseCommand, EventType from src.plugin_system.base import BaseCommand, EventType
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
from src.chat.utils.utils import is_mentioned_bot_in_message from src.chat.utils.utils import is_mentioned_bot_in_message
# 导入反注入系统 # 导入反注入系统
from src.chat.antipromptinjector import initialize_anti_injector from src.chat.antipromptinjector import initialize_anti_injector
@@ -511,7 +512,7 @@ class ChatBot:
chat_info_user_id=message.chat_stream.user_info.user_id, chat_info_user_id=message.chat_stream.user_info.user_id,
chat_info_user_nickname=message.chat_stream.user_info.user_nickname, chat_info_user_nickname=message.chat_stream.user_info.user_nickname,
chat_info_user_cardname=message.chat_stream.user_info.user_cardname, chat_info_user_cardname=message.chat_stream.user_info.user_cardname,
chat_info_user_platform=message.chat_stream.user_info.platform chat_info_user_platform=message.chat_stream.user_info.platform,
) )
# 如果是群聊,添加群组信息 # 如果是群聊,添加群组信息

View File

@@ -84,7 +84,9 @@ class ChatStream:
self.saved = False self.saved = False
self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息 self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息
# 从配置文件中读取focus_value如果没有则使用默认值1.0 # 从配置文件中读取focus_value如果没有则使用默认值1.0
self.focus_energy = data.get("focus_energy", global_config.chat.focus_value) if data else global_config.chat.focus_value self.focus_energy = (
data.get("focus_energy", global_config.chat.focus_value) if data else global_config.chat.focus_value
)
self.no_reply_consecutive = 0 self.no_reply_consecutive = 0
self.breaking_accumulated_interest = 0.0 self.breaking_accumulated_interest = 0.0

View File

@@ -168,7 +168,12 @@ class ActionManager:
if not chat_stream: if not chat_stream:
logger.error(f"{log_prefix} 无法找到chat_id对应的chat_stream: {chat_id}") logger.error(f"{log_prefix} 无法找到chat_id对应的chat_stream: {chat_id}")
return {"action_type": action_name, "success": False, "reply_text": "", "error": "chat_stream not found"} return {
"action_type": action_name,
"success": False,
"reply_text": "",
"error": "chat_stream not found",
}
if action_name == "no_action": if action_name == "no_action":
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""} return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}

View File

@@ -1,6 +1,7 @@
""" """
PlanGenerator: 负责搜集和汇总所有决策所需的信息,生成一个未经筛选的“原始计划” (Plan)。 PlanGenerator: 负责搜集和汇总所有决策所需的信息,生成一个未经筛选的“原始计划” (Plan)。
""" """
import time import time
from typing import Dict from typing import Dict
@@ -35,6 +36,7 @@ class PlanGenerator:
chat_id (str): 当前聊天的 ID。 chat_id (str): 当前聊天的 ID。
""" """
from src.chat.planner_actions.action_manager import ActionManager from src.chat.planner_actions.action_manager import ActionManager
self.chat_id = chat_id self.chat_id = chat_id
# 注意ActionManager 可能需要根据实际情况初始化 # 注意ActionManager 可能需要根据实际情况初始化
self.action_manager = ActionManager() self.action_manager = ActionManager()
@@ -65,7 +67,6 @@ class PlanGenerator:
) )
chat_history = [DatabaseMessages(**msg) for msg in chat_history_raw] chat_history = [DatabaseMessages(**msg) for msg in chat_history_raw]
plan = Plan( plan = Plan(
chat_id=self.chat_id, chat_id=self.chat_id,
mode=mode, mode=mode,
@@ -86,7 +87,7 @@ class PlanGenerator:
Dict[str, "ActionInfo"]: 一个字典,键是动作名称,值是 ActionInfo 对象。 Dict[str, "ActionInfo"]: 一个字典,键是动作名称,值是 ActionInfo 对象。
""" """
current_available_actions_dict = self.action_manager.get_using_actions() current_available_actions_dict = self.action_manager.get_using_actions()
all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
ComponentType.ACTION ComponentType.ACTION
) )
@@ -99,16 +100,13 @@ class PlanGenerator:
name="reply", name="reply",
component_type=ComponentType.ACTION, component_type=ComponentType.ACTION,
description="系统级动作:选择回复消息的决策", description="系统级动作:选择回复消息的决策",
action_parameters={ action_parameters={"content": "回复的文本内容", "reply_to_message_id": "要回复的消息ID"},
"content": "回复的文本内容",
"reply_to_message_id": "要回复的消息ID"
},
action_require=[ action_require=[
"你想要闲聊或者随便附和", "你想要闲聊或者随便附和",
"当用户提到你或艾特你时", "当用户提到你或艾特你时",
"当需要回答用户的问题时", "当需要回答用户的问题时",
"当你想参与对话时", "当你想参与对话时",
"当用户分享有趣的内容时" "当用户分享有趣的内容时",
], ],
activation_type=ActionActivationType.ALWAYS, activation_type=ActionActivationType.ALWAYS,
activation_keywords=[], activation_keywords=[],

View File

@@ -109,9 +109,7 @@ class ActionPlanner:
self.planner_stats["failed_plans"] += 1 self.planner_stats["failed_plans"] += 1
return [], None return [], None
async def _enhanced_plan_flow( async def _enhanced_plan_flow(self, mode: ChatMode, context: StreamContext) -> Tuple[List[Dict], Optional[Dict]]:
self, mode: ChatMode, context: StreamContext
) -> Tuple[List[Dict], Optional[Dict]]:
"""执行增强版规划流程""" """执行增强版规划流程"""
try: try:
# 1. 生成初始 Plan # 1. 生成初始 Plan
@@ -137,7 +135,9 @@ class ActionPlanner:
# 检查兴趣度是否达到非回复动作阈值 # 检查兴趣度是否达到非回复动作阈值
non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
if score < non_reply_action_interest_threshold: if score < non_reply_action_interest_threshold:
logger.info(f"❌ 兴趣度不足非回复动作阈值: {score:.3f} < {non_reply_action_interest_threshold:.3f}直接返回no_action") logger.info(
f"❌ 兴趣度不足非回复动作阈值: {score:.3f} < {non_reply_action_interest_threshold:.3f}直接返回no_action"
)
logger.info(f"📊 最低要求: {non_reply_action_interest_threshold:.3f}") logger.info(f"📊 最低要求: {non_reply_action_interest_threshold:.3f}")
# 直接返回 no_action # 直接返回 no_action
from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.data_models.info_data_model import ActionPlannerInfo

View File

@@ -598,6 +598,7 @@ class DefaultReplyer:
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]: def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
"""解析回复目标消息 - 使用共享工具""" """解析回复目标消息 - 使用共享工具"""
from src.chat.utils.prompt import Prompt from src.chat.utils.prompt import Prompt
if target_message is None: if target_message is None:
logger.warning("target_message为None返回默认值") logger.warning("target_message为None返回默认值")
return "未知用户", "(无消息内容)" return "未知用户", "(无消息内容)"
@@ -704,21 +705,23 @@ class DefaultReplyer:
unread_history_prompt = "" unread_history_prompt = ""
if unread_messages: if unread_messages:
# 尝试获取兴趣度评分 # 尝试获取兴趣度评分
interest_scores = await self._get_interest_scores_for_messages([msg.flatten() for msg in unread_messages]) interest_scores = await self._get_interest_scores_for_messages(
[msg.flatten() for msg in unread_messages]
)
unread_lines = [] unread_lines = []
for msg in unread_messages: for msg in unread_messages:
msg_id = msg.message_id msg_id = msg.message_id
msg_time = time.strftime('%H:%M:%S', time.localtime(msg.time)) msg_time = time.strftime("%H:%M:%S", time.localtime(msg.time))
msg_content = msg.processed_plain_text msg_content = msg.processed_plain_text
# 使用与已读历史消息相同的方法获取用户名 # 使用与已读历史消息相同的方法获取用户名
from src.person_info.person_info import PersonInfoManager, get_person_info_manager from src.person_info.person_info import PersonInfoManager, get_person_info_manager
# 获取用户信息 # 获取用户信息
user_info = getattr(msg, 'user_info', {}) user_info = getattr(msg, "user_info", {})
platform = getattr(user_info, 'platform', '') or getattr(msg, 'platform', '') platform = getattr(user_info, "platform", "") or getattr(msg, "platform", "")
user_id = getattr(user_info, 'user_id', '') or getattr(msg, 'user_id', '') user_id = getattr(user_info, "user_id", "") or getattr(msg, "user_id", "")
# 获取用户名 # 获取用户名
if platform and user_id: if platform and user_id:
@@ -808,7 +811,7 @@ class DefaultReplyer:
unread_lines = [] unread_lines = []
for msg in unread_messages: for msg in unread_messages:
msg_id = msg.get("message_id", "") msg_id = msg.get("message_id", "")
msg_time = time.strftime('%H:%M:%S', time.localtime(msg.get("time", time.time()))) msg_time = time.strftime("%H:%M:%S", time.localtime(msg.get("time", time.time())))
msg_content = msg.get("processed_plain_text", "") msg_content = msg.get("processed_plain_text", "")
# 使用与已读历史消息相同的方法获取用户名 # 使用与已读历史消息相同的方法获取用户名
@@ -834,7 +837,9 @@ class DefaultReplyer:
unread_lines.append(f"{msg_time} {sender_name}: {msg_content}{interest_text}") unread_lines.append(f"{msg_time} {sender_name}: {msg_content}{interest_text}")
unread_history_prompt_str = "\n".join(unread_lines) unread_history_prompt_str = "\n".join(unread_lines)
unread_history_prompt = f"这是未读历史消息,包含兴趣度评分,请优先对兴趣值高的消息做出动作:\n{unread_history_prompt_str}" unread_history_prompt = (
f"这是未读历史消息,包含兴趣度评分,请优先对兴趣值高的消息做出动作:\n{unread_history_prompt_str}"
)
else: else:
unread_history_prompt = "暂无未读历史消息" unread_history_prompt = "暂无未读历史消息"
@@ -1052,6 +1057,7 @@ class DefaultReplyer:
target_user_info = await person_info_manager.get_person_info_by_name(sender) target_user_info = await person_info_manager.get_person_info_by_name(sender)
from src.chat.utils.prompt import Prompt from src.chat.utils.prompt import Prompt
# 并行执行六个构建任务 # 并行执行六个构建任务
task_results = await asyncio.gather( task_results = await asyncio.gather(
self._time_and_run_task( self._time_and_run_task(
@@ -1127,6 +1133,7 @@ class DefaultReplyer:
schedule_block = "" schedule_block = ""
if global_config.planning_system.schedule_enable: if global_config.planning_system.schedule_enable:
from src.schedule.schedule_manager import schedule_manager from src.schedule.schedule_manager import schedule_manager
current_activity = schedule_manager.get_current_activity() current_activity = schedule_manager.get_current_activity()
if current_activity: if current_activity:
schedule_block = f"你当前正在:{current_activity}" schedule_block = f"你当前正在:{current_activity}"
@@ -1139,7 +1146,7 @@ class DefaultReplyer:
safety_guidelines = global_config.personality.safety_guidelines safety_guidelines = global_config.personality.safety_guidelines
safety_guidelines_block = "" safety_guidelines_block = ""
if safety_guidelines: if safety_guidelines:
guidelines_text = "\n".join(f"{i+1}. {line}" for i, line in enumerate(safety_guidelines)) guidelines_text = "\n".join(f"{i + 1}. {line}" for i, line in enumerate(safety_guidelines))
safety_guidelines_block = f"""### 安全与互动底线 safety_guidelines_block = f"""### 安全与互动底线
在任何情况下,你都必须遵守以下由你的设定者为你定义的原则: 在任何情况下,你都必须遵守以下由你的设定者为你定义的原则:
{guidelines_text} {guidelines_text}

View File

@@ -821,7 +821,7 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
try: try:
with get_db_session() as session: with get_db_session() as session:
image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar_one_or_none() image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar_one_or_none()
if image and image.description: # type: ignore if image and image.description: # type: ignore
description = image.description description = image.description
except Exception: except Exception:
# 如果查询失败,保持默认描述 # 如果查询失败,保持默认描述

View File

@@ -235,7 +235,7 @@ class Prompt:
template: str, template: str,
name: Optional[str] = None, name: Optional[str] = None,
parameters: Optional[PromptParameters] = None, parameters: Optional[PromptParameters] = None,
should_register: bool = True should_register: bool = True,
): ):
""" """
初始化统一提示词 初始化统一提示词
@@ -420,17 +420,19 @@ class Prompt:
await self._build_normal_chat_context(context_data) await self._build_normal_chat_context(context_data)
# 补充基础信息 # 补充基础信息
context_data.update({ context_data.update(
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt, {
"extra_info_block": self.parameters.extra_info_block, "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt,
"time_block": self.parameters.time_block or f"当前时间:{time.strftime('%Y-%m-%d %H:%M:%S')}", "extra_info_block": self.parameters.extra_info_block,
"identity": self.parameters.identity_block, "time_block": self.parameters.time_block or f"当前时间:{time.strftime('%Y-%m-%d %H:%M:%S')}",
"schedule_block": self.parameters.schedule_block, "identity": self.parameters.identity_block,
"moderation_prompt": self.parameters.moderation_prompt_block, "schedule_block": self.parameters.schedule_block,
"reply_target_block": self.parameters.reply_target_block, "moderation_prompt": self.parameters.moderation_prompt_block,
"mood_state": self.parameters.mood_prompt, "reply_target_block": self.parameters.reply_target_block,
"action_descriptions": self.parameters.action_descriptions, "mood_state": self.parameters.mood_prompt,
}) "action_descriptions": self.parameters.action_descriptions,
}
)
total_time = time.time() - start_time total_time = time.time() - start_time
logger.debug(f"上下文构建完成,总耗时: {total_time:.2f}s") logger.debug(f"上下文构建完成,总耗时: {total_time:.2f}s")
@@ -446,7 +448,7 @@ class Prompt:
self.parameters.message_list_before_now_long, self.parameters.message_list_before_now_long,
self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "", self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "",
self.parameters.sender, self.parameters.sender,
self.parameters.chat_id self.parameters.chat_id,
) )
context_data["read_history_prompt"] = read_history_prompt context_data["read_history_prompt"] = read_history_prompt
@@ -476,8 +478,6 @@ class Prompt:
except Exception as e: except Exception as e:
logger.error(f"构建S4U历史消息prompt失败: {e}") logger.error(f"构建S4U历史消息prompt失败: {e}")
async def _build_expression_habits(self) -> Dict[str, Any]: async def _build_expression_habits(self) -> Dict[str, Any]:
"""构建表达习惯""" """构建表达习惯"""
if not global_config.expression.enable_expression: if not global_config.expression.enable_expression:
@@ -491,10 +491,7 @@ class Prompt:
if self.parameters.message_list_before_now_long: if self.parameters.message_list_before_now_long:
recent_messages = self.parameters.message_list_before_now_long[-10:] recent_messages = self.parameters.message_list_before_now_long[-10:]
chat_history = build_readable_messages( chat_history = build_readable_messages(
recent_messages, recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
replace_bot_name=True,
timestamp_mode="normal",
truncate=True
) )
# 创建表情选择器 # 创建表情选择器
@@ -505,7 +502,7 @@ class Prompt:
chat_history=chat_history, chat_history=chat_history,
current_message=self.parameters.target, current_message=self.parameters.target,
emotional_tone="neutral", emotional_tone="neutral",
topic_type="general" topic_type="general",
) )
# 构建表达习惯块 # 构建表达习惯块
@@ -535,17 +532,13 @@ class Prompt:
if self.parameters.message_list_before_now_long: if self.parameters.message_list_before_now_long:
recent_messages = self.parameters.message_list_before_now_long[-20:] recent_messages = self.parameters.message_list_before_now_long[-20:]
chat_history = build_readable_messages( chat_history = build_readable_messages(
recent_messages, recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
replace_bot_name=True,
timestamp_mode="normal",
truncate=True
) )
# 激活长期记忆 # 激活长期记忆
memory_activator = MemoryActivator() memory_activator = MemoryActivator()
running_memories = await memory_activator.activate_memory_with_chat_history( running_memories = await memory_activator.activate_memory_with_chat_history(
target_message=self.parameters.target, target_message=self.parameters.target, chat_history_prompt=chat_history
chat_history_prompt=chat_history
) )
# 获取即时记忆 # 获取即时记忆
@@ -593,10 +586,7 @@ class Prompt:
if self.parameters.message_list_before_now_long: if self.parameters.message_list_before_now_long:
recent_messages = self.parameters.message_list_before_now_long[-15:] recent_messages = self.parameters.message_list_before_now_long[-15:]
chat_history = build_readable_messages( chat_history = build_readable_messages(
recent_messages, recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
replace_bot_name=True,
timestamp_mode="normal",
truncate=True
) )
# 创建工具执行器 # 创建工具执行器
@@ -607,12 +597,12 @@ class Prompt:
sender=self.parameters.sender, sender=self.parameters.sender,
target_message=self.parameters.target, target_message=self.parameters.target,
chat_history=chat_history, chat_history=chat_history,
return_details=False return_details=False,
) )
# 构建工具信息块 # 构建工具信息块
if tool_results: if tool_results:
tool_info_parts = ["## 工具信息","以下是你通过工具获取到的实时信息:"] tool_info_parts = ["## 工具信息", "以下是你通过工具获取到的实时信息:"]
for tool_result in tool_results: for tool_result in tool_results:
tool_name = tool_result.get("tool_name", "unknown") tool_name = tool_result.get("tool_name", "unknown")
content = tool_result.get("content", "") content = tool_result.get("content", "")
@@ -649,15 +639,12 @@ class Prompt:
# 搜索相关知识 # 搜索相关知识
knowledge_results = await qa_manager.get_knowledge( knowledge_results = await qa_manager.get_knowledge(
question=question, question=question, chat_id=self.parameters.chat_id, max_results=5, min_similarity=0.5
chat_id=self.parameters.chat_id,
max_results=5,
min_similarity=0.5
) )
# 构建知识块 # 构建知识块
if knowledge_results and knowledge_results.get("knowledge_items"): if knowledge_results and knowledge_results.get("knowledge_items"):
knowledge_parts = ["## 知识库信息","以下是与你当前对话相关的知识信息:"] knowledge_parts = ["## 知识库信息", "以下是与你当前对话相关的知识信息:"]
for item in knowledge_results["knowledge_items"]: for item in knowledge_results["knowledge_items"]:
content = item.get("content", "") content = item.get("content", "")
@@ -725,9 +712,11 @@ class Prompt:
"time_block": context_data.get("time_block", ""), "time_block": context_data.get("time_block", ""),
"reply_target_block": context_data.get("reply_target_block", ""), "reply_target_block": context_data.get("reply_target_block", ""),
"reply_style": global_config.personality.reply_style, "reply_style": global_config.personality.reply_style,
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""), "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""), "moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
"safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""), "safety_guidelines_block": self.parameters.safety_guidelines_block
or context_data.get("safety_guidelines_block", ""),
} }
def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
@@ -749,9 +738,11 @@ class Prompt:
"reply_target_block": context_data.get("reply_target_block", ""), "reply_target_block": context_data.get("reply_target_block", ""),
"config_expression_style": global_config.personality.reply_style, "config_expression_style": global_config.personality.reply_style,
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""), "mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""), "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""), "moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
"safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""), "safety_guidelines_block": self.parameters.safety_guidelines_block
or context_data.get("safety_guidelines_block", ""),
} }
def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
@@ -769,9 +760,11 @@ class Prompt:
"reason": "", "reason": "",
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""), "mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"reply_style": global_config.personality.reply_style, "reply_style": global_config.personality.reply_style,
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""), "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""), "moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
"safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""), "safety_guidelines_block": self.parameters.safety_guidelines_block
or context_data.get("safety_guidelines_block", ""),
} }
def format(self, *args, **kwargs) -> str: def format(self, *args, **kwargs) -> str:
@@ -872,9 +865,7 @@ class Prompt:
return await relationship_fetcher.build_relation_info(person_id, points_num=5) return await relationship_fetcher.build_relation_info(person_id, points_num=5)
@staticmethod @staticmethod
async def build_cross_context( async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]) -> str:
chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]
) -> str:
""" """
构建跨群聊上下文 - 统一实现 构建跨群聊上下文 - 统一实现
@@ -937,10 +928,7 @@ class Prompt:
# 工厂函数 # 工厂函数
def create_prompt( def create_prompt(
template: str, template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
name: Optional[str] = None,
parameters: Optional[PromptParameters] = None,
**kwargs
) -> Prompt: ) -> Prompt:
"""快速创建Prompt实例的工厂函数""" """快速创建Prompt实例的工厂函数"""
if parameters is None: if parameters is None:
@@ -949,14 +937,10 @@ def create_prompt(
async def create_prompt_async( async def create_prompt_async(
template: str, template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
name: Optional[str] = None,
parameters: Optional[PromptParameters] = None,
**kwargs
) -> Prompt: ) -> Prompt:
"""异步创建Prompt实例""" """异步创建Prompt实例"""
prompt = create_prompt(template, name, parameters, **kwargs) prompt = create_prompt(template, name, parameters, **kwargs)
if global_prompt_manager._context._current_context: if global_prompt_manager._context._current_context:
await global_prompt_manager._context.register_async(prompt) await global_prompt_manager._context.register_async(prompt)
return prompt return prompt

View File

@@ -343,7 +343,7 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
if split_mode == "llm": if split_mode == "llm":
logger.debug("未检测到 [SPLIT] 标记,本次不进行分割。") logger.debug("未检测到 [SPLIT] 标记,本次不进行分割。")
split_sentences = [cleaned_text] split_sentences = [cleaned_text]
else: # mode == "punctuation" else: # mode == "punctuation"
logger.debug("使用基于标点的传统模式进行分割。") logger.debug("使用基于标点的传统模式进行分割。")
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text) split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
else: else:

View File

@@ -6,6 +6,7 @@ class BaseDataModel:
def deepcopy(self): def deepcopy(self):
return copy.deepcopy(self) return copy.deepcopy(self)
def temporarily_transform_class_to_dict(obj: Any) -> Any: def temporarily_transform_class_to_dict(obj: Any) -> Any:
# sourcery skip: assign-if-exp, reintroduce-else # sourcery skip: assign-if-exp, reintroduce-else
""" """

View File

@@ -2,6 +2,7 @@
机器人兴趣标签数据模型 机器人兴趣标签数据模型
定义机器人的兴趣标签和相关的embedding数据结构 定义机器人的兴趣标签和相关的embedding数据结构
""" """
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Dict, Optional, Any from typing import List, Dict, Optional, Any
from datetime import datetime from datetime import datetime
@@ -12,6 +13,7 @@ from . import BaseDataModel
@dataclass @dataclass
class BotInterestTag(BaseDataModel): class BotInterestTag(BaseDataModel):
"""机器人兴趣标签""" """机器人兴趣标签"""
tag_name: str tag_name: str
weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0) weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0)
embedding: Optional[List[float]] = None # 标签的embedding向量 embedding: Optional[List[float]] = None # 标签的embedding向量
@@ -27,7 +29,7 @@ class BotInterestTag(BaseDataModel):
"embedding": self.embedding, "embedding": self.embedding,
"created_at": self.created_at.isoformat(), "created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(), "updated_at": self.updated_at.isoformat(),
"is_active": self.is_active "is_active": self.is_active,
} }
@classmethod @classmethod
@@ -39,13 +41,14 @@ class BotInterestTag(BaseDataModel):
embedding=data.get("embedding"), embedding=data.get("embedding"),
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(), created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(),
updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else datetime.now(), updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else datetime.now(),
is_active=data.get("is_active", True) is_active=data.get("is_active", True),
) )
@dataclass @dataclass
class BotPersonalityInterests(BaseDataModel): class BotPersonalityInterests(BaseDataModel):
"""机器人人格化兴趣配置""" """机器人人格化兴趣配置"""
personality_id: str personality_id: str
personality_description: str # 人设描述文本 personality_description: str # 人设描述文本
interest_tags: List[BotInterestTag] = field(default_factory=list) interest_tags: List[BotInterestTag] = field(default_factory=list)
@@ -57,7 +60,6 @@ class BotPersonalityInterests(BaseDataModel):
"""获取活跃的兴趣标签""" """获取活跃的兴趣标签"""
return [tag for tag in self.interest_tags if tag.is_active] return [tag for tag in self.interest_tags if tag.is_active]
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式""" """转换为字典格式"""
return { return {
@@ -66,7 +68,7 @@ class BotPersonalityInterests(BaseDataModel):
"interest_tags": [tag.to_dict() for tag in self.interest_tags], "interest_tags": [tag.to_dict() for tag in self.interest_tags],
"embedding_model": self.embedding_model, "embedding_model": self.embedding_model,
"last_updated": self.last_updated.isoformat(), "last_updated": self.last_updated.isoformat(),
"version": self.version "version": self.version,
} }
@classmethod @classmethod
@@ -78,13 +80,14 @@ class BotPersonalityInterests(BaseDataModel):
interest_tags=[BotInterestTag.from_dict(tag_data) for tag_data in data.get("interest_tags", [])], interest_tags=[BotInterestTag.from_dict(tag_data) for tag_data in data.get("interest_tags", [])],
embedding_model=data.get("embedding_model", "text-embedding-ada-002"), embedding_model=data.get("embedding_model", "text-embedding-ada-002"),
last_updated=datetime.fromisoformat(data["last_updated"]) if data.get("last_updated") else datetime.now(), last_updated=datetime.fromisoformat(data["last_updated"]) if data.get("last_updated") else datetime.now(),
version=data.get("version", 1) version=data.get("version", 1),
) )
@dataclass @dataclass
class InterestMatchResult(BaseDataModel): class InterestMatchResult(BaseDataModel):
"""兴趣匹配结果""" """兴趣匹配结果"""
message_id: str message_id: str
matched_tags: List[str] = field(default_factory=list) matched_tags: List[str] = field(default_factory=list)
match_scores: Dict[str, float] = field(default_factory=dict) # tag_name -> score match_scores: Dict[str, float] = field(default_factory=dict) # tag_name -> score
@@ -120,7 +123,9 @@ class InterestMatchResult(BaseDataModel):
# 计算置信度(基于匹配标签数量和分数分布) # 计算置信度(基于匹配标签数量和分数分布)
if len(self.match_scores) > 0: if len(self.match_scores) > 0:
avg_score = self.overall_score avg_score = self.overall_score
score_variance = sum((score - avg_score) ** 2 for score in self.match_scores.values()) / len(self.match_scores) score_variance = sum((score - avg_score) ** 2 for score in self.match_scores.values()) / len(
self.match_scores
)
# 分数越集中,置信度越高 # 分数越集中,置信度越高
self.confidence = max(0.0, 1.0 - score_variance) self.confidence = max(0.0, 1.0 - score_variance)
else: else:

View File

@@ -208,6 +208,7 @@ class DatabaseMessages(BaseDataModel):
"chat_info_user_cardname": self.chat_info.user_info.user_cardname, "chat_info_user_cardname": self.chat_info.user_info.user_cardname,
} }
@dataclass(init=False) @dataclass(init=False)
class DatabaseActionRecords(BaseDataModel): class DatabaseActionRecords(BaseDataModel):
def __init__( def __init__(

View File

@@ -28,6 +28,7 @@ class ActionPlannerInfo(BaseDataModel):
@dataclass @dataclass
class InterestScore(BaseDataModel): class InterestScore(BaseDataModel):
"""兴趣度评分结果""" """兴趣度评分结果"""
message_id: str message_id: str
total_score: float total_score: float
interest_match_score: float interest_match_score: float
@@ -41,6 +42,7 @@ class Plan(BaseDataModel):
""" """
统一规划数据模型 统一规划数据模型
""" """
chat_id: str chat_id: str
mode: "ChatMode" mode: "ChatMode"

View File

@@ -2,9 +2,11 @@ from dataclasses import dataclass
from typing import Optional, List, Tuple, TYPE_CHECKING, Any from typing import Optional, List, Tuple, TYPE_CHECKING, Any
from . import BaseDataModel from . import BaseDataModel
if TYPE_CHECKING: if TYPE_CHECKING:
from src.llm_models.payload_content.tool_option import ToolCall from src.llm_models.payload_content.tool_option import ToolCall
@dataclass @dataclass
class LLMGenerationDataModel(BaseDataModel): class LLMGenerationDataModel(BaseDataModel):
content: Optional[str] = None content: Optional[str] = None

View File

@@ -2,6 +2,7 @@
消息管理模块数据模型 消息管理模块数据模型
定义消息管理器使用的数据结构 定义消息管理器使用的数据结构
""" """
import asyncio import asyncio
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
@@ -16,14 +17,16 @@ if TYPE_CHECKING:
class MessageStatus(Enum): class MessageStatus(Enum):
"""消息状态枚举""" """消息状态枚举"""
UNREAD = "unread" # 未读消息
READ = "read" # 读消息 UNREAD = "unread" # 读消息
READ = "read" # 已读消息
PROCESSING = "processing" # 处理中 PROCESSING = "processing" # 处理中
@dataclass @dataclass
class StreamContext(BaseDataModel): class StreamContext(BaseDataModel):
"""聊天流上下文信息""" """聊天流上下文信息"""
stream_id: str stream_id: str
unread_messages: List["DatabaseMessages"] = field(default_factory=list) unread_messages: List["DatabaseMessages"] = field(default_factory=list)
history_messages: List["DatabaseMessages"] = field(default_factory=list) history_messages: List["DatabaseMessages"] = field(default_factory=list)
@@ -59,6 +62,7 @@ class StreamContext(BaseDataModel):
@dataclass @dataclass
class MessageManagerStats(BaseDataModel): class MessageManagerStats(BaseDataModel):
"""消息管理器统计信息""" """消息管理器统计信息"""
total_streams: int = 0 total_streams: int = 0
active_streams: int = 0 active_streams: int = 0
total_unread_messages: int = 0 total_unread_messages: int = 0
@@ -74,6 +78,7 @@ class MessageManagerStats(BaseDataModel):
@dataclass @dataclass
class StreamStats(BaseDataModel): class StreamStats(BaseDataModel):
"""聊天流统计信息""" """聊天流统计信息"""
stream_id: str stream_id: str
is_active: bool is_active: bool
unread_count: int unread_count: int

View File

@@ -31,7 +31,9 @@ class TelemetryHeartBeatTask(AsyncTask):
self.client_uuid: str | None = local_storage["mofox_uuid"] if "mofox_uuid" in local_storage else None # type: ignore self.client_uuid: str | None = local_storage["mofox_uuid"] if "mofox_uuid" in local_storage else None # type: ignore
"""客户端UUID""" """客户端UUID"""
self.private_key_pem: str | None = local_storage["mofox_private_key"] if "mofox_private_key" in local_storage else None # type: ignore self.private_key_pem: str | None = (
local_storage["mofox_private_key"] if "mofox_private_key" in local_storage else None
) # type: ignore
"""客户端私钥""" """客户端私钥"""
self.info_dict = self._get_sys_info() self.info_dict = self._get_sys_info()
@@ -75,10 +77,7 @@ class TelemetryHeartBeatTask(AsyncTask):
sign_data = f"{self.client_uuid}:{timestamp}:{json.dumps(request_body, separators=(',', ':'))}" sign_data = f"{self.client_uuid}:{timestamp}:{json.dumps(request_body, separators=(',', ':'))}"
# 加载私钥 # 加载私钥
private_key = serialization.load_pem_private_key( private_key = serialization.load_pem_private_key(self.private_key_pem.encode("utf-8"), password=None)
self.private_key_pem.encode('utf-8'),
password=None
)
# 确保是RSA私钥 # 确保是RSA私钥
if not isinstance(private_key, rsa.RSAPrivateKey): if not isinstance(private_key, rsa.RSAPrivateKey):
@@ -86,16 +85,13 @@ class TelemetryHeartBeatTask(AsyncTask):
# 生成签名 # 生成签名
signature = private_key.sign( signature = private_key.sign(
sign_data.encode('utf-8'), sign_data.encode("utf-8"),
padding.PSS( padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH),
mgf=padding.MGF1(hashes.SHA256()), hashes.SHA256(),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
) )
# Base64编码 # Base64编码
signature_b64 = base64.b64encode(signature).decode('utf-8') signature_b64 = base64.b64encode(signature).decode("utf-8")
return timestamp, signature_b64 return timestamp, signature_b64
@@ -113,10 +109,7 @@ class TelemetryHeartBeatTask(AsyncTask):
raise ValueError("私钥未初始化") raise ValueError("私钥未初始化")
# 加载私钥 # 加载私钥
private_key = serialization.load_pem_private_key( private_key = serialization.load_pem_private_key(self.private_key_pem.encode("utf-8"), password=None)
self.private_key_pem.encode('utf-8'),
password=None
)
# 确保是RSA私钥 # 确保是RSA私钥
if not isinstance(private_key, rsa.RSAPrivateKey): if not isinstance(private_key, rsa.RSAPrivateKey):
@@ -125,14 +118,10 @@ class TelemetryHeartBeatTask(AsyncTask):
# 解密挑战数据 # 解密挑战数据
decrypted_bytes = private_key.decrypt( decrypted_bytes = private_key.decrypt(
base64.b64decode(challenge_b64), base64.b64decode(challenge_b64),
padding.OAEP( padding.OAEP(mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None),
mgf=padding.MGF1(hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
) )
return decrypted_bytes.decode('utf-8') return decrypted_bytes.decode("utf-8")
async def _req_uuid(self) -> bool: async def _req_uuid(self) -> bool:
""" """
@@ -155,14 +144,12 @@ class TelemetryHeartBeatTask(AsyncTask):
if response.status != 200: if response.status != 200:
response_text = await response.text() response_text = await response.text()
logger.error( logger.error(f"注册步骤1失败状态码: {response.status}, 响应内容: {response_text}")
f"注册步骤1失败状态码: {response.status}, 响应内容: {response_text}"
)
raise aiohttp.ClientResponseError( raise aiohttp.ClientResponseError(
request_info=response.request_info, request_info=response.request_info,
history=response.history, history=response.history,
status=response.status, status=response.status,
message=f"Step1 failed: {response_text}" message=f"Step1 failed: {response_text}",
) )
step1_data = await response.json() step1_data = await response.json()
@@ -195,10 +182,7 @@ class TelemetryHeartBeatTask(AsyncTask):
# Step 2: 发送解密结果完成注册 # Step 2: 发送解密结果完成注册
async with session.post( async with session.post(
f"{TELEMETRY_SERVER_URL}/stat/reg_client_step2", f"{TELEMETRY_SERVER_URL}/stat/reg_client_step2",
json={ json={"temp_uuid": temp_uuid, "decrypted_uuid": decrypted_uuid},
"temp_uuid": temp_uuid,
"decrypted_uuid": decrypted_uuid
},
timeout=aiohttp.ClientTimeout(total=5), timeout=aiohttp.ClientTimeout(total=5),
) as response: ) as response:
logger.debug(f"Step2 Response status: {response.status}") logger.debug(f"Step2 Response status: {response.status}")
@@ -225,23 +209,19 @@ class TelemetryHeartBeatTask(AsyncTask):
raise ValueError(f"Step2失败: {response_text}") raise ValueError(f"Step2失败: {response_text}")
else: else:
response_text = await response.text() response_text = await response.text()
logger.error( logger.error(f"注册步骤2失败状态码: {response.status}, 响应内容: {response_text}")
f"注册步骤2失败状态码: {response.status}, 响应内容: {response_text}"
)
raise aiohttp.ClientResponseError( raise aiohttp.ClientResponseError(
request_info=response.request_info, request_info=response.request_info,
history=response.history, history=response.history,
status=response.status, status=response.status,
message=f"Step2 failed: {response_text}" message=f"Step2 failed: {response_text}",
) )
except Exception as e: except Exception as e:
import traceback import traceback
error_msg = str(e) or "未知错误" error_msg = str(e) or "未知错误"
logger.warning( logger.warning(f"注册客户端出错,不过你还是可以正常使用墨狐: {type(e).__name__}: {error_msg}")
f"注册客户端出错,不过你还是可以正常使用墨狐: {type(e).__name__}: {error_msg}"
)
logger.debug(f"完整错误信息: {traceback.format_exc()}") logger.debug(f"完整错误信息: {traceback.format_exc()}")
# 请求失败,重试次数+1 # 请求失败,重试次数+1
@@ -270,7 +250,7 @@ class TelemetryHeartBeatTask(AsyncTask):
"X-mofox-Signature": signature, "X-mofox-Signature": signature,
"X-mofox-Timestamp": timestamp, "X-mofox-Timestamp": timestamp,
"User-Agent": f"MofoxClient/{self.client_uuid[:8]}", "User-Agent": f"MofoxClient/{self.client_uuid[:8]}",
"Content-Type": "application/json" "Content-Type": "application/json",
} }
logger.debug(f"正在发送心跳到服务器: {self.server_url}") logger.debug(f"正在发送心跳到服务器: {self.server_url}")

View File

@@ -99,7 +99,6 @@ def get_global_server() -> Server:
"""获取全局服务器实例""" """获取全局服务器实例"""
global global_server global global_server
if global_server is None: if global_server is None:
host = os.getenv("HOST", "127.0.0.1") host = os.getenv("HOST", "127.0.0.1")
port_str = os.getenv("PORT", "8000") port_str = os.getenv("PORT", "8000")

View File

@@ -44,7 +44,7 @@ from src.config.official_configs import (
PermissionConfig, PermissionConfig,
CommandConfig, CommandConfig,
PlanningSystemConfig, PlanningSystemConfig,
AffinityFlowConfig AffinityFlowConfig,
) )
from .api_ada_configs import ( from .api_ada_configs import (
@@ -399,9 +399,7 @@ class Config(ValidatedConfigBase):
cross_context: CrossContextConfig = Field( cross_context: CrossContextConfig = Field(
default_factory=lambda: CrossContextConfig(), description="跨群聊上下文共享配置" default_factory=lambda: CrossContextConfig(), description="跨群聊上下文共享配置"
) )
affinity_flow: AffinityFlowConfig = Field( affinity_flow: AffinityFlowConfig = Field(default_factory=lambda: AffinityFlowConfig(), description="亲和流配置")
default_factory=lambda: AffinityFlowConfig(), description="亲和流配置"
)
class APIAdapterConfig(ValidatedConfigBase): class APIAdapterConfig(ValidatedConfigBase):

View File

@@ -51,8 +51,12 @@ class PersonalityConfig(ValidatedConfigBase):
personality_core: str = Field(..., description="核心人格") personality_core: str = Field(..., description="核心人格")
personality_side: str = Field(..., description="人格侧写") personality_side: str = Field(..., description="人格侧写")
identity: str = Field(default="", description="身份特征") identity: str = Field(default="", description="身份特征")
background_story: str = Field(default="", description="世界观背景故事这部分内容会作为背景知识LLM被指导不应主动复述") background_story: str = Field(
safety_guidelines: List[str] = Field(default_factory=list, description="安全与互动底线Bot在任何情况下都必须遵守的原则") default="", description="世界观背景故事这部分内容会作为背景知识LLM被指导不应主动复述"
)
safety_guidelines: List[str] = Field(
default_factory=list, description="安全与互动底线Bot在任何情况下都必须遵守的原则"
)
reply_style: str = Field(default="", description="表达风格") reply_style: str = Field(default="", description="表达风格")
prompt_mode: Literal["s4u", "normal"] = Field(default="s4u", description="Prompt模式") prompt_mode: Literal["s4u", "normal"] = Field(default="s4u", description="Prompt模式")
compress_personality: bool = Field(default=True, description="是否压缩人格") compress_personality: bool = Field(default=True, description="是否压缩人格")
@@ -79,7 +83,8 @@ class ChatConfig(ValidatedConfigBase):
talk_frequency_adjust: list[list[str]] = Field(default_factory=lambda: [], description="聊天频率调整") talk_frequency_adjust: list[list[str]] = Field(default_factory=lambda: [], description="聊天频率调整")
focus_value: float = Field(default=1.0, description="专注值") focus_value: float = Field(default=1.0, description="专注值")
focus_mode_quiet_groups: List[str] = Field( focus_mode_quiet_groups: List[str] = Field(
default_factory=list, description='专注模式下需要保持安静的群组列表, 格式: ["platform:group_id1", "platform:group_id2"]' default_factory=list,
description='专注模式下需要保持安静的群组列表, 格式: ["platform:group_id1", "platform:group_id2"]',
) )
force_reply_private: bool = Field(default=False, description="强制回复私聊") force_reply_private: bool = Field(default=False, description="强制回复私聊")
group_chat_mode: Literal["auto", "normal", "focus"] = Field(default="auto", description="群聊模式") group_chat_mode: Literal["auto", "normal", "focus"] = Field(default="auto", description="群聊模式")
@@ -343,6 +348,7 @@ class ExpressionConfig(ValidatedConfigBase):
# 如果都没有匹配,返回默认值 # 如果都没有匹配,返回默认值
return True, True, 1.0 return True, True, 1.0
class ToolConfig(ValidatedConfigBase): class ToolConfig(ValidatedConfigBase):
"""工具配置类""" """工具配置类"""
@@ -477,7 +483,6 @@ class ExperimentalConfig(ValidatedConfigBase):
pfc_chatting: bool = Field(default=False, description="启用PFC聊天") pfc_chatting: bool = Field(default=False, description="启用PFC聊天")
class MaimMessageConfig(ValidatedConfigBase): class MaimMessageConfig(ValidatedConfigBase):
"""maim_message配置类""" """maim_message配置类"""
@@ -602,8 +607,12 @@ class SleepSystemConfig(ValidatedConfigBase):
sleep_by_schedule: bool = Field(default=True, description="是否根据日程表进行睡觉") sleep_by_schedule: bool = Field(default=True, description="是否根据日程表进行睡觉")
fixed_sleep_time: str = Field(default="23:00", description="固定的睡觉时间") fixed_sleep_time: str = Field(default="23:00", description="固定的睡觉时间")
fixed_wake_up_time: str = Field(default="07:00", description="固定的起床时间") fixed_wake_up_time: str = Field(default="07:00", description="固定的起床时间")
sleep_time_offset_minutes: int = Field(default=15, ge=0, le=60, description="睡觉时间随机偏移量范围(分钟),实际睡觉时间会在±该值范围内随机") sleep_time_offset_minutes: int = Field(
wake_up_time_offset_minutes: int = Field(default=15, ge=0, le=60, description="起床时间随机偏移量范围(分钟),实际起床时间会在±该值范围内随机") default=15, ge=0, le=60, description="睡觉时间随机偏移量范围(分钟),实际睡觉时间会在±该值范围内随机"
)
wake_up_time_offset_minutes: int = Field(
default=15, ge=0, le=60, description="起床时间随机偏移量范围(分钟),实际起床时间会在±该值范围内随机"
)
wakeup_threshold: float = Field(default=15.0, ge=1.0, description="唤醒阈值,达到此值时会被唤醒") wakeup_threshold: float = Field(default=15.0, ge=1.0, description="唤醒阈值,达到此值时会被唤醒")
private_message_increment: float = Field(default=3.0, ge=0.1, description="私聊消息增加的唤醒度") private_message_increment: float = Field(default=3.0, ge=0.1, description="私聊消息增加的唤醒度")
group_mention_increment: float = Field(default=2.0, ge=0.1, description="群聊艾特增加的唤醒度") group_mention_increment: float = Field(default=2.0, ge=0.1, description="群聊艾特增加的唤醒度")
@@ -618,10 +627,10 @@ class SleepSystemConfig(ValidatedConfigBase):
# --- 失眠机制相关参数 --- # --- 失眠机制相关参数 ---
enable_insomnia_system: bool = Field(default=True, description="是否启用失眠系统") enable_insomnia_system: bool = Field(default=True, description="是否启用失眠系统")
insomnia_trigger_delay_minutes: List[int] = Field( insomnia_trigger_delay_minutes: List[int] = Field(
default_factory=lambda:[30, 60], description="入睡后触发失眠判定的延迟时间范围(分钟)" default_factory=lambda: [30, 60], description="入睡后触发失眠判定的延迟时间范围(分钟)"
) )
insomnia_duration_minutes: List[int] = Field( insomnia_duration_minutes: List[int] = Field(
default_factory=lambda:[15, 45], description="单次失眠状态的持续时间范围(分钟)" default_factory=lambda: [15, 45], description="单次失眠状态的持续时间范围(分钟)"
) )
sleep_pressure_threshold: float = Field(default=30.0, description="触发“压力不足型失眠”的睡眠压力阈值") sleep_pressure_threshold: float = Field(default=30.0, description="触发“压力不足型失眠”的睡眠压力阈值")
deep_sleep_threshold: float = Field(default=80.0, description="进入“深度睡眠”的睡眠压力阈值") deep_sleep_threshold: float = Field(default=80.0, description="进入“深度睡眠”的睡眠压力阈值")
@@ -657,6 +666,8 @@ class CrossContextConfig(ValidatedConfigBase):
enable: bool = Field(default=False, description="是否启用跨群聊上下文共享功能") enable: bool = Field(default=False, description="是否启用跨群聊上下文共享功能")
groups: List[ContextGroup] = Field(default_factory=list, description="上下文共享组列表") groups: List[ContextGroup] = Field(default_factory=list, description="上下文共享组列表")
class CommandConfig(ValidatedConfigBase): class CommandConfig(ValidatedConfigBase):
"""命令系统配置类""" """命令系统配置类"""

View File

@@ -88,8 +88,7 @@ class Individuality:
# 初始化智能兴趣系统 # 初始化智能兴趣系统
await interest_scoring_system.initialize_smart_interests( await interest_scoring_system.initialize_smart_interests(
personality_description=full_personality, personality_description=full_personality, personality_id=self.bot_person_id
personality_id=self.bot_person_id
) )
logger.info("智能兴趣系统初始化完成") logger.info("智能兴趣系统初始化完成")

View File

@@ -130,7 +130,8 @@ class MainSystem:
# 停止消息重组器 # 停止消息重组器
from src.plugin_system.core.event_manager import event_manager from src.plugin_system.core.event_manager import event_manager
from src.plugin_system import EventType from src.plugin_system import EventType
asyncio.run(event_manager.trigger_event(EventType.ON_STOP,permission_group="SYSTEM"))
asyncio.run(event_manager.trigger_event(EventType.ON_STOP, permission_group="SYSTEM"))
from src.utils.message_chunker import reassembler from src.utils.message_chunker import reassembler
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@@ -250,6 +251,7 @@ MoFox_Bot(第三方修改版)
# 初始化回复后关系追踪系统 # 初始化回复后关系追踪系统
from src.chat.affinity_flow.relationship_integration import initialize_relationship_tracking from src.chat.affinity_flow.relationship_integration import initialize_relationship_tracking
relationship_tracker = initialize_relationship_tracking() relationship_tracker = initialize_relationship_tracking()
if relationship_tracker: if relationship_tracker:
logger.info("回复后关系追踪系统初始化成功") logger.info("回复后关系追踪系统初始化成功")
@@ -273,6 +275,7 @@ MoFox_Bot(第三方修改版)
# 初始化LPMM知识库 # 初始化LPMM知识库
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge
initialize_lpmm_knowledge() initialize_lpmm_knowledge()
logger.info("LPMM知识库初始化成功") logger.info("LPMM知识库初始化成功")
@@ -298,6 +301,7 @@ MoFox_Bot(第三方修改版)
# 启动消息管理器 # 启动消息管理器
from src.chat.message_manager import message_manager from src.chat.message_manager import message_manager
await message_manager.start() await message_manager.start()
logger.info("消息管理器已启动") logger.info("消息管理器已启动")

View File

@@ -102,6 +102,7 @@ class PersonInfoManager:
return hashlib.md5(key.encode()).hexdigest() return hashlib.md5(key.encode()).hexdigest()
qq_id = hashlib.md5(key.encode()).hexdigest() qq_id = hashlib.md5(key.encode()).hexdigest()
# 对于 qq 平台,先检查该 person_id 是否已存在;如果存在直接返回 # 对于 qq 平台,先检查该 person_id 是否已存在;如果存在直接返回
def _db_check_and_migrate_sync(p_id: str, raw_user_id: str): def _db_check_and_migrate_sync(p_id: str, raw_user_id: str):
try: try:

View File

@@ -123,7 +123,9 @@ class RelationshipFetcher:
all_points = current_points + forgotten_points all_points = current_points + forgotten_points
if all_points: if all_points:
# 按权重和时效性综合排序 # 按权重和时效性综合排序
all_points.sort(key=lambda x: (float(x[1]) if len(x) > 1 else 0, float(x[2]) if len(x) > 2 else 0), reverse=True) all_points.sort(
key=lambda x: (float(x[1]) if len(x) > 1 else 0, float(x[2]) if len(x) > 2 else 0), reverse=True
)
selected_points = all_points[:points_num] selected_points = all_points[:points_num]
points_text = "\n".join([f"- {point[0]}{point[2]}" for point in selected_points if len(point) > 2]) points_text = "\n".join([f"- {point[0]}{point[2]}" for point in selected_points if len(point) > 2])
else: else:
@@ -139,7 +141,8 @@ class RelationshipFetcher:
# 2. 认识时间和频率 # 2. 认识时间和频率
if know_since: if know_since:
from datetime import datetime from datetime import datetime
know_time = datetime.fromtimestamp(know_since).strftime('%Y年%m月%d')
know_time = datetime.fromtimestamp(know_since).strftime("%Y年%m月%d")
relation_parts.append(f"你从{know_time}开始认识{person_name}") relation_parts.append(f"你从{know_time}开始认识{person_name}")
if know_times > 0: if know_times > 0:
@@ -147,7 +150,8 @@ class RelationshipFetcher:
if last_know: if last_know:
from datetime import datetime from datetime import datetime
last_time = datetime.fromtimestamp(last_know).strftime('%m月%d')
last_time = datetime.fromtimestamp(last_know).strftime("%m月%d")
relation_parts.append(f"最近一次交流是在{last_time}") relation_parts.append(f"最近一次交流是在{last_time}")
# 3. 态度和印象 # 3. 态度和印象
@@ -173,7 +177,7 @@ class RelationshipFetcher:
relationships = await db_query( relationships = await db_query(
UserRelationships, UserRelationships,
filters=[UserRelationships.user_id == str(person_info_manager.get_value_sync(person_id, "user_id"))], filters=[UserRelationships.user_id == str(person_info_manager.get_value_sync(person_id, "user_id"))],
limit=1 limit=1,
) )
if relationships: if relationships:
@@ -189,7 +193,9 @@ class RelationshipFetcher:
# 构建最终的关系信息字符串 # 构建最终的关系信息字符串
if relation_parts: if relation_parts:
relation_info = f"关于{person_name},你知道以下信息:\n" + "\n".join([f"{part}" for part in relation_parts]) relation_info = f"关于{person_name},你知道以下信息:\n" + "\n".join(
[f"{part}" for part in relation_parts]
)
else: else:
relation_info = f"你对{person_name}了解不多,这是比较初步的交流。" relation_info = f"你对{person_name}了解不多,这是比较初步的交流。"

View File

@@ -93,7 +93,6 @@ class BaseAction(ABC):
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy() self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
self.chat_type_allow: ChatType = getattr(self.__class__, "chat_type_allow", ChatType.ALL) self.chat_type_allow: ChatType = getattr(self.__class__, "chat_type_allow", ChatType.ALL)
# ============================================================================= # =============================================================================
# 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解) # 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解)
# ============================================================================= # =============================================================================
@@ -398,6 +397,7 @@ class BaseAction(ABC):
try: try:
# 1. 从注册中心获取Action类 # 1. 从注册中心获取Action类
from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.component_registry import component_registry
action_class = component_registry.get_component_class(action_name, ComponentType.ACTION) action_class = component_registry.get_component_class(action_name, ComponentType.ACTION)
if not action_class: if not action_class:
logger.error(f"{log_prefix} 未找到Action: {action_name}") logger.error(f"{log_prefix} 未找到Action: {action_name}")

View File

@@ -107,7 +107,7 @@ class BaseEventHandler(ABC):
""" """
self.plugin_name = plugin_name self.plugin_name = plugin_name
def set_plugin_config(self,plugin_config) -> None: def set_plugin_config(self, plugin_config) -> None:
self.plugin_config = plugin_config self.plugin_config = plugin_config
def get_config(self, key: str, default=None): def get_config(self, key: str, default=None):

View File

@@ -69,7 +69,7 @@ class EventType(Enum):
""" """
ON_START = "on_start" # 启动事件,用于调用按时任务 ON_START = "on_start" # 启动事件,用于调用按时任务
ON_STOP ="on_stop" ON_STOP = "on_stop"
ON_MESSAGE = "on_message" ON_MESSAGE = "on_message"
ON_PLAN = "on_plan" ON_PLAN = "on_plan"
POST_LLM = "post_llm" POST_LLM = "post_llm"

View File

@@ -270,7 +270,9 @@ class ComponentRegistry:
# 使用EventManager进行事件处理器注册 # 使用EventManager进行事件处理器注册
from src.plugin_system.core.event_manager import event_manager from src.plugin_system.core.event_manager import event_manager
return event_manager.register_event_handler(handler_class,self.get_plugin_config(handler_info.plugin_name) or {}) return event_manager.register_event_handler(
handler_class, self.get_plugin_config(handler_info.plugin_name) or {}
)
# === 组件移除相关 === # === 组件移除相关 ===
@@ -686,9 +688,10 @@ class ComponentRegistry:
# 如果插件实例不存在,尝试从配置文件读取 # 如果插件实例不存在,尝试从配置文件读取
try: try:
import toml import toml
config_path = Path("config") / "plugins" / plugin_name / "config.toml" config_path = Path("config") / "plugins" / plugin_name / "config.toml"
if config_path.exists(): if config_path.exists():
with open(config_path, 'r', encoding='utf-8') as f: with open(config_path, "r", encoding="utf-8") as f:
config_data = toml.load(f) config_data = toml.load(f)
logger.debug(f"从配置文件读取插件 {plugin_name} 的配置") logger.debug(f"从配置文件读取插件 {plugin_name} 的配置")
return config_data return config_data

View File

@@ -145,7 +145,9 @@ class EventManager:
logger.info(f"事件 {event_name} 已禁用") logger.info(f"事件 {event_name} 已禁用")
return True return True
def register_event_handler(self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None) -> bool: def register_event_handler(
self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None
) -> bool:
"""注册事件处理器 """注册事件处理器
Args: Args:
@@ -167,7 +169,7 @@ class EventManager:
# 创建事件处理器实例,传递插件配置 # 创建事件处理器实例,传递插件配置
handler_instance = handler_class() handler_instance = handler_class()
handler_instance.plugin_config = plugin_config handler_instance.plugin_config = plugin_config
if plugin_config is not None and hasattr(handler_instance, 'set_plugin_config'): if plugin_config is not None and hasattr(handler_instance, "set_plugin_config"):
handler_instance.set_plugin_config(plugin_config) handler_instance.set_plugin_config(plugin_config)
self._event_handlers[handler_name] = handler_instance self._event_handlers[handler_name] = handler_instance

View File

@@ -199,9 +199,7 @@ class PluginManager:
self._show_plugin_components(plugin_name) self._show_plugin_components(plugin_name)
# 检查并调用 on_plugin_loaded 钩子(如果存在) # 检查并调用 on_plugin_loaded 钩子(如果存在)
if hasattr(plugin_instance, "on_plugin_loaded") and callable( if hasattr(plugin_instance, "on_plugin_loaded") and callable(plugin_instance.on_plugin_loaded):
plugin_instance.on_plugin_loaded
):
logger.debug(f"为插件 '{plugin_name}' 调用 on_plugin_loaded 钩子") logger.debug(f"为插件 '{plugin_name}' 调用 on_plugin_loaded 钩子")
try: try:
# 使用 asyncio.create_task 确保它不会阻塞加载流程 # 使用 asyncio.create_task 确保它不会阻塞加载流程

View File

@@ -85,7 +85,7 @@ class AtAction(BaseAction):
reply_to=reply_to, reply_to=reply_to,
extra_info=extra_info, extra_info=extra_info,
enable_tool=False, # 艾特回复通常不需要工具调用 enable_tool=False, # 艾特回复通常不需要工具调用
from_plugin=False from_plugin=False,
) )
if success and llm_response: if success and llm_response:

View File

@@ -70,7 +70,9 @@ class EmojiAction(BaseAction):
# 2. 获取所有有效的表情包对象 # 2. 获取所有有效的表情包对象
emoji_manager = get_emoji_manager() emoji_manager = get_emoji_manager()
all_emojis_obj: list[MaiEmoji] = [e for e in emoji_manager.emoji_objects if not e.is_deleted and e.description] all_emojis_obj: list[MaiEmoji] = [
e for e in emoji_manager.emoji_objects if not e.is_deleted and e.description
]
if not all_emojis_obj: if not all_emojis_obj:
logger.warning(f"{self.log_prefix} 无法获取任何带有描述的有效表情包") logger.warning(f"{self.log_prefix} 无法获取任何带有描述的有效表情包")
return False, "无法获取任何带有描述的有效表情包" return False, "无法获取任何带有描述的有效表情包"
@@ -171,7 +173,9 @@ class EmojiAction(BaseAction):
if matched_key: if matched_key:
emoji_base64, emoji_description = random.choice(emotion_map[matched_key]) emoji_base64, emoji_description = random.choice(emotion_map[matched_key])
logger.info(f"{self.log_prefix} 找到匹配情感 '{chosen_emotion}' (匹配到: '{matched_key}') 的表情包: {emoji_description}") logger.info(
f"{self.log_prefix} 找到匹配情感 '{chosen_emotion}' (匹配到: '{matched_key}') 的表情包: {emoji_description}"
)
else: else:
logger.warning( logger.warning(
f"{self.log_prefix} LLM选择的情感 '{chosen_emotion}' 不在可用列表中, 将随机选择一个表情包" f"{self.log_prefix} LLM选择的情感 '{chosen_emotion}' 不在可用列表中, 将随机选择一个表情包"
@@ -226,15 +230,23 @@ class EmojiAction(BaseAction):
logger.info(f"{self.log_prefix} LLM选择的描述: {chosen_description}") logger.info(f"{self.log_prefix} LLM选择的描述: {chosen_description}")
# 简单关键词匹配 # 简单关键词匹配
matched_emoji = next((item for item in all_emojis_data if chosen_description.lower() in item[1].lower() or item[1].lower() in chosen_description.lower()), None) matched_emoji = next(
(
item
for item in all_emojis_data
if chosen_description.lower() in item[1].lower()
or item[1].lower() in chosen_description.lower()
),
None,
)
# 如果包含匹配失败,尝试关键词匹配 # 如果包含匹配失败,尝试关键词匹配
if not matched_emoji: if not matched_emoji:
keywords = ['惊讶', '困惑', '呆滞', '震惊', '', '无语', '', '可爱'] keywords = ["惊讶", "困惑", "呆滞", "震惊", "", "无语", "", "可爱"]
for keyword in keywords: for keyword in keywords:
if keyword in chosen_description: if keyword in chosen_description:
for item in all_emojis_data: for item in all_emojis_data:
if any(k in item[1] for k in ['', '', '', '困惑', '无语']): if any(k in item[1] for k in ["", "", "", "困惑", "无语"]):
matched_emoji = item matched_emoji = item
break break
if matched_emoji: if matched_emoji:
@@ -255,7 +267,9 @@ class EmojiAction(BaseAction):
if not success: if not success:
logger.error(f"{self.log_prefix} 表情包发送失败") logger.error(f"{self.log_prefix} 表情包发送失败")
await self.store_action_info(action_build_into_prompt = True,action_prompt_display =f"发送了一个表情包,但失败了",action_done= False) await self.store_action_info(
action_build_into_prompt=True, action_prompt_display=f"发送了一个表情包,但失败了", action_done=False
)
return False, "表情包发送失败" return False, "表情包发送失败"
# 发送成功后,记录到历史 # 发送成功后,记录到历史
@@ -264,7 +278,9 @@ class EmojiAction(BaseAction):
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix} 添加表情到历史记录时出错: {e}") logger.error(f"{self.log_prefix} 添加表情到历史记录时出错: {e}")
await self.store_action_info(action_build_into_prompt = True,action_prompt_display =f"发送了一个表情包",action_done= True) await self.store_action_info(
action_build_into_prompt=True, action_prompt_display=f"发送了一个表情包", action_done=True
)
return True, f"发送表情包: {emoji_description}" return True, f"发送表情包: {emoji_description}"

View File

@@ -1,4 +1,3 @@
from src.plugin_system import BaseEventHandler from src.plugin_system import BaseEventHandler
from src.plugin_system.base.base_event import HandlerResult from src.plugin_system.base.base_event import HandlerResult
@@ -1748,6 +1747,7 @@ class SetGroupSignHandler(BaseEventHandler):
logger.error("事件 napcat_set_group_sign 请求失败!") logger.error("事件 napcat_set_group_sign 请求失败!")
return HandlerResult(False, False, {"status": "error"}) return HandlerResult(False, False, {"status": "error"})
# ===PERSONAL=== # ===PERSONAL===
class SetInputStatusHandler(BaseEventHandler): class SetInputStatusHandler(BaseEventHandler):
handler_name: str = "napcat_set_input_status_handler" handler_name: str = "napcat_set_input_status_handler"

View File

@@ -285,7 +285,7 @@ class NapcatAdapterPlugin(BasePlugin):
def enable_plugin(self) -> bool: def enable_plugin(self) -> bool:
"""通过配置文件动态控制插件启用状态""" """通过配置文件动态控制插件启用状态"""
# 如果已经通过配置加载了状态,使用配置中的值 # 如果已经通过配置加载了状态,使用配置中的值
if hasattr(self, '_is_enabled'): if hasattr(self, "_is_enabled"):
return self._is_enabled return self._is_enabled
# 否则使用默认值(禁用状态) # 否则使用默认值(禁用状态)
return False return False
@@ -308,60 +308,107 @@ class NapcatAdapterPlugin(BasePlugin):
"nickname": ConfigField(type=str, default="", description="昵称配置(目前未使用)"), "nickname": ConfigField(type=str, default="", description="昵称配置(目前未使用)"),
}, },
"napcat_server": { "napcat_server": {
"mode": ConfigField(type=str, default="reverse", description="连接模式reverse=反向连接(作为服务器), forward=正向连接(作为客户端)", choices=["reverse", "forward"]), "mode": ConfigField(
type=str,
default="reverse",
description="连接模式reverse=反向连接(作为服务器), forward=正向连接(作为客户端)",
choices=["reverse", "forward"],
),
"host": ConfigField(type=str, default="localhost", description="主机地址"), "host": ConfigField(type=str, default="localhost", description="主机地址"),
"port": ConfigField(type=int, default=8095, description="端口号"), "port": ConfigField(type=int, default=8095, description="端口号"),
"url": ConfigField(type=str, default="", description="正向连接时的完整WebSocket URL如 ws://localhost:8080/ws (仅在forward模式下使用)"), "url": ConfigField(
"access_token": ConfigField(type=str, default="", description="WebSocket 连接的访问令牌,用于身份验证(可选)"), type=str,
default="",
description="正向连接时的完整WebSocket URL如 ws://localhost:8080/ws (仅在forward模式下使用)",
),
"access_token": ConfigField(
type=str, default="", description="WebSocket 连接的访问令牌,用于身份验证(可选)"
),
"heartbeat_interval": ConfigField(type=int, default=30, description="心跳间隔时间(按秒计)"), "heartbeat_interval": ConfigField(type=int, default=30, description="心跳间隔时间(按秒计)"),
}, },
"maibot_server": { "maibot_server": {
"host": ConfigField(type=str, default="localhost", description="麦麦在.env文件中设置的主机地址即HOST字段"), "host": ConfigField(
type=str, default="localhost", description="麦麦在.env文件中设置的主机地址即HOST字段"
),
"port": ConfigField(type=int, default=8000, description="麦麦在.env文件中设置的端口即PORT字段"), "port": ConfigField(type=int, default=8000, description="麦麦在.env文件中设置的端口即PORT字段"),
"platform_name": ConfigField(type=str, default="qq", description="平台名称,用于消息路由"), "platform_name": ConfigField(type=str, default="qq", description="平台名称,用于消息路由"),
}, },
"voice": { "voice": {
"use_tts": ConfigField(type=bool, default=False, description="是否使用tts语音请确保你配置了tts并有对应的adapter"), "use_tts": ConfigField(
type=bool, default=False, description="是否使用tts语音请确保你配置了tts并有对应的adapter"
),
}, },
"slicing": { "slicing": {
"max_frame_size": ConfigField(type=int, default=64, description="WebSocket帧的最大大小单位为字节默认64KB"), "max_frame_size": ConfigField(
type=int, default=64, description="WebSocket帧的最大大小单位为字节默认64KB"
),
"delay_ms": ConfigField(type=int, default=10, description="切片发送间隔时间,单位为毫秒"), "delay_ms": ConfigField(type=int, default=10, description="切片发送间隔时间,单位为毫秒"),
}, },
"debug": { "debug": {
"level": ConfigField(type=str, default="INFO", description="日志等级DEBUG, INFO, WARNING, ERROR, CRITICAL", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), "level": ConfigField(
type=str,
default="INFO",
description="日志等级DEBUG, INFO, WARNING, ERROR, CRITICAL",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
),
}, },
"features": { "features": {
# 权限设置 # 权限设置
"group_list_type": ConfigField(type=str, default="blacklist", description="群聊列表类型whitelist白名单或 blacklist黑名单", choices=["whitelist", "blacklist"]), "group_list_type": ConfigField(
type=str,
default="blacklist",
description="群聊列表类型whitelist白名单或 blacklist黑名单",
choices=["whitelist", "blacklist"],
),
"group_list": ConfigField(type=list, default=[], description="群聊ID列表"), "group_list": ConfigField(type=list, default=[], description="群聊ID列表"),
"private_list_type": ConfigField(type=str, default="blacklist", description="私聊列表类型whitelist白名单或 blacklist黑名单", choices=["whitelist", "blacklist"]), "private_list_type": ConfigField(
type=str,
default="blacklist",
description="私聊列表类型whitelist白名单或 blacklist黑名单",
choices=["whitelist", "blacklist"],
),
"private_list": ConfigField(type=list, default=[], description="用户ID列表"), "private_list": ConfigField(type=list, default=[], description="用户ID列表"),
"ban_user_id": ConfigField(type=list, default=[], description="全局禁止用户ID列表这些用户无法在任何地方使用机器人"), "ban_user_id": ConfigField(
type=list, default=[], description="全局禁止用户ID列表这些用户无法在任何地方使用机器人"
),
"ban_qq_bot": ConfigField(type=bool, default=False, description="是否屏蔽QQ官方机器人消息"), "ban_qq_bot": ConfigField(type=bool, default=False, description="是否屏蔽QQ官方机器人消息"),
# 聊天功能设置 # 聊天功能设置
"enable_poke": ConfigField(type=bool, default=True, description="是否启用戳一戳功能"), "enable_poke": ConfigField(type=bool, default=True, description="是否启用戳一戳功能"),
"ignore_non_self_poke": ConfigField(type=bool, default=False, description="是否无视不是针对自己的戳一戳"), "ignore_non_self_poke": ConfigField(type=bool, default=False, description="是否无视不是针对自己的戳一戳"),
"poke_debounce_seconds": ConfigField(type=int, default=3, description="戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略"), "poke_debounce_seconds": ConfigField(
type=int, default=3, description="戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略"
),
"enable_reply_at": ConfigField(type=bool, default=True, description="是否启用引用回复时艾特用户的功能"), "enable_reply_at": ConfigField(type=bool, default=True, description="是否启用引用回复时艾特用户的功能"),
"reply_at_rate": ConfigField(type=float, default=0.5, description="引用回复时艾特用户的几率 (0.0 ~ 1.0)"), "reply_at_rate": ConfigField(type=float, default=0.5, description="引用回复时艾特用户的几率 (0.0 ~ 1.0)"),
"enable_emoji_like": ConfigField(type=bool, default=True, description="是否启用群聊表情回复功能"), "enable_emoji_like": ConfigField(type=bool, default=True, description="是否启用群聊表情回复功能"),
# 视频处理设置 # 视频处理设置
"enable_video_analysis": ConfigField(type=bool, default=True, description="是否启用视频识别功能"), "enable_video_analysis": ConfigField(type=bool, default=True, description="是否启用视频识别功能"),
"max_video_size_mb": ConfigField(type=int, default=100, description="视频文件最大大小限制MB"), "max_video_size_mb": ConfigField(type=int, default=100, description="视频文件最大大小限制MB"),
"download_timeout": ConfigField(type=int, default=60, description="视频下载超时时间(秒)"), "download_timeout": ConfigField(type=int, default=60, description="视频下载超时时间(秒)"),
"supported_formats": ConfigField(type=list, default=["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"], description="支持的视频格式"), "supported_formats": ConfigField(
type=list, default=["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"], description="支持的视频格式"
),
# 消息缓冲设置 # 消息缓冲设置
"enable_message_buffer": ConfigField(type=bool, default=True, description="是否启用消息缓冲合并功能"), "enable_message_buffer": ConfigField(type=bool, default=True, description="是否启用消息缓冲合并功能"),
"message_buffer_enable_group": ConfigField(type=bool, default=True, description="是否启用群聊消息缓冲合并"), "message_buffer_enable_group": ConfigField(type=bool, default=True, description="是否启用群聊消息缓冲合并"),
"message_buffer_enable_private": ConfigField(type=bool, default=True, description="是否启用私聊消息缓冲合并"), "message_buffer_enable_private": ConfigField(
"message_buffer_interval": ConfigField(type=float, default=3.0, description="消息合并间隔时间(秒),在此时间内的连续消息将被合并"), type=bool, default=True, description="是否启用私聊消息缓冲合并"
"message_buffer_initial_delay": ConfigField(type=float, default=0.5, description="消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并"), ),
"message_buffer_max_components": ConfigField(type=int, default=50, description="单个会话最大缓冲消息组件数量,超过此数量将强制合并"), "message_buffer_interval": ConfigField(
"message_buffer_block_prefixes": ConfigField(type=list, default=["/", "!", "", ".", "", "#", "%"], description="消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲"), type=float, default=3.0, description="消息合并间隔时间(秒),在此时间内的连续消息将被合并"
} ),
"message_buffer_initial_delay": ConfigField(
type=float, default=0.5, description="消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并"
),
"message_buffer_max_components": ConfigField(
type=int, default=50, description="单个会话最大缓冲消息组件数量,超过此数量将强制合并"
),
"message_buffer_block_prefixes": ConfigField(
type=list,
default=["/", "!", "", ".", "", "#", "%"],
description="消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲",
),
},
} }
# 配置节描述 # 配置节描述
@@ -374,7 +421,7 @@ class NapcatAdapterPlugin(BasePlugin):
"voice": "发送语音设置", "voice": "发送语音设置",
"slicing": "WebSocket消息切片设置", "slicing": "WebSocket消息切片设置",
"debug": "调试设置", "debug": "调试设置",
"features": "功能设置(权限控制、聊天功能、视频处理、消息缓冲等)" "features": "功能设置(权限控制、聊天功能、视频处理、消息缓冲等)",
} }
def register_events(self): def register_events(self):
@@ -409,6 +456,7 @@ class NapcatAdapterPlugin(BasePlugin):
chunker.set_plugin_config(self.config) chunker.set_plugin_config(self.config)
# 设置response_pool的插件配置 # 设置response_pool的插件配置
from .src.response_pool import set_plugin_config as set_response_pool_config from .src.response_pool import set_plugin_config as set_response_pool_config
set_response_pool_config(self.config) set_response_pool_config(self.config)
# 设置send_handler的插件配置 # 设置send_handler的插件配置
send_handler.set_plugin_config(self.config) send_handler.set_plugin_config(self.config)

View File

@@ -102,7 +102,9 @@ class SimpleMessageBuffer:
return True return True
# 检查屏蔽前缀 # 检查屏蔽前缀
block_prefixes = tuple(config_api.get_plugin_config(self.plugin_config, "features.message_buffer_block_prefixes", [])) block_prefixes = tuple(
config_api.get_plugin_config(self.plugin_config, "features.message_buffer_block_prefixes", [])
)
text = text.strip() text = text.strip()
if text.startswith(block_prefixes): if text.startswith(block_prefixes):
@@ -134,9 +136,13 @@ class SimpleMessageBuffer:
# 检查是否启用对应类型的缓冲 # 检查是否启用对应类型的缓冲
message_type = event_data.get("message_type", "") message_type = event_data.get("message_type", "")
if message_type == "group" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_group", False): if message_type == "group" and not config_api.get_plugin_config(
self.plugin_config, "features.message_buffer_enable_group", False
):
return False return False
elif message_type == "private" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_private", False): elif message_type == "private" and not config_api.get_plugin_config(
self.plugin_config, "features.message_buffer_enable_private", False
):
return False return False
# 提取文本 # 提取文本
@@ -158,7 +164,9 @@ class SimpleMessageBuffer:
session = self.buffer_pool[session_id] session = self.buffer_pool[session_id]
# 检查是否超过最大组件数量 # 检查是否超过最大组件数量
if len(session.messages) >= config_api.get_plugin_config(self.plugin_config, "features.message_buffer_max_components", 5): if len(session.messages) >= config_api.get_plugin_config(
self.plugin_config, "features.message_buffer_max_components", 5
):
logger.debug(f"会话 {session_id} 消息数量达到上限,强制合并") logger.debug(f"会话 {session_id} 消息数量达到上限,强制合并")
asyncio.create_task(self._force_merge_session(session_id)) asyncio.create_task(self._force_merge_session(session_id))
self.buffer_pool[session_id] = BufferedSession(session_id=session_id, original_event=original_event) self.buffer_pool[session_id] = BufferedSession(session_id=session_id, original_event=original_event)

View File

@@ -32,7 +32,7 @@ class NoticeType: # 通知事件
group_recall = "group_recall" # 群聊消息撤回 group_recall = "group_recall" # 群聊消息撤回
notify = "notify" notify = "notify"
group_ban = "group_ban" # 群禁言 group_ban = "group_ban" # 群禁言
group_msg_emoji_like = "group_msg_emoji_like" # 群聊表情回复 group_msg_emoji_like = "group_msg_emoji_like" # 群聊表情回复
class Notify: class Notify:
poke = "poke" # 戳一戳 poke = "poke" # 戳一戳

View File

@@ -111,7 +111,9 @@ class MessageHandler:
return False return False
else: else:
# 检查私聊黑白名单 # 检查私聊黑白名单
private_list_type = config_api.get_plugin_config(self.plugin_config, "features.private_list_type", "blacklist") private_list_type = config_api.get_plugin_config(
self.plugin_config, "features.private_list_type", "blacklist"
)
private_list = config_api.get_plugin_config(self.plugin_config, "features.private_list", []) private_list = config_api.get_plugin_config(self.plugin_config, "features.private_list", [])
if private_list_type == "whitelist": if private_list_type == "whitelist":
@@ -158,17 +160,19 @@ class MessageHandler:
""" """
# 添加原始消息调试日志特别关注message字段 # 添加原始消息调试日志特别关注message字段
logger.debug(f"收到原始消息: message_type={raw_message.get('message_type')}, message_id={raw_message.get('message_id')}") logger.debug(
f"收到原始消息: message_type={raw_message.get('message_type')}, message_id={raw_message.get('message_id')}"
)
logger.debug(f"原始消息内容: {raw_message.get('message', [])}") logger.debug(f"原始消息内容: {raw_message.get('message', [])}")
# 检查是否包含@或video消息段 # 检查是否包含@或video消息段
message_segments = raw_message.get('message', []) message_segments = raw_message.get("message", [])
if message_segments: if message_segments:
for i, seg in enumerate(message_segments): for i, seg in enumerate(message_segments):
seg_type = seg.get('type') seg_type = seg.get("type")
if seg_type in ['at', 'video']: if seg_type in ["at", "video"]:
logger.info(f"检测到 {seg_type.upper()} 消息段 [{i}]: {seg}") logger.info(f"检测到 {seg_type.upper()} 消息段 [{i}]: {seg}")
elif seg_type not in ['text', 'face', 'image']: elif seg_type not in ["text", "face", "image"]:
logger.warning(f"检测到特殊消息段 [{i}]: type={seg_type}, data={seg.get('data', {})}") logger.warning(f"检测到特殊消息段 [{i}]: type={seg_type}, data={seg.get('data', {})}")
message_type: str = raw_message.get("message_type") message_type: str = raw_message.get("message_type")
@@ -308,9 +312,13 @@ class MessageHandler:
message_type = raw_message.get("message_type") message_type = raw_message.get("message_type")
should_use_buffer = False should_use_buffer = False
if message_type == "group" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_group", True): if message_type == "group" and config_api.get_plugin_config(
self.plugin_config, "features.message_buffer_enable_group", True
):
should_use_buffer = True should_use_buffer = True
elif message_type == "private" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_private", True): elif message_type == "private" and config_api.get_plugin_config(
self.plugin_config, "features.message_buffer_enable_private", True
):
should_use_buffer = True should_use_buffer = True
if should_use_buffer: if should_use_buffer:

View File

@@ -33,6 +33,7 @@ class MessageSending:
try: try:
# 重新导入router # 重新导入router
from ..mmc_com_layer import router from ..mmc_com_layer import router
self.maibot_router = router self.maibot_router = router
if self.maibot_router is not None: if self.maibot_router is not None:
logger.info("MaiBot router重连成功") logger.info("MaiBot router重连成功")
@@ -75,7 +76,7 @@ class MessageSending:
platform = message_base.message_info.platform platform = message_base.message_info.platform
# 再次检查router状态防止运行时被重置 # 再次检查router状态防止运行时被重置
if self.maibot_router is None or not hasattr(self.maibot_router, 'clients'): if self.maibot_router is None or not hasattr(self.maibot_router, "clients"):
logger.warning("MaiBot router连接已断开尝试重新连接") logger.warning("MaiBot router连接已断开尝试重新连接")
if not await self._attempt_reconnect(): if not await self._attempt_reconnect():
logger.error("MaiBot router重连失败切片发送中止") logger.error("MaiBot router重连失败切片发送中止")

View File

@@ -22,7 +22,9 @@ class MetaEventHandler:
"""设置插件配置""" """设置插件配置"""
self.plugin_config = plugin_config self.plugin_config = plugin_config
# 更新interval值 # 更新interval值
self.interval = config_api.get_plugin_config(self.plugin_config, "napcat_server.heartbeat_interval", 5000) / 1000 self.interval = (
config_api.get_plugin_config(self.plugin_config, "napcat_server.heartbeat_interval", 5000) / 1000
)
async def handle_meta_event(self, message: dict) -> None: async def handle_meta_event(self, message: dict) -> None:
event_type = message.get("meta_event_type") event_type = message.get("meta_event_type")

View File

@@ -116,9 +116,9 @@ class NoticeHandler:
sub_type = raw_message.get("sub_type") sub_type = raw_message.get("sub_type")
match sub_type: match sub_type:
case NoticeType.Notify.poke: case NoticeType.Notify.poke:
if config_api.get_plugin_config(self.plugin_config, "features.enable_poke", True) and await message_handler.check_allow_to_chat( if config_api.get_plugin_config(
user_id, group_id, False, False self.plugin_config, "features.enable_poke", True
): ) and await message_handler.check_allow_to_chat(user_id, group_id, False, False):
logger.debug("处理戳一戳消息") logger.debug("处理戳一戳消息")
handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id) handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id)
else: else:
@@ -127,14 +127,18 @@ class NoticeHandler:
from src.plugin_system.core.event_manager import event_manager from src.plugin_system.core.event_manager import event_manager
from ...event_types import NapcatEvent from ...event_types import NapcatEvent
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME) await event_manager.trigger_event(
NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME
)
case _: case _:
logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}") logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}")
case NoticeType.group_msg_emoji_like: case NoticeType.group_msg_emoji_like:
# 该事件转移到 handle_group_emoji_like_notify函数内触发 # 该事件转移到 handle_group_emoji_like_notify函数内触发
if config_api.get_plugin_config(self.plugin_config, "features.enable_emoji_like", True): if config_api.get_plugin_config(self.plugin_config, "features.enable_emoji_like", True):
logger.debug("处理群聊表情回复") logger.debug("处理群聊表情回复")
handled_message, user_info = await self.handle_group_emoji_like_notify(raw_message,group_id,user_id) handled_message, user_info = await self.handle_group_emoji_like_notify(
raw_message, group_id, user_id
)
else: else:
logger.warning("群聊表情回复被禁用,取消群聊表情回复处理") logger.warning("群聊表情回复被禁用,取消群聊表情回复处理")
case NoticeType.group_ban: case NoticeType.group_ban:
@@ -308,8 +312,10 @@ class NoticeHandler:
from src.plugin_system.core.event_manager import event_manager from src.plugin_system.core.event_manager import event_manager
from ...event_types import NapcatEvent from ...event_types import NapcatEvent
target_message = await event_manager.trigger_event(NapcatEvent.MESSAGE.GET_MSG,message_id=raw_message.get("message_id","")) target_message = await event_manager.trigger_event(
target_message_text = target_message.get_message_result().get("data",{}).get("raw_message","") NapcatEvent.MESSAGE.GET_MSG, message_id=raw_message.get("message_id", "")
)
target_message_text = target_message.get_message_result().get("data", {}).get("raw_message", "")
if not target_message: if not target_message:
logger.error("未找到对应消息") logger.error("未找到对应消息")
return None, None return None, None
@@ -325,14 +331,17 @@ class NoticeHandler:
like_emoji_id = raw_message.get("likes")[0].get("emoji_id") like_emoji_id = raw_message.get("likes")[0].get("emoji_id")
await event_manager.trigger_event( await event_manager.trigger_event(
NapcatEvent.ON_RECEIVED.EMOJI_LIEK, NapcatEvent.ON_RECEIVED.EMOJI_LIEK,
permission_group=PLUGIN_NAME, permission_group=PLUGIN_NAME,
group_id=group_id, group_id=group_id,
user_id=user_id, user_id=user_id,
message_id=raw_message.get("message_id",""), message_id=raw_message.get("message_id", ""),
emoji_id=like_emoji_id emoji_id=like_emoji_id,
) )
seg_data = Seg(type="text",data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id,"")}回复了你的消息[{target_message_text}]") seg_data = Seg(
type="text",
data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]",
)
return seg_data, user_info return seg_data, user_info
async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]: async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]:

View File

@@ -297,9 +297,9 @@ class SendHandler:
try: try:
# 检查是否为缓冲消息ID格式buffered-{original_id}-{timestamp} # 检查是否为缓冲消息ID格式buffered-{original_id}-{timestamp}
if id.startswith('buffered-'): if id.startswith("buffered-"):
# 从缓冲消息ID中提取原始消息ID # 从缓冲消息ID中提取原始消息ID
original_id = id.split('-')[1] original_id = id.split("-")[1]
msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(original_id)}) msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(original_id)})
else: else:
msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(id)}) msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(id)})

View File

@@ -18,7 +18,9 @@ class WebSocketManager:
self.max_reconnect_attempts = 10 # 最大重连次数 self.max_reconnect_attempts = 10 # 最大重连次数
self.plugin_config = None self.plugin_config = None
async def start_connection(self, message_handler: Callable[[Server.ServerConnection], Any], plugin_config: dict) -> None: async def start_connection(
self, message_handler: Callable[[Server.ServerConnection], Any], plugin_config: dict
) -> None:
"""根据配置启动 WebSocket 连接""" """根据配置启动 WebSocket 连接"""
self.plugin_config = plugin_config self.plugin_config = plugin_config
mode = config_api.get_plugin_config(plugin_config, "napcat_server.mode") mode = config_api.get_plugin_config(plugin_config, "napcat_server.mode")
@@ -72,9 +74,7 @@ class WebSocketManager:
# 如果配置了访问令牌,添加到请求头 # 如果配置了访问令牌,添加到请求头
access_token = config_api.get_plugin_config(self.plugin_config, "napcat_server.access_token") access_token = config_api.get_plugin_config(self.plugin_config, "napcat_server.access_token")
if access_token: if access_token:
connect_kwargs["additional_headers"] = { connect_kwargs["additional_headers"] = {"Authorization": f"Bearer {access_token}"}
"Authorization": f"Bearer {access_token}"
}
logger.info("已添加访问令牌到连接请求头") logger.info("已添加访问令牌到连接请求头")
async with Server.connect(url, **connect_kwargs) as websocket: async with Server.connect(url, **connect_kwargs) as websocket:

View File

@@ -1,6 +1,7 @@
""" """
Base search engine interface Base search engine interface
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Any from typing import Dict, List, Any

View File

@@ -1,6 +1,7 @@
""" """
Bing search engine implementation Bing search engine implementation
""" """
import asyncio import asyncio
import functools import functools
import random import random
@@ -202,12 +203,7 @@ class BingSearchEngine(BaseSearchEngine):
if ABSTRACT_MAX_LENGTH and len(abstract) > ABSTRACT_MAX_LENGTH: if ABSTRACT_MAX_LENGTH and len(abstract) > ABSTRACT_MAX_LENGTH:
abstract = abstract[:ABSTRACT_MAX_LENGTH] + "..." abstract = abstract[:ABSTRACT_MAX_LENGTH] + "..."
list_data.append({ list_data.append({"title": title, "url": url, "snippet": abstract, "provider": "Bing"})
"title": title,
"url": url,
"snippet": abstract,
"provider": "Bing"
})
if len(list_data) >= 10: # 限制结果数量 if len(list_data) >= 10: # 限制结果数量
break break
@@ -222,16 +218,28 @@ class BingSearchEngine(BaseSearchEngine):
text = link.get_text().strip() text = link.get_text().strip()
# 过滤有效的搜索结果链接 # 过滤有效的搜索结果链接
if (href and text and len(text) > 10 if (
href
and text
and len(text) > 10
and not href.startswith("javascript:") and not href.startswith("javascript:")
and not href.startswith("#") and not href.startswith("#")
and "http" in href and "http" in href
and not any(x in href for x in [ and not any(
"bing.com/search", "bing.com/images", "bing.com/videos", x in href
"bing.com/maps", "bing.com/news", "login", "account", for x in [
"microsoft", "javascript" "bing.com/search",
])): "bing.com/images",
"bing.com/videos",
"bing.com/maps",
"bing.com/news",
"login",
"account",
"microsoft",
"javascript",
]
)
):
# 尝试获取摘要 # 尝试获取摘要
abstract = "" abstract = ""
parent = link.parent parent = link.parent
@@ -244,12 +252,7 @@ class BingSearchEngine(BaseSearchEngine):
if ABSTRACT_MAX_LENGTH and len(abstract) > ABSTRACT_MAX_LENGTH: if ABSTRACT_MAX_LENGTH and len(abstract) > ABSTRACT_MAX_LENGTH:
abstract = abstract[:ABSTRACT_MAX_LENGTH] + "..." abstract = abstract[:ABSTRACT_MAX_LENGTH] + "..."
list_data.append({ list_data.append({"title": text, "url": href, "snippet": abstract, "provider": "Bing"})
"title": text,
"url": href,
"snippet": abstract,
"provider": "Bing"
})
if len(list_data) >= 10: if len(list_data) >= 10:
break break

View File

@@ -1,6 +1,7 @@
""" """
DuckDuckGo search engine implementation DuckDuckGo search engine implementation
""" """
from typing import Dict, List, Any from typing import Dict, List, Any
from asyncddgs import aDDGS from asyncddgs import aDDGS
@@ -29,12 +30,7 @@ class DDGSearchEngine(BaseSearchEngine):
search_response = await ddgs.text(query, max_results=num_results) search_response = await ddgs.text(query, max_results=num_results)
return [ return [
{ {"title": r.get("title"), "url": r.get("href"), "snippet": r.get("body"), "provider": "DuckDuckGo"}
"title": r.get("title"),
"url": r.get("href"),
"snippet": r.get("body"),
"provider": "DuckDuckGo"
}
for r in search_response for r in search_response
] ]
except Exception as e: except Exception as e:

View File

@@ -1,6 +1,7 @@
""" """
Exa search engine implementation Exa search engine implementation
""" """
import asyncio import asyncio
import functools import functools
from datetime import datetime, timedelta from datetime import datetime, timedelta
@@ -29,11 +30,7 @@ class ExaSearchEngine(BaseSearchEngine):
exa_api_keys = config_api.get_global_config("web_search.exa_api_keys", None) exa_api_keys = config_api.get_global_config("web_search.exa_api_keys", None)
# 创建API密钥管理器 # 创建API密钥管理器
self.api_manager = create_api_key_manager_from_config( self.api_manager = create_api_key_manager_from_config(exa_api_keys, lambda key: Exa(api_key=key), "Exa")
exa_api_keys,
lambda key: Exa(api_key=key),
"Exa"
)
def is_available(self) -> bool: def is_available(self) -> bool:
"""检查Exa搜索引擎是否可用""" """检查Exa搜索引擎是否可用"""
@@ -52,7 +49,7 @@ class ExaSearchEngine(BaseSearchEngine):
if time_range != "any": if time_range != "any":
today = datetime.now() today = datetime.now()
start_date = today - timedelta(days=7 if time_range == "week" else 30) start_date = today - timedelta(days=7 if time_range == "week" else 30)
exa_args["start_published_date"] = start_date.strftime('%Y-%m-%d') exa_args["start_published_date"] = start_date.strftime("%Y-%m-%d")
try: try:
# 使用API密钥管理器获取下一个客户端 # 使用API密钥管理器获取下一个客户端
@@ -69,8 +66,8 @@ class ExaSearchEngine(BaseSearchEngine):
{ {
"title": res.title, "title": res.title,
"url": res.url, "url": res.url,
"snippet": " ".join(getattr(res, 'highlights', [])) or (getattr(res, 'text', '')[:250] + '...'), "snippet": " ".join(getattr(res, "highlights", [])) or (getattr(res, "text", "")[:250] + "..."),
"provider": "Exa" "provider": "Exa",
} }
for res in search_response.results for res in search_response.results
] ]

View File

@@ -1,6 +1,7 @@
""" """
Tavily search engine implementation Tavily search engine implementation
""" """
import asyncio import asyncio
import functools import functools
from typing import Dict, List, Any from typing import Dict, List, Any
@@ -29,9 +30,7 @@ class TavilySearchEngine(BaseSearchEngine):
# 创建API密钥管理器 # 创建API密钥管理器
self.api_manager = create_api_key_manager_from_config( self.api_manager = create_api_key_manager_from_config(
tavily_api_keys, tavily_api_keys, lambda key: TavilyClient(api_key=key), "Tavily"
lambda key: TavilyClient(api_key=key),
"Tavily"
) )
def is_available(self) -> bool: def is_available(self) -> bool:
@@ -60,7 +59,7 @@ class TavilySearchEngine(BaseSearchEngine):
"max_results": num_results, "max_results": num_results,
"search_depth": "basic", "search_depth": "basic",
"include_answer": False, "include_answer": False,
"include_raw_content": False "include_raw_content": False,
} }
# 根据时间范围调整搜索参数 # 根据时间范围调整搜索参数
@@ -76,12 +75,14 @@ class TavilySearchEngine(BaseSearchEngine):
results = [] results = []
if search_response and "results" in search_response: if search_response and "results" in search_response:
for res in search_response["results"]: for res in search_response["results"]:
results.append({ results.append(
"title": res.get("title", "无标题"), {
"url": res.get("url", ""), "title": res.get("title", "无标题"),
"snippet": res.get("content", "")[:300] + "..." if res.get("content") else "无摘要", "url": res.get("url", ""),
"provider": "Tavily" "snippet": res.get("content", "")[:300] + "..." if res.get("content") else "无摘要",
}) "provider": "Tavily",
}
)
return results return results

View File

@@ -3,15 +3,10 @@ Web Search Tool Plugin
一个功能强大的网络搜索和URL解析插件支持多种搜索引擎和解析策略。 一个功能强大的网络搜索和URL解析插件支持多种搜索引擎和解析策略。
""" """
from typing import List, Tuple, Type from typing import List, Tuple, Type
from src.plugin_system import ( from src.plugin_system import BasePlugin, register_plugin, ComponentInfo, ConfigField, PythonDependency
BasePlugin,
register_plugin,
ComponentInfo,
ConfigField,
PythonDependency
)
from src.plugin_system.apis import config_api from src.plugin_system.apis import config_api
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -61,7 +56,7 @@ class WEBSEARCHPLUGIN(BasePlugin):
"Exa": exa_engine.is_available(), "Exa": exa_engine.is_available(),
"Tavily": tavily_engine.is_available(), "Tavily": tavily_engine.is_available(),
"DuckDuckGo": ddg_engine.is_available(), "DuckDuckGo": ddg_engine.is_available(),
"Bing": bing_engine.is_available() "Bing": bing_engine.is_available(),
} }
available_engines = [name for name, available in engines_status.items() if available] available_engines = [name for name, available in engines_status.items() if available]
@@ -77,37 +72,30 @@ class WEBSEARCHPLUGIN(BasePlugin):
# Python包依赖列表 # Python包依赖列表
python_dependencies: List[PythonDependency] = [ python_dependencies: List[PythonDependency] = [
PythonDependency( PythonDependency(package_name="asyncddgs", description="异步DuckDuckGo搜索库", optional=False),
package_name="asyncddgs",
description="异步DuckDuckGo搜索库",
optional=False
),
PythonDependency( PythonDependency(
package_name="exa_py", package_name="exa_py",
description="Exa搜索API客户端库", description="Exa搜索API客户端库",
optional=True # 如果没有API密钥这个是可选的 optional=True, # 如果没有API密钥这个是可选的
), ),
PythonDependency( PythonDependency(
package_name="tavily", package_name="tavily",
install_name="tavily-python", # 安装时使用这个名称 install_name="tavily-python", # 安装时使用这个名称
description="Tavily搜索API客户端库", description="Tavily搜索API客户端库",
optional=True # 如果没有API密钥这个是可选的 optional=True, # 如果没有API密钥这个是可选的
), ),
PythonDependency( PythonDependency(
package_name="httpx", package_name="httpx",
version=">=0.20.0", version=">=0.20.0",
install_name="httpx[socks]", # 安装时使用这个名称(包含可选依赖) install_name="httpx[socks]", # 安装时使用这个名称(包含可选依赖)
description="支持SOCKS代理的HTTP客户端库", description="支持SOCKS代理的HTTP客户端库",
optional=False optional=False,
) ),
] ]
config_file_name: str = "config.toml" # 配置文件名 config_file_name: str = "config.toml" # 配置文件名
# 配置节描述 # 配置节描述
config_section_descriptions = { config_section_descriptions = {"plugin": "插件基本信息", "proxy": "链接本地解析代理配置"}
"plugin": "插件基本信息",
"proxy": "链接本地解析代理配置"
}
# 配置Schema定义 # 配置Schema定义
# 注意EXA配置和组件设置已迁移到主配置文件(bot_config.toml)的[exa]和[web_search]部分 # 注意EXA配置和组件设置已迁移到主配置文件(bot_config.toml)的[exa]和[web_search]部分
@@ -119,25 +107,15 @@ class WEBSEARCHPLUGIN(BasePlugin):
}, },
"proxy": { "proxy": {
"http_proxy": ConfigField( "http_proxy": ConfigField(
type=str, type=str, default=None, description="HTTP代理地址格式如: http://proxy.example.com:8080"
default=None,
description="HTTP代理地址格式如: http://proxy.example.com:8080"
), ),
"https_proxy": ConfigField( "https_proxy": ConfigField(
type=str, type=str, default=None, description="HTTPS代理地址格式如: http://proxy.example.com:8080"
default=None,
description="HTTPS代理地址格式如: http://proxy.example.com:8080"
), ),
"socks5_proxy": ConfigField( "socks5_proxy": ConfigField(
type=str, type=str, default=None, description="SOCKS5代理地址格式如: socks5://proxy.example.com:1080"
default=None,
description="SOCKS5代理地址格式如: socks5://proxy.example.com:1080"
), ),
"enable_proxy": ConfigField( "enable_proxy": ConfigField(type=bool, default=False, description="是否启用代理"),
type=bool,
default=False,
description="是否启用代理"
)
}, },
} }

View File

@@ -1,6 +1,7 @@
""" """
URL parser tool implementation URL parser tool implementation
""" """
import asyncio import asyncio
import functools import functools
from typing import Any, Dict from typing import Any, Dict
@@ -24,6 +25,7 @@ class URLParserTool(BaseTool):
""" """
一个用于解析和总结一个或多个网页URL内容的工具。 一个用于解析和总结一个或多个网页URL内容的工具。
""" """
name: str = "parse_url" name: str = "parse_url"
description: str = "当需要理解一个或多个特定网页链接的内容时,使用此工具。例如:'这些网页讲了什么?[https://example.com, https://example2.com]''帮我总结一下这些文章'" description: str = "当需要理解一个或多个特定网页链接的内容时,使用此工具。例如:'这些网页讲了什么?[https://example.com, https://example2.com]''帮我总结一下这些文章'"
available_for_llm: bool = True available_for_llm: bool = True
@@ -45,9 +47,7 @@ class URLParserTool(BaseTool):
# 创建API密钥管理器 # 创建API密钥管理器
self.api_manager = create_api_key_manager_from_config( self.api_manager = create_api_key_manager_from_config(
exa_api_keys, exa_api_keys, lambda key: Exa(api_key=key), "Exa URL Parser"
lambda key: Exa(api_key=key),
"Exa URL Parser"
) )
async def _local_parse_and_summarize(self, url: str) -> Dict[str, Any]: async def _local_parse_and_summarize(self, url: str) -> Dict[str, Any]:
@@ -104,12 +104,12 @@ class URLParserTool(BaseTool):
return {"error": "未配置LLM模型"} return {"error": "未配置LLM模型"}
success, summary, reasoning, model_name = await llm_api.generate_with_model( success, summary, reasoning, model_name = await llm_api.generate_with_model(
prompt=summary_prompt, prompt=summary_prompt,
model_config=model_config, model_config=model_config,
request_type="story.generate", request_type="story.generate",
temperature=0.3, temperature=0.3,
max_tokens=1000 max_tokens=1000,
) )
if not success: if not success:
logger.info(f"生成摘要失败: {summary}") logger.info(f"生成摘要失败: {summary}")
@@ -117,12 +117,7 @@ class URLParserTool(BaseTool):
logger.info(f"成功生成摘要内容:'{summary}'") logger.info(f"成功生成摘要内容:'{summary}'")
return { return {"title": title, "url": url, "snippet": summary, "source": "local"}
"title": title,
"url": url,
"snippet": summary,
"source": "local"
}
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.warning(f"本地解析URL '{url}' 失败 (HTTP {e.response.status_code})") logger.warning(f"本地解析URL '{url}' 失败 (HTTP {e.response.status_code})")
@@ -137,6 +132,7 @@ class URLParserTool(BaseTool):
""" """
# 获取当前文件路径用于缓存键 # 获取当前文件路径用于缓存键
import os import os
current_file_path = os.path.abspath(__file__) current_file_path = os.path.abspath(__file__)
# 检查缓存 # 检查缓存
@@ -182,34 +178,38 @@ class URLParserTool(BaseTool):
contents_response = await loop.run_in_executor(None, func) contents_response = await loop.run_in_executor(None, func)
except Exception as e: except Exception as e:
logger.error(f"执行 Exa URL解析时发生严重异常: {e}", exc_info=True) logger.error(f"执行 Exa URL解析时发生严重异常: {e}", exc_info=True)
contents_response = None # 确保异常后为None contents_response = None # 确保异常后为None
# 步骤 2: 处理Exa的响应 # 步骤 2: 处理Exa的响应
if contents_response and hasattr(contents_response, 'statuses'): if contents_response and hasattr(contents_response, "statuses"):
results_map = {res.url: res for res in contents_response.results} if hasattr(contents_response, 'results') else {} results_map = (
{res.url: res for res in contents_response.results} if hasattr(contents_response, "results") else {}
)
if contents_response.statuses: if contents_response.statuses:
for status in contents_response.statuses: for status in contents_response.statuses:
if status.status == 'success': if status.status == "success":
res = results_map.get(status.id) res = results_map.get(status.id)
if res: if res:
summary = getattr(res, 'summary', '') summary = getattr(res, "summary", "")
highlights = " ".join(getattr(res, 'highlights', [])) highlights = " ".join(getattr(res, "highlights", []))
text_snippet = (getattr(res, 'text', '')[:300] + '...') if getattr(res, 'text', '') else '' text_snippet = (getattr(res, "text", "")[:300] + "...") if getattr(res, "text", "") else ""
snippet = summary or highlights or text_snippet or '无摘要' snippet = summary or highlights or text_snippet or "无摘要"
successful_results.append({ successful_results.append(
"title": getattr(res, 'title', '无标题'), {
"url": getattr(res, 'url', status.id), "title": getattr(res, "title", "无标题"),
"snippet": snippet, "url": getattr(res, "url", status.id),
"source": "exa" "snippet": snippet,
}) "source": "exa",
}
)
else: else:
error_tag = getattr(status, 'error', '未知错误') error_tag = getattr(status, "error", "未知错误")
logger.warning(f"Exa解析URL '{status.id}' 失败: {error_tag}。准备本地重试。") logger.warning(f"Exa解析URL '{status.id}' 失败: {error_tag}。准备本地重试。")
urls_to_retry_locally.append(status.id) urls_to_retry_locally.append(status.id)
else: else:
# 如果Exa未配置、API调用失败或返回无效响应则所有URL都进入本地重试 # 如果Exa未配置、API调用失败或返回无效响应则所有URL都进入本地重试
urls_to_retry_locally.extend(url for url in urls if url not in [res['url'] for res in successful_results]) urls_to_retry_locally.extend(url for url in urls if url not in [res["url"] for res in successful_results])
# 步骤 3: 对失败的URL进行本地解析 # 步骤 3: 对失败的URL进行本地解析
if urls_to_retry_locally: if urls_to_retry_locally:
@@ -229,11 +229,7 @@ class URLParserTool(BaseTool):
formatted_content = format_url_parse_results(successful_results) formatted_content = format_url_parse_results(successful_results)
result = { result = {"type": "url_parse_result", "content": formatted_content, "errors": error_messages}
"type": "url_parse_result",
"content": formatted_content,
"errors": error_messages
}
# 保存到缓存 # 保存到缓存
if "error" not in result: if "error" not in result:

View File

@@ -1,6 +1,7 @@
""" """
Web search tool implementation Web search tool implementation
""" """
import asyncio import asyncio
from typing import Any, Dict, List from typing import Any, Dict, List
@@ -22,14 +23,23 @@ class WebSurfingTool(BaseTool):
""" """
网络搜索工具 网络搜索工具
""" """
name: str = "web_search" name: str = "web_search"
description: str = "用于执行网络搜索。当用户明确要求搜索,或者需要获取关于公司、产品、事件的最新信息、新闻或动态时,必须使用此工具" description: str = (
"用于执行网络搜索。当用户明确要求搜索,或者需要获取关于公司、产品、事件的最新信息、新闻或动态时,必须使用此工具"
)
available_for_llm: bool = True available_for_llm: bool = True
parameters = [ parameters = [
("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None), ("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None),
("num_results", ToolParamType.INTEGER, "期望每个搜索引擎返回的搜索结果数量默认为5。", False, None), ("num_results", ToolParamType.INTEGER, "期望每个搜索引擎返回的搜索结果数量默认为5。", False, None),
("time_range", ToolParamType.STRING, "指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'", False, ["any", "week", "month"]) (
] # type: ignore "time_range",
ToolParamType.STRING,
"指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'",
False,
["any", "week", "month"],
),
] # type: ignore
def __init__(self, plugin_config=None): def __init__(self, plugin_config=None):
super().__init__(plugin_config) super().__init__(plugin_config)
@@ -38,7 +48,7 @@ class WebSurfingTool(BaseTool):
"exa": ExaSearchEngine(), "exa": ExaSearchEngine(),
"tavily": TavilySearchEngine(), "tavily": TavilySearchEngine(),
"ddg": DDGSearchEngine(), "ddg": DDGSearchEngine(),
"bing": BingSearchEngine() "bing": BingSearchEngine(),
} }
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
@@ -48,6 +58,7 @@ class WebSurfingTool(BaseTool):
# 获取当前文件路径用于缓存键 # 获取当前文件路径用于缓存键
import os import os
current_file_path = os.path.abspath(__file__) current_file_path = os.path.abspath(__file__)
# 检查缓存 # 检查缓存
@@ -76,7 +87,9 @@ class WebSurfingTool(BaseTool):
return result return result
async def _execute_parallel_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]: async def _execute_parallel_search(
self, function_args: Dict[str, Any], enabled_engines: List[str]
) -> Dict[str, Any]:
"""并行搜索策略:同时使用所有启用的搜索引擎""" """并行搜索策略:同时使用所有启用的搜索引擎"""
search_tasks = [] search_tasks = []
@@ -113,7 +126,9 @@ class WebSurfingTool(BaseTool):
logger.error(f"执行并行网络搜索时发生异常: {e}", exc_info=True) logger.error(f"执行并行网络搜索时发生异常: {e}", exc_info=True)
return {"error": f"执行网络搜索时发生严重错误: {str(e)}"} return {"error": f"执行网络搜索时发生严重错误: {str(e)}"}
async def _execute_fallback_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]: async def _execute_fallback_search(
self, function_args: Dict[str, Any], enabled_engines: List[str]
) -> Dict[str, Any]:
"""回退搜索策略:按顺序尝试搜索引擎,失败则尝试下一个""" """回退搜索策略:按顺序尝试搜索引擎,失败则尝试下一个"""
for engine_name in enabled_engines: for engine_name in enabled_engines:
engine = self.engines.get(engine_name) engine = self.engines.get(engine_name)

View File

@@ -1,13 +1,14 @@
""" """
API密钥管理器提供轮询机制 API密钥管理器提供轮询机制
""" """
import itertools import itertools
from typing import List, Optional, TypeVar, Generic, Callable from typing import List, Optional, TypeVar, Generic, Callable
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger("api_key_manager") logger = get_logger("api_key_manager")
T = TypeVar('T') T = TypeVar("T")
class APIKeyManager(Generic[T]): class APIKeyManager(Generic[T]):
@@ -65,9 +66,7 @@ class APIKeyManager(Generic[T]):
def create_api_key_manager_from_config( def create_api_key_manager_from_config(
config_keys: Optional[List[str]], config_keys: Optional[List[str]], client_factory: Callable[[str], T], service_name: str
client_factory: Callable[[str], T],
service_name: str
) -> APIKeyManager[T]: ) -> APIKeyManager[T]:
""" """
从配置创建API密钥管理器的便捷函数 从配置创建API密钥管理器的便捷函数

View File

@@ -1,6 +1,7 @@
""" """
Formatters for web search results Formatters for web search results
""" """
from typing import List, Dict, Any from typing import List, Dict, Any
@@ -13,9 +14,9 @@ def format_search_results(results: List[Dict[str, Any]]) -> str:
formatted_string = "根据网络搜索结果:\n\n" formatted_string = "根据网络搜索结果:\n\n"
for i, res in enumerate(results, 1): for i, res in enumerate(results, 1):
title = res.get("title", '无标题') title = res.get("title", "无标题")
url = res.get("url", '#') url = res.get("url", "#")
snippet = res.get("snippet", '无摘要') snippet = res.get("snippet", "无摘要")
provider = res.get("provider", "未知来源") provider = res.get("provider", "未知来源")
formatted_string += f"{i}. **{title}** (来自: {provider})\n" formatted_string += f"{i}. **{title}** (来自: {provider})\n"
@@ -31,10 +32,10 @@ def format_url_parse_results(results: List[Dict[str, Any]]) -> str:
""" """
formatted_parts = [] formatted_parts = []
for res in results: for res in results:
title = res.get('title', '无标题') title = res.get("title", "无标题")
url = res.get('url', '#') url = res.get("url", "#")
snippet = res.get('snippet', '无摘要') snippet = res.get("snippet", "无摘要")
source = res.get('source', '未知') source = res.get("source", "未知")
formatted_string = f"**{title}**\n" formatted_string = f"**{title}**\n"
formatted_string += f"**内容摘要**:\n{snippet}\n" formatted_string += f"**内容摘要**:\n{snippet}\n"

View File

@@ -1,6 +1,7 @@
""" """
URL processing utilities URL processing utilities
""" """
import re import re
from typing import List from typing import List
@@ -12,11 +13,11 @@ def parse_urls_from_input(urls_input) -> List[str]:
if isinstance(urls_input, str): if isinstance(urls_input, str):
# 如果是字符串尝试解析为URL列表 # 如果是字符串尝试解析为URL列表
# 提取所有HTTP/HTTPS URL # 提取所有HTTP/HTTPS URL
url_pattern = r'https?://[^\s\],]+' url_pattern = r"https?://[^\s\],]+"
urls = re.findall(url_pattern, urls_input) urls = re.findall(url_pattern, urls_input)
if not urls: if not urls:
# 如果没有找到标准URL将整个字符串作为单个URL # 如果没有找到标准URL将整个字符串作为单个URL
if urls_input.strip().startswith(('http://', 'https://')): if urls_input.strip().startswith(("http://", "https://")):
urls = [urls_input.strip()] urls = [urls_input.strip()]
else: else:
return [] return []
@@ -34,6 +35,6 @@ def validate_urls(urls: List[str]) -> List[str]:
""" """
valid_urls = [] valid_urls = []
for url in urls: for url in urls:
if url.startswith(('http://', 'https://')): if url.startswith(("http://", "https://")):
valid_urls.append(url) valid_urls.append(url)
return valid_urls return valid_urls

View File

@@ -21,8 +21,18 @@ logger = get_logger(__name__)
# ============================ AsyncTask ============================ # ============================ AsyncTask ============================
class ReminderTask(AsyncTask): class ReminderTask(AsyncTask):
def __init__(self, delay: float, stream_id: str, is_group: bool, target_user_id: str, target_user_name: str, event_details: str, creator_name: str): def __init__(
self,
delay: float,
stream_id: str,
is_group: bool,
target_user_id: str,
target_user_name: str,
event_details: str,
creator_name: str,
):
super().__init__(task_name=f"ReminderTask_{target_user_id}_{datetime.now().timestamp()}") super().__init__(task_name=f"ReminderTask_{target_user_id}_{datetime.now().timestamp()}")
self.delay = delay self.delay = delay
self.stream_id = stream_id self.stream_id = stream_id
@@ -44,15 +54,15 @@ class ReminderTask(AsyncTask):
if self.is_group: if self.is_group:
# 在群聊中,构造 @ 消息段并发送 # 在群聊中,构造 @ 消息段并发送
group_id = self.stream_id.split('_')[-1] if '_' in self.stream_id else self.stream_id group_id = self.stream_id.split("_")[-1] if "_" in self.stream_id else self.stream_id
message_payload = [ message_payload = [
{"type": "at", "data": {"qq": self.target_user_id}}, {"type": "at", "data": {"qq": self.target_user_id}},
{"type": "text", "data": {"text": f" {reminder_text}"}} {"type": "text", "data": {"text": f" {reminder_text}"}},
] ]
await send_api.adapter_command_to_stream( await send_api.adapter_command_to_stream(
action="send_group_msg", action="send_group_msg",
params={"group_id": group_id, "message": message_payload}, params={"group_id": group_id, "message": message_payload},
stream_id=self.stream_id stream_id=self.stream_id,
) )
else: else:
# 在私聊中,直接发送文本 # 在私聊中,直接发送文本
@@ -66,6 +76,7 @@ class ReminderTask(AsyncTask):
# =============================== Actions =============================== # =============================== Actions ===============================
class RemindAction(BaseAction): class RemindAction(BaseAction):
"""一个能从对话中智能识别并设置定时提醒的动作。""" """一个能从对话中智能识别并设置定时提醒的动作。"""
@@ -95,12 +106,12 @@ class RemindAction(BaseAction):
action_parameters = { action_parameters = {
"user_name": "需要被提醒的人的称呼或名字,如果没有明确指定给某人,则默认为'自己'", "user_name": "需要被提醒的人的称呼或名字,如果没有明确指定给某人,则默认为'自己'",
"remind_time": "描述提醒时间的自然语言字符串,例如'十分钟后''明天下午3点'", "remind_time": "描述提醒时间的自然语言字符串,例如'十分钟后''明天下午3点'",
"event_details": "需要提醒的具体事件内容" "event_details": "需要提醒的具体事件内容",
} }
action_require = [ action_require = [
"当用户请求在未来的某个时间点提醒他/她或别人某件事时使用", "当用户请求在未来的某个时间点提醒他/她或别人某件事时使用",
"适用于包含明确时间信息和事件描述的对话", "适用于包含明确时间信息和事件描述的对话",
"例如:'10分钟后提醒我收快递''明天早上九点喊一下李四参加晨会'" "例如:'10分钟后提醒我收快递''明天早上九点喊一下李四参加晨会'",
] ]
async def execute(self) -> Tuple[bool, str]: async def execute(self) -> Tuple[bool, str]:
@@ -110,7 +121,15 @@ class RemindAction(BaseAction):
event_details = self.action_data.get("event_details") event_details = self.action_data.get("event_details")
if not all([user_name, remind_time_str, event_details]): if not all([user_name, remind_time_str, event_details]):
missing_params = [p for p, v in {"user_name": user_name, "remind_time": remind_time_str, "event_details": event_details}.items() if not v] missing_params = [
p
for p, v in {
"user_name": user_name,
"remind_time": remind_time_str,
"event_details": event_details,
}.items()
if not v
]
error_msg = f"缺少必要的提醒参数: {', '.join(missing_params)}" error_msg = f"缺少必要的提醒参数: {', '.join(missing_params)}"
logger.warning(f"[ReminderPlugin] LLM未能提取完整参数: {error_msg}") logger.warning(f"[ReminderPlugin] LLM未能提取完整参数: {error_msg}")
return False, error_msg return False, error_msg
@@ -162,7 +181,7 @@ class RemindAction(BaseAction):
target_user_id=str(user_id_to_remind), target_user_id=str(user_id_to_remind),
target_user_name=str(user_name_to_remind), target_user_name=str(user_name_to_remind),
event_details=str(event_details), event_details=str(event_details),
creator_name=str(self.user_nickname) creator_name=str(self.user_nickname),
) )
await async_task_manager.add_task(reminder_task) await async_task_manager.add_task(reminder_task)
@@ -179,6 +198,7 @@ class RemindAction(BaseAction):
# =============================== Plugin =============================== # =============================== Plugin ===============================
@register_plugin @register_plugin
class ReminderPlugin(BasePlugin): class ReminderPlugin(BasePlugin):
"""一个能从对话中智能识别并设置定时提醒的插件。""" """一个能从对话中智能识别并设置定时提醒的插件。"""
@@ -193,6 +213,4 @@ class ReminderPlugin(BasePlugin):
def get_plugin_components(self) -> List[Tuple[ActionInfo, Type[BaseAction]]]: def get_plugin_components(self) -> List[Tuple[ActionInfo, Type[BaseAction]]]:
"""注册插件的所有功能组件。""" """注册插件的所有功能组件。"""
return [ return [(RemindAction.get_action_info(), RemindAction)]
(RemindAction.get_action_info(), RemindAction)
]