diff --git a/bot.py b/bot.py index 2490e6a97..5c4299f34 100644 --- a/bot.py +++ b/bot.py @@ -34,16 +34,18 @@ script_dir = os.path.dirname(os.path.abspath(__file__)) os.chdir(script_dir) logger.info(f"已设置工作目录为: {script_dir}") + # 检查并创建.env文件 def ensure_env_file(): """确保.env文件存在,如果不存在则从模板创建""" env_file = Path(".env") template_env = Path("template/template.env") - + if not env_file.exists(): if template_env.exists(): logger.info("未找到.env文件,正在从模板创建...") import shutil + shutil.copy(template_env, env_file) logger.info("已从template/template.env创建.env文件") logger.warning("请编辑.env文件,将EULA_CONFIRMED设置为true并配置其他必要参数") @@ -51,6 +53,7 @@ def ensure_env_file(): logger.error("未找到.env文件和template.env模板文件") sys.exit(1) + # 确保环境文件存在 ensure_env_file() @@ -130,32 +133,32 @@ async def graceful_shutdown(): def check_eula(): """检查EULA和隐私条款确认状态 - 环境变量版(类似Minecraft)""" # 检查环境变量中的EULA确认 - eula_confirmed = os.getenv('EULA_CONFIRMED', '').lower() - - if eula_confirmed == 'true': + eula_confirmed = os.getenv("EULA_CONFIRMED", "").lower() + + if eula_confirmed == "true": logger.info("EULA已通过环境变量确认") return - + # 如果没有确认,提示用户 confirm_logger.critical("您需要同意EULA和隐私条款才能使用MoFox_Bot") confirm_logger.critical("请阅读以下文件:") confirm_logger.critical(" - EULA.md (用户许可协议)") confirm_logger.critical(" - PRIVACY.md (隐私条款)") confirm_logger.critical("然后编辑 .env 文件,将 'EULA_CONFIRMED=false' 改为 'EULA_CONFIRMED=true'") - + # 等待用户确认 while True: try: load_dotenv(override=True) # 重新加载.env文件 - - eula_confirmed = os.getenv('EULA_CONFIRMED', '').lower() - if eula_confirmed == 'true': + + eula_confirmed = os.getenv("EULA_CONFIRMED", "").lower() + if eula_confirmed == "true": confirm_logger.info("EULA确认成功,感谢您的同意") return - + confirm_logger.critical("请修改 .env 文件中的 EULA_CONFIRMED=true 后重新启动程序") input("按Enter键检查.env文件状态...") - + except KeyboardInterrupt: confirm_logger.info("用户取消,程序退出") sys.exit(0) diff --git a/scripts/update_prompt_imports.py b/scripts/update_prompt_imports.py index 289d7f327..227491ec2 100644 --- a/scripts/update_prompt_imports.py +++ b/scripts/update_prompt_imports.py @@ -20,25 +20,26 @@ files_to_update = [ "src/mais4u/mais4u_chat/s4u_mood_manager.py", "src/plugin_system/core/tool_use.py", "src/chat/memory_system/memory_activator.py", - "src/chat/utils/smart_prompt.py" + "src/chat/utils/smart_prompt.py", ] + def update_prompt_imports(file_path): """更新文件中的Prompt导入""" if not os.path.exists(file_path): print(f"文件不存在: {file_path}") return False - - with open(file_path, 'r', encoding='utf-8') as f: + + with open(file_path, "r", encoding="utf-8") as f: content = f.read() - + # 替换导入语句 old_import = "from src.chat.utils.prompt_builder import Prompt, global_prompt_manager" new_import = "from src.chat.utils.prompt import Prompt, global_prompt_manager" - + if old_import in content: new_content = content.replace(old_import, new_import) - with open(file_path, 'w', encoding='utf-8') as f: + with open(file_path, "w", encoding="utf-8") as f: f.write(new_content) print(f"已更新: {file_path}") return True @@ -46,14 +47,16 @@ def update_prompt_imports(file_path): print(f"无需更新: {file_path}") return False + def main(): """主函数""" updated_count = 0 for file_path in files_to_update: if update_prompt_imports(file_path): updated_count += 1 - + print(f"\n更新完成!共更新了 {updated_count} 个文件") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/chat/affinity_flow/__init__.py b/src/chat/affinity_flow/__init__.py index ae0f33fec..59f35bacd 100644 --- a/src/chat/affinity_flow/__init__.py +++ b/src/chat/affinity_flow/__init__.py @@ -5,4 +5,4 @@ from src.chat.affinity_flow.afc_manager import afc_manager -__all__ = ['afc_manager', 'AFCManager', 'AffinityFlowChatter'] \ No newline at end of file +__all__ = ["afc_manager", "AFCManager", "AffinityFlowChatter"] diff --git a/src/chat/affinity_flow/afc_manager.py b/src/chat/affinity_flow/afc_manager.py index c96873089..9555ee5ea 100644 --- a/src/chat/affinity_flow/afc_manager.py +++ b/src/chat/affinity_flow/afc_manager.py @@ -2,6 +2,7 @@ 亲和力聊天处理流管理器 管理不同聊天流的亲和力聊天处理流,统一获取新消息并分发到对应的亲和力聊天处理流 """ + import time import traceback from typing import Dict, Optional, List @@ -20,7 +21,7 @@ class AFCManager: def __init__(self): self.affinity_flow_chatters: Dict[str, "AffinityFlowChatter"] = {} - '''所有聊天流的亲和力聊天处理流,stream_id -> affinity_flow_chatter''' + """所有聊天流的亲和力聊天处理流,stream_id -> affinity_flow_chatter""" # 动作管理器 self.action_manager = ActionManager() @@ -40,11 +41,7 @@ class AFCManager: # 创建增强版规划器 planner = ActionPlanner(stream_id, self.action_manager) - chatter = AffinityFlowChatter( - stream_id=stream_id, - planner=planner, - action_manager=self.action_manager - ) + chatter = AffinityFlowChatter(stream_id=stream_id, planner=planner, action_manager=self.action_manager) self.affinity_flow_chatters[stream_id] = chatter logger.info(f"创建新的亲和力聊天处理器: {stream_id}") @@ -74,7 +71,6 @@ class AFCManager: "executed_count": 0, } - def get_chatter_stats(self, stream_id: str) -> Optional[Dict[str, any]]: """获取聊天处理器统计""" if stream_id in self.affinity_flow_chatters: @@ -131,4 +127,5 @@ class AFCManager: self.affinity_flow_chatters[stream_id].update_interest_keywords(new_keywords) logger.info(f"已更新聊天流 {stream_id} 的兴趣关键词: {list(new_keywords.keys())}") -afc_manager = AFCManager() \ No newline at end of file + +afc_manager = AFCManager() diff --git a/src/chat/affinity_flow/chatter.py b/src/chat/affinity_flow/chatter.py index 7e5f9e6f1..fa3445924 100644 --- a/src/chat/affinity_flow/chatter.py +++ b/src/chat/affinity_flow/chatter.py @@ -2,6 +2,7 @@ 亲和力聊天处理器 单个聊天流的处理器,负责处理特定聊天流的完整交互流程 """ + import time import traceback from datetime import datetime @@ -57,10 +58,7 @@ class AffinityFlowChatter: unread_messages = context.get_unread_messages() # 使用增强版规划器处理消息 - actions, target_message = await self.planner.plan( - mode=ChatMode.FOCUS, - context=context - ) + actions, target_message = await self.planner.plan(mode=ChatMode.FOCUS, context=context) self.stats["plans_created"] += 1 # 执行动作(如果规划器返回了动作) @@ -84,7 +82,9 @@ class AffinityFlowChatter: **execution_result, } - logger.info(f"聊天流 {self.stream_id} StreamContext处理成功: 动作数={result['actions_count']}, 未读消息={result['unread_messages_processed']}") + logger.info( + f"聊天流 {self.stream_id} StreamContext处理成功: 动作数={result['actions_count']}, 未读消息={result['unread_messages_processed']}" + ) return result @@ -197,7 +197,9 @@ class AffinityFlowChatter: def __repr__(self) -> str: """详细字符串表示""" - return (f"AffinityFlowChatter(stream_id={self.stream_id}, " - f"messages_processed={self.stats['messages_processed']}, " - f"plans_created={self.stats['plans_created']}, " - f"last_activity={datetime.fromtimestamp(self.last_activity_time)})") \ No newline at end of file + return ( + f"AffinityFlowChatter(stream_id={self.stream_id}, " + f"messages_processed={self.stats['messages_processed']}, " + f"plans_created={self.stats['plans_created']}, " + f"last_activity={datetime.fromtimestamp(self.last_activity_time)})" + ) diff --git a/src/chat/affinity_flow/interest_scoring.py b/src/chat/affinity_flow/interest_scoring.py index d8dfc2778..cf5200bbc 100644 --- a/src/chat/affinity_flow/interest_scoring.py +++ b/src/chat/affinity_flow/interest_scoring.py @@ -38,7 +38,9 @@ class InterestScoringSystem: # 连续不回复概率提升 self.no_reply_count = 0 self.max_no_reply_count = affinity_config.max_no_reply_count - self.probability_boost_per_no_reply = affinity_config.no_reply_threshold_adjustment / affinity_config.max_no_reply_count # 每次不回复增加的概率 + self.probability_boost_per_no_reply = ( + affinity_config.no_reply_threshold_adjustment / affinity_config.max_no_reply_count + ) # 每次不回复增加的概率 # 用户关系数据 self.user_relationships: Dict[str, float] = {} # user_id -> relationship_score @@ -153,7 +155,9 @@ class InterestScoringSystem: # 返回匹配分数,考虑置信度和匹配标签数量 affinity_config = global_config.affinity_flow - match_count_bonus = min(len(match_result.matched_tags) * affinity_config.match_count_bonus, affinity_config.max_match_bonus) + match_count_bonus = min( + len(match_result.matched_tags) * affinity_config.match_count_bonus, affinity_config.max_match_bonus + ) final_score = match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus logger.debug( f"⚖️ 最终分数计算: 总分({match_result.overall_score:.3f}) × 1.3 × 置信度({match_result.confidence:.3f}) + 标签数量奖励({match_count_bonus:.3f}) = {final_score:.3f}" @@ -263,7 +267,17 @@ class InterestScoringSystem: if not msg.processed_plain_text: return 0.0 - if msg.is_mentioned or (bot_nickname and bot_nickname in msg.processed_plain_text): + # 检查是否被提及 + is_mentioned = msg.is_mentioned or (bot_nickname and bot_nickname in msg.processed_plain_text) + + # 检查是否为私聊(group_info为None表示私聊) + is_private_chat = msg.group_info is None + + # 如果被提及或是私聊,都视为提及了bot + if is_mentioned or is_private_chat: + logger.debug(f"🔍 提及检测 - 被提及: {is_mentioned}, 私聊: {is_private_chat}") + if is_private_chat and not is_mentioned: + logger.debug("💬 私聊消息自动视为提及bot") return global_config.affinity_flow.mention_bot_interest_score return 0.0 @@ -282,7 +296,9 @@ class InterestScoringSystem: logger.debug(f"📋 基础阈值: {base_threshold:.3f}") # 如果被提及,降低阈值 - if score.mentioned_score >= global_config.affinity_flow.mention_bot_interest_score * 0.5: # 使用提及bot兴趣分的一半作为判断阈值 + if ( + score.mentioned_score >= global_config.affinity_flow.mention_bot_interest_score * 0.5 + ): # 使用提及bot兴趣分的一半作为判断阈值 base_threshold = self.mention_threshold logger.debug(f"📣 消息提及了机器人,使用降低阈值: {base_threshold:.3f}") @@ -325,7 +341,9 @@ class InterestScoringSystem: def update_user_relationship(self, user_id: str, relationship_change: float): """更新用户关系""" - old_score = self.user_relationships.get(user_id, global_config.affinity_flow.base_relationship_score) # 默认新用户分数 + old_score = self.user_relationships.get( + user_id, global_config.affinity_flow.base_relationship_score + ) # 默认新用户分数 new_score = max(0.0, min(1.0, old_score + relationship_change)) self.user_relationships[user_id] = new_score diff --git a/src/chat/affinity_flow/relationship_tracker.py b/src/chat/affinity_flow/relationship_tracker.py index 8d5af05d0..49074bf93 100644 --- a/src/chat/affinity_flow/relationship_tracker.py +++ b/src/chat/affinity_flow/relationship_tracker.py @@ -116,6 +116,7 @@ class UserRelationshipTracker: try: # 获取bot人设信息 from src.individuality.individuality import Individuality + individuality = Individuality() bot_personality = await individuality.get_personality_block() @@ -168,7 +169,17 @@ class UserRelationshipTracker: # 清理LLM响应,移除可能的格式标记 cleaned_response = self._clean_llm_json_response(llm_response) response_data = json.loads(cleaned_response) - new_score = max(0.0, min(1.0, float(response_data.get("new_relationship_score", global_config.affinity_flow.base_relationship_score)))) + new_score = max( + 0.0, + min( + 1.0, + float( + response_data.get( + "new_relationship_score", global_config.affinity_flow.base_relationship_score + ) + ), + ), + ) if self.interest_scoring_system: self.interest_scoring_system.update_user_relationship( @@ -295,7 +306,9 @@ class UserRelationshipTracker: # 更新缓存 self.user_relationship_cache[user_id] = { "relationship_text": relationship_data.get("relationship_text", ""), - "relationship_score": relationship_data.get("relationship_score", global_config.affinity_flow.base_relationship_score), + "relationship_score": relationship_data.get( + "relationship_score", global_config.affinity_flow.base_relationship_score + ), "last_tracked": time.time(), } return relationship_data.get("relationship_score", global_config.affinity_flow.base_relationship_score) @@ -386,7 +399,11 @@ class UserRelationshipTracker: # 获取当前关系数据 current_relationship = self._get_user_relationship_from_db(user_id) - current_score = current_relationship.get("relationship_score", global_config.affinity_flow.base_relationship_score) if current_relationship else global_config.affinity_flow.base_relationship_score + current_score = ( + current_relationship.get("relationship_score", global_config.affinity_flow.base_relationship_score) + if current_relationship + else global_config.affinity_flow.base_relationship_score + ) current_text = current_relationship.get("relationship_text", "新用户") if current_relationship else "新用户" # 使用LLM分析并更新关系 @@ -501,6 +518,7 @@ class UserRelationshipTracker: # 获取bot人设信息 from src.individuality.individuality import Individuality + individuality = Individuality() bot_personality = await individuality.get_personality_block() diff --git a/src/chat/emoji_system/emoji_history.py b/src/chat/emoji_system/emoji_history.py index a25063f52..804f61e0a 100644 --- a/src/chat/emoji_system/emoji_history.py +++ b/src/chat/emoji_system/emoji_history.py @@ -2,6 +2,7 @@ """ 表情包发送历史记录模块 """ + import os from typing import List, Dict from collections import deque @@ -26,15 +27,15 @@ def add_emoji_to_history(chat_id: str, emoji_description: str): """ if not chat_id or not emoji_description: return - + # 如果当前聊天还没有历史记录,则创建一个新的 deque if chat_id not in _history_cache: _history_cache[chat_id] = deque(maxlen=MAX_HISTORY_SIZE) - + # 添加新表情到历史记录 history = _history_cache[chat_id] history.append(emoji_description) - + logger.debug(f"已将表情 '{emoji_description}' 添加到聊天 {chat_id} 的内存历史中") @@ -50,10 +51,10 @@ def get_recent_emojis(chat_id: str, limit: int = 5) -> List[str]: return [] history = _history_cache[chat_id] - + # 从 deque 的右侧(即最近添加的)开始取 num_to_get = min(limit, len(history)) recent_emojis = [history[-i] for i in range(1, num_to_get + 1)] - + logger.debug(f"为聊天 {chat_id} 从内存中获取到最近 {len(recent_emojis)} 个表情: {recent_emojis}") return recent_emojis diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 8e6079897..b614345f0 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -477,7 +477,7 @@ class EmojiManager: emoji_options_str = "" for i, emoji in enumerate(candidate_emojis): # 为每个表情包创建一个编号和它的详细描述 - emoji_options_str += f"编号: {i+1}\n描述: {emoji.description}\n\n" + emoji_options_str += f"编号: {i + 1}\n描述: {emoji.description}\n\n" # 精心设计的prompt,引导LLM做出选择 prompt = f""" @@ -524,10 +524,8 @@ class EmojiManager: self.record_usage(selected_emoji.hash) _time_end = time.time() - logger.info( - f"找到匹配描述的表情包: {selected_emoji.description}, 耗时: {(_time_end - _time_start):.2f}s" - ) - + logger.info(f"找到匹配描述的表情包: {selected_emoji.description}, 耗时: {(_time_end - _time_start):.2f}s") + # 8. 返回选中的表情包信息 return selected_emoji.full_path, f"[表情包:{selected_emoji.description}]", text_emotion @@ -627,8 +625,9 @@ class EmojiManager: # 无论steal_emoji是否开启,都检查emoji文件夹以支持手动注册 # 只有在需要腾出空间或填充表情库时,才真正执行注册 - if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or \ - (self.emoji_num < self.emoji_num_max): + if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or ( + self.emoji_num < self.emoji_num_max + ): try: # 获取目录下所有图片文件 files_to_process = [ @@ -931,16 +930,21 @@ class EmojiManager: image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") image_bytes = base64.b64decode(image_base64) image_hash = hashlib.md5(image_bytes).hexdigest() - image_format = Image.open(io.BytesIO(image_bytes)).format.lower() if Image.open(io.BytesIO(image_bytes)).format else "jpeg" - + image_format = ( + Image.open(io.BytesIO(image_bytes)).format.lower() + if Image.open(io.BytesIO(image_bytes)).format + else "jpeg" + ) # 2. 检查数据库中是否已存在该表情包的描述,实现复用 existing_description = None try: with get_db_session() as session: - existing_image = session.query(Images).filter( - (Images.emoji_hash == image_hash) & (Images.type == "emoji") - ).one_or_none() + existing_image = ( + session.query(Images) + .filter((Images.emoji_hash == image_hash) & (Images.type == "emoji")) + .one_or_none() + ) if existing_image and existing_image.description: existing_description = existing_image.description logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...") diff --git a/src/chat/frequency_analyzer/analyzer.py b/src/chat/frequency_analyzer/analyzer.py index bd6331465..aa8141f59 100644 --- a/src/chat/frequency_analyzer/analyzer.py +++ b/src/chat/frequency_analyzer/analyzer.py @@ -14,6 +14,7 @@ Chat Frequency Analyzer - MIN_CHATS_FOR_PEAK: 在一个窗口内需要多少次聊天才能被认为是高峰时段。 - MIN_GAP_BETWEEN_PEAKS_HOURS: 两个独立高峰时段之间的最小间隔(小时)。 """ + import time as time_module from datetime import datetime, timedelta, time from typing import List, Tuple, Optional @@ -71,12 +72,14 @@ class ChatFrequencyAnalyzer: current_window_end = datetimes[i] # 合并重叠或相邻的高峰时段 - if peak_windows and current_window_start - peak_windows[-1][1] < timedelta(hours=MIN_GAP_BETWEEN_PEAKS_HOURS): + if peak_windows and current_window_start - peak_windows[-1][1] < timedelta( + hours=MIN_GAP_BETWEEN_PEAKS_HOURS + ): # 扩展上一个窗口的结束时间 peak_windows[-1] = (peak_windows[-1][0], current_window_end) else: peak_windows.append((current_window_start, current_window_end)) - + return peak_windows def get_peak_chat_times(self, chat_id: str) -> List[Tuple[time, time]]: @@ -99,7 +102,7 @@ class ChatFrequencyAnalyzer: return [] peak_datetime_windows = self._find_peak_windows(timestamps) - + # 将 datetime 窗口转换为 time 窗口,并进行归一化处理 peak_time_windows = [] for start_dt, end_dt in peak_datetime_windows: @@ -109,7 +112,7 @@ class ChatFrequencyAnalyzer: # 更新缓存 self._analysis_cache[chat_id] = (time_module.time(), peak_time_windows) - + return peak_time_windows def is_in_peak_time(self, chat_id: str, now: Optional[datetime] = None) -> bool: @@ -125,7 +128,7 @@ class ChatFrequencyAnalyzer: """ if now is None: now = datetime.now() - + now_time = now.time() peak_times = self.get_peak_chat_times(chat_id) @@ -136,7 +139,7 @@ class ChatFrequencyAnalyzer: else: # 跨天 if now_time >= start_time or now_time <= end_time: return True - + return False diff --git a/src/chat/frequency_analyzer/tracker.py b/src/chat/frequency_analyzer/tracker.py index bee9e4623..55b5add30 100644 --- a/src/chat/frequency_analyzer/tracker.py +++ b/src/chat/frequency_analyzer/tracker.py @@ -55,7 +55,7 @@ class ChatFrequencyTracker: now = time.time() if chat_id not in self._timestamps: self._timestamps[chat_id] = [] - + self._timestamps[chat_id].append(now) logger.debug(f"为 chat_id '{chat_id}' 记录了新的聊天时间: {now}") self._save_timestamps() diff --git a/src/chat/frequency_analyzer/trigger.py b/src/chat/frequency_analyzer/trigger.py index 1558c923a..156d300dd 100644 --- a/src/chat/frequency_analyzer/trigger.py +++ b/src/chat/frequency_analyzer/trigger.py @@ -14,6 +14,7 @@ Frequency-Based Proactive Trigger - TRIGGER_CHECK_INTERVAL_SECONDS: 触发器检查的周期(秒)。 - COOLDOWN_HOURS: 在同一个高峰时段内触发一次后的冷却时间(小时)。 """ + import asyncio import time from datetime import datetime @@ -21,6 +22,7 @@ from typing import Dict, Optional from src.common.logger import get_logger from src.chat.affinity_flow.afc_manager import afc_manager + # TODO: 需要重新实现主动思考和睡眠管理功能 from .analyzer import chat_frequency_analyzer @@ -65,7 +67,7 @@ class FrequencyBasedTrigger: continue now = datetime.now() - + for chat_id in all_chat_ids: # 3. 检查是否处于冷却时间内 last_triggered_time = self._last_triggered.get(chat_id, 0) @@ -74,7 +76,6 @@ class FrequencyBasedTrigger: # 4. 检查当前是否是该用户的高峰聊天时间 if chat_frequency_analyzer.is_in_peak_time(chat_id, now): - # 5. 检查用户当前是否已有活跃的处理任务 # 亲和力流系统不直接提供循环状态,通过检查最后活动时间来判断是否忙碌 chatter = afc_manager.get_or_create_chatter(chat_id) @@ -87,13 +88,13 @@ class FrequencyBasedTrigger: if current_time - chatter.get_activity_time() < 60: logger.debug(f"用户 {chat_id} 的亲和力处理器正忙,本次不触发。") continue - + logger.info(f"检测到用户 {chat_id} 处于聊天高峰期,且处理器空闲,准备触发主动思考。") - + # 6. TODO: 亲和力流系统的主动思考机制需要另行实现 # 目前先记录日志,等待后续实现 logger.info(f"用户 {chat_id} 处于高峰期,但亲和力流的主动思考功能暂未实现") - + # 7. 更新触发时间,进入冷却 self._last_triggered[chat_id] = time.time() diff --git a/src/chat/interest_system/__init__.py b/src/chat/interest_system/__init__.py index 3fe14e7bf..e64f25a2f 100644 --- a/src/chat/interest_system/__init__.py +++ b/src/chat/interest_system/__init__.py @@ -4,14 +4,12 @@ """ from .bot_interest_manager import BotInterestManager, bot_interest_manager -from src.common.data_models.bot_interest_data_model import ( - BotInterestTag, BotPersonalityInterests, InterestMatchResult -) +from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult __all__ = [ "BotInterestManager", "bot_interest_manager", "BotInterestTag", "BotPersonalityInterests", - "InterestMatchResult" -] \ No newline at end of file + "InterestMatchResult", +] diff --git a/src/chat/interest_system/bot_interest_manager.py b/src/chat/interest_system/bot_interest_manager.py index c8bd4a004..abdc3563d 100644 --- a/src/chat/interest_system/bot_interest_manager.py +++ b/src/chat/interest_system/bot_interest_manager.py @@ -2,6 +2,7 @@ 机器人兴趣标签管理系统 基于人设生成兴趣标签,并使用embedding计算匹配度 """ + import orjson import traceback from typing import List, Dict, Optional, Any @@ -10,9 +11,7 @@ import numpy as np from src.common.logger import get_logger from src.config.config import global_config -from src.common.data_models.bot_interest_data_model import ( - BotPersonalityInterests, BotInterestTag, InterestMatchResult -) +from src.common.data_models.bot_interest_data_model import BotPersonalityInterests, BotInterestTag, InterestMatchResult logger = get_logger("bot_interest_manager") @@ -87,7 +86,7 @@ class BotInterestManager: logger.debug("✅ 成功导入embedding相关模块") # 检查embedding配置是否存在 - if not hasattr(model_config.model_task_config, 'embedding'): + if not hasattr(model_config.model_task_config, "embedding"): raise RuntimeError("❌ 未找到embedding模型配置") logger.info("📋 找到embedding模型配置") @@ -101,7 +100,7 @@ class BotInterestManager: logger.info(f"🔗 客户端类型: {type(self.embedding_request).__name__}") # 获取第一个embedding模型的ModelInfo - if hasattr(self.embedding_config, 'model_list') and self.embedding_config.model_list: + if hasattr(self.embedding_config, "model_list") and self.embedding_config.model_list: first_model_name = self.embedding_config.model_list[0] logger.info(f"🎯 使用embedding模型: {first_model_name}") else: @@ -127,7 +126,9 @@ class BotInterestManager: # 生成新的兴趣标签 logger.info("🆕 数据库中未找到兴趣标签,开始生成新的...") logger.info("🤖 正在调用LLM生成个性化兴趣标签...") - generated_interests = await self._generate_interests_from_personality(personality_description, personality_id) + generated_interests = await self._generate_interests_from_personality( + personality_description, personality_id + ) if generated_interests: self.current_interests = generated_interests @@ -140,14 +141,16 @@ class BotInterestManager: else: raise RuntimeError("❌ 兴趣标签生成失败") - async def _generate_interests_from_personality(self, personality_description: str, personality_id: str) -> Optional[BotPersonalityInterests]: + async def _generate_interests_from_personality( + self, personality_description: str, personality_id: str + ) -> Optional[BotPersonalityInterests]: """根据人设生成兴趣标签""" try: logger.info("🎨 开始根据人设生成兴趣标签...") logger.info(f"📝 人设长度: {len(personality_description)} 字符") # 检查embedding客户端是否可用 - if not hasattr(self, 'embedding_request'): + if not hasattr(self, "embedding_request"): raise RuntimeError("❌ Embedding客户端未初始化,无法生成兴趣标签") # 构建提示词 @@ -190,8 +193,7 @@ class BotInterestManager: interests_data = orjson.loads(response) bot_interests = BotPersonalityInterests( - personality_id=personality_id, - personality_description=personality_description + personality_id=personality_id, personality_description=personality_description ) # 解析生成的兴趣标签 @@ -202,10 +204,7 @@ class BotInterestManager: tag_name = tag_data.get("name", f"标签_{i}") weight = tag_data.get("weight", 0.5) - tag = BotInterestTag( - tag_name=tag_name, - weight=weight - ) + tag = BotInterestTag(tag_name=tag_name, weight=weight) bot_interests.interest_tags.append(tag) logger.debug(f" 🏷️ {tag_name} (权重: {weight:.2f})") @@ -225,7 +224,6 @@ class BotInterestManager: traceback.print_exc() raise - async def _call_llm_for_interest_generation(self, prompt: str) -> Optional[str]: """调用LLM生成兴趣标签""" try: @@ -241,10 +239,10 @@ class BotInterestManager: {prompt} 请确保返回格式为有效的JSON,不要包含任何额外的文本、解释或代码块标记。只返回JSON对象本身。""" - + # 使用replyer模型配置 replyer_config = model_config.model_task_config.replyer - + # 调用LLM API logger.info("🚀 正在通过LLM API发送请求...") success, response, reasoning_content, model_name = await llm_api.generate_with_model( @@ -252,15 +250,17 @@ class BotInterestManager: model_config=replyer_config, request_type="interest_generation", temperature=0.7, - max_tokens=2000 + max_tokens=2000, ) if success and response: logger.info(f"✅ LLM响应成功,模型: {model_name}, 响应长度: {len(response)} 字符") - logger.debug(f"📄 LLM响应内容: {response[:200]}..." if len(response) > 200 else f"📄 LLM响应内容: {response}") + logger.debug( + f"📄 LLM响应内容: {response[:200]}..." if len(response) > 200 else f"📄 LLM响应内容: {response}" + ) if reasoning_content: logger.debug(f"🧠 推理内容: {reasoning_content[:100]}...") - + # 清理响应内容,移除可能的代码块标记 cleaned_response = self._clean_llm_response(response) return cleaned_response @@ -277,25 +277,25 @@ class BotInterestManager: def _clean_llm_response(self, response: str) -> str: """清理LLM响应,移除代码块标记和其他非JSON内容""" import re - + # 移除 ```json 和 ``` 标记 - cleaned = re.sub(r'```json\s*', '', response) - cleaned = re.sub(r'\s*```', '', cleaned) - + cleaned = re.sub(r"```json\s*", "", response) + cleaned = re.sub(r"\s*```", "", cleaned) + # 移除可能的多余空格和换行 cleaned = cleaned.strip() - + # 尝试提取JSON对象(如果响应中有其他文本) - json_match = re.search(r'\{.*\}', cleaned, re.DOTALL) + json_match = re.search(r"\{.*\}", cleaned, re.DOTALL) if json_match: cleaned = json_match.group(0) - + logger.debug(f"🧹 清理后的响应: {cleaned[:200]}..." if len(cleaned) > 200 else f"🧹 清理后的响应: {cleaned}") return cleaned async def _generate_embeddings_for_tags(self, interests: BotPersonalityInterests): """为所有兴趣标签生成embedding""" - if not hasattr(self, 'embedding_request'): + if not hasattr(self, "embedding_request"): raise RuntimeError("❌ Embedding客户端未初始化,无法生成embedding") total_tags = len(interests.interest_tags) @@ -342,7 +342,7 @@ class BotInterestManager: async def _get_embedding(self, text: str) -> List[float]: """获取文本的embedding向量""" - if not hasattr(self, 'embedding_request'): + if not hasattr(self, "embedding_request"): raise RuntimeError("❌ Embedding请求客户端未初始化") # 检查缓存 @@ -376,7 +376,9 @@ class BotInterestManager: logger.debug(f"✅ 消息embedding生成成功,维度: {len(embedding)}") return embedding - async def _calculate_similarity_scores(self, result: InterestMatchResult, message_embedding: List[float], keywords: List[str]): + async def _calculate_similarity_scores( + self, result: InterestMatchResult, message_embedding: List[float], keywords: List[str] + ): """计算消息与兴趣标签的相似度分数""" try: if not self.current_interests: @@ -397,7 +399,9 @@ class BotInterestManager: # 设置相似度阈值为0.3 if similarity > 0.3: result.add_match(tag.tag_name, weighted_score, keywords) - logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 加权分数={weighted_score:.3f}") + logger.debug( + f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 加权分数={weighted_score:.3f}" + ) except Exception as e: logger.error(f"❌ 计算相似度分数失败: {e}") @@ -455,7 +459,9 @@ class BotInterestManager: match_count += 1 high_similarity_count += 1 result.add_match(tag.tag_name, enhanced_score, [tag.tag_name]) - logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [高匹配]") + logger.debug( + f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [高匹配]" + ) elif similarity > medium_threshold: # 中相似度:中等加成 @@ -463,7 +469,9 @@ class BotInterestManager: match_count += 1 medium_similarity_count += 1 result.add_match(tag.tag_name, enhanced_score, [tag.tag_name]) - logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [中匹配]") + logger.debug( + f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [中匹配]" + ) elif similarity > low_threshold: # 低相似度:轻微加成 @@ -471,7 +479,9 @@ class BotInterestManager: match_count += 1 low_similarity_count += 1 result.add_match(tag.tag_name, enhanced_score, [tag.tag_name]) - logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [低匹配]") + logger.debug( + f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [低匹配]" + ) logger.info(f"📈 匹配统计: {match_count}/{len(active_tags)} 个标签超过阈值") logger.info(f"🔥 高相似度匹配(>{high_threshold}): {high_similarity_count} 个") @@ -488,7 +498,9 @@ class BotInterestManager: original_score = result.match_scores[tag_name] bonus = keyword_bonus[tag_name] result.match_scores[tag_name] = original_score + bonus - logger.debug(f" 🏷️ '{tag_name}': 原始分数={original_score:.3f}, 奖励={bonus:.3f}, 最终分数={result.match_scores[tag_name]:.3f}") + logger.debug( + f" 🏷️ '{tag_name}': 原始分数={original_score:.3f}, 奖励={bonus:.3f}, 最终分数={result.match_scores[tag_name]:.3f}" + ) # 计算总体分数 result.calculate_overall_score() @@ -499,10 +511,11 @@ class BotInterestManager: result.top_tag = top_tag_name logger.info(f"🏆 最佳匹配标签: '{top_tag_name}' (分数: {result.match_scores[top_tag_name]:.3f})") - logger.info(f"📊 最终结果: 总分={result.overall_score:.3f}, 置信度={result.confidence:.3f}, 匹配标签数={len(result.matched_tags)}") + logger.info( + f"📊 最终结果: 总分={result.overall_score:.3f}, 置信度={result.confidence:.3f}, 匹配标签数={len(result.matched_tags)}" + ) return result - def _calculate_keyword_match_bonus(self, keywords: List[str], matched_tags: List[str]) -> Dict[str, float]: """计算关键词直接匹配奖励""" if not keywords or not matched_tags: @@ -522,17 +535,25 @@ class BotInterestManager: # 完全匹配 if keyword_lower == tag_name_lower: bonus += affinity_config.high_match_interest_threshold * 0.6 # 使用高匹配阈值的60%作为完全匹配奖励 - logger.debug(f" 🎯 关键词完全匹配: '{keyword}' == '{tag_name}' (+{affinity_config.high_match_interest_threshold * 0.6:.3f})") + logger.debug( + f" 🎯 关键词完全匹配: '{keyword}' == '{tag_name}' (+{affinity_config.high_match_interest_threshold * 0.6:.3f})" + ) # 包含匹配 elif keyword_lower in tag_name_lower or tag_name_lower in keyword_lower: - bonus += affinity_config.medium_match_interest_threshold * 0.3 # 使用中匹配阈值的30%作为包含匹配奖励 - logger.debug(f" 🎯 关键词包含匹配: '{keyword}' ⊃ '{tag_name}' (+{affinity_config.medium_match_interest_threshold * 0.3:.3f})") + bonus += ( + affinity_config.medium_match_interest_threshold * 0.3 + ) # 使用中匹配阈值的30%作为包含匹配奖励 + logger.debug( + f" 🎯 关键词包含匹配: '{keyword}' ⊃ '{tag_name}' (+{affinity_config.medium_match_interest_threshold * 0.3:.3f})" + ) # 部分匹配(编辑距离) elif self._calculate_partial_match(keyword_lower, tag_name_lower): bonus += affinity_config.low_match_interest_threshold * 0.4 # 使用低匹配阈值的40%作为部分匹配奖励 - logger.debug(f" 🎯 关键词部分匹配: '{keyword}' ≈ '{tag_name}' (+{affinity_config.low_match_interest_threshold * 0.4:.3f})") + logger.debug( + f" 🎯 关键词部分匹配: '{keyword}' ≈ '{tag_name}' (+{affinity_config.low_match_interest_threshold * 0.4:.3f})" + ) if bonus > 0: bonus_dict[tag_name] = min(bonus, affinity_config.max_match_bonus) # 使用配置的最大奖励限制 @@ -608,12 +629,12 @@ class BotInterestManager: with get_db_session() as session: # 查询最新的兴趣标签配置 - db_interests = session.query(DBBotPersonalityInterests).filter( - DBBotPersonalityInterests.personality_id == personality_id - ).order_by( - DBBotPersonalityInterests.version.desc(), - DBBotPersonalityInterests.last_updated.desc() - ).first() + db_interests = ( + session.query(DBBotPersonalityInterests) + .filter(DBBotPersonalityInterests.personality_id == personality_id) + .order_by(DBBotPersonalityInterests.version.desc(), DBBotPersonalityInterests.last_updated.desc()) + .first() + ) if db_interests: logger.info(f"✅ 找到数据库中的兴趣标签配置,版本: {db_interests.version}") @@ -631,7 +652,7 @@ class BotInterestManager: personality_description=db_interests.personality_description, embedding_model=db_interests.embedding_model, version=db_interests.version, - last_updated=db_interests.last_updated + last_updated=db_interests.last_updated, ) # 解析兴趣标签 @@ -639,10 +660,14 @@ class BotInterestManager: tag = BotInterestTag( tag_name=tag_data.get("tag_name", ""), weight=tag_data.get("weight", 0.5), - created_at=datetime.fromisoformat(tag_data.get("created_at", datetime.now().isoformat())), - updated_at=datetime.fromisoformat(tag_data.get("updated_at", datetime.now().isoformat())), + created_at=datetime.fromisoformat( + tag_data.get("created_at", datetime.now().isoformat()) + ), + updated_at=datetime.fromisoformat( + tag_data.get("updated_at", datetime.now().isoformat()) + ), is_active=tag_data.get("is_active", True), - embedding=tag_data.get("embedding") + embedding=tag_data.get("embedding"), ) interests.interest_tags.append(tag) @@ -685,7 +710,7 @@ class BotInterestManager: "created_at": tag.created_at.isoformat(), "updated_at": tag.updated_at.isoformat(), "is_active": tag.is_active, - "embedding": tag.embedding + "embedding": tag.embedding, } tags_data.append(tag_dict) @@ -694,9 +719,11 @@ class BotInterestManager: with get_db_session() as session: # 检查是否已存在相同personality_id的记录 - existing_record = session.query(DBBotPersonalityInterests).filter( - DBBotPersonalityInterests.personality_id == interests.personality_id - ).first() + existing_record = ( + session.query(DBBotPersonalityInterests) + .filter(DBBotPersonalityInterests.personality_id == interests.personality_id) + .first() + ) if existing_record: # 更新现有记录 @@ -718,7 +745,7 @@ class BotInterestManager: interest_tags=json_data, embedding_model=interests.embedding_model, version=interests.version, - last_updated=interests.last_updated + last_updated=interests.last_updated, ) session.add(new_record) session.commit() @@ -728,9 +755,11 @@ class BotInterestManager: # 验证保存是否成功 with get_db_session() as session: - saved_record = session.query(DBBotPersonalityInterests).filter( - DBBotPersonalityInterests.personality_id == interests.personality_id - ).first() + saved_record = ( + session.query(DBBotPersonalityInterests) + .filter(DBBotPersonalityInterests.personality_id == interests.personality_id) + .first() + ) session.commit() if saved_record: logger.info(f"✅ 验证成功:数据库中存在personality_id为 {interests.personality_id} 的记录") @@ -760,7 +789,7 @@ class BotInterestManager: "total_tags": len(active_tags), "embedding_model": self.current_interests.embedding_model, "last_updated": self.current_interests.last_updated.isoformat(), - "cache_size": len(self.embedding_cache) + "cache_size": len(self.embedding_cache), } async def update_interest_tags(self, new_personality_description: str = None): @@ -775,8 +804,7 @@ class BotInterestManager: # 重新生成兴趣标签 new_interests = await self._generate_interests_from_personality( - self.current_interests.personality_description, - self.current_interests.personality_id + self.current_interests.personality_description, self.current_interests.personality_id ) if new_interests: @@ -791,4 +819,4 @@ class BotInterestManager: # 创建全局实例(重新创建以包含新的属性) -bot_interest_manager = BotInterestManager() \ No newline at end of file +bot_interest_manager = BotInterestManager() diff --git a/src/chat/message_manager/__init__.py b/src/chat/message_manager/__init__.py index c52e7f1b2..f909f720a 100644 --- a/src/chat/message_manager/__init__.py +++ b/src/chat/message_manager/__init__.py @@ -4,13 +4,11 @@ """ from .message_manager import MessageManager, message_manager -from src.common.data_models.message_manager_data_model import StreamContext, MessageStatus, MessageManagerStats, StreamStats +from src.common.data_models.message_manager_data_model import ( + StreamContext, + MessageStatus, + MessageManagerStats, + StreamStats, +) -__all__ = [ - "MessageManager", - "message_manager", - "StreamContext", - "MessageStatus", - "MessageManagerStats", - "StreamStats" -] \ No newline at end of file +__all__ = ["MessageManager", "message_manager", "StreamContext", "MessageStatus", "MessageManagerStats", "StreamStats"] diff --git a/src/chat/message_manager/message_manager.py b/src/chat/message_manager/message_manager.py index 2f12112ca..b660beba6 100644 --- a/src/chat/message_manager/message_manager.py +++ b/src/chat/message_manager/message_manager.py @@ -2,6 +2,7 @@ 消息管理模块 管理每个聊天流的上下文信息,包含历史记录和未读消息,定期检查并处理新消息 """ + import asyncio import time import traceback @@ -100,9 +101,7 @@ class MessageManager: # 如果没有处理任务,创建一个 if not context.processing_task or context.processing_task.done(): - context.processing_task = asyncio.create_task( - self._process_stream_messages(stream_id) - ) + context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id)) # 更新统计 self.stats.active_streams = active_streams @@ -128,11 +127,11 @@ class MessageManager: try: # 发送到AFC处理器,传递StreamContext对象 results = await afc_manager.process_stream_context(stream_id, context) - + # 处理结果,标记消息为已读 if results.get("success", False): self._clear_all_unread_messages(context) - + except Exception as e: logger.error(f"处理聊天流 {stream_id} 时发生异常,将清除所有未读消息: {e}") raise @@ -175,7 +174,7 @@ class MessageManager: unread_count=len(context.get_unread_messages()), history_count=len(context.history_messages), last_check_time=context.last_check_time, - has_active_task=context.processing_task and not context.processing_task.done() + has_active_task=context.processing_task and not context.processing_task.done(), ) def get_manager_stats(self) -> Dict[str, Any]: @@ -186,7 +185,7 @@ class MessageManager: "total_unread_messages": self.stats.total_unread_messages, "total_processed_messages": self.stats.total_processed_messages, "uptime": self.stats.uptime, - "start_time": self.stats.start_time + "start_time": self.stats.start_time, } def cleanup_inactive_streams(self, max_inactive_hours: int = 24): @@ -196,8 +195,7 @@ class MessageManager: inactive_streams = [] for stream_id, context in self.stream_contexts.items(): - if (current_time - context.last_check_time > max_inactive_seconds and - not context.get_unread_messages()): + if current_time - context.last_check_time > max_inactive_seconds and not context.get_unread_messages(): inactive_streams.append(stream_id) for stream_id in inactive_streams: @@ -210,9 +208,9 @@ class MessageManager: unread_messages = context.get_unread_messages() if not unread_messages: return - + logger.warning(f"正在清除 {len(unread_messages)} 条未读消息") - + # 将所有未读消息标记为已读并移动到历史记录 for msg in unread_messages[:]: # 使用切片复制避免迭代时修改列表 try: @@ -224,4 +222,4 @@ class MessageManager: # 创建全局消息管理器实例 -message_manager = MessageManager() \ No newline at end of file +message_manager = MessageManager() diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index f3fff8bb9..d0c1146e6 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -17,6 +17,7 @@ from src.plugin_system.core import component_registry, event_manager, global_ann from src.plugin_system.base import BaseCommand, EventType from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor from src.chat.utils.utils import is_mentioned_bot_in_message + # 导入反注入系统 from src.chat.antipromptinjector import initialize_anti_injector @@ -511,7 +512,7 @@ class ChatBot: chat_info_user_id=message.chat_stream.user_info.user_id, chat_info_user_nickname=message.chat_stream.user_info.user_nickname, chat_info_user_cardname=message.chat_stream.user_info.user_cardname, - chat_info_user_platform=message.chat_stream.user_info.platform + chat_info_user_platform=message.chat_stream.user_info.platform, ) # 如果是群聊,添加群组信息 diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index f5822acfb..63ec0346e 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -84,7 +84,9 @@ class ChatStream: self.saved = False self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息 # 从配置文件中读取focus_value,如果没有则使用默认值1.0 - self.focus_energy = data.get("focus_energy", global_config.chat.focus_value) if data else global_config.chat.focus_value + self.focus_energy = ( + data.get("focus_energy", global_config.chat.focus_value) if data else global_config.chat.focus_value + ) self.no_reply_consecutive = 0 self.breaking_accumulated_interest = 0.0 diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 8f09894c3..ba6196804 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -165,10 +165,15 @@ class ActionManager: # 通过chat_id获取chat_stream chat_manager = get_chat_manager() chat_stream = chat_manager.get_stream(chat_id) - + if not chat_stream: logger.error(f"{log_prefix} 无法找到chat_id对应的chat_stream: {chat_id}") - return {"action_type": action_name, "success": False, "reply_text": "", "error": "chat_stream not found"} + return { + "action_type": action_name, + "success": False, + "reply_text": "", + "error": "chat_stream not found", + } if action_name == "no_action": return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""} @@ -177,7 +182,7 @@ class ActionManager: # 直接处理no_reply逻辑,不再通过动作系统 reason = reasoning or "选择不回复" logger.info(f"{log_prefix} 选择不回复,原因: {reason}") - + # 存储no_reply信息到数据库 await database_api.store_action_info( chat_stream=chat_stream, @@ -396,7 +401,7 @@ class ActionManager: } return loop_info, reply_text, cycle_timers - + async def send_response(self, chat_stream, reply_set, thinking_start_time, message_data) -> str: """ 发送回复内容的具体实现 @@ -471,4 +476,4 @@ class ActionManager: typing=True, ) - return reply_text \ No newline at end of file + return reply_text diff --git a/src/chat/planner_actions/plan_generator.py b/src/chat/planner_actions/plan_generator.py index 49ef38a18..26e05fcf1 100644 --- a/src/chat/planner_actions/plan_generator.py +++ b/src/chat/planner_actions/plan_generator.py @@ -1,6 +1,7 @@ """ PlanGenerator: 负责搜集和汇总所有决策所需的信息,生成一个未经筛选的“原始计划” (Plan)。 """ + import time from typing import Dict @@ -35,6 +36,7 @@ class PlanGenerator: chat_id (str): 当前聊天的 ID。 """ from src.chat.planner_actions.action_manager import ActionManager + self.chat_id = chat_id # 注意:ActionManager 可能需要根据实际情况初始化 self.action_manager = ActionManager() @@ -52,7 +54,7 @@ class PlanGenerator: Plan: 一个填充了初始上下文信息的 Plan 对象。 """ _is_group_chat, chat_target_info_dict = get_chat_type_and_target_info(self.chat_id) - + target_info = None if chat_target_info_dict: target_info = TargetPersonInfo(**chat_target_info_dict) @@ -65,7 +67,6 @@ class PlanGenerator: ) chat_history = [DatabaseMessages(**msg) for msg in chat_history_raw] - plan = Plan( chat_id=self.chat_id, mode=mode, @@ -86,10 +87,10 @@ class PlanGenerator: Dict[str, "ActionInfo"]: 一个字典,键是动作名称,值是 ActionInfo 对象。 """ current_available_actions_dict = self.action_manager.get_using_actions() - all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore + all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore ComponentType.ACTION ) - + current_available_actions = {} for action_name in current_available_actions_dict: if action_name in all_registered_actions: @@ -99,16 +100,13 @@ class PlanGenerator: name="reply", component_type=ComponentType.ACTION, description="系统级动作:选择回复消息的决策", - action_parameters={ - "content": "回复的文本内容", - "reply_to_message_id": "要回复的消息ID" - }, + action_parameters={"content": "回复的文本内容", "reply_to_message_id": "要回复的消息ID"}, action_require=[ "你想要闲聊或者随便附和", "当用户提到你或艾特你时", "当需要回答用户的问题时", "当你想参与对话时", - "当用户分享有趣的内容时" + "当用户分享有趣的内容时", ], activation_type=ActionActivationType.ALWAYS, activation_keywords=[], @@ -131,4 +129,4 @@ class PlanGenerator: ) current_available_actions["no_reply"] = no_reply_info current_available_actions["reply"] = reply_info - return current_available_actions \ No newline at end of file + return current_available_actions diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index f39ee67a0..d648c2292 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -109,9 +109,7 @@ class ActionPlanner: self.planner_stats["failed_plans"] += 1 return [], None - async def _enhanced_plan_flow( - self, mode: ChatMode, context: StreamContext - ) -> Tuple[List[Dict], Optional[Dict]]: + async def _enhanced_plan_flow(self, mode: ChatMode, context: StreamContext) -> Tuple[List[Dict], Optional[Dict]]: """执行增强版规划流程""" try: # 1. 生成初始 Plan @@ -137,7 +135,9 @@ class ActionPlanner: # 检查兴趣度是否达到非回复动作阈值 non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold if score < non_reply_action_interest_threshold: - logger.info(f"❌ 兴趣度不足非回复动作阈值: {score:.3f} < {non_reply_action_interest_threshold:.3f},直接返回no_action") + logger.info( + f"❌ 兴趣度不足非回复动作阈值: {score:.3f} < {non_reply_action_interest_threshold:.3f},直接返回no_action" + ) logger.info(f"📊 最低要求: {non_reply_action_interest_threshold:.3f}") # 直接返回 no_action from src.common.data_models.info_data_model import ActionPlannerInfo diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index e837964ca..7f9b1c501 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -303,7 +303,7 @@ class DefaultReplyer: "model": model_name, "tool_calls": tool_call, } - + # 触发 AFTER_LLM 事件 if not from_plugin: result = await event_manager.trigger_event( @@ -598,6 +598,7 @@ class DefaultReplyer: def _parse_reply_target(self, target_message: str) -> Tuple[str, str]: """解析回复目标消息 - 使用共享工具""" from src.chat.utils.prompt import Prompt + if target_message is None: logger.warning("target_message为None,返回默认值") return "未知用户", "(无消息内容)" @@ -704,22 +705,24 @@ class DefaultReplyer: unread_history_prompt = "" if unread_messages: # 尝试获取兴趣度评分 - interest_scores = await self._get_interest_scores_for_messages([msg.flatten() for msg in unread_messages]) + interest_scores = await self._get_interest_scores_for_messages( + [msg.flatten() for msg in unread_messages] + ) unread_lines = [] for msg in unread_messages: msg_id = msg.message_id - msg_time = time.strftime('%H:%M:%S', time.localtime(msg.time)) + msg_time = time.strftime("%H:%M:%S", time.localtime(msg.time)) msg_content = msg.processed_plain_text - + # 使用与已读历史消息相同的方法获取用户名 from src.person_info.person_info import PersonInfoManager, get_person_info_manager - + # 获取用户信息 - user_info = getattr(msg, 'user_info', {}) - platform = getattr(user_info, 'platform', '') or getattr(msg, 'platform', '') - user_id = getattr(user_info, 'user_id', '') or getattr(msg, 'user_id', '') - + user_info = getattr(msg, "user_info", {}) + platform = getattr(user_info, "platform", "") or getattr(msg, "platform", "") + user_id = getattr(user_info, "user_id", "") or getattr(msg, "user_id", "") + # 获取用户名 if platform and user_id: person_id = PersonInfoManager.get_person_id(platform, user_id) @@ -727,11 +730,11 @@ class DefaultReplyer: sender_name = person_info_manager.get_value_sync(person_id, "person_name") or "未知用户" else: sender_name = "未知用户" - + # 添加兴趣度信息 interest_score = interest_scores.get(msg_id, 0.0) interest_text = f" [兴趣度: {interest_score:.3f}]" if interest_score > 0 else "" - + unread_lines.append(f"{msg_time} {sender_name}: {msg_content}{interest_text}") unread_history_prompt_str = "\n".join(unread_lines) @@ -808,17 +811,17 @@ class DefaultReplyer: unread_lines = [] for msg in unread_messages: msg_id = msg.get("message_id", "") - msg_time = time.strftime('%H:%M:%S', time.localtime(msg.get("time", time.time()))) + msg_time = time.strftime("%H:%M:%S", time.localtime(msg.get("time", time.time()))) msg_content = msg.get("processed_plain_text", "") - + # 使用与已读历史消息相同的方法获取用户名 from src.person_info.person_info import PersonInfoManager, get_person_info_manager - + # 获取用户信息 user_info = msg.get("user_info", {}) platform = user_info.get("platform") or msg.get("platform", "") user_id = user_info.get("user_id") or msg.get("user_id", "") - + # 获取用户名 if platform and user_id: person_id = PersonInfoManager.get_person_id(platform, user_id) @@ -834,7 +837,9 @@ class DefaultReplyer: unread_lines.append(f"{msg_time} {sender_name}: {msg_content}{interest_text}") unread_history_prompt_str = "\n".join(unread_lines) - unread_history_prompt = f"这是未读历史消息,包含兴趣度评分,请优先对兴趣值高的消息做出动作:\n{unread_history_prompt_str}" + unread_history_prompt = ( + f"这是未读历史消息,包含兴趣度评分,请优先对兴趣值高的消息做出动作:\n{unread_history_prompt_str}" + ) else: unread_history_prompt = "暂无未读历史消息" @@ -982,7 +987,7 @@ class DefaultReplyer: reply_message.get("user_id"), # type: ignore ) person_name = await person_info_manager.get_value(person_id, "person_name") - + # 如果person_name为None,使用fallback值 if person_name is None: # 尝试从reply_message获取用户名 @@ -990,12 +995,12 @@ class DefaultReplyer: logger.warning(f"未知用户,将存储用户信息:{fallback_name}") person_name = str(fallback_name) person_info_manager.set_value(person_id, "person_name", fallback_name) - + # 检查是否是bot自己的名字,如果是则替换为"(你)" bot_user_id = str(global_config.bot.qq_account) current_user_id = person_info_manager.get_value_sync(person_id, "user_id") current_platform = reply_message.get("chat_info_platform") - + if current_user_id == bot_user_id and current_platform == global_config.bot.platform: sender = f"{person_name}(你)" else: @@ -1050,8 +1055,9 @@ class DefaultReplyer: target_user_info = None if sender: target_user_info = await person_info_manager.get_person_info_by_name(sender) - + from src.chat.utils.prompt import Prompt + # 并行执行六个构建任务 task_results = await asyncio.gather( self._time_and_run_task( @@ -1127,6 +1133,7 @@ class DefaultReplyer: schedule_block = "" if global_config.planning_system.schedule_enable: from src.schedule.schedule_manager import schedule_manager + current_activity = schedule_manager.get_current_activity() if current_activity: schedule_block = f"你当前正在:{current_activity}。" @@ -1139,7 +1146,7 @@ class DefaultReplyer: safety_guidelines = global_config.personality.safety_guidelines safety_guidelines_block = "" if safety_guidelines: - guidelines_text = "\n".join(f"{i+1}. {line}" for i, line in enumerate(safety_guidelines)) + guidelines_text = "\n".join(f"{i + 1}. {line}" for i, line in enumerate(safety_guidelines)) safety_guidelines_block = f"""### 安全与互动底线 在任何情况下,你都必须遵守以下由你的设定者为你定义的原则: {guidelines_text} @@ -1212,7 +1219,7 @@ class DefaultReplyer: template_name = "normal_style_prompt" elif current_prompt_mode == "minimal": template_name = "default_expressor_prompt" - + # 获取模板内容 template_prompt = await global_prompt_manager.get_prompt_async(template_name) prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters) @@ -1488,19 +1495,19 @@ class DefaultReplyer: # 使用AFC关系追踪器获取关系信息 try: from src.chat.affinity_flow.relationship_integration import get_relationship_tracker - + relationship_tracker = get_relationship_tracker() if relationship_tracker: # 获取用户信息以获取真实的user_id user_info = await person_info_manager.get_values(person_id, ["user_id", "platform"]) user_id = user_info.get("user_id", "unknown") - + # 从数据库获取关系数据 relationship_data = relationship_tracker._get_user_relationship_from_db(user_id) if relationship_data: relationship_text = relationship_data.get("relationship_text", "") relationship_score = relationship_data.get("relationship_score", 0.3) - + # 构建丰富的关系信息描述 if relationship_text: # 转换关系分数为描述性文本 @@ -1514,7 +1521,7 @@ class DefaultReplyer: relationship_level = "认识的人" else: relationship_level = "陌生人" - + return f"你与{sender}的关系:{relationship_level}(关系分:{relationship_score:.2f}/1.0)。{relationship_text}" else: return f"你与{sender}是初次见面,关系分:{relationship_score:.2f}/1.0。" @@ -1523,7 +1530,7 @@ class DefaultReplyer: else: logger.warning("AFC关系追踪器未初始化,使用默认关系信息") return f"你与{sender}是普通朋友关系。" - + except Exception as e: logger.error(f"获取AFC关系信息失败: {e}") return f"你与{sender}是普通朋友关系。" diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 83b1b0587..41c2f2ed9 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -37,7 +37,7 @@ def replace_user_references_sync( """ if not content: return "" - + if name_resolver is None: person_info_manager = get_person_info_manager() @@ -821,7 +821,7 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: try: with get_db_session() as session: image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar_one_or_none() - if image and image.description: # type: ignore + if image and image.description: # type: ignore description = image.description except Exception: # 如果查询失败,保持默认描述 diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index fa73f9538..dd6010937 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -25,7 +25,7 @@ logger = get_logger("unified_prompt") @dataclass class PromptParameters: """统一提示词参数系统""" - + # 基础参数 chat_id: str = "" is_group_chat: bool = False @@ -34,7 +34,7 @@ class PromptParameters: reply_to: str = "" extra_info: str = "" prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u" - + # 功能开关 enable_tool: bool = True enable_memory: bool = True @@ -42,20 +42,20 @@ class PromptParameters: enable_relation: bool = True enable_cross_context: bool = True enable_knowledge: bool = True - + # 性能控制 max_context_messages: int = 50 - + # 调试选项 debug_mode: bool = False - + # 聊天历史和上下文 chat_target_info: Optional[Dict[str, Any]] = None message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list) message_list_before_short: List[Dict[str, Any]] = field(default_factory=list) chat_talking_prompt_short: str = "" target_user_info: Optional[Dict[str, Any]] = None - + # 已构建的内容块 expression_habits_block: str = "" relation_info_block: str = "" @@ -63,7 +63,7 @@ class PromptParameters: tool_info_block: str = "" knowledge_prompt: str = "" cross_context_block: str = "" - + # 其他内容块 keywords_reaction_prompt: str = "" extra_info_block: str = "" @@ -75,10 +75,10 @@ class PromptParameters: reply_target_block: str = "" mood_prompt: str = "" action_descriptions: str = "" - + # 可用动作信息 available_actions: Optional[Dict[str, Any]] = None - + def validate(self) -> List[str]: """参数验证""" errors = [] @@ -93,22 +93,22 @@ class PromptParameters: class PromptContext: """提示词上下文管理器""" - + def __init__(self): self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {} self._current_context_var = contextvars.ContextVar("current_context", default=None) self._context_lock = asyncio.Lock() - + @property def _current_context(self) -> Optional[str]: """获取当前协程的上下文ID""" return self._current_context_var.get() - + @_current_context.setter def _current_context(self, value: Optional[str]): """设置当前协程的上下文ID""" self._current_context_var.set(value) # type: ignore - + @asynccontextmanager async def async_scope(self, context_id: Optional[str] = None): """创建一个异步的临时提示模板作用域""" @@ -123,13 +123,13 @@ class PromptContext: except asyncio.TimeoutError: logger.warning(f"获取上下文锁超时,context_id: {context_id}") context_id = None - + previous_context = self._current_context token = self._current_context_var.set(context_id) if context_id else None else: previous_context = self._current_context token = None - + try: yield self finally: @@ -142,7 +142,7 @@ class PromptContext: self._current_context = previous_context except Exception: ... - + async def get_prompt_async(self, name: str) -> Optional["Prompt"]: """异步获取当前作用域中的提示模板""" async with self._context_lock: @@ -155,7 +155,7 @@ class PromptContext: ): return self._context_prompts[current_context][name] return None - + async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None: """异步注册提示模板到指定作用域""" async with self._context_lock: @@ -166,49 +166,49 @@ class PromptContext: class PromptManager: """统一提示词管理器""" - + def __init__(self): self._prompts = {} self._counter = 0 self._context = PromptContext() self._lock = asyncio.Lock() - + @asynccontextmanager async def async_message_scope(self, message_id: Optional[str] = None): """为消息处理创建异步临时作用域""" async with self._context.async_scope(message_id): yield self - + async def get_prompt_async(self, name: str) -> "Prompt": """异步获取提示模板""" context_prompt = await self._context.get_prompt_async(name) if context_prompt is not None: logger.debug(f"从上下文中获取提示词: {name} {context_prompt}") return context_prompt - + async with self._lock: if name not in self._prompts: raise KeyError(f"Prompt '{name}' not found") return self._prompts[name] - + def generate_name(self, template: str) -> str: """为未命名的prompt生成名称""" self._counter += 1 return f"prompt_{self._counter}" - + def register(self, prompt: "Prompt") -> None: """注册一个prompt""" if not prompt.name: prompt.name = self.generate_name(prompt.template) self._prompts[prompt.name] = prompt - + def add_prompt(self, name: str, fstr: str) -> "Prompt": """添加新提示模板""" prompt = Prompt(fstr, name=name) if prompt.name: self._prompts[prompt.name] = prompt return prompt - + async def format_prompt(self, name: str, **kwargs) -> str: """格式化提示模板""" prompt = await self.get_prompt_async(name) @@ -225,21 +225,21 @@ class Prompt: 统一提示词类 - 合并模板管理和智能构建功能 真正的Prompt类,支持模板管理和智能上下文构建 """ - + # 临时标记,作为类常量 _TEMP_LEFT_BRACE = "__ESCAPED_LEFT_BRACE__" _TEMP_RIGHT_BRACE = "__ESCAPED_RIGHT_BRACE__" - + def __init__( self, template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, - should_register: bool = True + should_register: bool = True, ): """ 初始化统一提示词 - + Args: template: 提示词模板字符串 name: 提示词名称 @@ -251,14 +251,14 @@ class Prompt: self.parameters = parameters or PromptParameters() self.args = self._parse_template_args(template) self._formatted_result = "" - + # 预处理模板中的转义花括号 self._processed_template = self._process_escaped_braces(template) - + # 自动注册 if should_register and not global_prompt_manager._context._current_context: global_prompt_manager.register(self) - + @staticmethod def _process_escaped_braces(template) -> str: """处理模板中的转义花括号""" @@ -266,14 +266,14 @@ class Prompt: template = "\n".join(str(item) for item in template) elif not isinstance(template, str): template = str(template) - + return template.replace("\\{", Prompt._TEMP_LEFT_BRACE).replace("\\}", Prompt._TEMP_RIGHT_BRACE) - + @staticmethod def _restore_escaped_braces(template: str) -> str: """将临时标记还原为实际的花括号字符""" return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}") - + def _parse_template_args(self, template: str) -> List[str]: """解析模板参数""" template_args = [] @@ -283,11 +283,11 @@ class Prompt: if expr and expr not in template_args: template_args.append(expr) return template_args - + async def build(self) -> str: """ 构建完整的提示词,包含智能上下文 - + Returns: str: 构建完成的提示词文本 """ @@ -296,38 +296,38 @@ class Prompt: if errors: logger.error(f"参数验证失败: {', '.join(errors)}") raise ValueError(f"参数验证失败: {', '.join(errors)}") - + start_time = time.time() try: # 构建上下文数据 context_data = await self._build_context_data() - + # 格式化模板 result = await self._format_with_context(context_data) - + total_time = time.time() - start_time logger.debug(f"Prompt构建完成,模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s") - + self._formatted_result = result return result - + except asyncio.TimeoutError as e: logger.error(f"构建Prompt超时: {e}") raise TimeoutError(f"构建Prompt超时: {e}") from e except Exception as e: logger.error(f"构建Prompt失败: {e}") raise RuntimeError(f"构建Prompt失败: {e}") from e - + async def _build_context_data(self) -> Dict[str, Any]: """构建智能上下文数据""" # 并行执行所有构建任务 start_time = time.time() - + try: # 准备构建任务 tasks = [] task_names = [] - + # 初始化预构建参数 pre_built_params = {} if self.parameters.expression_habits_block: @@ -342,32 +342,32 @@ class Prompt: pre_built_params["knowledge_prompt"] = self.parameters.knowledge_prompt if self.parameters.cross_context_block: pre_built_params["cross_context_block"] = self.parameters.cross_context_block - + # 根据参数确定要构建的项 if self.parameters.enable_expression and not pre_built_params.get("expression_habits_block"): tasks.append(self._build_expression_habits()) task_names.append("expression_habits") - + if self.parameters.enable_memory and not pre_built_params.get("memory_block"): tasks.append(self._build_memory_block()) task_names.append("memory_block") - + if self.parameters.enable_relation and not pre_built_params.get("relation_info_block"): tasks.append(self._build_relation_info()) task_names.append("relation_info") - + if self.parameters.enable_tool and not pre_built_params.get("tool_info_block"): tasks.append(self._build_tool_info()) task_names.append("tool_info") - + if self.parameters.enable_knowledge and not pre_built_params.get("knowledge_prompt"): tasks.append(self._build_knowledge_info()) task_names.append("knowledge_info") - + if self.parameters.enable_cross_context and not pre_built_params.get("cross_context_block"): tasks.append(self._build_cross_context()) task_names.append("cross_context") - + # 性能优化 base_timeout = 10.0 task_timeout = 2.0 @@ -375,13 +375,13 @@ class Prompt: max(base_timeout, len(tasks) * task_timeout), 30.0, ) - + max_concurrent_tasks = 5 if len(tasks) > max_concurrent_tasks: results = [] for i in range(0, len(tasks), max_concurrent_tasks): batch_tasks = tasks[i : i + max_concurrent_tasks] - + batch_results = await asyncio.wait_for( asyncio.gather(*batch_tasks, return_exceptions=True), timeout=timeout_seconds ) @@ -390,53 +390,55 @@ class Prompt: results = await asyncio.wait_for( asyncio.gather(*tasks, return_exceptions=True), timeout=timeout_seconds ) - + # 处理结果 context_data = {} for i, result in enumerate(results): task_name = task_names[i] if i < len(task_names) else f"task_{i}" - + if isinstance(result, Exception): logger.error(f"构建任务{task_name}失败: {str(result)}") elif isinstance(result, dict): context_data.update(result) - + # 添加预构建的参数 for key, value in pre_built_params.items(): if value: context_data[key] = value - + except asyncio.TimeoutError: logger.error(f"构建超时 ({timeout_seconds}s)") context_data = {} for key, value in pre_built_params.items(): if value: context_data[key] = value - + # 构建聊天历史 if self.parameters.prompt_mode == "s4u": await self._build_s4u_chat_context(context_data) else: await self._build_normal_chat_context(context_data) - + # 补充基础信息 - context_data.update({ - "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt, - "extra_info_block": self.parameters.extra_info_block, - "time_block": self.parameters.time_block or f"当前时间:{time.strftime('%Y-%m-%d %H:%M:%S')}", - "identity": self.parameters.identity_block, - "schedule_block": self.parameters.schedule_block, - "moderation_prompt": self.parameters.moderation_prompt_block, - "reply_target_block": self.parameters.reply_target_block, - "mood_state": self.parameters.mood_prompt, - "action_descriptions": self.parameters.action_descriptions, - }) - + context_data.update( + { + "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt, + "extra_info_block": self.parameters.extra_info_block, + "time_block": self.parameters.time_block or f"当前时间:{time.strftime('%Y-%m-%d %H:%M:%S')}", + "identity": self.parameters.identity_block, + "schedule_block": self.parameters.schedule_block, + "moderation_prompt": self.parameters.moderation_prompt_block, + "reply_target_block": self.parameters.reply_target_block, + "mood_state": self.parameters.mood_prompt, + "action_descriptions": self.parameters.action_descriptions, + } + ) + total_time = time.time() - start_time logger.debug(f"上下文构建完成,总耗时: {total_time:.2f}s") - + return context_data - + async def _build_s4u_chat_context(self, context_data: Dict[str, Any]) -> None: """构建S4U模式的聊天上下文""" if not self.parameters.message_list_before_now_long: @@ -446,20 +448,20 @@ class Prompt: self.parameters.message_list_before_now_long, self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "", self.parameters.sender, - self.parameters.chat_id + self.parameters.chat_id, ) context_data["read_history_prompt"] = read_history_prompt context_data["unread_history_prompt"] = unread_history_prompt - + async def _build_normal_chat_context(self, context_data: Dict[str, Any]) -> None: """构建normal模式的聊天上下文""" if not self.parameters.chat_talking_prompt_short: return - + context_data["chat_info"] = f"""群里的聊天内容: {self.parameters.chat_talking_prompt_short}""" - + async def _build_s4u_chat_history_prompts( self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, chat_id: str ) -> Tuple[str, str]: @@ -476,101 +478,92 @@ class Prompt: except Exception as e: logger.error(f"构建S4U历史消息prompt失败: {e}") - - async def _build_expression_habits(self) -> Dict[str, Any]: """构建表达习惯""" if not global_config.expression.enable_expression: return {"expression_habits_block": ""} - + try: from src.chat.express.expression_selector import ExpressionSelector - + # 获取聊天历史用于表情选择 chat_history = "" if self.parameters.message_list_before_now_long: recent_messages = self.parameters.message_list_before_now_long[-10:] chat_history = build_readable_messages( - recent_messages, - replace_bot_name=True, - timestamp_mode="normal", - truncate=True + recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True ) - + # 创建表情选择器 expression_selector = ExpressionSelector(self.parameters.chat_id) - + # 选择合适的表情 selected_expressions = await expression_selector.select_suitable_expressions_llm( chat_history=chat_history, current_message=self.parameters.target, emotional_tone="neutral", - topic_type="general" + topic_type="general", ) - + # 构建表达习惯块 if selected_expressions: style_habits_str = "\n".join([f"- {expr}" for expr in selected_expressions]) expression_habits_block = f"- 你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:\n{style_habits_str}" else: expression_habits_block = "" - + return {"expression_habits_block": expression_habits_block} - + except Exception as e: logger.error(f"构建表达习惯失败: {e}") return {"expression_habits_block": ""} - + async def _build_memory_block(self) -> Dict[str, Any]: """构建记忆块""" if not global_config.memory.enable_memory: return {"memory_block": ""} - + try: from src.chat.memory_system.memory_activator import MemoryActivator from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory - + # 获取聊天历史 chat_history = "" if self.parameters.message_list_before_now_long: recent_messages = self.parameters.message_list_before_now_long[-20:] chat_history = build_readable_messages( - recent_messages, - replace_bot_name=True, - timestamp_mode="normal", - truncate=True + recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True ) - + # 激活长期记忆 memory_activator = MemoryActivator() running_memories = await memory_activator.activate_memory_with_chat_history( - target_message=self.parameters.target, - chat_history_prompt=chat_history + target_message=self.parameters.target, chat_history_prompt=chat_history ) - + # 获取即时记忆 async_memory_wrapper = get_async_instant_memory(self.parameters.chat_id) instant_memory = await async_memory_wrapper.get_memory_with_fallback(self.parameters.target) - + # 构建记忆块 memory_parts = [] - + if running_memories: memory_parts.append("以下是当前在聊天中,你回忆起的记忆:") for memory in running_memories: memory_parts.append(f"- {memory['content']}") - + if instant_memory: memory_parts.append(f"- {instant_memory}") - + memory_block = "\n".join(memory_parts) if memory_parts else "" - + return {"memory_block": memory_block} - + except Exception as e: logger.error(f"构建记忆块失败: {e}") return {"memory_block": ""} - + async def _build_relation_info(self) -> Dict[str, Any]: """构建关系信息""" try: @@ -579,110 +572,104 @@ class Prompt: except Exception as e: logger.error(f"构建关系信息失败: {e}") return {"relation_info_block": ""} - + async def _build_tool_info(self) -> Dict[str, Any]: """构建工具信息""" if not global_config.tool.enable_tool: return {"tool_info_block": ""} - + try: from src.plugin_system.core.tool_use import ToolExecutor - + # 获取聊天历史 chat_history = "" if self.parameters.message_list_before_now_long: recent_messages = self.parameters.message_list_before_now_long[-15:] chat_history = build_readable_messages( - recent_messages, - replace_bot_name=True, - timestamp_mode="normal", - truncate=True + recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True ) - + # 创建工具执行器 tool_executor = ToolExecutor(chat_id=self.parameters.chat_id) - + # 执行工具获取信息 tool_results, _, _ = await tool_executor.execute_from_chat_message( sender=self.parameters.sender, target_message=self.parameters.target, chat_history=chat_history, - return_details=False + return_details=False, ) - + # 构建工具信息块 if tool_results: - tool_info_parts = ["## 工具信息","以下是你通过工具获取到的实时信息:"] + tool_info_parts = ["## 工具信息", "以下是你通过工具获取到的实时信息:"] for tool_result in tool_results: tool_name = tool_result.get("tool_name", "unknown") content = tool_result.get("content", "") result_type = tool_result.get("type", "tool_result") - + tool_info_parts.append(f"- 【{tool_name}】{result_type}: {content}") - + tool_info_parts.append("以上是你获取到的实时信息,请在回复时参考这些信息。") tool_info_block = "\n".join(tool_info_parts) else: tool_info_block = "" - + return {"tool_info_block": tool_info_block} - + except Exception as e: logger.error(f"构建工具信息失败: {e}") return {"tool_info_block": ""} - + async def _build_knowledge_info(self) -> Dict[str, Any]: """构建知识信息""" if not global_config.lpmm_knowledge.enable: return {"knowledge_prompt": ""} - + try: from src.chat.knowledge.knowledge_lib import QAManager - + # 获取问题文本(当前消息) question = self.parameters.target or "" if not question: return {"knowledge_prompt": ""} - + # 创建QA管理器 qa_manager = QAManager() - + # 搜索相关知识 knowledge_results = await qa_manager.get_knowledge( - question=question, - chat_id=self.parameters.chat_id, - max_results=5, - min_similarity=0.5 + question=question, chat_id=self.parameters.chat_id, max_results=5, min_similarity=0.5 ) - + # 构建知识块 if knowledge_results and knowledge_results.get("knowledge_items"): - knowledge_parts = ["## 知识库信息","以下是与你当前对话相关的知识信息:"] - + knowledge_parts = ["## 知识库信息", "以下是与你当前对话相关的知识信息:"] + for item in knowledge_results["knowledge_items"]: content = item.get("content", "") source = item.get("source", "") relevance = item.get("relevance", 0.0) - + if content: if source: knowledge_parts.append(f"- [{relevance:.2f}] {content} (来源: {source})") else: knowledge_parts.append(f"- [{relevance:.2f}] {content}") - + if knowledge_results.get("summary"): knowledge_parts.append(f"\n知识总结: {knowledge_results['summary']}") - + knowledge_prompt = "\n".join(knowledge_parts) else: knowledge_prompt = "" - + return {"knowledge_prompt": knowledge_prompt} - + except Exception as e: logger.error(f"构建知识信息失败: {e}") return {"knowledge_prompt": ""} - + async def _build_cross_context(self) -> Dict[str, Any]: """构建跨群上下文""" try: @@ -693,7 +680,7 @@ class Prompt: except Exception as e: logger.error(f"构建跨群上下文失败: {e}") return {"cross_context_block": ""} - + async def _format_with_context(self, context_data: Dict[str, Any]) -> str: """使用上下文数据格式化模板""" if self.parameters.prompt_mode == "s4u": @@ -702,9 +689,9 @@ class Prompt: params = self._prepare_normal_params(context_data) else: params = self._prepare_default_params(context_data) - + return await global_prompt_manager.format_prompt(self.name, **params) if self.name else self.format(**params) - + def _prepare_s4u_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: """准备S4U模式的参数""" return { @@ -725,11 +712,13 @@ class Prompt: "time_block": context_data.get("time_block", ""), "reply_target_block": context_data.get("reply_target_block", ""), "reply_style": global_config.personality.reply_style, - "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""), + "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt + or context_data.get("keywords_reaction_prompt", ""), "moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""), - "safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""), + "safety_guidelines_block": self.parameters.safety_guidelines_block + or context_data.get("safety_guidelines_block", ""), } - + def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: """准备Normal模式的参数""" return { @@ -749,11 +738,13 @@ class Prompt: "reply_target_block": context_data.get("reply_target_block", ""), "config_expression_style": global_config.personality.reply_style, "mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""), - "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""), + "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt + or context_data.get("keywords_reaction_prompt", ""), "moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""), - "safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""), + "safety_guidelines_block": self.parameters.safety_guidelines_block + or context_data.get("safety_guidelines_block", ""), } - + def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]: """准备默认模式的参数""" return { @@ -769,11 +760,13 @@ class Prompt: "reason": "", "mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""), "reply_style": global_config.personality.reply_style, - "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""), + "keywords_reaction_prompt": self.parameters.keywords_reaction_prompt + or context_data.get("keywords_reaction_prompt", ""), "moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""), - "safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""), + "safety_guidelines_block": self.parameters.safety_guidelines_block + or context_data.get("safety_guidelines_block", ""), } - + def format(self, *args, **kwargs) -> str: """格式化模板,支持位置参数和关键字参数""" try: @@ -786,21 +779,21 @@ class Prompt: processed_template = self._processed_template.format(**formatted_args) else: processed_template = self._processed_template - + # 再用关键字参数格式化 if kwargs: processed_template = processed_template.format(**kwargs) - + # 将临时标记还原为实际的花括号 result = self._restore_escaped_braces(processed_template) return result except (IndexError, KeyError) as e: raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {str(e)}") from e - + def __str__(self) -> str: """返回格式化后的结果或原始模板""" return self._formatted_result if self._formatted_result else self.template - + def __repr__(self) -> str: """返回提示词的表示形式""" return f"Prompt(template='{self.template}', name='{self.name}')" @@ -872,9 +865,7 @@ class Prompt: return await relationship_fetcher.build_relation_info(person_id, points_num=5) @staticmethod - async def build_cross_context( - chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]] - ) -> str: + async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]) -> str: """ 构建跨群聊上下文 - 统一实现 @@ -890,7 +881,7 @@ class Prompt: return "" from src.plugin_system.apis import cross_context_api - + other_chat_raw_ids = cross_context_api.get_context_groups(chat_id) if not other_chat_raw_ids: return "" @@ -937,10 +928,7 @@ class Prompt: # 工厂函数 def create_prompt( - template: str, - name: Optional[str] = None, - parameters: Optional[PromptParameters] = None, - **kwargs + template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs ) -> Prompt: """快速创建Prompt实例的工厂函数""" if parameters is None: @@ -949,14 +937,10 @@ def create_prompt( async def create_prompt_async( - template: str, - name: Optional[str] = None, - parameters: Optional[PromptParameters] = None, - **kwargs + template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs ) -> Prompt: """异步创建Prompt实例""" prompt = create_prompt(template, name, parameters, **kwargs) if global_prompt_manager._context._current_context: await global_prompt_manager._context.register_async(prompt) return prompt - diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 675bf4b85..38780ec3f 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -332,9 +332,9 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese if global_config.response_splitter.enable and enable_splitter: logger.info(f"回复分割器已启用,模式: {global_config.response_splitter.split_mode}。") - + split_mode = global_config.response_splitter.split_mode - + if split_mode == "llm" and "[SPLIT]" in cleaned_text: logger.debug("检测到 [SPLIT] 标记,使用 LLM 自定义分割。") split_sentences_raw = cleaned_text.split("[SPLIT]") @@ -343,7 +343,7 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese if split_mode == "llm": logger.debug("未检测到 [SPLIT] 标记,本次不进行分割。") split_sentences = [cleaned_text] - else: # mode == "punctuation" + else: # mode == "punctuation" logger.debug("使用基于标点的传统模式进行分割。") split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text) else: diff --git a/src/common/data_models/__init__.py b/src/common/data_models/__init__.py index 222ff59ca..d104eec9c 100644 --- a/src/common/data_models/__init__.py +++ b/src/common/data_models/__init__.py @@ -6,6 +6,7 @@ class BaseDataModel: def deepcopy(self): return copy.deepcopy(self) + def temporarily_transform_class_to_dict(obj: Any) -> Any: # sourcery skip: assign-if-exp, reintroduce-else """ diff --git a/src/common/data_models/bot_interest_data_model.py b/src/common/data_models/bot_interest_data_model.py index e0f86237f..819b50a8f 100644 --- a/src/common/data_models/bot_interest_data_model.py +++ b/src/common/data_models/bot_interest_data_model.py @@ -2,6 +2,7 @@ 机器人兴趣标签数据模型 定义机器人的兴趣标签和相关的embedding数据结构 """ + from dataclasses import dataclass, field from typing import List, Dict, Optional, Any from datetime import datetime @@ -12,6 +13,7 @@ from . import BaseDataModel @dataclass class BotInterestTag(BaseDataModel): """机器人兴趣标签""" + tag_name: str weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0) embedding: Optional[List[float]] = None # 标签的embedding向量 @@ -27,7 +29,7 @@ class BotInterestTag(BaseDataModel): "embedding": self.embedding, "created_at": self.created_at.isoformat(), "updated_at": self.updated_at.isoformat(), - "is_active": self.is_active + "is_active": self.is_active, } @classmethod @@ -39,13 +41,14 @@ class BotInterestTag(BaseDataModel): embedding=data.get("embedding"), created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(), updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else datetime.now(), - is_active=data.get("is_active", True) + is_active=data.get("is_active", True), ) @dataclass class BotPersonalityInterests(BaseDataModel): """机器人人格化兴趣配置""" + personality_id: str personality_description: str # 人设描述文本 interest_tags: List[BotInterestTag] = field(default_factory=list) @@ -57,7 +60,6 @@ class BotPersonalityInterests(BaseDataModel): """获取活跃的兴趣标签""" return [tag for tag in self.interest_tags if tag.is_active] - def to_dict(self) -> Dict[str, Any]: """转换为字典格式""" return { @@ -66,7 +68,7 @@ class BotPersonalityInterests(BaseDataModel): "interest_tags": [tag.to_dict() for tag in self.interest_tags], "embedding_model": self.embedding_model, "last_updated": self.last_updated.isoformat(), - "version": self.version + "version": self.version, } @classmethod @@ -78,13 +80,14 @@ class BotPersonalityInterests(BaseDataModel): interest_tags=[BotInterestTag.from_dict(tag_data) for tag_data in data.get("interest_tags", [])], embedding_model=data.get("embedding_model", "text-embedding-ada-002"), last_updated=datetime.fromisoformat(data["last_updated"]) if data.get("last_updated") else datetime.now(), - version=data.get("version", 1) + version=data.get("version", 1), ) @dataclass class InterestMatchResult(BaseDataModel): """兴趣匹配结果""" + message_id: str matched_tags: List[str] = field(default_factory=list) match_scores: Dict[str, float] = field(default_factory=dict) # tag_name -> score @@ -120,7 +123,9 @@ class InterestMatchResult(BaseDataModel): # 计算置信度(基于匹配标签数量和分数分布) if len(self.match_scores) > 0: avg_score = self.overall_score - score_variance = sum((score - avg_score) ** 2 for score in self.match_scores.values()) / len(self.match_scores) + score_variance = sum((score - avg_score) ** 2 for score in self.match_scores.values()) / len( + self.match_scores + ) # 分数越集中,置信度越高 self.confidence = max(0.0, 1.0 - score_variance) else: @@ -129,4 +134,4 @@ class InterestMatchResult(BaseDataModel): def get_top_matches(self, top_n: int = 3) -> List[tuple]: """获取前N个最佳匹配""" sorted_matches = sorted(self.match_scores.items(), key=lambda x: x[1], reverse=True) - return sorted_matches[:top_n] \ No newline at end of file + return sorted_matches[:top_n] diff --git a/src/common/data_models/database_data_model.py b/src/common/data_models/database_data_model.py index 7167c64cb..4d2e00e3b 100644 --- a/src/common/data_models/database_data_model.py +++ b/src/common/data_models/database_data_model.py @@ -208,6 +208,7 @@ class DatabaseMessages(BaseDataModel): "chat_info_user_cardname": self.chat_info.user_info.user_cardname, } + @dataclass(init=False) class DatabaseActionRecords(BaseDataModel): def __init__( @@ -235,4 +236,4 @@ class DatabaseActionRecords(BaseDataModel): self.action_prompt_display = action_prompt_display self.chat_id = chat_id self.chat_info_stream_id = chat_info_stream_id - self.chat_info_platform = chat_info_platform \ No newline at end of file + self.chat_info_platform = chat_info_platform diff --git a/src/common/data_models/info_data_model.py b/src/common/data_models/info_data_model.py index 0e3cfd35d..5351ab76a 100644 --- a/src/common/data_models/info_data_model.py +++ b/src/common/data_models/info_data_model.py @@ -28,6 +28,7 @@ class ActionPlannerInfo(BaseDataModel): @dataclass class InterestScore(BaseDataModel): """兴趣度评分结果""" + message_id: str total_score: float interest_match_score: float @@ -41,6 +42,7 @@ class Plan(BaseDataModel): """ 统一规划数据模型 """ + chat_id: str mode: "ChatMode" diff --git a/src/common/data_models/llm_data_model.py b/src/common/data_models/llm_data_model.py index 1d5b75e0c..d862e9b54 100644 --- a/src/common/data_models/llm_data_model.py +++ b/src/common/data_models/llm_data_model.py @@ -2,9 +2,11 @@ from dataclasses import dataclass from typing import Optional, List, Tuple, TYPE_CHECKING, Any from . import BaseDataModel + if TYPE_CHECKING: from src.llm_models.payload_content.tool_option import ToolCall + @dataclass class LLMGenerationDataModel(BaseDataModel): content: Optional[str] = None @@ -13,4 +15,4 @@ class LLMGenerationDataModel(BaseDataModel): tool_calls: Optional[List["ToolCall"]] = None prompt: Optional[str] = None selected_expressions: Optional[List[int]] = None - reply_set: Optional[List[Tuple[str, Any]]] = None \ No newline at end of file + reply_set: Optional[List[Tuple[str, Any]]] = None diff --git a/src/common/data_models/message_manager_data_model.py b/src/common/data_models/message_manager_data_model.py index a54cb826b..27ed03759 100644 --- a/src/common/data_models/message_manager_data_model.py +++ b/src/common/data_models/message_manager_data_model.py @@ -2,6 +2,7 @@ 消息管理模块数据模型 定义消息管理器使用的数据结构 """ + import asyncio import time from dataclasses import dataclass, field @@ -16,14 +17,16 @@ if TYPE_CHECKING: class MessageStatus(Enum): """消息状态枚举""" - UNREAD = "unread" # 未读消息 - READ = "read" # 已读消息 + + UNREAD = "unread" # 未读消息 + READ = "read" # 已读消息 PROCESSING = "processing" # 处理中 @dataclass class StreamContext(BaseDataModel): """聊天流上下文信息""" + stream_id: str unread_messages: List["DatabaseMessages"] = field(default_factory=list) history_messages: List["DatabaseMessages"] = field(default_factory=list) @@ -59,6 +62,7 @@ class StreamContext(BaseDataModel): @dataclass class MessageManagerStats(BaseDataModel): """消息管理器统计信息""" + total_streams: int = 0 active_streams: int = 0 total_unread_messages: int = 0 @@ -74,9 +78,10 @@ class MessageManagerStats(BaseDataModel): @dataclass class StreamStats(BaseDataModel): """聊天流统计信息""" + stream_id: str is_active: bool unread_count: int history_count: int last_check_time: float - has_active_task: bool \ No newline at end of file + has_active_task: bool diff --git a/src/common/message/api.py b/src/common/message/api.py index d24574d6e..37b7a7ddc 100644 --- a/src/common/message/api.py +++ b/src/common/message/api.py @@ -23,15 +23,15 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method maim_message_config = global_config.maim_message # 设置基本参数 - + host = os.getenv("HOST", "127.0.0.1") port_str = os.getenv("PORT", "8000") - + try: port = int(port_str) except ValueError: port = 8000 - + kwargs = { "host": host, "port": port, diff --git a/src/common/remote.py b/src/common/remote.py index 2aa750449..95202f810 100644 --- a/src/common/remote.py +++ b/src/common/remote.py @@ -31,7 +31,9 @@ class TelemetryHeartBeatTask(AsyncTask): self.client_uuid: str | None = local_storage["mofox_uuid"] if "mofox_uuid" in local_storage else None # type: ignore """客户端UUID""" - self.private_key_pem: str | None = local_storage["mofox_private_key"] if "mofox_private_key" in local_storage else None # type: ignore + self.private_key_pem: str | None = ( + local_storage["mofox_private_key"] if "mofox_private_key" in local_storage else None + ) # type: ignore """客户端私钥""" self.info_dict = self._get_sys_info() @@ -61,78 +63,65 @@ class TelemetryHeartBeatTask(AsyncTask): def _generate_signature(self, request_body: dict) -> tuple[str, str]: """ 生成RSA签名 - + Returns: tuple[str, str]: (timestamp, signature_b64) """ if not self.private_key_pem: raise ValueError("私钥未初始化") - + # 生成时间戳 timestamp = datetime.now(timezone.utc).isoformat() - + # 创建签名数据字符串 sign_data = f"{self.client_uuid}:{timestamp}:{json.dumps(request_body, separators=(',', ':'))}" - + # 加载私钥 - private_key = serialization.load_pem_private_key( - self.private_key_pem.encode('utf-8'), - password=None - ) - + private_key = serialization.load_pem_private_key(self.private_key_pem.encode("utf-8"), password=None) + # 确保是RSA私钥 if not isinstance(private_key, rsa.RSAPrivateKey): raise ValueError("私钥必须是RSA格式") - + # 生成签名 signature = private_key.sign( - sign_data.encode('utf-8'), - padding.PSS( - mgf=padding.MGF1(hashes.SHA256()), - salt_length=padding.PSS.MAX_LENGTH - ), - hashes.SHA256() + sign_data.encode("utf-8"), + padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH), + hashes.SHA256(), ) - + # Base64编码 - signature_b64 = base64.b64encode(signature).decode('utf-8') - + signature_b64 = base64.b64encode(signature).decode("utf-8") + return timestamp, signature_b64 def _decrypt_challenge(self, challenge_b64: str) -> str: """ 解密挑战数据 - + Args: challenge_b64: Base64编码的挑战数据 - + Returns: str: 解密后的UUID字符串 """ if not self.private_key_pem: raise ValueError("私钥未初始化") - + # 加载私钥 - private_key = serialization.load_pem_private_key( - self.private_key_pem.encode('utf-8'), - password=None - ) - + private_key = serialization.load_pem_private_key(self.private_key_pem.encode("utf-8"), password=None) + # 确保是RSA私钥 if not isinstance(private_key, rsa.RSAPrivateKey): raise ValueError("私钥必须是RSA格式") - + # 解密挑战数据 decrypted_bytes = private_key.decrypt( base64.b64decode(challenge_b64), - padding.OAEP( - mgf=padding.MGF1(hashes.SHA256()), - algorithm=hashes.SHA256(), - label=None - ) + padding.OAEP(mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None), ) - - return decrypted_bytes.decode('utf-8') + + return decrypted_bytes.decode("utf-8") async def _req_uuid(self) -> bool: """ @@ -155,28 +144,26 @@ class TelemetryHeartBeatTask(AsyncTask): if response.status != 200: response_text = await response.text() - logger.error( - f"注册步骤1失败,状态码: {response.status}, 响应内容: {response_text}" - ) + logger.error(f"注册步骤1失败,状态码: {response.status}, 响应内容: {response_text}") raise aiohttp.ClientResponseError( request_info=response.request_info, history=response.history, status=response.status, - message=f"Step1 failed: {response_text}" + message=f"Step1 failed: {response_text}", ) step1_data = await response.json() temp_uuid = step1_data.get("temp_uuid") private_key = step1_data.get("private_key") challenge = step1_data.get("challenge") - + if not all([temp_uuid, private_key, challenge]): logger.error("Step1响应缺少必要字段:temp_uuid, private_key 或 challenge") raise ValueError("Step1响应数据不完整") # 临时保存私钥用于解密 self.private_key_pem = private_key - + # 解密挑战数据 logger.debug("解密挑战数据...") try: @@ -184,21 +171,18 @@ class TelemetryHeartBeatTask(AsyncTask): except Exception as e: logger.error(f"解密挑战数据失败: {e}") raise - + # 验证解密结果 if decrypted_uuid != temp_uuid: logger.error(f"解密结果验证失败: 期望 {temp_uuid}, 实际 {decrypted_uuid}") raise ValueError("解密结果与临时UUID不匹配") - + logger.debug("挑战数据解密成功,开始注册步骤2") # Step 2: 发送解密结果完成注册 async with session.post( f"{TELEMETRY_SERVER_URL}/stat/reg_client_step2", - json={ - "temp_uuid": temp_uuid, - "decrypted_uuid": decrypted_uuid - }, + json={"temp_uuid": temp_uuid, "decrypted_uuid": decrypted_uuid}, timeout=aiohttp.ClientTimeout(total=5), ) as response: logger.debug(f"Step2 Response status: {response.status}") @@ -206,7 +190,7 @@ class TelemetryHeartBeatTask(AsyncTask): if response.status == 200: step2_data = await response.json() mofox_uuid = step2_data.get("mofox_uuid") - + if mofox_uuid: # 将正式UUID和私钥存储到本地 local_storage["mofox_uuid"] = mofox_uuid @@ -225,23 +209,19 @@ class TelemetryHeartBeatTask(AsyncTask): raise ValueError(f"Step2失败: {response_text}") else: response_text = await response.text() - logger.error( - f"注册步骤2失败,状态码: {response.status}, 响应内容: {response_text}" - ) + logger.error(f"注册步骤2失败,状态码: {response.status}, 响应内容: {response_text}") raise aiohttp.ClientResponseError( request_info=response.request_info, history=response.history, status=response.status, - message=f"Step2 failed: {response_text}" + message=f"Step2 failed: {response_text}", ) except Exception as e: import traceback error_msg = str(e) or "未知错误" - logger.warning( - f"注册客户端出错,不过你还是可以正常使用墨狐: {type(e).__name__}: {error_msg}" - ) + logger.warning(f"注册客户端出错,不过你还是可以正常使用墨狐: {type(e).__name__}: {error_msg}") logger.debug(f"完整错误信息: {traceback.format_exc()}") # 请求失败,重试次数+1 @@ -264,13 +244,13 @@ class TelemetryHeartBeatTask(AsyncTask): try: # 生成签名 timestamp, signature = self._generate_signature(self.info_dict) - + headers = { "X-mofox-UUID": self.client_uuid, "X-mofox-Signature": signature, "X-mofox-Timestamp": timestamp, "User-Agent": f"MofoxClient/{self.client_uuid[:8]}", - "Content-Type": "application/json" + "Content-Type": "application/json", } logger.debug(f"正在发送心跳到服务器: {self.server_url}") @@ -347,4 +327,4 @@ class TelemetryHeartBeatTask(AsyncTask): logger.warning("客户端注册失败,跳过此次心跳") return - await self._send_heartbeat() \ No newline at end of file + await self._send_heartbeat() diff --git a/src/common/server.py b/src/common/server.py index 3263589a2..6d6dbfd4e 100644 --- a/src/common/server.py +++ b/src/common/server.py @@ -99,14 +99,13 @@ def get_global_server() -> Server: """获取全局服务器实例""" global global_server if global_server is None: - host = os.getenv("HOST", "127.0.0.1") port_str = os.getenv("PORT", "8000") - + try: port = int(port_str) except ValueError: port = 8000 - + global_server = Server(host=host, port=port) return global_server diff --git a/src/config/config.py b/src/config/config.py index 2cebe54a8..fdd450e01 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -44,7 +44,7 @@ from src.config.official_configs import ( PermissionConfig, CommandConfig, PlanningSystemConfig, - AffinityFlowConfig + AffinityFlowConfig, ) from .api_ada_configs import ( @@ -399,9 +399,7 @@ class Config(ValidatedConfigBase): cross_context: CrossContextConfig = Field( default_factory=lambda: CrossContextConfig(), description="跨群聊上下文共享配置" ) - affinity_flow: AffinityFlowConfig = Field( - default_factory=lambda: AffinityFlowConfig(), description="亲和流配置" - ) + affinity_flow: AffinityFlowConfig = Field(default_factory=lambda: AffinityFlowConfig(), description="亲和流配置") class APIAdapterConfig(ValidatedConfigBase): diff --git a/src/config/official_configs.py b/src/config/official_configs.py index f98da99fb..d97e92f56 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -51,8 +51,12 @@ class PersonalityConfig(ValidatedConfigBase): personality_core: str = Field(..., description="核心人格") personality_side: str = Field(..., description="人格侧写") identity: str = Field(default="", description="身份特征") - background_story: str = Field(default="", description="世界观背景故事,这部分内容会作为背景知识,LLM被指导不应主动复述") - safety_guidelines: List[str] = Field(default_factory=list, description="安全与互动底线,Bot在任何情况下都必须遵守的原则") + background_story: str = Field( + default="", description="世界观背景故事,这部分内容会作为背景知识,LLM被指导不应主动复述" + ) + safety_guidelines: List[str] = Field( + default_factory=list, description="安全与互动底线,Bot在任何情况下都必须遵守的原则" + ) reply_style: str = Field(default="", description="表达风格") prompt_mode: Literal["s4u", "normal"] = Field(default="s4u", description="Prompt模式") compress_personality: bool = Field(default=True, description="是否压缩人格") @@ -79,7 +83,8 @@ class ChatConfig(ValidatedConfigBase): talk_frequency_adjust: list[list[str]] = Field(default_factory=lambda: [], description="聊天频率调整") focus_value: float = Field(default=1.0, description="专注值") focus_mode_quiet_groups: List[str] = Field( - default_factory=list, description='专注模式下需要保持安静的群组列表, 格式: ["platform:group_id1", "platform:group_id2"]' + default_factory=list, + description='专注模式下需要保持安静的群组列表, 格式: ["platform:group_id1", "platform:group_id2"]', ) force_reply_private: bool = Field(default=False, description="强制回复私聊") group_chat_mode: Literal["auto", "normal", "focus"] = Field(default="auto", description="群聊模式") @@ -343,6 +348,7 @@ class ExpressionConfig(ValidatedConfigBase): # 如果都没有匹配,返回默认值 return True, True, 1.0 + class ToolConfig(ValidatedConfigBase): """工具配置类""" @@ -477,7 +483,6 @@ class ExperimentalConfig(ValidatedConfigBase): pfc_chatting: bool = Field(default=False, description="启用PFC聊天") - class MaimMessageConfig(ValidatedConfigBase): """maim_message配置类""" @@ -602,8 +607,12 @@ class SleepSystemConfig(ValidatedConfigBase): sleep_by_schedule: bool = Field(default=True, description="是否根据日程表进行睡觉") fixed_sleep_time: str = Field(default="23:00", description="固定的睡觉时间") fixed_wake_up_time: str = Field(default="07:00", description="固定的起床时间") - sleep_time_offset_minutes: int = Field(default=15, ge=0, le=60, description="睡觉时间随机偏移量范围(分钟),实际睡觉时间会在±该值范围内随机") - wake_up_time_offset_minutes: int = Field(default=15, ge=0, le=60, description="起床时间随机偏移量范围(分钟),实际起床时间会在±该值范围内随机") + sleep_time_offset_minutes: int = Field( + default=15, ge=0, le=60, description="睡觉时间随机偏移量范围(分钟),实际睡觉时间会在±该值范围内随机" + ) + wake_up_time_offset_minutes: int = Field( + default=15, ge=0, le=60, description="起床时间随机偏移量范围(分钟),实际起床时间会在±该值范围内随机" + ) wakeup_threshold: float = Field(default=15.0, ge=1.0, description="唤醒阈值,达到此值时会被唤醒") private_message_increment: float = Field(default=3.0, ge=0.1, description="私聊消息增加的唤醒度") group_mention_increment: float = Field(default=2.0, ge=0.1, description="群聊艾特增加的唤醒度") @@ -618,10 +627,10 @@ class SleepSystemConfig(ValidatedConfigBase): # --- 失眠机制相关参数 --- enable_insomnia_system: bool = Field(default=True, description="是否启用失眠系统") insomnia_trigger_delay_minutes: List[int] = Field( - default_factory=lambda:[30, 60], description="入睡后触发失眠判定的延迟时间范围(分钟)" + default_factory=lambda: [30, 60], description="入睡后触发失眠判定的延迟时间范围(分钟)" ) insomnia_duration_minutes: List[int] = Field( - default_factory=lambda:[15, 45], description="单次失眠状态的持续时间范围(分钟)" + default_factory=lambda: [15, 45], description="单次失眠状态的持续时间范围(分钟)" ) sleep_pressure_threshold: float = Field(default=30.0, description="触发“压力不足型失眠”的睡眠压力阈值") deep_sleep_threshold: float = Field(default=80.0, description="进入“深度睡眠”的睡眠压力阈值") @@ -657,6 +666,8 @@ class CrossContextConfig(ValidatedConfigBase): enable: bool = Field(default=False, description="是否启用跨群聊上下文共享功能") groups: List[ContextGroup] = Field(default_factory=list, description="上下文共享组列表") + + class CommandConfig(ValidatedConfigBase): """命令系统配置类""" diff --git a/src/individuality/individuality.py b/src/individuality/individuality.py index 09bd3ad00..342bfaab5 100644 --- a/src/individuality/individuality.py +++ b/src/individuality/individuality.py @@ -88,8 +88,7 @@ class Individuality: # 初始化智能兴趣系统 await interest_scoring_system.initialize_smart_interests( - personality_description=full_personality, - personality_id=self.bot_person_id + personality_description=full_personality, personality_id=self.bot_person_id ) logger.info("智能兴趣系统初始化完成") diff --git a/src/main.py b/src/main.py index 2d2cd6db5..9faee813d 100644 --- a/src/main.py +++ b/src/main.py @@ -130,7 +130,8 @@ class MainSystem: # 停止消息重组器 from src.plugin_system.core.event_manager import event_manager from src.plugin_system import EventType - asyncio.run(event_manager.trigger_event(EventType.ON_STOP,permission_group="SYSTEM")) + + asyncio.run(event_manager.trigger_event(EventType.ON_STOP, permission_group="SYSTEM")) from src.utils.message_chunker import reassembler loop = asyncio.get_event_loop() @@ -216,7 +217,7 @@ MoFox_Bot(第三方修改版) # 添加统计信息输出任务 await async_task_manager.add_task(StatisticOutputTask()) - + # 添加遥测心跳任务 await async_task_manager.add_task(TelemetryHeartBeatTask()) @@ -250,6 +251,7 @@ MoFox_Bot(第三方修改版) # 初始化回复后关系追踪系统 from src.chat.affinity_flow.relationship_integration import initialize_relationship_tracking + relationship_tracker = initialize_relationship_tracking() if relationship_tracker: logger.info("回复后关系追踪系统初始化成功") @@ -273,6 +275,7 @@ MoFox_Bot(第三方修改版) # 初始化LPMM知识库 from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge + initialize_lpmm_knowledge() logger.info("LPMM知识库初始化成功") @@ -298,6 +301,7 @@ MoFox_Bot(第三方修改版) # 启动消息管理器 from src.chat.message_manager import message_manager + await message_manager.start() logger.info("消息管理器已启动") diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index e0c582b10..3a63387cd 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -96,12 +96,13 @@ class PersonInfoManager: # 在此处打一个补丁,如果platform为qq,尝试生成id后检查是否存在,如果不存在,则将平台换为napcat后再次检查,如果存在,则更新原id为platform为qq的id components = [platform, str(user_id)] key = "_".join(components) - + # 如果不是 qq 平台,直接返回计算的 id if platform != "qq": return hashlib.md5(key.encode()).hexdigest() qq_id = hashlib.md5(key.encode()).hexdigest() + # 对于 qq 平台,先检查该 person_id 是否已存在;如果存在直接返回 def _db_check_and_migrate_sync(p_id: str, raw_user_id: str): try: @@ -191,16 +192,16 @@ class PersonInfoManager: # Ensure person_id is correctly set from the argument final_data["person_id"] = person_id # 你们的英文注释是何意味? - + # 检查并修复关键字段为None的情况喵 if final_data.get("user_id") is None: logger.warning(f"user_id为None,使用'unknown'作为默认值 person_id={person_id}") final_data["user_id"] = "unknown" - + if final_data.get("platform") is None: logger.warning(f"platform为None,使用'unknown'作为默认值 person_id={person_id}") final_data["platform"] = "unknown" - + # 这里的目的是为了防止在识别出错的情况下有一个最小回退,不只是针对@消息识别成视频后的报错问题 # Serialize JSON fields @@ -251,12 +252,12 @@ class PersonInfoManager: # Ensure person_id is correctly set from the argument final_data["person_id"] = person_id - + # 检查并修复关键字段为None的情况 if final_data.get("user_id") is None: logger.warning(f"user_id为None,使用'unknown'作为默认值 person_id={person_id}") final_data["user_id"] = "unknown" - + if final_data.get("platform") is None: logger.warning(f"platform为None,使用'unknown'作为默认值 person_id={person_id}") final_data["platform"] = "unknown" @@ -356,12 +357,12 @@ class PersonInfoManager: creation_data["platform"] = data["platform"] if data and "user_id" in data: creation_data["user_id"] = data["user_id"] - + # 额外检查关键字段,如果为None则使用默认值 if creation_data.get("user_id") is None: logger.warning(f"创建用户时user_id为None,使用'unknown'作为默认值 person_id={person_id}") creation_data["user_id"] = "unknown" - + if creation_data.get("platform") is None: logger.warning(f"创建用户时platform为None,使用'unknown'作为默认值 person_id={person_id}") creation_data["platform"] = "unknown" diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index 552d0878c..89632dd73 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -123,7 +123,9 @@ class RelationshipFetcher: all_points = current_points + forgotten_points if all_points: # 按权重和时效性综合排序 - all_points.sort(key=lambda x: (float(x[1]) if len(x) > 1 else 0, float(x[2]) if len(x) > 2 else 0), reverse=True) + all_points.sort( + key=lambda x: (float(x[1]) if len(x) > 1 else 0, float(x[2]) if len(x) > 2 else 0), reverse=True + ) selected_points = all_points[:points_num] points_text = "\n".join([f"- {point[0]}({point[2]})" for point in selected_points if len(point) > 2]) else: @@ -139,15 +141,17 @@ class RelationshipFetcher: # 2. 认识时间和频率 if know_since: from datetime import datetime - know_time = datetime.fromtimestamp(know_since).strftime('%Y年%m月%d日') + + know_time = datetime.fromtimestamp(know_since).strftime("%Y年%m月%d日") relation_parts.append(f"你从{know_time}开始认识{person_name}") - + if know_times > 0: relation_parts.append(f"你们已经交流过{int(know_times)}次") - + if last_know: from datetime import datetime - last_time = datetime.fromtimestamp(last_know).strftime('%m月%d日') + + last_time = datetime.fromtimestamp(last_know).strftime("%m月%d日") relation_parts.append(f"最近一次交流是在{last_time}") # 3. 态度和印象 @@ -156,7 +160,7 @@ class RelationshipFetcher: if short_impression: relation_parts.append(f"你对ta的总体印象:{short_impression}") - + if full_impression: relation_parts.append(f"更详细的了解:{full_impression}") @@ -168,14 +172,14 @@ class RelationshipFetcher: try: from src.common.database.sqlalchemy_database_api import db_query from src.common.database.sqlalchemy_models import UserRelationships - + # 查询用户关系数据 relationships = await db_query( UserRelationships, filters=[UserRelationships.user_id == str(person_info_manager.get_value_sync(person_id, "user_id"))], - limit=1 + limit=1, ) - + if relationships: rel_data = relationships[0] if rel_data.relationship_text: @@ -183,13 +187,15 @@ class RelationshipFetcher: if rel_data.relationship_score: score_desc = self._get_relationship_score_description(rel_data.relationship_score) relation_parts.append(f"关系亲密程度:{score_desc}") - + except Exception as e: logger.debug(f"查询UserRelationships表失败: {e}") # 构建最终的关系信息字符串 if relation_parts: - relation_info = f"关于{person_name},你知道以下信息:\n" + "\n".join([f"• {part}" for part in relation_parts]) + relation_info = f"关于{person_name},你知道以下信息:\n" + "\n".join( + [f"• {part}" for part in relation_parts] + ) else: relation_info = f"你对{person_name}了解不多,这是比较初步的交流。" diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index 9400032f8..725619adb 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -93,7 +93,6 @@ class BaseAction(ABC): self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy() self.chat_type_allow: ChatType = getattr(self.__class__, "chat_type_allow", ChatType.ALL) - # ============================================================================= # 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解) # ============================================================================= @@ -398,6 +397,7 @@ class BaseAction(ABC): try: # 1. 从注册中心获取Action类 from src.plugin_system.core.component_registry import component_registry + action_class = component_registry.get_component_class(action_name, ComponentType.ACTION) if not action_class: logger.error(f"{log_prefix} 未找到Action: {action_name}") @@ -406,7 +406,7 @@ class BaseAction(ABC): # 2. 准备实例化参数 # 复用当前Action的大部分上下文信息 called_action_data = action_data if action_data is not None else self.action_data - + component_info = component_registry.get_component_info(action_name, ComponentType.ACTION) if not component_info: logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}") diff --git a/src/plugin_system/base/base_events_handler.py b/src/plugin_system/base/base_events_handler.py index 6b8ed1d73..517de92c2 100644 --- a/src/plugin_system/base/base_events_handler.py +++ b/src/plugin_system/base/base_events_handler.py @@ -98,7 +98,7 @@ class BaseEventHandler(ABC): weight=cls.weight, intercept_message=cls.intercept_message, ) - + def set_plugin_name(self, plugin_name: str) -> None: """设置插件名称 @@ -107,9 +107,9 @@ class BaseEventHandler(ABC): """ self.plugin_name = plugin_name - def set_plugin_config(self,plugin_config) -> None: + def set_plugin_config(self, plugin_config) -> None: self.plugin_config = plugin_config - + def get_config(self, key: str, default=None): """获取插件配置值,支持嵌套键访问 diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index 0bcb0060e..a939d0ab5 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -69,7 +69,7 @@ class EventType(Enum): """ ON_START = "on_start" # 启动事件,用于调用按时任务 - ON_STOP ="on_stop" + ON_STOP = "on_stop" ON_MESSAGE = "on_message" ON_PLAN = "on_plan" POST_LLM = "post_llm" diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 529f327a3..b782a9292 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -270,7 +270,9 @@ class ComponentRegistry: # 使用EventManager进行事件处理器注册 from src.plugin_system.core.event_manager import event_manager - return event_manager.register_event_handler(handler_class,self.get_plugin_config(handler_info.plugin_name) or {}) + return event_manager.register_event_handler( + handler_class, self.get_plugin_config(handler_info.plugin_name) or {} + ) # === 组件移除相关 === @@ -682,19 +684,20 @@ class ComponentRegistry: plugin_instance = plugin_manager.get_plugin_instance(plugin_name) if plugin_instance and plugin_instance.config: return plugin_instance.config - + # 如果插件实例不存在,尝试从配置文件读取 try: import toml + config_path = Path("config") / "plugins" / plugin_name / "config.toml" if config_path.exists(): - with open(config_path, 'r', encoding='utf-8') as f: + with open(config_path, "r", encoding="utf-8") as f: config_data = toml.load(f) logger.debug(f"从配置文件读取插件 {plugin_name} 的配置") return config_data except Exception as e: logger.debug(f"读取插件 {plugin_name} 配置文件失败: {e}") - + return {} def get_registry_stats(self) -> Dict[str, Any]: diff --git a/src/plugin_system/core/event_manager.py b/src/plugin_system/core/event_manager.py index f359409af..aaf3f3dc2 100644 --- a/src/plugin_system/core/event_manager.py +++ b/src/plugin_system/core/event_manager.py @@ -145,7 +145,9 @@ class EventManager: logger.info(f"事件 {event_name} 已禁用") return True - def register_event_handler(self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None) -> bool: + def register_event_handler( + self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None + ) -> bool: """注册事件处理器 Args: @@ -167,7 +169,7 @@ class EventManager: # 创建事件处理器实例,传递插件配置 handler_instance = handler_class() handler_instance.plugin_config = plugin_config - if plugin_config is not None and hasattr(handler_instance, 'set_plugin_config'): + if plugin_config is not None and hasattr(handler_instance, "set_plugin_config"): handler_instance.set_plugin_config(plugin_config) self._event_handlers[handler_name] = handler_instance diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index 07d33b773..05bb8bf1b 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -199,9 +199,7 @@ class PluginManager: self._show_plugin_components(plugin_name) # 检查并调用 on_plugin_loaded 钩子(如果存在) - if hasattr(plugin_instance, "on_plugin_loaded") and callable( - plugin_instance.on_plugin_loaded - ): + if hasattr(plugin_instance, "on_plugin_loaded") and callable(plugin_instance.on_plugin_loaded): logger.debug(f"为插件 '{plugin_name}' 调用 on_plugin_loaded 钩子") try: # 使用 asyncio.create_task 确保它不会阻塞加载流程 diff --git a/src/plugins/built_in/at_user_plugin/plugin.py b/src/plugins/built_in/at_user_plugin/plugin.py index ba40903cd..820b37a27 100644 --- a/src/plugins/built_in/at_user_plugin/plugin.py +++ b/src/plugins/built_in/at_user_plugin/plugin.py @@ -64,50 +64,50 @@ class AtAction(BaseAction): # 使用回复器生成艾特回复,而不是直接发送命令 from src.chat.replyer.default_generator import DefaultReplyer from src.chat.message_receive.chat_stream import get_chat_manager - + # 获取当前聊天流 chat_manager = get_chat_manager() chat_stream = self.chat_stream or chat_manager.get_stream(self.chat_id) - + if not chat_stream: logger.error(f"找不到聊天流: {self.chat_stream}") return False, "聊天流不存在" - + # 创建回复器实例 replyer = DefaultReplyer(chat_stream) - + # 构建回复对象,将艾特消息作为回复目标 reply_to = f"{user_name}:{at_message}" extra_info = f"你需要艾特用户 {user_name} 并回复他们说: {at_message}" - + # 使用回复器生成回复 success, llm_response, prompt = await replyer.generate_reply_with_context( reply_to=reply_to, extra_info=extra_info, enable_tool=False, # 艾特回复通常不需要工具调用 - from_plugin=False + from_plugin=False, ) - + if success and llm_response: # 获取生成的回复内容 reply_content = llm_response.get("content", "") if reply_content: # 获取用户QQ号,发送真正的艾特消息 user_id = user_info.get("user_id") - + # 发送真正的艾特命令,使用回复器生成的智能内容 await self.send_command( "SEND_AT_MESSAGE", args={"qq_id": user_id, "text": reply_content}, display_message=f"艾特用户 {user_name} 并发送智能回复: {reply_content}", ) - + await self.store_action_info( action_build_into_prompt=True, action_prompt_display=f"执行了艾特用户动作:艾特用户 {user_name} 并发送智能回复: {reply_content}", action_done=True, ) - + logger.info(f"成功通过回复器生成智能内容并发送真正的艾特消息给 {user_name}: {reply_content}") return True, "智能艾特消息发送成功" else: @@ -116,7 +116,7 @@ class AtAction(BaseAction): else: logger.error("回复器生成回复失败") return False, "回复生成失败" - + except Exception as e: logger.error(f"执行艾特用户动作时发生异常: {e}", exc_info=True) await self.store_action_info( diff --git a/src/plugins/built_in/core_actions/emoji.py b/src/plugins/built_in/core_actions/emoji.py index 4375ae1a2..fe03f4478 100644 --- a/src/plugins/built_in/core_actions/emoji.py +++ b/src/plugins/built_in/core_actions/emoji.py @@ -70,7 +70,9 @@ class EmojiAction(BaseAction): # 2. 获取所有有效的表情包对象 emoji_manager = get_emoji_manager() - all_emojis_obj: list[MaiEmoji] = [e for e in emoji_manager.emoji_objects if not e.is_deleted and e.description] + all_emojis_obj: list[MaiEmoji] = [ + e for e in emoji_manager.emoji_objects if not e.is_deleted and e.description + ] if not all_emojis_obj: logger.warning(f"{self.log_prefix} 无法获取任何带有描述的有效表情包") return False, "无法获取任何带有描述的有效表情包" @@ -91,12 +93,12 @@ class EmojiAction(BaseAction): # 4. 准备情感数据和后备列表 emotion_map = {} all_emojis_data = [] - + for emoji in all_emojis_obj: b64 = image_path_to_base64(emoji.full_path) if not b64: continue - + desc = emoji.description emotions = emoji.emotion all_emojis_data.append((b64, desc)) @@ -168,16 +170,18 @@ class EmojiAction(BaseAction): # 使用模糊匹配来查找最相关的情感标签 matched_key = next((key for key in emotion_map if chosen_emotion in key), None) - + if matched_key: emoji_base64, emoji_description = random.choice(emotion_map[matched_key]) - logger.info(f"{self.log_prefix} 找到匹配情感 '{chosen_emotion}' (匹配到: '{matched_key}') 的表情包: {emoji_description}") + logger.info( + f"{self.log_prefix} 找到匹配情感 '{chosen_emotion}' (匹配到: '{matched_key}') 的表情包: {emoji_description}" + ) else: logger.warning( f"{self.log_prefix} LLM选择的情感 '{chosen_emotion}' 不在可用列表中, 将随机选择一个表情包" ) emoji_base64, emoji_description = random.choice(all_emojis_data) - + elif global_config.emoji.emoji_selection_mode == "description": # --- 详细描述选择模式 --- # 获取最近的5条消息内容用于判断 @@ -226,15 +230,23 @@ class EmojiAction(BaseAction): logger.info(f"{self.log_prefix} LLM选择的描述: {chosen_description}") # 简单关键词匹配 - matched_emoji = next((item for item in all_emojis_data if chosen_description.lower() in item[1].lower() or item[1].lower() in chosen_description.lower()), None) - + matched_emoji = next( + ( + item + for item in all_emojis_data + if chosen_description.lower() in item[1].lower() + or item[1].lower() in chosen_description.lower() + ), + None, + ) + # 如果包含匹配失败,尝试关键词匹配 if not matched_emoji: - keywords = ['惊讶', '困惑', '呆滞', '震惊', '懵', '无语', '萌', '可爱'] + keywords = ["惊讶", "困惑", "呆滞", "震惊", "懵", "无语", "萌", "可爱"] for keyword in keywords: if keyword in chosen_description: for item in all_emojis_data: - if any(k in item[1] for k in ['呆', '萌', '惊', '困惑', '无语']): + if any(k in item[1] for k in ["呆", "萌", "惊", "困惑", "无语"]): matched_emoji = item break if matched_emoji: @@ -255,7 +267,9 @@ class EmojiAction(BaseAction): if not success: logger.error(f"{self.log_prefix} 表情包发送失败") - await self.store_action_info(action_build_into_prompt = True,action_prompt_display =f"发送了一个表情包,但失败了",action_done= False) + await self.store_action_info( + action_build_into_prompt=True, action_prompt_display=f"发送了一个表情包,但失败了", action_done=False + ) return False, "表情包发送失败" # 发送成功后,记录到历史 @@ -263,8 +277,10 @@ class EmojiAction(BaseAction): add_emoji_to_history(self.chat_id, emoji_description) except Exception as e: logger.error(f"{self.log_prefix} 添加表情到历史记录时出错: {e}") - - await self.store_action_info(action_build_into_prompt = True,action_prompt_display =f"发送了一个表情包",action_done= True) + + await self.store_action_info( + action_build_into_prompt=True, action_prompt_display=f"发送了一个表情包", action_done=True + ) return True, f"发送表情包: {emoji_description}" diff --git a/src/plugins/built_in/napcat_adapter_plugin/event_handlers.py b/src/plugins/built_in/napcat_adapter_plugin/event_handlers.py index c4f889712..9fe6f8096 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/event_handlers.py +++ b/src/plugins/built_in/napcat_adapter_plugin/event_handlers.py @@ -1,4 +1,3 @@ - from src.plugin_system import BaseEventHandler from src.plugin_system.base.base_event import HandlerResult @@ -1748,6 +1747,7 @@ class SetGroupSignHandler(BaseEventHandler): logger.error("事件 napcat_set_group_sign 请求失败!") return HandlerResult(False, False, {"status": "error"}) + # ===PERSONAL=== class SetInputStatusHandler(BaseEventHandler): handler_name: str = "napcat_set_input_status_handler" diff --git a/src/plugins/built_in/napcat_adapter_plugin/plugin.py b/src/plugins/built_in/napcat_adapter_plugin/plugin.py index 952fcaccc..1c1138511 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/plugin.py +++ b/src/plugins/built_in/napcat_adapter_plugin/plugin.py @@ -227,7 +227,7 @@ class LauchNapcatAdapterHandler(BaseEventHandler): await reassembler.start_cleanup_task() logger.info("开始启动Napcat Adapter") - + # 创建单独的异步任务,防止阻塞主线程 asyncio.create_task(self._start_maibot_connection()) asyncio.create_task(napcat_server(self.plugin_config)) @@ -238,10 +238,10 @@ class LauchNapcatAdapterHandler(BaseEventHandler): """非阻塞方式启动MaiBot连接,等待主服务启动后再连接""" # 等待一段时间让MaiBot主服务完全启动 await asyncio.sleep(5) - + max_attempts = 10 attempt = 0 - + while attempt < max_attempts: try: logger.info(f"尝试连接MaiBot (第{attempt + 1}次)") @@ -285,7 +285,7 @@ class NapcatAdapterPlugin(BasePlugin): def enable_plugin(self) -> bool: """通过配置文件动态控制插件启用状态""" # 如果已经通过配置加载了状态,使用配置中的值 - if hasattr(self, '_is_enabled'): + if hasattr(self, "_is_enabled"): return self._is_enabled # 否则使用默认值(禁用状态) return False @@ -308,60 +308,107 @@ class NapcatAdapterPlugin(BasePlugin): "nickname": ConfigField(type=str, default="", description="昵称配置(目前未使用)"), }, "napcat_server": { - "mode": ConfigField(type=str, default="reverse", description="连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端)", choices=["reverse", "forward"]), + "mode": ConfigField( + type=str, + default="reverse", + description="连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端)", + choices=["reverse", "forward"], + ), "host": ConfigField(type=str, default="localhost", description="主机地址"), "port": ConfigField(type=int, default=8095, description="端口号"), - "url": ConfigField(type=str, default="", description="正向连接时的完整WebSocket URL,如 ws://localhost:8080/ws (仅在forward模式下使用)"), - "access_token": ConfigField(type=str, default="", description="WebSocket 连接的访问令牌,用于身份验证(可选)"), + "url": ConfigField( + type=str, + default="", + description="正向连接时的完整WebSocket URL,如 ws://localhost:8080/ws (仅在forward模式下使用)", + ), + "access_token": ConfigField( + type=str, default="", description="WebSocket 连接的访问令牌,用于身份验证(可选)" + ), "heartbeat_interval": ConfigField(type=int, default=30, description="心跳间隔时间(按秒计)"), }, "maibot_server": { - "host": ConfigField(type=str, default="localhost", description="麦麦在.env文件中设置的主机地址,即HOST字段"), + "host": ConfigField( + type=str, default="localhost", description="麦麦在.env文件中设置的主机地址,即HOST字段" + ), "port": ConfigField(type=int, default=8000, description="麦麦在.env文件中设置的端口,即PORT字段"), "platform_name": ConfigField(type=str, default="qq", description="平台名称,用于消息路由"), }, "voice": { - "use_tts": ConfigField(type=bool, default=False, description="是否使用tts语音(请确保你配置了tts并有对应的adapter)"), + "use_tts": ConfigField( + type=bool, default=False, description="是否使用tts语音(请确保你配置了tts并有对应的adapter)" + ), }, "slicing": { - "max_frame_size": ConfigField(type=int, default=64, description="WebSocket帧的最大大小,单位为字节,默认64KB"), + "max_frame_size": ConfigField( + type=int, default=64, description="WebSocket帧的最大大小,单位为字节,默认64KB" + ), "delay_ms": ConfigField(type=int, default=10, description="切片发送间隔时间,单位为毫秒"), }, "debug": { - "level": ConfigField(type=str, default="INFO", description="日志等级(DEBUG, INFO, WARNING, ERROR, CRITICAL)", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), + "level": ConfigField( + type=str, + default="INFO", + description="日志等级(DEBUG, INFO, WARNING, ERROR, CRITICAL)", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + ), }, "features": { # 权限设置 - "group_list_type": ConfigField(type=str, default="blacklist", description="群聊列表类型:whitelist(白名单)或 blacklist(黑名单)", choices=["whitelist", "blacklist"]), + "group_list_type": ConfigField( + type=str, + default="blacklist", + description="群聊列表类型:whitelist(白名单)或 blacklist(黑名单)", + choices=["whitelist", "blacklist"], + ), "group_list": ConfigField(type=list, default=[], description="群聊ID列表"), - "private_list_type": ConfigField(type=str, default="blacklist", description="私聊列表类型:whitelist(白名单)或 blacklist(黑名单)", choices=["whitelist", "blacklist"]), + "private_list_type": ConfigField( + type=str, + default="blacklist", + description="私聊列表类型:whitelist(白名单)或 blacklist(黑名单)", + choices=["whitelist", "blacklist"], + ), "private_list": ConfigField(type=list, default=[], description="用户ID列表"), - "ban_user_id": ConfigField(type=list, default=[], description="全局禁止用户ID列表,这些用户无法在任何地方使用机器人"), + "ban_user_id": ConfigField( + type=list, default=[], description="全局禁止用户ID列表,这些用户无法在任何地方使用机器人" + ), "ban_qq_bot": ConfigField(type=bool, default=False, description="是否屏蔽QQ官方机器人消息"), - # 聊天功能设置 "enable_poke": ConfigField(type=bool, default=True, description="是否启用戳一戳功能"), "ignore_non_self_poke": ConfigField(type=bool, default=False, description="是否无视不是针对自己的戳一戳"), - "poke_debounce_seconds": ConfigField(type=int, default=3, description="戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略"), + "poke_debounce_seconds": ConfigField( + type=int, default=3, description="戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略" + ), "enable_reply_at": ConfigField(type=bool, default=True, description="是否启用引用回复时艾特用户的功能"), "reply_at_rate": ConfigField(type=float, default=0.5, description="引用回复时艾特用户的几率 (0.0 ~ 1.0)"), "enable_emoji_like": ConfigField(type=bool, default=True, description="是否启用群聊表情回复功能"), - # 视频处理设置 "enable_video_analysis": ConfigField(type=bool, default=True, description="是否启用视频识别功能"), "max_video_size_mb": ConfigField(type=int, default=100, description="视频文件最大大小限制(MB)"), "download_timeout": ConfigField(type=int, default=60, description="视频下载超时时间(秒)"), - "supported_formats": ConfigField(type=list, default=["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"], description="支持的视频格式"), - + "supported_formats": ConfigField( + type=list, default=["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"], description="支持的视频格式" + ), # 消息缓冲设置 "enable_message_buffer": ConfigField(type=bool, default=True, description="是否启用消息缓冲合并功能"), "message_buffer_enable_group": ConfigField(type=bool, default=True, description="是否启用群聊消息缓冲合并"), - "message_buffer_enable_private": ConfigField(type=bool, default=True, description="是否启用私聊消息缓冲合并"), - "message_buffer_interval": ConfigField(type=float, default=3.0, description="消息合并间隔时间(秒),在此时间内的连续消息将被合并"), - "message_buffer_initial_delay": ConfigField(type=float, default=0.5, description="消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并"), - "message_buffer_max_components": ConfigField(type=int, default=50, description="单个会话最大缓冲消息组件数量,超过此数量将强制合并"), - "message_buffer_block_prefixes": ConfigField(type=list, default=["/", "!", "!", ".", "。", "#", "%"], description="消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲"), - } + "message_buffer_enable_private": ConfigField( + type=bool, default=True, description="是否启用私聊消息缓冲合并" + ), + "message_buffer_interval": ConfigField( + type=float, default=3.0, description="消息合并间隔时间(秒),在此时间内的连续消息将被合并" + ), + "message_buffer_initial_delay": ConfigField( + type=float, default=0.5, description="消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并" + ), + "message_buffer_max_components": ConfigField( + type=int, default=50, description="单个会话最大缓冲消息组件数量,超过此数量将强制合并" + ), + "message_buffer_block_prefixes": ConfigField( + type=list, + default=["/", "!", "!", ".", "。", "#", "%"], + description="消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲", + ), + }, } # 配置节描述 @@ -374,7 +421,7 @@ class NapcatAdapterPlugin(BasePlugin): "voice": "发送语音设置", "slicing": "WebSocket消息切片设置", "debug": "调试设置", - "features": "功能设置(权限控制、聊天功能、视频处理、消息缓冲等)" + "features": "功能设置(权限控制、聊天功能、视频处理、消息缓冲等)", } def register_events(self): @@ -409,6 +456,7 @@ class NapcatAdapterPlugin(BasePlugin): chunker.set_plugin_config(self.config) # 设置response_pool的插件配置 from .src.response_pool import set_plugin_config as set_response_pool_config + set_response_pool_config(self.config) # 设置send_handler的插件配置 send_handler.set_plugin_config(self.config) @@ -418,4 +466,4 @@ class NapcatAdapterPlugin(BasePlugin): notice_handler.set_plugin_config(self.config) # 设置meta_event_handler的插件配置 meta_event_handler.set_plugin_config(self.config) - # 设置其他handler的插件配置(现在由component_registry在注册时自动设置) \ No newline at end of file + # 设置其他handler的插件配置(现在由component_registry在注册时自动设置) diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/message_buffer.py b/src/plugins/built_in/napcat_adapter_plugin/src/message_buffer.py index 64a1e3faa..2bfe9078d 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/message_buffer.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/message_buffer.py @@ -102,7 +102,9 @@ class SimpleMessageBuffer: return True # 检查屏蔽前缀 - block_prefixes = tuple(config_api.get_plugin_config(self.plugin_config, "features.message_buffer_block_prefixes", [])) + block_prefixes = tuple( + config_api.get_plugin_config(self.plugin_config, "features.message_buffer_block_prefixes", []) + ) text = text.strip() if text.startswith(block_prefixes): @@ -134,9 +136,13 @@ class SimpleMessageBuffer: # 检查是否启用对应类型的缓冲 message_type = event_data.get("message_type", "") - if message_type == "group" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_group", False): + if message_type == "group" and not config_api.get_plugin_config( + self.plugin_config, "features.message_buffer_enable_group", False + ): return False - elif message_type == "private" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_private", False): + elif message_type == "private" and not config_api.get_plugin_config( + self.plugin_config, "features.message_buffer_enable_private", False + ): return False # 提取文本 @@ -158,7 +164,9 @@ class SimpleMessageBuffer: session = self.buffer_pool[session_id] # 检查是否超过最大组件数量 - if len(session.messages) >= config_api.get_plugin_config(self.plugin_config, "features.message_buffer_max_components", 5): + if len(session.messages) >= config_api.get_plugin_config( + self.plugin_config, "features.message_buffer_max_components", 5 + ): logger.debug(f"会话 {session_id} 消息数量达到上限,强制合并") asyncio.create_task(self._force_merge_session(session_id)) self.buffer_pool[session_id] = BufferedSession(session_id=session_id, original_event=original_event) diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/mmc_com_layer.py b/src/plugins/built_in/napcat_adapter_plugin/src/mmc_com_layer.py index 655fff64c..acd12fe01 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/mmc_com_layer.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/mmc_com_layer.py @@ -14,7 +14,7 @@ def create_router(plugin_config: dict): platform_name = config_api.get_plugin_config(plugin_config, "maibot_server.platform_name", "qq") host = config_api.get_plugin_config(plugin_config, "maibot_server.host", "localhost") port = config_api.get_plugin_config(plugin_config, "maibot_server.port", 8000) - + route_config = RouteConfig( route_config={ platform_name: TargetConfig( @@ -32,7 +32,7 @@ async def mmc_start_com(plugin_config: dict = None): logger.info("正在连接MaiBot") if plugin_config: create_router(plugin_config) - + if router: router.register_class_handler(send_handler.handle_message) await router.run() diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/__init__.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/__init__.py index 48561ffbe..231c0ce39 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/__init__.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/__init__.py @@ -32,7 +32,7 @@ class NoticeType: # 通知事件 group_recall = "group_recall" # 群聊消息撤回 notify = "notify" group_ban = "group_ban" # 群禁言 - group_msg_emoji_like = "group_msg_emoji_like" # 群聊表情回复 + group_msg_emoji_like = "group_msg_emoji_like" # 群聊表情回复 class Notify: poke = "poke" # 戳一戳 diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py index 0a644345b..88eb48abc 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_handler.py @@ -100,7 +100,7 @@ class MessageHandler: # 检查群聊黑白名单 group_list_type = config_api.get_plugin_config(self.plugin_config, "features.group_list_type", "blacklist") group_list = config_api.get_plugin_config(self.plugin_config, "features.group_list", []) - + if group_list_type == "whitelist": if group_id not in group_list: logger.warning("群聊不在白名单中,消息被丢弃") @@ -111,9 +111,11 @@ class MessageHandler: return False else: # 检查私聊黑白名单 - private_list_type = config_api.get_plugin_config(self.plugin_config, "features.private_list_type", "blacklist") + private_list_type = config_api.get_plugin_config( + self.plugin_config, "features.private_list_type", "blacklist" + ) private_list = config_api.get_plugin_config(self.plugin_config, "features.private_list", []) - + if private_list_type == "whitelist": if user_id not in private_list: logger.warning("私聊不在白名单中,消息被丢弃") @@ -156,21 +158,23 @@ class MessageHandler: Parameters: raw_message: dict: 原始消息 """ - + # 添加原始消息调试日志,特别关注message字段 - logger.debug(f"收到原始消息: message_type={raw_message.get('message_type')}, message_id={raw_message.get('message_id')}") + logger.debug( + f"收到原始消息: message_type={raw_message.get('message_type')}, message_id={raw_message.get('message_id')}" + ) logger.debug(f"原始消息内容: {raw_message.get('message', [])}") - + # 检查是否包含@或video消息段 - message_segments = raw_message.get('message', []) + message_segments = raw_message.get("message", []) if message_segments: for i, seg in enumerate(message_segments): - seg_type = seg.get('type') - if seg_type in ['at', 'video']: + seg_type = seg.get("type") + if seg_type in ["at", "video"]: logger.info(f"检测到 {seg_type.upper()} 消息段 [{i}]: {seg}") - elif seg_type not in ['text', 'face', 'image']: + elif seg_type not in ["text", "face", "image"]: logger.warning(f"检测到特殊消息段 [{i}]: type={seg_type}, data={seg.get('data', {})}") - + message_type: str = raw_message.get("message_type") message_id: int = raw_message.get("message_id") # message_time: int = raw_message.get("time") @@ -308,9 +312,13 @@ class MessageHandler: message_type = raw_message.get("message_type") should_use_buffer = False - if message_type == "group" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_group", True): + if message_type == "group" and config_api.get_plugin_config( + self.plugin_config, "features.message_buffer_enable_group", True + ): should_use_buffer = True - elif message_type == "private" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_private", True): + elif message_type == "private" and config_api.get_plugin_config( + self.plugin_config, "features.message_buffer_enable_private", True + ): should_use_buffer = True if should_use_buffer: @@ -368,10 +376,10 @@ class MessageHandler: for sub_message in real_message: sub_message: dict sub_message_type = sub_message.get("type") - + # 添加详细的消息类型调试信息 logger.debug(f"处理消息段: type={sub_message_type}, data={sub_message.get('data', {})}") - + # 特别关注 at 和 video 消息的识别 if sub_message_type == "at": logger.debug(f"检测到@消息: {sub_message}") @@ -379,7 +387,7 @@ class MessageHandler: logger.debug(f"检测到VIDEO消息: {sub_message}") elif sub_message_type not in ["text", "face", "image", "record"]: logger.warning(f"检测到特殊消息类型: {sub_message_type}, 完整消息: {sub_message}") - + match sub_message_type: case RealMessageType.text: ret_seg = await self.handle_text_message(sub_message) diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_sending.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_sending.py index b7ca408d9..ade4c7193 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_sending.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/message_sending.py @@ -33,6 +33,7 @@ class MessageSending: try: # 重新导入router from ..mmc_com_layer import router + self.maibot_router = router if self.maibot_router is not None: logger.info("MaiBot router重连成功") @@ -73,14 +74,14 @@ class MessageSending: # 获取对应的客户端并发送切片 platform = message_base.message_info.platform - + # 再次检查router状态(防止运行时被重置) - if self.maibot_router is None or not hasattr(self.maibot_router, 'clients'): + if self.maibot_router is None or not hasattr(self.maibot_router, "clients"): logger.warning("MaiBot router连接已断开,尝试重新连接") if not await self._attempt_reconnect(): logger.error("MaiBot router重连失败,切片发送中止") return False - + if platform not in self.maibot_router.clients: logger.error(f"平台 {platform} 未连接") return False diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py index 217347c36..7f310fbfa 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/meta_event_handler.py @@ -22,7 +22,9 @@ class MetaEventHandler: """设置插件配置""" self.plugin_config = plugin_config # 更新interval值 - self.interval = config_api.get_plugin_config(self.plugin_config, "napcat_server.heartbeat_interval", 5000) / 1000 + self.interval = ( + config_api.get_plugin_config(self.plugin_config, "napcat_server.heartbeat_interval", 5000) / 1000 + ) async def handle_meta_event(self, message: dict) -> None: event_type = message.get("meta_event_type") diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py index c373a9a10..5ea018f4d 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/recv_handler/notice_handler.py @@ -116,9 +116,9 @@ class NoticeHandler: sub_type = raw_message.get("sub_type") match sub_type: case NoticeType.Notify.poke: - if config_api.get_plugin_config(self.plugin_config, "features.enable_poke", True) and await message_handler.check_allow_to_chat( - user_id, group_id, False, False - ): + if config_api.get_plugin_config( + self.plugin_config, "features.enable_poke", True + ) and await message_handler.check_allow_to_chat(user_id, group_id, False, False): logger.debug("处理戳一戳消息") handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id) else: @@ -127,14 +127,18 @@ class NoticeHandler: from src.plugin_system.core.event_manager import event_manager from ...event_types import NapcatEvent - await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME) + await event_manager.trigger_event( + NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME + ) case _: logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}") - case NoticeType.group_msg_emoji_like: + case NoticeType.group_msg_emoji_like: # 该事件转移到 handle_group_emoji_like_notify函数内触发 if config_api.get_plugin_config(self.plugin_config, "features.enable_emoji_like", True): logger.debug("处理群聊表情回复") - handled_message, user_info = await self.handle_group_emoji_like_notify(raw_message,group_id,user_id) + handled_message, user_info = await self.handle_group_emoji_like_notify( + raw_message, group_id, user_id + ) else: logger.warning("群聊表情回复被禁用,取消群聊表情回复处理") case NoticeType.group_ban: @@ -294,7 +298,7 @@ class NoticeHandler: async def handle_group_emoji_like_notify(self, raw_message: dict, group_id: int, user_id: int): if not group_id: logger.error("群ID不能为空,无法处理群聊表情回复通知") - return None, None + return None, None user_qq_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id) if user_qq_info: @@ -304,37 +308,42 @@ class NoticeHandler: user_name = "QQ用户" user_cardname = "QQ用户" logger.debug("无法获取表情回复对方的用户昵称") - + from src.plugin_system.core.event_manager import event_manager from ...event_types import NapcatEvent - target_message = await event_manager.trigger_event(NapcatEvent.MESSAGE.GET_MSG,message_id=raw_message.get("message_id","")) - target_message_text = target_message.get_message_result().get("data",{}).get("raw_message","") + target_message = await event_manager.trigger_event( + NapcatEvent.MESSAGE.GET_MSG, message_id=raw_message.get("message_id", "") + ) + target_message_text = target_message.get_message_result().get("data", {}).get("raw_message", "") if not target_message: logger.error("未找到对应消息") return None, None if len(target_message_text) > 15: target_message_text = target_message_text[:15] + "..." - + user_info: UserInfo = UserInfo( platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"), user_id=user_id, user_nickname=user_name, user_cardname=user_cardname, ) - + like_emoji_id = raw_message.get("likes")[0].get("emoji_id") await event_manager.trigger_event( - NapcatEvent.ON_RECEIVED.EMOJI_LIEK, - permission_group=PLUGIN_NAME, - group_id=group_id, - user_id=user_id, - message_id=raw_message.get("message_id",""), - emoji_id=like_emoji_id - ) - seg_data = Seg(type="text",data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id,"")}回复了你的消息[{target_message_text}]") + NapcatEvent.ON_RECEIVED.EMOJI_LIEK, + permission_group=PLUGIN_NAME, + group_id=group_id, + user_id=user_id, + message_id=raw_message.get("message_id", ""), + emoji_id=like_emoji_id, + ) + seg_data = Seg( + type="text", + data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]", + ) return seg_data, user_info - + async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]: if not group_id: logger.error("群ID不能为空,无法处理禁言通知") diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/response_pool.py b/src/plugins/built_in/napcat_adapter_plugin/src/response_pool.py index 3e8e5c4a4..7ba313af5 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/response_pool.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/response_pool.py @@ -45,12 +45,12 @@ async def check_timeout_response() -> None: while True: cleaned_message_count: int = 0 now_time = time.time() - + # 获取心跳间隔配置 heartbeat_interval = 30 # 默认值 if plugin_config: heartbeat_interval = config_api.get_plugin_config(plugin_config, "napcat_server.heartbeat_interval", 30) - + for echo_id, response_time in list(response_time_dict.items()): if now_time - response_time > heartbeat_interval: cleaned_message_count += 1 diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py b/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py index 0bb7435ee..40e144821 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/send_handler.py @@ -297,9 +297,9 @@ class SendHandler: try: # 检查是否为缓冲消息ID(格式:buffered-{original_id}-{timestamp}) - if id.startswith('buffered-'): + if id.startswith("buffered-"): # 从缓冲消息ID中提取原始消息ID - original_id = id.split('-')[1] + original_id = id.split("-")[1] msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(original_id)}) else: msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(id)}) @@ -363,7 +363,7 @@ class SendHandler: use_tts = False if self.plugin_config: use_tts = config_api.get_plugin_config(self.plugin_config, "voice.use_tts", False) - + if not use_tts: logger.warning("未启用语音消息处理") return {} diff --git a/src/plugins/built_in/napcat_adapter_plugin/src/websocket_manager.py b/src/plugins/built_in/napcat_adapter_plugin/src/websocket_manager.py index 484b9b59e..0ef55a70f 100644 --- a/src/plugins/built_in/napcat_adapter_plugin/src/websocket_manager.py +++ b/src/plugins/built_in/napcat_adapter_plugin/src/websocket_manager.py @@ -18,7 +18,9 @@ class WebSocketManager: self.max_reconnect_attempts = 10 # 最大重连次数 self.plugin_config = None - async def start_connection(self, message_handler: Callable[[Server.ServerConnection], Any], plugin_config: dict) -> None: + async def start_connection( + self, message_handler: Callable[[Server.ServerConnection], Any], plugin_config: dict + ) -> None: """根据配置启动 WebSocket 连接""" self.plugin_config = plugin_config mode = config_api.get_plugin_config(plugin_config, "napcat_server.mode") @@ -72,9 +74,7 @@ class WebSocketManager: # 如果配置了访问令牌,添加到请求头 access_token = config_api.get_plugin_config(self.plugin_config, "napcat_server.access_token") if access_token: - connect_kwargs["additional_headers"] = { - "Authorization": f"Bearer {access_token}" - } + connect_kwargs["additional_headers"] = {"Authorization": f"Bearer {access_token}"} logger.info("已添加访问令牌到连接请求头") async with Server.connect(url, **connect_kwargs) as websocket: diff --git a/src/plugins/built_in/web_search_tool/engines/base.py b/src/plugins/built_in/web_search_tool/engines/base.py index f7641aa2f..30d20a540 100644 --- a/src/plugins/built_in/web_search_tool/engines/base.py +++ b/src/plugins/built_in/web_search_tool/engines/base.py @@ -1,6 +1,7 @@ """ Base search engine interface """ + from abc import ABC, abstractmethod from typing import Dict, List, Any @@ -9,20 +10,20 @@ class BaseSearchEngine(ABC): """ 搜索引擎基类 """ - + @abstractmethod async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: """ 执行搜索 - + Args: args: 搜索参数,包含 query、num_results、time_range 等 - + Returns: 搜索结果列表,每个结果包含 title、url、snippet、provider 字段 """ pass - + @abstractmethod def is_available(self) -> bool: """ diff --git a/src/plugins/built_in/web_search_tool/engines/bing_engine.py b/src/plugins/built_in/web_search_tool/engines/bing_engine.py index ac90956e0..c779ed39c 100644 --- a/src/plugins/built_in/web_search_tool/engines/bing_engine.py +++ b/src/plugins/built_in/web_search_tool/engines/bing_engine.py @@ -1,6 +1,7 @@ """ Bing search engine implementation """ + import asyncio import functools import random @@ -58,21 +59,21 @@ class BingSearchEngine(BaseSearchEngine): """ Bing搜索引擎实现 """ - + def __init__(self): self.session = requests.Session() self.session.headers = HEADERS - + def is_available(self) -> bool: """检查Bing搜索引擎是否可用""" return True # Bing是免费搜索引擎,总是可用 - + async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: """执行Bing搜索""" query = args["query"] num_results = args.get("num_results", 3) time_range = args.get("time_range", "any") - + try: loop = asyncio.get_running_loop() func = functools.partial(self._search_sync, query, num_results, time_range) @@ -81,17 +82,17 @@ class BingSearchEngine(BaseSearchEngine): except Exception as e: logger.error(f"Bing 搜索失败: {e}") return [] - + def _search_sync(self, keyword: str, num_results: int, time_range: str) -> List[Dict[str, Any]]: """同步执行Bing搜索""" if not keyword: return [] list_result = [] - + # 构建搜索URL search_url = bing_search_url + keyword - + # 如果指定了时间范围,添加时间过滤参数 if time_range == "week": search_url += "&qft=+filterui:date-range-7" @@ -181,34 +182,29 @@ class BingSearchEngine(BaseSearchEngine): # 尝试提取搜索结果 # 方法1: 查找标准的搜索结果容器 results = root.select("ol#b_results li.b_algo") - + if results: for _rank, result in enumerate(results, 1): # 提取标题和链接 title_link = result.select_one("h2 a") if not title_link: continue - + title = title_link.get_text().strip() url = title_link.get("href", "") - + # 提取摘要 abstract = "" abstract_elem = result.select_one("div.b_caption p") if abstract_elem: abstract = abstract_elem.get_text().strip() - + # 限制摘要长度 if ABSTRACT_MAX_LENGTH and len(abstract) > ABSTRACT_MAX_LENGTH: abstract = abstract[:ABSTRACT_MAX_LENGTH] + "..." - - list_data.append({ - "title": title, - "url": url, - "snippet": abstract, - "provider": "Bing" - }) - + + list_data.append({"title": title, "url": url, "snippet": abstract, "provider": "Bing"}) + if len(list_data) >= 10: # 限制结果数量 break @@ -216,22 +212,34 @@ class BingSearchEngine(BaseSearchEngine): if not list_data: # 查找所有可能的搜索结果链接 all_links = root.find_all("a") - + for link in all_links: href = link.get("href", "") text = link.get_text().strip() - + # 过滤有效的搜索结果链接 - if (href and text and len(text) > 10 + if ( + href + and text + and len(text) > 10 and not href.startswith("javascript:") and not href.startswith("#") and "http" in href - and not any(x in href for x in [ - "bing.com/search", "bing.com/images", "bing.com/videos", - "bing.com/maps", "bing.com/news", "login", "account", - "microsoft", "javascript" - ])): - + and not any( + x in href + for x in [ + "bing.com/search", + "bing.com/images", + "bing.com/videos", + "bing.com/maps", + "bing.com/news", + "login", + "account", + "microsoft", + "javascript", + ] + ) + ): # 尝试获取摘要 abstract = "" parent = link.parent @@ -239,18 +247,13 @@ class BingSearchEngine(BaseSearchEngine): full_text = parent.get_text().strip() if len(full_text) > len(text): abstract = full_text.replace(text, "", 1).strip() - + # 限制摘要长度 if ABSTRACT_MAX_LENGTH and len(abstract) > ABSTRACT_MAX_LENGTH: abstract = abstract[:ABSTRACT_MAX_LENGTH] + "..." - - list_data.append({ - "title": text, - "url": href, - "snippet": abstract, - "provider": "Bing" - }) - + + list_data.append({"title": text, "url": href, "snippet": abstract, "provider": "Bing"}) + if len(list_data) >= 10: break diff --git a/src/plugins/built_in/web_search_tool/engines/ddg_engine.py b/src/plugins/built_in/web_search_tool/engines/ddg_engine.py index 011935e27..29f03b31a 100644 --- a/src/plugins/built_in/web_search_tool/engines/ddg_engine.py +++ b/src/plugins/built_in/web_search_tool/engines/ddg_engine.py @@ -1,6 +1,7 @@ """ DuckDuckGo search engine implementation """ + from typing import Dict, List, Any from asyncddgs import aDDGS @@ -14,27 +15,22 @@ class DDGSearchEngine(BaseSearchEngine): """ DuckDuckGo搜索引擎实现 """ - + def is_available(self) -> bool: """检查DuckDuckGo搜索引擎是否可用""" return True # DuckDuckGo不需要API密钥,总是可用 - + async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: """执行DuckDuckGo搜索""" query = args["query"] num_results = args.get("num_results", 3) - + try: async with aDDGS() as ddgs: search_response = await ddgs.text(query, max_results=num_results) - + return [ - { - "title": r.get("title"), - "url": r.get("href"), - "snippet": r.get("body"), - "provider": "DuckDuckGo" - } + {"title": r.get("title"), "url": r.get("href"), "snippet": r.get("body"), "provider": "DuckDuckGo"} for r in search_response ] except Exception as e: diff --git a/src/plugins/built_in/web_search_tool/engines/exa_engine.py b/src/plugins/built_in/web_search_tool/engines/exa_engine.py index 7327afaeb..269e32bd1 100644 --- a/src/plugins/built_in/web_search_tool/engines/exa_engine.py +++ b/src/plugins/built_in/web_search_tool/engines/exa_engine.py @@ -1,6 +1,7 @@ """ Exa search engine implementation """ + import asyncio import functools from datetime import datetime, timedelta @@ -19,31 +20,27 @@ class ExaSearchEngine(BaseSearchEngine): """ Exa搜索引擎实现 """ - + def __init__(self): self._initialize_clients() - + def _initialize_clients(self): """初始化Exa客户端""" # 从主配置文件读取API密钥 exa_api_keys = config_api.get_global_config("web_search.exa_api_keys", None) - + # 创建API密钥管理器 - self.api_manager = create_api_key_manager_from_config( - exa_api_keys, - lambda key: Exa(api_key=key), - "Exa" - ) - + self.api_manager = create_api_key_manager_from_config(exa_api_keys, lambda key: Exa(api_key=key), "Exa") + def is_available(self) -> bool: """检查Exa搜索引擎是否可用""" return self.api_manager.is_available() - + async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: """执行Exa搜索""" if not self.is_available(): return [] - + query = args["query"] num_results = args.get("num_results", 3) time_range = args.get("time_range", "any") @@ -52,7 +49,7 @@ class ExaSearchEngine(BaseSearchEngine): if time_range != "any": today = datetime.now() start_date = today - timedelta(days=7 if time_range == "week" else 30) - exa_args["start_published_date"] = start_date.strftime('%Y-%m-%d') + exa_args["start_published_date"] = start_date.strftime("%Y-%m-%d") try: # 使用API密钥管理器获取下一个客户端 @@ -60,17 +57,17 @@ class ExaSearchEngine(BaseSearchEngine): if not exa_client: logger.error("无法获取Exa客户端") return [] - + loop = asyncio.get_running_loop() func = functools.partial(exa_client.search_and_contents, query, **exa_args) search_response = await loop.run_in_executor(None, func) - + return [ { "title": res.title, "url": res.url, - "snippet": " ".join(getattr(res, 'highlights', [])) or (getattr(res, 'text', '')[:250] + '...'), - "provider": "Exa" + "snippet": " ".join(getattr(res, "highlights", [])) or (getattr(res, "text", "")[:250] + "..."), + "provider": "Exa", } for res in search_response.results ] diff --git a/src/plugins/built_in/web_search_tool/engines/tavily_engine.py b/src/plugins/built_in/web_search_tool/engines/tavily_engine.py index d7cf61d6c..2f929284f 100644 --- a/src/plugins/built_in/web_search_tool/engines/tavily_engine.py +++ b/src/plugins/built_in/web_search_tool/engines/tavily_engine.py @@ -1,6 +1,7 @@ """ Tavily search engine implementation """ + import asyncio import functools from typing import Dict, List, Any @@ -18,31 +19,29 @@ class TavilySearchEngine(BaseSearchEngine): """ Tavily搜索引擎实现 """ - + def __init__(self): self._initialize_clients() - + def _initialize_clients(self): """初始化Tavily客户端""" # 从主配置文件读取API密钥 tavily_api_keys = config_api.get_global_config("web_search.tavily_api_keys", None) - + # 创建API密钥管理器 self.api_manager = create_api_key_manager_from_config( - tavily_api_keys, - lambda key: TavilyClient(api_key=key), - "Tavily" + tavily_api_keys, lambda key: TavilyClient(api_key=key), "Tavily" ) - + def is_available(self) -> bool: """检查Tavily搜索引擎是否可用""" return self.api_manager.is_available() - + async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]: """执行Tavily搜索""" if not self.is_available(): return [] - + query = args["query"] num_results = args.get("num_results", 3) time_range = args.get("time_range", "any") @@ -53,38 +52,40 @@ class TavilySearchEngine(BaseSearchEngine): if not tavily_client: logger.error("无法获取Tavily客户端") return [] - + # 构建Tavily搜索参数 search_params = { "query": query, "max_results": num_results, "search_depth": "basic", "include_answer": False, - "include_raw_content": False + "include_raw_content": False, } - + # 根据时间范围调整搜索参数 if time_range == "week": search_params["days"] = 7 elif time_range == "month": search_params["days"] = 30 - + loop = asyncio.get_running_loop() func = functools.partial(tavily_client.search, **search_params) search_response = await loop.run_in_executor(None, func) - + results = [] if search_response and "results" in search_response: for res in search_response["results"]: - results.append({ - "title": res.get("title", "无标题"), - "url": res.get("url", ""), - "snippet": res.get("content", "")[:300] + "..." if res.get("content") else "无摘要", - "provider": "Tavily" - }) - + results.append( + { + "title": res.get("title", "无标题"), + "url": res.get("url", ""), + "snippet": res.get("content", "")[:300] + "..." if res.get("content") else "无摘要", + "provider": "Tavily", + } + ) + return results - + except Exception as e: logger.error(f"Tavily 搜索失败: {e}") return [] diff --git a/src/plugins/built_in/web_search_tool/plugin.py b/src/plugins/built_in/web_search_tool/plugin.py index 1789062ae..fadc02a88 100644 --- a/src/plugins/built_in/web_search_tool/plugin.py +++ b/src/plugins/built_in/web_search_tool/plugin.py @@ -3,15 +3,10 @@ Web Search Tool Plugin 一个功能强大的网络搜索和URL解析插件,支持多种搜索引擎和解析策略。 """ + from typing import List, Tuple, Type -from src.plugin_system import ( - BasePlugin, - register_plugin, - ComponentInfo, - ConfigField, - PythonDependency -) +from src.plugin_system import BasePlugin, register_plugin, ComponentInfo, ConfigField, PythonDependency from src.plugin_system.apis import config_api from src.common.logger import get_logger @@ -25,7 +20,7 @@ logger = get_logger("web_search_plugin") class WEBSEARCHPLUGIN(BasePlugin): """ 网络搜索工具插件 - + 提供网络搜索和URL解析功能,支持多种搜索引擎: - Exa (需要API密钥) - Tavily (需要API密钥) @@ -37,11 +32,11 @@ class WEBSEARCHPLUGIN(BasePlugin): plugin_name: str = "web_search_tool" # 内部标识符 enable_plugin: bool = True dependencies: List[str] = [] # 插件依赖列表 - + def __init__(self, *args, **kwargs): """初始化插件,立即加载所有搜索引擎""" super().__init__(*args, **kwargs) - + # 立即初始化所有搜索引擎,触发API密钥管理器的日志输出 logger.info("🚀 正在初始化所有搜索引擎...") try: @@ -49,65 +44,58 @@ class WEBSEARCHPLUGIN(BasePlugin): from .engines.tavily_engine import TavilySearchEngine from .engines.ddg_engine import DDGSearchEngine from .engines.bing_engine import BingSearchEngine - + # 实例化所有搜索引擎,这会触发API密钥管理器的初始化 exa_engine = ExaSearchEngine() tavily_engine = TavilySearchEngine() ddg_engine = DDGSearchEngine() bing_engine = BingSearchEngine() - + # 报告每个引擎的状态 engines_status = { "Exa": exa_engine.is_available(), "Tavily": tavily_engine.is_available(), "DuckDuckGo": ddg_engine.is_available(), - "Bing": bing_engine.is_available() + "Bing": bing_engine.is_available(), } - + available_engines = [name for name, available in engines_status.items() if available] unavailable_engines = [name for name, available in engines_status.items() if not available] - + if available_engines: logger.info(f"✅ 可用搜索引擎: {', '.join(available_engines)}") if unavailable_engines: logger.info(f"❌ 不可用搜索引擎: {', '.join(unavailable_engines)}") - + except Exception as e: logger.error(f"❌ 搜索引擎初始化失败: {e}", exc_info=True) - + # Python包依赖列表 python_dependencies: List[PythonDependency] = [ - PythonDependency( - package_name="asyncddgs", - description="异步DuckDuckGo搜索库", - optional=False - ), + PythonDependency(package_name="asyncddgs", description="异步DuckDuckGo搜索库", optional=False), PythonDependency( package_name="exa_py", description="Exa搜索API客户端库", - optional=True # 如果没有API密钥,这个是可选的 + optional=True, # 如果没有API密钥,这个是可选的 ), PythonDependency( package_name="tavily", install_name="tavily-python", # 安装时使用这个名称 description="Tavily搜索API客户端库", - optional=True # 如果没有API密钥,这个是可选的 + optional=True, # 如果没有API密钥,这个是可选的 ), PythonDependency( package_name="httpx", version=">=0.20.0", install_name="httpx[socks]", # 安装时使用这个名称(包含可选依赖) description="支持SOCKS代理的HTTP客户端库", - optional=False - ) + optional=False, + ), ] config_file_name: str = "config.toml" # 配置文件名 # 配置节描述 - config_section_descriptions = { - "plugin": "插件基本信息", - "proxy": "链接本地解析代理配置" - } + config_section_descriptions = {"plugin": "插件基本信息", "proxy": "链接本地解析代理配置"} # 配置Schema定义 # 注意:EXA配置和组件设置已迁移到主配置文件(bot_config.toml)的[exa]和[web_search]部分 @@ -119,42 +107,32 @@ class WEBSEARCHPLUGIN(BasePlugin): }, "proxy": { "http_proxy": ConfigField( - type=str, - default=None, - description="HTTP代理地址,格式如: http://proxy.example.com:8080" + type=str, default=None, description="HTTP代理地址,格式如: http://proxy.example.com:8080" ), "https_proxy": ConfigField( - type=str, - default=None, - description="HTTPS代理地址,格式如: http://proxy.example.com:8080" + type=str, default=None, description="HTTPS代理地址,格式如: http://proxy.example.com:8080" ), "socks5_proxy": ConfigField( - type=str, - default=None, - description="SOCKS5代理地址,格式如: socks5://proxy.example.com:1080" + type=str, default=None, description="SOCKS5代理地址,格式如: socks5://proxy.example.com:1080" ), - "enable_proxy": ConfigField( - type=bool, - default=False, - description="是否启用代理" - ) + "enable_proxy": ConfigField(type=bool, default=False, description="是否启用代理"), }, } - + def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: """ 获取插件组件列表 - + Returns: 组件信息和类型的元组列表 """ enable_tool = [] - + # 从主配置文件读取组件启用配置 if config_api.get_global_config("web_search.enable_web_search_tool", True): enable_tool.append((WebSurfingTool.get_tool_info(), WebSurfingTool)) - + if config_api.get_global_config("web_search.enable_url_tool", True): enable_tool.append((URLParserTool.get_tool_info(), URLParserTool)) - + return enable_tool diff --git a/src/plugins/built_in/web_search_tool/tools/url_parser.py b/src/plugins/built_in/web_search_tool/tools/url_parser.py index 315e06271..da91c419a 100644 --- a/src/plugins/built_in/web_search_tool/tools/url_parser.py +++ b/src/plugins/built_in/web_search_tool/tools/url_parser.py @@ -1,6 +1,7 @@ """ URL parser tool implementation """ + import asyncio import functools from typing import Any, Dict @@ -24,17 +25,18 @@ class URLParserTool(BaseTool): """ 一个用于解析和总结一个或多个网页URL内容的工具。 """ + name: str = "parse_url" description: str = "当需要理解一个或多个特定网页链接的内容时,使用此工具。例如:'这些网页讲了什么?[https://example.com, https://example2.com]' 或 '帮我总结一下这些文章'" available_for_llm: bool = True parameters = [ ("urls", ToolParamType.STRING, "要理解的网站", True, None), ] - + def __init__(self, plugin_config=None): super().__init__(plugin_config) self._initialize_exa_clients() - + def _initialize_exa_clients(self): """初始化Exa客户端""" # 优先从主配置文件读取,如果没有则从插件配置文件读取 @@ -42,12 +44,10 @@ class URLParserTool(BaseTool): if exa_api_keys is None: # 从插件配置文件读取 exa_api_keys = self.get_config("exa.api_keys", []) - + # 创建API密钥管理器 self.api_manager = create_api_key_manager_from_config( - exa_api_keys, - lambda key: Exa(api_key=key), - "Exa URL Parser" + exa_api_keys, lambda key: Exa(api_key=key), "Exa URL Parser" ) async def _local_parse_and_summarize(self, url: str) -> Dict[str, Any]: @@ -58,12 +58,12 @@ class URLParserTool(BaseTool): # 读取代理配置 enable_proxy = self.get_config("proxy.enable_proxy", False) proxies = None - + if enable_proxy: socks5_proxy = self.get_config("proxy.socks5_proxy", None) http_proxy = self.get_config("proxy.http_proxy", None) https_proxy = self.get_config("proxy.https_proxy", None) - + # 优先使用SOCKS5代理(全协议代理) if socks5_proxy: proxies = socks5_proxy @@ -75,17 +75,17 @@ class URLParserTool(BaseTool): if https_proxy: proxies["https://"] = https_proxy logger.info(f"使用HTTP/HTTPS代理配置: {proxies}") - + client_kwargs = {"timeout": 15.0, "follow_redirects": True} if proxies: client_kwargs["proxies"] = proxies - + async with httpx.AsyncClient(**client_kwargs) as client: response = await client.get(url) response.raise_for_status() soup = BeautifulSoup(response.text, "html.parser") - + title = soup.title.string if soup.title else "无标题" for script in soup(["script", "style"]): script.extract() @@ -104,12 +104,12 @@ class URLParserTool(BaseTool): return {"error": "未配置LLM模型"} success, summary, reasoning, model_name = await llm_api.generate_with_model( - prompt=summary_prompt, - model_config=model_config, - request_type="story.generate", - temperature=0.3, - max_tokens=1000 - ) + prompt=summary_prompt, + model_config=model_config, + request_type="story.generate", + temperature=0.3, + max_tokens=1000, + ) if not success: logger.info(f"生成摘要失败: {summary}") @@ -117,12 +117,7 @@ class URLParserTool(BaseTool): logger.info(f"成功生成摘要内容:'{summary}'") - return { - "title": title, - "url": url, - "snippet": summary, - "source": "local" - } + return {"title": title, "url": url, "snippet": summary, "source": "local"} except httpx.HTTPStatusError as e: logger.warning(f"本地解析URL '{url}' 失败 (HTTP {e.response.status_code})") @@ -137,6 +132,7 @@ class URLParserTool(BaseTool): """ # 获取当前文件路径用于缓存键 import os + current_file_path = os.path.abspath(__file__) # 检查缓存 @@ -144,7 +140,7 @@ class URLParserTool(BaseTool): if cached_result: logger.info(f"缓存命中: {self.name} -> {function_args}") return cached_result - + urls_input = function_args.get("urls") if not urls_input: return {"error": "URL列表不能为空。"} @@ -158,14 +154,14 @@ class URLParserTool(BaseTool): valid_urls = validate_urls(urls) if not valid_urls: return {"error": "未找到有效的URL。"} - + urls = valid_urls logger.info(f"准备解析 {len(urls)} 个URL: {urls}") successful_results = [] error_messages = [] urls_to_retry_locally = [] - + # 步骤 1: 尝试使用 Exa API 进行解析 contents_response = None if self.api_manager.is_available(): @@ -182,41 +178,45 @@ class URLParserTool(BaseTool): contents_response = await loop.run_in_executor(None, func) except Exception as e: logger.error(f"执行 Exa URL解析时发生严重异常: {e}", exc_info=True) - contents_response = None # 确保异常后为None + contents_response = None # 确保异常后为None # 步骤 2: 处理Exa的响应 - if contents_response and hasattr(contents_response, 'statuses'): - results_map = {res.url: res for res in contents_response.results} if hasattr(contents_response, 'results') else {} + if contents_response and hasattr(contents_response, "statuses"): + results_map = ( + {res.url: res for res in contents_response.results} if hasattr(contents_response, "results") else {} + ) if contents_response.statuses: for status in contents_response.statuses: - if status.status == 'success': + if status.status == "success": res = results_map.get(status.id) if res: - summary = getattr(res, 'summary', '') - highlights = " ".join(getattr(res, 'highlights', [])) - text_snippet = (getattr(res, 'text', '')[:300] + '...') if getattr(res, 'text', '') else '' - snippet = summary or highlights or text_snippet or '无摘要' - - successful_results.append({ - "title": getattr(res, 'title', '无标题'), - "url": getattr(res, 'url', status.id), - "snippet": snippet, - "source": "exa" - }) + summary = getattr(res, "summary", "") + highlights = " ".join(getattr(res, "highlights", [])) + text_snippet = (getattr(res, "text", "")[:300] + "...") if getattr(res, "text", "") else "" + snippet = summary or highlights or text_snippet or "无摘要" + + successful_results.append( + { + "title": getattr(res, "title", "无标题"), + "url": getattr(res, "url", status.id), + "snippet": snippet, + "source": "exa", + } + ) else: - error_tag = getattr(status, 'error', '未知错误') + error_tag = getattr(status, "error", "未知错误") logger.warning(f"Exa解析URL '{status.id}' 失败: {error_tag}。准备本地重试。") urls_to_retry_locally.append(status.id) else: # 如果Exa未配置、API调用失败或返回无效响应,则所有URL都进入本地重试 - urls_to_retry_locally.extend(url for url in urls if url not in [res['url'] for res in successful_results]) + urls_to_retry_locally.extend(url for url in urls if url not in [res["url"] for res in successful_results]) # 步骤 3: 对失败的URL进行本地解析 if urls_to_retry_locally: logger.info(f"开始本地解析以下URL: {urls_to_retry_locally}") local_tasks = [self._local_parse_and_summarize(url) for url in urls_to_retry_locally] local_results = await asyncio.gather(*local_tasks) - + for i, res in enumerate(local_results): url = urls_to_retry_locally[i] if "error" in res: @@ -228,13 +228,9 @@ class URLParserTool(BaseTool): return {"error": "无法从所有给定的URL获取内容。", "details": error_messages} formatted_content = format_url_parse_results(successful_results) - - result = { - "type": "url_parse_result", - "content": formatted_content, - "errors": error_messages - } - + + result = {"type": "url_parse_result", "content": formatted_content, "errors": error_messages} + # 保存到缓存 if "error" not in result: await tool_cache.set(self.name, function_args, current_file_path, result) diff --git a/src/plugins/built_in/web_search_tool/tools/web_search.py b/src/plugins/built_in/web_search_tool/tools/web_search.py index c09ad5e92..3e4039cb8 100644 --- a/src/plugins/built_in/web_search_tool/tools/web_search.py +++ b/src/plugins/built_in/web_search_tool/tools/web_search.py @@ -1,6 +1,7 @@ """ Web search tool implementation """ + import asyncio from typing import Any, Dict, List @@ -22,14 +23,23 @@ class WebSurfingTool(BaseTool): """ 网络搜索工具 """ + name: str = "web_search" - description: str = "用于执行网络搜索。当用户明确要求搜索,或者需要获取关于公司、产品、事件的最新信息、新闻或动态时,必须使用此工具" + description: str = ( + "用于执行网络搜索。当用户明确要求搜索,或者需要获取关于公司、产品、事件的最新信息、新闻或动态时,必须使用此工具" + ) available_for_llm: bool = True parameters = [ ("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None), ("num_results", ToolParamType.INTEGER, "期望每个搜索引擎返回的搜索结果数量,默认为5。", False, None), - ("time_range", ToolParamType.STRING, "指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'。", False, ["any", "week", "month"]) - ] # type: ignore + ( + "time_range", + ToolParamType.STRING, + "指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'。", + False, + ["any", "week", "month"], + ), + ] # type: ignore def __init__(self, plugin_config=None): super().__init__(plugin_config) @@ -38,7 +48,7 @@ class WebSurfingTool(BaseTool): "exa": ExaSearchEngine(), "tavily": TavilySearchEngine(), "ddg": DDGSearchEngine(), - "bing": BingSearchEngine() + "bing": BingSearchEngine(), } async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: @@ -48,6 +58,7 @@ class WebSurfingTool(BaseTool): # 获取当前文件路径用于缓存键 import os + current_file_path = os.path.abspath(__file__) # 检查缓存 @@ -59,7 +70,7 @@ class WebSurfingTool(BaseTool): # 读取搜索配置 enabled_engines = config_api.get_global_config("web_search.enabled_engines", ["ddg"]) search_strategy = config_api.get_global_config("web_search.search_strategy", "single") - + logger.info(f"开始搜索,策略: {search_strategy}, 启用引擎: {enabled_engines}, 参数: '{function_args}'") # 根据策略执行搜索 @@ -69,17 +80,19 @@ class WebSurfingTool(BaseTool): result = await self._execute_fallback_search(function_args, enabled_engines) else: # single result = await self._execute_single_search(function_args, enabled_engines) - + # 保存到缓存 if "error" not in result: await tool_cache.set(self.name, function_args, current_file_path, result, semantic_query=query) - + return result - async def _execute_parallel_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]: + async def _execute_parallel_search( + self, function_args: Dict[str, Any], enabled_engines: List[str] + ) -> Dict[str, Any]: """并行搜索策略:同时使用所有启用的搜索引擎""" search_tasks = [] - + for engine_name in enabled_engines: engine = self.engines.get(engine_name) if engine and engine.is_available(): @@ -92,7 +105,7 @@ class WebSurfingTool(BaseTool): try: search_results_lists = await asyncio.gather(*search_tasks, return_exceptions=True) - + all_results = [] for result in search_results_lists: if isinstance(result, list): @@ -103,7 +116,7 @@ class WebSurfingTool(BaseTool): # 去重并格式化 unique_results = deduplicate_results(all_results) formatted_content = format_search_results(unique_results) - + return { "type": "web_search_result", "content": formatted_content, @@ -113,30 +126,32 @@ class WebSurfingTool(BaseTool): logger.error(f"执行并行网络搜索时发生异常: {e}", exc_info=True) return {"error": f"执行网络搜索时发生严重错误: {str(e)}"} - async def _execute_fallback_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]: + async def _execute_fallback_search( + self, function_args: Dict[str, Any], enabled_engines: List[str] + ) -> Dict[str, Any]: """回退搜索策略:按顺序尝试搜索引擎,失败则尝试下一个""" for engine_name in enabled_engines: engine = self.engines.get(engine_name) if not engine or not engine.is_available(): continue - + try: custom_args = function_args.copy() custom_args["num_results"] = custom_args.get("num_results", 5) - + results = await engine.search(custom_args) - + if results: # 如果有结果,直接返回 formatted_content = format_search_results(results) return { "type": "web_search_result", "content": formatted_content, } - + except Exception as e: logger.warning(f"{engine_name} 搜索失败,尝试下一个引擎: {e}") continue - + return {"error": "所有搜索引擎都失败了。"} async def _execute_single_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]: @@ -145,20 +160,20 @@ class WebSurfingTool(BaseTool): engine = self.engines.get(engine_name) if not engine or not engine.is_available(): continue - + try: custom_args = function_args.copy() custom_args["num_results"] = custom_args.get("num_results", 5) - + results = await engine.search(custom_args) formatted_content = format_search_results(results) return { "type": "web_search_result", "content": formatted_content, } - + except Exception as e: logger.error(f"{engine_name} 搜索失败: {e}") return {"error": f"{engine_name} 搜索失败: {str(e)}"} - + return {"error": "没有可用的搜索引擎。"} diff --git a/src/plugins/built_in/web_search_tool/utils/api_key_manager.py b/src/plugins/built_in/web_search_tool/utils/api_key_manager.py index f8e0afa71..07757cdb1 100644 --- a/src/plugins/built_in/web_search_tool/utils/api_key_manager.py +++ b/src/plugins/built_in/web_search_tool/utils/api_key_manager.py @@ -1,24 +1,25 @@ """ API密钥管理器,提供轮询机制 """ + import itertools from typing import List, Optional, TypeVar, Generic, Callable from src.common.logger import get_logger logger = get_logger("api_key_manager") -T = TypeVar('T') +T = TypeVar("T") class APIKeyManager(Generic[T]): """ API密钥管理器,支持轮询机制 """ - + def __init__(self, api_keys: List[str], client_factory: Callable[[str], T], service_name: str = "Unknown"): """ 初始化API密钥管理器 - + Args: api_keys: API密钥列表 client_factory: 客户端工厂函数,接受API密钥参数并返回客户端实例 @@ -27,14 +28,14 @@ class APIKeyManager(Generic[T]): self.service_name = service_name self.clients: List[T] = [] self.client_cycle: Optional[itertools.cycle] = None - + if api_keys: # 过滤有效的API密钥,排除None、空字符串、"None"字符串等 valid_keys = [] for key in api_keys: if isinstance(key, str) and key.strip() and key.strip().lower() not in ("none", "null", ""): valid_keys.append(key.strip()) - + if valid_keys: try: self.clients = [client_factory(key) for key in valid_keys] @@ -48,35 +49,33 @@ class APIKeyManager(Generic[T]): logger.warning(f"⚠️ {service_name} API Keys 配置无效(包含None或空值),{service_name} 功能将不可用") else: logger.warning(f"⚠️ {service_name} API Keys 未配置,{service_name} 功能将不可用") - + def is_available(self) -> bool: """检查是否有可用的客户端""" return bool(self.clients and self.client_cycle) - + def get_next_client(self) -> Optional[T]: """获取下一个客户端(轮询)""" if not self.is_available(): return None return next(self.client_cycle) - + def get_client_count(self) -> int: """获取可用客户端数量""" return len(self.clients) def create_api_key_manager_from_config( - config_keys: Optional[List[str]], - client_factory: Callable[[str], T], - service_name: str + config_keys: Optional[List[str]], client_factory: Callable[[str], T], service_name: str ) -> APIKeyManager[T]: """ 从配置创建API密钥管理器的便捷函数 - + Args: config_keys: 从配置读取的API密钥列表 client_factory: 客户端工厂函数 service_name: 服务名称 - + Returns: API密钥管理器实例 """ diff --git a/src/plugins/built_in/web_search_tool/utils/formatters.py b/src/plugins/built_in/web_search_tool/utils/formatters.py index 434f6f3c8..df1e4ea18 100644 --- a/src/plugins/built_in/web_search_tool/utils/formatters.py +++ b/src/plugins/built_in/web_search_tool/utils/formatters.py @@ -1,6 +1,7 @@ """ Formatters for web search results """ + from typing import List, Dict, Any @@ -13,15 +14,15 @@ def format_search_results(results: List[Dict[str, Any]]) -> str: formatted_string = "根据网络搜索结果:\n\n" for i, res in enumerate(results, 1): - title = res.get("title", '无标题') - url = res.get("url", '#') - snippet = res.get("snippet", '无摘要') + title = res.get("title", "无标题") + url = res.get("url", "#") + snippet = res.get("snippet", "无摘要") provider = res.get("provider", "未知来源") - + formatted_string += f"{i}. **{title}** (来自: {provider})\n" formatted_string += f" - 摘要: {snippet}\n" formatted_string += f" - 来源: {url}\n\n" - + return formatted_string @@ -31,10 +32,10 @@ def format_url_parse_results(results: List[Dict[str, Any]]) -> str: """ formatted_parts = [] for res in results: - title = res.get('title', '无标题') - url = res.get('url', '#') - snippet = res.get('snippet', '无摘要') - source = res.get('source', '未知') + title = res.get("title", "无标题") + url = res.get("url", "#") + snippet = res.get("snippet", "无摘要") + source = res.get("source", "未知") formatted_string = f"**{title}**\n" formatted_string += f"**内容摘要**:\n{snippet}\n" diff --git a/src/plugins/built_in/web_search_tool/utils/url_utils.py b/src/plugins/built_in/web_search_tool/utils/url_utils.py index 74afbc819..5bdde0a55 100644 --- a/src/plugins/built_in/web_search_tool/utils/url_utils.py +++ b/src/plugins/built_in/web_search_tool/utils/url_utils.py @@ -1,6 +1,7 @@ """ URL processing utilities """ + import re from typing import List @@ -12,11 +13,11 @@ def parse_urls_from_input(urls_input) -> List[str]: if isinstance(urls_input, str): # 如果是字符串,尝试解析为URL列表 # 提取所有HTTP/HTTPS URL - url_pattern = r'https?://[^\s\],]+' + url_pattern = r"https?://[^\s\],]+" urls = re.findall(url_pattern, urls_input) if not urls: # 如果没有找到标准URL,将整个字符串作为单个URL - if urls_input.strip().startswith(('http://', 'https://')): + if urls_input.strip().startswith(("http://", "https://")): urls = [urls_input.strip()] else: return [] @@ -24,7 +25,7 @@ def parse_urls_from_input(urls_input) -> List[str]: urls = [url.strip() for url in urls_input if isinstance(url, str) and url.strip()] else: return [] - + return urls @@ -34,6 +35,6 @@ def validate_urls(urls: List[str]) -> List[str]: """ valid_urls = [] for url in urls: - if url.startswith(('http://', 'https://')): + if url.startswith(("http://", "https://")): valid_urls.append(url) return valid_urls diff --git a/src/plugins/reminder_plugin/plugin.py b/src/plugins/reminder_plugin/plugin.py index 8a833f5be..31ea899df 100644 --- a/src/plugins/reminder_plugin/plugin.py +++ b/src/plugins/reminder_plugin/plugin.py @@ -21,8 +21,18 @@ logger = get_logger(__name__) # ============================ AsyncTask ============================ + class ReminderTask(AsyncTask): - def __init__(self, delay: float, stream_id: str, is_group: bool, target_user_id: str, target_user_name: str, event_details: str, creator_name: str): + def __init__( + self, + delay: float, + stream_id: str, + is_group: bool, + target_user_id: str, + target_user_name: str, + event_details: str, + creator_name: str, + ): super().__init__(task_name=f"ReminderTask_{target_user_id}_{datetime.now().timestamp()}") self.delay = delay self.stream_id = stream_id @@ -37,22 +47,22 @@ class ReminderTask(AsyncTask): if self.delay > 0: logger.info(f"等待 {self.delay:.2f} 秒后执行提醒...") await asyncio.sleep(self.delay) - + logger.info(f"执行提醒任务: 给 {self.target_user_name} 发送关于 '{self.event_details}' 的提醒") reminder_text = f"叮咚!这是 {self.creator_name} 让我准时提醒你的事情:\n\n{self.event_details}" if self.is_group: # 在群聊中,构造 @ 消息段并发送 - group_id = self.stream_id.split('_')[-1] if '_' in self.stream_id else self.stream_id + group_id = self.stream_id.split("_")[-1] if "_" in self.stream_id else self.stream_id message_payload = [ {"type": "at", "data": {"qq": self.target_user_id}}, - {"type": "text", "data": {"text": f" {reminder_text}"}} + {"type": "text", "data": {"text": f" {reminder_text}"}}, ] await send_api.adapter_command_to_stream( action="send_group_msg", params={"group_id": group_id, "message": message_payload}, - stream_id=self.stream_id + stream_id=self.stream_id, ) else: # 在私聊中,直接发送文本 @@ -66,6 +76,7 @@ class ReminderTask(AsyncTask): # =============================== Actions =============================== + class RemindAction(BaseAction): """一个能从对话中智能识别并设置定时提醒的动作。""" @@ -95,12 +106,12 @@ class RemindAction(BaseAction): action_parameters = { "user_name": "需要被提醒的人的称呼或名字,如果没有明确指定给某人,则默认为'自己'", "remind_time": "描述提醒时间的自然语言字符串,例如'十分钟后'或'明天下午3点'", - "event_details": "需要提醒的具体事件内容" + "event_details": "需要提醒的具体事件内容", } action_require = [ "当用户请求在未来的某个时间点提醒他/她或别人某件事时使用", "适用于包含明确时间信息和事件描述的对话", - "例如:'10分钟后提醒我收快递'、'明天早上九点喊一下李四参加晨会'" + "例如:'10分钟后提醒我收快递'、'明天早上九点喊一下李四参加晨会'", ] async def execute(self) -> Tuple[bool, str]: @@ -110,7 +121,15 @@ class RemindAction(BaseAction): event_details = self.action_data.get("event_details") if not all([user_name, remind_time_str, event_details]): - missing_params = [p for p, v in {"user_name": user_name, "remind_time": remind_time_str, "event_details": event_details}.items() if not v] + missing_params = [ + p + for p, v in { + "user_name": user_name, + "remind_time": remind_time_str, + "event_details": event_details, + }.items() + if not v + ] error_msg = f"缺少必要的提醒参数: {', '.join(missing_params)}" logger.warning(f"[ReminderPlugin] LLM未能提取完整参数: {error_msg}") return False, error_msg @@ -135,9 +154,9 @@ class RemindAction(BaseAction): person_manager = get_person_info_manager() user_id_to_remind = None user_name_to_remind = "" - + assert isinstance(user_name, str) - + if user_name.strip() in ["自己", "我", "me"]: user_id_to_remind = self.user_id user_name_to_remind = self.user_nickname @@ -154,7 +173,7 @@ class RemindAction(BaseAction): try: assert user_id_to_remind is not None assert event_details is not None - + reminder_task = ReminderTask( delay=delay_seconds, stream_id=self.chat_id, @@ -162,14 +181,14 @@ class RemindAction(BaseAction): target_user_id=str(user_id_to_remind), target_user_name=str(user_name_to_remind), event_details=str(event_details), - creator_name=str(self.user_nickname) + creator_name=str(self.user_nickname), ) await async_task_manager.add_task(reminder_task) - + # 4. 发送确认消息 confirm_message = f"好的,我记下了。\n将在 {target_time.strftime('%Y-%m-%d %H:%M:%S')} 提醒 {user_name_to_remind}:\n{event_details}" await self.send_text(confirm_message) - + return True, "提醒设置成功" except Exception as e: logger.error(f"[ReminderPlugin] 创建提醒任务时出错: {e}", exc_info=True) @@ -179,6 +198,7 @@ class RemindAction(BaseAction): # =============================== Plugin =============================== + @register_plugin class ReminderPlugin(BasePlugin): """一个能从对话中智能识别并设置定时提醒的插件。""" @@ -193,6 +213,4 @@ class ReminderPlugin(BasePlugin): def get_plugin_components(self) -> List[Tuple[ActionInfo, Type[BaseAction]]]: """注册插件的所有功能组件。""" - return [ - (RemindAction.get_action_info(), RemindAction) - ] + return [(RemindAction.get_action_info(), RemindAction)] diff --git a/src/schedule/database.py b/src/schedule/database.py index 88337f4df..a2d9d3046 100644 --- a/src/schedule/database.py +++ b/src/schedule/database.py @@ -290,4 +290,4 @@ def has_active_plans(month: str) -> bool: return count > 0 except Exception as e: logger.error(f"检查 {month} 的有效月度计划时发生错误: {e}") - return False \ No newline at end of file + return False diff --git a/src/schedule/llm_generator.py b/src/schedule/llm_generator.py index 9dda68f80..5703e10da 100644 --- a/src/schedule/llm_generator.py +++ b/src/schedule/llm_generator.py @@ -221,4 +221,4 @@ class MonthlyPlanLLMGenerator: return plans except Exception as e: logger.error(f"解析月度计划响应时发生错误: {e}") - return [] \ No newline at end of file + return [] diff --git a/src/schedule/plan_manager.py b/src/schedule/plan_manager.py index 0fae5c381..d72f55275 100644 --- a/src/schedule/plan_manager.py +++ b/src/schedule/plan_manager.py @@ -102,4 +102,4 @@ class PlanManager: def get_plans_for_schedule(self, month: str, max_count: int) -> List: avoid_days = global_config.planning_system.avoid_repetition_days - return get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days) \ No newline at end of file + return get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days) diff --git a/src/schedule/schemas.py b/src/schedule/schemas.py index 5eb7c003a..a733731be 100644 --- a/src/schedule/schemas.py +++ b/src/schedule/schemas.py @@ -96,4 +96,4 @@ class ScheduleData(BaseModel): covered[i] = True # 检查是否所有分钟都被覆盖 - return all(covered) \ No newline at end of file + return all(covered)