ruff,私聊视为提及了bot

This commit is contained in:
Windpicker-owo
2025-09-20 22:34:22 +08:00
parent 006f9130b9
commit 444f1ca315
76 changed files with 1066 additions and 882 deletions

25
bot.py
View File

@@ -34,16 +34,18 @@ script_dir = os.path.dirname(os.path.abspath(__file__))
os.chdir(script_dir) os.chdir(script_dir)
logger.info(f"已设置工作目录为: {script_dir}") logger.info(f"已设置工作目录为: {script_dir}")
# 检查并创建.env文件 # 检查并创建.env文件
def ensure_env_file(): def ensure_env_file():
"""确保.env文件存在如果不存在则从模板创建""" """确保.env文件存在如果不存在则从模板创建"""
env_file = Path(".env") env_file = Path(".env")
template_env = Path("template/template.env") template_env = Path("template/template.env")
if not env_file.exists(): if not env_file.exists():
if template_env.exists(): if template_env.exists():
logger.info("未找到.env文件正在从模板创建...") logger.info("未找到.env文件正在从模板创建...")
import shutil import shutil
shutil.copy(template_env, env_file) shutil.copy(template_env, env_file)
logger.info("已从template/template.env创建.env文件") logger.info("已从template/template.env创建.env文件")
logger.warning("请编辑.env文件将EULA_CONFIRMED设置为true并配置其他必要参数") logger.warning("请编辑.env文件将EULA_CONFIRMED设置为true并配置其他必要参数")
@@ -51,6 +53,7 @@ def ensure_env_file():
logger.error("未找到.env文件和template.env模板文件") logger.error("未找到.env文件和template.env模板文件")
sys.exit(1) sys.exit(1)
# 确保环境文件存在 # 确保环境文件存在
ensure_env_file() ensure_env_file()
@@ -130,32 +133,32 @@ async def graceful_shutdown():
def check_eula(): def check_eula():
"""检查EULA和隐私条款确认状态 - 环境变量版类似Minecraft""" """检查EULA和隐私条款确认状态 - 环境变量版类似Minecraft"""
# 检查环境变量中的EULA确认 # 检查环境变量中的EULA确认
eula_confirmed = os.getenv('EULA_CONFIRMED', '').lower() eula_confirmed = os.getenv("EULA_CONFIRMED", "").lower()
if eula_confirmed == 'true': if eula_confirmed == "true":
logger.info("EULA已通过环境变量确认") logger.info("EULA已通过环境变量确认")
return return
# 如果没有确认,提示用户 # 如果没有确认,提示用户
confirm_logger.critical("您需要同意EULA和隐私条款才能使用MoFox_Bot") confirm_logger.critical("您需要同意EULA和隐私条款才能使用MoFox_Bot")
confirm_logger.critical("请阅读以下文件:") confirm_logger.critical("请阅读以下文件:")
confirm_logger.critical(" - EULA.md (用户许可协议)") confirm_logger.critical(" - EULA.md (用户许可协议)")
confirm_logger.critical(" - PRIVACY.md (隐私条款)") confirm_logger.critical(" - PRIVACY.md (隐私条款)")
confirm_logger.critical("然后编辑 .env 文件,将 'EULA_CONFIRMED=false' 改为 'EULA_CONFIRMED=true'") confirm_logger.critical("然后编辑 .env 文件,将 'EULA_CONFIRMED=false' 改为 'EULA_CONFIRMED=true'")
# 等待用户确认 # 等待用户确认
while True: while True:
try: try:
load_dotenv(override=True) # 重新加载.env文件 load_dotenv(override=True) # 重新加载.env文件
eula_confirmed = os.getenv('EULA_CONFIRMED', '').lower() eula_confirmed = os.getenv("EULA_CONFIRMED", "").lower()
if eula_confirmed == 'true': if eula_confirmed == "true":
confirm_logger.info("EULA确认成功感谢您的同意") confirm_logger.info("EULA确认成功感谢您的同意")
return return
confirm_logger.critical("请修改 .env 文件中的 EULA_CONFIRMED=true 后重新启动程序") confirm_logger.critical("请修改 .env 文件中的 EULA_CONFIRMED=true 后重新启动程序")
input("按Enter键检查.env文件状态...") input("按Enter键检查.env文件状态...")
except KeyboardInterrupt: except KeyboardInterrupt:
confirm_logger.info("用户取消,程序退出") confirm_logger.info("用户取消,程序退出")
sys.exit(0) sys.exit(0)

View File

@@ -20,25 +20,26 @@ files_to_update = [
"src/mais4u/mais4u_chat/s4u_mood_manager.py", "src/mais4u/mais4u_chat/s4u_mood_manager.py",
"src/plugin_system/core/tool_use.py", "src/plugin_system/core/tool_use.py",
"src/chat/memory_system/memory_activator.py", "src/chat/memory_system/memory_activator.py",
"src/chat/utils/smart_prompt.py" "src/chat/utils/smart_prompt.py",
] ]
def update_prompt_imports(file_path): def update_prompt_imports(file_path):
"""更新文件中的Prompt导入""" """更新文件中的Prompt导入"""
if not os.path.exists(file_path): if not os.path.exists(file_path):
print(f"文件不存在: {file_path}") print(f"文件不存在: {file_path}")
return False return False
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, "r", encoding="utf-8") as f:
content = f.read() content = f.read()
# 替换导入语句 # 替换导入语句
old_import = "from src.chat.utils.prompt_builder import Prompt, global_prompt_manager" old_import = "from src.chat.utils.prompt_builder import Prompt, global_prompt_manager"
new_import = "from src.chat.utils.prompt import Prompt, global_prompt_manager" new_import = "from src.chat.utils.prompt import Prompt, global_prompt_manager"
if old_import in content: if old_import in content:
new_content = content.replace(old_import, new_import) new_content = content.replace(old_import, new_import)
with open(file_path, 'w', encoding='utf-8') as f: with open(file_path, "w", encoding="utf-8") as f:
f.write(new_content) f.write(new_content)
print(f"已更新: {file_path}") print(f"已更新: {file_path}")
return True return True
@@ -46,14 +47,16 @@ def update_prompt_imports(file_path):
print(f"无需更新: {file_path}") print(f"无需更新: {file_path}")
return False return False
def main(): def main():
"""主函数""" """主函数"""
updated_count = 0 updated_count = 0
for file_path in files_to_update: for file_path in files_to_update:
if update_prompt_imports(file_path): if update_prompt_imports(file_path):
updated_count += 1 updated_count += 1
print(f"\n更新完成!共更新了 {updated_count} 个文件") print(f"\n更新完成!共更新了 {updated_count} 个文件")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -5,4 +5,4 @@
from src.chat.affinity_flow.afc_manager import afc_manager from src.chat.affinity_flow.afc_manager import afc_manager
__all__ = ['afc_manager', 'AFCManager', 'AffinityFlowChatter'] __all__ = ["afc_manager", "AFCManager", "AffinityFlowChatter"]

View File

@@ -2,6 +2,7 @@
亲和力聊天处理流管理器 亲和力聊天处理流管理器
管理不同聊天流的亲和力聊天处理流,统一获取新消息并分发到对应的亲和力聊天处理流 管理不同聊天流的亲和力聊天处理流,统一获取新消息并分发到对应的亲和力聊天处理流
""" """
import time import time
import traceback import traceback
from typing import Dict, Optional, List from typing import Dict, Optional, List
@@ -20,7 +21,7 @@ class AFCManager:
def __init__(self): def __init__(self):
self.affinity_flow_chatters: Dict[str, "AffinityFlowChatter"] = {} self.affinity_flow_chatters: Dict[str, "AffinityFlowChatter"] = {}
'''所有聊天流的亲和力聊天处理流stream_id -> affinity_flow_chatter''' """所有聊天流的亲和力聊天处理流stream_id -> affinity_flow_chatter"""
# 动作管理器 # 动作管理器
self.action_manager = ActionManager() self.action_manager = ActionManager()
@@ -40,11 +41,7 @@ class AFCManager:
# 创建增强版规划器 # 创建增强版规划器
planner = ActionPlanner(stream_id, self.action_manager) planner = ActionPlanner(stream_id, self.action_manager)
chatter = AffinityFlowChatter( chatter = AffinityFlowChatter(stream_id=stream_id, planner=planner, action_manager=self.action_manager)
stream_id=stream_id,
planner=planner,
action_manager=self.action_manager
)
self.affinity_flow_chatters[stream_id] = chatter self.affinity_flow_chatters[stream_id] = chatter
logger.info(f"创建新的亲和力聊天处理器: {stream_id}") logger.info(f"创建新的亲和力聊天处理器: {stream_id}")
@@ -74,7 +71,6 @@ class AFCManager:
"executed_count": 0, "executed_count": 0,
} }
def get_chatter_stats(self, stream_id: str) -> Optional[Dict[str, any]]: def get_chatter_stats(self, stream_id: str) -> Optional[Dict[str, any]]:
"""获取聊天处理器统计""" """获取聊天处理器统计"""
if stream_id in self.affinity_flow_chatters: if stream_id in self.affinity_flow_chatters:
@@ -131,4 +127,5 @@ class AFCManager:
self.affinity_flow_chatters[stream_id].update_interest_keywords(new_keywords) self.affinity_flow_chatters[stream_id].update_interest_keywords(new_keywords)
logger.info(f"已更新聊天流 {stream_id} 的兴趣关键词: {list(new_keywords.keys())}") logger.info(f"已更新聊天流 {stream_id} 的兴趣关键词: {list(new_keywords.keys())}")
afc_manager = AFCManager()
afc_manager = AFCManager()

View File

@@ -2,6 +2,7 @@
亲和力聊天处理器 亲和力聊天处理器
单个聊天流的处理器,负责处理特定聊天流的完整交互流程 单个聊天流的处理器,负责处理特定聊天流的完整交互流程
""" """
import time import time
import traceback import traceback
from datetime import datetime from datetime import datetime
@@ -57,10 +58,7 @@ class AffinityFlowChatter:
unread_messages = context.get_unread_messages() unread_messages = context.get_unread_messages()
# 使用增强版规划器处理消息 # 使用增强版规划器处理消息
actions, target_message = await self.planner.plan( actions, target_message = await self.planner.plan(mode=ChatMode.FOCUS, context=context)
mode=ChatMode.FOCUS,
context=context
)
self.stats["plans_created"] += 1 self.stats["plans_created"] += 1
# 执行动作(如果规划器返回了动作) # 执行动作(如果规划器返回了动作)
@@ -84,7 +82,9 @@ class AffinityFlowChatter:
**execution_result, **execution_result,
} }
logger.info(f"聊天流 {self.stream_id} StreamContext处理成功: 动作数={result['actions_count']}, 未读消息={result['unread_messages_processed']}") logger.info(
f"聊天流 {self.stream_id} StreamContext处理成功: 动作数={result['actions_count']}, 未读消息={result['unread_messages_processed']}"
)
return result return result
@@ -197,7 +197,9 @@ class AffinityFlowChatter:
def __repr__(self) -> str: def __repr__(self) -> str:
"""详细字符串表示""" """详细字符串表示"""
return (f"AffinityFlowChatter(stream_id={self.stream_id}, " return (
f"messages_processed={self.stats['messages_processed']}, " f"AffinityFlowChatter(stream_id={self.stream_id}, "
f"plans_created={self.stats['plans_created']}, " f"messages_processed={self.stats['messages_processed']}, "
f"last_activity={datetime.fromtimestamp(self.last_activity_time)})") f"plans_created={self.stats['plans_created']}, "
f"last_activity={datetime.fromtimestamp(self.last_activity_time)})"
)

View File

@@ -38,7 +38,9 @@ class InterestScoringSystem:
# 连续不回复概率提升 # 连续不回复概率提升
self.no_reply_count = 0 self.no_reply_count = 0
self.max_no_reply_count = affinity_config.max_no_reply_count self.max_no_reply_count = affinity_config.max_no_reply_count
self.probability_boost_per_no_reply = affinity_config.no_reply_threshold_adjustment / affinity_config.max_no_reply_count # 每次不回复增加的概率 self.probability_boost_per_no_reply = (
affinity_config.no_reply_threshold_adjustment / affinity_config.max_no_reply_count
) # 每次不回复增加的概率
# 用户关系数据 # 用户关系数据
self.user_relationships: Dict[str, float] = {} # user_id -> relationship_score self.user_relationships: Dict[str, float] = {} # user_id -> relationship_score
@@ -153,7 +155,9 @@ class InterestScoringSystem:
# 返回匹配分数,考虑置信度和匹配标签数量 # 返回匹配分数,考虑置信度和匹配标签数量
affinity_config = global_config.affinity_flow affinity_config = global_config.affinity_flow
match_count_bonus = min(len(match_result.matched_tags) * affinity_config.match_count_bonus, affinity_config.max_match_bonus) match_count_bonus = min(
len(match_result.matched_tags) * affinity_config.match_count_bonus, affinity_config.max_match_bonus
)
final_score = match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus final_score = match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus
logger.debug( logger.debug(
f"⚖️ 最终分数计算: 总分({match_result.overall_score:.3f}) × 1.3 × 置信度({match_result.confidence:.3f}) + 标签数量奖励({match_count_bonus:.3f}) = {final_score:.3f}" f"⚖️ 最终分数计算: 总分({match_result.overall_score:.3f}) × 1.3 × 置信度({match_result.confidence:.3f}) + 标签数量奖励({match_count_bonus:.3f}) = {final_score:.3f}"
@@ -263,7 +267,17 @@ class InterestScoringSystem:
if not msg.processed_plain_text: if not msg.processed_plain_text:
return 0.0 return 0.0
if msg.is_mentioned or (bot_nickname and bot_nickname in msg.processed_plain_text): # 检查是否被提及
is_mentioned = msg.is_mentioned or (bot_nickname and bot_nickname in msg.processed_plain_text)
# 检查是否为私聊group_info为None表示私聊
is_private_chat = msg.group_info is None
# 如果被提及或是私聊都视为提及了bot
if is_mentioned or is_private_chat:
logger.debug(f"🔍 提及检测 - 被提及: {is_mentioned}, 私聊: {is_private_chat}")
if is_private_chat and not is_mentioned:
logger.debug("💬 私聊消息自动视为提及bot")
return global_config.affinity_flow.mention_bot_interest_score return global_config.affinity_flow.mention_bot_interest_score
return 0.0 return 0.0
@@ -282,7 +296,9 @@ class InterestScoringSystem:
logger.debug(f"📋 基础阈值: {base_threshold:.3f}") logger.debug(f"📋 基础阈值: {base_threshold:.3f}")
# 如果被提及,降低阈值 # 如果被提及,降低阈值
if score.mentioned_score >= global_config.affinity_flow.mention_bot_interest_score * 0.5: # 使用提及bot兴趣分的一半作为判断阈值 if (
score.mentioned_score >= global_config.affinity_flow.mention_bot_interest_score * 0.5
): # 使用提及bot兴趣分的一半作为判断阈值
base_threshold = self.mention_threshold base_threshold = self.mention_threshold
logger.debug(f"📣 消息提及了机器人,使用降低阈值: {base_threshold:.3f}") logger.debug(f"📣 消息提及了机器人,使用降低阈值: {base_threshold:.3f}")
@@ -325,7 +341,9 @@ class InterestScoringSystem:
def update_user_relationship(self, user_id: str, relationship_change: float): def update_user_relationship(self, user_id: str, relationship_change: float):
"""更新用户关系""" """更新用户关系"""
old_score = self.user_relationships.get(user_id, global_config.affinity_flow.base_relationship_score) # 默认新用户分数 old_score = self.user_relationships.get(
user_id, global_config.affinity_flow.base_relationship_score
) # 默认新用户分数
new_score = max(0.0, min(1.0, old_score + relationship_change)) new_score = max(0.0, min(1.0, old_score + relationship_change))
self.user_relationships[user_id] = new_score self.user_relationships[user_id] = new_score

View File

@@ -116,6 +116,7 @@ class UserRelationshipTracker:
try: try:
# 获取bot人设信息 # 获取bot人设信息
from src.individuality.individuality import Individuality from src.individuality.individuality import Individuality
individuality = Individuality() individuality = Individuality()
bot_personality = await individuality.get_personality_block() bot_personality = await individuality.get_personality_block()
@@ -168,7 +169,17 @@ class UserRelationshipTracker:
# 清理LLM响应移除可能的格式标记 # 清理LLM响应移除可能的格式标记
cleaned_response = self._clean_llm_json_response(llm_response) cleaned_response = self._clean_llm_json_response(llm_response)
response_data = json.loads(cleaned_response) response_data = json.loads(cleaned_response)
new_score = max(0.0, min(1.0, float(response_data.get("new_relationship_score", global_config.affinity_flow.base_relationship_score)))) new_score = max(
0.0,
min(
1.0,
float(
response_data.get(
"new_relationship_score", global_config.affinity_flow.base_relationship_score
)
),
),
)
if self.interest_scoring_system: if self.interest_scoring_system:
self.interest_scoring_system.update_user_relationship( self.interest_scoring_system.update_user_relationship(
@@ -295,7 +306,9 @@ class UserRelationshipTracker:
# 更新缓存 # 更新缓存
self.user_relationship_cache[user_id] = { self.user_relationship_cache[user_id] = {
"relationship_text": relationship_data.get("relationship_text", ""), "relationship_text": relationship_data.get("relationship_text", ""),
"relationship_score": relationship_data.get("relationship_score", global_config.affinity_flow.base_relationship_score), "relationship_score": relationship_data.get(
"relationship_score", global_config.affinity_flow.base_relationship_score
),
"last_tracked": time.time(), "last_tracked": time.time(),
} }
return relationship_data.get("relationship_score", global_config.affinity_flow.base_relationship_score) return relationship_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)
@@ -386,7 +399,11 @@ class UserRelationshipTracker:
# 获取当前关系数据 # 获取当前关系数据
current_relationship = self._get_user_relationship_from_db(user_id) current_relationship = self._get_user_relationship_from_db(user_id)
current_score = current_relationship.get("relationship_score", global_config.affinity_flow.base_relationship_score) if current_relationship else global_config.affinity_flow.base_relationship_score current_score = (
current_relationship.get("relationship_score", global_config.affinity_flow.base_relationship_score)
if current_relationship
else global_config.affinity_flow.base_relationship_score
)
current_text = current_relationship.get("relationship_text", "新用户") if current_relationship else "新用户" current_text = current_relationship.get("relationship_text", "新用户") if current_relationship else "新用户"
# 使用LLM分析并更新关系 # 使用LLM分析并更新关系
@@ -501,6 +518,7 @@ class UserRelationshipTracker:
# 获取bot人设信息 # 获取bot人设信息
from src.individuality.individuality import Individuality from src.individuality.individuality import Individuality
individuality = Individuality() individuality = Individuality()
bot_personality = await individuality.get_personality_block() bot_personality = await individuality.get_personality_block()

View File

@@ -2,6 +2,7 @@
""" """
表情包发送历史记录模块 表情包发送历史记录模块
""" """
import os import os
from typing import List, Dict from typing import List, Dict
from collections import deque from collections import deque
@@ -26,15 +27,15 @@ def add_emoji_to_history(chat_id: str, emoji_description: str):
""" """
if not chat_id or not emoji_description: if not chat_id or not emoji_description:
return return
# 如果当前聊天还没有历史记录,则创建一个新的 deque # 如果当前聊天还没有历史记录,则创建一个新的 deque
if chat_id not in _history_cache: if chat_id not in _history_cache:
_history_cache[chat_id] = deque(maxlen=MAX_HISTORY_SIZE) _history_cache[chat_id] = deque(maxlen=MAX_HISTORY_SIZE)
# 添加新表情到历史记录 # 添加新表情到历史记录
history = _history_cache[chat_id] history = _history_cache[chat_id]
history.append(emoji_description) history.append(emoji_description)
logger.debug(f"已将表情 '{emoji_description}' 添加到聊天 {chat_id} 的内存历史中") logger.debug(f"已将表情 '{emoji_description}' 添加到聊天 {chat_id} 的内存历史中")
@@ -50,10 +51,10 @@ def get_recent_emojis(chat_id: str, limit: int = 5) -> List[str]:
return [] return []
history = _history_cache[chat_id] history = _history_cache[chat_id]
# 从 deque 的右侧(即最近添加的)开始取 # 从 deque 的右侧(即最近添加的)开始取
num_to_get = min(limit, len(history)) num_to_get = min(limit, len(history))
recent_emojis = [history[-i] for i in range(1, num_to_get + 1)] recent_emojis = [history[-i] for i in range(1, num_to_get + 1)]
logger.debug(f"为聊天 {chat_id} 从内存中获取到最近 {len(recent_emojis)} 个表情: {recent_emojis}") logger.debug(f"为聊天 {chat_id} 从内存中获取到最近 {len(recent_emojis)} 个表情: {recent_emojis}")
return recent_emojis return recent_emojis

View File

@@ -477,7 +477,7 @@ class EmojiManager:
emoji_options_str = "" emoji_options_str = ""
for i, emoji in enumerate(candidate_emojis): for i, emoji in enumerate(candidate_emojis):
# 为每个表情包创建一个编号和它的详细描述 # 为每个表情包创建一个编号和它的详细描述
emoji_options_str += f"编号: {i+1}\n描述: {emoji.description}\n\n" emoji_options_str += f"编号: {i + 1}\n描述: {emoji.description}\n\n"
# 精心设计的prompt引导LLM做出选择 # 精心设计的prompt引导LLM做出选择
prompt = f""" prompt = f"""
@@ -524,10 +524,8 @@ class EmojiManager:
self.record_usage(selected_emoji.hash) self.record_usage(selected_emoji.hash)
_time_end = time.time() _time_end = time.time()
logger.info( logger.info(f"找到匹配描述的表情包: {selected_emoji.description}, 耗时: {(_time_end - _time_start):.2f}s")
f"找到匹配描述的表情包: {selected_emoji.description}, 耗时: {(_time_end - _time_start):.2f}s"
)
# 8. 返回选中的表情包信息 # 8. 返回选中的表情包信息
return selected_emoji.full_path, f"[表情包:{selected_emoji.description}]", text_emotion return selected_emoji.full_path, f"[表情包:{selected_emoji.description}]", text_emotion
@@ -627,8 +625,9 @@ class EmojiManager:
# 无论steal_emoji是否开启都检查emoji文件夹以支持手动注册 # 无论steal_emoji是否开启都检查emoji文件夹以支持手动注册
# 只有在需要腾出空间或填充表情库时,才真正执行注册 # 只有在需要腾出空间或填充表情库时,才真正执行注册
if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or \ if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or (
(self.emoji_num < self.emoji_num_max): self.emoji_num < self.emoji_num_max
):
try: try:
# 获取目录下所有图片文件 # 获取目录下所有图片文件
files_to_process = [ files_to_process = [
@@ -931,16 +930,21 @@ class EmojiManager:
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64) image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() if Image.open(io.BytesIO(image_bytes)).format else "jpeg" image_format = (
Image.open(io.BytesIO(image_bytes)).format.lower()
if Image.open(io.BytesIO(image_bytes)).format
else "jpeg"
)
# 2. 检查数据库中是否已存在该表情包的描述,实现复用 # 2. 检查数据库中是否已存在该表情包的描述,实现复用
existing_description = None existing_description = None
try: try:
with get_db_session() as session: with get_db_session() as session:
existing_image = session.query(Images).filter( existing_image = (
(Images.emoji_hash == image_hash) & (Images.type == "emoji") session.query(Images)
).one_or_none() .filter((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
.one_or_none()
)
if existing_image and existing_image.description: if existing_image and existing_image.description:
existing_description = existing_image.description existing_description = existing_image.description
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...") logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")

View File

@@ -14,6 +14,7 @@ Chat Frequency Analyzer
- MIN_CHATS_FOR_PEAK: 在一个窗口内需要多少次聊天才能被认为是高峰时段。 - MIN_CHATS_FOR_PEAK: 在一个窗口内需要多少次聊天才能被认为是高峰时段。
- MIN_GAP_BETWEEN_PEAKS_HOURS: 两个独立高峰时段之间的最小间隔(小时)。 - MIN_GAP_BETWEEN_PEAKS_HOURS: 两个独立高峰时段之间的最小间隔(小时)。
""" """
import time as time_module import time as time_module
from datetime import datetime, timedelta, time from datetime import datetime, timedelta, time
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
@@ -71,12 +72,14 @@ class ChatFrequencyAnalyzer:
current_window_end = datetimes[i] current_window_end = datetimes[i]
# 合并重叠或相邻的高峰时段 # 合并重叠或相邻的高峰时段
if peak_windows and current_window_start - peak_windows[-1][1] < timedelta(hours=MIN_GAP_BETWEEN_PEAKS_HOURS): if peak_windows and current_window_start - peak_windows[-1][1] < timedelta(
hours=MIN_GAP_BETWEEN_PEAKS_HOURS
):
# 扩展上一个窗口的结束时间 # 扩展上一个窗口的结束时间
peak_windows[-1] = (peak_windows[-1][0], current_window_end) peak_windows[-1] = (peak_windows[-1][0], current_window_end)
else: else:
peak_windows.append((current_window_start, current_window_end)) peak_windows.append((current_window_start, current_window_end))
return peak_windows return peak_windows
def get_peak_chat_times(self, chat_id: str) -> List[Tuple[time, time]]: def get_peak_chat_times(self, chat_id: str) -> List[Tuple[time, time]]:
@@ -99,7 +102,7 @@ class ChatFrequencyAnalyzer:
return [] return []
peak_datetime_windows = self._find_peak_windows(timestamps) peak_datetime_windows = self._find_peak_windows(timestamps)
# 将 datetime 窗口转换为 time 窗口,并进行归一化处理 # 将 datetime 窗口转换为 time 窗口,并进行归一化处理
peak_time_windows = [] peak_time_windows = []
for start_dt, end_dt in peak_datetime_windows: for start_dt, end_dt in peak_datetime_windows:
@@ -109,7 +112,7 @@ class ChatFrequencyAnalyzer:
# 更新缓存 # 更新缓存
self._analysis_cache[chat_id] = (time_module.time(), peak_time_windows) self._analysis_cache[chat_id] = (time_module.time(), peak_time_windows)
return peak_time_windows return peak_time_windows
def is_in_peak_time(self, chat_id: str, now: Optional[datetime] = None) -> bool: def is_in_peak_time(self, chat_id: str, now: Optional[datetime] = None) -> bool:
@@ -125,7 +128,7 @@ class ChatFrequencyAnalyzer:
""" """
if now is None: if now is None:
now = datetime.now() now = datetime.now()
now_time = now.time() now_time = now.time()
peak_times = self.get_peak_chat_times(chat_id) peak_times = self.get_peak_chat_times(chat_id)
@@ -136,7 +139,7 @@ class ChatFrequencyAnalyzer:
else: # 跨天 else: # 跨天
if now_time >= start_time or now_time <= end_time: if now_time >= start_time or now_time <= end_time:
return True return True
return False return False

View File

@@ -55,7 +55,7 @@ class ChatFrequencyTracker:
now = time.time() now = time.time()
if chat_id not in self._timestamps: if chat_id not in self._timestamps:
self._timestamps[chat_id] = [] self._timestamps[chat_id] = []
self._timestamps[chat_id].append(now) self._timestamps[chat_id].append(now)
logger.debug(f"为 chat_id '{chat_id}' 记录了新的聊天时间: {now}") logger.debug(f"为 chat_id '{chat_id}' 记录了新的聊天时间: {now}")
self._save_timestamps() self._save_timestamps()

View File

@@ -14,6 +14,7 @@ Frequency-Based Proactive Trigger
- TRIGGER_CHECK_INTERVAL_SECONDS: 触发器检查的周期(秒)。 - TRIGGER_CHECK_INTERVAL_SECONDS: 触发器检查的周期(秒)。
- COOLDOWN_HOURS: 在同一个高峰时段内触发一次后的冷却时间(小时)。 - COOLDOWN_HOURS: 在同一个高峰时段内触发一次后的冷却时间(小时)。
""" """
import asyncio import asyncio
import time import time
from datetime import datetime from datetime import datetime
@@ -21,6 +22,7 @@ from typing import Dict, Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.affinity_flow.afc_manager import afc_manager from src.chat.affinity_flow.afc_manager import afc_manager
# TODO: 需要重新实现主动思考和睡眠管理功能 # TODO: 需要重新实现主动思考和睡眠管理功能
from .analyzer import chat_frequency_analyzer from .analyzer import chat_frequency_analyzer
@@ -65,7 +67,7 @@ class FrequencyBasedTrigger:
continue continue
now = datetime.now() now = datetime.now()
for chat_id in all_chat_ids: for chat_id in all_chat_ids:
# 3. 检查是否处于冷却时间内 # 3. 检查是否处于冷却时间内
last_triggered_time = self._last_triggered.get(chat_id, 0) last_triggered_time = self._last_triggered.get(chat_id, 0)
@@ -74,7 +76,6 @@ class FrequencyBasedTrigger:
# 4. 检查当前是否是该用户的高峰聊天时间 # 4. 检查当前是否是该用户的高峰聊天时间
if chat_frequency_analyzer.is_in_peak_time(chat_id, now): if chat_frequency_analyzer.is_in_peak_time(chat_id, now):
# 5. 检查用户当前是否已有活跃的处理任务 # 5. 检查用户当前是否已有活跃的处理任务
# 亲和力流系统不直接提供循环状态,通过检查最后活动时间来判断是否忙碌 # 亲和力流系统不直接提供循环状态,通过检查最后活动时间来判断是否忙碌
chatter = afc_manager.get_or_create_chatter(chat_id) chatter = afc_manager.get_or_create_chatter(chat_id)
@@ -87,13 +88,13 @@ class FrequencyBasedTrigger:
if current_time - chatter.get_activity_time() < 60: if current_time - chatter.get_activity_time() < 60:
logger.debug(f"用户 {chat_id} 的亲和力处理器正忙,本次不触发。") logger.debug(f"用户 {chat_id} 的亲和力处理器正忙,本次不触发。")
continue continue
logger.info(f"检测到用户 {chat_id} 处于聊天高峰期,且处理器空闲,准备触发主动思考。") logger.info(f"检测到用户 {chat_id} 处于聊天高峰期,且处理器空闲,准备触发主动思考。")
# 6. TODO: 亲和力流系统的主动思考机制需要另行实现 # 6. TODO: 亲和力流系统的主动思考机制需要另行实现
# 目前先记录日志,等待后续实现 # 目前先记录日志,等待后续实现
logger.info(f"用户 {chat_id} 处于高峰期,但亲和力流的主动思考功能暂未实现") logger.info(f"用户 {chat_id} 处于高峰期,但亲和力流的主动思考功能暂未实现")
# 7. 更新触发时间,进入冷却 # 7. 更新触发时间,进入冷却
self._last_triggered[chat_id] = time.time() self._last_triggered[chat_id] = time.time()

View File

@@ -4,14 +4,12 @@
""" """
from .bot_interest_manager import BotInterestManager, bot_interest_manager from .bot_interest_manager import BotInterestManager, bot_interest_manager
from src.common.data_models.bot_interest_data_model import ( from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
BotInterestTag, BotPersonalityInterests, InterestMatchResult
)
__all__ = [ __all__ = [
"BotInterestManager", "BotInterestManager",
"bot_interest_manager", "bot_interest_manager",
"BotInterestTag", "BotInterestTag",
"BotPersonalityInterests", "BotPersonalityInterests",
"InterestMatchResult" "InterestMatchResult",
] ]

View File

@@ -2,6 +2,7 @@
机器人兴趣标签管理系统 机器人兴趣标签管理系统
基于人设生成兴趣标签并使用embedding计算匹配度 基于人设生成兴趣标签并使用embedding计算匹配度
""" """
import orjson import orjson
import traceback import traceback
from typing import List, Dict, Optional, Any from typing import List, Dict, Optional, Any
@@ -10,9 +11,7 @@ import numpy as np
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
from src.common.data_models.bot_interest_data_model import ( from src.common.data_models.bot_interest_data_model import BotPersonalityInterests, BotInterestTag, InterestMatchResult
BotPersonalityInterests, BotInterestTag, InterestMatchResult
)
logger = get_logger("bot_interest_manager") logger = get_logger("bot_interest_manager")
@@ -87,7 +86,7 @@ class BotInterestManager:
logger.debug("✅ 成功导入embedding相关模块") logger.debug("✅ 成功导入embedding相关模块")
# 检查embedding配置是否存在 # 检查embedding配置是否存在
if not hasattr(model_config.model_task_config, 'embedding'): if not hasattr(model_config.model_task_config, "embedding"):
raise RuntimeError("❌ 未找到embedding模型配置") raise RuntimeError("❌ 未找到embedding模型配置")
logger.info("📋 找到embedding模型配置") logger.info("📋 找到embedding模型配置")
@@ -101,7 +100,7 @@ class BotInterestManager:
logger.info(f"🔗 客户端类型: {type(self.embedding_request).__name__}") logger.info(f"🔗 客户端类型: {type(self.embedding_request).__name__}")
# 获取第一个embedding模型的ModelInfo # 获取第一个embedding模型的ModelInfo
if hasattr(self.embedding_config, 'model_list') and self.embedding_config.model_list: if hasattr(self.embedding_config, "model_list") and self.embedding_config.model_list:
first_model_name = self.embedding_config.model_list[0] first_model_name = self.embedding_config.model_list[0]
logger.info(f"🎯 使用embedding模型: {first_model_name}") logger.info(f"🎯 使用embedding模型: {first_model_name}")
else: else:
@@ -127,7 +126,9 @@ class BotInterestManager:
# 生成新的兴趣标签 # 生成新的兴趣标签
logger.info("🆕 数据库中未找到兴趣标签,开始生成新的...") logger.info("🆕 数据库中未找到兴趣标签,开始生成新的...")
logger.info("🤖 正在调用LLM生成个性化兴趣标签...") logger.info("🤖 正在调用LLM生成个性化兴趣标签...")
generated_interests = await self._generate_interests_from_personality(personality_description, personality_id) generated_interests = await self._generate_interests_from_personality(
personality_description, personality_id
)
if generated_interests: if generated_interests:
self.current_interests = generated_interests self.current_interests = generated_interests
@@ -140,14 +141,16 @@ class BotInterestManager:
else: else:
raise RuntimeError("❌ 兴趣标签生成失败") raise RuntimeError("❌ 兴趣标签生成失败")
async def _generate_interests_from_personality(self, personality_description: str, personality_id: str) -> Optional[BotPersonalityInterests]: async def _generate_interests_from_personality(
self, personality_description: str, personality_id: str
) -> Optional[BotPersonalityInterests]:
"""根据人设生成兴趣标签""" """根据人设生成兴趣标签"""
try: try:
logger.info("🎨 开始根据人设生成兴趣标签...") logger.info("🎨 开始根据人设生成兴趣标签...")
logger.info(f"📝 人设长度: {len(personality_description)} 字符") logger.info(f"📝 人设长度: {len(personality_description)} 字符")
# 检查embedding客户端是否可用 # 检查embedding客户端是否可用
if not hasattr(self, 'embedding_request'): if not hasattr(self, "embedding_request"):
raise RuntimeError("❌ Embedding客户端未初始化无法生成兴趣标签") raise RuntimeError("❌ Embedding客户端未初始化无法生成兴趣标签")
# 构建提示词 # 构建提示词
@@ -190,8 +193,7 @@ class BotInterestManager:
interests_data = orjson.loads(response) interests_data = orjson.loads(response)
bot_interests = BotPersonalityInterests( bot_interests = BotPersonalityInterests(
personality_id=personality_id, personality_id=personality_id, personality_description=personality_description
personality_description=personality_description
) )
# 解析生成的兴趣标签 # 解析生成的兴趣标签
@@ -202,10 +204,7 @@ class BotInterestManager:
tag_name = tag_data.get("name", f"标签_{i}") tag_name = tag_data.get("name", f"标签_{i}")
weight = tag_data.get("weight", 0.5) weight = tag_data.get("weight", 0.5)
tag = BotInterestTag( tag = BotInterestTag(tag_name=tag_name, weight=weight)
tag_name=tag_name,
weight=weight
)
bot_interests.interest_tags.append(tag) bot_interests.interest_tags.append(tag)
logger.debug(f" 🏷️ {tag_name} (权重: {weight:.2f})") logger.debug(f" 🏷️ {tag_name} (权重: {weight:.2f})")
@@ -225,7 +224,6 @@ class BotInterestManager:
traceback.print_exc() traceback.print_exc()
raise raise
async def _call_llm_for_interest_generation(self, prompt: str) -> Optional[str]: async def _call_llm_for_interest_generation(self, prompt: str) -> Optional[str]:
"""调用LLM生成兴趣标签""" """调用LLM生成兴趣标签"""
try: try:
@@ -241,10 +239,10 @@ class BotInterestManager:
{prompt} {prompt}
请确保返回格式为有效的JSON不要包含任何额外的文本、解释或代码块标记。只返回JSON对象本身。""" 请确保返回格式为有效的JSON不要包含任何额外的文本、解释或代码块标记。只返回JSON对象本身。"""
# 使用replyer模型配置 # 使用replyer模型配置
replyer_config = model_config.model_task_config.replyer replyer_config = model_config.model_task_config.replyer
# 调用LLM API # 调用LLM API
logger.info("🚀 正在通过LLM API发送请求...") logger.info("🚀 正在通过LLM API发送请求...")
success, response, reasoning_content, model_name = await llm_api.generate_with_model( success, response, reasoning_content, model_name = await llm_api.generate_with_model(
@@ -252,15 +250,17 @@ class BotInterestManager:
model_config=replyer_config, model_config=replyer_config,
request_type="interest_generation", request_type="interest_generation",
temperature=0.7, temperature=0.7,
max_tokens=2000 max_tokens=2000,
) )
if success and response: if success and response:
logger.info(f"✅ LLM响应成功模型: {model_name}, 响应长度: {len(response)} 字符") logger.info(f"✅ LLM响应成功模型: {model_name}, 响应长度: {len(response)} 字符")
logger.debug(f"📄 LLM响应内容: {response[:200]}..." if len(response) > 200 else f"📄 LLM响应内容: {response}") logger.debug(
f"📄 LLM响应内容: {response[:200]}..." if len(response) > 200 else f"📄 LLM响应内容: {response}"
)
if reasoning_content: if reasoning_content:
logger.debug(f"🧠 推理内容: {reasoning_content[:100]}...") logger.debug(f"🧠 推理内容: {reasoning_content[:100]}...")
# 清理响应内容,移除可能的代码块标记 # 清理响应内容,移除可能的代码块标记
cleaned_response = self._clean_llm_response(response) cleaned_response = self._clean_llm_response(response)
return cleaned_response return cleaned_response
@@ -277,25 +277,25 @@ class BotInterestManager:
def _clean_llm_response(self, response: str) -> str: def _clean_llm_response(self, response: str) -> str:
"""清理LLM响应移除代码块标记和其他非JSON内容""" """清理LLM响应移除代码块标记和其他非JSON内容"""
import re import re
# 移除 ```json 和 ``` 标记 # 移除 ```json 和 ``` 标记
cleaned = re.sub(r'```json\s*', '', response) cleaned = re.sub(r"```json\s*", "", response)
cleaned = re.sub(r'\s*```', '', cleaned) cleaned = re.sub(r"\s*```", "", cleaned)
# 移除可能的多余空格和换行 # 移除可能的多余空格和换行
cleaned = cleaned.strip() cleaned = cleaned.strip()
# 尝试提取JSON对象如果响应中有其他文本 # 尝试提取JSON对象如果响应中有其他文本
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL) json_match = re.search(r"\{.*\}", cleaned, re.DOTALL)
if json_match: if json_match:
cleaned = json_match.group(0) cleaned = json_match.group(0)
logger.debug(f"🧹 清理后的响应: {cleaned[:200]}..." if len(cleaned) > 200 else f"🧹 清理后的响应: {cleaned}") logger.debug(f"🧹 清理后的响应: {cleaned[:200]}..." if len(cleaned) > 200 else f"🧹 清理后的响应: {cleaned}")
return cleaned return cleaned
async def _generate_embeddings_for_tags(self, interests: BotPersonalityInterests): async def _generate_embeddings_for_tags(self, interests: BotPersonalityInterests):
"""为所有兴趣标签生成embedding""" """为所有兴趣标签生成embedding"""
if not hasattr(self, 'embedding_request'): if not hasattr(self, "embedding_request"):
raise RuntimeError("❌ Embedding客户端未初始化无法生成embedding") raise RuntimeError("❌ Embedding客户端未初始化无法生成embedding")
total_tags = len(interests.interest_tags) total_tags = len(interests.interest_tags)
@@ -342,7 +342,7 @@ class BotInterestManager:
async def _get_embedding(self, text: str) -> List[float]: async def _get_embedding(self, text: str) -> List[float]:
"""获取文本的embedding向量""" """获取文本的embedding向量"""
if not hasattr(self, 'embedding_request'): if not hasattr(self, "embedding_request"):
raise RuntimeError("❌ Embedding请求客户端未初始化") raise RuntimeError("❌ Embedding请求客户端未初始化")
# 检查缓存 # 检查缓存
@@ -376,7 +376,9 @@ class BotInterestManager:
logger.debug(f"✅ 消息embedding生成成功维度: {len(embedding)}") logger.debug(f"✅ 消息embedding生成成功维度: {len(embedding)}")
return embedding return embedding
async def _calculate_similarity_scores(self, result: InterestMatchResult, message_embedding: List[float], keywords: List[str]): async def _calculate_similarity_scores(
self, result: InterestMatchResult, message_embedding: List[float], keywords: List[str]
):
"""计算消息与兴趣标签的相似度分数""" """计算消息与兴趣标签的相似度分数"""
try: try:
if not self.current_interests: if not self.current_interests:
@@ -397,7 +399,9 @@ class BotInterestManager:
# 设置相似度阈值为0.3 # 设置相似度阈值为0.3
if similarity > 0.3: if similarity > 0.3:
result.add_match(tag.tag_name, weighted_score, keywords) result.add_match(tag.tag_name, weighted_score, keywords)
logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 加权分数={weighted_score:.3f}") logger.debug(
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 加权分数={weighted_score:.3f}"
)
except Exception as e: except Exception as e:
logger.error(f"❌ 计算相似度分数失败: {e}") logger.error(f"❌ 计算相似度分数失败: {e}")
@@ -455,7 +459,9 @@ class BotInterestManager:
match_count += 1 match_count += 1
high_similarity_count += 1 high_similarity_count += 1
result.add_match(tag.tag_name, enhanced_score, [tag.tag_name]) result.add_match(tag.tag_name, enhanced_score, [tag.tag_name])
logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [高匹配]") logger.debug(
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [高匹配]"
)
elif similarity > medium_threshold: elif similarity > medium_threshold:
# 中相似度:中等加成 # 中相似度:中等加成
@@ -463,7 +469,9 @@ class BotInterestManager:
match_count += 1 match_count += 1
medium_similarity_count += 1 medium_similarity_count += 1
result.add_match(tag.tag_name, enhanced_score, [tag.tag_name]) result.add_match(tag.tag_name, enhanced_score, [tag.tag_name])
logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [中匹配]") logger.debug(
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [中匹配]"
)
elif similarity > low_threshold: elif similarity > low_threshold:
# 低相似度:轻微加成 # 低相似度:轻微加成
@@ -471,7 +479,9 @@ class BotInterestManager:
match_count += 1 match_count += 1
low_similarity_count += 1 low_similarity_count += 1
result.add_match(tag.tag_name, enhanced_score, [tag.tag_name]) result.add_match(tag.tag_name, enhanced_score, [tag.tag_name])
logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [低匹配]") logger.debug(
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [低匹配]"
)
logger.info(f"📈 匹配统计: {match_count}/{len(active_tags)} 个标签超过阈值") logger.info(f"📈 匹配统计: {match_count}/{len(active_tags)} 个标签超过阈值")
logger.info(f"🔥 高相似度匹配(>{high_threshold}): {high_similarity_count}") logger.info(f"🔥 高相似度匹配(>{high_threshold}): {high_similarity_count}")
@@ -488,7 +498,9 @@ class BotInterestManager:
original_score = result.match_scores[tag_name] original_score = result.match_scores[tag_name]
bonus = keyword_bonus[tag_name] bonus = keyword_bonus[tag_name]
result.match_scores[tag_name] = original_score + bonus result.match_scores[tag_name] = original_score + bonus
logger.debug(f" 🏷️ '{tag_name}': 原始分数={original_score:.3f}, 奖励={bonus:.3f}, 最终分数={result.match_scores[tag_name]:.3f}") logger.debug(
f" 🏷️ '{tag_name}': 原始分数={original_score:.3f}, 奖励={bonus:.3f}, 最终分数={result.match_scores[tag_name]:.3f}"
)
# 计算总体分数 # 计算总体分数
result.calculate_overall_score() result.calculate_overall_score()
@@ -499,10 +511,11 @@ class BotInterestManager:
result.top_tag = top_tag_name result.top_tag = top_tag_name
logger.info(f"🏆 最佳匹配标签: '{top_tag_name}' (分数: {result.match_scores[top_tag_name]:.3f})") logger.info(f"🏆 最佳匹配标签: '{top_tag_name}' (分数: {result.match_scores[top_tag_name]:.3f})")
logger.info(f"📊 最终结果: 总分={result.overall_score:.3f}, 置信度={result.confidence:.3f}, 匹配标签数={len(result.matched_tags)}") logger.info(
f"📊 最终结果: 总分={result.overall_score:.3f}, 置信度={result.confidence:.3f}, 匹配标签数={len(result.matched_tags)}"
)
return result return result
def _calculate_keyword_match_bonus(self, keywords: List[str], matched_tags: List[str]) -> Dict[str, float]: def _calculate_keyword_match_bonus(self, keywords: List[str], matched_tags: List[str]) -> Dict[str, float]:
"""计算关键词直接匹配奖励""" """计算关键词直接匹配奖励"""
if not keywords or not matched_tags: if not keywords or not matched_tags:
@@ -522,17 +535,25 @@ class BotInterestManager:
# 完全匹配 # 完全匹配
if keyword_lower == tag_name_lower: if keyword_lower == tag_name_lower:
bonus += affinity_config.high_match_interest_threshold * 0.6 # 使用高匹配阈值的60%作为完全匹配奖励 bonus += affinity_config.high_match_interest_threshold * 0.6 # 使用高匹配阈值的60%作为完全匹配奖励
logger.debug(f" 🎯 关键词完全匹配: '{keyword}' == '{tag_name}' (+{affinity_config.high_match_interest_threshold * 0.6:.3f})") logger.debug(
f" 🎯 关键词完全匹配: '{keyword}' == '{tag_name}' (+{affinity_config.high_match_interest_threshold * 0.6:.3f})"
)
# 包含匹配 # 包含匹配
elif keyword_lower in tag_name_lower or tag_name_lower in keyword_lower: elif keyword_lower in tag_name_lower or tag_name_lower in keyword_lower:
bonus += affinity_config.medium_match_interest_threshold * 0.3 # 使用中匹配阈值的30%作为包含匹配奖励 bonus += (
logger.debug(f" 🎯 关键词包含匹配: '{keyword}''{tag_name}' (+{affinity_config.medium_match_interest_threshold * 0.3:.3f})") affinity_config.medium_match_interest_threshold * 0.3
) # 使用中匹配阈值的30%作为包含匹配奖励
logger.debug(
f" 🎯 关键词包含匹配: '{keyword}''{tag_name}' (+{affinity_config.medium_match_interest_threshold * 0.3:.3f})"
)
# 部分匹配(编辑距离) # 部分匹配(编辑距离)
elif self._calculate_partial_match(keyword_lower, tag_name_lower): elif self._calculate_partial_match(keyword_lower, tag_name_lower):
bonus += affinity_config.low_match_interest_threshold * 0.4 # 使用低匹配阈值的40%作为部分匹配奖励 bonus += affinity_config.low_match_interest_threshold * 0.4 # 使用低匹配阈值的40%作为部分匹配奖励
logger.debug(f" 🎯 关键词部分匹配: '{keyword}''{tag_name}' (+{affinity_config.low_match_interest_threshold * 0.4:.3f})") logger.debug(
f" 🎯 关键词部分匹配: '{keyword}''{tag_name}' (+{affinity_config.low_match_interest_threshold * 0.4:.3f})"
)
if bonus > 0: if bonus > 0:
bonus_dict[tag_name] = min(bonus, affinity_config.max_match_bonus) # 使用配置的最大奖励限制 bonus_dict[tag_name] = min(bonus, affinity_config.max_match_bonus) # 使用配置的最大奖励限制
@@ -608,12 +629,12 @@ class BotInterestManager:
with get_db_session() as session: with get_db_session() as session:
# 查询最新的兴趣标签配置 # 查询最新的兴趣标签配置
db_interests = session.query(DBBotPersonalityInterests).filter( db_interests = (
DBBotPersonalityInterests.personality_id == personality_id session.query(DBBotPersonalityInterests)
).order_by( .filter(DBBotPersonalityInterests.personality_id == personality_id)
DBBotPersonalityInterests.version.desc(), .order_by(DBBotPersonalityInterests.version.desc(), DBBotPersonalityInterests.last_updated.desc())
DBBotPersonalityInterests.last_updated.desc() .first()
).first() )
if db_interests: if db_interests:
logger.info(f"✅ 找到数据库中的兴趣标签配置,版本: {db_interests.version}") logger.info(f"✅ 找到数据库中的兴趣标签配置,版本: {db_interests.version}")
@@ -631,7 +652,7 @@ class BotInterestManager:
personality_description=db_interests.personality_description, personality_description=db_interests.personality_description,
embedding_model=db_interests.embedding_model, embedding_model=db_interests.embedding_model,
version=db_interests.version, version=db_interests.version,
last_updated=db_interests.last_updated last_updated=db_interests.last_updated,
) )
# 解析兴趣标签 # 解析兴趣标签
@@ -639,10 +660,14 @@ class BotInterestManager:
tag = BotInterestTag( tag = BotInterestTag(
tag_name=tag_data.get("tag_name", ""), tag_name=tag_data.get("tag_name", ""),
weight=tag_data.get("weight", 0.5), weight=tag_data.get("weight", 0.5),
created_at=datetime.fromisoformat(tag_data.get("created_at", datetime.now().isoformat())), created_at=datetime.fromisoformat(
updated_at=datetime.fromisoformat(tag_data.get("updated_at", datetime.now().isoformat())), tag_data.get("created_at", datetime.now().isoformat())
),
updated_at=datetime.fromisoformat(
tag_data.get("updated_at", datetime.now().isoformat())
),
is_active=tag_data.get("is_active", True), is_active=tag_data.get("is_active", True),
embedding=tag_data.get("embedding") embedding=tag_data.get("embedding"),
) )
interests.interest_tags.append(tag) interests.interest_tags.append(tag)
@@ -685,7 +710,7 @@ class BotInterestManager:
"created_at": tag.created_at.isoformat(), "created_at": tag.created_at.isoformat(),
"updated_at": tag.updated_at.isoformat(), "updated_at": tag.updated_at.isoformat(),
"is_active": tag.is_active, "is_active": tag.is_active,
"embedding": tag.embedding "embedding": tag.embedding,
} }
tags_data.append(tag_dict) tags_data.append(tag_dict)
@@ -694,9 +719,11 @@ class BotInterestManager:
with get_db_session() as session: with get_db_session() as session:
# 检查是否已存在相同personality_id的记录 # 检查是否已存在相同personality_id的记录
existing_record = session.query(DBBotPersonalityInterests).filter( existing_record = (
DBBotPersonalityInterests.personality_id == interests.personality_id session.query(DBBotPersonalityInterests)
).first() .filter(DBBotPersonalityInterests.personality_id == interests.personality_id)
.first()
)
if existing_record: if existing_record:
# 更新现有记录 # 更新现有记录
@@ -718,7 +745,7 @@ class BotInterestManager:
interest_tags=json_data, interest_tags=json_data,
embedding_model=interests.embedding_model, embedding_model=interests.embedding_model,
version=interests.version, version=interests.version,
last_updated=interests.last_updated last_updated=interests.last_updated,
) )
session.add(new_record) session.add(new_record)
session.commit() session.commit()
@@ -728,9 +755,11 @@ class BotInterestManager:
# 验证保存是否成功 # 验证保存是否成功
with get_db_session() as session: with get_db_session() as session:
saved_record = session.query(DBBotPersonalityInterests).filter( saved_record = (
DBBotPersonalityInterests.personality_id == interests.personality_id session.query(DBBotPersonalityInterests)
).first() .filter(DBBotPersonalityInterests.personality_id == interests.personality_id)
.first()
)
session.commit() session.commit()
if saved_record: if saved_record:
logger.info(f"✅ 验证成功数据库中存在personality_id为 {interests.personality_id} 的记录") logger.info(f"✅ 验证成功数据库中存在personality_id为 {interests.personality_id} 的记录")
@@ -760,7 +789,7 @@ class BotInterestManager:
"total_tags": len(active_tags), "total_tags": len(active_tags),
"embedding_model": self.current_interests.embedding_model, "embedding_model": self.current_interests.embedding_model,
"last_updated": self.current_interests.last_updated.isoformat(), "last_updated": self.current_interests.last_updated.isoformat(),
"cache_size": len(self.embedding_cache) "cache_size": len(self.embedding_cache),
} }
async def update_interest_tags(self, new_personality_description: str = None): async def update_interest_tags(self, new_personality_description: str = None):
@@ -775,8 +804,7 @@ class BotInterestManager:
# 重新生成兴趣标签 # 重新生成兴趣标签
new_interests = await self._generate_interests_from_personality( new_interests = await self._generate_interests_from_personality(
self.current_interests.personality_description, self.current_interests.personality_description, self.current_interests.personality_id
self.current_interests.personality_id
) )
if new_interests: if new_interests:
@@ -791,4 +819,4 @@ class BotInterestManager:
# 创建全局实例(重新创建以包含新的属性) # 创建全局实例(重新创建以包含新的属性)
bot_interest_manager = BotInterestManager() bot_interest_manager = BotInterestManager()

View File

@@ -4,13 +4,11 @@
""" """
from .message_manager import MessageManager, message_manager from .message_manager import MessageManager, message_manager
from src.common.data_models.message_manager_data_model import StreamContext, MessageStatus, MessageManagerStats, StreamStats from src.common.data_models.message_manager_data_model import (
StreamContext,
MessageStatus,
MessageManagerStats,
StreamStats,
)
__all__ = [ __all__ = ["MessageManager", "message_manager", "StreamContext", "MessageStatus", "MessageManagerStats", "StreamStats"]
"MessageManager",
"message_manager",
"StreamContext",
"MessageStatus",
"MessageManagerStats",
"StreamStats"
]

View File

@@ -2,6 +2,7 @@
消息管理模块 消息管理模块
管理每个聊天流的上下文信息,包含历史记录和未读消息,定期检查并处理新消息 管理每个聊天流的上下文信息,包含历史记录和未读消息,定期检查并处理新消息
""" """
import asyncio import asyncio
import time import time
import traceback import traceback
@@ -100,9 +101,7 @@ class MessageManager:
# 如果没有处理任务,创建一个 # 如果没有处理任务,创建一个
if not context.processing_task or context.processing_task.done(): if not context.processing_task or context.processing_task.done():
context.processing_task = asyncio.create_task( context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
self._process_stream_messages(stream_id)
)
# 更新统计 # 更新统计
self.stats.active_streams = active_streams self.stats.active_streams = active_streams
@@ -128,11 +127,11 @@ class MessageManager:
try: try:
# 发送到AFC处理器传递StreamContext对象 # 发送到AFC处理器传递StreamContext对象
results = await afc_manager.process_stream_context(stream_id, context) results = await afc_manager.process_stream_context(stream_id, context)
# 处理结果,标记消息为已读 # 处理结果,标记消息为已读
if results.get("success", False): if results.get("success", False):
self._clear_all_unread_messages(context) self._clear_all_unread_messages(context)
except Exception as e: except Exception as e:
logger.error(f"处理聊天流 {stream_id} 时发生异常,将清除所有未读消息: {e}") logger.error(f"处理聊天流 {stream_id} 时发生异常,将清除所有未读消息: {e}")
raise raise
@@ -175,7 +174,7 @@ class MessageManager:
unread_count=len(context.get_unread_messages()), unread_count=len(context.get_unread_messages()),
history_count=len(context.history_messages), history_count=len(context.history_messages),
last_check_time=context.last_check_time, last_check_time=context.last_check_time,
has_active_task=context.processing_task and not context.processing_task.done() has_active_task=context.processing_task and not context.processing_task.done(),
) )
def get_manager_stats(self) -> Dict[str, Any]: def get_manager_stats(self) -> Dict[str, Any]:
@@ -186,7 +185,7 @@ class MessageManager:
"total_unread_messages": self.stats.total_unread_messages, "total_unread_messages": self.stats.total_unread_messages,
"total_processed_messages": self.stats.total_processed_messages, "total_processed_messages": self.stats.total_processed_messages,
"uptime": self.stats.uptime, "uptime": self.stats.uptime,
"start_time": self.stats.start_time "start_time": self.stats.start_time,
} }
def cleanup_inactive_streams(self, max_inactive_hours: int = 24): def cleanup_inactive_streams(self, max_inactive_hours: int = 24):
@@ -196,8 +195,7 @@ class MessageManager:
inactive_streams = [] inactive_streams = []
for stream_id, context in self.stream_contexts.items(): for stream_id, context in self.stream_contexts.items():
if (current_time - context.last_check_time > max_inactive_seconds and if current_time - context.last_check_time > max_inactive_seconds and not context.get_unread_messages():
not context.get_unread_messages()):
inactive_streams.append(stream_id) inactive_streams.append(stream_id)
for stream_id in inactive_streams: for stream_id in inactive_streams:
@@ -210,9 +208,9 @@ class MessageManager:
unread_messages = context.get_unread_messages() unread_messages = context.get_unread_messages()
if not unread_messages: if not unread_messages:
return return
logger.warning(f"正在清除 {len(unread_messages)} 条未读消息") logger.warning(f"正在清除 {len(unread_messages)} 条未读消息")
# 将所有未读消息标记为已读并移动到历史记录 # 将所有未读消息标记为已读并移动到历史记录
for msg in unread_messages[:]: # 使用切片复制避免迭代时修改列表 for msg in unread_messages[:]: # 使用切片复制避免迭代时修改列表
try: try:
@@ -224,4 +222,4 @@ class MessageManager:
# 创建全局消息管理器实例 # 创建全局消息管理器实例
message_manager = MessageManager() message_manager = MessageManager()

View File

@@ -17,6 +17,7 @@ from src.plugin_system.core import component_registry, event_manager, global_ann
from src.plugin_system.base import BaseCommand, EventType from src.plugin_system.base import BaseCommand, EventType
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
from src.chat.utils.utils import is_mentioned_bot_in_message from src.chat.utils.utils import is_mentioned_bot_in_message
# 导入反注入系统 # 导入反注入系统
from src.chat.antipromptinjector import initialize_anti_injector from src.chat.antipromptinjector import initialize_anti_injector
@@ -511,7 +512,7 @@ class ChatBot:
chat_info_user_id=message.chat_stream.user_info.user_id, chat_info_user_id=message.chat_stream.user_info.user_id,
chat_info_user_nickname=message.chat_stream.user_info.user_nickname, chat_info_user_nickname=message.chat_stream.user_info.user_nickname,
chat_info_user_cardname=message.chat_stream.user_info.user_cardname, chat_info_user_cardname=message.chat_stream.user_info.user_cardname,
chat_info_user_platform=message.chat_stream.user_info.platform chat_info_user_platform=message.chat_stream.user_info.platform,
) )
# 如果是群聊,添加群组信息 # 如果是群聊,添加群组信息

View File

@@ -84,7 +84,9 @@ class ChatStream:
self.saved = False self.saved = False
self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息 self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息
# 从配置文件中读取focus_value如果没有则使用默认值1.0 # 从配置文件中读取focus_value如果没有则使用默认值1.0
self.focus_energy = data.get("focus_energy", global_config.chat.focus_value) if data else global_config.chat.focus_value self.focus_energy = (
data.get("focus_energy", global_config.chat.focus_value) if data else global_config.chat.focus_value
)
self.no_reply_consecutive = 0 self.no_reply_consecutive = 0
self.breaking_accumulated_interest = 0.0 self.breaking_accumulated_interest = 0.0

View File

@@ -165,10 +165,15 @@ class ActionManager:
# 通过chat_id获取chat_stream # 通过chat_id获取chat_stream
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(chat_id) chat_stream = chat_manager.get_stream(chat_id)
if not chat_stream: if not chat_stream:
logger.error(f"{log_prefix} 无法找到chat_id对应的chat_stream: {chat_id}") logger.error(f"{log_prefix} 无法找到chat_id对应的chat_stream: {chat_id}")
return {"action_type": action_name, "success": False, "reply_text": "", "error": "chat_stream not found"} return {
"action_type": action_name,
"success": False,
"reply_text": "",
"error": "chat_stream not found",
}
if action_name == "no_action": if action_name == "no_action":
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""} return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
@@ -177,7 +182,7 @@ class ActionManager:
# 直接处理no_reply逻辑不再通过动作系统 # 直接处理no_reply逻辑不再通过动作系统
reason = reasoning or "选择不回复" reason = reasoning or "选择不回复"
logger.info(f"{log_prefix} 选择不回复,原因: {reason}") logger.info(f"{log_prefix} 选择不回复,原因: {reason}")
# 存储no_reply信息到数据库 # 存储no_reply信息到数据库
await database_api.store_action_info( await database_api.store_action_info(
chat_stream=chat_stream, chat_stream=chat_stream,
@@ -396,7 +401,7 @@ class ActionManager:
} }
return loop_info, reply_text, cycle_timers return loop_info, reply_text, cycle_timers
async def send_response(self, chat_stream, reply_set, thinking_start_time, message_data) -> str: async def send_response(self, chat_stream, reply_set, thinking_start_time, message_data) -> str:
""" """
发送回复内容的具体实现 发送回复内容的具体实现
@@ -471,4 +476,4 @@ class ActionManager:
typing=True, typing=True,
) )
return reply_text return reply_text

View File

@@ -1,6 +1,7 @@
""" """
PlanGenerator: 负责搜集和汇总所有决策所需的信息,生成一个未经筛选的“原始计划” (Plan)。 PlanGenerator: 负责搜集和汇总所有决策所需的信息,生成一个未经筛选的“原始计划” (Plan)。
""" """
import time import time
from typing import Dict from typing import Dict
@@ -35,6 +36,7 @@ class PlanGenerator:
chat_id (str): 当前聊天的 ID。 chat_id (str): 当前聊天的 ID。
""" """
from src.chat.planner_actions.action_manager import ActionManager from src.chat.planner_actions.action_manager import ActionManager
self.chat_id = chat_id self.chat_id = chat_id
# 注意ActionManager 可能需要根据实际情况初始化 # 注意ActionManager 可能需要根据实际情况初始化
self.action_manager = ActionManager() self.action_manager = ActionManager()
@@ -52,7 +54,7 @@ class PlanGenerator:
Plan: 一个填充了初始上下文信息的 Plan 对象。 Plan: 一个填充了初始上下文信息的 Plan 对象。
""" """
_is_group_chat, chat_target_info_dict = get_chat_type_and_target_info(self.chat_id) _is_group_chat, chat_target_info_dict = get_chat_type_and_target_info(self.chat_id)
target_info = None target_info = None
if chat_target_info_dict: if chat_target_info_dict:
target_info = TargetPersonInfo(**chat_target_info_dict) target_info = TargetPersonInfo(**chat_target_info_dict)
@@ -65,7 +67,6 @@ class PlanGenerator:
) )
chat_history = [DatabaseMessages(**msg) for msg in chat_history_raw] chat_history = [DatabaseMessages(**msg) for msg in chat_history_raw]
plan = Plan( plan = Plan(
chat_id=self.chat_id, chat_id=self.chat_id,
mode=mode, mode=mode,
@@ -86,10 +87,10 @@ class PlanGenerator:
Dict[str, "ActionInfo"]: 一个字典,键是动作名称,值是 ActionInfo 对象。 Dict[str, "ActionInfo"]: 一个字典,键是动作名称,值是 ActionInfo 对象。
""" """
current_available_actions_dict = self.action_manager.get_using_actions() current_available_actions_dict = self.action_manager.get_using_actions()
all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
ComponentType.ACTION ComponentType.ACTION
) )
current_available_actions = {} current_available_actions = {}
for action_name in current_available_actions_dict: for action_name in current_available_actions_dict:
if action_name in all_registered_actions: if action_name in all_registered_actions:
@@ -99,16 +100,13 @@ class PlanGenerator:
name="reply", name="reply",
component_type=ComponentType.ACTION, component_type=ComponentType.ACTION,
description="系统级动作:选择回复消息的决策", description="系统级动作:选择回复消息的决策",
action_parameters={ action_parameters={"content": "回复的文本内容", "reply_to_message_id": "要回复的消息ID"},
"content": "回复的文本内容",
"reply_to_message_id": "要回复的消息ID"
},
action_require=[ action_require=[
"你想要闲聊或者随便附和", "你想要闲聊或者随便附和",
"当用户提到你或艾特你时", "当用户提到你或艾特你时",
"当需要回答用户的问题时", "当需要回答用户的问题时",
"当你想参与对话时", "当你想参与对话时",
"当用户分享有趣的内容时" "当用户分享有趣的内容时",
], ],
activation_type=ActionActivationType.ALWAYS, activation_type=ActionActivationType.ALWAYS,
activation_keywords=[], activation_keywords=[],
@@ -131,4 +129,4 @@ class PlanGenerator:
) )
current_available_actions["no_reply"] = no_reply_info current_available_actions["no_reply"] = no_reply_info
current_available_actions["reply"] = reply_info current_available_actions["reply"] = reply_info
return current_available_actions return current_available_actions

View File

@@ -109,9 +109,7 @@ class ActionPlanner:
self.planner_stats["failed_plans"] += 1 self.planner_stats["failed_plans"] += 1
return [], None return [], None
async def _enhanced_plan_flow( async def _enhanced_plan_flow(self, mode: ChatMode, context: StreamContext) -> Tuple[List[Dict], Optional[Dict]]:
self, mode: ChatMode, context: StreamContext
) -> Tuple[List[Dict], Optional[Dict]]:
"""执行增强版规划流程""" """执行增强版规划流程"""
try: try:
# 1. 生成初始 Plan # 1. 生成初始 Plan
@@ -137,7 +135,9 @@ class ActionPlanner:
# 检查兴趣度是否达到非回复动作阈值 # 检查兴趣度是否达到非回复动作阈值
non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
if score < non_reply_action_interest_threshold: if score < non_reply_action_interest_threshold:
logger.info(f"❌ 兴趣度不足非回复动作阈值: {score:.3f} < {non_reply_action_interest_threshold:.3f}直接返回no_action") logger.info(
f"❌ 兴趣度不足非回复动作阈值: {score:.3f} < {non_reply_action_interest_threshold:.3f}直接返回no_action"
)
logger.info(f"📊 最低要求: {non_reply_action_interest_threshold:.3f}") logger.info(f"📊 最低要求: {non_reply_action_interest_threshold:.3f}")
# 直接返回 no_action # 直接返回 no_action
from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.data_models.info_data_model import ActionPlannerInfo

View File

@@ -303,7 +303,7 @@ class DefaultReplyer:
"model": model_name, "model": model_name,
"tool_calls": tool_call, "tool_calls": tool_call,
} }
# 触发 AFTER_LLM 事件 # 触发 AFTER_LLM 事件
if not from_plugin: if not from_plugin:
result = await event_manager.trigger_event( result = await event_manager.trigger_event(
@@ -598,6 +598,7 @@ class DefaultReplyer:
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]: def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
"""解析回复目标消息 - 使用共享工具""" """解析回复目标消息 - 使用共享工具"""
from src.chat.utils.prompt import Prompt from src.chat.utils.prompt import Prompt
if target_message is None: if target_message is None:
logger.warning("target_message为None返回默认值") logger.warning("target_message为None返回默认值")
return "未知用户", "(无消息内容)" return "未知用户", "(无消息内容)"
@@ -704,22 +705,24 @@ class DefaultReplyer:
unread_history_prompt = "" unread_history_prompt = ""
if unread_messages: if unread_messages:
# 尝试获取兴趣度评分 # 尝试获取兴趣度评分
interest_scores = await self._get_interest_scores_for_messages([msg.flatten() for msg in unread_messages]) interest_scores = await self._get_interest_scores_for_messages(
[msg.flatten() for msg in unread_messages]
)
unread_lines = [] unread_lines = []
for msg in unread_messages: for msg in unread_messages:
msg_id = msg.message_id msg_id = msg.message_id
msg_time = time.strftime('%H:%M:%S', time.localtime(msg.time)) msg_time = time.strftime("%H:%M:%S", time.localtime(msg.time))
msg_content = msg.processed_plain_text msg_content = msg.processed_plain_text
# 使用与已读历史消息相同的方法获取用户名 # 使用与已读历史消息相同的方法获取用户名
from src.person_info.person_info import PersonInfoManager, get_person_info_manager from src.person_info.person_info import PersonInfoManager, get_person_info_manager
# 获取用户信息 # 获取用户信息
user_info = getattr(msg, 'user_info', {}) user_info = getattr(msg, "user_info", {})
platform = getattr(user_info, 'platform', '') or getattr(msg, 'platform', '') platform = getattr(user_info, "platform", "") or getattr(msg, "platform", "")
user_id = getattr(user_info, 'user_id', '') or getattr(msg, 'user_id', '') user_id = getattr(user_info, "user_id", "") or getattr(msg, "user_id", "")
# 获取用户名 # 获取用户名
if platform and user_id: if platform and user_id:
person_id = PersonInfoManager.get_person_id(platform, user_id) person_id = PersonInfoManager.get_person_id(platform, user_id)
@@ -727,11 +730,11 @@ class DefaultReplyer:
sender_name = person_info_manager.get_value_sync(person_id, "person_name") or "未知用户" sender_name = person_info_manager.get_value_sync(person_id, "person_name") or "未知用户"
else: else:
sender_name = "未知用户" sender_name = "未知用户"
# 添加兴趣度信息 # 添加兴趣度信息
interest_score = interest_scores.get(msg_id, 0.0) interest_score = interest_scores.get(msg_id, 0.0)
interest_text = f" [兴趣度: {interest_score:.3f}]" if interest_score > 0 else "" interest_text = f" [兴趣度: {interest_score:.3f}]" if interest_score > 0 else ""
unread_lines.append(f"{msg_time} {sender_name}: {msg_content}{interest_text}") unread_lines.append(f"{msg_time} {sender_name}: {msg_content}{interest_text}")
unread_history_prompt_str = "\n".join(unread_lines) unread_history_prompt_str = "\n".join(unread_lines)
@@ -808,17 +811,17 @@ class DefaultReplyer:
unread_lines = [] unread_lines = []
for msg in unread_messages: for msg in unread_messages:
msg_id = msg.get("message_id", "") msg_id = msg.get("message_id", "")
msg_time = time.strftime('%H:%M:%S', time.localtime(msg.get("time", time.time()))) msg_time = time.strftime("%H:%M:%S", time.localtime(msg.get("time", time.time())))
msg_content = msg.get("processed_plain_text", "") msg_content = msg.get("processed_plain_text", "")
# 使用与已读历史消息相同的方法获取用户名 # 使用与已读历史消息相同的方法获取用户名
from src.person_info.person_info import PersonInfoManager, get_person_info_manager from src.person_info.person_info import PersonInfoManager, get_person_info_manager
# 获取用户信息 # 获取用户信息
user_info = msg.get("user_info", {}) user_info = msg.get("user_info", {})
platform = user_info.get("platform") or msg.get("platform", "") platform = user_info.get("platform") or msg.get("platform", "")
user_id = user_info.get("user_id") or msg.get("user_id", "") user_id = user_info.get("user_id") or msg.get("user_id", "")
# 获取用户名 # 获取用户名
if platform and user_id: if platform and user_id:
person_id = PersonInfoManager.get_person_id(platform, user_id) person_id = PersonInfoManager.get_person_id(platform, user_id)
@@ -834,7 +837,9 @@ class DefaultReplyer:
unread_lines.append(f"{msg_time} {sender_name}: {msg_content}{interest_text}") unread_lines.append(f"{msg_time} {sender_name}: {msg_content}{interest_text}")
unread_history_prompt_str = "\n".join(unread_lines) unread_history_prompt_str = "\n".join(unread_lines)
unread_history_prompt = f"这是未读历史消息,包含兴趣度评分,请优先对兴趣值高的消息做出动作:\n{unread_history_prompt_str}" unread_history_prompt = (
f"这是未读历史消息,包含兴趣度评分,请优先对兴趣值高的消息做出动作:\n{unread_history_prompt_str}"
)
else: else:
unread_history_prompt = "暂无未读历史消息" unread_history_prompt = "暂无未读历史消息"
@@ -982,7 +987,7 @@ class DefaultReplyer:
reply_message.get("user_id"), # type: ignore reply_message.get("user_id"), # type: ignore
) )
person_name = await person_info_manager.get_value(person_id, "person_name") person_name = await person_info_manager.get_value(person_id, "person_name")
# 如果person_name为None使用fallback值 # 如果person_name为None使用fallback值
if person_name is None: if person_name is None:
# 尝试从reply_message获取用户名 # 尝试从reply_message获取用户名
@@ -990,12 +995,12 @@ class DefaultReplyer:
logger.warning(f"未知用户,将存储用户信息:{fallback_name}") logger.warning(f"未知用户,将存储用户信息:{fallback_name}")
person_name = str(fallback_name) person_name = str(fallback_name)
person_info_manager.set_value(person_id, "person_name", fallback_name) person_info_manager.set_value(person_id, "person_name", fallback_name)
# 检查是否是bot自己的名字如果是则替换为"(你)" # 检查是否是bot自己的名字如果是则替换为"(你)"
bot_user_id = str(global_config.bot.qq_account) bot_user_id = str(global_config.bot.qq_account)
current_user_id = person_info_manager.get_value_sync(person_id, "user_id") current_user_id = person_info_manager.get_value_sync(person_id, "user_id")
current_platform = reply_message.get("chat_info_platform") current_platform = reply_message.get("chat_info_platform")
if current_user_id == bot_user_id and current_platform == global_config.bot.platform: if current_user_id == bot_user_id and current_platform == global_config.bot.platform:
sender = f"{person_name}(你)" sender = f"{person_name}(你)"
else: else:
@@ -1050,8 +1055,9 @@ class DefaultReplyer:
target_user_info = None target_user_info = None
if sender: if sender:
target_user_info = await person_info_manager.get_person_info_by_name(sender) target_user_info = await person_info_manager.get_person_info_by_name(sender)
from src.chat.utils.prompt import Prompt from src.chat.utils.prompt import Prompt
# 并行执行六个构建任务 # 并行执行六个构建任务
task_results = await asyncio.gather( task_results = await asyncio.gather(
self._time_and_run_task( self._time_and_run_task(
@@ -1127,6 +1133,7 @@ class DefaultReplyer:
schedule_block = "" schedule_block = ""
if global_config.planning_system.schedule_enable: if global_config.planning_system.schedule_enable:
from src.schedule.schedule_manager import schedule_manager from src.schedule.schedule_manager import schedule_manager
current_activity = schedule_manager.get_current_activity() current_activity = schedule_manager.get_current_activity()
if current_activity: if current_activity:
schedule_block = f"你当前正在:{current_activity}" schedule_block = f"你当前正在:{current_activity}"
@@ -1139,7 +1146,7 @@ class DefaultReplyer:
safety_guidelines = global_config.personality.safety_guidelines safety_guidelines = global_config.personality.safety_guidelines
safety_guidelines_block = "" safety_guidelines_block = ""
if safety_guidelines: if safety_guidelines:
guidelines_text = "\n".join(f"{i+1}. {line}" for i, line in enumerate(safety_guidelines)) guidelines_text = "\n".join(f"{i + 1}. {line}" for i, line in enumerate(safety_guidelines))
safety_guidelines_block = f"""### 安全与互动底线 safety_guidelines_block = f"""### 安全与互动底线
在任何情况下,你都必须遵守以下由你的设定者为你定义的原则: 在任何情况下,你都必须遵守以下由你的设定者为你定义的原则:
{guidelines_text} {guidelines_text}
@@ -1212,7 +1219,7 @@ class DefaultReplyer:
template_name = "normal_style_prompt" template_name = "normal_style_prompt"
elif current_prompt_mode == "minimal": elif current_prompt_mode == "minimal":
template_name = "default_expressor_prompt" template_name = "default_expressor_prompt"
# 获取模板内容 # 获取模板内容
template_prompt = await global_prompt_manager.get_prompt_async(template_name) template_prompt = await global_prompt_manager.get_prompt_async(template_name)
prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters) prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters)
@@ -1488,19 +1495,19 @@ class DefaultReplyer:
# 使用AFC关系追踪器获取关系信息 # 使用AFC关系追踪器获取关系信息
try: try:
from src.chat.affinity_flow.relationship_integration import get_relationship_tracker from src.chat.affinity_flow.relationship_integration import get_relationship_tracker
relationship_tracker = get_relationship_tracker() relationship_tracker = get_relationship_tracker()
if relationship_tracker: if relationship_tracker:
# 获取用户信息以获取真实的user_id # 获取用户信息以获取真实的user_id
user_info = await person_info_manager.get_values(person_id, ["user_id", "platform"]) user_info = await person_info_manager.get_values(person_id, ["user_id", "platform"])
user_id = user_info.get("user_id", "unknown") user_id = user_info.get("user_id", "unknown")
# 从数据库获取关系数据 # 从数据库获取关系数据
relationship_data = relationship_tracker._get_user_relationship_from_db(user_id) relationship_data = relationship_tracker._get_user_relationship_from_db(user_id)
if relationship_data: if relationship_data:
relationship_text = relationship_data.get("relationship_text", "") relationship_text = relationship_data.get("relationship_text", "")
relationship_score = relationship_data.get("relationship_score", 0.3) relationship_score = relationship_data.get("relationship_score", 0.3)
# 构建丰富的关系信息描述 # 构建丰富的关系信息描述
if relationship_text: if relationship_text:
# 转换关系分数为描述性文本 # 转换关系分数为描述性文本
@@ -1514,7 +1521,7 @@ class DefaultReplyer:
relationship_level = "认识的人" relationship_level = "认识的人"
else: else:
relationship_level = "陌生人" relationship_level = "陌生人"
return f"你与{sender}的关系:{relationship_level}(关系分:{relationship_score:.2f}/1.0)。{relationship_text}" return f"你与{sender}的关系:{relationship_level}(关系分:{relationship_score:.2f}/1.0)。{relationship_text}"
else: else:
return f"你与{sender}是初次见面,关系分:{relationship_score:.2f}/1.0。" return f"你与{sender}是初次见面,关系分:{relationship_score:.2f}/1.0。"
@@ -1523,7 +1530,7 @@ class DefaultReplyer:
else: else:
logger.warning("AFC关系追踪器未初始化使用默认关系信息") logger.warning("AFC关系追踪器未初始化使用默认关系信息")
return f"你与{sender}是普通朋友关系。" return f"你与{sender}是普通朋友关系。"
except Exception as e: except Exception as e:
logger.error(f"获取AFC关系信息失败: {e}") logger.error(f"获取AFC关系信息失败: {e}")
return f"你与{sender}是普通朋友关系。" return f"你与{sender}是普通朋友关系。"

View File

@@ -37,7 +37,7 @@ def replace_user_references_sync(
""" """
if not content: if not content:
return "" return ""
if name_resolver is None: if name_resolver is None:
person_info_manager = get_person_info_manager() person_info_manager = get_person_info_manager()
@@ -821,7 +821,7 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
try: try:
with get_db_session() as session: with get_db_session() as session:
image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar_one_or_none() image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar_one_or_none()
if image and image.description: # type: ignore if image and image.description: # type: ignore
description = image.description description = image.description
except Exception: except Exception:
# 如果查询失败,保持默认描述 # 如果查询失败,保持默认描述

View File

@@ -25,7 +25,7 @@ logger = get_logger("unified_prompt")
@dataclass @dataclass
class PromptParameters: class PromptParameters:
"""统一提示词参数系统""" """统一提示词参数系统"""
# 基础参数 # 基础参数
chat_id: str = "" chat_id: str = ""
is_group_chat: bool = False is_group_chat: bool = False
@@ -34,7 +34,7 @@ class PromptParameters:
reply_to: str = "" reply_to: str = ""
extra_info: str = "" extra_info: str = ""
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u" prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
# 功能开关 # 功能开关
enable_tool: bool = True enable_tool: bool = True
enable_memory: bool = True enable_memory: bool = True
@@ -42,20 +42,20 @@ class PromptParameters:
enable_relation: bool = True enable_relation: bool = True
enable_cross_context: bool = True enable_cross_context: bool = True
enable_knowledge: bool = True enable_knowledge: bool = True
# 性能控制 # 性能控制
max_context_messages: int = 50 max_context_messages: int = 50
# 调试选项 # 调试选项
debug_mode: bool = False debug_mode: bool = False
# 聊天历史和上下文 # 聊天历史和上下文
chat_target_info: Optional[Dict[str, Any]] = None chat_target_info: Optional[Dict[str, Any]] = None
message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list) message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list)
message_list_before_short: List[Dict[str, Any]] = field(default_factory=list) message_list_before_short: List[Dict[str, Any]] = field(default_factory=list)
chat_talking_prompt_short: str = "" chat_talking_prompt_short: str = ""
target_user_info: Optional[Dict[str, Any]] = None target_user_info: Optional[Dict[str, Any]] = None
# 已构建的内容块 # 已构建的内容块
expression_habits_block: str = "" expression_habits_block: str = ""
relation_info_block: str = "" relation_info_block: str = ""
@@ -63,7 +63,7 @@ class PromptParameters:
tool_info_block: str = "" tool_info_block: str = ""
knowledge_prompt: str = "" knowledge_prompt: str = ""
cross_context_block: str = "" cross_context_block: str = ""
# 其他内容块 # 其他内容块
keywords_reaction_prompt: str = "" keywords_reaction_prompt: str = ""
extra_info_block: str = "" extra_info_block: str = ""
@@ -75,10 +75,10 @@ class PromptParameters:
reply_target_block: str = "" reply_target_block: str = ""
mood_prompt: str = "" mood_prompt: str = ""
action_descriptions: str = "" action_descriptions: str = ""
# 可用动作信息 # 可用动作信息
available_actions: Optional[Dict[str, Any]] = None available_actions: Optional[Dict[str, Any]] = None
def validate(self) -> List[str]: def validate(self) -> List[str]:
"""参数验证""" """参数验证"""
errors = [] errors = []
@@ -93,22 +93,22 @@ class PromptParameters:
class PromptContext: class PromptContext:
"""提示词上下文管理器""" """提示词上下文管理器"""
def __init__(self): def __init__(self):
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {} self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
self._current_context_var = contextvars.ContextVar("current_context", default=None) self._current_context_var = contextvars.ContextVar("current_context", default=None)
self._context_lock = asyncio.Lock() self._context_lock = asyncio.Lock()
@property @property
def _current_context(self) -> Optional[str]: def _current_context(self) -> Optional[str]:
"""获取当前协程的上下文ID""" """获取当前协程的上下文ID"""
return self._current_context_var.get() return self._current_context_var.get()
@_current_context.setter @_current_context.setter
def _current_context(self, value: Optional[str]): def _current_context(self, value: Optional[str]):
"""设置当前协程的上下文ID""" """设置当前协程的上下文ID"""
self._current_context_var.set(value) # type: ignore self._current_context_var.set(value) # type: ignore
@asynccontextmanager @asynccontextmanager
async def async_scope(self, context_id: Optional[str] = None): async def async_scope(self, context_id: Optional[str] = None):
"""创建一个异步的临时提示模板作用域""" """创建一个异步的临时提示模板作用域"""
@@ -123,13 +123,13 @@ class PromptContext:
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning(f"获取上下文锁超时context_id: {context_id}") logger.warning(f"获取上下文锁超时context_id: {context_id}")
context_id = None context_id = None
previous_context = self._current_context previous_context = self._current_context
token = self._current_context_var.set(context_id) if context_id else None token = self._current_context_var.set(context_id) if context_id else None
else: else:
previous_context = self._current_context previous_context = self._current_context
token = None token = None
try: try:
yield self yield self
finally: finally:
@@ -142,7 +142,7 @@ class PromptContext:
self._current_context = previous_context self._current_context = previous_context
except Exception: except Exception:
... ...
async def get_prompt_async(self, name: str) -> Optional["Prompt"]: async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
"""异步获取当前作用域中的提示模板""" """异步获取当前作用域中的提示模板"""
async with self._context_lock: async with self._context_lock:
@@ -155,7 +155,7 @@ class PromptContext:
): ):
return self._context_prompts[current_context][name] return self._context_prompts[current_context][name]
return None return None
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None: async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
"""异步注册提示模板到指定作用域""" """异步注册提示模板到指定作用域"""
async with self._context_lock: async with self._context_lock:
@@ -166,49 +166,49 @@ class PromptContext:
class PromptManager: class PromptManager:
"""统一提示词管理器""" """统一提示词管理器"""
def __init__(self): def __init__(self):
self._prompts = {} self._prompts = {}
self._counter = 0 self._counter = 0
self._context = PromptContext() self._context = PromptContext()
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
@asynccontextmanager @asynccontextmanager
async def async_message_scope(self, message_id: Optional[str] = None): async def async_message_scope(self, message_id: Optional[str] = None):
"""为消息处理创建异步临时作用域""" """为消息处理创建异步临时作用域"""
async with self._context.async_scope(message_id): async with self._context.async_scope(message_id):
yield self yield self
async def get_prompt_async(self, name: str) -> "Prompt": async def get_prompt_async(self, name: str) -> "Prompt":
"""异步获取提示模板""" """异步获取提示模板"""
context_prompt = await self._context.get_prompt_async(name) context_prompt = await self._context.get_prompt_async(name)
if context_prompt is not None: if context_prompt is not None:
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}") logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
return context_prompt return context_prompt
async with self._lock: async with self._lock:
if name not in self._prompts: if name not in self._prompts:
raise KeyError(f"Prompt '{name}' not found") raise KeyError(f"Prompt '{name}' not found")
return self._prompts[name] return self._prompts[name]
def generate_name(self, template: str) -> str: def generate_name(self, template: str) -> str:
"""为未命名的prompt生成名称""" """为未命名的prompt生成名称"""
self._counter += 1 self._counter += 1
return f"prompt_{self._counter}" return f"prompt_{self._counter}"
def register(self, prompt: "Prompt") -> None: def register(self, prompt: "Prompt") -> None:
"""注册一个prompt""" """注册一个prompt"""
if not prompt.name: if not prompt.name:
prompt.name = self.generate_name(prompt.template) prompt.name = self.generate_name(prompt.template)
self._prompts[prompt.name] = prompt self._prompts[prompt.name] = prompt
def add_prompt(self, name: str, fstr: str) -> "Prompt": def add_prompt(self, name: str, fstr: str) -> "Prompt":
"""添加新提示模板""" """添加新提示模板"""
prompt = Prompt(fstr, name=name) prompt = Prompt(fstr, name=name)
if prompt.name: if prompt.name:
self._prompts[prompt.name] = prompt self._prompts[prompt.name] = prompt
return prompt return prompt
async def format_prompt(self, name: str, **kwargs) -> str: async def format_prompt(self, name: str, **kwargs) -> str:
"""格式化提示模板""" """格式化提示模板"""
prompt = await self.get_prompt_async(name) prompt = await self.get_prompt_async(name)
@@ -225,21 +225,21 @@ class Prompt:
统一提示词类 - 合并模板管理和智能构建功能 统一提示词类 - 合并模板管理和智能构建功能
真正的Prompt类支持模板管理和智能上下文构建 真正的Prompt类支持模板管理和智能上下文构建
""" """
# 临时标记,作为类常量 # 临时标记,作为类常量
_TEMP_LEFT_BRACE = "__ESCAPED_LEFT_BRACE__" _TEMP_LEFT_BRACE = "__ESCAPED_LEFT_BRACE__"
_TEMP_RIGHT_BRACE = "__ESCAPED_RIGHT_BRACE__" _TEMP_RIGHT_BRACE = "__ESCAPED_RIGHT_BRACE__"
def __init__( def __init__(
self, self,
template: str, template: str,
name: Optional[str] = None, name: Optional[str] = None,
parameters: Optional[PromptParameters] = None, parameters: Optional[PromptParameters] = None,
should_register: bool = True should_register: bool = True,
): ):
""" """
初始化统一提示词 初始化统一提示词
Args: Args:
template: 提示词模板字符串 template: 提示词模板字符串
name: 提示词名称 name: 提示词名称
@@ -251,14 +251,14 @@ class Prompt:
self.parameters = parameters or PromptParameters() self.parameters = parameters or PromptParameters()
self.args = self._parse_template_args(template) self.args = self._parse_template_args(template)
self._formatted_result = "" self._formatted_result = ""
# 预处理模板中的转义花括号 # 预处理模板中的转义花括号
self._processed_template = self._process_escaped_braces(template) self._processed_template = self._process_escaped_braces(template)
# 自动注册 # 自动注册
if should_register and not global_prompt_manager._context._current_context: if should_register and not global_prompt_manager._context._current_context:
global_prompt_manager.register(self) global_prompt_manager.register(self)
@staticmethod @staticmethod
def _process_escaped_braces(template) -> str: def _process_escaped_braces(template) -> str:
"""处理模板中的转义花括号""" """处理模板中的转义花括号"""
@@ -266,14 +266,14 @@ class Prompt:
template = "\n".join(str(item) for item in template) template = "\n".join(str(item) for item in template)
elif not isinstance(template, str): elif not isinstance(template, str):
template = str(template) template = str(template)
return template.replace("\\{", Prompt._TEMP_LEFT_BRACE).replace("\\}", Prompt._TEMP_RIGHT_BRACE) return template.replace("\\{", Prompt._TEMP_LEFT_BRACE).replace("\\}", Prompt._TEMP_RIGHT_BRACE)
@staticmethod @staticmethod
def _restore_escaped_braces(template: str) -> str: def _restore_escaped_braces(template: str) -> str:
"""将临时标记还原为实际的花括号字符""" """将临时标记还原为实际的花括号字符"""
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}") return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
def _parse_template_args(self, template: str) -> List[str]: def _parse_template_args(self, template: str) -> List[str]:
"""解析模板参数""" """解析模板参数"""
template_args = [] template_args = []
@@ -283,11 +283,11 @@ class Prompt:
if expr and expr not in template_args: if expr and expr not in template_args:
template_args.append(expr) template_args.append(expr)
return template_args return template_args
async def build(self) -> str: async def build(self) -> str:
""" """
构建完整的提示词,包含智能上下文 构建完整的提示词,包含智能上下文
Returns: Returns:
str: 构建完成的提示词文本 str: 构建完成的提示词文本
""" """
@@ -296,38 +296,38 @@ class Prompt:
if errors: if errors:
logger.error(f"参数验证失败: {', '.join(errors)}") logger.error(f"参数验证失败: {', '.join(errors)}")
raise ValueError(f"参数验证失败: {', '.join(errors)}") raise ValueError(f"参数验证失败: {', '.join(errors)}")
start_time = time.time() start_time = time.time()
try: try:
# 构建上下文数据 # 构建上下文数据
context_data = await self._build_context_data() context_data = await self._build_context_data()
# 格式化模板 # 格式化模板
result = await self._format_with_context(context_data) result = await self._format_with_context(context_data)
total_time = time.time() - start_time total_time = time.time() - start_time
logger.debug(f"Prompt构建完成模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s") logger.debug(f"Prompt构建完成模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s")
self._formatted_result = result self._formatted_result = result
return result return result
except asyncio.TimeoutError as e: except asyncio.TimeoutError as e:
logger.error(f"构建Prompt超时: {e}") logger.error(f"构建Prompt超时: {e}")
raise TimeoutError(f"构建Prompt超时: {e}") from e raise TimeoutError(f"构建Prompt超时: {e}") from e
except Exception as e: except Exception as e:
logger.error(f"构建Prompt失败: {e}") logger.error(f"构建Prompt失败: {e}")
raise RuntimeError(f"构建Prompt失败: {e}") from e raise RuntimeError(f"构建Prompt失败: {e}") from e
async def _build_context_data(self) -> Dict[str, Any]: async def _build_context_data(self) -> Dict[str, Any]:
"""构建智能上下文数据""" """构建智能上下文数据"""
# 并行执行所有构建任务 # 并行执行所有构建任务
start_time = time.time() start_time = time.time()
try: try:
# 准备构建任务 # 准备构建任务
tasks = [] tasks = []
task_names = [] task_names = []
# 初始化预构建参数 # 初始化预构建参数
pre_built_params = {} pre_built_params = {}
if self.parameters.expression_habits_block: if self.parameters.expression_habits_block:
@@ -342,32 +342,32 @@ class Prompt:
pre_built_params["knowledge_prompt"] = self.parameters.knowledge_prompt pre_built_params["knowledge_prompt"] = self.parameters.knowledge_prompt
if self.parameters.cross_context_block: if self.parameters.cross_context_block:
pre_built_params["cross_context_block"] = self.parameters.cross_context_block pre_built_params["cross_context_block"] = self.parameters.cross_context_block
# 根据参数确定要构建的项 # 根据参数确定要构建的项
if self.parameters.enable_expression and not pre_built_params.get("expression_habits_block"): if self.parameters.enable_expression and not pre_built_params.get("expression_habits_block"):
tasks.append(self._build_expression_habits()) tasks.append(self._build_expression_habits())
task_names.append("expression_habits") task_names.append("expression_habits")
if self.parameters.enable_memory and not pre_built_params.get("memory_block"): if self.parameters.enable_memory and not pre_built_params.get("memory_block"):
tasks.append(self._build_memory_block()) tasks.append(self._build_memory_block())
task_names.append("memory_block") task_names.append("memory_block")
if self.parameters.enable_relation and not pre_built_params.get("relation_info_block"): if self.parameters.enable_relation and not pre_built_params.get("relation_info_block"):
tasks.append(self._build_relation_info()) tasks.append(self._build_relation_info())
task_names.append("relation_info") task_names.append("relation_info")
if self.parameters.enable_tool and not pre_built_params.get("tool_info_block"): if self.parameters.enable_tool and not pre_built_params.get("tool_info_block"):
tasks.append(self._build_tool_info()) tasks.append(self._build_tool_info())
task_names.append("tool_info") task_names.append("tool_info")
if self.parameters.enable_knowledge and not pre_built_params.get("knowledge_prompt"): if self.parameters.enable_knowledge and not pre_built_params.get("knowledge_prompt"):
tasks.append(self._build_knowledge_info()) tasks.append(self._build_knowledge_info())
task_names.append("knowledge_info") task_names.append("knowledge_info")
if self.parameters.enable_cross_context and not pre_built_params.get("cross_context_block"): if self.parameters.enable_cross_context and not pre_built_params.get("cross_context_block"):
tasks.append(self._build_cross_context()) tasks.append(self._build_cross_context())
task_names.append("cross_context") task_names.append("cross_context")
# 性能优化 # 性能优化
base_timeout = 10.0 base_timeout = 10.0
task_timeout = 2.0 task_timeout = 2.0
@@ -375,13 +375,13 @@ class Prompt:
max(base_timeout, len(tasks) * task_timeout), max(base_timeout, len(tasks) * task_timeout),
30.0, 30.0,
) )
max_concurrent_tasks = 5 max_concurrent_tasks = 5
if len(tasks) > max_concurrent_tasks: if len(tasks) > max_concurrent_tasks:
results = [] results = []
for i in range(0, len(tasks), max_concurrent_tasks): for i in range(0, len(tasks), max_concurrent_tasks):
batch_tasks = tasks[i : i + max_concurrent_tasks] batch_tasks = tasks[i : i + max_concurrent_tasks]
batch_results = await asyncio.wait_for( batch_results = await asyncio.wait_for(
asyncio.gather(*batch_tasks, return_exceptions=True), timeout=timeout_seconds asyncio.gather(*batch_tasks, return_exceptions=True), timeout=timeout_seconds
) )
@@ -390,53 +390,55 @@ class Prompt:
results = await asyncio.wait_for( results = await asyncio.wait_for(
asyncio.gather(*tasks, return_exceptions=True), timeout=timeout_seconds asyncio.gather(*tasks, return_exceptions=True), timeout=timeout_seconds
) )
# 处理结果 # 处理结果
context_data = {} context_data = {}
for i, result in enumerate(results): for i, result in enumerate(results):
task_name = task_names[i] if i < len(task_names) else f"task_{i}" task_name = task_names[i] if i < len(task_names) else f"task_{i}"
if isinstance(result, Exception): if isinstance(result, Exception):
logger.error(f"构建任务{task_name}失败: {str(result)}") logger.error(f"构建任务{task_name}失败: {str(result)}")
elif isinstance(result, dict): elif isinstance(result, dict):
context_data.update(result) context_data.update(result)
# 添加预构建的参数 # 添加预构建的参数
for key, value in pre_built_params.items(): for key, value in pre_built_params.items():
if value: if value:
context_data[key] = value context_data[key] = value
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.error(f"构建超时 ({timeout_seconds}s)") logger.error(f"构建超时 ({timeout_seconds}s)")
context_data = {} context_data = {}
for key, value in pre_built_params.items(): for key, value in pre_built_params.items():
if value: if value:
context_data[key] = value context_data[key] = value
# 构建聊天历史 # 构建聊天历史
if self.parameters.prompt_mode == "s4u": if self.parameters.prompt_mode == "s4u":
await self._build_s4u_chat_context(context_data) await self._build_s4u_chat_context(context_data)
else: else:
await self._build_normal_chat_context(context_data) await self._build_normal_chat_context(context_data)
# 补充基础信息 # 补充基础信息
context_data.update({ context_data.update(
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt, {
"extra_info_block": self.parameters.extra_info_block, "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt,
"time_block": self.parameters.time_block or f"当前时间:{time.strftime('%Y-%m-%d %H:%M:%S')}", "extra_info_block": self.parameters.extra_info_block,
"identity": self.parameters.identity_block, "time_block": self.parameters.time_block or f"当前时间:{time.strftime('%Y-%m-%d %H:%M:%S')}",
"schedule_block": self.parameters.schedule_block, "identity": self.parameters.identity_block,
"moderation_prompt": self.parameters.moderation_prompt_block, "schedule_block": self.parameters.schedule_block,
"reply_target_block": self.parameters.reply_target_block, "moderation_prompt": self.parameters.moderation_prompt_block,
"mood_state": self.parameters.mood_prompt, "reply_target_block": self.parameters.reply_target_block,
"action_descriptions": self.parameters.action_descriptions, "mood_state": self.parameters.mood_prompt,
}) "action_descriptions": self.parameters.action_descriptions,
}
)
total_time = time.time() - start_time total_time = time.time() - start_time
logger.debug(f"上下文构建完成,总耗时: {total_time:.2f}s") logger.debug(f"上下文构建完成,总耗时: {total_time:.2f}s")
return context_data return context_data
async def _build_s4u_chat_context(self, context_data: Dict[str, Any]) -> None: async def _build_s4u_chat_context(self, context_data: Dict[str, Any]) -> None:
"""构建S4U模式的聊天上下文""" """构建S4U模式的聊天上下文"""
if not self.parameters.message_list_before_now_long: if not self.parameters.message_list_before_now_long:
@@ -446,20 +448,20 @@ class Prompt:
self.parameters.message_list_before_now_long, self.parameters.message_list_before_now_long,
self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "", self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "",
self.parameters.sender, self.parameters.sender,
self.parameters.chat_id self.parameters.chat_id,
) )
context_data["read_history_prompt"] = read_history_prompt context_data["read_history_prompt"] = read_history_prompt
context_data["unread_history_prompt"] = unread_history_prompt context_data["unread_history_prompt"] = unread_history_prompt
async def _build_normal_chat_context(self, context_data: Dict[str, Any]) -> None: async def _build_normal_chat_context(self, context_data: Dict[str, Any]) -> None:
"""构建normal模式的聊天上下文""" """构建normal模式的聊天上下文"""
if not self.parameters.chat_talking_prompt_short: if not self.parameters.chat_talking_prompt_short:
return return
context_data["chat_info"] = f"""群里的聊天内容: context_data["chat_info"] = f"""群里的聊天内容:
{self.parameters.chat_talking_prompt_short}""" {self.parameters.chat_talking_prompt_short}"""
async def _build_s4u_chat_history_prompts( async def _build_s4u_chat_history_prompts(
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, chat_id: str self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, chat_id: str
) -> Tuple[str, str]: ) -> Tuple[str, str]:
@@ -476,101 +478,92 @@ class Prompt:
except Exception as e: except Exception as e:
logger.error(f"构建S4U历史消息prompt失败: {e}") logger.error(f"构建S4U历史消息prompt失败: {e}")
async def _build_expression_habits(self) -> Dict[str, Any]: async def _build_expression_habits(self) -> Dict[str, Any]:
"""构建表达习惯""" """构建表达习惯"""
if not global_config.expression.enable_expression: if not global_config.expression.enable_expression:
return {"expression_habits_block": ""} return {"expression_habits_block": ""}
try: try:
from src.chat.express.expression_selector import ExpressionSelector from src.chat.express.expression_selector import ExpressionSelector
# 获取聊天历史用于表情选择 # 获取聊天历史用于表情选择
chat_history = "" chat_history = ""
if self.parameters.message_list_before_now_long: if self.parameters.message_list_before_now_long:
recent_messages = self.parameters.message_list_before_now_long[-10:] recent_messages = self.parameters.message_list_before_now_long[-10:]
chat_history = build_readable_messages( chat_history = build_readable_messages(
recent_messages, recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
replace_bot_name=True,
timestamp_mode="normal",
truncate=True
) )
# 创建表情选择器 # 创建表情选择器
expression_selector = ExpressionSelector(self.parameters.chat_id) expression_selector = ExpressionSelector(self.parameters.chat_id)
# 选择合适的表情 # 选择合适的表情
selected_expressions = await expression_selector.select_suitable_expressions_llm( selected_expressions = await expression_selector.select_suitable_expressions_llm(
chat_history=chat_history, chat_history=chat_history,
current_message=self.parameters.target, current_message=self.parameters.target,
emotional_tone="neutral", emotional_tone="neutral",
topic_type="general" topic_type="general",
) )
# 构建表达习惯块 # 构建表达习惯块
if selected_expressions: if selected_expressions:
style_habits_str = "\n".join([f"- {expr}" for expr in selected_expressions]) style_habits_str = "\n".join([f"- {expr}" for expr in selected_expressions])
expression_habits_block = f"- 你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:\n{style_habits_str}" expression_habits_block = f"- 你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:\n{style_habits_str}"
else: else:
expression_habits_block = "" expression_habits_block = ""
return {"expression_habits_block": expression_habits_block} return {"expression_habits_block": expression_habits_block}
except Exception as e: except Exception as e:
logger.error(f"构建表达习惯失败: {e}") logger.error(f"构建表达习惯失败: {e}")
return {"expression_habits_block": ""} return {"expression_habits_block": ""}
async def _build_memory_block(self) -> Dict[str, Any]: async def _build_memory_block(self) -> Dict[str, Any]:
"""构建记忆块""" """构建记忆块"""
if not global_config.memory.enable_memory: if not global_config.memory.enable_memory:
return {"memory_block": ""} return {"memory_block": ""}
try: try:
from src.chat.memory_system.memory_activator import MemoryActivator from src.chat.memory_system.memory_activator import MemoryActivator
from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory
# 获取聊天历史 # 获取聊天历史
chat_history = "" chat_history = ""
if self.parameters.message_list_before_now_long: if self.parameters.message_list_before_now_long:
recent_messages = self.parameters.message_list_before_now_long[-20:] recent_messages = self.parameters.message_list_before_now_long[-20:]
chat_history = build_readable_messages( chat_history = build_readable_messages(
recent_messages, recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
replace_bot_name=True,
timestamp_mode="normal",
truncate=True
) )
# 激活长期记忆 # 激活长期记忆
memory_activator = MemoryActivator() memory_activator = MemoryActivator()
running_memories = await memory_activator.activate_memory_with_chat_history( running_memories = await memory_activator.activate_memory_with_chat_history(
target_message=self.parameters.target, target_message=self.parameters.target, chat_history_prompt=chat_history
chat_history_prompt=chat_history
) )
# 获取即时记忆 # 获取即时记忆
async_memory_wrapper = get_async_instant_memory(self.parameters.chat_id) async_memory_wrapper = get_async_instant_memory(self.parameters.chat_id)
instant_memory = await async_memory_wrapper.get_memory_with_fallback(self.parameters.target) instant_memory = await async_memory_wrapper.get_memory_with_fallback(self.parameters.target)
# 构建记忆块 # 构建记忆块
memory_parts = [] memory_parts = []
if running_memories: if running_memories:
memory_parts.append("以下是当前在聊天中,你回忆起的记忆:") memory_parts.append("以下是当前在聊天中,你回忆起的记忆:")
for memory in running_memories: for memory in running_memories:
memory_parts.append(f"- {memory['content']}") memory_parts.append(f"- {memory['content']}")
if instant_memory: if instant_memory:
memory_parts.append(f"- {instant_memory}") memory_parts.append(f"- {instant_memory}")
memory_block = "\n".join(memory_parts) if memory_parts else "" memory_block = "\n".join(memory_parts) if memory_parts else ""
return {"memory_block": memory_block} return {"memory_block": memory_block}
except Exception as e: except Exception as e:
logger.error(f"构建记忆块失败: {e}") logger.error(f"构建记忆块失败: {e}")
return {"memory_block": ""} return {"memory_block": ""}
async def _build_relation_info(self) -> Dict[str, Any]: async def _build_relation_info(self) -> Dict[str, Any]:
"""构建关系信息""" """构建关系信息"""
try: try:
@@ -579,110 +572,104 @@ class Prompt:
except Exception as e: except Exception as e:
logger.error(f"构建关系信息失败: {e}") logger.error(f"构建关系信息失败: {e}")
return {"relation_info_block": ""} return {"relation_info_block": ""}
async def _build_tool_info(self) -> Dict[str, Any]: async def _build_tool_info(self) -> Dict[str, Any]:
"""构建工具信息""" """构建工具信息"""
if not global_config.tool.enable_tool: if not global_config.tool.enable_tool:
return {"tool_info_block": ""} return {"tool_info_block": ""}
try: try:
from src.plugin_system.core.tool_use import ToolExecutor from src.plugin_system.core.tool_use import ToolExecutor
# 获取聊天历史 # 获取聊天历史
chat_history = "" chat_history = ""
if self.parameters.message_list_before_now_long: if self.parameters.message_list_before_now_long:
recent_messages = self.parameters.message_list_before_now_long[-15:] recent_messages = self.parameters.message_list_before_now_long[-15:]
chat_history = build_readable_messages( chat_history = build_readable_messages(
recent_messages, recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
replace_bot_name=True,
timestamp_mode="normal",
truncate=True
) )
# 创建工具执行器 # 创建工具执行器
tool_executor = ToolExecutor(chat_id=self.parameters.chat_id) tool_executor = ToolExecutor(chat_id=self.parameters.chat_id)
# 执行工具获取信息 # 执行工具获取信息
tool_results, _, _ = await tool_executor.execute_from_chat_message( tool_results, _, _ = await tool_executor.execute_from_chat_message(
sender=self.parameters.sender, sender=self.parameters.sender,
target_message=self.parameters.target, target_message=self.parameters.target,
chat_history=chat_history, chat_history=chat_history,
return_details=False return_details=False,
) )
# 构建工具信息块 # 构建工具信息块
if tool_results: if tool_results:
tool_info_parts = ["## 工具信息","以下是你通过工具获取到的实时信息:"] tool_info_parts = ["## 工具信息", "以下是你通过工具获取到的实时信息:"]
for tool_result in tool_results: for tool_result in tool_results:
tool_name = tool_result.get("tool_name", "unknown") tool_name = tool_result.get("tool_name", "unknown")
content = tool_result.get("content", "") content = tool_result.get("content", "")
result_type = tool_result.get("type", "tool_result") result_type = tool_result.get("type", "tool_result")
tool_info_parts.append(f"- 【{tool_name}{result_type}: {content}") tool_info_parts.append(f"- 【{tool_name}{result_type}: {content}")
tool_info_parts.append("以上是你获取到的实时信息,请在回复时参考这些信息。") tool_info_parts.append("以上是你获取到的实时信息,请在回复时参考这些信息。")
tool_info_block = "\n".join(tool_info_parts) tool_info_block = "\n".join(tool_info_parts)
else: else:
tool_info_block = "" tool_info_block = ""
return {"tool_info_block": tool_info_block} return {"tool_info_block": tool_info_block}
except Exception as e: except Exception as e:
logger.error(f"构建工具信息失败: {e}") logger.error(f"构建工具信息失败: {e}")
return {"tool_info_block": ""} return {"tool_info_block": ""}
async def _build_knowledge_info(self) -> Dict[str, Any]: async def _build_knowledge_info(self) -> Dict[str, Any]:
"""构建知识信息""" """构建知识信息"""
if not global_config.lpmm_knowledge.enable: if not global_config.lpmm_knowledge.enable:
return {"knowledge_prompt": ""} return {"knowledge_prompt": ""}
try: try:
from src.chat.knowledge.knowledge_lib import QAManager from src.chat.knowledge.knowledge_lib import QAManager
# 获取问题文本(当前消息) # 获取问题文本(当前消息)
question = self.parameters.target or "" question = self.parameters.target or ""
if not question: if not question:
return {"knowledge_prompt": ""} return {"knowledge_prompt": ""}
# 创建QA管理器 # 创建QA管理器
qa_manager = QAManager() qa_manager = QAManager()
# 搜索相关知识 # 搜索相关知识
knowledge_results = await qa_manager.get_knowledge( knowledge_results = await qa_manager.get_knowledge(
question=question, question=question, chat_id=self.parameters.chat_id, max_results=5, min_similarity=0.5
chat_id=self.parameters.chat_id,
max_results=5,
min_similarity=0.5
) )
# 构建知识块 # 构建知识块
if knowledge_results and knowledge_results.get("knowledge_items"): if knowledge_results and knowledge_results.get("knowledge_items"):
knowledge_parts = ["## 知识库信息","以下是与你当前对话相关的知识信息:"] knowledge_parts = ["## 知识库信息", "以下是与你当前对话相关的知识信息:"]
for item in knowledge_results["knowledge_items"]: for item in knowledge_results["knowledge_items"]:
content = item.get("content", "") content = item.get("content", "")
source = item.get("source", "") source = item.get("source", "")
relevance = item.get("relevance", 0.0) relevance = item.get("relevance", 0.0)
if content: if content:
if source: if source:
knowledge_parts.append(f"- [{relevance:.2f}] {content} (来源: {source})") knowledge_parts.append(f"- [{relevance:.2f}] {content} (来源: {source})")
else: else:
knowledge_parts.append(f"- [{relevance:.2f}] {content}") knowledge_parts.append(f"- [{relevance:.2f}] {content}")
if knowledge_results.get("summary"): if knowledge_results.get("summary"):
knowledge_parts.append(f"\n知识总结: {knowledge_results['summary']}") knowledge_parts.append(f"\n知识总结: {knowledge_results['summary']}")
knowledge_prompt = "\n".join(knowledge_parts) knowledge_prompt = "\n".join(knowledge_parts)
else: else:
knowledge_prompt = "" knowledge_prompt = ""
return {"knowledge_prompt": knowledge_prompt} return {"knowledge_prompt": knowledge_prompt}
except Exception as e: except Exception as e:
logger.error(f"构建知识信息失败: {e}") logger.error(f"构建知识信息失败: {e}")
return {"knowledge_prompt": ""} return {"knowledge_prompt": ""}
async def _build_cross_context(self) -> Dict[str, Any]: async def _build_cross_context(self) -> Dict[str, Any]:
"""构建跨群上下文""" """构建跨群上下文"""
try: try:
@@ -693,7 +680,7 @@ class Prompt:
except Exception as e: except Exception as e:
logger.error(f"构建跨群上下文失败: {e}") logger.error(f"构建跨群上下文失败: {e}")
return {"cross_context_block": ""} return {"cross_context_block": ""}
async def _format_with_context(self, context_data: Dict[str, Any]) -> str: async def _format_with_context(self, context_data: Dict[str, Any]) -> str:
"""使用上下文数据格式化模板""" """使用上下文数据格式化模板"""
if self.parameters.prompt_mode == "s4u": if self.parameters.prompt_mode == "s4u":
@@ -702,9 +689,9 @@ class Prompt:
params = self._prepare_normal_params(context_data) params = self._prepare_normal_params(context_data)
else: else:
params = self._prepare_default_params(context_data) params = self._prepare_default_params(context_data)
return await global_prompt_manager.format_prompt(self.name, **params) if self.name else self.format(**params) return await global_prompt_manager.format_prompt(self.name, **params) if self.name else self.format(**params)
def _prepare_s4u_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: def _prepare_s4u_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
"""准备S4U模式的参数""" """准备S4U模式的参数"""
return { return {
@@ -725,11 +712,13 @@ class Prompt:
"time_block": context_data.get("time_block", ""), "time_block": context_data.get("time_block", ""),
"reply_target_block": context_data.get("reply_target_block", ""), "reply_target_block": context_data.get("reply_target_block", ""),
"reply_style": global_config.personality.reply_style, "reply_style": global_config.personality.reply_style,
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""), "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""), "moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
"safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""), "safety_guidelines_block": self.parameters.safety_guidelines_block
or context_data.get("safety_guidelines_block", ""),
} }
def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
"""准备Normal模式的参数""" """准备Normal模式的参数"""
return { return {
@@ -749,11 +738,13 @@ class Prompt:
"reply_target_block": context_data.get("reply_target_block", ""), "reply_target_block": context_data.get("reply_target_block", ""),
"config_expression_style": global_config.personality.reply_style, "config_expression_style": global_config.personality.reply_style,
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""), "mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""), "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""), "moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
"safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""), "safety_guidelines_block": self.parameters.safety_guidelines_block
or context_data.get("safety_guidelines_block", ""),
} }
def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
"""准备默认模式的参数""" """准备默认模式的参数"""
return { return {
@@ -769,11 +760,13 @@ class Prompt:
"reason": "", "reason": "",
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""), "mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"reply_style": global_config.personality.reply_style, "reply_style": global_config.personality.reply_style,
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""), "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""), "moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
"safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""), "safety_guidelines_block": self.parameters.safety_guidelines_block
or context_data.get("safety_guidelines_block", ""),
} }
def format(self, *args, **kwargs) -> str: def format(self, *args, **kwargs) -> str:
"""格式化模板,支持位置参数和关键字参数""" """格式化模板,支持位置参数和关键字参数"""
try: try:
@@ -786,21 +779,21 @@ class Prompt:
processed_template = self._processed_template.format(**formatted_args) processed_template = self._processed_template.format(**formatted_args)
else: else:
processed_template = self._processed_template processed_template = self._processed_template
# 再用关键字参数格式化 # 再用关键字参数格式化
if kwargs: if kwargs:
processed_template = processed_template.format(**kwargs) processed_template = processed_template.format(**kwargs)
# 将临时标记还原为实际的花括号 # 将临时标记还原为实际的花括号
result = self._restore_escaped_braces(processed_template) result = self._restore_escaped_braces(processed_template)
return result return result
except (IndexError, KeyError) as e: except (IndexError, KeyError) as e:
raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {str(e)}") from e raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {str(e)}") from e
def __str__(self) -> str: def __str__(self) -> str:
"""返回格式化后的结果或原始模板""" """返回格式化后的结果或原始模板"""
return self._formatted_result if self._formatted_result else self.template return self._formatted_result if self._formatted_result else self.template
def __repr__(self) -> str: def __repr__(self) -> str:
"""返回提示词的表示形式""" """返回提示词的表示形式"""
return f"Prompt(template='{self.template}', name='{self.name}')" return f"Prompt(template='{self.template}', name='{self.name}')"
@@ -872,9 +865,7 @@ class Prompt:
return await relationship_fetcher.build_relation_info(person_id, points_num=5) return await relationship_fetcher.build_relation_info(person_id, points_num=5)
@staticmethod @staticmethod
async def build_cross_context( async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]) -> str:
chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]
) -> str:
""" """
构建跨群聊上下文 - 统一实现 构建跨群聊上下文 - 统一实现
@@ -890,7 +881,7 @@ class Prompt:
return "" return ""
from src.plugin_system.apis import cross_context_api from src.plugin_system.apis import cross_context_api
other_chat_raw_ids = cross_context_api.get_context_groups(chat_id) other_chat_raw_ids = cross_context_api.get_context_groups(chat_id)
if not other_chat_raw_ids: if not other_chat_raw_ids:
return "" return ""
@@ -937,10 +928,7 @@ class Prompt:
# 工厂函数 # 工厂函数
def create_prompt( def create_prompt(
template: str, template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
name: Optional[str] = None,
parameters: Optional[PromptParameters] = None,
**kwargs
) -> Prompt: ) -> Prompt:
"""快速创建Prompt实例的工厂函数""" """快速创建Prompt实例的工厂函数"""
if parameters is None: if parameters is None:
@@ -949,14 +937,10 @@ def create_prompt(
async def create_prompt_async( async def create_prompt_async(
template: str, template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
name: Optional[str] = None,
parameters: Optional[PromptParameters] = None,
**kwargs
) -> Prompt: ) -> Prompt:
"""异步创建Prompt实例""" """异步创建Prompt实例"""
prompt = create_prompt(template, name, parameters, **kwargs) prompt = create_prompt(template, name, parameters, **kwargs)
if global_prompt_manager._context._current_context: if global_prompt_manager._context._current_context:
await global_prompt_manager._context.register_async(prompt) await global_prompt_manager._context.register_async(prompt)
return prompt return prompt

View File

@@ -332,9 +332,9 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
if global_config.response_splitter.enable and enable_splitter: if global_config.response_splitter.enable and enable_splitter:
logger.info(f"回复分割器已启用,模式: {global_config.response_splitter.split_mode}") logger.info(f"回复分割器已启用,模式: {global_config.response_splitter.split_mode}")
split_mode = global_config.response_splitter.split_mode split_mode = global_config.response_splitter.split_mode
if split_mode == "llm" and "[SPLIT]" in cleaned_text: if split_mode == "llm" and "[SPLIT]" in cleaned_text:
logger.debug("检测到 [SPLIT] 标记,使用 LLM 自定义分割。") logger.debug("检测到 [SPLIT] 标记,使用 LLM 自定义分割。")
split_sentences_raw = cleaned_text.split("[SPLIT]") split_sentences_raw = cleaned_text.split("[SPLIT]")
@@ -343,7 +343,7 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
if split_mode == "llm": if split_mode == "llm":
logger.debug("未检测到 [SPLIT] 标记,本次不进行分割。") logger.debug("未检测到 [SPLIT] 标记,本次不进行分割。")
split_sentences = [cleaned_text] split_sentences = [cleaned_text]
else: # mode == "punctuation" else: # mode == "punctuation"
logger.debug("使用基于标点的传统模式进行分割。") logger.debug("使用基于标点的传统模式进行分割。")
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text) split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
else: else:

View File

@@ -6,6 +6,7 @@ class BaseDataModel:
def deepcopy(self): def deepcopy(self):
return copy.deepcopy(self) return copy.deepcopy(self)
def temporarily_transform_class_to_dict(obj: Any) -> Any: def temporarily_transform_class_to_dict(obj: Any) -> Any:
# sourcery skip: assign-if-exp, reintroduce-else # sourcery skip: assign-if-exp, reintroduce-else
""" """

View File

@@ -2,6 +2,7 @@
机器人兴趣标签数据模型 机器人兴趣标签数据模型
定义机器人的兴趣标签和相关的embedding数据结构 定义机器人的兴趣标签和相关的embedding数据结构
""" """
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Dict, Optional, Any from typing import List, Dict, Optional, Any
from datetime import datetime from datetime import datetime
@@ -12,6 +13,7 @@ from . import BaseDataModel
@dataclass @dataclass
class BotInterestTag(BaseDataModel): class BotInterestTag(BaseDataModel):
"""机器人兴趣标签""" """机器人兴趣标签"""
tag_name: str tag_name: str
weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0) weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0)
embedding: Optional[List[float]] = None # 标签的embedding向量 embedding: Optional[List[float]] = None # 标签的embedding向量
@@ -27,7 +29,7 @@ class BotInterestTag(BaseDataModel):
"embedding": self.embedding, "embedding": self.embedding,
"created_at": self.created_at.isoformat(), "created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(), "updated_at": self.updated_at.isoformat(),
"is_active": self.is_active "is_active": self.is_active,
} }
@classmethod @classmethod
@@ -39,13 +41,14 @@ class BotInterestTag(BaseDataModel):
embedding=data.get("embedding"), embedding=data.get("embedding"),
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(), created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(),
updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else datetime.now(), updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else datetime.now(),
is_active=data.get("is_active", True) is_active=data.get("is_active", True),
) )
@dataclass @dataclass
class BotPersonalityInterests(BaseDataModel): class BotPersonalityInterests(BaseDataModel):
"""机器人人格化兴趣配置""" """机器人人格化兴趣配置"""
personality_id: str personality_id: str
personality_description: str # 人设描述文本 personality_description: str # 人设描述文本
interest_tags: List[BotInterestTag] = field(default_factory=list) interest_tags: List[BotInterestTag] = field(default_factory=list)
@@ -57,7 +60,6 @@ class BotPersonalityInterests(BaseDataModel):
"""获取活跃的兴趣标签""" """获取活跃的兴趣标签"""
return [tag for tag in self.interest_tags if tag.is_active] return [tag for tag in self.interest_tags if tag.is_active]
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式""" """转换为字典格式"""
return { return {
@@ -66,7 +68,7 @@ class BotPersonalityInterests(BaseDataModel):
"interest_tags": [tag.to_dict() for tag in self.interest_tags], "interest_tags": [tag.to_dict() for tag in self.interest_tags],
"embedding_model": self.embedding_model, "embedding_model": self.embedding_model,
"last_updated": self.last_updated.isoformat(), "last_updated": self.last_updated.isoformat(),
"version": self.version "version": self.version,
} }
@classmethod @classmethod
@@ -78,13 +80,14 @@ class BotPersonalityInterests(BaseDataModel):
interest_tags=[BotInterestTag.from_dict(tag_data) for tag_data in data.get("interest_tags", [])], interest_tags=[BotInterestTag.from_dict(tag_data) for tag_data in data.get("interest_tags", [])],
embedding_model=data.get("embedding_model", "text-embedding-ada-002"), embedding_model=data.get("embedding_model", "text-embedding-ada-002"),
last_updated=datetime.fromisoformat(data["last_updated"]) if data.get("last_updated") else datetime.now(), last_updated=datetime.fromisoformat(data["last_updated"]) if data.get("last_updated") else datetime.now(),
version=data.get("version", 1) version=data.get("version", 1),
) )
@dataclass @dataclass
class InterestMatchResult(BaseDataModel): class InterestMatchResult(BaseDataModel):
"""兴趣匹配结果""" """兴趣匹配结果"""
message_id: str message_id: str
matched_tags: List[str] = field(default_factory=list) matched_tags: List[str] = field(default_factory=list)
match_scores: Dict[str, float] = field(default_factory=dict) # tag_name -> score match_scores: Dict[str, float] = field(default_factory=dict) # tag_name -> score
@@ -120,7 +123,9 @@ class InterestMatchResult(BaseDataModel):
# 计算置信度(基于匹配标签数量和分数分布) # 计算置信度(基于匹配标签数量和分数分布)
if len(self.match_scores) > 0: if len(self.match_scores) > 0:
avg_score = self.overall_score avg_score = self.overall_score
score_variance = sum((score - avg_score) ** 2 for score in self.match_scores.values()) / len(self.match_scores) score_variance = sum((score - avg_score) ** 2 for score in self.match_scores.values()) / len(
self.match_scores
)
# 分数越集中,置信度越高 # 分数越集中,置信度越高
self.confidence = max(0.0, 1.0 - score_variance) self.confidence = max(0.0, 1.0 - score_variance)
else: else:
@@ -129,4 +134,4 @@ class InterestMatchResult(BaseDataModel):
def get_top_matches(self, top_n: int = 3) -> List[tuple]: def get_top_matches(self, top_n: int = 3) -> List[tuple]:
"""获取前N个最佳匹配""" """获取前N个最佳匹配"""
sorted_matches = sorted(self.match_scores.items(), key=lambda x: x[1], reverse=True) sorted_matches = sorted(self.match_scores.items(), key=lambda x: x[1], reverse=True)
return sorted_matches[:top_n] return sorted_matches[:top_n]

View File

@@ -208,6 +208,7 @@ class DatabaseMessages(BaseDataModel):
"chat_info_user_cardname": self.chat_info.user_info.user_cardname, "chat_info_user_cardname": self.chat_info.user_info.user_cardname,
} }
@dataclass(init=False) @dataclass(init=False)
class DatabaseActionRecords(BaseDataModel): class DatabaseActionRecords(BaseDataModel):
def __init__( def __init__(
@@ -235,4 +236,4 @@ class DatabaseActionRecords(BaseDataModel):
self.action_prompt_display = action_prompt_display self.action_prompt_display = action_prompt_display
self.chat_id = chat_id self.chat_id = chat_id
self.chat_info_stream_id = chat_info_stream_id self.chat_info_stream_id = chat_info_stream_id
self.chat_info_platform = chat_info_platform self.chat_info_platform = chat_info_platform

View File

@@ -28,6 +28,7 @@ class ActionPlannerInfo(BaseDataModel):
@dataclass @dataclass
class InterestScore(BaseDataModel): class InterestScore(BaseDataModel):
"""兴趣度评分结果""" """兴趣度评分结果"""
message_id: str message_id: str
total_score: float total_score: float
interest_match_score: float interest_match_score: float
@@ -41,6 +42,7 @@ class Plan(BaseDataModel):
""" """
统一规划数据模型 统一规划数据模型
""" """
chat_id: str chat_id: str
mode: "ChatMode" mode: "ChatMode"

View File

@@ -2,9 +2,11 @@ from dataclasses import dataclass
from typing import Optional, List, Tuple, TYPE_CHECKING, Any from typing import Optional, List, Tuple, TYPE_CHECKING, Any
from . import BaseDataModel from . import BaseDataModel
if TYPE_CHECKING: if TYPE_CHECKING:
from src.llm_models.payload_content.tool_option import ToolCall from src.llm_models.payload_content.tool_option import ToolCall
@dataclass @dataclass
class LLMGenerationDataModel(BaseDataModel): class LLMGenerationDataModel(BaseDataModel):
content: Optional[str] = None content: Optional[str] = None
@@ -13,4 +15,4 @@ class LLMGenerationDataModel(BaseDataModel):
tool_calls: Optional[List["ToolCall"]] = None tool_calls: Optional[List["ToolCall"]] = None
prompt: Optional[str] = None prompt: Optional[str] = None
selected_expressions: Optional[List[int]] = None selected_expressions: Optional[List[int]] = None
reply_set: Optional[List[Tuple[str, Any]]] = None reply_set: Optional[List[Tuple[str, Any]]] = None

View File

@@ -2,6 +2,7 @@
消息管理模块数据模型 消息管理模块数据模型
定义消息管理器使用的数据结构 定义消息管理器使用的数据结构
""" """
import asyncio import asyncio
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
@@ -16,14 +17,16 @@ if TYPE_CHECKING:
class MessageStatus(Enum): class MessageStatus(Enum):
"""消息状态枚举""" """消息状态枚举"""
UNREAD = "unread" # 未读消息
READ = "read" # 读消息 UNREAD = "unread" # 读消息
READ = "read" # 已读消息
PROCESSING = "processing" # 处理中 PROCESSING = "processing" # 处理中
@dataclass @dataclass
class StreamContext(BaseDataModel): class StreamContext(BaseDataModel):
"""聊天流上下文信息""" """聊天流上下文信息"""
stream_id: str stream_id: str
unread_messages: List["DatabaseMessages"] = field(default_factory=list) unread_messages: List["DatabaseMessages"] = field(default_factory=list)
history_messages: List["DatabaseMessages"] = field(default_factory=list) history_messages: List["DatabaseMessages"] = field(default_factory=list)
@@ -59,6 +62,7 @@ class StreamContext(BaseDataModel):
@dataclass @dataclass
class MessageManagerStats(BaseDataModel): class MessageManagerStats(BaseDataModel):
"""消息管理器统计信息""" """消息管理器统计信息"""
total_streams: int = 0 total_streams: int = 0
active_streams: int = 0 active_streams: int = 0
total_unread_messages: int = 0 total_unread_messages: int = 0
@@ -74,9 +78,10 @@ class MessageManagerStats(BaseDataModel):
@dataclass @dataclass
class StreamStats(BaseDataModel): class StreamStats(BaseDataModel):
"""聊天流统计信息""" """聊天流统计信息"""
stream_id: str stream_id: str
is_active: bool is_active: bool
unread_count: int unread_count: int
history_count: int history_count: int
last_check_time: float last_check_time: float
has_active_task: bool has_active_task: bool

View File

@@ -23,15 +23,15 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
maim_message_config = global_config.maim_message maim_message_config = global_config.maim_message
# 设置基本参数 # 设置基本参数
host = os.getenv("HOST", "127.0.0.1") host = os.getenv("HOST", "127.0.0.1")
port_str = os.getenv("PORT", "8000") port_str = os.getenv("PORT", "8000")
try: try:
port = int(port_str) port = int(port_str)
except ValueError: except ValueError:
port = 8000 port = 8000
kwargs = { kwargs = {
"host": host, "host": host,
"port": port, "port": port,

View File

@@ -31,7 +31,9 @@ class TelemetryHeartBeatTask(AsyncTask):
self.client_uuid: str | None = local_storage["mofox_uuid"] if "mofox_uuid" in local_storage else None # type: ignore self.client_uuid: str | None = local_storage["mofox_uuid"] if "mofox_uuid" in local_storage else None # type: ignore
"""客户端UUID""" """客户端UUID"""
self.private_key_pem: str | None = local_storage["mofox_private_key"] if "mofox_private_key" in local_storage else None # type: ignore self.private_key_pem: str | None = (
local_storage["mofox_private_key"] if "mofox_private_key" in local_storage else None
) # type: ignore
"""客户端私钥""" """客户端私钥"""
self.info_dict = self._get_sys_info() self.info_dict = self._get_sys_info()
@@ -61,78 +63,65 @@ class TelemetryHeartBeatTask(AsyncTask):
def _generate_signature(self, request_body: dict) -> tuple[str, str]: def _generate_signature(self, request_body: dict) -> tuple[str, str]:
""" """
生成RSA签名 生成RSA签名
Returns: Returns:
tuple[str, str]: (timestamp, signature_b64) tuple[str, str]: (timestamp, signature_b64)
""" """
if not self.private_key_pem: if not self.private_key_pem:
raise ValueError("私钥未初始化") raise ValueError("私钥未初始化")
# 生成时间戳 # 生成时间戳
timestamp = datetime.now(timezone.utc).isoformat() timestamp = datetime.now(timezone.utc).isoformat()
# 创建签名数据字符串 # 创建签名数据字符串
sign_data = f"{self.client_uuid}:{timestamp}:{json.dumps(request_body, separators=(',', ':'))}" sign_data = f"{self.client_uuid}:{timestamp}:{json.dumps(request_body, separators=(',', ':'))}"
# 加载私钥 # 加载私钥
private_key = serialization.load_pem_private_key( private_key = serialization.load_pem_private_key(self.private_key_pem.encode("utf-8"), password=None)
self.private_key_pem.encode('utf-8'),
password=None
)
# 确保是RSA私钥 # 确保是RSA私钥
if not isinstance(private_key, rsa.RSAPrivateKey): if not isinstance(private_key, rsa.RSAPrivateKey):
raise ValueError("私钥必须是RSA格式") raise ValueError("私钥必须是RSA格式")
# 生成签名 # 生成签名
signature = private_key.sign( signature = private_key.sign(
sign_data.encode('utf-8'), sign_data.encode("utf-8"),
padding.PSS( padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH),
mgf=padding.MGF1(hashes.SHA256()), hashes.SHA256(),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
) )
# Base64编码 # Base64编码
signature_b64 = base64.b64encode(signature).decode('utf-8') signature_b64 = base64.b64encode(signature).decode("utf-8")
return timestamp, signature_b64 return timestamp, signature_b64
def _decrypt_challenge(self, challenge_b64: str) -> str: def _decrypt_challenge(self, challenge_b64: str) -> str:
""" """
解密挑战数据 解密挑战数据
Args: Args:
challenge_b64: Base64编码的挑战数据 challenge_b64: Base64编码的挑战数据
Returns: Returns:
str: 解密后的UUID字符串 str: 解密后的UUID字符串
""" """
if not self.private_key_pem: if not self.private_key_pem:
raise ValueError("私钥未初始化") raise ValueError("私钥未初始化")
# 加载私钥 # 加载私钥
private_key = serialization.load_pem_private_key( private_key = serialization.load_pem_private_key(self.private_key_pem.encode("utf-8"), password=None)
self.private_key_pem.encode('utf-8'),
password=None
)
# 确保是RSA私钥 # 确保是RSA私钥
if not isinstance(private_key, rsa.RSAPrivateKey): if not isinstance(private_key, rsa.RSAPrivateKey):
raise ValueError("私钥必须是RSA格式") raise ValueError("私钥必须是RSA格式")
# 解密挑战数据 # 解密挑战数据
decrypted_bytes = private_key.decrypt( decrypted_bytes = private_key.decrypt(
base64.b64decode(challenge_b64), base64.b64decode(challenge_b64),
padding.OAEP( padding.OAEP(mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None),
mgf=padding.MGF1(hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
) )
return decrypted_bytes.decode('utf-8') return decrypted_bytes.decode("utf-8")
async def _req_uuid(self) -> bool: async def _req_uuid(self) -> bool:
""" """
@@ -155,28 +144,26 @@ class TelemetryHeartBeatTask(AsyncTask):
if response.status != 200: if response.status != 200:
response_text = await response.text() response_text = await response.text()
logger.error( logger.error(f"注册步骤1失败状态码: {response.status}, 响应内容: {response_text}")
f"注册步骤1失败状态码: {response.status}, 响应内容: {response_text}"
)
raise aiohttp.ClientResponseError( raise aiohttp.ClientResponseError(
request_info=response.request_info, request_info=response.request_info,
history=response.history, history=response.history,
status=response.status, status=response.status,
message=f"Step1 failed: {response_text}" message=f"Step1 failed: {response_text}",
) )
step1_data = await response.json() step1_data = await response.json()
temp_uuid = step1_data.get("temp_uuid") temp_uuid = step1_data.get("temp_uuid")
private_key = step1_data.get("private_key") private_key = step1_data.get("private_key")
challenge = step1_data.get("challenge") challenge = step1_data.get("challenge")
if not all([temp_uuid, private_key, challenge]): if not all([temp_uuid, private_key, challenge]):
logger.error("Step1响应缺少必要字段temp_uuid, private_key 或 challenge") logger.error("Step1响应缺少必要字段temp_uuid, private_key 或 challenge")
raise ValueError("Step1响应数据不完整") raise ValueError("Step1响应数据不完整")
# 临时保存私钥用于解密 # 临时保存私钥用于解密
self.private_key_pem = private_key self.private_key_pem = private_key
# 解密挑战数据 # 解密挑战数据
logger.debug("解密挑战数据...") logger.debug("解密挑战数据...")
try: try:
@@ -184,21 +171,18 @@ class TelemetryHeartBeatTask(AsyncTask):
except Exception as e: except Exception as e:
logger.error(f"解密挑战数据失败: {e}") logger.error(f"解密挑战数据失败: {e}")
raise raise
# 验证解密结果 # 验证解密结果
if decrypted_uuid != temp_uuid: if decrypted_uuid != temp_uuid:
logger.error(f"解密结果验证失败: 期望 {temp_uuid}, 实际 {decrypted_uuid}") logger.error(f"解密结果验证失败: 期望 {temp_uuid}, 实际 {decrypted_uuid}")
raise ValueError("解密结果与临时UUID不匹配") raise ValueError("解密结果与临时UUID不匹配")
logger.debug("挑战数据解密成功开始注册步骤2") logger.debug("挑战数据解密成功开始注册步骤2")
# Step 2: 发送解密结果完成注册 # Step 2: 发送解密结果完成注册
async with session.post( async with session.post(
f"{TELEMETRY_SERVER_URL}/stat/reg_client_step2", f"{TELEMETRY_SERVER_URL}/stat/reg_client_step2",
json={ json={"temp_uuid": temp_uuid, "decrypted_uuid": decrypted_uuid},
"temp_uuid": temp_uuid,
"decrypted_uuid": decrypted_uuid
},
timeout=aiohttp.ClientTimeout(total=5), timeout=aiohttp.ClientTimeout(total=5),
) as response: ) as response:
logger.debug(f"Step2 Response status: {response.status}") logger.debug(f"Step2 Response status: {response.status}")
@@ -206,7 +190,7 @@ class TelemetryHeartBeatTask(AsyncTask):
if response.status == 200: if response.status == 200:
step2_data = await response.json() step2_data = await response.json()
mofox_uuid = step2_data.get("mofox_uuid") mofox_uuid = step2_data.get("mofox_uuid")
if mofox_uuid: if mofox_uuid:
# 将正式UUID和私钥存储到本地 # 将正式UUID和私钥存储到本地
local_storage["mofox_uuid"] = mofox_uuid local_storage["mofox_uuid"] = mofox_uuid
@@ -225,23 +209,19 @@ class TelemetryHeartBeatTask(AsyncTask):
raise ValueError(f"Step2失败: {response_text}") raise ValueError(f"Step2失败: {response_text}")
else: else:
response_text = await response.text() response_text = await response.text()
logger.error( logger.error(f"注册步骤2失败状态码: {response.status}, 响应内容: {response_text}")
f"注册步骤2失败状态码: {response.status}, 响应内容: {response_text}"
)
raise aiohttp.ClientResponseError( raise aiohttp.ClientResponseError(
request_info=response.request_info, request_info=response.request_info,
history=response.history, history=response.history,
status=response.status, status=response.status,
message=f"Step2 failed: {response_text}" message=f"Step2 failed: {response_text}",
) )
except Exception as e: except Exception as e:
import traceback import traceback
error_msg = str(e) or "未知错误" error_msg = str(e) or "未知错误"
logger.warning( logger.warning(f"注册客户端出错,不过你还是可以正常使用墨狐: {type(e).__name__}: {error_msg}")
f"注册客户端出错,不过你还是可以正常使用墨狐: {type(e).__name__}: {error_msg}"
)
logger.debug(f"完整错误信息: {traceback.format_exc()}") logger.debug(f"完整错误信息: {traceback.format_exc()}")
# 请求失败,重试次数+1 # 请求失败,重试次数+1
@@ -264,13 +244,13 @@ class TelemetryHeartBeatTask(AsyncTask):
try: try:
# 生成签名 # 生成签名
timestamp, signature = self._generate_signature(self.info_dict) timestamp, signature = self._generate_signature(self.info_dict)
headers = { headers = {
"X-mofox-UUID": self.client_uuid, "X-mofox-UUID": self.client_uuid,
"X-mofox-Signature": signature, "X-mofox-Signature": signature,
"X-mofox-Timestamp": timestamp, "X-mofox-Timestamp": timestamp,
"User-Agent": f"MofoxClient/{self.client_uuid[:8]}", "User-Agent": f"MofoxClient/{self.client_uuid[:8]}",
"Content-Type": "application/json" "Content-Type": "application/json",
} }
logger.debug(f"正在发送心跳到服务器: {self.server_url}") logger.debug(f"正在发送心跳到服务器: {self.server_url}")
@@ -347,4 +327,4 @@ class TelemetryHeartBeatTask(AsyncTask):
logger.warning("客户端注册失败,跳过此次心跳") logger.warning("客户端注册失败,跳过此次心跳")
return return
await self._send_heartbeat() await self._send_heartbeat()

View File

@@ -99,14 +99,13 @@ def get_global_server() -> Server:
"""获取全局服务器实例""" """获取全局服务器实例"""
global global_server global global_server
if global_server is None: if global_server is None:
host = os.getenv("HOST", "127.0.0.1") host = os.getenv("HOST", "127.0.0.1")
port_str = os.getenv("PORT", "8000") port_str = os.getenv("PORT", "8000")
try: try:
port = int(port_str) port = int(port_str)
except ValueError: except ValueError:
port = 8000 port = 8000
global_server = Server(host=host, port=port) global_server = Server(host=host, port=port)
return global_server return global_server

View File

@@ -44,7 +44,7 @@ from src.config.official_configs import (
PermissionConfig, PermissionConfig,
CommandConfig, CommandConfig,
PlanningSystemConfig, PlanningSystemConfig,
AffinityFlowConfig AffinityFlowConfig,
) )
from .api_ada_configs import ( from .api_ada_configs import (
@@ -399,9 +399,7 @@ class Config(ValidatedConfigBase):
cross_context: CrossContextConfig = Field( cross_context: CrossContextConfig = Field(
default_factory=lambda: CrossContextConfig(), description="跨群聊上下文共享配置" default_factory=lambda: CrossContextConfig(), description="跨群聊上下文共享配置"
) )
affinity_flow: AffinityFlowConfig = Field( affinity_flow: AffinityFlowConfig = Field(default_factory=lambda: AffinityFlowConfig(), description="亲和流配置")
default_factory=lambda: AffinityFlowConfig(), description="亲和流配置"
)
class APIAdapterConfig(ValidatedConfigBase): class APIAdapterConfig(ValidatedConfigBase):

View File

@@ -51,8 +51,12 @@ class PersonalityConfig(ValidatedConfigBase):
personality_core: str = Field(..., description="核心人格") personality_core: str = Field(..., description="核心人格")
personality_side: str = Field(..., description="人格侧写") personality_side: str = Field(..., description="人格侧写")
identity: str = Field(default="", description="身份特征") identity: str = Field(default="", description="身份特征")
background_story: str = Field(default="", description="世界观背景故事这部分内容会作为背景知识LLM被指导不应主动复述") background_story: str = Field(
safety_guidelines: List[str] = Field(default_factory=list, description="安全与互动底线Bot在任何情况下都必须遵守的原则") default="", description="世界观背景故事这部分内容会作为背景知识LLM被指导不应主动复述"
)
safety_guidelines: List[str] = Field(
default_factory=list, description="安全与互动底线Bot在任何情况下都必须遵守的原则"
)
reply_style: str = Field(default="", description="表达风格") reply_style: str = Field(default="", description="表达风格")
prompt_mode: Literal["s4u", "normal"] = Field(default="s4u", description="Prompt模式") prompt_mode: Literal["s4u", "normal"] = Field(default="s4u", description="Prompt模式")
compress_personality: bool = Field(default=True, description="是否压缩人格") compress_personality: bool = Field(default=True, description="是否压缩人格")
@@ -79,7 +83,8 @@ class ChatConfig(ValidatedConfigBase):
talk_frequency_adjust: list[list[str]] = Field(default_factory=lambda: [], description="聊天频率调整") talk_frequency_adjust: list[list[str]] = Field(default_factory=lambda: [], description="聊天频率调整")
focus_value: float = Field(default=1.0, description="专注值") focus_value: float = Field(default=1.0, description="专注值")
focus_mode_quiet_groups: List[str] = Field( focus_mode_quiet_groups: List[str] = Field(
default_factory=list, description='专注模式下需要保持安静的群组列表, 格式: ["platform:group_id1", "platform:group_id2"]' default_factory=list,
description='专注模式下需要保持安静的群组列表, 格式: ["platform:group_id1", "platform:group_id2"]',
) )
force_reply_private: bool = Field(default=False, description="强制回复私聊") force_reply_private: bool = Field(default=False, description="强制回复私聊")
group_chat_mode: Literal["auto", "normal", "focus"] = Field(default="auto", description="群聊模式") group_chat_mode: Literal["auto", "normal", "focus"] = Field(default="auto", description="群聊模式")
@@ -343,6 +348,7 @@ class ExpressionConfig(ValidatedConfigBase):
# 如果都没有匹配,返回默认值 # 如果都没有匹配,返回默认值
return True, True, 1.0 return True, True, 1.0
class ToolConfig(ValidatedConfigBase): class ToolConfig(ValidatedConfigBase):
"""工具配置类""" """工具配置类"""
@@ -477,7 +483,6 @@ class ExperimentalConfig(ValidatedConfigBase):
pfc_chatting: bool = Field(default=False, description="启用PFC聊天") pfc_chatting: bool = Field(default=False, description="启用PFC聊天")
class MaimMessageConfig(ValidatedConfigBase): class MaimMessageConfig(ValidatedConfigBase):
"""maim_message配置类""" """maim_message配置类"""
@@ -602,8 +607,12 @@ class SleepSystemConfig(ValidatedConfigBase):
sleep_by_schedule: bool = Field(default=True, description="是否根据日程表进行睡觉") sleep_by_schedule: bool = Field(default=True, description="是否根据日程表进行睡觉")
fixed_sleep_time: str = Field(default="23:00", description="固定的睡觉时间") fixed_sleep_time: str = Field(default="23:00", description="固定的睡觉时间")
fixed_wake_up_time: str = Field(default="07:00", description="固定的起床时间") fixed_wake_up_time: str = Field(default="07:00", description="固定的起床时间")
sleep_time_offset_minutes: int = Field(default=15, ge=0, le=60, description="睡觉时间随机偏移量范围(分钟),实际睡觉时间会在±该值范围内随机") sleep_time_offset_minutes: int = Field(
wake_up_time_offset_minutes: int = Field(default=15, ge=0, le=60, description="起床时间随机偏移量范围(分钟),实际起床时间会在±该值范围内随机") default=15, ge=0, le=60, description="睡觉时间随机偏移量范围(分钟),实际睡觉时间会在±该值范围内随机"
)
wake_up_time_offset_minutes: int = Field(
default=15, ge=0, le=60, description="起床时间随机偏移量范围(分钟),实际起床时间会在±该值范围内随机"
)
wakeup_threshold: float = Field(default=15.0, ge=1.0, description="唤醒阈值,达到此值时会被唤醒") wakeup_threshold: float = Field(default=15.0, ge=1.0, description="唤醒阈值,达到此值时会被唤醒")
private_message_increment: float = Field(default=3.0, ge=0.1, description="私聊消息增加的唤醒度") private_message_increment: float = Field(default=3.0, ge=0.1, description="私聊消息增加的唤醒度")
group_mention_increment: float = Field(default=2.0, ge=0.1, description="群聊艾特增加的唤醒度") group_mention_increment: float = Field(default=2.0, ge=0.1, description="群聊艾特增加的唤醒度")
@@ -618,10 +627,10 @@ class SleepSystemConfig(ValidatedConfigBase):
# --- 失眠机制相关参数 --- # --- 失眠机制相关参数 ---
enable_insomnia_system: bool = Field(default=True, description="是否启用失眠系统") enable_insomnia_system: bool = Field(default=True, description="是否启用失眠系统")
insomnia_trigger_delay_minutes: List[int] = Field( insomnia_trigger_delay_minutes: List[int] = Field(
default_factory=lambda:[30, 60], description="入睡后触发失眠判定的延迟时间范围(分钟)" default_factory=lambda: [30, 60], description="入睡后触发失眠判定的延迟时间范围(分钟)"
) )
insomnia_duration_minutes: List[int] = Field( insomnia_duration_minutes: List[int] = Field(
default_factory=lambda:[15, 45], description="单次失眠状态的持续时间范围(分钟)" default_factory=lambda: [15, 45], description="单次失眠状态的持续时间范围(分钟)"
) )
sleep_pressure_threshold: float = Field(default=30.0, description="触发“压力不足型失眠”的睡眠压力阈值") sleep_pressure_threshold: float = Field(default=30.0, description="触发“压力不足型失眠”的睡眠压力阈值")
deep_sleep_threshold: float = Field(default=80.0, description="进入“深度睡眠”的睡眠压力阈值") deep_sleep_threshold: float = Field(default=80.0, description="进入“深度睡眠”的睡眠压力阈值")
@@ -657,6 +666,8 @@ class CrossContextConfig(ValidatedConfigBase):
enable: bool = Field(default=False, description="是否启用跨群聊上下文共享功能") enable: bool = Field(default=False, description="是否启用跨群聊上下文共享功能")
groups: List[ContextGroup] = Field(default_factory=list, description="上下文共享组列表") groups: List[ContextGroup] = Field(default_factory=list, description="上下文共享组列表")
class CommandConfig(ValidatedConfigBase): class CommandConfig(ValidatedConfigBase):
"""命令系统配置类""" """命令系统配置类"""

View File

@@ -88,8 +88,7 @@ class Individuality:
# 初始化智能兴趣系统 # 初始化智能兴趣系统
await interest_scoring_system.initialize_smart_interests( await interest_scoring_system.initialize_smart_interests(
personality_description=full_personality, personality_description=full_personality, personality_id=self.bot_person_id
personality_id=self.bot_person_id
) )
logger.info("智能兴趣系统初始化完成") logger.info("智能兴趣系统初始化完成")

View File

@@ -130,7 +130,8 @@ class MainSystem:
# 停止消息重组器 # 停止消息重组器
from src.plugin_system.core.event_manager import event_manager from src.plugin_system.core.event_manager import event_manager
from src.plugin_system import EventType from src.plugin_system import EventType
asyncio.run(event_manager.trigger_event(EventType.ON_STOP,permission_group="SYSTEM"))
asyncio.run(event_manager.trigger_event(EventType.ON_STOP, permission_group="SYSTEM"))
from src.utils.message_chunker import reassembler from src.utils.message_chunker import reassembler
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@@ -216,7 +217,7 @@ MoFox_Bot(第三方修改版)
# 添加统计信息输出任务 # 添加统计信息输出任务
await async_task_manager.add_task(StatisticOutputTask()) await async_task_manager.add_task(StatisticOutputTask())
# 添加遥测心跳任务 # 添加遥测心跳任务
await async_task_manager.add_task(TelemetryHeartBeatTask()) await async_task_manager.add_task(TelemetryHeartBeatTask())
@@ -250,6 +251,7 @@ MoFox_Bot(第三方修改版)
# 初始化回复后关系追踪系统 # 初始化回复后关系追踪系统
from src.chat.affinity_flow.relationship_integration import initialize_relationship_tracking from src.chat.affinity_flow.relationship_integration import initialize_relationship_tracking
relationship_tracker = initialize_relationship_tracking() relationship_tracker = initialize_relationship_tracking()
if relationship_tracker: if relationship_tracker:
logger.info("回复后关系追踪系统初始化成功") logger.info("回复后关系追踪系统初始化成功")
@@ -273,6 +275,7 @@ MoFox_Bot(第三方修改版)
# 初始化LPMM知识库 # 初始化LPMM知识库
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge
initialize_lpmm_knowledge() initialize_lpmm_knowledge()
logger.info("LPMM知识库初始化成功") logger.info("LPMM知识库初始化成功")
@@ -298,6 +301,7 @@ MoFox_Bot(第三方修改版)
# 启动消息管理器 # 启动消息管理器
from src.chat.message_manager import message_manager from src.chat.message_manager import message_manager
await message_manager.start() await message_manager.start()
logger.info("消息管理器已启动") logger.info("消息管理器已启动")

View File

@@ -96,12 +96,13 @@ class PersonInfoManager:
# 在此处打一个补丁如果platform为qq尝试生成id后检查是否存在如果不存在则将平台换为napcat后再次检查如果存在则更新原id为platform为qq的id # 在此处打一个补丁如果platform为qq尝试生成id后检查是否存在如果不存在则将平台换为napcat后再次检查如果存在则更新原id为platform为qq的id
components = [platform, str(user_id)] components = [platform, str(user_id)]
key = "_".join(components) key = "_".join(components)
# 如果不是 qq 平台,直接返回计算的 id # 如果不是 qq 平台,直接返回计算的 id
if platform != "qq": if platform != "qq":
return hashlib.md5(key.encode()).hexdigest() return hashlib.md5(key.encode()).hexdigest()
qq_id = hashlib.md5(key.encode()).hexdigest() qq_id = hashlib.md5(key.encode()).hexdigest()
# 对于 qq 平台,先检查该 person_id 是否已存在;如果存在直接返回 # 对于 qq 平台,先检查该 person_id 是否已存在;如果存在直接返回
def _db_check_and_migrate_sync(p_id: str, raw_user_id: str): def _db_check_and_migrate_sync(p_id: str, raw_user_id: str):
try: try:
@@ -191,16 +192,16 @@ class PersonInfoManager:
# Ensure person_id is correctly set from the argument # Ensure person_id is correctly set from the argument
final_data["person_id"] = person_id final_data["person_id"] = person_id
# 你们的英文注释是何意味? # 你们的英文注释是何意味?
# 检查并修复关键字段为None的情况喵 # 检查并修复关键字段为None的情况喵
if final_data.get("user_id") is None: if final_data.get("user_id") is None:
logger.warning(f"user_id为None使用'unknown'作为默认值 person_id={person_id}") logger.warning(f"user_id为None使用'unknown'作为默认值 person_id={person_id}")
final_data["user_id"] = "unknown" final_data["user_id"] = "unknown"
if final_data.get("platform") is None: if final_data.get("platform") is None:
logger.warning(f"platform为None使用'unknown'作为默认值 person_id={person_id}") logger.warning(f"platform为None使用'unknown'作为默认值 person_id={person_id}")
final_data["platform"] = "unknown" final_data["platform"] = "unknown"
# 这里的目的是为了防止在识别出错的情况下有一个最小回退,不只是针对@消息识别成视频后的报错问题 # 这里的目的是为了防止在识别出错的情况下有一个最小回退,不只是针对@消息识别成视频后的报错问题
# Serialize JSON fields # Serialize JSON fields
@@ -251,12 +252,12 @@ class PersonInfoManager:
# Ensure person_id is correctly set from the argument # Ensure person_id is correctly set from the argument
final_data["person_id"] = person_id final_data["person_id"] = person_id
# 检查并修复关键字段为None的情况 # 检查并修复关键字段为None的情况
if final_data.get("user_id") is None: if final_data.get("user_id") is None:
logger.warning(f"user_id为None使用'unknown'作为默认值 person_id={person_id}") logger.warning(f"user_id为None使用'unknown'作为默认值 person_id={person_id}")
final_data["user_id"] = "unknown" final_data["user_id"] = "unknown"
if final_data.get("platform") is None: if final_data.get("platform") is None:
logger.warning(f"platform为None使用'unknown'作为默认值 person_id={person_id}") logger.warning(f"platform为None使用'unknown'作为默认值 person_id={person_id}")
final_data["platform"] = "unknown" final_data["platform"] = "unknown"
@@ -356,12 +357,12 @@ class PersonInfoManager:
creation_data["platform"] = data["platform"] creation_data["platform"] = data["platform"]
if data and "user_id" in data: if data and "user_id" in data:
creation_data["user_id"] = data["user_id"] creation_data["user_id"] = data["user_id"]
# 额外检查关键字段如果为None则使用默认值 # 额外检查关键字段如果为None则使用默认值
if creation_data.get("user_id") is None: if creation_data.get("user_id") is None:
logger.warning(f"创建用户时user_id为None使用'unknown'作为默认值 person_id={person_id}") logger.warning(f"创建用户时user_id为None使用'unknown'作为默认值 person_id={person_id}")
creation_data["user_id"] = "unknown" creation_data["user_id"] = "unknown"
if creation_data.get("platform") is None: if creation_data.get("platform") is None:
logger.warning(f"创建用户时platform为None使用'unknown'作为默认值 person_id={person_id}") logger.warning(f"创建用户时platform为None使用'unknown'作为默认值 person_id={person_id}")
creation_data["platform"] = "unknown" creation_data["platform"] = "unknown"

View File

@@ -123,7 +123,9 @@ class RelationshipFetcher:
all_points = current_points + forgotten_points all_points = current_points + forgotten_points
if all_points: if all_points:
# 按权重和时效性综合排序 # 按权重和时效性综合排序
all_points.sort(key=lambda x: (float(x[1]) if len(x) > 1 else 0, float(x[2]) if len(x) > 2 else 0), reverse=True) all_points.sort(
key=lambda x: (float(x[1]) if len(x) > 1 else 0, float(x[2]) if len(x) > 2 else 0), reverse=True
)
selected_points = all_points[:points_num] selected_points = all_points[:points_num]
points_text = "\n".join([f"- {point[0]}{point[2]}" for point in selected_points if len(point) > 2]) points_text = "\n".join([f"- {point[0]}{point[2]}" for point in selected_points if len(point) > 2])
else: else:
@@ -139,15 +141,17 @@ class RelationshipFetcher:
# 2. 认识时间和频率 # 2. 认识时间和频率
if know_since: if know_since:
from datetime import datetime from datetime import datetime
know_time = datetime.fromtimestamp(know_since).strftime('%Y年%m月%d')
know_time = datetime.fromtimestamp(know_since).strftime("%Y年%m月%d")
relation_parts.append(f"你从{know_time}开始认识{person_name}") relation_parts.append(f"你从{know_time}开始认识{person_name}")
if know_times > 0: if know_times > 0:
relation_parts.append(f"你们已经交流过{int(know_times)}") relation_parts.append(f"你们已经交流过{int(know_times)}")
if last_know: if last_know:
from datetime import datetime from datetime import datetime
last_time = datetime.fromtimestamp(last_know).strftime('%m月%d')
last_time = datetime.fromtimestamp(last_know).strftime("%m月%d")
relation_parts.append(f"最近一次交流是在{last_time}") relation_parts.append(f"最近一次交流是在{last_time}")
# 3. 态度和印象 # 3. 态度和印象
@@ -156,7 +160,7 @@ class RelationshipFetcher:
if short_impression: if short_impression:
relation_parts.append(f"你对ta的总体印象{short_impression}") relation_parts.append(f"你对ta的总体印象{short_impression}")
if full_impression: if full_impression:
relation_parts.append(f"更详细的了解:{full_impression}") relation_parts.append(f"更详细的了解:{full_impression}")
@@ -168,14 +172,14 @@ class RelationshipFetcher:
try: try:
from src.common.database.sqlalchemy_database_api import db_query from src.common.database.sqlalchemy_database_api import db_query
from src.common.database.sqlalchemy_models import UserRelationships from src.common.database.sqlalchemy_models import UserRelationships
# 查询用户关系数据 # 查询用户关系数据
relationships = await db_query( relationships = await db_query(
UserRelationships, UserRelationships,
filters=[UserRelationships.user_id == str(person_info_manager.get_value_sync(person_id, "user_id"))], filters=[UserRelationships.user_id == str(person_info_manager.get_value_sync(person_id, "user_id"))],
limit=1 limit=1,
) )
if relationships: if relationships:
rel_data = relationships[0] rel_data = relationships[0]
if rel_data.relationship_text: if rel_data.relationship_text:
@@ -183,13 +187,15 @@ class RelationshipFetcher:
if rel_data.relationship_score: if rel_data.relationship_score:
score_desc = self._get_relationship_score_description(rel_data.relationship_score) score_desc = self._get_relationship_score_description(rel_data.relationship_score)
relation_parts.append(f"关系亲密程度:{score_desc}") relation_parts.append(f"关系亲密程度:{score_desc}")
except Exception as e: except Exception as e:
logger.debug(f"查询UserRelationships表失败: {e}") logger.debug(f"查询UserRelationships表失败: {e}")
# 构建最终的关系信息字符串 # 构建最终的关系信息字符串
if relation_parts: if relation_parts:
relation_info = f"关于{person_name},你知道以下信息:\n" + "\n".join([f"{part}" for part in relation_parts]) relation_info = f"关于{person_name},你知道以下信息:\n" + "\n".join(
[f"{part}" for part in relation_parts]
)
else: else:
relation_info = f"你对{person_name}了解不多,这是比较初步的交流。" relation_info = f"你对{person_name}了解不多,这是比较初步的交流。"

View File

@@ -93,7 +93,6 @@ class BaseAction(ABC):
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy() self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
self.chat_type_allow: ChatType = getattr(self.__class__, "chat_type_allow", ChatType.ALL) self.chat_type_allow: ChatType = getattr(self.__class__, "chat_type_allow", ChatType.ALL)
# ============================================================================= # =============================================================================
# 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解) # 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解)
# ============================================================================= # =============================================================================
@@ -398,6 +397,7 @@ class BaseAction(ABC):
try: try:
# 1. 从注册中心获取Action类 # 1. 从注册中心获取Action类
from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.component_registry import component_registry
action_class = component_registry.get_component_class(action_name, ComponentType.ACTION) action_class = component_registry.get_component_class(action_name, ComponentType.ACTION)
if not action_class: if not action_class:
logger.error(f"{log_prefix} 未找到Action: {action_name}") logger.error(f"{log_prefix} 未找到Action: {action_name}")
@@ -406,7 +406,7 @@ class BaseAction(ABC):
# 2. 准备实例化参数 # 2. 准备实例化参数
# 复用当前Action的大部分上下文信息 # 复用当前Action的大部分上下文信息
called_action_data = action_data if action_data is not None else self.action_data called_action_data = action_data if action_data is not None else self.action_data
component_info = component_registry.get_component_info(action_name, ComponentType.ACTION) component_info = component_registry.get_component_info(action_name, ComponentType.ACTION)
if not component_info: if not component_info:
logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}") logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}")

View File

@@ -98,7 +98,7 @@ class BaseEventHandler(ABC):
weight=cls.weight, weight=cls.weight,
intercept_message=cls.intercept_message, intercept_message=cls.intercept_message,
) )
def set_plugin_name(self, plugin_name: str) -> None: def set_plugin_name(self, plugin_name: str) -> None:
"""设置插件名称 """设置插件名称
@@ -107,9 +107,9 @@ class BaseEventHandler(ABC):
""" """
self.plugin_name = plugin_name self.plugin_name = plugin_name
def set_plugin_config(self,plugin_config) -> None: def set_plugin_config(self, plugin_config) -> None:
self.plugin_config = plugin_config self.plugin_config = plugin_config
def get_config(self, key: str, default=None): def get_config(self, key: str, default=None):
"""获取插件配置值,支持嵌套键访问 """获取插件配置值,支持嵌套键访问

View File

@@ -69,7 +69,7 @@ class EventType(Enum):
""" """
ON_START = "on_start" # 启动事件,用于调用按时任务 ON_START = "on_start" # 启动事件,用于调用按时任务
ON_STOP ="on_stop" ON_STOP = "on_stop"
ON_MESSAGE = "on_message" ON_MESSAGE = "on_message"
ON_PLAN = "on_plan" ON_PLAN = "on_plan"
POST_LLM = "post_llm" POST_LLM = "post_llm"

View File

@@ -270,7 +270,9 @@ class ComponentRegistry:
# 使用EventManager进行事件处理器注册 # 使用EventManager进行事件处理器注册
from src.plugin_system.core.event_manager import event_manager from src.plugin_system.core.event_manager import event_manager
return event_manager.register_event_handler(handler_class,self.get_plugin_config(handler_info.plugin_name) or {}) return event_manager.register_event_handler(
handler_class, self.get_plugin_config(handler_info.plugin_name) or {}
)
# === 组件移除相关 === # === 组件移除相关 ===
@@ -682,19 +684,20 @@ class ComponentRegistry:
plugin_instance = plugin_manager.get_plugin_instance(plugin_name) plugin_instance = plugin_manager.get_plugin_instance(plugin_name)
if plugin_instance and plugin_instance.config: if plugin_instance and plugin_instance.config:
return plugin_instance.config return plugin_instance.config
# 如果插件实例不存在,尝试从配置文件读取 # 如果插件实例不存在,尝试从配置文件读取
try: try:
import toml import toml
config_path = Path("config") / "plugins" / plugin_name / "config.toml" config_path = Path("config") / "plugins" / plugin_name / "config.toml"
if config_path.exists(): if config_path.exists():
with open(config_path, 'r', encoding='utf-8') as f: with open(config_path, "r", encoding="utf-8") as f:
config_data = toml.load(f) config_data = toml.load(f)
logger.debug(f"从配置文件读取插件 {plugin_name} 的配置") logger.debug(f"从配置文件读取插件 {plugin_name} 的配置")
return config_data return config_data
except Exception as e: except Exception as e:
logger.debug(f"读取插件 {plugin_name} 配置文件失败: {e}") logger.debug(f"读取插件 {plugin_name} 配置文件失败: {e}")
return {} return {}
def get_registry_stats(self) -> Dict[str, Any]: def get_registry_stats(self) -> Dict[str, Any]:

View File

@@ -145,7 +145,9 @@ class EventManager:
logger.info(f"事件 {event_name} 已禁用") logger.info(f"事件 {event_name} 已禁用")
return True return True
def register_event_handler(self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None) -> bool: def register_event_handler(
self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None
) -> bool:
"""注册事件处理器 """注册事件处理器
Args: Args:
@@ -167,7 +169,7 @@ class EventManager:
# 创建事件处理器实例,传递插件配置 # 创建事件处理器实例,传递插件配置
handler_instance = handler_class() handler_instance = handler_class()
handler_instance.plugin_config = plugin_config handler_instance.plugin_config = plugin_config
if plugin_config is not None and hasattr(handler_instance, 'set_plugin_config'): if plugin_config is not None and hasattr(handler_instance, "set_plugin_config"):
handler_instance.set_plugin_config(plugin_config) handler_instance.set_plugin_config(plugin_config)
self._event_handlers[handler_name] = handler_instance self._event_handlers[handler_name] = handler_instance

View File

@@ -199,9 +199,7 @@ class PluginManager:
self._show_plugin_components(plugin_name) self._show_plugin_components(plugin_name)
# 检查并调用 on_plugin_loaded 钩子(如果存在) # 检查并调用 on_plugin_loaded 钩子(如果存在)
if hasattr(plugin_instance, "on_plugin_loaded") and callable( if hasattr(plugin_instance, "on_plugin_loaded") and callable(plugin_instance.on_plugin_loaded):
plugin_instance.on_plugin_loaded
):
logger.debug(f"为插件 '{plugin_name}' 调用 on_plugin_loaded 钩子") logger.debug(f"为插件 '{plugin_name}' 调用 on_plugin_loaded 钩子")
try: try:
# 使用 asyncio.create_task 确保它不会阻塞加载流程 # 使用 asyncio.create_task 确保它不会阻塞加载流程

View File

@@ -64,50 +64,50 @@ class AtAction(BaseAction):
# 使用回复器生成艾特回复,而不是直接发送命令 # 使用回复器生成艾特回复,而不是直接发送命令
from src.chat.replyer.default_generator import DefaultReplyer from src.chat.replyer.default_generator import DefaultReplyer
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
# 获取当前聊天流 # 获取当前聊天流
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
chat_stream = self.chat_stream or chat_manager.get_stream(self.chat_id) chat_stream = self.chat_stream or chat_manager.get_stream(self.chat_id)
if not chat_stream: if not chat_stream:
logger.error(f"找不到聊天流: {self.chat_stream}") logger.error(f"找不到聊天流: {self.chat_stream}")
return False, "聊天流不存在" return False, "聊天流不存在"
# 创建回复器实例 # 创建回复器实例
replyer = DefaultReplyer(chat_stream) replyer = DefaultReplyer(chat_stream)
# 构建回复对象,将艾特消息作为回复目标 # 构建回复对象,将艾特消息作为回复目标
reply_to = f"{user_name}:{at_message}" reply_to = f"{user_name}:{at_message}"
extra_info = f"你需要艾特用户 {user_name} 并回复他们说: {at_message}" extra_info = f"你需要艾特用户 {user_name} 并回复他们说: {at_message}"
# 使用回复器生成回复 # 使用回复器生成回复
success, llm_response, prompt = await replyer.generate_reply_with_context( success, llm_response, prompt = await replyer.generate_reply_with_context(
reply_to=reply_to, reply_to=reply_to,
extra_info=extra_info, extra_info=extra_info,
enable_tool=False, # 艾特回复通常不需要工具调用 enable_tool=False, # 艾特回复通常不需要工具调用
from_plugin=False from_plugin=False,
) )
if success and llm_response: if success and llm_response:
# 获取生成的回复内容 # 获取生成的回复内容
reply_content = llm_response.get("content", "") reply_content = llm_response.get("content", "")
if reply_content: if reply_content:
# 获取用户QQ号发送真正的艾特消息 # 获取用户QQ号发送真正的艾特消息
user_id = user_info.get("user_id") user_id = user_info.get("user_id")
# 发送真正的艾特命令,使用回复器生成的智能内容 # 发送真正的艾特命令,使用回复器生成的智能内容
await self.send_command( await self.send_command(
"SEND_AT_MESSAGE", "SEND_AT_MESSAGE",
args={"qq_id": user_id, "text": reply_content}, args={"qq_id": user_id, "text": reply_content},
display_message=f"艾特用户 {user_name} 并发送智能回复: {reply_content}", display_message=f"艾特用户 {user_name} 并发送智能回复: {reply_content}",
) )
await self.store_action_info( await self.store_action_info(
action_build_into_prompt=True, action_build_into_prompt=True,
action_prompt_display=f"执行了艾特用户动作:艾特用户 {user_name} 并发送智能回复: {reply_content}", action_prompt_display=f"执行了艾特用户动作:艾特用户 {user_name} 并发送智能回复: {reply_content}",
action_done=True, action_done=True,
) )
logger.info(f"成功通过回复器生成智能内容并发送真正的艾特消息给 {user_name}: {reply_content}") logger.info(f"成功通过回复器生成智能内容并发送真正的艾特消息给 {user_name}: {reply_content}")
return True, "智能艾特消息发送成功" return True, "智能艾特消息发送成功"
else: else:
@@ -116,7 +116,7 @@ class AtAction(BaseAction):
else: else:
logger.error("回复器生成回复失败") logger.error("回复器生成回复失败")
return False, "回复生成失败" return False, "回复生成失败"
except Exception as e: except Exception as e:
logger.error(f"执行艾特用户动作时发生异常: {e}", exc_info=True) logger.error(f"执行艾特用户动作时发生异常: {e}", exc_info=True)
await self.store_action_info( await self.store_action_info(

View File

@@ -70,7 +70,9 @@ class EmojiAction(BaseAction):
# 2. 获取所有有效的表情包对象 # 2. 获取所有有效的表情包对象
emoji_manager = get_emoji_manager() emoji_manager = get_emoji_manager()
all_emojis_obj: list[MaiEmoji] = [e for e in emoji_manager.emoji_objects if not e.is_deleted and e.description] all_emojis_obj: list[MaiEmoji] = [
e for e in emoji_manager.emoji_objects if not e.is_deleted and e.description
]
if not all_emojis_obj: if not all_emojis_obj:
logger.warning(f"{self.log_prefix} 无法获取任何带有描述的有效表情包") logger.warning(f"{self.log_prefix} 无法获取任何带有描述的有效表情包")
return False, "无法获取任何带有描述的有效表情包" return False, "无法获取任何带有描述的有效表情包"
@@ -91,12 +93,12 @@ class EmojiAction(BaseAction):
# 4. 准备情感数据和后备列表 # 4. 准备情感数据和后备列表
emotion_map = {} emotion_map = {}
all_emojis_data = [] all_emojis_data = []
for emoji in all_emojis_obj: for emoji in all_emojis_obj:
b64 = image_path_to_base64(emoji.full_path) b64 = image_path_to_base64(emoji.full_path)
if not b64: if not b64:
continue continue
desc = emoji.description desc = emoji.description
emotions = emoji.emotion emotions = emoji.emotion
all_emojis_data.append((b64, desc)) all_emojis_data.append((b64, desc))
@@ -168,16 +170,18 @@ class EmojiAction(BaseAction):
# 使用模糊匹配来查找最相关的情感标签 # 使用模糊匹配来查找最相关的情感标签
matched_key = next((key for key in emotion_map if chosen_emotion in key), None) matched_key = next((key for key in emotion_map if chosen_emotion in key), None)
if matched_key: if matched_key:
emoji_base64, emoji_description = random.choice(emotion_map[matched_key]) emoji_base64, emoji_description = random.choice(emotion_map[matched_key])
logger.info(f"{self.log_prefix} 找到匹配情感 '{chosen_emotion}' (匹配到: '{matched_key}') 的表情包: {emoji_description}") logger.info(
f"{self.log_prefix} 找到匹配情感 '{chosen_emotion}' (匹配到: '{matched_key}') 的表情包: {emoji_description}"
)
else: else:
logger.warning( logger.warning(
f"{self.log_prefix} LLM选择的情感 '{chosen_emotion}' 不在可用列表中, 将随机选择一个表情包" f"{self.log_prefix} LLM选择的情感 '{chosen_emotion}' 不在可用列表中, 将随机选择一个表情包"
) )
emoji_base64, emoji_description = random.choice(all_emojis_data) emoji_base64, emoji_description = random.choice(all_emojis_data)
elif global_config.emoji.emoji_selection_mode == "description": elif global_config.emoji.emoji_selection_mode == "description":
# --- 详细描述选择模式 --- # --- 详细描述选择模式 ---
# 获取最近的5条消息内容用于判断 # 获取最近的5条消息内容用于判断
@@ -226,15 +230,23 @@ class EmojiAction(BaseAction):
logger.info(f"{self.log_prefix} LLM选择的描述: {chosen_description}") logger.info(f"{self.log_prefix} LLM选择的描述: {chosen_description}")
# 简单关键词匹配 # 简单关键词匹配
matched_emoji = next((item for item in all_emojis_data if chosen_description.lower() in item[1].lower() or item[1].lower() in chosen_description.lower()), None) matched_emoji = next(
(
item
for item in all_emojis_data
if chosen_description.lower() in item[1].lower()
or item[1].lower() in chosen_description.lower()
),
None,
)
# 如果包含匹配失败,尝试关键词匹配 # 如果包含匹配失败,尝试关键词匹配
if not matched_emoji: if not matched_emoji:
keywords = ['惊讶', '困惑', '呆滞', '震惊', '', '无语', '', '可爱'] keywords = ["惊讶", "困惑", "呆滞", "震惊", "", "无语", "", "可爱"]
for keyword in keywords: for keyword in keywords:
if keyword in chosen_description: if keyword in chosen_description:
for item in all_emojis_data: for item in all_emojis_data:
if any(k in item[1] for k in ['', '', '', '困惑', '无语']): if any(k in item[1] for k in ["", "", "", "困惑", "无语"]):
matched_emoji = item matched_emoji = item
break break
if matched_emoji: if matched_emoji:
@@ -255,7 +267,9 @@ class EmojiAction(BaseAction):
if not success: if not success:
logger.error(f"{self.log_prefix} 表情包发送失败") logger.error(f"{self.log_prefix} 表情包发送失败")
await self.store_action_info(action_build_into_prompt = True,action_prompt_display =f"发送了一个表情包,但失败了",action_done= False) await self.store_action_info(
action_build_into_prompt=True, action_prompt_display=f"发送了一个表情包,但失败了", action_done=False
)
return False, "表情包发送失败" return False, "表情包发送失败"
# 发送成功后,记录到历史 # 发送成功后,记录到历史
@@ -263,8 +277,10 @@ class EmojiAction(BaseAction):
add_emoji_to_history(self.chat_id, emoji_description) add_emoji_to_history(self.chat_id, emoji_description)
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix} 添加表情到历史记录时出错: {e}") logger.error(f"{self.log_prefix} 添加表情到历史记录时出错: {e}")
await self.store_action_info(action_build_into_prompt = True,action_prompt_display =f"发送了一个表情包",action_done= True) await self.store_action_info(
action_build_into_prompt=True, action_prompt_display=f"发送了一个表情包", action_done=True
)
return True, f"发送表情包: {emoji_description}" return True, f"发送表情包: {emoji_description}"

View File

@@ -1,4 +1,3 @@
from src.plugin_system import BaseEventHandler from src.plugin_system import BaseEventHandler
from src.plugin_system.base.base_event import HandlerResult from src.plugin_system.base.base_event import HandlerResult
@@ -1748,6 +1747,7 @@ class SetGroupSignHandler(BaseEventHandler):
logger.error("事件 napcat_set_group_sign 请求失败!") logger.error("事件 napcat_set_group_sign 请求失败!")
return HandlerResult(False, False, {"status": "error"}) return HandlerResult(False, False, {"status": "error"})
# ===PERSONAL=== # ===PERSONAL===
class SetInputStatusHandler(BaseEventHandler): class SetInputStatusHandler(BaseEventHandler):
handler_name: str = "napcat_set_input_status_handler" handler_name: str = "napcat_set_input_status_handler"

View File

@@ -227,7 +227,7 @@ class LauchNapcatAdapterHandler(BaseEventHandler):
await reassembler.start_cleanup_task() await reassembler.start_cleanup_task()
logger.info("开始启动Napcat Adapter") logger.info("开始启动Napcat Adapter")
# 创建单独的异步任务,防止阻塞主线程 # 创建单独的异步任务,防止阻塞主线程
asyncio.create_task(self._start_maibot_connection()) asyncio.create_task(self._start_maibot_connection())
asyncio.create_task(napcat_server(self.plugin_config)) asyncio.create_task(napcat_server(self.plugin_config))
@@ -238,10 +238,10 @@ class LauchNapcatAdapterHandler(BaseEventHandler):
"""非阻塞方式启动MaiBot连接等待主服务启动后再连接""" """非阻塞方式启动MaiBot连接等待主服务启动后再连接"""
# 等待一段时间让MaiBot主服务完全启动 # 等待一段时间让MaiBot主服务完全启动
await asyncio.sleep(5) await asyncio.sleep(5)
max_attempts = 10 max_attempts = 10
attempt = 0 attempt = 0
while attempt < max_attempts: while attempt < max_attempts:
try: try:
logger.info(f"尝试连接MaiBot (第{attempt + 1}次)") logger.info(f"尝试连接MaiBot (第{attempt + 1}次)")
@@ -285,7 +285,7 @@ class NapcatAdapterPlugin(BasePlugin):
def enable_plugin(self) -> bool: def enable_plugin(self) -> bool:
"""通过配置文件动态控制插件启用状态""" """通过配置文件动态控制插件启用状态"""
# 如果已经通过配置加载了状态,使用配置中的值 # 如果已经通过配置加载了状态,使用配置中的值
if hasattr(self, '_is_enabled'): if hasattr(self, "_is_enabled"):
return self._is_enabled return self._is_enabled
# 否则使用默认值(禁用状态) # 否则使用默认值(禁用状态)
return False return False
@@ -308,60 +308,107 @@ class NapcatAdapterPlugin(BasePlugin):
"nickname": ConfigField(type=str, default="", description="昵称配置(目前未使用)"), "nickname": ConfigField(type=str, default="", description="昵称配置(目前未使用)"),
}, },
"napcat_server": { "napcat_server": {
"mode": ConfigField(type=str, default="reverse", description="连接模式reverse=反向连接(作为服务器), forward=正向连接(作为客户端)", choices=["reverse", "forward"]), "mode": ConfigField(
type=str,
default="reverse",
description="连接模式reverse=反向连接(作为服务器), forward=正向连接(作为客户端)",
choices=["reverse", "forward"],
),
"host": ConfigField(type=str, default="localhost", description="主机地址"), "host": ConfigField(type=str, default="localhost", description="主机地址"),
"port": ConfigField(type=int, default=8095, description="端口号"), "port": ConfigField(type=int, default=8095, description="端口号"),
"url": ConfigField(type=str, default="", description="正向连接时的完整WebSocket URL如 ws://localhost:8080/ws (仅在forward模式下使用)"), "url": ConfigField(
"access_token": ConfigField(type=str, default="", description="WebSocket 连接的访问令牌,用于身份验证(可选)"), type=str,
default="",
description="正向连接时的完整WebSocket URL如 ws://localhost:8080/ws (仅在forward模式下使用)",
),
"access_token": ConfigField(
type=str, default="", description="WebSocket 连接的访问令牌,用于身份验证(可选)"
),
"heartbeat_interval": ConfigField(type=int, default=30, description="心跳间隔时间(按秒计)"), "heartbeat_interval": ConfigField(type=int, default=30, description="心跳间隔时间(按秒计)"),
}, },
"maibot_server": { "maibot_server": {
"host": ConfigField(type=str, default="localhost", description="麦麦在.env文件中设置的主机地址即HOST字段"), "host": ConfigField(
type=str, default="localhost", description="麦麦在.env文件中设置的主机地址即HOST字段"
),
"port": ConfigField(type=int, default=8000, description="麦麦在.env文件中设置的端口即PORT字段"), "port": ConfigField(type=int, default=8000, description="麦麦在.env文件中设置的端口即PORT字段"),
"platform_name": ConfigField(type=str, default="qq", description="平台名称,用于消息路由"), "platform_name": ConfigField(type=str, default="qq", description="平台名称,用于消息路由"),
}, },
"voice": { "voice": {
"use_tts": ConfigField(type=bool, default=False, description="是否使用tts语音请确保你配置了tts并有对应的adapter"), "use_tts": ConfigField(
type=bool, default=False, description="是否使用tts语音请确保你配置了tts并有对应的adapter"
),
}, },
"slicing": { "slicing": {
"max_frame_size": ConfigField(type=int, default=64, description="WebSocket帧的最大大小单位为字节默认64KB"), "max_frame_size": ConfigField(
type=int, default=64, description="WebSocket帧的最大大小单位为字节默认64KB"
),
"delay_ms": ConfigField(type=int, default=10, description="切片发送间隔时间,单位为毫秒"), "delay_ms": ConfigField(type=int, default=10, description="切片发送间隔时间,单位为毫秒"),
}, },
"debug": { "debug": {
"level": ConfigField(type=str, default="INFO", description="日志等级DEBUG, INFO, WARNING, ERROR, CRITICAL", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), "level": ConfigField(
type=str,
default="INFO",
description="日志等级DEBUG, INFO, WARNING, ERROR, CRITICAL",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
),
}, },
"features": { "features": {
# 权限设置 # 权限设置
"group_list_type": ConfigField(type=str, default="blacklist", description="群聊列表类型whitelist白名单或 blacklist黑名单", choices=["whitelist", "blacklist"]), "group_list_type": ConfigField(
type=str,
default="blacklist",
description="群聊列表类型whitelist白名单或 blacklist黑名单",
choices=["whitelist", "blacklist"],
),
"group_list": ConfigField(type=list, default=[], description="群聊ID列表"), "group_list": ConfigField(type=list, default=[], description="群聊ID列表"),
"private_list_type": ConfigField(type=str, default="blacklist", description="私聊列表类型whitelist白名单或 blacklist黑名单", choices=["whitelist", "blacklist"]), "private_list_type": ConfigField(
type=str,
default="blacklist",
description="私聊列表类型whitelist白名单或 blacklist黑名单",
choices=["whitelist", "blacklist"],
),
"private_list": ConfigField(type=list, default=[], description="用户ID列表"), "private_list": ConfigField(type=list, default=[], description="用户ID列表"),
"ban_user_id": ConfigField(type=list, default=[], description="全局禁止用户ID列表这些用户无法在任何地方使用机器人"), "ban_user_id": ConfigField(
type=list, default=[], description="全局禁止用户ID列表这些用户无法在任何地方使用机器人"
),
"ban_qq_bot": ConfigField(type=bool, default=False, description="是否屏蔽QQ官方机器人消息"), "ban_qq_bot": ConfigField(type=bool, default=False, description="是否屏蔽QQ官方机器人消息"),
# 聊天功能设置 # 聊天功能设置
"enable_poke": ConfigField(type=bool, default=True, description="是否启用戳一戳功能"), "enable_poke": ConfigField(type=bool, default=True, description="是否启用戳一戳功能"),
"ignore_non_self_poke": ConfigField(type=bool, default=False, description="是否无视不是针对自己的戳一戳"), "ignore_non_self_poke": ConfigField(type=bool, default=False, description="是否无视不是针对自己的戳一戳"),
"poke_debounce_seconds": ConfigField(type=int, default=3, description="戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略"), "poke_debounce_seconds": ConfigField(
type=int, default=3, description="戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略"
),
"enable_reply_at": ConfigField(type=bool, default=True, description="是否启用引用回复时艾特用户的功能"), "enable_reply_at": ConfigField(type=bool, default=True, description="是否启用引用回复时艾特用户的功能"),
"reply_at_rate": ConfigField(type=float, default=0.5, description="引用回复时艾特用户的几率 (0.0 ~ 1.0)"), "reply_at_rate": ConfigField(type=float, default=0.5, description="引用回复时艾特用户的几率 (0.0 ~ 1.0)"),
"enable_emoji_like": ConfigField(type=bool, default=True, description="是否启用群聊表情回复功能"), "enable_emoji_like": ConfigField(type=bool, default=True, description="是否启用群聊表情回复功能"),
# 视频处理设置 # 视频处理设置
"enable_video_analysis": ConfigField(type=bool, default=True, description="是否启用视频识别功能"), "enable_video_analysis": ConfigField(type=bool, default=True, description="是否启用视频识别功能"),
"max_video_size_mb": ConfigField(type=int, default=100, description="视频文件最大大小限制MB"), "max_video_size_mb": ConfigField(type=int, default=100, description="视频文件最大大小限制MB"),
"download_timeout": ConfigField(type=int, default=60, description="视频下载超时时间(秒)"), "download_timeout": ConfigField(type=int, default=60, description="视频下载超时时间(秒)"),
"supported_formats": ConfigField(type=list, default=["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"], description="支持的视频格式"), "supported_formats": ConfigField(
type=list, default=["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"], description="支持的视频格式"
),
# 消息缓冲设置 # 消息缓冲设置
"enable_message_buffer": ConfigField(type=bool, default=True, description="是否启用消息缓冲合并功能"), "enable_message_buffer": ConfigField(type=bool, default=True, description="是否启用消息缓冲合并功能"),
"message_buffer_enable_group": ConfigField(type=bool, default=True, description="是否启用群聊消息缓冲合并"), "message_buffer_enable_group": ConfigField(type=bool, default=True, description="是否启用群聊消息缓冲合并"),
"message_buffer_enable_private": ConfigField(type=bool, default=True, description="是否启用私聊消息缓冲合并"), "message_buffer_enable_private": ConfigField(
"message_buffer_interval": ConfigField(type=float, default=3.0, description="消息合并间隔时间(秒),在此时间内的连续消息将被合并"), type=bool, default=True, description="是否启用私聊消息缓冲合并"
"message_buffer_initial_delay": ConfigField(type=float, default=0.5, description="消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并"), ),
"message_buffer_max_components": ConfigField(type=int, default=50, description="单个会话最大缓冲消息组件数量,超过此数量将强制合并"), "message_buffer_interval": ConfigField(
"message_buffer_block_prefixes": ConfigField(type=list, default=["/", "!", "", ".", "", "#", "%"], description="消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲"), type=float, default=3.0, description="消息合并间隔时间(秒),在此时间内的连续消息将被合并"
} ),
"message_buffer_initial_delay": ConfigField(
type=float, default=0.5, description="消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并"
),
"message_buffer_max_components": ConfigField(
type=int, default=50, description="单个会话最大缓冲消息组件数量,超过此数量将强制合并"
),
"message_buffer_block_prefixes": ConfigField(
type=list,
default=["/", "!", "", ".", "", "#", "%"],
description="消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲",
),
},
} }
# 配置节描述 # 配置节描述
@@ -374,7 +421,7 @@ class NapcatAdapterPlugin(BasePlugin):
"voice": "发送语音设置", "voice": "发送语音设置",
"slicing": "WebSocket消息切片设置", "slicing": "WebSocket消息切片设置",
"debug": "调试设置", "debug": "调试设置",
"features": "功能设置(权限控制、聊天功能、视频处理、消息缓冲等)" "features": "功能设置(权限控制、聊天功能、视频处理、消息缓冲等)",
} }
def register_events(self): def register_events(self):
@@ -409,6 +456,7 @@ class NapcatAdapterPlugin(BasePlugin):
chunker.set_plugin_config(self.config) chunker.set_plugin_config(self.config)
# 设置response_pool的插件配置 # 设置response_pool的插件配置
from .src.response_pool import set_plugin_config as set_response_pool_config from .src.response_pool import set_plugin_config as set_response_pool_config
set_response_pool_config(self.config) set_response_pool_config(self.config)
# 设置send_handler的插件配置 # 设置send_handler的插件配置
send_handler.set_plugin_config(self.config) send_handler.set_plugin_config(self.config)
@@ -418,4 +466,4 @@ class NapcatAdapterPlugin(BasePlugin):
notice_handler.set_plugin_config(self.config) notice_handler.set_plugin_config(self.config)
# 设置meta_event_handler的插件配置 # 设置meta_event_handler的插件配置
meta_event_handler.set_plugin_config(self.config) meta_event_handler.set_plugin_config(self.config)
# 设置其他handler的插件配置现在由component_registry在注册时自动设置 # 设置其他handler的插件配置现在由component_registry在注册时自动设置

View File

@@ -102,7 +102,9 @@ class SimpleMessageBuffer:
return True return True
# 检查屏蔽前缀 # 检查屏蔽前缀
block_prefixes = tuple(config_api.get_plugin_config(self.plugin_config, "features.message_buffer_block_prefixes", [])) block_prefixes = tuple(
config_api.get_plugin_config(self.plugin_config, "features.message_buffer_block_prefixes", [])
)
text = text.strip() text = text.strip()
if text.startswith(block_prefixes): if text.startswith(block_prefixes):
@@ -134,9 +136,13 @@ class SimpleMessageBuffer:
# 检查是否启用对应类型的缓冲 # 检查是否启用对应类型的缓冲
message_type = event_data.get("message_type", "") message_type = event_data.get("message_type", "")
if message_type == "group" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_group", False): if message_type == "group" and not config_api.get_plugin_config(
self.plugin_config, "features.message_buffer_enable_group", False
):
return False return False
elif message_type == "private" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_private", False): elif message_type == "private" and not config_api.get_plugin_config(
self.plugin_config, "features.message_buffer_enable_private", False
):
return False return False
# 提取文本 # 提取文本
@@ -158,7 +164,9 @@ class SimpleMessageBuffer:
session = self.buffer_pool[session_id] session = self.buffer_pool[session_id]
# 检查是否超过最大组件数量 # 检查是否超过最大组件数量
if len(session.messages) >= config_api.get_plugin_config(self.plugin_config, "features.message_buffer_max_components", 5): if len(session.messages) >= config_api.get_plugin_config(
self.plugin_config, "features.message_buffer_max_components", 5
):
logger.debug(f"会话 {session_id} 消息数量达到上限,强制合并") logger.debug(f"会话 {session_id} 消息数量达到上限,强制合并")
asyncio.create_task(self._force_merge_session(session_id)) asyncio.create_task(self._force_merge_session(session_id))
self.buffer_pool[session_id] = BufferedSession(session_id=session_id, original_event=original_event) self.buffer_pool[session_id] = BufferedSession(session_id=session_id, original_event=original_event)

View File

@@ -14,7 +14,7 @@ def create_router(plugin_config: dict):
platform_name = config_api.get_plugin_config(plugin_config, "maibot_server.platform_name", "qq") platform_name = config_api.get_plugin_config(plugin_config, "maibot_server.platform_name", "qq")
host = config_api.get_plugin_config(plugin_config, "maibot_server.host", "localhost") host = config_api.get_plugin_config(plugin_config, "maibot_server.host", "localhost")
port = config_api.get_plugin_config(plugin_config, "maibot_server.port", 8000) port = config_api.get_plugin_config(plugin_config, "maibot_server.port", 8000)
route_config = RouteConfig( route_config = RouteConfig(
route_config={ route_config={
platform_name: TargetConfig( platform_name: TargetConfig(
@@ -32,7 +32,7 @@ async def mmc_start_com(plugin_config: dict = None):
logger.info("正在连接MaiBot") logger.info("正在连接MaiBot")
if plugin_config: if plugin_config:
create_router(plugin_config) create_router(plugin_config)
if router: if router:
router.register_class_handler(send_handler.handle_message) router.register_class_handler(send_handler.handle_message)
await router.run() await router.run()

View File

@@ -32,7 +32,7 @@ class NoticeType: # 通知事件
group_recall = "group_recall" # 群聊消息撤回 group_recall = "group_recall" # 群聊消息撤回
notify = "notify" notify = "notify"
group_ban = "group_ban" # 群禁言 group_ban = "group_ban" # 群禁言
group_msg_emoji_like = "group_msg_emoji_like" # 群聊表情回复 group_msg_emoji_like = "group_msg_emoji_like" # 群聊表情回复
class Notify: class Notify:
poke = "poke" # 戳一戳 poke = "poke" # 戳一戳

View File

@@ -100,7 +100,7 @@ class MessageHandler:
# 检查群聊黑白名单 # 检查群聊黑白名单
group_list_type = config_api.get_plugin_config(self.plugin_config, "features.group_list_type", "blacklist") group_list_type = config_api.get_plugin_config(self.plugin_config, "features.group_list_type", "blacklist")
group_list = config_api.get_plugin_config(self.plugin_config, "features.group_list", []) group_list = config_api.get_plugin_config(self.plugin_config, "features.group_list", [])
if group_list_type == "whitelist": if group_list_type == "whitelist":
if group_id not in group_list: if group_id not in group_list:
logger.warning("群聊不在白名单中,消息被丢弃") logger.warning("群聊不在白名单中,消息被丢弃")
@@ -111,9 +111,11 @@ class MessageHandler:
return False return False
else: else:
# 检查私聊黑白名单 # 检查私聊黑白名单
private_list_type = config_api.get_plugin_config(self.plugin_config, "features.private_list_type", "blacklist") private_list_type = config_api.get_plugin_config(
self.plugin_config, "features.private_list_type", "blacklist"
)
private_list = config_api.get_plugin_config(self.plugin_config, "features.private_list", []) private_list = config_api.get_plugin_config(self.plugin_config, "features.private_list", [])
if private_list_type == "whitelist": if private_list_type == "whitelist":
if user_id not in private_list: if user_id not in private_list:
logger.warning("私聊不在白名单中,消息被丢弃") logger.warning("私聊不在白名单中,消息被丢弃")
@@ -156,21 +158,23 @@ class MessageHandler:
Parameters: Parameters:
raw_message: dict: 原始消息 raw_message: dict: 原始消息
""" """
# 添加原始消息调试日志特别关注message字段 # 添加原始消息调试日志特别关注message字段
logger.debug(f"收到原始消息: message_type={raw_message.get('message_type')}, message_id={raw_message.get('message_id')}") logger.debug(
f"收到原始消息: message_type={raw_message.get('message_type')}, message_id={raw_message.get('message_id')}"
)
logger.debug(f"原始消息内容: {raw_message.get('message', [])}") logger.debug(f"原始消息内容: {raw_message.get('message', [])}")
# 检查是否包含@或video消息段 # 检查是否包含@或video消息段
message_segments = raw_message.get('message', []) message_segments = raw_message.get("message", [])
if message_segments: if message_segments:
for i, seg in enumerate(message_segments): for i, seg in enumerate(message_segments):
seg_type = seg.get('type') seg_type = seg.get("type")
if seg_type in ['at', 'video']: if seg_type in ["at", "video"]:
logger.info(f"检测到 {seg_type.upper()} 消息段 [{i}]: {seg}") logger.info(f"检测到 {seg_type.upper()} 消息段 [{i}]: {seg}")
elif seg_type not in ['text', 'face', 'image']: elif seg_type not in ["text", "face", "image"]:
logger.warning(f"检测到特殊消息段 [{i}]: type={seg_type}, data={seg.get('data', {})}") logger.warning(f"检测到特殊消息段 [{i}]: type={seg_type}, data={seg.get('data', {})}")
message_type: str = raw_message.get("message_type") message_type: str = raw_message.get("message_type")
message_id: int = raw_message.get("message_id") message_id: int = raw_message.get("message_id")
# message_time: int = raw_message.get("time") # message_time: int = raw_message.get("time")
@@ -308,9 +312,13 @@ class MessageHandler:
message_type = raw_message.get("message_type") message_type = raw_message.get("message_type")
should_use_buffer = False should_use_buffer = False
if message_type == "group" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_group", True): if message_type == "group" and config_api.get_plugin_config(
self.plugin_config, "features.message_buffer_enable_group", True
):
should_use_buffer = True should_use_buffer = True
elif message_type == "private" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_private", True): elif message_type == "private" and config_api.get_plugin_config(
self.plugin_config, "features.message_buffer_enable_private", True
):
should_use_buffer = True should_use_buffer = True
if should_use_buffer: if should_use_buffer:
@@ -368,10 +376,10 @@ class MessageHandler:
for sub_message in real_message: for sub_message in real_message:
sub_message: dict sub_message: dict
sub_message_type = sub_message.get("type") sub_message_type = sub_message.get("type")
# 添加详细的消息类型调试信息 # 添加详细的消息类型调试信息
logger.debug(f"处理消息段: type={sub_message_type}, data={sub_message.get('data', {})}") logger.debug(f"处理消息段: type={sub_message_type}, data={sub_message.get('data', {})}")
# 特别关注 at 和 video 消息的识别 # 特别关注 at 和 video 消息的识别
if sub_message_type == "at": if sub_message_type == "at":
logger.debug(f"检测到@消息: {sub_message}") logger.debug(f"检测到@消息: {sub_message}")
@@ -379,7 +387,7 @@ class MessageHandler:
logger.debug(f"检测到VIDEO消息: {sub_message}") logger.debug(f"检测到VIDEO消息: {sub_message}")
elif sub_message_type not in ["text", "face", "image", "record"]: elif sub_message_type not in ["text", "face", "image", "record"]:
logger.warning(f"检测到特殊消息类型: {sub_message_type}, 完整消息: {sub_message}") logger.warning(f"检测到特殊消息类型: {sub_message_type}, 完整消息: {sub_message}")
match sub_message_type: match sub_message_type:
case RealMessageType.text: case RealMessageType.text:
ret_seg = await self.handle_text_message(sub_message) ret_seg = await self.handle_text_message(sub_message)

View File

@@ -33,6 +33,7 @@ class MessageSending:
try: try:
# 重新导入router # 重新导入router
from ..mmc_com_layer import router from ..mmc_com_layer import router
self.maibot_router = router self.maibot_router = router
if self.maibot_router is not None: if self.maibot_router is not None:
logger.info("MaiBot router重连成功") logger.info("MaiBot router重连成功")
@@ -73,14 +74,14 @@ class MessageSending:
# 获取对应的客户端并发送切片 # 获取对应的客户端并发送切片
platform = message_base.message_info.platform platform = message_base.message_info.platform
# 再次检查router状态防止运行时被重置 # 再次检查router状态防止运行时被重置
if self.maibot_router is None or not hasattr(self.maibot_router, 'clients'): if self.maibot_router is None or not hasattr(self.maibot_router, "clients"):
logger.warning("MaiBot router连接已断开尝试重新连接") logger.warning("MaiBot router连接已断开尝试重新连接")
if not await self._attempt_reconnect(): if not await self._attempt_reconnect():
logger.error("MaiBot router重连失败切片发送中止") logger.error("MaiBot router重连失败切片发送中止")
return False return False
if platform not in self.maibot_router.clients: if platform not in self.maibot_router.clients:
logger.error(f"平台 {platform} 未连接") logger.error(f"平台 {platform} 未连接")
return False return False

View File

@@ -22,7 +22,9 @@ class MetaEventHandler:
"""设置插件配置""" """设置插件配置"""
self.plugin_config = plugin_config self.plugin_config = plugin_config
# 更新interval值 # 更新interval值
self.interval = config_api.get_plugin_config(self.plugin_config, "napcat_server.heartbeat_interval", 5000) / 1000 self.interval = (
config_api.get_plugin_config(self.plugin_config, "napcat_server.heartbeat_interval", 5000) / 1000
)
async def handle_meta_event(self, message: dict) -> None: async def handle_meta_event(self, message: dict) -> None:
event_type = message.get("meta_event_type") event_type = message.get("meta_event_type")

View File

@@ -116,9 +116,9 @@ class NoticeHandler:
sub_type = raw_message.get("sub_type") sub_type = raw_message.get("sub_type")
match sub_type: match sub_type:
case NoticeType.Notify.poke: case NoticeType.Notify.poke:
if config_api.get_plugin_config(self.plugin_config, "features.enable_poke", True) and await message_handler.check_allow_to_chat( if config_api.get_plugin_config(
user_id, group_id, False, False self.plugin_config, "features.enable_poke", True
): ) and await message_handler.check_allow_to_chat(user_id, group_id, False, False):
logger.debug("处理戳一戳消息") logger.debug("处理戳一戳消息")
handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id) handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id)
else: else:
@@ -127,14 +127,18 @@ class NoticeHandler:
from src.plugin_system.core.event_manager import event_manager from src.plugin_system.core.event_manager import event_manager
from ...event_types import NapcatEvent from ...event_types import NapcatEvent
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME) await event_manager.trigger_event(
NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME
)
case _: case _:
logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}") logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}")
case NoticeType.group_msg_emoji_like: case NoticeType.group_msg_emoji_like:
# 该事件转移到 handle_group_emoji_like_notify函数内触发 # 该事件转移到 handle_group_emoji_like_notify函数内触发
if config_api.get_plugin_config(self.plugin_config, "features.enable_emoji_like", True): if config_api.get_plugin_config(self.plugin_config, "features.enable_emoji_like", True):
logger.debug("处理群聊表情回复") logger.debug("处理群聊表情回复")
handled_message, user_info = await self.handle_group_emoji_like_notify(raw_message,group_id,user_id) handled_message, user_info = await self.handle_group_emoji_like_notify(
raw_message, group_id, user_id
)
else: else:
logger.warning("群聊表情回复被禁用,取消群聊表情回复处理") logger.warning("群聊表情回复被禁用,取消群聊表情回复处理")
case NoticeType.group_ban: case NoticeType.group_ban:
@@ -294,7 +298,7 @@ class NoticeHandler:
async def handle_group_emoji_like_notify(self, raw_message: dict, group_id: int, user_id: int): async def handle_group_emoji_like_notify(self, raw_message: dict, group_id: int, user_id: int):
if not group_id: if not group_id:
logger.error("群ID不能为空无法处理群聊表情回复通知") logger.error("群ID不能为空无法处理群聊表情回复通知")
return None, None return None, None
user_qq_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id) user_qq_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id)
if user_qq_info: if user_qq_info:
@@ -304,37 +308,42 @@ class NoticeHandler:
user_name = "QQ用户" user_name = "QQ用户"
user_cardname = "QQ用户" user_cardname = "QQ用户"
logger.debug("无法获取表情回复对方的用户昵称") logger.debug("无法获取表情回复对方的用户昵称")
from src.plugin_system.core.event_manager import event_manager from src.plugin_system.core.event_manager import event_manager
from ...event_types import NapcatEvent from ...event_types import NapcatEvent
target_message = await event_manager.trigger_event(NapcatEvent.MESSAGE.GET_MSG,message_id=raw_message.get("message_id","")) target_message = await event_manager.trigger_event(
target_message_text = target_message.get_message_result().get("data",{}).get("raw_message","") NapcatEvent.MESSAGE.GET_MSG, message_id=raw_message.get("message_id", "")
)
target_message_text = target_message.get_message_result().get("data", {}).get("raw_message", "")
if not target_message: if not target_message:
logger.error("未找到对应消息") logger.error("未找到对应消息")
return None, None return None, None
if len(target_message_text) > 15: if len(target_message_text) > 15:
target_message_text = target_message_text[:15] + "..." target_message_text = target_message_text[:15] + "..."
user_info: UserInfo = UserInfo( user_info: UserInfo = UserInfo(
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"), platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
user_id=user_id, user_id=user_id,
user_nickname=user_name, user_nickname=user_name,
user_cardname=user_cardname, user_cardname=user_cardname,
) )
like_emoji_id = raw_message.get("likes")[0].get("emoji_id") like_emoji_id = raw_message.get("likes")[0].get("emoji_id")
await event_manager.trigger_event( await event_manager.trigger_event(
NapcatEvent.ON_RECEIVED.EMOJI_LIEK, NapcatEvent.ON_RECEIVED.EMOJI_LIEK,
permission_group=PLUGIN_NAME, permission_group=PLUGIN_NAME,
group_id=group_id, group_id=group_id,
user_id=user_id, user_id=user_id,
message_id=raw_message.get("message_id",""), message_id=raw_message.get("message_id", ""),
emoji_id=like_emoji_id emoji_id=like_emoji_id,
) )
seg_data = Seg(type="text",data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id,"")}回复了你的消息[{target_message_text}]") seg_data = Seg(
type="text",
data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]",
)
return seg_data, user_info return seg_data, user_info
async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]: async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]:
if not group_id: if not group_id:
logger.error("群ID不能为空无法处理禁言通知") logger.error("群ID不能为空无法处理禁言通知")

View File

@@ -45,12 +45,12 @@ async def check_timeout_response() -> None:
while True: while True:
cleaned_message_count: int = 0 cleaned_message_count: int = 0
now_time = time.time() now_time = time.time()
# 获取心跳间隔配置 # 获取心跳间隔配置
heartbeat_interval = 30 # 默认值 heartbeat_interval = 30 # 默认值
if plugin_config: if plugin_config:
heartbeat_interval = config_api.get_plugin_config(plugin_config, "napcat_server.heartbeat_interval", 30) heartbeat_interval = config_api.get_plugin_config(plugin_config, "napcat_server.heartbeat_interval", 30)
for echo_id, response_time in list(response_time_dict.items()): for echo_id, response_time in list(response_time_dict.items()):
if now_time - response_time > heartbeat_interval: if now_time - response_time > heartbeat_interval:
cleaned_message_count += 1 cleaned_message_count += 1

View File

@@ -297,9 +297,9 @@ class SendHandler:
try: try:
# 检查是否为缓冲消息ID格式buffered-{original_id}-{timestamp} # 检查是否为缓冲消息ID格式buffered-{original_id}-{timestamp}
if id.startswith('buffered-'): if id.startswith("buffered-"):
# 从缓冲消息ID中提取原始消息ID # 从缓冲消息ID中提取原始消息ID
original_id = id.split('-')[1] original_id = id.split("-")[1]
msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(original_id)}) msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(original_id)})
else: else:
msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(id)}) msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(id)})
@@ -363,7 +363,7 @@ class SendHandler:
use_tts = False use_tts = False
if self.plugin_config: if self.plugin_config:
use_tts = config_api.get_plugin_config(self.plugin_config, "voice.use_tts", False) use_tts = config_api.get_plugin_config(self.plugin_config, "voice.use_tts", False)
if not use_tts: if not use_tts:
logger.warning("未启用语音消息处理") logger.warning("未启用语音消息处理")
return {} return {}

View File

@@ -18,7 +18,9 @@ class WebSocketManager:
self.max_reconnect_attempts = 10 # 最大重连次数 self.max_reconnect_attempts = 10 # 最大重连次数
self.plugin_config = None self.plugin_config = None
async def start_connection(self, message_handler: Callable[[Server.ServerConnection], Any], plugin_config: dict) -> None: async def start_connection(
self, message_handler: Callable[[Server.ServerConnection], Any], plugin_config: dict
) -> None:
"""根据配置启动 WebSocket 连接""" """根据配置启动 WebSocket 连接"""
self.plugin_config = plugin_config self.plugin_config = plugin_config
mode = config_api.get_plugin_config(plugin_config, "napcat_server.mode") mode = config_api.get_plugin_config(plugin_config, "napcat_server.mode")
@@ -72,9 +74,7 @@ class WebSocketManager:
# 如果配置了访问令牌,添加到请求头 # 如果配置了访问令牌,添加到请求头
access_token = config_api.get_plugin_config(self.plugin_config, "napcat_server.access_token") access_token = config_api.get_plugin_config(self.plugin_config, "napcat_server.access_token")
if access_token: if access_token:
connect_kwargs["additional_headers"] = { connect_kwargs["additional_headers"] = {"Authorization": f"Bearer {access_token}"}
"Authorization": f"Bearer {access_token}"
}
logger.info("已添加访问令牌到连接请求头") logger.info("已添加访问令牌到连接请求头")
async with Server.connect(url, **connect_kwargs) as websocket: async with Server.connect(url, **connect_kwargs) as websocket:

View File

@@ -1,6 +1,7 @@
""" """
Base search engine interface Base search engine interface
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Any from typing import Dict, List, Any
@@ -9,20 +10,20 @@ class BaseSearchEngine(ABC):
""" """
搜索引擎基类 搜索引擎基类
""" """
@abstractmethod @abstractmethod
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
""" """
执行搜索 执行搜索
Args: Args:
args: 搜索参数,包含 query、num_results、time_range 等 args: 搜索参数,包含 query、num_results、time_range 等
Returns: Returns:
搜索结果列表,每个结果包含 title、url、snippet、provider 字段 搜索结果列表,每个结果包含 title、url、snippet、provider 字段
""" """
pass pass
@abstractmethod @abstractmethod
def is_available(self) -> bool: def is_available(self) -> bool:
""" """

View File

@@ -1,6 +1,7 @@
""" """
Bing search engine implementation Bing search engine implementation
""" """
import asyncio import asyncio
import functools import functools
import random import random
@@ -58,21 +59,21 @@ class BingSearchEngine(BaseSearchEngine):
""" """
Bing搜索引擎实现 Bing搜索引擎实现
""" """
def __init__(self): def __init__(self):
self.session = requests.Session() self.session = requests.Session()
self.session.headers = HEADERS self.session.headers = HEADERS
def is_available(self) -> bool: def is_available(self) -> bool:
"""检查Bing搜索引擎是否可用""" """检查Bing搜索引擎是否可用"""
return True # Bing是免费搜索引擎总是可用 return True # Bing是免费搜索引擎总是可用
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
"""执行Bing搜索""" """执行Bing搜索"""
query = args["query"] query = args["query"]
num_results = args.get("num_results", 3) num_results = args.get("num_results", 3)
time_range = args.get("time_range", "any") time_range = args.get("time_range", "any")
try: try:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
func = functools.partial(self._search_sync, query, num_results, time_range) func = functools.partial(self._search_sync, query, num_results, time_range)
@@ -81,17 +82,17 @@ class BingSearchEngine(BaseSearchEngine):
except Exception as e: except Exception as e:
logger.error(f"Bing 搜索失败: {e}") logger.error(f"Bing 搜索失败: {e}")
return [] return []
def _search_sync(self, keyword: str, num_results: int, time_range: str) -> List[Dict[str, Any]]: def _search_sync(self, keyword: str, num_results: int, time_range: str) -> List[Dict[str, Any]]:
"""同步执行Bing搜索""" """同步执行Bing搜索"""
if not keyword: if not keyword:
return [] return []
list_result = [] list_result = []
# 构建搜索URL # 构建搜索URL
search_url = bing_search_url + keyword search_url = bing_search_url + keyword
# 如果指定了时间范围,添加时间过滤参数 # 如果指定了时间范围,添加时间过滤参数
if time_range == "week": if time_range == "week":
search_url += "&qft=+filterui:date-range-7" search_url += "&qft=+filterui:date-range-7"
@@ -181,34 +182,29 @@ class BingSearchEngine(BaseSearchEngine):
# 尝试提取搜索结果 # 尝试提取搜索结果
# 方法1: 查找标准的搜索结果容器 # 方法1: 查找标准的搜索结果容器
results = root.select("ol#b_results li.b_algo") results = root.select("ol#b_results li.b_algo")
if results: if results:
for _rank, result in enumerate(results, 1): for _rank, result in enumerate(results, 1):
# 提取标题和链接 # 提取标题和链接
title_link = result.select_one("h2 a") title_link = result.select_one("h2 a")
if not title_link: if not title_link:
continue continue
title = title_link.get_text().strip() title = title_link.get_text().strip()
url = title_link.get("href", "") url = title_link.get("href", "")
# 提取摘要 # 提取摘要
abstract = "" abstract = ""
abstract_elem = result.select_one("div.b_caption p") abstract_elem = result.select_one("div.b_caption p")
if abstract_elem: if abstract_elem:
abstract = abstract_elem.get_text().strip() abstract = abstract_elem.get_text().strip()
# 限制摘要长度 # 限制摘要长度
if ABSTRACT_MAX_LENGTH and len(abstract) > ABSTRACT_MAX_LENGTH: if ABSTRACT_MAX_LENGTH and len(abstract) > ABSTRACT_MAX_LENGTH:
abstract = abstract[:ABSTRACT_MAX_LENGTH] + "..." abstract = abstract[:ABSTRACT_MAX_LENGTH] + "..."
list_data.append({ list_data.append({"title": title, "url": url, "snippet": abstract, "provider": "Bing"})
"title": title,
"url": url,
"snippet": abstract,
"provider": "Bing"
})
if len(list_data) >= 10: # 限制结果数量 if len(list_data) >= 10: # 限制结果数量
break break
@@ -216,22 +212,34 @@ class BingSearchEngine(BaseSearchEngine):
if not list_data: if not list_data:
# 查找所有可能的搜索结果链接 # 查找所有可能的搜索结果链接
all_links = root.find_all("a") all_links = root.find_all("a")
for link in all_links: for link in all_links:
href = link.get("href", "") href = link.get("href", "")
text = link.get_text().strip() text = link.get_text().strip()
# 过滤有效的搜索结果链接 # 过滤有效的搜索结果链接
if (href and text and len(text) > 10 if (
href
and text
and len(text) > 10
and not href.startswith("javascript:") and not href.startswith("javascript:")
and not href.startswith("#") and not href.startswith("#")
and "http" in href and "http" in href
and not any(x in href for x in [ and not any(
"bing.com/search", "bing.com/images", "bing.com/videos", x in href
"bing.com/maps", "bing.com/news", "login", "account", for x in [
"microsoft", "javascript" "bing.com/search",
])): "bing.com/images",
"bing.com/videos",
"bing.com/maps",
"bing.com/news",
"login",
"account",
"microsoft",
"javascript",
]
)
):
# 尝试获取摘要 # 尝试获取摘要
abstract = "" abstract = ""
parent = link.parent parent = link.parent
@@ -239,18 +247,13 @@ class BingSearchEngine(BaseSearchEngine):
full_text = parent.get_text().strip() full_text = parent.get_text().strip()
if len(full_text) > len(text): if len(full_text) > len(text):
abstract = full_text.replace(text, "", 1).strip() abstract = full_text.replace(text, "", 1).strip()
# 限制摘要长度 # 限制摘要长度
if ABSTRACT_MAX_LENGTH and len(abstract) > ABSTRACT_MAX_LENGTH: if ABSTRACT_MAX_LENGTH and len(abstract) > ABSTRACT_MAX_LENGTH:
abstract = abstract[:ABSTRACT_MAX_LENGTH] + "..." abstract = abstract[:ABSTRACT_MAX_LENGTH] + "..."
list_data.append({ list_data.append({"title": text, "url": href, "snippet": abstract, "provider": "Bing"})
"title": text,
"url": href,
"snippet": abstract,
"provider": "Bing"
})
if len(list_data) >= 10: if len(list_data) >= 10:
break break

View File

@@ -1,6 +1,7 @@
""" """
DuckDuckGo search engine implementation DuckDuckGo search engine implementation
""" """
from typing import Dict, List, Any from typing import Dict, List, Any
from asyncddgs import aDDGS from asyncddgs import aDDGS
@@ -14,27 +15,22 @@ class DDGSearchEngine(BaseSearchEngine):
""" """
DuckDuckGo搜索引擎实现 DuckDuckGo搜索引擎实现
""" """
def is_available(self) -> bool: def is_available(self) -> bool:
"""检查DuckDuckGo搜索引擎是否可用""" """检查DuckDuckGo搜索引擎是否可用"""
return True # DuckDuckGo不需要API密钥总是可用 return True # DuckDuckGo不需要API密钥总是可用
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
"""执行DuckDuckGo搜索""" """执行DuckDuckGo搜索"""
query = args["query"] query = args["query"]
num_results = args.get("num_results", 3) num_results = args.get("num_results", 3)
try: try:
async with aDDGS() as ddgs: async with aDDGS() as ddgs:
search_response = await ddgs.text(query, max_results=num_results) search_response = await ddgs.text(query, max_results=num_results)
return [ return [
{ {"title": r.get("title"), "url": r.get("href"), "snippet": r.get("body"), "provider": "DuckDuckGo"}
"title": r.get("title"),
"url": r.get("href"),
"snippet": r.get("body"),
"provider": "DuckDuckGo"
}
for r in search_response for r in search_response
] ]
except Exception as e: except Exception as e:

View File

@@ -1,6 +1,7 @@
""" """
Exa search engine implementation Exa search engine implementation
""" """
import asyncio import asyncio
import functools import functools
from datetime import datetime, timedelta from datetime import datetime, timedelta
@@ -19,31 +20,27 @@ class ExaSearchEngine(BaseSearchEngine):
""" """
Exa搜索引擎实现 Exa搜索引擎实现
""" """
def __init__(self): def __init__(self):
self._initialize_clients() self._initialize_clients()
def _initialize_clients(self): def _initialize_clients(self):
"""初始化Exa客户端""" """初始化Exa客户端"""
# 从主配置文件读取API密钥 # 从主配置文件读取API密钥
exa_api_keys = config_api.get_global_config("web_search.exa_api_keys", None) exa_api_keys = config_api.get_global_config("web_search.exa_api_keys", None)
# 创建API密钥管理器 # 创建API密钥管理器
self.api_manager = create_api_key_manager_from_config( self.api_manager = create_api_key_manager_from_config(exa_api_keys, lambda key: Exa(api_key=key), "Exa")
exa_api_keys,
lambda key: Exa(api_key=key),
"Exa"
)
def is_available(self) -> bool: def is_available(self) -> bool:
"""检查Exa搜索引擎是否可用""" """检查Exa搜索引擎是否可用"""
return self.api_manager.is_available() return self.api_manager.is_available()
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
"""执行Exa搜索""" """执行Exa搜索"""
if not self.is_available(): if not self.is_available():
return [] return []
query = args["query"] query = args["query"]
num_results = args.get("num_results", 3) num_results = args.get("num_results", 3)
time_range = args.get("time_range", "any") time_range = args.get("time_range", "any")
@@ -52,7 +49,7 @@ class ExaSearchEngine(BaseSearchEngine):
if time_range != "any": if time_range != "any":
today = datetime.now() today = datetime.now()
start_date = today - timedelta(days=7 if time_range == "week" else 30) start_date = today - timedelta(days=7 if time_range == "week" else 30)
exa_args["start_published_date"] = start_date.strftime('%Y-%m-%d') exa_args["start_published_date"] = start_date.strftime("%Y-%m-%d")
try: try:
# 使用API密钥管理器获取下一个客户端 # 使用API密钥管理器获取下一个客户端
@@ -60,17 +57,17 @@ class ExaSearchEngine(BaseSearchEngine):
if not exa_client: if not exa_client:
logger.error("无法获取Exa客户端") logger.error("无法获取Exa客户端")
return [] return []
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
func = functools.partial(exa_client.search_and_contents, query, **exa_args) func = functools.partial(exa_client.search_and_contents, query, **exa_args)
search_response = await loop.run_in_executor(None, func) search_response = await loop.run_in_executor(None, func)
return [ return [
{ {
"title": res.title, "title": res.title,
"url": res.url, "url": res.url,
"snippet": " ".join(getattr(res, 'highlights', [])) or (getattr(res, 'text', '')[:250] + '...'), "snippet": " ".join(getattr(res, "highlights", [])) or (getattr(res, "text", "")[:250] + "..."),
"provider": "Exa" "provider": "Exa",
} }
for res in search_response.results for res in search_response.results
] ]

View File

@@ -1,6 +1,7 @@
""" """
Tavily search engine implementation Tavily search engine implementation
""" """
import asyncio import asyncio
import functools import functools
from typing import Dict, List, Any from typing import Dict, List, Any
@@ -18,31 +19,29 @@ class TavilySearchEngine(BaseSearchEngine):
""" """
Tavily搜索引擎实现 Tavily搜索引擎实现
""" """
def __init__(self): def __init__(self):
self._initialize_clients() self._initialize_clients()
def _initialize_clients(self): def _initialize_clients(self):
"""初始化Tavily客户端""" """初始化Tavily客户端"""
# 从主配置文件读取API密钥 # 从主配置文件读取API密钥
tavily_api_keys = config_api.get_global_config("web_search.tavily_api_keys", None) tavily_api_keys = config_api.get_global_config("web_search.tavily_api_keys", None)
# 创建API密钥管理器 # 创建API密钥管理器
self.api_manager = create_api_key_manager_from_config( self.api_manager = create_api_key_manager_from_config(
tavily_api_keys, tavily_api_keys, lambda key: TavilyClient(api_key=key), "Tavily"
lambda key: TavilyClient(api_key=key),
"Tavily"
) )
def is_available(self) -> bool: def is_available(self) -> bool:
"""检查Tavily搜索引擎是否可用""" """检查Tavily搜索引擎是否可用"""
return self.api_manager.is_available() return self.api_manager.is_available()
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
"""执行Tavily搜索""" """执行Tavily搜索"""
if not self.is_available(): if not self.is_available():
return [] return []
query = args["query"] query = args["query"]
num_results = args.get("num_results", 3) num_results = args.get("num_results", 3)
time_range = args.get("time_range", "any") time_range = args.get("time_range", "any")
@@ -53,38 +52,40 @@ class TavilySearchEngine(BaseSearchEngine):
if not tavily_client: if not tavily_client:
logger.error("无法获取Tavily客户端") logger.error("无法获取Tavily客户端")
return [] return []
# 构建Tavily搜索参数 # 构建Tavily搜索参数
search_params = { search_params = {
"query": query, "query": query,
"max_results": num_results, "max_results": num_results,
"search_depth": "basic", "search_depth": "basic",
"include_answer": False, "include_answer": False,
"include_raw_content": False "include_raw_content": False,
} }
# 根据时间范围调整搜索参数 # 根据时间范围调整搜索参数
if time_range == "week": if time_range == "week":
search_params["days"] = 7 search_params["days"] = 7
elif time_range == "month": elif time_range == "month":
search_params["days"] = 30 search_params["days"] = 30
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
func = functools.partial(tavily_client.search, **search_params) func = functools.partial(tavily_client.search, **search_params)
search_response = await loop.run_in_executor(None, func) search_response = await loop.run_in_executor(None, func)
results = [] results = []
if search_response and "results" in search_response: if search_response and "results" in search_response:
for res in search_response["results"]: for res in search_response["results"]:
results.append({ results.append(
"title": res.get("title", "无标题"), {
"url": res.get("url", ""), "title": res.get("title", "无标题"),
"snippet": res.get("content", "")[:300] + "..." if res.get("content") else "无摘要", "url": res.get("url", ""),
"provider": "Tavily" "snippet": res.get("content", "")[:300] + "..." if res.get("content") else "无摘要",
}) "provider": "Tavily",
}
)
return results return results
except Exception as e: except Exception as e:
logger.error(f"Tavily 搜索失败: {e}") logger.error(f"Tavily 搜索失败: {e}")
return [] return []

View File

@@ -3,15 +3,10 @@ Web Search Tool Plugin
一个功能强大的网络搜索和URL解析插件支持多种搜索引擎和解析策略。 一个功能强大的网络搜索和URL解析插件支持多种搜索引擎和解析策略。
""" """
from typing import List, Tuple, Type from typing import List, Tuple, Type
from src.plugin_system import ( from src.plugin_system import BasePlugin, register_plugin, ComponentInfo, ConfigField, PythonDependency
BasePlugin,
register_plugin,
ComponentInfo,
ConfigField,
PythonDependency
)
from src.plugin_system.apis import config_api from src.plugin_system.apis import config_api
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -25,7 +20,7 @@ logger = get_logger("web_search_plugin")
class WEBSEARCHPLUGIN(BasePlugin): class WEBSEARCHPLUGIN(BasePlugin):
""" """
网络搜索工具插件 网络搜索工具插件
提供网络搜索和URL解析功能支持多种搜索引擎 提供网络搜索和URL解析功能支持多种搜索引擎
- Exa (需要API密钥) - Exa (需要API密钥)
- Tavily (需要API密钥) - Tavily (需要API密钥)
@@ -37,11 +32,11 @@ class WEBSEARCHPLUGIN(BasePlugin):
plugin_name: str = "web_search_tool" # 内部标识符 plugin_name: str = "web_search_tool" # 内部标识符
enable_plugin: bool = True enable_plugin: bool = True
dependencies: List[str] = [] # 插件依赖列表 dependencies: List[str] = [] # 插件依赖列表
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""初始化插件,立即加载所有搜索引擎""" """初始化插件,立即加载所有搜索引擎"""
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# 立即初始化所有搜索引擎触发API密钥管理器的日志输出 # 立即初始化所有搜索引擎触发API密钥管理器的日志输出
logger.info("🚀 正在初始化所有搜索引擎...") logger.info("🚀 正在初始化所有搜索引擎...")
try: try:
@@ -49,65 +44,58 @@ class WEBSEARCHPLUGIN(BasePlugin):
from .engines.tavily_engine import TavilySearchEngine from .engines.tavily_engine import TavilySearchEngine
from .engines.ddg_engine import DDGSearchEngine from .engines.ddg_engine import DDGSearchEngine
from .engines.bing_engine import BingSearchEngine from .engines.bing_engine import BingSearchEngine
# 实例化所有搜索引擎这会触发API密钥管理器的初始化 # 实例化所有搜索引擎这会触发API密钥管理器的初始化
exa_engine = ExaSearchEngine() exa_engine = ExaSearchEngine()
tavily_engine = TavilySearchEngine() tavily_engine = TavilySearchEngine()
ddg_engine = DDGSearchEngine() ddg_engine = DDGSearchEngine()
bing_engine = BingSearchEngine() bing_engine = BingSearchEngine()
# 报告每个引擎的状态 # 报告每个引擎的状态
engines_status = { engines_status = {
"Exa": exa_engine.is_available(), "Exa": exa_engine.is_available(),
"Tavily": tavily_engine.is_available(), "Tavily": tavily_engine.is_available(),
"DuckDuckGo": ddg_engine.is_available(), "DuckDuckGo": ddg_engine.is_available(),
"Bing": bing_engine.is_available() "Bing": bing_engine.is_available(),
} }
available_engines = [name for name, available in engines_status.items() if available] available_engines = [name for name, available in engines_status.items() if available]
unavailable_engines = [name for name, available in engines_status.items() if not available] unavailable_engines = [name for name, available in engines_status.items() if not available]
if available_engines: if available_engines:
logger.info(f"✅ 可用搜索引擎: {', '.join(available_engines)}") logger.info(f"✅ 可用搜索引擎: {', '.join(available_engines)}")
if unavailable_engines: if unavailable_engines:
logger.info(f"❌ 不可用搜索引擎: {', '.join(unavailable_engines)}") logger.info(f"❌ 不可用搜索引擎: {', '.join(unavailable_engines)}")
except Exception as e: except Exception as e:
logger.error(f"❌ 搜索引擎初始化失败: {e}", exc_info=True) logger.error(f"❌ 搜索引擎初始化失败: {e}", exc_info=True)
# Python包依赖列表 # Python包依赖列表
python_dependencies: List[PythonDependency] = [ python_dependencies: List[PythonDependency] = [
PythonDependency( PythonDependency(package_name="asyncddgs", description="异步DuckDuckGo搜索库", optional=False),
package_name="asyncddgs",
description="异步DuckDuckGo搜索库",
optional=False
),
PythonDependency( PythonDependency(
package_name="exa_py", package_name="exa_py",
description="Exa搜索API客户端库", description="Exa搜索API客户端库",
optional=True # 如果没有API密钥这个是可选的 optional=True, # 如果没有API密钥这个是可选的
), ),
PythonDependency( PythonDependency(
package_name="tavily", package_name="tavily",
install_name="tavily-python", # 安装时使用这个名称 install_name="tavily-python", # 安装时使用这个名称
description="Tavily搜索API客户端库", description="Tavily搜索API客户端库",
optional=True # 如果没有API密钥这个是可选的 optional=True, # 如果没有API密钥这个是可选的
), ),
PythonDependency( PythonDependency(
package_name="httpx", package_name="httpx",
version=">=0.20.0", version=">=0.20.0",
install_name="httpx[socks]", # 安装时使用这个名称(包含可选依赖) install_name="httpx[socks]", # 安装时使用这个名称(包含可选依赖)
description="支持SOCKS代理的HTTP客户端库", description="支持SOCKS代理的HTTP客户端库",
optional=False optional=False,
) ),
] ]
config_file_name: str = "config.toml" # 配置文件名 config_file_name: str = "config.toml" # 配置文件名
# 配置节描述 # 配置节描述
config_section_descriptions = { config_section_descriptions = {"plugin": "插件基本信息", "proxy": "链接本地解析代理配置"}
"plugin": "插件基本信息",
"proxy": "链接本地解析代理配置"
}
# 配置Schema定义 # 配置Schema定义
# 注意EXA配置和组件设置已迁移到主配置文件(bot_config.toml)的[exa]和[web_search]部分 # 注意EXA配置和组件设置已迁移到主配置文件(bot_config.toml)的[exa]和[web_search]部分
@@ -119,42 +107,32 @@ class WEBSEARCHPLUGIN(BasePlugin):
}, },
"proxy": { "proxy": {
"http_proxy": ConfigField( "http_proxy": ConfigField(
type=str, type=str, default=None, description="HTTP代理地址格式如: http://proxy.example.com:8080"
default=None,
description="HTTP代理地址格式如: http://proxy.example.com:8080"
), ),
"https_proxy": ConfigField( "https_proxy": ConfigField(
type=str, type=str, default=None, description="HTTPS代理地址格式如: http://proxy.example.com:8080"
default=None,
description="HTTPS代理地址格式如: http://proxy.example.com:8080"
), ),
"socks5_proxy": ConfigField( "socks5_proxy": ConfigField(
type=str, type=str, default=None, description="SOCKS5代理地址格式如: socks5://proxy.example.com:1080"
default=None,
description="SOCKS5代理地址格式如: socks5://proxy.example.com:1080"
), ),
"enable_proxy": ConfigField( "enable_proxy": ConfigField(type=bool, default=False, description="是否启用代理"),
type=bool,
default=False,
description="是否启用代理"
)
}, },
} }
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
""" """
获取插件组件列表 获取插件组件列表
Returns: Returns:
组件信息和类型的元组列表 组件信息和类型的元组列表
""" """
enable_tool = [] enable_tool = []
# 从主配置文件读取组件启用配置 # 从主配置文件读取组件启用配置
if config_api.get_global_config("web_search.enable_web_search_tool", True): if config_api.get_global_config("web_search.enable_web_search_tool", True):
enable_tool.append((WebSurfingTool.get_tool_info(), WebSurfingTool)) enable_tool.append((WebSurfingTool.get_tool_info(), WebSurfingTool))
if config_api.get_global_config("web_search.enable_url_tool", True): if config_api.get_global_config("web_search.enable_url_tool", True):
enable_tool.append((URLParserTool.get_tool_info(), URLParserTool)) enable_tool.append((URLParserTool.get_tool_info(), URLParserTool))
return enable_tool return enable_tool

View File

@@ -1,6 +1,7 @@
""" """
URL parser tool implementation URL parser tool implementation
""" """
import asyncio import asyncio
import functools import functools
from typing import Any, Dict from typing import Any, Dict
@@ -24,17 +25,18 @@ class URLParserTool(BaseTool):
""" """
一个用于解析和总结一个或多个网页URL内容的工具。 一个用于解析和总结一个或多个网页URL内容的工具。
""" """
name: str = "parse_url" name: str = "parse_url"
description: str = "当需要理解一个或多个特定网页链接的内容时,使用此工具。例如:'这些网页讲了什么?[https://example.com, https://example2.com]''帮我总结一下这些文章'" description: str = "当需要理解一个或多个特定网页链接的内容时,使用此工具。例如:'这些网页讲了什么?[https://example.com, https://example2.com]''帮我总结一下这些文章'"
available_for_llm: bool = True available_for_llm: bool = True
parameters = [ parameters = [
("urls", ToolParamType.STRING, "要理解的网站", True, None), ("urls", ToolParamType.STRING, "要理解的网站", True, None),
] ]
def __init__(self, plugin_config=None): def __init__(self, plugin_config=None):
super().__init__(plugin_config) super().__init__(plugin_config)
self._initialize_exa_clients() self._initialize_exa_clients()
def _initialize_exa_clients(self): def _initialize_exa_clients(self):
"""初始化Exa客户端""" """初始化Exa客户端"""
# 优先从主配置文件读取,如果没有则从插件配置文件读取 # 优先从主配置文件读取,如果没有则从插件配置文件读取
@@ -42,12 +44,10 @@ class URLParserTool(BaseTool):
if exa_api_keys is None: if exa_api_keys is None:
# 从插件配置文件读取 # 从插件配置文件读取
exa_api_keys = self.get_config("exa.api_keys", []) exa_api_keys = self.get_config("exa.api_keys", [])
# 创建API密钥管理器 # 创建API密钥管理器
self.api_manager = create_api_key_manager_from_config( self.api_manager = create_api_key_manager_from_config(
exa_api_keys, exa_api_keys, lambda key: Exa(api_key=key), "Exa URL Parser"
lambda key: Exa(api_key=key),
"Exa URL Parser"
) )
async def _local_parse_and_summarize(self, url: str) -> Dict[str, Any]: async def _local_parse_and_summarize(self, url: str) -> Dict[str, Any]:
@@ -58,12 +58,12 @@ class URLParserTool(BaseTool):
# 读取代理配置 # 读取代理配置
enable_proxy = self.get_config("proxy.enable_proxy", False) enable_proxy = self.get_config("proxy.enable_proxy", False)
proxies = None proxies = None
if enable_proxy: if enable_proxy:
socks5_proxy = self.get_config("proxy.socks5_proxy", None) socks5_proxy = self.get_config("proxy.socks5_proxy", None)
http_proxy = self.get_config("proxy.http_proxy", None) http_proxy = self.get_config("proxy.http_proxy", None)
https_proxy = self.get_config("proxy.https_proxy", None) https_proxy = self.get_config("proxy.https_proxy", None)
# 优先使用SOCKS5代理全协议代理 # 优先使用SOCKS5代理全协议代理
if socks5_proxy: if socks5_proxy:
proxies = socks5_proxy proxies = socks5_proxy
@@ -75,17 +75,17 @@ class URLParserTool(BaseTool):
if https_proxy: if https_proxy:
proxies["https://"] = https_proxy proxies["https://"] = https_proxy
logger.info(f"使用HTTP/HTTPS代理配置: {proxies}") logger.info(f"使用HTTP/HTTPS代理配置: {proxies}")
client_kwargs = {"timeout": 15.0, "follow_redirects": True} client_kwargs = {"timeout": 15.0, "follow_redirects": True}
if proxies: if proxies:
client_kwargs["proxies"] = proxies client_kwargs["proxies"] = proxies
async with httpx.AsyncClient(**client_kwargs) as client: async with httpx.AsyncClient(**client_kwargs) as client:
response = await client.get(url) response = await client.get(url)
response.raise_for_status() response.raise_for_status()
soup = BeautifulSoup(response.text, "html.parser") soup = BeautifulSoup(response.text, "html.parser")
title = soup.title.string if soup.title else "无标题" title = soup.title.string if soup.title else "无标题"
for script in soup(["script", "style"]): for script in soup(["script", "style"]):
script.extract() script.extract()
@@ -104,12 +104,12 @@ class URLParserTool(BaseTool):
return {"error": "未配置LLM模型"} return {"error": "未配置LLM模型"}
success, summary, reasoning, model_name = await llm_api.generate_with_model( success, summary, reasoning, model_name = await llm_api.generate_with_model(
prompt=summary_prompt, prompt=summary_prompt,
model_config=model_config, model_config=model_config,
request_type="story.generate", request_type="story.generate",
temperature=0.3, temperature=0.3,
max_tokens=1000 max_tokens=1000,
) )
if not success: if not success:
logger.info(f"生成摘要失败: {summary}") logger.info(f"生成摘要失败: {summary}")
@@ -117,12 +117,7 @@ class URLParserTool(BaseTool):
logger.info(f"成功生成摘要内容:'{summary}'") logger.info(f"成功生成摘要内容:'{summary}'")
return { return {"title": title, "url": url, "snippet": summary, "source": "local"}
"title": title,
"url": url,
"snippet": summary,
"source": "local"
}
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
logger.warning(f"本地解析URL '{url}' 失败 (HTTP {e.response.status_code})") logger.warning(f"本地解析URL '{url}' 失败 (HTTP {e.response.status_code})")
@@ -137,6 +132,7 @@ class URLParserTool(BaseTool):
""" """
# 获取当前文件路径用于缓存键 # 获取当前文件路径用于缓存键
import os import os
current_file_path = os.path.abspath(__file__) current_file_path = os.path.abspath(__file__)
# 检查缓存 # 检查缓存
@@ -144,7 +140,7 @@ class URLParserTool(BaseTool):
if cached_result: if cached_result:
logger.info(f"缓存命中: {self.name} -> {function_args}") logger.info(f"缓存命中: {self.name} -> {function_args}")
return cached_result return cached_result
urls_input = function_args.get("urls") urls_input = function_args.get("urls")
if not urls_input: if not urls_input:
return {"error": "URL列表不能为空。"} return {"error": "URL列表不能为空。"}
@@ -158,14 +154,14 @@ class URLParserTool(BaseTool):
valid_urls = validate_urls(urls) valid_urls = validate_urls(urls)
if not valid_urls: if not valid_urls:
return {"error": "未找到有效的URL。"} return {"error": "未找到有效的URL。"}
urls = valid_urls urls = valid_urls
logger.info(f"准备解析 {len(urls)} 个URL: {urls}") logger.info(f"准备解析 {len(urls)} 个URL: {urls}")
successful_results = [] successful_results = []
error_messages = [] error_messages = []
urls_to_retry_locally = [] urls_to_retry_locally = []
# 步骤 1: 尝试使用 Exa API 进行解析 # 步骤 1: 尝试使用 Exa API 进行解析
contents_response = None contents_response = None
if self.api_manager.is_available(): if self.api_manager.is_available():
@@ -182,41 +178,45 @@ class URLParserTool(BaseTool):
contents_response = await loop.run_in_executor(None, func) contents_response = await loop.run_in_executor(None, func)
except Exception as e: except Exception as e:
logger.error(f"执行 Exa URL解析时发生严重异常: {e}", exc_info=True) logger.error(f"执行 Exa URL解析时发生严重异常: {e}", exc_info=True)
contents_response = None # 确保异常后为None contents_response = None # 确保异常后为None
# 步骤 2: 处理Exa的响应 # 步骤 2: 处理Exa的响应
if contents_response and hasattr(contents_response, 'statuses'): if contents_response and hasattr(contents_response, "statuses"):
results_map = {res.url: res for res in contents_response.results} if hasattr(contents_response, 'results') else {} results_map = (
{res.url: res for res in contents_response.results} if hasattr(contents_response, "results") else {}
)
if contents_response.statuses: if contents_response.statuses:
for status in contents_response.statuses: for status in contents_response.statuses:
if status.status == 'success': if status.status == "success":
res = results_map.get(status.id) res = results_map.get(status.id)
if res: if res:
summary = getattr(res, 'summary', '') summary = getattr(res, "summary", "")
highlights = " ".join(getattr(res, 'highlights', [])) highlights = " ".join(getattr(res, "highlights", []))
text_snippet = (getattr(res, 'text', '')[:300] + '...') if getattr(res, 'text', '') else '' text_snippet = (getattr(res, "text", "")[:300] + "...") if getattr(res, "text", "") else ""
snippet = summary or highlights or text_snippet or '无摘要' snippet = summary or highlights or text_snippet or "无摘要"
successful_results.append({ successful_results.append(
"title": getattr(res, 'title', '无标题'), {
"url": getattr(res, 'url', status.id), "title": getattr(res, "title", "无标题"),
"snippet": snippet, "url": getattr(res, "url", status.id),
"source": "exa" "snippet": snippet,
}) "source": "exa",
}
)
else: else:
error_tag = getattr(status, 'error', '未知错误') error_tag = getattr(status, "error", "未知错误")
logger.warning(f"Exa解析URL '{status.id}' 失败: {error_tag}。准备本地重试。") logger.warning(f"Exa解析URL '{status.id}' 失败: {error_tag}。准备本地重试。")
urls_to_retry_locally.append(status.id) urls_to_retry_locally.append(status.id)
else: else:
# 如果Exa未配置、API调用失败或返回无效响应则所有URL都进入本地重试 # 如果Exa未配置、API调用失败或返回无效响应则所有URL都进入本地重试
urls_to_retry_locally.extend(url for url in urls if url not in [res['url'] for res in successful_results]) urls_to_retry_locally.extend(url for url in urls if url not in [res["url"] for res in successful_results])
# 步骤 3: 对失败的URL进行本地解析 # 步骤 3: 对失败的URL进行本地解析
if urls_to_retry_locally: if urls_to_retry_locally:
logger.info(f"开始本地解析以下URL: {urls_to_retry_locally}") logger.info(f"开始本地解析以下URL: {urls_to_retry_locally}")
local_tasks = [self._local_parse_and_summarize(url) for url in urls_to_retry_locally] local_tasks = [self._local_parse_and_summarize(url) for url in urls_to_retry_locally]
local_results = await asyncio.gather(*local_tasks) local_results = await asyncio.gather(*local_tasks)
for i, res in enumerate(local_results): for i, res in enumerate(local_results):
url = urls_to_retry_locally[i] url = urls_to_retry_locally[i]
if "error" in res: if "error" in res:
@@ -228,13 +228,9 @@ class URLParserTool(BaseTool):
return {"error": "无法从所有给定的URL获取内容。", "details": error_messages} return {"error": "无法从所有给定的URL获取内容。", "details": error_messages}
formatted_content = format_url_parse_results(successful_results) formatted_content = format_url_parse_results(successful_results)
result = { result = {"type": "url_parse_result", "content": formatted_content, "errors": error_messages}
"type": "url_parse_result",
"content": formatted_content,
"errors": error_messages
}
# 保存到缓存 # 保存到缓存
if "error" not in result: if "error" not in result:
await tool_cache.set(self.name, function_args, current_file_path, result) await tool_cache.set(self.name, function_args, current_file_path, result)

View File

@@ -1,6 +1,7 @@
""" """
Web search tool implementation Web search tool implementation
""" """
import asyncio import asyncio
from typing import Any, Dict, List from typing import Any, Dict, List
@@ -22,14 +23,23 @@ class WebSurfingTool(BaseTool):
""" """
网络搜索工具 网络搜索工具
""" """
name: str = "web_search" name: str = "web_search"
description: str = "用于执行网络搜索。当用户明确要求搜索,或者需要获取关于公司、产品、事件的最新信息、新闻或动态时,必须使用此工具" description: str = (
"用于执行网络搜索。当用户明确要求搜索,或者需要获取关于公司、产品、事件的最新信息、新闻或动态时,必须使用此工具"
)
available_for_llm: bool = True available_for_llm: bool = True
parameters = [ parameters = [
("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None), ("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None),
("num_results", ToolParamType.INTEGER, "期望每个搜索引擎返回的搜索结果数量默认为5。", False, None), ("num_results", ToolParamType.INTEGER, "期望每个搜索引擎返回的搜索结果数量默认为5。", False, None),
("time_range", ToolParamType.STRING, "指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'", False, ["any", "week", "month"]) (
] # type: ignore "time_range",
ToolParamType.STRING,
"指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'",
False,
["any", "week", "month"],
),
] # type: ignore
def __init__(self, plugin_config=None): def __init__(self, plugin_config=None):
super().__init__(plugin_config) super().__init__(plugin_config)
@@ -38,7 +48,7 @@ class WebSurfingTool(BaseTool):
"exa": ExaSearchEngine(), "exa": ExaSearchEngine(),
"tavily": TavilySearchEngine(), "tavily": TavilySearchEngine(),
"ddg": DDGSearchEngine(), "ddg": DDGSearchEngine(),
"bing": BingSearchEngine() "bing": BingSearchEngine(),
} }
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
@@ -48,6 +58,7 @@ class WebSurfingTool(BaseTool):
# 获取当前文件路径用于缓存键 # 获取当前文件路径用于缓存键
import os import os
current_file_path = os.path.abspath(__file__) current_file_path = os.path.abspath(__file__)
# 检查缓存 # 检查缓存
@@ -59,7 +70,7 @@ class WebSurfingTool(BaseTool):
# 读取搜索配置 # 读取搜索配置
enabled_engines = config_api.get_global_config("web_search.enabled_engines", ["ddg"]) enabled_engines = config_api.get_global_config("web_search.enabled_engines", ["ddg"])
search_strategy = config_api.get_global_config("web_search.search_strategy", "single") search_strategy = config_api.get_global_config("web_search.search_strategy", "single")
logger.info(f"开始搜索,策略: {search_strategy}, 启用引擎: {enabled_engines}, 参数: '{function_args}'") logger.info(f"开始搜索,策略: {search_strategy}, 启用引擎: {enabled_engines}, 参数: '{function_args}'")
# 根据策略执行搜索 # 根据策略执行搜索
@@ -69,17 +80,19 @@ class WebSurfingTool(BaseTool):
result = await self._execute_fallback_search(function_args, enabled_engines) result = await self._execute_fallback_search(function_args, enabled_engines)
else: # single else: # single
result = await self._execute_single_search(function_args, enabled_engines) result = await self._execute_single_search(function_args, enabled_engines)
# 保存到缓存 # 保存到缓存
if "error" not in result: if "error" not in result:
await tool_cache.set(self.name, function_args, current_file_path, result, semantic_query=query) await tool_cache.set(self.name, function_args, current_file_path, result, semantic_query=query)
return result return result
async def _execute_parallel_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]: async def _execute_parallel_search(
self, function_args: Dict[str, Any], enabled_engines: List[str]
) -> Dict[str, Any]:
"""并行搜索策略:同时使用所有启用的搜索引擎""" """并行搜索策略:同时使用所有启用的搜索引擎"""
search_tasks = [] search_tasks = []
for engine_name in enabled_engines: for engine_name in enabled_engines:
engine = self.engines.get(engine_name) engine = self.engines.get(engine_name)
if engine and engine.is_available(): if engine and engine.is_available():
@@ -92,7 +105,7 @@ class WebSurfingTool(BaseTool):
try: try:
search_results_lists = await asyncio.gather(*search_tasks, return_exceptions=True) search_results_lists = await asyncio.gather(*search_tasks, return_exceptions=True)
all_results = [] all_results = []
for result in search_results_lists: for result in search_results_lists:
if isinstance(result, list): if isinstance(result, list):
@@ -103,7 +116,7 @@ class WebSurfingTool(BaseTool):
# 去重并格式化 # 去重并格式化
unique_results = deduplicate_results(all_results) unique_results = deduplicate_results(all_results)
formatted_content = format_search_results(unique_results) formatted_content = format_search_results(unique_results)
return { return {
"type": "web_search_result", "type": "web_search_result",
"content": formatted_content, "content": formatted_content,
@@ -113,30 +126,32 @@ class WebSurfingTool(BaseTool):
logger.error(f"执行并行网络搜索时发生异常: {e}", exc_info=True) logger.error(f"执行并行网络搜索时发生异常: {e}", exc_info=True)
return {"error": f"执行网络搜索时发生严重错误: {str(e)}"} return {"error": f"执行网络搜索时发生严重错误: {str(e)}"}
async def _execute_fallback_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]: async def _execute_fallback_search(
self, function_args: Dict[str, Any], enabled_engines: List[str]
) -> Dict[str, Any]:
"""回退搜索策略:按顺序尝试搜索引擎,失败则尝试下一个""" """回退搜索策略:按顺序尝试搜索引擎,失败则尝试下一个"""
for engine_name in enabled_engines: for engine_name in enabled_engines:
engine = self.engines.get(engine_name) engine = self.engines.get(engine_name)
if not engine or not engine.is_available(): if not engine or not engine.is_available():
continue continue
try: try:
custom_args = function_args.copy() custom_args = function_args.copy()
custom_args["num_results"] = custom_args.get("num_results", 5) custom_args["num_results"] = custom_args.get("num_results", 5)
results = await engine.search(custom_args) results = await engine.search(custom_args)
if results: # 如果有结果,直接返回 if results: # 如果有结果,直接返回
formatted_content = format_search_results(results) formatted_content = format_search_results(results)
return { return {
"type": "web_search_result", "type": "web_search_result",
"content": formatted_content, "content": formatted_content,
} }
except Exception as e: except Exception as e:
logger.warning(f"{engine_name} 搜索失败,尝试下一个引擎: {e}") logger.warning(f"{engine_name} 搜索失败,尝试下一个引擎: {e}")
continue continue
return {"error": "所有搜索引擎都失败了。"} return {"error": "所有搜索引擎都失败了。"}
async def _execute_single_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]: async def _execute_single_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]:
@@ -145,20 +160,20 @@ class WebSurfingTool(BaseTool):
engine = self.engines.get(engine_name) engine = self.engines.get(engine_name)
if not engine or not engine.is_available(): if not engine or not engine.is_available():
continue continue
try: try:
custom_args = function_args.copy() custom_args = function_args.copy()
custom_args["num_results"] = custom_args.get("num_results", 5) custom_args["num_results"] = custom_args.get("num_results", 5)
results = await engine.search(custom_args) results = await engine.search(custom_args)
formatted_content = format_search_results(results) formatted_content = format_search_results(results)
return { return {
"type": "web_search_result", "type": "web_search_result",
"content": formatted_content, "content": formatted_content,
} }
except Exception as e: except Exception as e:
logger.error(f"{engine_name} 搜索失败: {e}") logger.error(f"{engine_name} 搜索失败: {e}")
return {"error": f"{engine_name} 搜索失败: {str(e)}"} return {"error": f"{engine_name} 搜索失败: {str(e)}"}
return {"error": "没有可用的搜索引擎。"} return {"error": "没有可用的搜索引擎。"}

View File

@@ -1,24 +1,25 @@
""" """
API密钥管理器提供轮询机制 API密钥管理器提供轮询机制
""" """
import itertools import itertools
from typing import List, Optional, TypeVar, Generic, Callable from typing import List, Optional, TypeVar, Generic, Callable
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger("api_key_manager") logger = get_logger("api_key_manager")
T = TypeVar('T') T = TypeVar("T")
class APIKeyManager(Generic[T]): class APIKeyManager(Generic[T]):
""" """
API密钥管理器支持轮询机制 API密钥管理器支持轮询机制
""" """
def __init__(self, api_keys: List[str], client_factory: Callable[[str], T], service_name: str = "Unknown"): def __init__(self, api_keys: List[str], client_factory: Callable[[str], T], service_name: str = "Unknown"):
""" """
初始化API密钥管理器 初始化API密钥管理器
Args: Args:
api_keys: API密钥列表 api_keys: API密钥列表
client_factory: 客户端工厂函数接受API密钥参数并返回客户端实例 client_factory: 客户端工厂函数接受API密钥参数并返回客户端实例
@@ -27,14 +28,14 @@ class APIKeyManager(Generic[T]):
self.service_name = service_name self.service_name = service_name
self.clients: List[T] = [] self.clients: List[T] = []
self.client_cycle: Optional[itertools.cycle] = None self.client_cycle: Optional[itertools.cycle] = None
if api_keys: if api_keys:
# 过滤有效的API密钥排除None、空字符串、"None"字符串等 # 过滤有效的API密钥排除None、空字符串、"None"字符串等
valid_keys = [] valid_keys = []
for key in api_keys: for key in api_keys:
if isinstance(key, str) and key.strip() and key.strip().lower() not in ("none", "null", ""): if isinstance(key, str) and key.strip() and key.strip().lower() not in ("none", "null", ""):
valid_keys.append(key.strip()) valid_keys.append(key.strip())
if valid_keys: if valid_keys:
try: try:
self.clients = [client_factory(key) for key in valid_keys] self.clients = [client_factory(key) for key in valid_keys]
@@ -48,35 +49,33 @@ class APIKeyManager(Generic[T]):
logger.warning(f"⚠️ {service_name} API Keys 配置无效包含None或空值{service_name} 功能将不可用") logger.warning(f"⚠️ {service_name} API Keys 配置无效包含None或空值{service_name} 功能将不可用")
else: else:
logger.warning(f"⚠️ {service_name} API Keys 未配置,{service_name} 功能将不可用") logger.warning(f"⚠️ {service_name} API Keys 未配置,{service_name} 功能将不可用")
def is_available(self) -> bool: def is_available(self) -> bool:
"""检查是否有可用的客户端""" """检查是否有可用的客户端"""
return bool(self.clients and self.client_cycle) return bool(self.clients and self.client_cycle)
def get_next_client(self) -> Optional[T]: def get_next_client(self) -> Optional[T]:
"""获取下一个客户端(轮询)""" """获取下一个客户端(轮询)"""
if not self.is_available(): if not self.is_available():
return None return None
return next(self.client_cycle) return next(self.client_cycle)
def get_client_count(self) -> int: def get_client_count(self) -> int:
"""获取可用客户端数量""" """获取可用客户端数量"""
return len(self.clients) return len(self.clients)
def create_api_key_manager_from_config( def create_api_key_manager_from_config(
config_keys: Optional[List[str]], config_keys: Optional[List[str]], client_factory: Callable[[str], T], service_name: str
client_factory: Callable[[str], T],
service_name: str
) -> APIKeyManager[T]: ) -> APIKeyManager[T]:
""" """
从配置创建API密钥管理器的便捷函数 从配置创建API密钥管理器的便捷函数
Args: Args:
config_keys: 从配置读取的API密钥列表 config_keys: 从配置读取的API密钥列表
client_factory: 客户端工厂函数 client_factory: 客户端工厂函数
service_name: 服务名称 service_name: 服务名称
Returns: Returns:
API密钥管理器实例 API密钥管理器实例
""" """

View File

@@ -1,6 +1,7 @@
""" """
Formatters for web search results Formatters for web search results
""" """
from typing import List, Dict, Any from typing import List, Dict, Any
@@ -13,15 +14,15 @@ def format_search_results(results: List[Dict[str, Any]]) -> str:
formatted_string = "根据网络搜索结果:\n\n" formatted_string = "根据网络搜索结果:\n\n"
for i, res in enumerate(results, 1): for i, res in enumerate(results, 1):
title = res.get("title", '无标题') title = res.get("title", "无标题")
url = res.get("url", '#') url = res.get("url", "#")
snippet = res.get("snippet", '无摘要') snippet = res.get("snippet", "无摘要")
provider = res.get("provider", "未知来源") provider = res.get("provider", "未知来源")
formatted_string += f"{i}. **{title}** (来自: {provider})\n" formatted_string += f"{i}. **{title}** (来自: {provider})\n"
formatted_string += f" - 摘要: {snippet}\n" formatted_string += f" - 摘要: {snippet}\n"
formatted_string += f" - 来源: {url}\n\n" formatted_string += f" - 来源: {url}\n\n"
return formatted_string return formatted_string
@@ -31,10 +32,10 @@ def format_url_parse_results(results: List[Dict[str, Any]]) -> str:
""" """
formatted_parts = [] formatted_parts = []
for res in results: for res in results:
title = res.get('title', '无标题') title = res.get("title", "无标题")
url = res.get('url', '#') url = res.get("url", "#")
snippet = res.get('snippet', '无摘要') snippet = res.get("snippet", "无摘要")
source = res.get('source', '未知') source = res.get("source", "未知")
formatted_string = f"**{title}**\n" formatted_string = f"**{title}**\n"
formatted_string += f"**内容摘要**:\n{snippet}\n" formatted_string += f"**内容摘要**:\n{snippet}\n"

View File

@@ -1,6 +1,7 @@
""" """
URL processing utilities URL processing utilities
""" """
import re import re
from typing import List from typing import List
@@ -12,11 +13,11 @@ def parse_urls_from_input(urls_input) -> List[str]:
if isinstance(urls_input, str): if isinstance(urls_input, str):
# 如果是字符串尝试解析为URL列表 # 如果是字符串尝试解析为URL列表
# 提取所有HTTP/HTTPS URL # 提取所有HTTP/HTTPS URL
url_pattern = r'https?://[^\s\],]+' url_pattern = r"https?://[^\s\],]+"
urls = re.findall(url_pattern, urls_input) urls = re.findall(url_pattern, urls_input)
if not urls: if not urls:
# 如果没有找到标准URL将整个字符串作为单个URL # 如果没有找到标准URL将整个字符串作为单个URL
if urls_input.strip().startswith(('http://', 'https://')): if urls_input.strip().startswith(("http://", "https://")):
urls = [urls_input.strip()] urls = [urls_input.strip()]
else: else:
return [] return []
@@ -24,7 +25,7 @@ def parse_urls_from_input(urls_input) -> List[str]:
urls = [url.strip() for url in urls_input if isinstance(url, str) and url.strip()] urls = [url.strip() for url in urls_input if isinstance(url, str) and url.strip()]
else: else:
return [] return []
return urls return urls
@@ -34,6 +35,6 @@ def validate_urls(urls: List[str]) -> List[str]:
""" """
valid_urls = [] valid_urls = []
for url in urls: for url in urls:
if url.startswith(('http://', 'https://')): if url.startswith(("http://", "https://")):
valid_urls.append(url) valid_urls.append(url)
return valid_urls return valid_urls

View File

@@ -21,8 +21,18 @@ logger = get_logger(__name__)
# ============================ AsyncTask ============================ # ============================ AsyncTask ============================
class ReminderTask(AsyncTask): class ReminderTask(AsyncTask):
def __init__(self, delay: float, stream_id: str, is_group: bool, target_user_id: str, target_user_name: str, event_details: str, creator_name: str): def __init__(
self,
delay: float,
stream_id: str,
is_group: bool,
target_user_id: str,
target_user_name: str,
event_details: str,
creator_name: str,
):
super().__init__(task_name=f"ReminderTask_{target_user_id}_{datetime.now().timestamp()}") super().__init__(task_name=f"ReminderTask_{target_user_id}_{datetime.now().timestamp()}")
self.delay = delay self.delay = delay
self.stream_id = stream_id self.stream_id = stream_id
@@ -37,22 +47,22 @@ class ReminderTask(AsyncTask):
if self.delay > 0: if self.delay > 0:
logger.info(f"等待 {self.delay:.2f} 秒后执行提醒...") logger.info(f"等待 {self.delay:.2f} 秒后执行提醒...")
await asyncio.sleep(self.delay) await asyncio.sleep(self.delay)
logger.info(f"执行提醒任务: 给 {self.target_user_name} 发送关于 '{self.event_details}' 的提醒") logger.info(f"执行提醒任务: 给 {self.target_user_name} 发送关于 '{self.event_details}' 的提醒")
reminder_text = f"叮咚!这是 {self.creator_name} 让我准时提醒你的事情:\n\n{self.event_details}" reminder_text = f"叮咚!这是 {self.creator_name} 让我准时提醒你的事情:\n\n{self.event_details}"
if self.is_group: if self.is_group:
# 在群聊中,构造 @ 消息段并发送 # 在群聊中,构造 @ 消息段并发送
group_id = self.stream_id.split('_')[-1] if '_' in self.stream_id else self.stream_id group_id = self.stream_id.split("_")[-1] if "_" in self.stream_id else self.stream_id
message_payload = [ message_payload = [
{"type": "at", "data": {"qq": self.target_user_id}}, {"type": "at", "data": {"qq": self.target_user_id}},
{"type": "text", "data": {"text": f" {reminder_text}"}} {"type": "text", "data": {"text": f" {reminder_text}"}},
] ]
await send_api.adapter_command_to_stream( await send_api.adapter_command_to_stream(
action="send_group_msg", action="send_group_msg",
params={"group_id": group_id, "message": message_payload}, params={"group_id": group_id, "message": message_payload},
stream_id=self.stream_id stream_id=self.stream_id,
) )
else: else:
# 在私聊中,直接发送文本 # 在私聊中,直接发送文本
@@ -66,6 +76,7 @@ class ReminderTask(AsyncTask):
# =============================== Actions =============================== # =============================== Actions ===============================
class RemindAction(BaseAction): class RemindAction(BaseAction):
"""一个能从对话中智能识别并设置定时提醒的动作。""" """一个能从对话中智能识别并设置定时提醒的动作。"""
@@ -95,12 +106,12 @@ class RemindAction(BaseAction):
action_parameters = { action_parameters = {
"user_name": "需要被提醒的人的称呼或名字,如果没有明确指定给某人,则默认为'自己'", "user_name": "需要被提醒的人的称呼或名字,如果没有明确指定给某人,则默认为'自己'",
"remind_time": "描述提醒时间的自然语言字符串,例如'十分钟后''明天下午3点'", "remind_time": "描述提醒时间的自然语言字符串,例如'十分钟后''明天下午3点'",
"event_details": "需要提醒的具体事件内容" "event_details": "需要提醒的具体事件内容",
} }
action_require = [ action_require = [
"当用户请求在未来的某个时间点提醒他/她或别人某件事时使用", "当用户请求在未来的某个时间点提醒他/她或别人某件事时使用",
"适用于包含明确时间信息和事件描述的对话", "适用于包含明确时间信息和事件描述的对话",
"例如:'10分钟后提醒我收快递''明天早上九点喊一下李四参加晨会'" "例如:'10分钟后提醒我收快递''明天早上九点喊一下李四参加晨会'",
] ]
async def execute(self) -> Tuple[bool, str]: async def execute(self) -> Tuple[bool, str]:
@@ -110,7 +121,15 @@ class RemindAction(BaseAction):
event_details = self.action_data.get("event_details") event_details = self.action_data.get("event_details")
if not all([user_name, remind_time_str, event_details]): if not all([user_name, remind_time_str, event_details]):
missing_params = [p for p, v in {"user_name": user_name, "remind_time": remind_time_str, "event_details": event_details}.items() if not v] missing_params = [
p
for p, v in {
"user_name": user_name,
"remind_time": remind_time_str,
"event_details": event_details,
}.items()
if not v
]
error_msg = f"缺少必要的提醒参数: {', '.join(missing_params)}" error_msg = f"缺少必要的提醒参数: {', '.join(missing_params)}"
logger.warning(f"[ReminderPlugin] LLM未能提取完整参数: {error_msg}") logger.warning(f"[ReminderPlugin] LLM未能提取完整参数: {error_msg}")
return False, error_msg return False, error_msg
@@ -135,9 +154,9 @@ class RemindAction(BaseAction):
person_manager = get_person_info_manager() person_manager = get_person_info_manager()
user_id_to_remind = None user_id_to_remind = None
user_name_to_remind = "" user_name_to_remind = ""
assert isinstance(user_name, str) assert isinstance(user_name, str)
if user_name.strip() in ["自己", "", "me"]: if user_name.strip() in ["自己", "", "me"]:
user_id_to_remind = self.user_id user_id_to_remind = self.user_id
user_name_to_remind = self.user_nickname user_name_to_remind = self.user_nickname
@@ -154,7 +173,7 @@ class RemindAction(BaseAction):
try: try:
assert user_id_to_remind is not None assert user_id_to_remind is not None
assert event_details is not None assert event_details is not None
reminder_task = ReminderTask( reminder_task = ReminderTask(
delay=delay_seconds, delay=delay_seconds,
stream_id=self.chat_id, stream_id=self.chat_id,
@@ -162,14 +181,14 @@ class RemindAction(BaseAction):
target_user_id=str(user_id_to_remind), target_user_id=str(user_id_to_remind),
target_user_name=str(user_name_to_remind), target_user_name=str(user_name_to_remind),
event_details=str(event_details), event_details=str(event_details),
creator_name=str(self.user_nickname) creator_name=str(self.user_nickname),
) )
await async_task_manager.add_task(reminder_task) await async_task_manager.add_task(reminder_task)
# 4. 发送确认消息 # 4. 发送确认消息
confirm_message = f"好的,我记下了。\n将在 {target_time.strftime('%Y-%m-%d %H:%M:%S')} 提醒 {user_name_to_remind}\n{event_details}" confirm_message = f"好的,我记下了。\n将在 {target_time.strftime('%Y-%m-%d %H:%M:%S')} 提醒 {user_name_to_remind}\n{event_details}"
await self.send_text(confirm_message) await self.send_text(confirm_message)
return True, "提醒设置成功" return True, "提醒设置成功"
except Exception as e: except Exception as e:
logger.error(f"[ReminderPlugin] 创建提醒任务时出错: {e}", exc_info=True) logger.error(f"[ReminderPlugin] 创建提醒任务时出错: {e}", exc_info=True)
@@ -179,6 +198,7 @@ class RemindAction(BaseAction):
# =============================== Plugin =============================== # =============================== Plugin ===============================
@register_plugin @register_plugin
class ReminderPlugin(BasePlugin): class ReminderPlugin(BasePlugin):
"""一个能从对话中智能识别并设置定时提醒的插件。""" """一个能从对话中智能识别并设置定时提醒的插件。"""
@@ -193,6 +213,4 @@ class ReminderPlugin(BasePlugin):
def get_plugin_components(self) -> List[Tuple[ActionInfo, Type[BaseAction]]]: def get_plugin_components(self) -> List[Tuple[ActionInfo, Type[BaseAction]]]:
"""注册插件的所有功能组件。""" """注册插件的所有功能组件。"""
return [ return [(RemindAction.get_action_info(), RemindAction)]
(RemindAction.get_action_info(), RemindAction)
]

View File

@@ -290,4 +290,4 @@ def has_active_plans(month: str) -> bool:
return count > 0 return count > 0
except Exception as e: except Exception as e:
logger.error(f"检查 {month} 的有效月度计划时发生错误: {e}") logger.error(f"检查 {month} 的有效月度计划时发生错误: {e}")
return False return False

View File

@@ -221,4 +221,4 @@ class MonthlyPlanLLMGenerator:
return plans return plans
except Exception as e: except Exception as e:
logger.error(f"解析月度计划响应时发生错误: {e}") logger.error(f"解析月度计划响应时发生错误: {e}")
return [] return []

View File

@@ -102,4 +102,4 @@ class PlanManager:
def get_plans_for_schedule(self, month: str, max_count: int) -> List: def get_plans_for_schedule(self, month: str, max_count: int) -> List:
avoid_days = global_config.planning_system.avoid_repetition_days avoid_days = global_config.planning_system.avoid_repetition_days
return get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days) return get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days)

View File

@@ -96,4 +96,4 @@ class ScheduleData(BaseModel):
covered[i] = True covered[i] = True
# 检查是否所有分钟都被覆盖 # 检查是否所有分钟都被覆盖
return all(covered) return all(covered)