ruff,私聊视为提及了bot
This commit is contained in:
25
bot.py
25
bot.py
@@ -34,16 +34,18 @@ script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
os.chdir(script_dir)
|
||||
logger.info(f"已设置工作目录为: {script_dir}")
|
||||
|
||||
|
||||
# 检查并创建.env文件
|
||||
def ensure_env_file():
|
||||
"""确保.env文件存在,如果不存在则从模板创建"""
|
||||
env_file = Path(".env")
|
||||
template_env = Path("template/template.env")
|
||||
|
||||
|
||||
if not env_file.exists():
|
||||
if template_env.exists():
|
||||
logger.info("未找到.env文件,正在从模板创建...")
|
||||
import shutil
|
||||
|
||||
shutil.copy(template_env, env_file)
|
||||
logger.info("已从template/template.env创建.env文件")
|
||||
logger.warning("请编辑.env文件,将EULA_CONFIRMED设置为true并配置其他必要参数")
|
||||
@@ -51,6 +53,7 @@ def ensure_env_file():
|
||||
logger.error("未找到.env文件和template.env模板文件")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# 确保环境文件存在
|
||||
ensure_env_file()
|
||||
|
||||
@@ -130,32 +133,32 @@ async def graceful_shutdown():
|
||||
def check_eula():
|
||||
"""检查EULA和隐私条款确认状态 - 环境变量版(类似Minecraft)"""
|
||||
# 检查环境变量中的EULA确认
|
||||
eula_confirmed = os.getenv('EULA_CONFIRMED', '').lower()
|
||||
|
||||
if eula_confirmed == 'true':
|
||||
eula_confirmed = os.getenv("EULA_CONFIRMED", "").lower()
|
||||
|
||||
if eula_confirmed == "true":
|
||||
logger.info("EULA已通过环境变量确认")
|
||||
return
|
||||
|
||||
|
||||
# 如果没有确认,提示用户
|
||||
confirm_logger.critical("您需要同意EULA和隐私条款才能使用MoFox_Bot")
|
||||
confirm_logger.critical("请阅读以下文件:")
|
||||
confirm_logger.critical(" - EULA.md (用户许可协议)")
|
||||
confirm_logger.critical(" - PRIVACY.md (隐私条款)")
|
||||
confirm_logger.critical("然后编辑 .env 文件,将 'EULA_CONFIRMED=false' 改为 'EULA_CONFIRMED=true'")
|
||||
|
||||
|
||||
# 等待用户确认
|
||||
while True:
|
||||
try:
|
||||
load_dotenv(override=True) # 重新加载.env文件
|
||||
|
||||
eula_confirmed = os.getenv('EULA_CONFIRMED', '').lower()
|
||||
if eula_confirmed == 'true':
|
||||
|
||||
eula_confirmed = os.getenv("EULA_CONFIRMED", "").lower()
|
||||
if eula_confirmed == "true":
|
||||
confirm_logger.info("EULA确认成功,感谢您的同意")
|
||||
return
|
||||
|
||||
|
||||
confirm_logger.critical("请修改 .env 文件中的 EULA_CONFIRMED=true 后重新启动程序")
|
||||
input("按Enter键检查.env文件状态...")
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
confirm_logger.info("用户取消,程序退出")
|
||||
sys.exit(0)
|
||||
|
||||
@@ -20,25 +20,26 @@ files_to_update = [
|
||||
"src/mais4u/mais4u_chat/s4u_mood_manager.py",
|
||||
"src/plugin_system/core/tool_use.py",
|
||||
"src/chat/memory_system/memory_activator.py",
|
||||
"src/chat/utils/smart_prompt.py"
|
||||
"src/chat/utils/smart_prompt.py",
|
||||
]
|
||||
|
||||
|
||||
def update_prompt_imports(file_path):
|
||||
"""更新文件中的Prompt导入"""
|
||||
if not os.path.exists(file_path):
|
||||
print(f"文件不存在: {file_path}")
|
||||
return False
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
|
||||
# 替换导入语句
|
||||
old_import = "from src.chat.utils.prompt_builder import Prompt, global_prompt_manager"
|
||||
new_import = "from src.chat.utils.prompt import Prompt, global_prompt_manager"
|
||||
|
||||
|
||||
if old_import in content:
|
||||
new_content = content.replace(old_import, new_import)
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(new_content)
|
||||
print(f"已更新: {file_path}")
|
||||
return True
|
||||
@@ -46,14 +47,16 @@ def update_prompt_imports(file_path):
|
||||
print(f"无需更新: {file_path}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
updated_count = 0
|
||||
for file_path in files_to_update:
|
||||
if update_prompt_imports(file_path):
|
||||
updated_count += 1
|
||||
|
||||
|
||||
print(f"\n更新完成!共更新了 {updated_count} 个文件")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@@ -5,4 +5,4 @@
|
||||
|
||||
from src.chat.affinity_flow.afc_manager import afc_manager
|
||||
|
||||
__all__ = ['afc_manager', 'AFCManager', 'AffinityFlowChatter']
|
||||
__all__ = ["afc_manager", "AFCManager", "AffinityFlowChatter"]
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
亲和力聊天处理流管理器
|
||||
管理不同聊天流的亲和力聊天处理流,统一获取新消息并分发到对应的亲和力聊天处理流
|
||||
"""
|
||||
|
||||
import time
|
||||
import traceback
|
||||
from typing import Dict, Optional, List
|
||||
@@ -20,7 +21,7 @@ class AFCManager:
|
||||
|
||||
def __init__(self):
|
||||
self.affinity_flow_chatters: Dict[str, "AffinityFlowChatter"] = {}
|
||||
'''所有聊天流的亲和力聊天处理流,stream_id -> affinity_flow_chatter'''
|
||||
"""所有聊天流的亲和力聊天处理流,stream_id -> affinity_flow_chatter"""
|
||||
|
||||
# 动作管理器
|
||||
self.action_manager = ActionManager()
|
||||
@@ -40,11 +41,7 @@ class AFCManager:
|
||||
# 创建增强版规划器
|
||||
planner = ActionPlanner(stream_id, self.action_manager)
|
||||
|
||||
chatter = AffinityFlowChatter(
|
||||
stream_id=stream_id,
|
||||
planner=planner,
|
||||
action_manager=self.action_manager
|
||||
)
|
||||
chatter = AffinityFlowChatter(stream_id=stream_id, planner=planner, action_manager=self.action_manager)
|
||||
self.affinity_flow_chatters[stream_id] = chatter
|
||||
logger.info(f"创建新的亲和力聊天处理器: {stream_id}")
|
||||
|
||||
@@ -74,7 +71,6 @@ class AFCManager:
|
||||
"executed_count": 0,
|
||||
}
|
||||
|
||||
|
||||
def get_chatter_stats(self, stream_id: str) -> Optional[Dict[str, any]]:
|
||||
"""获取聊天处理器统计"""
|
||||
if stream_id in self.affinity_flow_chatters:
|
||||
@@ -131,4 +127,5 @@ class AFCManager:
|
||||
self.affinity_flow_chatters[stream_id].update_interest_keywords(new_keywords)
|
||||
logger.info(f"已更新聊天流 {stream_id} 的兴趣关键词: {list(new_keywords.keys())}")
|
||||
|
||||
afc_manager = AFCManager()
|
||||
|
||||
afc_manager = AFCManager()
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
亲和力聊天处理器
|
||||
单个聊天流的处理器,负责处理特定聊天流的完整交互流程
|
||||
"""
|
||||
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
@@ -57,10 +58,7 @@ class AffinityFlowChatter:
|
||||
unread_messages = context.get_unread_messages()
|
||||
|
||||
# 使用增强版规划器处理消息
|
||||
actions, target_message = await self.planner.plan(
|
||||
mode=ChatMode.FOCUS,
|
||||
context=context
|
||||
)
|
||||
actions, target_message = await self.planner.plan(mode=ChatMode.FOCUS, context=context)
|
||||
self.stats["plans_created"] += 1
|
||||
|
||||
# 执行动作(如果规划器返回了动作)
|
||||
@@ -84,7 +82,9 @@ class AffinityFlowChatter:
|
||||
**execution_result,
|
||||
}
|
||||
|
||||
logger.info(f"聊天流 {self.stream_id} StreamContext处理成功: 动作数={result['actions_count']}, 未读消息={result['unread_messages_processed']}")
|
||||
logger.info(
|
||||
f"聊天流 {self.stream_id} StreamContext处理成功: 动作数={result['actions_count']}, 未读消息={result['unread_messages_processed']}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -197,7 +197,9 @@ class AffinityFlowChatter:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""详细字符串表示"""
|
||||
return (f"AffinityFlowChatter(stream_id={self.stream_id}, "
|
||||
f"messages_processed={self.stats['messages_processed']}, "
|
||||
f"plans_created={self.stats['plans_created']}, "
|
||||
f"last_activity={datetime.fromtimestamp(self.last_activity_time)})")
|
||||
return (
|
||||
f"AffinityFlowChatter(stream_id={self.stream_id}, "
|
||||
f"messages_processed={self.stats['messages_processed']}, "
|
||||
f"plans_created={self.stats['plans_created']}, "
|
||||
f"last_activity={datetime.fromtimestamp(self.last_activity_time)})"
|
||||
)
|
||||
|
||||
@@ -38,7 +38,9 @@ class InterestScoringSystem:
|
||||
# 连续不回复概率提升
|
||||
self.no_reply_count = 0
|
||||
self.max_no_reply_count = affinity_config.max_no_reply_count
|
||||
self.probability_boost_per_no_reply = affinity_config.no_reply_threshold_adjustment / affinity_config.max_no_reply_count # 每次不回复增加的概率
|
||||
self.probability_boost_per_no_reply = (
|
||||
affinity_config.no_reply_threshold_adjustment / affinity_config.max_no_reply_count
|
||||
) # 每次不回复增加的概率
|
||||
|
||||
# 用户关系数据
|
||||
self.user_relationships: Dict[str, float] = {} # user_id -> relationship_score
|
||||
@@ -153,7 +155,9 @@ class InterestScoringSystem:
|
||||
|
||||
# 返回匹配分数,考虑置信度和匹配标签数量
|
||||
affinity_config = global_config.affinity_flow
|
||||
match_count_bonus = min(len(match_result.matched_tags) * affinity_config.match_count_bonus, affinity_config.max_match_bonus)
|
||||
match_count_bonus = min(
|
||||
len(match_result.matched_tags) * affinity_config.match_count_bonus, affinity_config.max_match_bonus
|
||||
)
|
||||
final_score = match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus
|
||||
logger.debug(
|
||||
f"⚖️ 最终分数计算: 总分({match_result.overall_score:.3f}) × 1.3 × 置信度({match_result.confidence:.3f}) + 标签数量奖励({match_count_bonus:.3f}) = {final_score:.3f}"
|
||||
@@ -263,7 +267,17 @@ class InterestScoringSystem:
|
||||
if not msg.processed_plain_text:
|
||||
return 0.0
|
||||
|
||||
if msg.is_mentioned or (bot_nickname and bot_nickname in msg.processed_plain_text):
|
||||
# 检查是否被提及
|
||||
is_mentioned = msg.is_mentioned or (bot_nickname and bot_nickname in msg.processed_plain_text)
|
||||
|
||||
# 检查是否为私聊(group_info为None表示私聊)
|
||||
is_private_chat = msg.group_info is None
|
||||
|
||||
# 如果被提及或是私聊,都视为提及了bot
|
||||
if is_mentioned or is_private_chat:
|
||||
logger.debug(f"🔍 提及检测 - 被提及: {is_mentioned}, 私聊: {is_private_chat}")
|
||||
if is_private_chat and not is_mentioned:
|
||||
logger.debug("💬 私聊消息自动视为提及bot")
|
||||
return global_config.affinity_flow.mention_bot_interest_score
|
||||
|
||||
return 0.0
|
||||
@@ -282,7 +296,9 @@ class InterestScoringSystem:
|
||||
logger.debug(f"📋 基础阈值: {base_threshold:.3f}")
|
||||
|
||||
# 如果被提及,降低阈值
|
||||
if score.mentioned_score >= global_config.affinity_flow.mention_bot_interest_score * 0.5: # 使用提及bot兴趣分的一半作为判断阈值
|
||||
if (
|
||||
score.mentioned_score >= global_config.affinity_flow.mention_bot_interest_score * 0.5
|
||||
): # 使用提及bot兴趣分的一半作为判断阈值
|
||||
base_threshold = self.mention_threshold
|
||||
logger.debug(f"📣 消息提及了机器人,使用降低阈值: {base_threshold:.3f}")
|
||||
|
||||
@@ -325,7 +341,9 @@ class InterestScoringSystem:
|
||||
|
||||
def update_user_relationship(self, user_id: str, relationship_change: float):
|
||||
"""更新用户关系"""
|
||||
old_score = self.user_relationships.get(user_id, global_config.affinity_flow.base_relationship_score) # 默认新用户分数
|
||||
old_score = self.user_relationships.get(
|
||||
user_id, global_config.affinity_flow.base_relationship_score
|
||||
) # 默认新用户分数
|
||||
new_score = max(0.0, min(1.0, old_score + relationship_change))
|
||||
|
||||
self.user_relationships[user_id] = new_score
|
||||
|
||||
@@ -116,6 +116,7 @@ class UserRelationshipTracker:
|
||||
try:
|
||||
# 获取bot人设信息
|
||||
from src.individuality.individuality import Individuality
|
||||
|
||||
individuality = Individuality()
|
||||
bot_personality = await individuality.get_personality_block()
|
||||
|
||||
@@ -168,7 +169,17 @@ class UserRelationshipTracker:
|
||||
# 清理LLM响应,移除可能的格式标记
|
||||
cleaned_response = self._clean_llm_json_response(llm_response)
|
||||
response_data = json.loads(cleaned_response)
|
||||
new_score = max(0.0, min(1.0, float(response_data.get("new_relationship_score", global_config.affinity_flow.base_relationship_score))))
|
||||
new_score = max(
|
||||
0.0,
|
||||
min(
|
||||
1.0,
|
||||
float(
|
||||
response_data.get(
|
||||
"new_relationship_score", global_config.affinity_flow.base_relationship_score
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
if self.interest_scoring_system:
|
||||
self.interest_scoring_system.update_user_relationship(
|
||||
@@ -295,7 +306,9 @@ class UserRelationshipTracker:
|
||||
# 更新缓存
|
||||
self.user_relationship_cache[user_id] = {
|
||||
"relationship_text": relationship_data.get("relationship_text", ""),
|
||||
"relationship_score": relationship_data.get("relationship_score", global_config.affinity_flow.base_relationship_score),
|
||||
"relationship_score": relationship_data.get(
|
||||
"relationship_score", global_config.affinity_flow.base_relationship_score
|
||||
),
|
||||
"last_tracked": time.time(),
|
||||
}
|
||||
return relationship_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)
|
||||
@@ -386,7 +399,11 @@ class UserRelationshipTracker:
|
||||
|
||||
# 获取当前关系数据
|
||||
current_relationship = self._get_user_relationship_from_db(user_id)
|
||||
current_score = current_relationship.get("relationship_score", global_config.affinity_flow.base_relationship_score) if current_relationship else global_config.affinity_flow.base_relationship_score
|
||||
current_score = (
|
||||
current_relationship.get("relationship_score", global_config.affinity_flow.base_relationship_score)
|
||||
if current_relationship
|
||||
else global_config.affinity_flow.base_relationship_score
|
||||
)
|
||||
current_text = current_relationship.get("relationship_text", "新用户") if current_relationship else "新用户"
|
||||
|
||||
# 使用LLM分析并更新关系
|
||||
@@ -501,6 +518,7 @@ class UserRelationshipTracker:
|
||||
|
||||
# 获取bot人设信息
|
||||
from src.individuality.individuality import Individuality
|
||||
|
||||
individuality = Individuality()
|
||||
bot_personality = await individuality.get_personality_block()
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
"""
|
||||
表情包发送历史记录模块
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Dict
|
||||
from collections import deque
|
||||
@@ -26,15 +27,15 @@ def add_emoji_to_history(chat_id: str, emoji_description: str):
|
||||
"""
|
||||
if not chat_id or not emoji_description:
|
||||
return
|
||||
|
||||
|
||||
# 如果当前聊天还没有历史记录,则创建一个新的 deque
|
||||
if chat_id not in _history_cache:
|
||||
_history_cache[chat_id] = deque(maxlen=MAX_HISTORY_SIZE)
|
||||
|
||||
|
||||
# 添加新表情到历史记录
|
||||
history = _history_cache[chat_id]
|
||||
history.append(emoji_description)
|
||||
|
||||
|
||||
logger.debug(f"已将表情 '{emoji_description}' 添加到聊天 {chat_id} 的内存历史中")
|
||||
|
||||
|
||||
@@ -50,10 +51,10 @@ def get_recent_emojis(chat_id: str, limit: int = 5) -> List[str]:
|
||||
return []
|
||||
|
||||
history = _history_cache[chat_id]
|
||||
|
||||
|
||||
# 从 deque 的右侧(即最近添加的)开始取
|
||||
num_to_get = min(limit, len(history))
|
||||
recent_emojis = [history[-i] for i in range(1, num_to_get + 1)]
|
||||
|
||||
|
||||
logger.debug(f"为聊天 {chat_id} 从内存中获取到最近 {len(recent_emojis)} 个表情: {recent_emojis}")
|
||||
return recent_emojis
|
||||
|
||||
@@ -477,7 +477,7 @@ class EmojiManager:
|
||||
emoji_options_str = ""
|
||||
for i, emoji in enumerate(candidate_emojis):
|
||||
# 为每个表情包创建一个编号和它的详细描述
|
||||
emoji_options_str += f"编号: {i+1}\n描述: {emoji.description}\n\n"
|
||||
emoji_options_str += f"编号: {i + 1}\n描述: {emoji.description}\n\n"
|
||||
|
||||
# 精心设计的prompt,引导LLM做出选择
|
||||
prompt = f"""
|
||||
@@ -524,10 +524,8 @@ class EmojiManager:
|
||||
self.record_usage(selected_emoji.hash)
|
||||
_time_end = time.time()
|
||||
|
||||
logger.info(
|
||||
f"找到匹配描述的表情包: {selected_emoji.description}, 耗时: {(_time_end - _time_start):.2f}s"
|
||||
)
|
||||
|
||||
logger.info(f"找到匹配描述的表情包: {selected_emoji.description}, 耗时: {(_time_end - _time_start):.2f}s")
|
||||
|
||||
# 8. 返回选中的表情包信息
|
||||
return selected_emoji.full_path, f"[表情包:{selected_emoji.description}]", text_emotion
|
||||
|
||||
@@ -627,8 +625,9 @@ class EmojiManager:
|
||||
|
||||
# 无论steal_emoji是否开启,都检查emoji文件夹以支持手动注册
|
||||
# 只有在需要腾出空间或填充表情库时,才真正执行注册
|
||||
if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or \
|
||||
(self.emoji_num < self.emoji_num_max):
|
||||
if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or (
|
||||
self.emoji_num < self.emoji_num_max
|
||||
):
|
||||
try:
|
||||
# 获取目录下所有图片文件
|
||||
files_to_process = [
|
||||
@@ -931,16 +930,21 @@ class EmojiManager:
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() if Image.open(io.BytesIO(image_bytes)).format else "jpeg"
|
||||
|
||||
image_format = (
|
||||
Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
if Image.open(io.BytesIO(image_bytes)).format
|
||||
else "jpeg"
|
||||
)
|
||||
|
||||
# 2. 检查数据库中是否已存在该表情包的描述,实现复用
|
||||
existing_description = None
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
existing_image = session.query(Images).filter(
|
||||
(Images.emoji_hash == image_hash) & (Images.type == "emoji")
|
||||
).one_or_none()
|
||||
existing_image = (
|
||||
session.query(Images)
|
||||
.filter((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
|
||||
.one_or_none()
|
||||
)
|
||||
if existing_image and existing_image.description:
|
||||
existing_description = existing_image.description
|
||||
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")
|
||||
|
||||
@@ -14,6 +14,7 @@ Chat Frequency Analyzer
|
||||
- MIN_CHATS_FOR_PEAK: 在一个窗口内需要多少次聊天才能被认为是高峰时段。
|
||||
- MIN_GAP_BETWEEN_PEAKS_HOURS: 两个独立高峰时段之间的最小间隔(小时)。
|
||||
"""
|
||||
|
||||
import time as time_module
|
||||
from datetime import datetime, timedelta, time
|
||||
from typing import List, Tuple, Optional
|
||||
@@ -71,12 +72,14 @@ class ChatFrequencyAnalyzer:
|
||||
current_window_end = datetimes[i]
|
||||
|
||||
# 合并重叠或相邻的高峰时段
|
||||
if peak_windows and current_window_start - peak_windows[-1][1] < timedelta(hours=MIN_GAP_BETWEEN_PEAKS_HOURS):
|
||||
if peak_windows and current_window_start - peak_windows[-1][1] < timedelta(
|
||||
hours=MIN_GAP_BETWEEN_PEAKS_HOURS
|
||||
):
|
||||
# 扩展上一个窗口的结束时间
|
||||
peak_windows[-1] = (peak_windows[-1][0], current_window_end)
|
||||
else:
|
||||
peak_windows.append((current_window_start, current_window_end))
|
||||
|
||||
|
||||
return peak_windows
|
||||
|
||||
def get_peak_chat_times(self, chat_id: str) -> List[Tuple[time, time]]:
|
||||
@@ -99,7 +102,7 @@ class ChatFrequencyAnalyzer:
|
||||
return []
|
||||
|
||||
peak_datetime_windows = self._find_peak_windows(timestamps)
|
||||
|
||||
|
||||
# 将 datetime 窗口转换为 time 窗口,并进行归一化处理
|
||||
peak_time_windows = []
|
||||
for start_dt, end_dt in peak_datetime_windows:
|
||||
@@ -109,7 +112,7 @@ class ChatFrequencyAnalyzer:
|
||||
|
||||
# 更新缓存
|
||||
self._analysis_cache[chat_id] = (time_module.time(), peak_time_windows)
|
||||
|
||||
|
||||
return peak_time_windows
|
||||
|
||||
def is_in_peak_time(self, chat_id: str, now: Optional[datetime] = None) -> bool:
|
||||
@@ -125,7 +128,7 @@ class ChatFrequencyAnalyzer:
|
||||
"""
|
||||
if now is None:
|
||||
now = datetime.now()
|
||||
|
||||
|
||||
now_time = now.time()
|
||||
peak_times = self.get_peak_chat_times(chat_id)
|
||||
|
||||
@@ -136,7 +139,7 @@ class ChatFrequencyAnalyzer:
|
||||
else: # 跨天
|
||||
if now_time >= start_time or now_time <= end_time:
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ class ChatFrequencyTracker:
|
||||
now = time.time()
|
||||
if chat_id not in self._timestamps:
|
||||
self._timestamps[chat_id] = []
|
||||
|
||||
|
||||
self._timestamps[chat_id].append(now)
|
||||
logger.debug(f"为 chat_id '{chat_id}' 记录了新的聊天时间: {now}")
|
||||
self._save_timestamps()
|
||||
|
||||
@@ -14,6 +14,7 @@ Frequency-Based Proactive Trigger
|
||||
- TRIGGER_CHECK_INTERVAL_SECONDS: 触发器检查的周期(秒)。
|
||||
- COOLDOWN_HOURS: 在同一个高峰时段内触发一次后的冷却时间(小时)。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
@@ -21,6 +22,7 @@ from typing import Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.affinity_flow.afc_manager import afc_manager
|
||||
|
||||
# TODO: 需要重新实现主动思考和睡眠管理功能
|
||||
from .analyzer import chat_frequency_analyzer
|
||||
|
||||
@@ -65,7 +67,7 @@ class FrequencyBasedTrigger:
|
||||
continue
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
|
||||
for chat_id in all_chat_ids:
|
||||
# 3. 检查是否处于冷却时间内
|
||||
last_triggered_time = self._last_triggered.get(chat_id, 0)
|
||||
@@ -74,7 +76,6 @@ class FrequencyBasedTrigger:
|
||||
|
||||
# 4. 检查当前是否是该用户的高峰聊天时间
|
||||
if chat_frequency_analyzer.is_in_peak_time(chat_id, now):
|
||||
|
||||
# 5. 检查用户当前是否已有活跃的处理任务
|
||||
# 亲和力流系统不直接提供循环状态,通过检查最后活动时间来判断是否忙碌
|
||||
chatter = afc_manager.get_or_create_chatter(chat_id)
|
||||
@@ -87,13 +88,13 @@ class FrequencyBasedTrigger:
|
||||
if current_time - chatter.get_activity_time() < 60:
|
||||
logger.debug(f"用户 {chat_id} 的亲和力处理器正忙,本次不触发。")
|
||||
continue
|
||||
|
||||
|
||||
logger.info(f"检测到用户 {chat_id} 处于聊天高峰期,且处理器空闲,准备触发主动思考。")
|
||||
|
||||
|
||||
# 6. TODO: 亲和力流系统的主动思考机制需要另行实现
|
||||
# 目前先记录日志,等待后续实现
|
||||
logger.info(f"用户 {chat_id} 处于高峰期,但亲和力流的主动思考功能暂未实现")
|
||||
|
||||
|
||||
# 7. 更新触发时间,进入冷却
|
||||
self._last_triggered[chat_id] = time.time()
|
||||
|
||||
|
||||
@@ -4,14 +4,12 @@
|
||||
"""
|
||||
|
||||
from .bot_interest_manager import BotInterestManager, bot_interest_manager
|
||||
from src.common.data_models.bot_interest_data_model import (
|
||||
BotInterestTag, BotPersonalityInterests, InterestMatchResult
|
||||
)
|
||||
from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
|
||||
|
||||
__all__ = [
|
||||
"BotInterestManager",
|
||||
"bot_interest_manager",
|
||||
"BotInterestTag",
|
||||
"BotPersonalityInterests",
|
||||
"InterestMatchResult"
|
||||
]
|
||||
"InterestMatchResult",
|
||||
]
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
机器人兴趣标签管理系统
|
||||
基于人设生成兴趣标签,并使用embedding计算匹配度
|
||||
"""
|
||||
|
||||
import orjson
|
||||
import traceback
|
||||
from typing import List, Dict, Optional, Any
|
||||
@@ -10,9 +11,7 @@ import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.common.data_models.bot_interest_data_model import (
|
||||
BotPersonalityInterests, BotInterestTag, InterestMatchResult
|
||||
)
|
||||
from src.common.data_models.bot_interest_data_model import BotPersonalityInterests, BotInterestTag, InterestMatchResult
|
||||
|
||||
logger = get_logger("bot_interest_manager")
|
||||
|
||||
@@ -87,7 +86,7 @@ class BotInterestManager:
|
||||
logger.debug("✅ 成功导入embedding相关模块")
|
||||
|
||||
# 检查embedding配置是否存在
|
||||
if not hasattr(model_config.model_task_config, 'embedding'):
|
||||
if not hasattr(model_config.model_task_config, "embedding"):
|
||||
raise RuntimeError("❌ 未找到embedding模型配置")
|
||||
|
||||
logger.info("📋 找到embedding模型配置")
|
||||
@@ -101,7 +100,7 @@ class BotInterestManager:
|
||||
logger.info(f"🔗 客户端类型: {type(self.embedding_request).__name__}")
|
||||
|
||||
# 获取第一个embedding模型的ModelInfo
|
||||
if hasattr(self.embedding_config, 'model_list') and self.embedding_config.model_list:
|
||||
if hasattr(self.embedding_config, "model_list") and self.embedding_config.model_list:
|
||||
first_model_name = self.embedding_config.model_list[0]
|
||||
logger.info(f"🎯 使用embedding模型: {first_model_name}")
|
||||
else:
|
||||
@@ -127,7 +126,9 @@ class BotInterestManager:
|
||||
# 生成新的兴趣标签
|
||||
logger.info("🆕 数据库中未找到兴趣标签,开始生成新的...")
|
||||
logger.info("🤖 正在调用LLM生成个性化兴趣标签...")
|
||||
generated_interests = await self._generate_interests_from_personality(personality_description, personality_id)
|
||||
generated_interests = await self._generate_interests_from_personality(
|
||||
personality_description, personality_id
|
||||
)
|
||||
|
||||
if generated_interests:
|
||||
self.current_interests = generated_interests
|
||||
@@ -140,14 +141,16 @@ class BotInterestManager:
|
||||
else:
|
||||
raise RuntimeError("❌ 兴趣标签生成失败")
|
||||
|
||||
async def _generate_interests_from_personality(self, personality_description: str, personality_id: str) -> Optional[BotPersonalityInterests]:
|
||||
async def _generate_interests_from_personality(
|
||||
self, personality_description: str, personality_id: str
|
||||
) -> Optional[BotPersonalityInterests]:
|
||||
"""根据人设生成兴趣标签"""
|
||||
try:
|
||||
logger.info("🎨 开始根据人设生成兴趣标签...")
|
||||
logger.info(f"📝 人设长度: {len(personality_description)} 字符")
|
||||
|
||||
# 检查embedding客户端是否可用
|
||||
if not hasattr(self, 'embedding_request'):
|
||||
if not hasattr(self, "embedding_request"):
|
||||
raise RuntimeError("❌ Embedding客户端未初始化,无法生成兴趣标签")
|
||||
|
||||
# 构建提示词
|
||||
@@ -190,8 +193,7 @@ class BotInterestManager:
|
||||
interests_data = orjson.loads(response)
|
||||
|
||||
bot_interests = BotPersonalityInterests(
|
||||
personality_id=personality_id,
|
||||
personality_description=personality_description
|
||||
personality_id=personality_id, personality_description=personality_description
|
||||
)
|
||||
|
||||
# 解析生成的兴趣标签
|
||||
@@ -202,10 +204,7 @@ class BotInterestManager:
|
||||
tag_name = tag_data.get("name", f"标签_{i}")
|
||||
weight = tag_data.get("weight", 0.5)
|
||||
|
||||
tag = BotInterestTag(
|
||||
tag_name=tag_name,
|
||||
weight=weight
|
||||
)
|
||||
tag = BotInterestTag(tag_name=tag_name, weight=weight)
|
||||
bot_interests.interest_tags.append(tag)
|
||||
|
||||
logger.debug(f" 🏷️ {tag_name} (权重: {weight:.2f})")
|
||||
@@ -225,7 +224,6 @@ class BotInterestManager:
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
|
||||
async def _call_llm_for_interest_generation(self, prompt: str) -> Optional[str]:
|
||||
"""调用LLM生成兴趣标签"""
|
||||
try:
|
||||
@@ -241,10 +239,10 @@ class BotInterestManager:
|
||||
{prompt}
|
||||
|
||||
请确保返回格式为有效的JSON,不要包含任何额外的文本、解释或代码块标记。只返回JSON对象本身。"""
|
||||
|
||||
|
||||
# 使用replyer模型配置
|
||||
replyer_config = model_config.model_task_config.replyer
|
||||
|
||||
|
||||
# 调用LLM API
|
||||
logger.info("🚀 正在通过LLM API发送请求...")
|
||||
success, response, reasoning_content, model_name = await llm_api.generate_with_model(
|
||||
@@ -252,15 +250,17 @@ class BotInterestManager:
|
||||
model_config=replyer_config,
|
||||
request_type="interest_generation",
|
||||
temperature=0.7,
|
||||
max_tokens=2000
|
||||
max_tokens=2000,
|
||||
)
|
||||
|
||||
if success and response:
|
||||
logger.info(f"✅ LLM响应成功,模型: {model_name}, 响应长度: {len(response)} 字符")
|
||||
logger.debug(f"📄 LLM响应内容: {response[:200]}..." if len(response) > 200 else f"📄 LLM响应内容: {response}")
|
||||
logger.debug(
|
||||
f"📄 LLM响应内容: {response[:200]}..." if len(response) > 200 else f"📄 LLM响应内容: {response}"
|
||||
)
|
||||
if reasoning_content:
|
||||
logger.debug(f"🧠 推理内容: {reasoning_content[:100]}...")
|
||||
|
||||
|
||||
# 清理响应内容,移除可能的代码块标记
|
||||
cleaned_response = self._clean_llm_response(response)
|
||||
return cleaned_response
|
||||
@@ -277,25 +277,25 @@ class BotInterestManager:
|
||||
def _clean_llm_response(self, response: str) -> str:
|
||||
"""清理LLM响应,移除代码块标记和其他非JSON内容"""
|
||||
import re
|
||||
|
||||
|
||||
# 移除 ```json 和 ``` 标记
|
||||
cleaned = re.sub(r'```json\s*', '', response)
|
||||
cleaned = re.sub(r'\s*```', '', cleaned)
|
||||
|
||||
cleaned = re.sub(r"```json\s*", "", response)
|
||||
cleaned = re.sub(r"\s*```", "", cleaned)
|
||||
|
||||
# 移除可能的多余空格和换行
|
||||
cleaned = cleaned.strip()
|
||||
|
||||
|
||||
# 尝试提取JSON对象(如果响应中有其他文本)
|
||||
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
|
||||
json_match = re.search(r"\{.*\}", cleaned, re.DOTALL)
|
||||
if json_match:
|
||||
cleaned = json_match.group(0)
|
||||
|
||||
|
||||
logger.debug(f"🧹 清理后的响应: {cleaned[:200]}..." if len(cleaned) > 200 else f"🧹 清理后的响应: {cleaned}")
|
||||
return cleaned
|
||||
|
||||
async def _generate_embeddings_for_tags(self, interests: BotPersonalityInterests):
|
||||
"""为所有兴趣标签生成embedding"""
|
||||
if not hasattr(self, 'embedding_request'):
|
||||
if not hasattr(self, "embedding_request"):
|
||||
raise RuntimeError("❌ Embedding客户端未初始化,无法生成embedding")
|
||||
|
||||
total_tags = len(interests.interest_tags)
|
||||
@@ -342,7 +342,7 @@ class BotInterestManager:
|
||||
|
||||
async def _get_embedding(self, text: str) -> List[float]:
|
||||
"""获取文本的embedding向量"""
|
||||
if not hasattr(self, 'embedding_request'):
|
||||
if not hasattr(self, "embedding_request"):
|
||||
raise RuntimeError("❌ Embedding请求客户端未初始化")
|
||||
|
||||
# 检查缓存
|
||||
@@ -376,7 +376,9 @@ class BotInterestManager:
|
||||
logger.debug(f"✅ 消息embedding生成成功,维度: {len(embedding)}")
|
||||
return embedding
|
||||
|
||||
async def _calculate_similarity_scores(self, result: InterestMatchResult, message_embedding: List[float], keywords: List[str]):
|
||||
async def _calculate_similarity_scores(
|
||||
self, result: InterestMatchResult, message_embedding: List[float], keywords: List[str]
|
||||
):
|
||||
"""计算消息与兴趣标签的相似度分数"""
|
||||
try:
|
||||
if not self.current_interests:
|
||||
@@ -397,7 +399,9 @@ class BotInterestManager:
|
||||
# 设置相似度阈值为0.3
|
||||
if similarity > 0.3:
|
||||
result.add_match(tag.tag_name, weighted_score, keywords)
|
||||
logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 加权分数={weighted_score:.3f}")
|
||||
logger.debug(
|
||||
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 加权分数={weighted_score:.3f}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 计算相似度分数失败: {e}")
|
||||
@@ -455,7 +459,9 @@ class BotInterestManager:
|
||||
match_count += 1
|
||||
high_similarity_count += 1
|
||||
result.add_match(tag.tag_name, enhanced_score, [tag.tag_name])
|
||||
logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [高匹配]")
|
||||
logger.debug(
|
||||
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [高匹配]"
|
||||
)
|
||||
|
||||
elif similarity > medium_threshold:
|
||||
# 中相似度:中等加成
|
||||
@@ -463,7 +469,9 @@ class BotInterestManager:
|
||||
match_count += 1
|
||||
medium_similarity_count += 1
|
||||
result.add_match(tag.tag_name, enhanced_score, [tag.tag_name])
|
||||
logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [中匹配]")
|
||||
logger.debug(
|
||||
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [中匹配]"
|
||||
)
|
||||
|
||||
elif similarity > low_threshold:
|
||||
# 低相似度:轻微加成
|
||||
@@ -471,7 +479,9 @@ class BotInterestManager:
|
||||
match_count += 1
|
||||
low_similarity_count += 1
|
||||
result.add_match(tag.tag_name, enhanced_score, [tag.tag_name])
|
||||
logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [低匹配]")
|
||||
logger.debug(
|
||||
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [低匹配]"
|
||||
)
|
||||
|
||||
logger.info(f"📈 匹配统计: {match_count}/{len(active_tags)} 个标签超过阈值")
|
||||
logger.info(f"🔥 高相似度匹配(>{high_threshold}): {high_similarity_count} 个")
|
||||
@@ -488,7 +498,9 @@ class BotInterestManager:
|
||||
original_score = result.match_scores[tag_name]
|
||||
bonus = keyword_bonus[tag_name]
|
||||
result.match_scores[tag_name] = original_score + bonus
|
||||
logger.debug(f" 🏷️ '{tag_name}': 原始分数={original_score:.3f}, 奖励={bonus:.3f}, 最终分数={result.match_scores[tag_name]:.3f}")
|
||||
logger.debug(
|
||||
f" 🏷️ '{tag_name}': 原始分数={original_score:.3f}, 奖励={bonus:.3f}, 最终分数={result.match_scores[tag_name]:.3f}"
|
||||
)
|
||||
|
||||
# 计算总体分数
|
||||
result.calculate_overall_score()
|
||||
@@ -499,10 +511,11 @@ class BotInterestManager:
|
||||
result.top_tag = top_tag_name
|
||||
logger.info(f"🏆 最佳匹配标签: '{top_tag_name}' (分数: {result.match_scores[top_tag_name]:.3f})")
|
||||
|
||||
logger.info(f"📊 最终结果: 总分={result.overall_score:.3f}, 置信度={result.confidence:.3f}, 匹配标签数={len(result.matched_tags)}")
|
||||
logger.info(
|
||||
f"📊 最终结果: 总分={result.overall_score:.3f}, 置信度={result.confidence:.3f}, 匹配标签数={len(result.matched_tags)}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _calculate_keyword_match_bonus(self, keywords: List[str], matched_tags: List[str]) -> Dict[str, float]:
|
||||
"""计算关键词直接匹配奖励"""
|
||||
if not keywords or not matched_tags:
|
||||
@@ -522,17 +535,25 @@ class BotInterestManager:
|
||||
# 完全匹配
|
||||
if keyword_lower == tag_name_lower:
|
||||
bonus += affinity_config.high_match_interest_threshold * 0.6 # 使用高匹配阈值的60%作为完全匹配奖励
|
||||
logger.debug(f" 🎯 关键词完全匹配: '{keyword}' == '{tag_name}' (+{affinity_config.high_match_interest_threshold * 0.6:.3f})")
|
||||
logger.debug(
|
||||
f" 🎯 关键词完全匹配: '{keyword}' == '{tag_name}' (+{affinity_config.high_match_interest_threshold * 0.6:.3f})"
|
||||
)
|
||||
|
||||
# 包含匹配
|
||||
elif keyword_lower in tag_name_lower or tag_name_lower in keyword_lower:
|
||||
bonus += affinity_config.medium_match_interest_threshold * 0.3 # 使用中匹配阈值的30%作为包含匹配奖励
|
||||
logger.debug(f" 🎯 关键词包含匹配: '{keyword}' ⊃ '{tag_name}' (+{affinity_config.medium_match_interest_threshold * 0.3:.3f})")
|
||||
bonus += (
|
||||
affinity_config.medium_match_interest_threshold * 0.3
|
||||
) # 使用中匹配阈值的30%作为包含匹配奖励
|
||||
logger.debug(
|
||||
f" 🎯 关键词包含匹配: '{keyword}' ⊃ '{tag_name}' (+{affinity_config.medium_match_interest_threshold * 0.3:.3f})"
|
||||
)
|
||||
|
||||
# 部分匹配(编辑距离)
|
||||
elif self._calculate_partial_match(keyword_lower, tag_name_lower):
|
||||
bonus += affinity_config.low_match_interest_threshold * 0.4 # 使用低匹配阈值的40%作为部分匹配奖励
|
||||
logger.debug(f" 🎯 关键词部分匹配: '{keyword}' ≈ '{tag_name}' (+{affinity_config.low_match_interest_threshold * 0.4:.3f})")
|
||||
logger.debug(
|
||||
f" 🎯 关键词部分匹配: '{keyword}' ≈ '{tag_name}' (+{affinity_config.low_match_interest_threshold * 0.4:.3f})"
|
||||
)
|
||||
|
||||
if bonus > 0:
|
||||
bonus_dict[tag_name] = min(bonus, affinity_config.max_match_bonus) # 使用配置的最大奖励限制
|
||||
@@ -608,12 +629,12 @@ class BotInterestManager:
|
||||
|
||||
with get_db_session() as session:
|
||||
# 查询最新的兴趣标签配置
|
||||
db_interests = session.query(DBBotPersonalityInterests).filter(
|
||||
DBBotPersonalityInterests.personality_id == personality_id
|
||||
).order_by(
|
||||
DBBotPersonalityInterests.version.desc(),
|
||||
DBBotPersonalityInterests.last_updated.desc()
|
||||
).first()
|
||||
db_interests = (
|
||||
session.query(DBBotPersonalityInterests)
|
||||
.filter(DBBotPersonalityInterests.personality_id == personality_id)
|
||||
.order_by(DBBotPersonalityInterests.version.desc(), DBBotPersonalityInterests.last_updated.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if db_interests:
|
||||
logger.info(f"✅ 找到数据库中的兴趣标签配置,版本: {db_interests.version}")
|
||||
@@ -631,7 +652,7 @@ class BotInterestManager:
|
||||
personality_description=db_interests.personality_description,
|
||||
embedding_model=db_interests.embedding_model,
|
||||
version=db_interests.version,
|
||||
last_updated=db_interests.last_updated
|
||||
last_updated=db_interests.last_updated,
|
||||
)
|
||||
|
||||
# 解析兴趣标签
|
||||
@@ -639,10 +660,14 @@ class BotInterestManager:
|
||||
tag = BotInterestTag(
|
||||
tag_name=tag_data.get("tag_name", ""),
|
||||
weight=tag_data.get("weight", 0.5),
|
||||
created_at=datetime.fromisoformat(tag_data.get("created_at", datetime.now().isoformat())),
|
||||
updated_at=datetime.fromisoformat(tag_data.get("updated_at", datetime.now().isoformat())),
|
||||
created_at=datetime.fromisoformat(
|
||||
tag_data.get("created_at", datetime.now().isoformat())
|
||||
),
|
||||
updated_at=datetime.fromisoformat(
|
||||
tag_data.get("updated_at", datetime.now().isoformat())
|
||||
),
|
||||
is_active=tag_data.get("is_active", True),
|
||||
embedding=tag_data.get("embedding")
|
||||
embedding=tag_data.get("embedding"),
|
||||
)
|
||||
interests.interest_tags.append(tag)
|
||||
|
||||
@@ -685,7 +710,7 @@ class BotInterestManager:
|
||||
"created_at": tag.created_at.isoformat(),
|
||||
"updated_at": tag.updated_at.isoformat(),
|
||||
"is_active": tag.is_active,
|
||||
"embedding": tag.embedding
|
||||
"embedding": tag.embedding,
|
||||
}
|
||||
tags_data.append(tag_dict)
|
||||
|
||||
@@ -694,9 +719,11 @@ class BotInterestManager:
|
||||
|
||||
with get_db_session() as session:
|
||||
# 检查是否已存在相同personality_id的记录
|
||||
existing_record = session.query(DBBotPersonalityInterests).filter(
|
||||
DBBotPersonalityInterests.personality_id == interests.personality_id
|
||||
).first()
|
||||
existing_record = (
|
||||
session.query(DBBotPersonalityInterests)
|
||||
.filter(DBBotPersonalityInterests.personality_id == interests.personality_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_record:
|
||||
# 更新现有记录
|
||||
@@ -718,7 +745,7 @@ class BotInterestManager:
|
||||
interest_tags=json_data,
|
||||
embedding_model=interests.embedding_model,
|
||||
version=interests.version,
|
||||
last_updated=interests.last_updated
|
||||
last_updated=interests.last_updated,
|
||||
)
|
||||
session.add(new_record)
|
||||
session.commit()
|
||||
@@ -728,9 +755,11 @@ class BotInterestManager:
|
||||
|
||||
# 验证保存是否成功
|
||||
with get_db_session() as session:
|
||||
saved_record = session.query(DBBotPersonalityInterests).filter(
|
||||
DBBotPersonalityInterests.personality_id == interests.personality_id
|
||||
).first()
|
||||
saved_record = (
|
||||
session.query(DBBotPersonalityInterests)
|
||||
.filter(DBBotPersonalityInterests.personality_id == interests.personality_id)
|
||||
.first()
|
||||
)
|
||||
session.commit()
|
||||
if saved_record:
|
||||
logger.info(f"✅ 验证成功:数据库中存在personality_id为 {interests.personality_id} 的记录")
|
||||
@@ -760,7 +789,7 @@ class BotInterestManager:
|
||||
"total_tags": len(active_tags),
|
||||
"embedding_model": self.current_interests.embedding_model,
|
||||
"last_updated": self.current_interests.last_updated.isoformat(),
|
||||
"cache_size": len(self.embedding_cache)
|
||||
"cache_size": len(self.embedding_cache),
|
||||
}
|
||||
|
||||
async def update_interest_tags(self, new_personality_description: str = None):
|
||||
@@ -775,8 +804,7 @@ class BotInterestManager:
|
||||
|
||||
# 重新生成兴趣标签
|
||||
new_interests = await self._generate_interests_from_personality(
|
||||
self.current_interests.personality_description,
|
||||
self.current_interests.personality_id
|
||||
self.current_interests.personality_description, self.current_interests.personality_id
|
||||
)
|
||||
|
||||
if new_interests:
|
||||
@@ -791,4 +819,4 @@ class BotInterestManager:
|
||||
|
||||
|
||||
# 创建全局实例(重新创建以包含新的属性)
|
||||
bot_interest_manager = BotInterestManager()
|
||||
bot_interest_manager = BotInterestManager()
|
||||
|
||||
@@ -4,13 +4,11 @@
|
||||
"""
|
||||
|
||||
from .message_manager import MessageManager, message_manager
|
||||
from src.common.data_models.message_manager_data_model import StreamContext, MessageStatus, MessageManagerStats, StreamStats
|
||||
from src.common.data_models.message_manager_data_model import (
|
||||
StreamContext,
|
||||
MessageStatus,
|
||||
MessageManagerStats,
|
||||
StreamStats,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MessageManager",
|
||||
"message_manager",
|
||||
"StreamContext",
|
||||
"MessageStatus",
|
||||
"MessageManagerStats",
|
||||
"StreamStats"
|
||||
]
|
||||
__all__ = ["MessageManager", "message_manager", "StreamContext", "MessageStatus", "MessageManagerStats", "StreamStats"]
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
消息管理模块
|
||||
管理每个聊天流的上下文信息,包含历史记录和未读消息,定期检查并处理新消息
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
@@ -100,9 +101,7 @@ class MessageManager:
|
||||
|
||||
# 如果没有处理任务,创建一个
|
||||
if not context.processing_task or context.processing_task.done():
|
||||
context.processing_task = asyncio.create_task(
|
||||
self._process_stream_messages(stream_id)
|
||||
)
|
||||
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
|
||||
|
||||
# 更新统计
|
||||
self.stats.active_streams = active_streams
|
||||
@@ -128,11 +127,11 @@ class MessageManager:
|
||||
try:
|
||||
# 发送到AFC处理器,传递StreamContext对象
|
||||
results = await afc_manager.process_stream_context(stream_id, context)
|
||||
|
||||
|
||||
# 处理结果,标记消息为已读
|
||||
if results.get("success", False):
|
||||
self._clear_all_unread_messages(context)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理聊天流 {stream_id} 时发生异常,将清除所有未读消息: {e}")
|
||||
raise
|
||||
@@ -175,7 +174,7 @@ class MessageManager:
|
||||
unread_count=len(context.get_unread_messages()),
|
||||
history_count=len(context.history_messages),
|
||||
last_check_time=context.last_check_time,
|
||||
has_active_task=context.processing_task and not context.processing_task.done()
|
||||
has_active_task=context.processing_task and not context.processing_task.done(),
|
||||
)
|
||||
|
||||
def get_manager_stats(self) -> Dict[str, Any]:
|
||||
@@ -186,7 +185,7 @@ class MessageManager:
|
||||
"total_unread_messages": self.stats.total_unread_messages,
|
||||
"total_processed_messages": self.stats.total_processed_messages,
|
||||
"uptime": self.stats.uptime,
|
||||
"start_time": self.stats.start_time
|
||||
"start_time": self.stats.start_time,
|
||||
}
|
||||
|
||||
def cleanup_inactive_streams(self, max_inactive_hours: int = 24):
|
||||
@@ -196,8 +195,7 @@ class MessageManager:
|
||||
|
||||
inactive_streams = []
|
||||
for stream_id, context in self.stream_contexts.items():
|
||||
if (current_time - context.last_check_time > max_inactive_seconds and
|
||||
not context.get_unread_messages()):
|
||||
if current_time - context.last_check_time > max_inactive_seconds and not context.get_unread_messages():
|
||||
inactive_streams.append(stream_id)
|
||||
|
||||
for stream_id in inactive_streams:
|
||||
@@ -210,9 +208,9 @@ class MessageManager:
|
||||
unread_messages = context.get_unread_messages()
|
||||
if not unread_messages:
|
||||
return
|
||||
|
||||
|
||||
logger.warning(f"正在清除 {len(unread_messages)} 条未读消息")
|
||||
|
||||
|
||||
# 将所有未读消息标记为已读并移动到历史记录
|
||||
for msg in unread_messages[:]: # 使用切片复制避免迭代时修改列表
|
||||
try:
|
||||
@@ -224,4 +222,4 @@ class MessageManager:
|
||||
|
||||
|
||||
# 创建全局消息管理器实例
|
||||
message_manager = MessageManager()
|
||||
message_manager = MessageManager()
|
||||
|
||||
@@ -17,6 +17,7 @@ from src.plugin_system.core import component_registry, event_manager, global_ann
|
||||
from src.plugin_system.base import BaseCommand, EventType
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
|
||||
# 导入反注入系统
|
||||
from src.chat.antipromptinjector import initialize_anti_injector
|
||||
|
||||
@@ -511,7 +512,7 @@ class ChatBot:
|
||||
chat_info_user_id=message.chat_stream.user_info.user_id,
|
||||
chat_info_user_nickname=message.chat_stream.user_info.user_nickname,
|
||||
chat_info_user_cardname=message.chat_stream.user_info.user_cardname,
|
||||
chat_info_user_platform=message.chat_stream.user_info.platform
|
||||
chat_info_user_platform=message.chat_stream.user_info.platform,
|
||||
)
|
||||
|
||||
# 如果是群聊,添加群组信息
|
||||
|
||||
@@ -84,7 +84,9 @@ class ChatStream:
|
||||
self.saved = False
|
||||
self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息
|
||||
# 从配置文件中读取focus_value,如果没有则使用默认值1.0
|
||||
self.focus_energy = data.get("focus_energy", global_config.chat.focus_value) if data else global_config.chat.focus_value
|
||||
self.focus_energy = (
|
||||
data.get("focus_energy", global_config.chat.focus_value) if data else global_config.chat.focus_value
|
||||
)
|
||||
self.no_reply_consecutive = 0
|
||||
self.breaking_accumulated_interest = 0.0
|
||||
|
||||
|
||||
@@ -165,10 +165,15 @@ class ActionManager:
|
||||
# 通过chat_id获取chat_stream
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(chat_id)
|
||||
|
||||
|
||||
if not chat_stream:
|
||||
logger.error(f"{log_prefix} 无法找到chat_id对应的chat_stream: {chat_id}")
|
||||
return {"action_type": action_name, "success": False, "reply_text": "", "error": "chat_stream not found"}
|
||||
return {
|
||||
"action_type": action_name,
|
||||
"success": False,
|
||||
"reply_text": "",
|
||||
"error": "chat_stream not found",
|
||||
}
|
||||
|
||||
if action_name == "no_action":
|
||||
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
|
||||
@@ -177,7 +182,7 @@ class ActionManager:
|
||||
# 直接处理no_reply逻辑,不再通过动作系统
|
||||
reason = reasoning or "选择不回复"
|
||||
logger.info(f"{log_prefix} 选择不回复,原因: {reason}")
|
||||
|
||||
|
||||
# 存储no_reply信息到数据库
|
||||
await database_api.store_action_info(
|
||||
chat_stream=chat_stream,
|
||||
@@ -396,7 +401,7 @@ class ActionManager:
|
||||
}
|
||||
|
||||
return loop_info, reply_text, cycle_timers
|
||||
|
||||
|
||||
async def send_response(self, chat_stream, reply_set, thinking_start_time, message_data) -> str:
|
||||
"""
|
||||
发送回复内容的具体实现
|
||||
@@ -471,4 +476,4 @@ class ActionManager:
|
||||
typing=True,
|
||||
)
|
||||
|
||||
return reply_text
|
||||
return reply_text
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
PlanGenerator: 负责搜集和汇总所有决策所需的信息,生成一个未经筛选的“原始计划” (Plan)。
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
@@ -35,6 +36,7 @@ class PlanGenerator:
|
||||
chat_id (str): 当前聊天的 ID。
|
||||
"""
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
|
||||
self.chat_id = chat_id
|
||||
# 注意:ActionManager 可能需要根据实际情况初始化
|
||||
self.action_manager = ActionManager()
|
||||
@@ -52,7 +54,7 @@ class PlanGenerator:
|
||||
Plan: 一个填充了初始上下文信息的 Plan 对象。
|
||||
"""
|
||||
_is_group_chat, chat_target_info_dict = get_chat_type_and_target_info(self.chat_id)
|
||||
|
||||
|
||||
target_info = None
|
||||
if chat_target_info_dict:
|
||||
target_info = TargetPersonInfo(**chat_target_info_dict)
|
||||
@@ -65,7 +67,6 @@ class PlanGenerator:
|
||||
)
|
||||
chat_history = [DatabaseMessages(**msg) for msg in chat_history_raw]
|
||||
|
||||
|
||||
plan = Plan(
|
||||
chat_id=self.chat_id,
|
||||
mode=mode,
|
||||
@@ -86,10 +87,10 @@ class PlanGenerator:
|
||||
Dict[str, "ActionInfo"]: 一个字典,键是动作名称,值是 ActionInfo 对象。
|
||||
"""
|
||||
current_available_actions_dict = self.action_manager.get_using_actions()
|
||||
all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
|
||||
all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
|
||||
ComponentType.ACTION
|
||||
)
|
||||
|
||||
|
||||
current_available_actions = {}
|
||||
for action_name in current_available_actions_dict:
|
||||
if action_name in all_registered_actions:
|
||||
@@ -99,16 +100,13 @@ class PlanGenerator:
|
||||
name="reply",
|
||||
component_type=ComponentType.ACTION,
|
||||
description="系统级动作:选择回复消息的决策",
|
||||
action_parameters={
|
||||
"content": "回复的文本内容",
|
||||
"reply_to_message_id": "要回复的消息ID"
|
||||
},
|
||||
action_parameters={"content": "回复的文本内容", "reply_to_message_id": "要回复的消息ID"},
|
||||
action_require=[
|
||||
"你想要闲聊或者随便附和",
|
||||
"当用户提到你或艾特你时",
|
||||
"当需要回答用户的问题时",
|
||||
"当你想参与对话时",
|
||||
"当用户分享有趣的内容时"
|
||||
"当用户分享有趣的内容时",
|
||||
],
|
||||
activation_type=ActionActivationType.ALWAYS,
|
||||
activation_keywords=[],
|
||||
@@ -131,4 +129,4 @@ class PlanGenerator:
|
||||
)
|
||||
current_available_actions["no_reply"] = no_reply_info
|
||||
current_available_actions["reply"] = reply_info
|
||||
return current_available_actions
|
||||
return current_available_actions
|
||||
|
||||
@@ -109,9 +109,7 @@ class ActionPlanner:
|
||||
self.planner_stats["failed_plans"] += 1
|
||||
return [], None
|
||||
|
||||
async def _enhanced_plan_flow(
|
||||
self, mode: ChatMode, context: StreamContext
|
||||
) -> Tuple[List[Dict], Optional[Dict]]:
|
||||
async def _enhanced_plan_flow(self, mode: ChatMode, context: StreamContext) -> Tuple[List[Dict], Optional[Dict]]:
|
||||
"""执行增强版规划流程"""
|
||||
try:
|
||||
# 1. 生成初始 Plan
|
||||
@@ -137,7 +135,9 @@ class ActionPlanner:
|
||||
# 检查兴趣度是否达到非回复动作阈值
|
||||
non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
|
||||
if score < non_reply_action_interest_threshold:
|
||||
logger.info(f"❌ 兴趣度不足非回复动作阈值: {score:.3f} < {non_reply_action_interest_threshold:.3f},直接返回no_action")
|
||||
logger.info(
|
||||
f"❌ 兴趣度不足非回复动作阈值: {score:.3f} < {non_reply_action_interest_threshold:.3f},直接返回no_action"
|
||||
)
|
||||
logger.info(f"📊 最低要求: {non_reply_action_interest_threshold:.3f}")
|
||||
# 直接返回 no_action
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
|
||||
@@ -303,7 +303,7 @@ class DefaultReplyer:
|
||||
"model": model_name,
|
||||
"tool_calls": tool_call,
|
||||
}
|
||||
|
||||
|
||||
# 触发 AFTER_LLM 事件
|
||||
if not from_plugin:
|
||||
result = await event_manager.trigger_event(
|
||||
@@ -598,6 +598,7 @@ class DefaultReplyer:
|
||||
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
|
||||
"""解析回复目标消息 - 使用共享工具"""
|
||||
from src.chat.utils.prompt import Prompt
|
||||
|
||||
if target_message is None:
|
||||
logger.warning("target_message为None,返回默认值")
|
||||
return "未知用户", "(无消息内容)"
|
||||
@@ -704,22 +705,24 @@ class DefaultReplyer:
|
||||
unread_history_prompt = ""
|
||||
if unread_messages:
|
||||
# 尝试获取兴趣度评分
|
||||
interest_scores = await self._get_interest_scores_for_messages([msg.flatten() for msg in unread_messages])
|
||||
interest_scores = await self._get_interest_scores_for_messages(
|
||||
[msg.flatten() for msg in unread_messages]
|
||||
)
|
||||
|
||||
unread_lines = []
|
||||
for msg in unread_messages:
|
||||
msg_id = msg.message_id
|
||||
msg_time = time.strftime('%H:%M:%S', time.localtime(msg.time))
|
||||
msg_time = time.strftime("%H:%M:%S", time.localtime(msg.time))
|
||||
msg_content = msg.processed_plain_text
|
||||
|
||||
|
||||
# 使用与已读历史消息相同的方法获取用户名
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
|
||||
|
||||
# 获取用户信息
|
||||
user_info = getattr(msg, 'user_info', {})
|
||||
platform = getattr(user_info, 'platform', '') or getattr(msg, 'platform', '')
|
||||
user_id = getattr(user_info, 'user_id', '') or getattr(msg, 'user_id', '')
|
||||
|
||||
user_info = getattr(msg, "user_info", {})
|
||||
platform = getattr(user_info, "platform", "") or getattr(msg, "platform", "")
|
||||
user_id = getattr(user_info, "user_id", "") or getattr(msg, "user_id", "")
|
||||
|
||||
# 获取用户名
|
||||
if platform and user_id:
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
@@ -727,11 +730,11 @@ class DefaultReplyer:
|
||||
sender_name = person_info_manager.get_value_sync(person_id, "person_name") or "未知用户"
|
||||
else:
|
||||
sender_name = "未知用户"
|
||||
|
||||
|
||||
# 添加兴趣度信息
|
||||
interest_score = interest_scores.get(msg_id, 0.0)
|
||||
interest_text = f" [兴趣度: {interest_score:.3f}]" if interest_score > 0 else ""
|
||||
|
||||
|
||||
unread_lines.append(f"{msg_time} {sender_name}: {msg_content}{interest_text}")
|
||||
|
||||
unread_history_prompt_str = "\n".join(unread_lines)
|
||||
@@ -808,17 +811,17 @@ class DefaultReplyer:
|
||||
unread_lines = []
|
||||
for msg in unread_messages:
|
||||
msg_id = msg.get("message_id", "")
|
||||
msg_time = time.strftime('%H:%M:%S', time.localtime(msg.get("time", time.time())))
|
||||
msg_time = time.strftime("%H:%M:%S", time.localtime(msg.get("time", time.time())))
|
||||
msg_content = msg.get("processed_plain_text", "")
|
||||
|
||||
|
||||
# 使用与已读历史消息相同的方法获取用户名
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
|
||||
|
||||
# 获取用户信息
|
||||
user_info = msg.get("user_info", {})
|
||||
platform = user_info.get("platform") or msg.get("platform", "")
|
||||
user_id = user_info.get("user_id") or msg.get("user_id", "")
|
||||
|
||||
|
||||
# 获取用户名
|
||||
if platform and user_id:
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
@@ -834,7 +837,9 @@ class DefaultReplyer:
|
||||
unread_lines.append(f"{msg_time} {sender_name}: {msg_content}{interest_text}")
|
||||
|
||||
unread_history_prompt_str = "\n".join(unread_lines)
|
||||
unread_history_prompt = f"这是未读历史消息,包含兴趣度评分,请优先对兴趣值高的消息做出动作:\n{unread_history_prompt_str}"
|
||||
unread_history_prompt = (
|
||||
f"这是未读历史消息,包含兴趣度评分,请优先对兴趣值高的消息做出动作:\n{unread_history_prompt_str}"
|
||||
)
|
||||
else:
|
||||
unread_history_prompt = "暂无未读历史消息"
|
||||
|
||||
@@ -982,7 +987,7 @@ class DefaultReplyer:
|
||||
reply_message.get("user_id"), # type: ignore
|
||||
)
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
|
||||
|
||||
# 如果person_name为None,使用fallback值
|
||||
if person_name is None:
|
||||
# 尝试从reply_message获取用户名
|
||||
@@ -990,12 +995,12 @@ class DefaultReplyer:
|
||||
logger.warning(f"未知用户,将存储用户信息:{fallback_name}")
|
||||
person_name = str(fallback_name)
|
||||
person_info_manager.set_value(person_id, "person_name", fallback_name)
|
||||
|
||||
|
||||
# 检查是否是bot自己的名字,如果是则替换为"(你)"
|
||||
bot_user_id = str(global_config.bot.qq_account)
|
||||
current_user_id = person_info_manager.get_value_sync(person_id, "user_id")
|
||||
current_platform = reply_message.get("chat_info_platform")
|
||||
|
||||
|
||||
if current_user_id == bot_user_id and current_platform == global_config.bot.platform:
|
||||
sender = f"{person_name}(你)"
|
||||
else:
|
||||
@@ -1050,8 +1055,9 @@ class DefaultReplyer:
|
||||
target_user_info = None
|
||||
if sender:
|
||||
target_user_info = await person_info_manager.get_person_info_by_name(sender)
|
||||
|
||||
|
||||
from src.chat.utils.prompt import Prompt
|
||||
|
||||
# 并行执行六个构建任务
|
||||
task_results = await asyncio.gather(
|
||||
self._time_and_run_task(
|
||||
@@ -1127,6 +1133,7 @@ class DefaultReplyer:
|
||||
schedule_block = ""
|
||||
if global_config.planning_system.schedule_enable:
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
|
||||
current_activity = schedule_manager.get_current_activity()
|
||||
if current_activity:
|
||||
schedule_block = f"你当前正在:{current_activity}。"
|
||||
@@ -1139,7 +1146,7 @@ class DefaultReplyer:
|
||||
safety_guidelines = global_config.personality.safety_guidelines
|
||||
safety_guidelines_block = ""
|
||||
if safety_guidelines:
|
||||
guidelines_text = "\n".join(f"{i+1}. {line}" for i, line in enumerate(safety_guidelines))
|
||||
guidelines_text = "\n".join(f"{i + 1}. {line}" for i, line in enumerate(safety_guidelines))
|
||||
safety_guidelines_block = f"""### 安全与互动底线
|
||||
在任何情况下,你都必须遵守以下由你的设定者为你定义的原则:
|
||||
{guidelines_text}
|
||||
@@ -1212,7 +1219,7 @@ class DefaultReplyer:
|
||||
template_name = "normal_style_prompt"
|
||||
elif current_prompt_mode == "minimal":
|
||||
template_name = "default_expressor_prompt"
|
||||
|
||||
|
||||
# 获取模板内容
|
||||
template_prompt = await global_prompt_manager.get_prompt_async(template_name)
|
||||
prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters)
|
||||
@@ -1488,19 +1495,19 @@ class DefaultReplyer:
|
||||
# 使用AFC关系追踪器获取关系信息
|
||||
try:
|
||||
from src.chat.affinity_flow.relationship_integration import get_relationship_tracker
|
||||
|
||||
|
||||
relationship_tracker = get_relationship_tracker()
|
||||
if relationship_tracker:
|
||||
# 获取用户信息以获取真实的user_id
|
||||
user_info = await person_info_manager.get_values(person_id, ["user_id", "platform"])
|
||||
user_id = user_info.get("user_id", "unknown")
|
||||
|
||||
|
||||
# 从数据库获取关系数据
|
||||
relationship_data = relationship_tracker._get_user_relationship_from_db(user_id)
|
||||
if relationship_data:
|
||||
relationship_text = relationship_data.get("relationship_text", "")
|
||||
relationship_score = relationship_data.get("relationship_score", 0.3)
|
||||
|
||||
|
||||
# 构建丰富的关系信息描述
|
||||
if relationship_text:
|
||||
# 转换关系分数为描述性文本
|
||||
@@ -1514,7 +1521,7 @@ class DefaultReplyer:
|
||||
relationship_level = "认识的人"
|
||||
else:
|
||||
relationship_level = "陌生人"
|
||||
|
||||
|
||||
return f"你与{sender}的关系:{relationship_level}(关系分:{relationship_score:.2f}/1.0)。{relationship_text}"
|
||||
else:
|
||||
return f"你与{sender}是初次见面,关系分:{relationship_score:.2f}/1.0。"
|
||||
@@ -1523,7 +1530,7 @@ class DefaultReplyer:
|
||||
else:
|
||||
logger.warning("AFC关系追踪器未初始化,使用默认关系信息")
|
||||
return f"你与{sender}是普通朋友关系。"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取AFC关系信息失败: {e}")
|
||||
return f"你与{sender}是普通朋友关系。"
|
||||
|
||||
@@ -37,7 +37,7 @@ def replace_user_references_sync(
|
||||
"""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
|
||||
if name_resolver is None:
|
||||
person_info_manager = get_person_info_manager()
|
||||
|
||||
@@ -821,7 +821,7 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar_one_or_none()
|
||||
if image and image.description: # type: ignore
|
||||
if image and image.description: # type: ignore
|
||||
description = image.description
|
||||
except Exception:
|
||||
# 如果查询失败,保持默认描述
|
||||
|
||||
@@ -25,7 +25,7 @@ logger = get_logger("unified_prompt")
|
||||
@dataclass
|
||||
class PromptParameters:
|
||||
"""统一提示词参数系统"""
|
||||
|
||||
|
||||
# 基础参数
|
||||
chat_id: str = ""
|
||||
is_group_chat: bool = False
|
||||
@@ -34,7 +34,7 @@ class PromptParameters:
|
||||
reply_to: str = ""
|
||||
extra_info: str = ""
|
||||
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
|
||||
|
||||
|
||||
# 功能开关
|
||||
enable_tool: bool = True
|
||||
enable_memory: bool = True
|
||||
@@ -42,20 +42,20 @@ class PromptParameters:
|
||||
enable_relation: bool = True
|
||||
enable_cross_context: bool = True
|
||||
enable_knowledge: bool = True
|
||||
|
||||
|
||||
# 性能控制
|
||||
max_context_messages: int = 50
|
||||
|
||||
|
||||
# 调试选项
|
||||
debug_mode: bool = False
|
||||
|
||||
|
||||
# 聊天历史和上下文
|
||||
chat_target_info: Optional[Dict[str, Any]] = None
|
||||
message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list)
|
||||
message_list_before_short: List[Dict[str, Any]] = field(default_factory=list)
|
||||
chat_talking_prompt_short: str = ""
|
||||
target_user_info: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
# 已构建的内容块
|
||||
expression_habits_block: str = ""
|
||||
relation_info_block: str = ""
|
||||
@@ -63,7 +63,7 @@ class PromptParameters:
|
||||
tool_info_block: str = ""
|
||||
knowledge_prompt: str = ""
|
||||
cross_context_block: str = ""
|
||||
|
||||
|
||||
# 其他内容块
|
||||
keywords_reaction_prompt: str = ""
|
||||
extra_info_block: str = ""
|
||||
@@ -75,10 +75,10 @@ class PromptParameters:
|
||||
reply_target_block: str = ""
|
||||
mood_prompt: str = ""
|
||||
action_descriptions: str = ""
|
||||
|
||||
|
||||
# 可用动作信息
|
||||
available_actions: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
def validate(self) -> List[str]:
|
||||
"""参数验证"""
|
||||
errors = []
|
||||
@@ -93,22 +93,22 @@ class PromptParameters:
|
||||
|
||||
class PromptContext:
|
||||
"""提示词上下文管理器"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
|
||||
self._current_context_var = contextvars.ContextVar("current_context", default=None)
|
||||
self._context_lock = asyncio.Lock()
|
||||
|
||||
|
||||
@property
|
||||
def _current_context(self) -> Optional[str]:
|
||||
"""获取当前协程的上下文ID"""
|
||||
return self._current_context_var.get()
|
||||
|
||||
|
||||
@_current_context.setter
|
||||
def _current_context(self, value: Optional[str]):
|
||||
"""设置当前协程的上下文ID"""
|
||||
self._current_context_var.set(value) # type: ignore
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_scope(self, context_id: Optional[str] = None):
|
||||
"""创建一个异步的临时提示模板作用域"""
|
||||
@@ -123,13 +123,13 @@ class PromptContext:
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"获取上下文锁超时,context_id: {context_id}")
|
||||
context_id = None
|
||||
|
||||
|
||||
previous_context = self._current_context
|
||||
token = self._current_context_var.set(context_id) if context_id else None
|
||||
else:
|
||||
previous_context = self._current_context
|
||||
token = None
|
||||
|
||||
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
@@ -142,7 +142,7 @@ class PromptContext:
|
||||
self._current_context = previous_context
|
||||
except Exception:
|
||||
...
|
||||
|
||||
|
||||
async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
|
||||
"""异步获取当前作用域中的提示模板"""
|
||||
async with self._context_lock:
|
||||
@@ -155,7 +155,7 @@ class PromptContext:
|
||||
):
|
||||
return self._context_prompts[current_context][name]
|
||||
return None
|
||||
|
||||
|
||||
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
|
||||
"""异步注册提示模板到指定作用域"""
|
||||
async with self._context_lock:
|
||||
@@ -166,49 +166,49 @@ class PromptContext:
|
||||
|
||||
class PromptManager:
|
||||
"""统一提示词管理器"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self._prompts = {}
|
||||
self._counter = 0
|
||||
self._context = PromptContext()
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_message_scope(self, message_id: Optional[str] = None):
|
||||
"""为消息处理创建异步临时作用域"""
|
||||
async with self._context.async_scope(message_id):
|
||||
yield self
|
||||
|
||||
|
||||
async def get_prompt_async(self, name: str) -> "Prompt":
|
||||
"""异步获取提示模板"""
|
||||
context_prompt = await self._context.get_prompt_async(name)
|
||||
if context_prompt is not None:
|
||||
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
|
||||
return context_prompt
|
||||
|
||||
|
||||
async with self._lock:
|
||||
if name not in self._prompts:
|
||||
raise KeyError(f"Prompt '{name}' not found")
|
||||
return self._prompts[name]
|
||||
|
||||
|
||||
def generate_name(self, template: str) -> str:
|
||||
"""为未命名的prompt生成名称"""
|
||||
self._counter += 1
|
||||
return f"prompt_{self._counter}"
|
||||
|
||||
|
||||
def register(self, prompt: "Prompt") -> None:
|
||||
"""注册一个prompt"""
|
||||
if not prompt.name:
|
||||
prompt.name = self.generate_name(prompt.template)
|
||||
self._prompts[prompt.name] = prompt
|
||||
|
||||
|
||||
def add_prompt(self, name: str, fstr: str) -> "Prompt":
|
||||
"""添加新提示模板"""
|
||||
prompt = Prompt(fstr, name=name)
|
||||
if prompt.name:
|
||||
self._prompts[prompt.name] = prompt
|
||||
return prompt
|
||||
|
||||
|
||||
async def format_prompt(self, name: str, **kwargs) -> str:
|
||||
"""格式化提示模板"""
|
||||
prompt = await self.get_prompt_async(name)
|
||||
@@ -225,21 +225,21 @@ class Prompt:
|
||||
统一提示词类 - 合并模板管理和智能构建功能
|
||||
真正的Prompt类,支持模板管理和智能上下文构建
|
||||
"""
|
||||
|
||||
|
||||
# 临时标记,作为类常量
|
||||
_TEMP_LEFT_BRACE = "__ESCAPED_LEFT_BRACE__"
|
||||
_TEMP_RIGHT_BRACE = "__ESCAPED_RIGHT_BRACE__"
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
template: str,
|
||||
name: Optional[str] = None,
|
||||
parameters: Optional[PromptParameters] = None,
|
||||
should_register: bool = True
|
||||
should_register: bool = True,
|
||||
):
|
||||
"""
|
||||
初始化统一提示词
|
||||
|
||||
|
||||
Args:
|
||||
template: 提示词模板字符串
|
||||
name: 提示词名称
|
||||
@@ -251,14 +251,14 @@ class Prompt:
|
||||
self.parameters = parameters or PromptParameters()
|
||||
self.args = self._parse_template_args(template)
|
||||
self._formatted_result = ""
|
||||
|
||||
|
||||
# 预处理模板中的转义花括号
|
||||
self._processed_template = self._process_escaped_braces(template)
|
||||
|
||||
|
||||
# 自动注册
|
||||
if should_register and not global_prompt_manager._context._current_context:
|
||||
global_prompt_manager.register(self)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _process_escaped_braces(template) -> str:
|
||||
"""处理模板中的转义花括号"""
|
||||
@@ -266,14 +266,14 @@ class Prompt:
|
||||
template = "\n".join(str(item) for item in template)
|
||||
elif not isinstance(template, str):
|
||||
template = str(template)
|
||||
|
||||
|
||||
return template.replace("\\{", Prompt._TEMP_LEFT_BRACE).replace("\\}", Prompt._TEMP_RIGHT_BRACE)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _restore_escaped_braces(template: str) -> str:
|
||||
"""将临时标记还原为实际的花括号字符"""
|
||||
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
|
||||
|
||||
|
||||
def _parse_template_args(self, template: str) -> List[str]:
|
||||
"""解析模板参数"""
|
||||
template_args = []
|
||||
@@ -283,11 +283,11 @@ class Prompt:
|
||||
if expr and expr not in template_args:
|
||||
template_args.append(expr)
|
||||
return template_args
|
||||
|
||||
|
||||
async def build(self) -> str:
|
||||
"""
|
||||
构建完整的提示词,包含智能上下文
|
||||
|
||||
|
||||
Returns:
|
||||
str: 构建完成的提示词文本
|
||||
"""
|
||||
@@ -296,38 +296,38 @@ class Prompt:
|
||||
if errors:
|
||||
logger.error(f"参数验证失败: {', '.join(errors)}")
|
||||
raise ValueError(f"参数验证失败: {', '.join(errors)}")
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
# 构建上下文数据
|
||||
context_data = await self._build_context_data()
|
||||
|
||||
|
||||
# 格式化模板
|
||||
result = await self._format_with_context(context_data)
|
||||
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logger.debug(f"Prompt构建完成,模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s")
|
||||
|
||||
|
||||
self._formatted_result = result
|
||||
return result
|
||||
|
||||
|
||||
except asyncio.TimeoutError as e:
|
||||
logger.error(f"构建Prompt超时: {e}")
|
||||
raise TimeoutError(f"构建Prompt超时: {e}") from e
|
||||
except Exception as e:
|
||||
logger.error(f"构建Prompt失败: {e}")
|
||||
raise RuntimeError(f"构建Prompt失败: {e}") from e
|
||||
|
||||
|
||||
async def _build_context_data(self) -> Dict[str, Any]:
|
||||
"""构建智能上下文数据"""
|
||||
# 并行执行所有构建任务
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
try:
|
||||
# 准备构建任务
|
||||
tasks = []
|
||||
task_names = []
|
||||
|
||||
|
||||
# 初始化预构建参数
|
||||
pre_built_params = {}
|
||||
if self.parameters.expression_habits_block:
|
||||
@@ -342,32 +342,32 @@ class Prompt:
|
||||
pre_built_params["knowledge_prompt"] = self.parameters.knowledge_prompt
|
||||
if self.parameters.cross_context_block:
|
||||
pre_built_params["cross_context_block"] = self.parameters.cross_context_block
|
||||
|
||||
|
||||
# 根据参数确定要构建的项
|
||||
if self.parameters.enable_expression and not pre_built_params.get("expression_habits_block"):
|
||||
tasks.append(self._build_expression_habits())
|
||||
task_names.append("expression_habits")
|
||||
|
||||
|
||||
if self.parameters.enable_memory and not pre_built_params.get("memory_block"):
|
||||
tasks.append(self._build_memory_block())
|
||||
task_names.append("memory_block")
|
||||
|
||||
|
||||
if self.parameters.enable_relation and not pre_built_params.get("relation_info_block"):
|
||||
tasks.append(self._build_relation_info())
|
||||
task_names.append("relation_info")
|
||||
|
||||
|
||||
if self.parameters.enable_tool and not pre_built_params.get("tool_info_block"):
|
||||
tasks.append(self._build_tool_info())
|
||||
task_names.append("tool_info")
|
||||
|
||||
|
||||
if self.parameters.enable_knowledge and not pre_built_params.get("knowledge_prompt"):
|
||||
tasks.append(self._build_knowledge_info())
|
||||
task_names.append("knowledge_info")
|
||||
|
||||
|
||||
if self.parameters.enable_cross_context and not pre_built_params.get("cross_context_block"):
|
||||
tasks.append(self._build_cross_context())
|
||||
task_names.append("cross_context")
|
||||
|
||||
|
||||
# 性能优化
|
||||
base_timeout = 10.0
|
||||
task_timeout = 2.0
|
||||
@@ -375,13 +375,13 @@ class Prompt:
|
||||
max(base_timeout, len(tasks) * task_timeout),
|
||||
30.0,
|
||||
)
|
||||
|
||||
|
||||
max_concurrent_tasks = 5
|
||||
if len(tasks) > max_concurrent_tasks:
|
||||
results = []
|
||||
for i in range(0, len(tasks), max_concurrent_tasks):
|
||||
batch_tasks = tasks[i : i + max_concurrent_tasks]
|
||||
|
||||
|
||||
batch_results = await asyncio.wait_for(
|
||||
asyncio.gather(*batch_tasks, return_exceptions=True), timeout=timeout_seconds
|
||||
)
|
||||
@@ -390,53 +390,55 @@ class Prompt:
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.gather(*tasks, return_exceptions=True), timeout=timeout_seconds
|
||||
)
|
||||
|
||||
|
||||
# 处理结果
|
||||
context_data = {}
|
||||
for i, result in enumerate(results):
|
||||
task_name = task_names[i] if i < len(task_names) else f"task_{i}"
|
||||
|
||||
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"构建任务{task_name}失败: {str(result)}")
|
||||
elif isinstance(result, dict):
|
||||
context_data.update(result)
|
||||
|
||||
|
||||
# 添加预构建的参数
|
||||
for key, value in pre_built_params.items():
|
||||
if value:
|
||||
context_data[key] = value
|
||||
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"构建超时 ({timeout_seconds}s)")
|
||||
context_data = {}
|
||||
for key, value in pre_built_params.items():
|
||||
if value:
|
||||
context_data[key] = value
|
||||
|
||||
|
||||
# 构建聊天历史
|
||||
if self.parameters.prompt_mode == "s4u":
|
||||
await self._build_s4u_chat_context(context_data)
|
||||
else:
|
||||
await self._build_normal_chat_context(context_data)
|
||||
|
||||
|
||||
# 补充基础信息
|
||||
context_data.update({
|
||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt,
|
||||
"extra_info_block": self.parameters.extra_info_block,
|
||||
"time_block": self.parameters.time_block or f"当前时间:{time.strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
"identity": self.parameters.identity_block,
|
||||
"schedule_block": self.parameters.schedule_block,
|
||||
"moderation_prompt": self.parameters.moderation_prompt_block,
|
||||
"reply_target_block": self.parameters.reply_target_block,
|
||||
"mood_state": self.parameters.mood_prompt,
|
||||
"action_descriptions": self.parameters.action_descriptions,
|
||||
})
|
||||
|
||||
context_data.update(
|
||||
{
|
||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt,
|
||||
"extra_info_block": self.parameters.extra_info_block,
|
||||
"time_block": self.parameters.time_block or f"当前时间:{time.strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
"identity": self.parameters.identity_block,
|
||||
"schedule_block": self.parameters.schedule_block,
|
||||
"moderation_prompt": self.parameters.moderation_prompt_block,
|
||||
"reply_target_block": self.parameters.reply_target_block,
|
||||
"mood_state": self.parameters.mood_prompt,
|
||||
"action_descriptions": self.parameters.action_descriptions,
|
||||
}
|
||||
)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logger.debug(f"上下文构建完成,总耗时: {total_time:.2f}s")
|
||||
|
||||
|
||||
return context_data
|
||||
|
||||
|
||||
async def _build_s4u_chat_context(self, context_data: Dict[str, Any]) -> None:
|
||||
"""构建S4U模式的聊天上下文"""
|
||||
if not self.parameters.message_list_before_now_long:
|
||||
@@ -446,20 +448,20 @@ class Prompt:
|
||||
self.parameters.message_list_before_now_long,
|
||||
self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "",
|
||||
self.parameters.sender,
|
||||
self.parameters.chat_id
|
||||
self.parameters.chat_id,
|
||||
)
|
||||
|
||||
context_data["read_history_prompt"] = read_history_prompt
|
||||
context_data["unread_history_prompt"] = unread_history_prompt
|
||||
|
||||
|
||||
async def _build_normal_chat_context(self, context_data: Dict[str, Any]) -> None:
|
||||
"""构建normal模式的聊天上下文"""
|
||||
if not self.parameters.chat_talking_prompt_short:
|
||||
return
|
||||
|
||||
|
||||
context_data["chat_info"] = f"""群里的聊天内容:
|
||||
{self.parameters.chat_talking_prompt_short}"""
|
||||
|
||||
|
||||
async def _build_s4u_chat_history_prompts(
|
||||
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, chat_id: str
|
||||
) -> Tuple[str, str]:
|
||||
@@ -476,101 +478,92 @@ class Prompt:
|
||||
except Exception as e:
|
||||
logger.error(f"构建S4U历史消息prompt失败: {e}")
|
||||
|
||||
|
||||
|
||||
async def _build_expression_habits(self) -> Dict[str, Any]:
|
||||
"""构建表达习惯"""
|
||||
if not global_config.expression.enable_expression:
|
||||
return {"expression_habits_block": ""}
|
||||
|
||||
|
||||
try:
|
||||
from src.chat.express.expression_selector import ExpressionSelector
|
||||
|
||||
|
||||
# 获取聊天历史用于表情选择
|
||||
chat_history = ""
|
||||
if self.parameters.message_list_before_now_long:
|
||||
recent_messages = self.parameters.message_list_before_now_long[-10:]
|
||||
chat_history = build_readable_messages(
|
||||
recent_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal",
|
||||
truncate=True
|
||||
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
|
||||
)
|
||||
|
||||
|
||||
# 创建表情选择器
|
||||
expression_selector = ExpressionSelector(self.parameters.chat_id)
|
||||
|
||||
|
||||
# 选择合适的表情
|
||||
selected_expressions = await expression_selector.select_suitable_expressions_llm(
|
||||
chat_history=chat_history,
|
||||
current_message=self.parameters.target,
|
||||
emotional_tone="neutral",
|
||||
topic_type="general"
|
||||
topic_type="general",
|
||||
)
|
||||
|
||||
|
||||
# 构建表达习惯块
|
||||
if selected_expressions:
|
||||
style_habits_str = "\n".join([f"- {expr}" for expr in selected_expressions])
|
||||
expression_habits_block = f"- 你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:\n{style_habits_str}"
|
||||
else:
|
||||
expression_habits_block = ""
|
||||
|
||||
|
||||
return {"expression_habits_block": expression_habits_block}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"构建表达习惯失败: {e}")
|
||||
return {"expression_habits_block": ""}
|
||||
|
||||
|
||||
async def _build_memory_block(self) -> Dict[str, Any]:
|
||||
"""构建记忆块"""
|
||||
if not global_config.memory.enable_memory:
|
||||
return {"memory_block": ""}
|
||||
|
||||
|
||||
try:
|
||||
from src.chat.memory_system.memory_activator import MemoryActivator
|
||||
from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory
|
||||
|
||||
|
||||
# 获取聊天历史
|
||||
chat_history = ""
|
||||
if self.parameters.message_list_before_now_long:
|
||||
recent_messages = self.parameters.message_list_before_now_long[-20:]
|
||||
chat_history = build_readable_messages(
|
||||
recent_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal",
|
||||
truncate=True
|
||||
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
|
||||
)
|
||||
|
||||
|
||||
# 激活长期记忆
|
||||
memory_activator = MemoryActivator()
|
||||
running_memories = await memory_activator.activate_memory_with_chat_history(
|
||||
target_message=self.parameters.target,
|
||||
chat_history_prompt=chat_history
|
||||
target_message=self.parameters.target, chat_history_prompt=chat_history
|
||||
)
|
||||
|
||||
|
||||
# 获取即时记忆
|
||||
async_memory_wrapper = get_async_instant_memory(self.parameters.chat_id)
|
||||
instant_memory = await async_memory_wrapper.get_memory_with_fallback(self.parameters.target)
|
||||
|
||||
|
||||
# 构建记忆块
|
||||
memory_parts = []
|
||||
|
||||
|
||||
if running_memories:
|
||||
memory_parts.append("以下是当前在聊天中,你回忆起的记忆:")
|
||||
for memory in running_memories:
|
||||
memory_parts.append(f"- {memory['content']}")
|
||||
|
||||
|
||||
if instant_memory:
|
||||
memory_parts.append(f"- {instant_memory}")
|
||||
|
||||
|
||||
memory_block = "\n".join(memory_parts) if memory_parts else ""
|
||||
|
||||
|
||||
return {"memory_block": memory_block}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"构建记忆块失败: {e}")
|
||||
return {"memory_block": ""}
|
||||
|
||||
|
||||
async def _build_relation_info(self) -> Dict[str, Any]:
|
||||
"""构建关系信息"""
|
||||
try:
|
||||
@@ -579,110 +572,104 @@ class Prompt:
|
||||
except Exception as e:
|
||||
logger.error(f"构建关系信息失败: {e}")
|
||||
return {"relation_info_block": ""}
|
||||
|
||||
|
||||
async def _build_tool_info(self) -> Dict[str, Any]:
|
||||
"""构建工具信息"""
|
||||
if not global_config.tool.enable_tool:
|
||||
return {"tool_info_block": ""}
|
||||
|
||||
|
||||
try:
|
||||
from src.plugin_system.core.tool_use import ToolExecutor
|
||||
|
||||
|
||||
# 获取聊天历史
|
||||
chat_history = ""
|
||||
if self.parameters.message_list_before_now_long:
|
||||
recent_messages = self.parameters.message_list_before_now_long[-15:]
|
||||
chat_history = build_readable_messages(
|
||||
recent_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal",
|
||||
truncate=True
|
||||
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
|
||||
)
|
||||
|
||||
|
||||
# 创建工具执行器
|
||||
tool_executor = ToolExecutor(chat_id=self.parameters.chat_id)
|
||||
|
||||
|
||||
# 执行工具获取信息
|
||||
tool_results, _, _ = await tool_executor.execute_from_chat_message(
|
||||
sender=self.parameters.sender,
|
||||
target_message=self.parameters.target,
|
||||
chat_history=chat_history,
|
||||
return_details=False
|
||||
return_details=False,
|
||||
)
|
||||
|
||||
|
||||
# 构建工具信息块
|
||||
if tool_results:
|
||||
tool_info_parts = ["## 工具信息","以下是你通过工具获取到的实时信息:"]
|
||||
tool_info_parts = ["## 工具信息", "以下是你通过工具获取到的实时信息:"]
|
||||
for tool_result in tool_results:
|
||||
tool_name = tool_result.get("tool_name", "unknown")
|
||||
content = tool_result.get("content", "")
|
||||
result_type = tool_result.get("type", "tool_result")
|
||||
|
||||
|
||||
tool_info_parts.append(f"- 【{tool_name}】{result_type}: {content}")
|
||||
|
||||
|
||||
tool_info_parts.append("以上是你获取到的实时信息,请在回复时参考这些信息。")
|
||||
tool_info_block = "\n".join(tool_info_parts)
|
||||
else:
|
||||
tool_info_block = ""
|
||||
|
||||
|
||||
return {"tool_info_block": tool_info_block}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"构建工具信息失败: {e}")
|
||||
return {"tool_info_block": ""}
|
||||
|
||||
|
||||
async def _build_knowledge_info(self) -> Dict[str, Any]:
|
||||
"""构建知识信息"""
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
return {"knowledge_prompt": ""}
|
||||
|
||||
|
||||
try:
|
||||
from src.chat.knowledge.knowledge_lib import QAManager
|
||||
|
||||
|
||||
# 获取问题文本(当前消息)
|
||||
question = self.parameters.target or ""
|
||||
if not question:
|
||||
return {"knowledge_prompt": ""}
|
||||
|
||||
|
||||
# 创建QA管理器
|
||||
qa_manager = QAManager()
|
||||
|
||||
|
||||
# 搜索相关知识
|
||||
knowledge_results = await qa_manager.get_knowledge(
|
||||
question=question,
|
||||
chat_id=self.parameters.chat_id,
|
||||
max_results=5,
|
||||
min_similarity=0.5
|
||||
question=question, chat_id=self.parameters.chat_id, max_results=5, min_similarity=0.5
|
||||
)
|
||||
|
||||
|
||||
# 构建知识块
|
||||
if knowledge_results and knowledge_results.get("knowledge_items"):
|
||||
knowledge_parts = ["## 知识库信息","以下是与你当前对话相关的知识信息:"]
|
||||
|
||||
knowledge_parts = ["## 知识库信息", "以下是与你当前对话相关的知识信息:"]
|
||||
|
||||
for item in knowledge_results["knowledge_items"]:
|
||||
content = item.get("content", "")
|
||||
source = item.get("source", "")
|
||||
relevance = item.get("relevance", 0.0)
|
||||
|
||||
|
||||
if content:
|
||||
if source:
|
||||
knowledge_parts.append(f"- [{relevance:.2f}] {content} (来源: {source})")
|
||||
else:
|
||||
knowledge_parts.append(f"- [{relevance:.2f}] {content}")
|
||||
|
||||
|
||||
if knowledge_results.get("summary"):
|
||||
knowledge_parts.append(f"\n知识总结: {knowledge_results['summary']}")
|
||||
|
||||
|
||||
knowledge_prompt = "\n".join(knowledge_parts)
|
||||
else:
|
||||
knowledge_prompt = ""
|
||||
|
||||
|
||||
return {"knowledge_prompt": knowledge_prompt}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"构建知识信息失败: {e}")
|
||||
return {"knowledge_prompt": ""}
|
||||
|
||||
|
||||
async def _build_cross_context(self) -> Dict[str, Any]:
|
||||
"""构建跨群上下文"""
|
||||
try:
|
||||
@@ -693,7 +680,7 @@ class Prompt:
|
||||
except Exception as e:
|
||||
logger.error(f"构建跨群上下文失败: {e}")
|
||||
return {"cross_context_block": ""}
|
||||
|
||||
|
||||
async def _format_with_context(self, context_data: Dict[str, Any]) -> str:
|
||||
"""使用上下文数据格式化模板"""
|
||||
if self.parameters.prompt_mode == "s4u":
|
||||
@@ -702,9 +689,9 @@ class Prompt:
|
||||
params = self._prepare_normal_params(context_data)
|
||||
else:
|
||||
params = self._prepare_default_params(context_data)
|
||||
|
||||
|
||||
return await global_prompt_manager.format_prompt(self.name, **params) if self.name else self.format(**params)
|
||||
|
||||
|
||||
def _prepare_s4u_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""准备S4U模式的参数"""
|
||||
return {
|
||||
@@ -725,11 +712,13 @@ class Prompt:
|
||||
"time_block": context_data.get("time_block", ""),
|
||||
"reply_target_block": context_data.get("reply_target_block", ""),
|
||||
"reply_style": global_config.personality.reply_style,
|
||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
|
||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
|
||||
or context_data.get("keywords_reaction_prompt", ""),
|
||||
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
|
||||
"safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""),
|
||||
"safety_guidelines_block": self.parameters.safety_guidelines_block
|
||||
or context_data.get("safety_guidelines_block", ""),
|
||||
}
|
||||
|
||||
|
||||
def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""准备Normal模式的参数"""
|
||||
return {
|
||||
@@ -749,11 +738,13 @@ class Prompt:
|
||||
"reply_target_block": context_data.get("reply_target_block", ""),
|
||||
"config_expression_style": global_config.personality.reply_style,
|
||||
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
|
||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
|
||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
|
||||
or context_data.get("keywords_reaction_prompt", ""),
|
||||
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
|
||||
"safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""),
|
||||
"safety_guidelines_block": self.parameters.safety_guidelines_block
|
||||
or context_data.get("safety_guidelines_block", ""),
|
||||
}
|
||||
|
||||
|
||||
def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""准备默认模式的参数"""
|
||||
return {
|
||||
@@ -769,11 +760,13 @@ class Prompt:
|
||||
"reason": "",
|
||||
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
|
||||
"reply_style": global_config.personality.reply_style,
|
||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
|
||||
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
|
||||
or context_data.get("keywords_reaction_prompt", ""),
|
||||
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
|
||||
"safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""),
|
||||
"safety_guidelines_block": self.parameters.safety_guidelines_block
|
||||
or context_data.get("safety_guidelines_block", ""),
|
||||
}
|
||||
|
||||
|
||||
def format(self, *args, **kwargs) -> str:
|
||||
"""格式化模板,支持位置参数和关键字参数"""
|
||||
try:
|
||||
@@ -786,21 +779,21 @@ class Prompt:
|
||||
processed_template = self._processed_template.format(**formatted_args)
|
||||
else:
|
||||
processed_template = self._processed_template
|
||||
|
||||
|
||||
# 再用关键字参数格式化
|
||||
if kwargs:
|
||||
processed_template = processed_template.format(**kwargs)
|
||||
|
||||
|
||||
# 将临时标记还原为实际的花括号
|
||||
result = self._restore_escaped_braces(processed_template)
|
||||
return result
|
||||
except (IndexError, KeyError) as e:
|
||||
raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {str(e)}") from e
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""返回格式化后的结果或原始模板"""
|
||||
return self._formatted_result if self._formatted_result else self.template
|
||||
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""返回提示词的表示形式"""
|
||||
return f"Prompt(template='{self.template}', name='{self.name}')"
|
||||
@@ -872,9 +865,7 @@ class Prompt:
|
||||
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
|
||||
@staticmethod
|
||||
async def build_cross_context(
|
||||
chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]
|
||||
) -> str:
|
||||
async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
构建跨群聊上下文 - 统一实现
|
||||
|
||||
@@ -890,7 +881,7 @@ class Prompt:
|
||||
return ""
|
||||
|
||||
from src.plugin_system.apis import cross_context_api
|
||||
|
||||
|
||||
other_chat_raw_ids = cross_context_api.get_context_groups(chat_id)
|
||||
if not other_chat_raw_ids:
|
||||
return ""
|
||||
@@ -937,10 +928,7 @@ class Prompt:
|
||||
|
||||
# 工厂函数
|
||||
def create_prompt(
|
||||
template: str,
|
||||
name: Optional[str] = None,
|
||||
parameters: Optional[PromptParameters] = None,
|
||||
**kwargs
|
||||
template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
|
||||
) -> Prompt:
|
||||
"""快速创建Prompt实例的工厂函数"""
|
||||
if parameters is None:
|
||||
@@ -949,14 +937,10 @@ def create_prompt(
|
||||
|
||||
|
||||
async def create_prompt_async(
|
||||
template: str,
|
||||
name: Optional[str] = None,
|
||||
parameters: Optional[PromptParameters] = None,
|
||||
**kwargs
|
||||
template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
|
||||
) -> Prompt:
|
||||
"""异步创建Prompt实例"""
|
||||
prompt = create_prompt(template, name, parameters, **kwargs)
|
||||
if global_prompt_manager._context._current_context:
|
||||
await global_prompt_manager._context.register_async(prompt)
|
||||
return prompt
|
||||
|
||||
|
||||
@@ -332,9 +332,9 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
|
||||
|
||||
if global_config.response_splitter.enable and enable_splitter:
|
||||
logger.info(f"回复分割器已启用,模式: {global_config.response_splitter.split_mode}。")
|
||||
|
||||
|
||||
split_mode = global_config.response_splitter.split_mode
|
||||
|
||||
|
||||
if split_mode == "llm" and "[SPLIT]" in cleaned_text:
|
||||
logger.debug("检测到 [SPLIT] 标记,使用 LLM 自定义分割。")
|
||||
split_sentences_raw = cleaned_text.split("[SPLIT]")
|
||||
@@ -343,7 +343,7 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
|
||||
if split_mode == "llm":
|
||||
logger.debug("未检测到 [SPLIT] 标记,本次不进行分割。")
|
||||
split_sentences = [cleaned_text]
|
||||
else: # mode == "punctuation"
|
||||
else: # mode == "punctuation"
|
||||
logger.debug("使用基于标点的传统模式进行分割。")
|
||||
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
|
||||
else:
|
||||
|
||||
@@ -6,6 +6,7 @@ class BaseDataModel:
|
||||
def deepcopy(self):
|
||||
return copy.deepcopy(self)
|
||||
|
||||
|
||||
def temporarily_transform_class_to_dict(obj: Any) -> Any:
|
||||
# sourcery skip: assign-if-exp, reintroduce-else
|
||||
"""
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
机器人兴趣标签数据模型
|
||||
定义机器人的兴趣标签和相关的embedding数据结构
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Optional, Any
|
||||
from datetime import datetime
|
||||
@@ -12,6 +13,7 @@ from . import BaseDataModel
|
||||
@dataclass
|
||||
class BotInterestTag(BaseDataModel):
|
||||
"""机器人兴趣标签"""
|
||||
|
||||
tag_name: str
|
||||
weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0)
|
||||
embedding: Optional[List[float]] = None # 标签的embedding向量
|
||||
@@ -27,7 +29,7 @@ class BotInterestTag(BaseDataModel):
|
||||
"embedding": self.embedding,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
"is_active": self.is_active
|
||||
"is_active": self.is_active,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -39,13 +41,14 @@ class BotInterestTag(BaseDataModel):
|
||||
embedding=data.get("embedding"),
|
||||
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(),
|
||||
updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else datetime.now(),
|
||||
is_active=data.get("is_active", True)
|
||||
is_active=data.get("is_active", True),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotPersonalityInterests(BaseDataModel):
|
||||
"""机器人人格化兴趣配置"""
|
||||
|
||||
personality_id: str
|
||||
personality_description: str # 人设描述文本
|
||||
interest_tags: List[BotInterestTag] = field(default_factory=list)
|
||||
@@ -57,7 +60,6 @@ class BotPersonalityInterests(BaseDataModel):
|
||||
"""获取活跃的兴趣标签"""
|
||||
return [tag for tag in self.interest_tags if tag.is_active]
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
@@ -66,7 +68,7 @@ class BotPersonalityInterests(BaseDataModel):
|
||||
"interest_tags": [tag.to_dict() for tag in self.interest_tags],
|
||||
"embedding_model": self.embedding_model,
|
||||
"last_updated": self.last_updated.isoformat(),
|
||||
"version": self.version
|
||||
"version": self.version,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -78,13 +80,14 @@ class BotPersonalityInterests(BaseDataModel):
|
||||
interest_tags=[BotInterestTag.from_dict(tag_data) for tag_data in data.get("interest_tags", [])],
|
||||
embedding_model=data.get("embedding_model", "text-embedding-ada-002"),
|
||||
last_updated=datetime.fromisoformat(data["last_updated"]) if data.get("last_updated") else datetime.now(),
|
||||
version=data.get("version", 1)
|
||||
version=data.get("version", 1),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterestMatchResult(BaseDataModel):
|
||||
"""兴趣匹配结果"""
|
||||
|
||||
message_id: str
|
||||
matched_tags: List[str] = field(default_factory=list)
|
||||
match_scores: Dict[str, float] = field(default_factory=dict) # tag_name -> score
|
||||
@@ -120,7 +123,9 @@ class InterestMatchResult(BaseDataModel):
|
||||
# 计算置信度(基于匹配标签数量和分数分布)
|
||||
if len(self.match_scores) > 0:
|
||||
avg_score = self.overall_score
|
||||
score_variance = sum((score - avg_score) ** 2 for score in self.match_scores.values()) / len(self.match_scores)
|
||||
score_variance = sum((score - avg_score) ** 2 for score in self.match_scores.values()) / len(
|
||||
self.match_scores
|
||||
)
|
||||
# 分数越集中,置信度越高
|
||||
self.confidence = max(0.0, 1.0 - score_variance)
|
||||
else:
|
||||
@@ -129,4 +134,4 @@ class InterestMatchResult(BaseDataModel):
|
||||
def get_top_matches(self, top_n: int = 3) -> List[tuple]:
|
||||
"""获取前N个最佳匹配"""
|
||||
sorted_matches = sorted(self.match_scores.items(), key=lambda x: x[1], reverse=True)
|
||||
return sorted_matches[:top_n]
|
||||
return sorted_matches[:top_n]
|
||||
|
||||
@@ -208,6 +208,7 @@ class DatabaseMessages(BaseDataModel):
|
||||
"chat_info_user_cardname": self.chat_info.user_info.user_cardname,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class DatabaseActionRecords(BaseDataModel):
|
||||
def __init__(
|
||||
@@ -235,4 +236,4 @@ class DatabaseActionRecords(BaseDataModel):
|
||||
self.action_prompt_display = action_prompt_display
|
||||
self.chat_id = chat_id
|
||||
self.chat_info_stream_id = chat_info_stream_id
|
||||
self.chat_info_platform = chat_info_platform
|
||||
self.chat_info_platform = chat_info_platform
|
||||
|
||||
@@ -28,6 +28,7 @@ class ActionPlannerInfo(BaseDataModel):
|
||||
@dataclass
|
||||
class InterestScore(BaseDataModel):
|
||||
"""兴趣度评分结果"""
|
||||
|
||||
message_id: str
|
||||
total_score: float
|
||||
interest_match_score: float
|
||||
@@ -41,6 +42,7 @@ class Plan(BaseDataModel):
|
||||
"""
|
||||
统一规划数据模型
|
||||
"""
|
||||
|
||||
chat_id: str
|
||||
mode: "ChatMode"
|
||||
|
||||
|
||||
@@ -2,9 +2,11 @@ from dataclasses import dataclass
|
||||
from typing import Optional, List, Tuple, TYPE_CHECKING, Any
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMGenerationDataModel(BaseDataModel):
|
||||
content: Optional[str] = None
|
||||
@@ -13,4 +15,4 @@ class LLMGenerationDataModel(BaseDataModel):
|
||||
tool_calls: Optional[List["ToolCall"]] = None
|
||||
prompt: Optional[str] = None
|
||||
selected_expressions: Optional[List[int]] = None
|
||||
reply_set: Optional[List[Tuple[str, Any]]] = None
|
||||
reply_set: Optional[List[Tuple[str, Any]]] = None
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
消息管理模块数据模型
|
||||
定义消息管理器使用的数据结构
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
@@ -16,14 +17,16 @@ if TYPE_CHECKING:
|
||||
|
||||
class MessageStatus(Enum):
|
||||
"""消息状态枚举"""
|
||||
UNREAD = "unread" # 未读消息
|
||||
READ = "read" # 已读消息
|
||||
|
||||
UNREAD = "unread" # 未读消息
|
||||
READ = "read" # 已读消息
|
||||
PROCESSING = "processing" # 处理中
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamContext(BaseDataModel):
|
||||
"""聊天流上下文信息"""
|
||||
|
||||
stream_id: str
|
||||
unread_messages: List["DatabaseMessages"] = field(default_factory=list)
|
||||
history_messages: List["DatabaseMessages"] = field(default_factory=list)
|
||||
@@ -59,6 +62,7 @@ class StreamContext(BaseDataModel):
|
||||
@dataclass
|
||||
class MessageManagerStats(BaseDataModel):
|
||||
"""消息管理器统计信息"""
|
||||
|
||||
total_streams: int = 0
|
||||
active_streams: int = 0
|
||||
total_unread_messages: int = 0
|
||||
@@ -74,9 +78,10 @@ class MessageManagerStats(BaseDataModel):
|
||||
@dataclass
|
||||
class StreamStats(BaseDataModel):
|
||||
"""聊天流统计信息"""
|
||||
|
||||
stream_id: str
|
||||
is_active: bool
|
||||
unread_count: int
|
||||
history_count: int
|
||||
last_check_time: float
|
||||
has_active_task: bool
|
||||
has_active_task: bool
|
||||
|
||||
@@ -23,15 +23,15 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
||||
maim_message_config = global_config.maim_message
|
||||
|
||||
# 设置基本参数
|
||||
|
||||
|
||||
host = os.getenv("HOST", "127.0.0.1")
|
||||
port_str = os.getenv("PORT", "8000")
|
||||
|
||||
|
||||
try:
|
||||
port = int(port_str)
|
||||
except ValueError:
|
||||
port = 8000
|
||||
|
||||
|
||||
kwargs = {
|
||||
"host": host,
|
||||
"port": port,
|
||||
|
||||
@@ -31,7 +31,9 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
self.client_uuid: str | None = local_storage["mofox_uuid"] if "mofox_uuid" in local_storage else None # type: ignore
|
||||
"""客户端UUID"""
|
||||
|
||||
self.private_key_pem: str | None = local_storage["mofox_private_key"] if "mofox_private_key" in local_storage else None # type: ignore
|
||||
self.private_key_pem: str | None = (
|
||||
local_storage["mofox_private_key"] if "mofox_private_key" in local_storage else None
|
||||
) # type: ignore
|
||||
"""客户端私钥"""
|
||||
|
||||
self.info_dict = self._get_sys_info()
|
||||
@@ -61,78 +63,65 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
def _generate_signature(self, request_body: dict) -> tuple[str, str]:
|
||||
"""
|
||||
生成RSA签名
|
||||
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: (timestamp, signature_b64)
|
||||
"""
|
||||
if not self.private_key_pem:
|
||||
raise ValueError("私钥未初始化")
|
||||
|
||||
|
||||
# 生成时间戳
|
||||
timestamp = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
# 创建签名数据字符串
|
||||
sign_data = f"{self.client_uuid}:{timestamp}:{json.dumps(request_body, separators=(',', ':'))}"
|
||||
|
||||
|
||||
# 加载私钥
|
||||
private_key = serialization.load_pem_private_key(
|
||||
self.private_key_pem.encode('utf-8'),
|
||||
password=None
|
||||
)
|
||||
|
||||
private_key = serialization.load_pem_private_key(self.private_key_pem.encode("utf-8"), password=None)
|
||||
|
||||
# 确保是RSA私钥
|
||||
if not isinstance(private_key, rsa.RSAPrivateKey):
|
||||
raise ValueError("私钥必须是RSA格式")
|
||||
|
||||
|
||||
# 生成签名
|
||||
signature = private_key.sign(
|
||||
sign_data.encode('utf-8'),
|
||||
padding.PSS(
|
||||
mgf=padding.MGF1(hashes.SHA256()),
|
||||
salt_length=padding.PSS.MAX_LENGTH
|
||||
),
|
||||
hashes.SHA256()
|
||||
sign_data.encode("utf-8"),
|
||||
padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH),
|
||||
hashes.SHA256(),
|
||||
)
|
||||
|
||||
|
||||
# Base64编码
|
||||
signature_b64 = base64.b64encode(signature).decode('utf-8')
|
||||
|
||||
signature_b64 = base64.b64encode(signature).decode("utf-8")
|
||||
|
||||
return timestamp, signature_b64
|
||||
|
||||
def _decrypt_challenge(self, challenge_b64: str) -> str:
|
||||
"""
|
||||
解密挑战数据
|
||||
|
||||
|
||||
Args:
|
||||
challenge_b64: Base64编码的挑战数据
|
||||
|
||||
|
||||
Returns:
|
||||
str: 解密后的UUID字符串
|
||||
"""
|
||||
if not self.private_key_pem:
|
||||
raise ValueError("私钥未初始化")
|
||||
|
||||
|
||||
# 加载私钥
|
||||
private_key = serialization.load_pem_private_key(
|
||||
self.private_key_pem.encode('utf-8'),
|
||||
password=None
|
||||
)
|
||||
|
||||
private_key = serialization.load_pem_private_key(self.private_key_pem.encode("utf-8"), password=None)
|
||||
|
||||
# 确保是RSA私钥
|
||||
if not isinstance(private_key, rsa.RSAPrivateKey):
|
||||
raise ValueError("私钥必须是RSA格式")
|
||||
|
||||
|
||||
# 解密挑战数据
|
||||
decrypted_bytes = private_key.decrypt(
|
||||
base64.b64decode(challenge_b64),
|
||||
padding.OAEP(
|
||||
mgf=padding.MGF1(hashes.SHA256()),
|
||||
algorithm=hashes.SHA256(),
|
||||
label=None
|
||||
)
|
||||
padding.OAEP(mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None),
|
||||
)
|
||||
|
||||
return decrypted_bytes.decode('utf-8')
|
||||
|
||||
return decrypted_bytes.decode("utf-8")
|
||||
|
||||
async def _req_uuid(self) -> bool:
|
||||
"""
|
||||
@@ -155,28 +144,26 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
|
||||
if response.status != 200:
|
||||
response_text = await response.text()
|
||||
logger.error(
|
||||
f"注册步骤1失败,状态码: {response.status}, 响应内容: {response_text}"
|
||||
)
|
||||
logger.error(f"注册步骤1失败,状态码: {response.status}, 响应内容: {response_text}")
|
||||
raise aiohttp.ClientResponseError(
|
||||
request_info=response.request_info,
|
||||
history=response.history,
|
||||
status=response.status,
|
||||
message=f"Step1 failed: {response_text}"
|
||||
message=f"Step1 failed: {response_text}",
|
||||
)
|
||||
|
||||
step1_data = await response.json()
|
||||
temp_uuid = step1_data.get("temp_uuid")
|
||||
private_key = step1_data.get("private_key")
|
||||
challenge = step1_data.get("challenge")
|
||||
|
||||
|
||||
if not all([temp_uuid, private_key, challenge]):
|
||||
logger.error("Step1响应缺少必要字段:temp_uuid, private_key 或 challenge")
|
||||
raise ValueError("Step1响应数据不完整")
|
||||
|
||||
# 临时保存私钥用于解密
|
||||
self.private_key_pem = private_key
|
||||
|
||||
|
||||
# 解密挑战数据
|
||||
logger.debug("解密挑战数据...")
|
||||
try:
|
||||
@@ -184,21 +171,18 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
except Exception as e:
|
||||
logger.error(f"解密挑战数据失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# 验证解密结果
|
||||
if decrypted_uuid != temp_uuid:
|
||||
logger.error(f"解密结果验证失败: 期望 {temp_uuid}, 实际 {decrypted_uuid}")
|
||||
raise ValueError("解密结果与临时UUID不匹配")
|
||||
|
||||
|
||||
logger.debug("挑战数据解密成功,开始注册步骤2")
|
||||
|
||||
# Step 2: 发送解密结果完成注册
|
||||
async with session.post(
|
||||
f"{TELEMETRY_SERVER_URL}/stat/reg_client_step2",
|
||||
json={
|
||||
"temp_uuid": temp_uuid,
|
||||
"decrypted_uuid": decrypted_uuid
|
||||
},
|
||||
json={"temp_uuid": temp_uuid, "decrypted_uuid": decrypted_uuid},
|
||||
timeout=aiohttp.ClientTimeout(total=5),
|
||||
) as response:
|
||||
logger.debug(f"Step2 Response status: {response.status}")
|
||||
@@ -206,7 +190,7 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
if response.status == 200:
|
||||
step2_data = await response.json()
|
||||
mofox_uuid = step2_data.get("mofox_uuid")
|
||||
|
||||
|
||||
if mofox_uuid:
|
||||
# 将正式UUID和私钥存储到本地
|
||||
local_storage["mofox_uuid"] = mofox_uuid
|
||||
@@ -225,23 +209,19 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
raise ValueError(f"Step2失败: {response_text}")
|
||||
else:
|
||||
response_text = await response.text()
|
||||
logger.error(
|
||||
f"注册步骤2失败,状态码: {response.status}, 响应内容: {response_text}"
|
||||
)
|
||||
logger.error(f"注册步骤2失败,状态码: {response.status}, 响应内容: {response_text}")
|
||||
raise aiohttp.ClientResponseError(
|
||||
request_info=response.request_info,
|
||||
history=response.history,
|
||||
status=response.status,
|
||||
message=f"Step2 failed: {response_text}"
|
||||
message=f"Step2 failed: {response_text}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
error_msg = str(e) or "未知错误"
|
||||
logger.warning(
|
||||
f"注册客户端出错,不过你还是可以正常使用墨狐: {type(e).__name__}: {error_msg}"
|
||||
)
|
||||
logger.warning(f"注册客户端出错,不过你还是可以正常使用墨狐: {type(e).__name__}: {error_msg}")
|
||||
logger.debug(f"完整错误信息: {traceback.format_exc()}")
|
||||
|
||||
# 请求失败,重试次数+1
|
||||
@@ -264,13 +244,13 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
try:
|
||||
# 生成签名
|
||||
timestamp, signature = self._generate_signature(self.info_dict)
|
||||
|
||||
|
||||
headers = {
|
||||
"X-mofox-UUID": self.client_uuid,
|
||||
"X-mofox-Signature": signature,
|
||||
"X-mofox-Timestamp": timestamp,
|
||||
"User-Agent": f"MofoxClient/{self.client_uuid[:8]}",
|
||||
"Content-Type": "application/json"
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
logger.debug(f"正在发送心跳到服务器: {self.server_url}")
|
||||
@@ -347,4 +327,4 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
logger.warning("客户端注册失败,跳过此次心跳")
|
||||
return
|
||||
|
||||
await self._send_heartbeat()
|
||||
await self._send_heartbeat()
|
||||
|
||||
@@ -99,14 +99,13 @@ def get_global_server() -> Server:
|
||||
"""获取全局服务器实例"""
|
||||
global global_server
|
||||
if global_server is None:
|
||||
|
||||
host = os.getenv("HOST", "127.0.0.1")
|
||||
port_str = os.getenv("PORT", "8000")
|
||||
|
||||
|
||||
try:
|
||||
port = int(port_str)
|
||||
except ValueError:
|
||||
port = 8000
|
||||
|
||||
|
||||
global_server = Server(host=host, port=port)
|
||||
return global_server
|
||||
|
||||
@@ -44,7 +44,7 @@ from src.config.official_configs import (
|
||||
PermissionConfig,
|
||||
CommandConfig,
|
||||
PlanningSystemConfig,
|
||||
AffinityFlowConfig
|
||||
AffinityFlowConfig,
|
||||
)
|
||||
|
||||
from .api_ada_configs import (
|
||||
@@ -399,9 +399,7 @@ class Config(ValidatedConfigBase):
|
||||
cross_context: CrossContextConfig = Field(
|
||||
default_factory=lambda: CrossContextConfig(), description="跨群聊上下文共享配置"
|
||||
)
|
||||
affinity_flow: AffinityFlowConfig = Field(
|
||||
default_factory=lambda: AffinityFlowConfig(), description="亲和流配置"
|
||||
)
|
||||
affinity_flow: AffinityFlowConfig = Field(default_factory=lambda: AffinityFlowConfig(), description="亲和流配置")
|
||||
|
||||
|
||||
class APIAdapterConfig(ValidatedConfigBase):
|
||||
|
||||
@@ -51,8 +51,12 @@ class PersonalityConfig(ValidatedConfigBase):
|
||||
personality_core: str = Field(..., description="核心人格")
|
||||
personality_side: str = Field(..., description="人格侧写")
|
||||
identity: str = Field(default="", description="身份特征")
|
||||
background_story: str = Field(default="", description="世界观背景故事,这部分内容会作为背景知识,LLM被指导不应主动复述")
|
||||
safety_guidelines: List[str] = Field(default_factory=list, description="安全与互动底线,Bot在任何情况下都必须遵守的原则")
|
||||
background_story: str = Field(
|
||||
default="", description="世界观背景故事,这部分内容会作为背景知识,LLM被指导不应主动复述"
|
||||
)
|
||||
safety_guidelines: List[str] = Field(
|
||||
default_factory=list, description="安全与互动底线,Bot在任何情况下都必须遵守的原则"
|
||||
)
|
||||
reply_style: str = Field(default="", description="表达风格")
|
||||
prompt_mode: Literal["s4u", "normal"] = Field(default="s4u", description="Prompt模式")
|
||||
compress_personality: bool = Field(default=True, description="是否压缩人格")
|
||||
@@ -79,7 +83,8 @@ class ChatConfig(ValidatedConfigBase):
|
||||
talk_frequency_adjust: list[list[str]] = Field(default_factory=lambda: [], description="聊天频率调整")
|
||||
focus_value: float = Field(default=1.0, description="专注值")
|
||||
focus_mode_quiet_groups: List[str] = Field(
|
||||
default_factory=list, description='专注模式下需要保持安静的群组列表, 格式: ["platform:group_id1", "platform:group_id2"]'
|
||||
default_factory=list,
|
||||
description='专注模式下需要保持安静的群组列表, 格式: ["platform:group_id1", "platform:group_id2"]',
|
||||
)
|
||||
force_reply_private: bool = Field(default=False, description="强制回复私聊")
|
||||
group_chat_mode: Literal["auto", "normal", "focus"] = Field(default="auto", description="群聊模式")
|
||||
@@ -343,6 +348,7 @@ class ExpressionConfig(ValidatedConfigBase):
|
||||
# 如果都没有匹配,返回默认值
|
||||
return True, True, 1.0
|
||||
|
||||
|
||||
class ToolConfig(ValidatedConfigBase):
|
||||
"""工具配置类"""
|
||||
|
||||
@@ -477,7 +483,6 @@ class ExperimentalConfig(ValidatedConfigBase):
|
||||
pfc_chatting: bool = Field(default=False, description="启用PFC聊天")
|
||||
|
||||
|
||||
|
||||
class MaimMessageConfig(ValidatedConfigBase):
|
||||
"""maim_message配置类"""
|
||||
|
||||
@@ -602,8 +607,12 @@ class SleepSystemConfig(ValidatedConfigBase):
|
||||
sleep_by_schedule: bool = Field(default=True, description="是否根据日程表进行睡觉")
|
||||
fixed_sleep_time: str = Field(default="23:00", description="固定的睡觉时间")
|
||||
fixed_wake_up_time: str = Field(default="07:00", description="固定的起床时间")
|
||||
sleep_time_offset_minutes: int = Field(default=15, ge=0, le=60, description="睡觉时间随机偏移量范围(分钟),实际睡觉时间会在±该值范围内随机")
|
||||
wake_up_time_offset_minutes: int = Field(default=15, ge=0, le=60, description="起床时间随机偏移量范围(分钟),实际起床时间会在±该值范围内随机")
|
||||
sleep_time_offset_minutes: int = Field(
|
||||
default=15, ge=0, le=60, description="睡觉时间随机偏移量范围(分钟),实际睡觉时间会在±该值范围内随机"
|
||||
)
|
||||
wake_up_time_offset_minutes: int = Field(
|
||||
default=15, ge=0, le=60, description="起床时间随机偏移量范围(分钟),实际起床时间会在±该值范围内随机"
|
||||
)
|
||||
wakeup_threshold: float = Field(default=15.0, ge=1.0, description="唤醒阈值,达到此值时会被唤醒")
|
||||
private_message_increment: float = Field(default=3.0, ge=0.1, description="私聊消息增加的唤醒度")
|
||||
group_mention_increment: float = Field(default=2.0, ge=0.1, description="群聊艾特增加的唤醒度")
|
||||
@@ -618,10 +627,10 @@ class SleepSystemConfig(ValidatedConfigBase):
|
||||
# --- 失眠机制相关参数 ---
|
||||
enable_insomnia_system: bool = Field(default=True, description="是否启用失眠系统")
|
||||
insomnia_trigger_delay_minutes: List[int] = Field(
|
||||
default_factory=lambda:[30, 60], description="入睡后触发失眠判定的延迟时间范围(分钟)"
|
||||
default_factory=lambda: [30, 60], description="入睡后触发失眠判定的延迟时间范围(分钟)"
|
||||
)
|
||||
insomnia_duration_minutes: List[int] = Field(
|
||||
default_factory=lambda:[15, 45], description="单次失眠状态的持续时间范围(分钟)"
|
||||
default_factory=lambda: [15, 45], description="单次失眠状态的持续时间范围(分钟)"
|
||||
)
|
||||
sleep_pressure_threshold: float = Field(default=30.0, description="触发“压力不足型失眠”的睡眠压力阈值")
|
||||
deep_sleep_threshold: float = Field(default=80.0, description="进入“深度睡眠”的睡眠压力阈值")
|
||||
@@ -657,6 +666,8 @@ class CrossContextConfig(ValidatedConfigBase):
|
||||
|
||||
enable: bool = Field(default=False, description="是否启用跨群聊上下文共享功能")
|
||||
groups: List[ContextGroup] = Field(default_factory=list, description="上下文共享组列表")
|
||||
|
||||
|
||||
class CommandConfig(ValidatedConfigBase):
|
||||
"""命令系统配置类"""
|
||||
|
||||
|
||||
@@ -88,8 +88,7 @@ class Individuality:
|
||||
|
||||
# 初始化智能兴趣系统
|
||||
await interest_scoring_system.initialize_smart_interests(
|
||||
personality_description=full_personality,
|
||||
personality_id=self.bot_person_id
|
||||
personality_description=full_personality, personality_id=self.bot_person_id
|
||||
)
|
||||
|
||||
logger.info("智能兴趣系统初始化完成")
|
||||
|
||||
@@ -130,7 +130,8 @@ class MainSystem:
|
||||
# 停止消息重组器
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system import EventType
|
||||
asyncio.run(event_manager.trigger_event(EventType.ON_STOP,permission_group="SYSTEM"))
|
||||
|
||||
asyncio.run(event_manager.trigger_event(EventType.ON_STOP, permission_group="SYSTEM"))
|
||||
from src.utils.message_chunker import reassembler
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
@@ -216,7 +217,7 @@ MoFox_Bot(第三方修改版)
|
||||
|
||||
# 添加统计信息输出任务
|
||||
await async_task_manager.add_task(StatisticOutputTask())
|
||||
|
||||
|
||||
# 添加遥测心跳任务
|
||||
await async_task_manager.add_task(TelemetryHeartBeatTask())
|
||||
|
||||
@@ -250,6 +251,7 @@ MoFox_Bot(第三方修改版)
|
||||
|
||||
# 初始化回复后关系追踪系统
|
||||
from src.chat.affinity_flow.relationship_integration import initialize_relationship_tracking
|
||||
|
||||
relationship_tracker = initialize_relationship_tracking()
|
||||
if relationship_tracker:
|
||||
logger.info("回复后关系追踪系统初始化成功")
|
||||
@@ -273,6 +275,7 @@ MoFox_Bot(第三方修改版)
|
||||
|
||||
# 初始化LPMM知识库
|
||||
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge
|
||||
|
||||
initialize_lpmm_knowledge()
|
||||
logger.info("LPMM知识库初始化成功")
|
||||
|
||||
@@ -298,6 +301,7 @@ MoFox_Bot(第三方修改版)
|
||||
|
||||
# 启动消息管理器
|
||||
from src.chat.message_manager import message_manager
|
||||
|
||||
await message_manager.start()
|
||||
logger.info("消息管理器已启动")
|
||||
|
||||
|
||||
@@ -96,12 +96,13 @@ class PersonInfoManager:
|
||||
# 在此处打一个补丁,如果platform为qq,尝试生成id后检查是否存在,如果不存在,则将平台换为napcat后再次检查,如果存在,则更新原id为platform为qq的id
|
||||
components = [platform, str(user_id)]
|
||||
key = "_".join(components)
|
||||
|
||||
|
||||
# 如果不是 qq 平台,直接返回计算的 id
|
||||
if platform != "qq":
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
qq_id = hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
# 对于 qq 平台,先检查该 person_id 是否已存在;如果存在直接返回
|
||||
def _db_check_and_migrate_sync(p_id: str, raw_user_id: str):
|
||||
try:
|
||||
@@ -191,16 +192,16 @@ class PersonInfoManager:
|
||||
# Ensure person_id is correctly set from the argument
|
||||
final_data["person_id"] = person_id
|
||||
# 你们的英文注释是何意味?
|
||||
|
||||
|
||||
# 检查并修复关键字段为None的情况喵
|
||||
if final_data.get("user_id") is None:
|
||||
logger.warning(f"user_id为None,使用'unknown'作为默认值 person_id={person_id}")
|
||||
final_data["user_id"] = "unknown"
|
||||
|
||||
|
||||
if final_data.get("platform") is None:
|
||||
logger.warning(f"platform为None,使用'unknown'作为默认值 person_id={person_id}")
|
||||
final_data["platform"] = "unknown"
|
||||
|
||||
|
||||
# 这里的目的是为了防止在识别出错的情况下有一个最小回退,不只是针对@消息识别成视频后的报错问题
|
||||
|
||||
# Serialize JSON fields
|
||||
@@ -251,12 +252,12 @@ class PersonInfoManager:
|
||||
|
||||
# Ensure person_id is correctly set from the argument
|
||||
final_data["person_id"] = person_id
|
||||
|
||||
|
||||
# 检查并修复关键字段为None的情况
|
||||
if final_data.get("user_id") is None:
|
||||
logger.warning(f"user_id为None,使用'unknown'作为默认值 person_id={person_id}")
|
||||
final_data["user_id"] = "unknown"
|
||||
|
||||
|
||||
if final_data.get("platform") is None:
|
||||
logger.warning(f"platform为None,使用'unknown'作为默认值 person_id={person_id}")
|
||||
final_data["platform"] = "unknown"
|
||||
@@ -356,12 +357,12 @@ class PersonInfoManager:
|
||||
creation_data["platform"] = data["platform"]
|
||||
if data and "user_id" in data:
|
||||
creation_data["user_id"] = data["user_id"]
|
||||
|
||||
|
||||
# 额外检查关键字段,如果为None则使用默认值
|
||||
if creation_data.get("user_id") is None:
|
||||
logger.warning(f"创建用户时user_id为None,使用'unknown'作为默认值 person_id={person_id}")
|
||||
creation_data["user_id"] = "unknown"
|
||||
|
||||
|
||||
if creation_data.get("platform") is None:
|
||||
logger.warning(f"创建用户时platform为None,使用'unknown'作为默认值 person_id={person_id}")
|
||||
creation_data["platform"] = "unknown"
|
||||
|
||||
@@ -123,7 +123,9 @@ class RelationshipFetcher:
|
||||
all_points = current_points + forgotten_points
|
||||
if all_points:
|
||||
# 按权重和时效性综合排序
|
||||
all_points.sort(key=lambda x: (float(x[1]) if len(x) > 1 else 0, float(x[2]) if len(x) > 2 else 0), reverse=True)
|
||||
all_points.sort(
|
||||
key=lambda x: (float(x[1]) if len(x) > 1 else 0, float(x[2]) if len(x) > 2 else 0), reverse=True
|
||||
)
|
||||
selected_points = all_points[:points_num]
|
||||
points_text = "\n".join([f"- {point[0]}({point[2]})" for point in selected_points if len(point) > 2])
|
||||
else:
|
||||
@@ -139,15 +141,17 @@ class RelationshipFetcher:
|
||||
# 2. 认识时间和频率
|
||||
if know_since:
|
||||
from datetime import datetime
|
||||
know_time = datetime.fromtimestamp(know_since).strftime('%Y年%m月%d日')
|
||||
|
||||
know_time = datetime.fromtimestamp(know_since).strftime("%Y年%m月%d日")
|
||||
relation_parts.append(f"你从{know_time}开始认识{person_name}")
|
||||
|
||||
|
||||
if know_times > 0:
|
||||
relation_parts.append(f"你们已经交流过{int(know_times)}次")
|
||||
|
||||
|
||||
if last_know:
|
||||
from datetime import datetime
|
||||
last_time = datetime.fromtimestamp(last_know).strftime('%m月%d日')
|
||||
|
||||
last_time = datetime.fromtimestamp(last_know).strftime("%m月%d日")
|
||||
relation_parts.append(f"最近一次交流是在{last_time}")
|
||||
|
||||
# 3. 态度和印象
|
||||
@@ -156,7 +160,7 @@ class RelationshipFetcher:
|
||||
|
||||
if short_impression:
|
||||
relation_parts.append(f"你对ta的总体印象:{short_impression}")
|
||||
|
||||
|
||||
if full_impression:
|
||||
relation_parts.append(f"更详细的了解:{full_impression}")
|
||||
|
||||
@@ -168,14 +172,14 @@ class RelationshipFetcher:
|
||||
try:
|
||||
from src.common.database.sqlalchemy_database_api import db_query
|
||||
from src.common.database.sqlalchemy_models import UserRelationships
|
||||
|
||||
|
||||
# 查询用户关系数据
|
||||
relationships = await db_query(
|
||||
UserRelationships,
|
||||
filters=[UserRelationships.user_id == str(person_info_manager.get_value_sync(person_id, "user_id"))],
|
||||
limit=1
|
||||
limit=1,
|
||||
)
|
||||
|
||||
|
||||
if relationships:
|
||||
rel_data = relationships[0]
|
||||
if rel_data.relationship_text:
|
||||
@@ -183,13 +187,15 @@ class RelationshipFetcher:
|
||||
if rel_data.relationship_score:
|
||||
score_desc = self._get_relationship_score_description(rel_data.relationship_score)
|
||||
relation_parts.append(f"关系亲密程度:{score_desc}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"查询UserRelationships表失败: {e}")
|
||||
|
||||
# 构建最终的关系信息字符串
|
||||
if relation_parts:
|
||||
relation_info = f"关于{person_name},你知道以下信息:\n" + "\n".join([f"• {part}" for part in relation_parts])
|
||||
relation_info = f"关于{person_name},你知道以下信息:\n" + "\n".join(
|
||||
[f"• {part}" for part in relation_parts]
|
||||
)
|
||||
else:
|
||||
relation_info = f"你对{person_name}了解不多,这是比较初步的交流。"
|
||||
|
||||
|
||||
@@ -93,7 +93,6 @@ class BaseAction(ABC):
|
||||
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
|
||||
self.chat_type_allow: ChatType = getattr(self.__class__, "chat_type_allow", ChatType.ALL)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解)
|
||||
# =============================================================================
|
||||
@@ -398,6 +397,7 @@ class BaseAction(ABC):
|
||||
try:
|
||||
# 1. 从注册中心获取Action类
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
action_class = component_registry.get_component_class(action_name, ComponentType.ACTION)
|
||||
if not action_class:
|
||||
logger.error(f"{log_prefix} 未找到Action: {action_name}")
|
||||
@@ -406,7 +406,7 @@ class BaseAction(ABC):
|
||||
# 2. 准备实例化参数
|
||||
# 复用当前Action的大部分上下文信息
|
||||
called_action_data = action_data if action_data is not None else self.action_data
|
||||
|
||||
|
||||
component_info = component_registry.get_component_info(action_name, ComponentType.ACTION)
|
||||
if not component_info:
|
||||
logger.warning(f"{log_prefix} 未找到Action组件信息: {action_name}")
|
||||
|
||||
@@ -98,7 +98,7 @@ class BaseEventHandler(ABC):
|
||||
weight=cls.weight,
|
||||
intercept_message=cls.intercept_message,
|
||||
)
|
||||
|
||||
|
||||
def set_plugin_name(self, plugin_name: str) -> None:
|
||||
"""设置插件名称
|
||||
|
||||
@@ -107,9 +107,9 @@ class BaseEventHandler(ABC):
|
||||
"""
|
||||
self.plugin_name = plugin_name
|
||||
|
||||
def set_plugin_config(self,plugin_config) -> None:
|
||||
def set_plugin_config(self, plugin_config) -> None:
|
||||
self.plugin_config = plugin_config
|
||||
|
||||
|
||||
def get_config(self, key: str, default=None):
|
||||
"""获取插件配置值,支持嵌套键访问
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ class EventType(Enum):
|
||||
"""
|
||||
|
||||
ON_START = "on_start" # 启动事件,用于调用按时任务
|
||||
ON_STOP ="on_stop"
|
||||
ON_STOP = "on_stop"
|
||||
ON_MESSAGE = "on_message"
|
||||
ON_PLAN = "on_plan"
|
||||
POST_LLM = "post_llm"
|
||||
|
||||
@@ -270,7 +270,9 @@ class ComponentRegistry:
|
||||
# 使用EventManager进行事件处理器注册
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
return event_manager.register_event_handler(handler_class,self.get_plugin_config(handler_info.plugin_name) or {})
|
||||
return event_manager.register_event_handler(
|
||||
handler_class, self.get_plugin_config(handler_info.plugin_name) or {}
|
||||
)
|
||||
|
||||
# === 组件移除相关 ===
|
||||
|
||||
@@ -682,19 +684,20 @@ class ComponentRegistry:
|
||||
plugin_instance = plugin_manager.get_plugin_instance(plugin_name)
|
||||
if plugin_instance and plugin_instance.config:
|
||||
return plugin_instance.config
|
||||
|
||||
|
||||
# 如果插件实例不存在,尝试从配置文件读取
|
||||
try:
|
||||
import toml
|
||||
|
||||
config_path = Path("config") / "plugins" / plugin_name / "config.toml"
|
||||
if config_path.exists():
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = toml.load(f)
|
||||
logger.debug(f"从配置文件读取插件 {plugin_name} 的配置")
|
||||
return config_data
|
||||
except Exception as e:
|
||||
logger.debug(f"读取插件 {plugin_name} 配置文件失败: {e}")
|
||||
|
||||
|
||||
return {}
|
||||
|
||||
def get_registry_stats(self) -> Dict[str, Any]:
|
||||
|
||||
@@ -145,7 +145,9 @@ class EventManager:
|
||||
logger.info(f"事件 {event_name} 已禁用")
|
||||
return True
|
||||
|
||||
def register_event_handler(self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None) -> bool:
|
||||
def register_event_handler(
|
||||
self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None
|
||||
) -> bool:
|
||||
"""注册事件处理器
|
||||
|
||||
Args:
|
||||
@@ -167,7 +169,7 @@ class EventManager:
|
||||
# 创建事件处理器实例,传递插件配置
|
||||
handler_instance = handler_class()
|
||||
handler_instance.plugin_config = plugin_config
|
||||
if plugin_config is not None and hasattr(handler_instance, 'set_plugin_config'):
|
||||
if plugin_config is not None and hasattr(handler_instance, "set_plugin_config"):
|
||||
handler_instance.set_plugin_config(plugin_config)
|
||||
|
||||
self._event_handlers[handler_name] = handler_instance
|
||||
|
||||
@@ -199,9 +199,7 @@ class PluginManager:
|
||||
self._show_plugin_components(plugin_name)
|
||||
|
||||
# 检查并调用 on_plugin_loaded 钩子(如果存在)
|
||||
if hasattr(plugin_instance, "on_plugin_loaded") and callable(
|
||||
plugin_instance.on_plugin_loaded
|
||||
):
|
||||
if hasattr(plugin_instance, "on_plugin_loaded") and callable(plugin_instance.on_plugin_loaded):
|
||||
logger.debug(f"为插件 '{plugin_name}' 调用 on_plugin_loaded 钩子")
|
||||
try:
|
||||
# 使用 asyncio.create_task 确保它不会阻塞加载流程
|
||||
|
||||
@@ -64,50 +64,50 @@ class AtAction(BaseAction):
|
||||
# 使用回复器生成艾特回复,而不是直接发送命令
|
||||
from src.chat.replyer.default_generator import DefaultReplyer
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
|
||||
# 获取当前聊天流
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = self.chat_stream or chat_manager.get_stream(self.chat_id)
|
||||
|
||||
|
||||
if not chat_stream:
|
||||
logger.error(f"找不到聊天流: {self.chat_stream}")
|
||||
return False, "聊天流不存在"
|
||||
|
||||
|
||||
# 创建回复器实例
|
||||
replyer = DefaultReplyer(chat_stream)
|
||||
|
||||
|
||||
# 构建回复对象,将艾特消息作为回复目标
|
||||
reply_to = f"{user_name}:{at_message}"
|
||||
extra_info = f"你需要艾特用户 {user_name} 并回复他们说: {at_message}"
|
||||
|
||||
|
||||
# 使用回复器生成回复
|
||||
success, llm_response, prompt = await replyer.generate_reply_with_context(
|
||||
reply_to=reply_to,
|
||||
extra_info=extra_info,
|
||||
enable_tool=False, # 艾特回复通常不需要工具调用
|
||||
from_plugin=False
|
||||
from_plugin=False,
|
||||
)
|
||||
|
||||
|
||||
if success and llm_response:
|
||||
# 获取生成的回复内容
|
||||
reply_content = llm_response.get("content", "")
|
||||
if reply_content:
|
||||
# 获取用户QQ号,发送真正的艾特消息
|
||||
user_id = user_info.get("user_id")
|
||||
|
||||
|
||||
# 发送真正的艾特命令,使用回复器生成的智能内容
|
||||
await self.send_command(
|
||||
"SEND_AT_MESSAGE",
|
||||
args={"qq_id": user_id, "text": reply_content},
|
||||
display_message=f"艾特用户 {user_name} 并发送智能回复: {reply_content}",
|
||||
)
|
||||
|
||||
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True,
|
||||
action_prompt_display=f"执行了艾特用户动作:艾特用户 {user_name} 并发送智能回复: {reply_content}",
|
||||
action_done=True,
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"成功通过回复器生成智能内容并发送真正的艾特消息给 {user_name}: {reply_content}")
|
||||
return True, "智能艾特消息发送成功"
|
||||
else:
|
||||
@@ -116,7 +116,7 @@ class AtAction(BaseAction):
|
||||
else:
|
||||
logger.error("回复器生成回复失败")
|
||||
return False, "回复生成失败"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行艾特用户动作时发生异常: {e}", exc_info=True)
|
||||
await self.store_action_info(
|
||||
|
||||
@@ -70,7 +70,9 @@ class EmojiAction(BaseAction):
|
||||
|
||||
# 2. 获取所有有效的表情包对象
|
||||
emoji_manager = get_emoji_manager()
|
||||
all_emojis_obj: list[MaiEmoji] = [e for e in emoji_manager.emoji_objects if not e.is_deleted and e.description]
|
||||
all_emojis_obj: list[MaiEmoji] = [
|
||||
e for e in emoji_manager.emoji_objects if not e.is_deleted and e.description
|
||||
]
|
||||
if not all_emojis_obj:
|
||||
logger.warning(f"{self.log_prefix} 无法获取任何带有描述的有效表情包")
|
||||
return False, "无法获取任何带有描述的有效表情包"
|
||||
@@ -91,12 +93,12 @@ class EmojiAction(BaseAction):
|
||||
# 4. 准备情感数据和后备列表
|
||||
emotion_map = {}
|
||||
all_emojis_data = []
|
||||
|
||||
|
||||
for emoji in all_emojis_obj:
|
||||
b64 = image_path_to_base64(emoji.full_path)
|
||||
if not b64:
|
||||
continue
|
||||
|
||||
|
||||
desc = emoji.description
|
||||
emotions = emoji.emotion
|
||||
all_emojis_data.append((b64, desc))
|
||||
@@ -168,16 +170,18 @@ class EmojiAction(BaseAction):
|
||||
|
||||
# 使用模糊匹配来查找最相关的情感标签
|
||||
matched_key = next((key for key in emotion_map if chosen_emotion in key), None)
|
||||
|
||||
|
||||
if matched_key:
|
||||
emoji_base64, emoji_description = random.choice(emotion_map[matched_key])
|
||||
logger.info(f"{self.log_prefix} 找到匹配情感 '{chosen_emotion}' (匹配到: '{matched_key}') 的表情包: {emoji_description}")
|
||||
logger.info(
|
||||
f"{self.log_prefix} 找到匹配情感 '{chosen_emotion}' (匹配到: '{matched_key}') 的表情包: {emoji_description}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"{self.log_prefix} LLM选择的情感 '{chosen_emotion}' 不在可用列表中, 将随机选择一个表情包"
|
||||
)
|
||||
emoji_base64, emoji_description = random.choice(all_emojis_data)
|
||||
|
||||
|
||||
elif global_config.emoji.emoji_selection_mode == "description":
|
||||
# --- 详细描述选择模式 ---
|
||||
# 获取最近的5条消息内容用于判断
|
||||
@@ -226,15 +230,23 @@ class EmojiAction(BaseAction):
|
||||
logger.info(f"{self.log_prefix} LLM选择的描述: {chosen_description}")
|
||||
|
||||
# 简单关键词匹配
|
||||
matched_emoji = next((item for item in all_emojis_data if chosen_description.lower() in item[1].lower() or item[1].lower() in chosen_description.lower()), None)
|
||||
|
||||
matched_emoji = next(
|
||||
(
|
||||
item
|
||||
for item in all_emojis_data
|
||||
if chosen_description.lower() in item[1].lower()
|
||||
or item[1].lower() in chosen_description.lower()
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# 如果包含匹配失败,尝试关键词匹配
|
||||
if not matched_emoji:
|
||||
keywords = ['惊讶', '困惑', '呆滞', '震惊', '懵', '无语', '萌', '可爱']
|
||||
keywords = ["惊讶", "困惑", "呆滞", "震惊", "懵", "无语", "萌", "可爱"]
|
||||
for keyword in keywords:
|
||||
if keyword in chosen_description:
|
||||
for item in all_emojis_data:
|
||||
if any(k in item[1] for k in ['呆', '萌', '惊', '困惑', '无语']):
|
||||
if any(k in item[1] for k in ["呆", "萌", "惊", "困惑", "无语"]):
|
||||
matched_emoji = item
|
||||
break
|
||||
if matched_emoji:
|
||||
@@ -255,7 +267,9 @@ class EmojiAction(BaseAction):
|
||||
|
||||
if not success:
|
||||
logger.error(f"{self.log_prefix} 表情包发送失败")
|
||||
await self.store_action_info(action_build_into_prompt = True,action_prompt_display =f"发送了一个表情包,但失败了",action_done= False)
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True, action_prompt_display=f"发送了一个表情包,但失败了", action_done=False
|
||||
)
|
||||
return False, "表情包发送失败"
|
||||
|
||||
# 发送成功后,记录到历史
|
||||
@@ -263,8 +277,10 @@ class EmojiAction(BaseAction):
|
||||
add_emoji_to_history(self.chat_id, emoji_description)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 添加表情到历史记录时出错: {e}")
|
||||
|
||||
await self.store_action_info(action_build_into_prompt = True,action_prompt_display =f"发送了一个表情包",action_done= True)
|
||||
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True, action_prompt_display=f"发送了一个表情包", action_done=True
|
||||
)
|
||||
|
||||
return True, f"发送表情包: {emoji_description}"
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
from src.plugin_system import BaseEventHandler
|
||||
from src.plugin_system.base.base_event import HandlerResult
|
||||
|
||||
@@ -1748,6 +1747,7 @@ class SetGroupSignHandler(BaseEventHandler):
|
||||
logger.error("事件 napcat_set_group_sign 请求失败!")
|
||||
return HandlerResult(False, False, {"status": "error"})
|
||||
|
||||
|
||||
# ===PERSONAL===
|
||||
class SetInputStatusHandler(BaseEventHandler):
|
||||
handler_name: str = "napcat_set_input_status_handler"
|
||||
|
||||
@@ -227,7 +227,7 @@ class LauchNapcatAdapterHandler(BaseEventHandler):
|
||||
await reassembler.start_cleanup_task()
|
||||
|
||||
logger.info("开始启动Napcat Adapter")
|
||||
|
||||
|
||||
# 创建单独的异步任务,防止阻塞主线程
|
||||
asyncio.create_task(self._start_maibot_connection())
|
||||
asyncio.create_task(napcat_server(self.plugin_config))
|
||||
@@ -238,10 +238,10 @@ class LauchNapcatAdapterHandler(BaseEventHandler):
|
||||
"""非阻塞方式启动MaiBot连接,等待主服务启动后再连接"""
|
||||
# 等待一段时间让MaiBot主服务完全启动
|
||||
await asyncio.sleep(5)
|
||||
|
||||
|
||||
max_attempts = 10
|
||||
attempt = 0
|
||||
|
||||
|
||||
while attempt < max_attempts:
|
||||
try:
|
||||
logger.info(f"尝试连接MaiBot (第{attempt + 1}次)")
|
||||
@@ -285,7 +285,7 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
def enable_plugin(self) -> bool:
|
||||
"""通过配置文件动态控制插件启用状态"""
|
||||
# 如果已经通过配置加载了状态,使用配置中的值
|
||||
if hasattr(self, '_is_enabled'):
|
||||
if hasattr(self, "_is_enabled"):
|
||||
return self._is_enabled
|
||||
# 否则使用默认值(禁用状态)
|
||||
return False
|
||||
@@ -308,60 +308,107 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
"nickname": ConfigField(type=str, default="", description="昵称配置(目前未使用)"),
|
||||
},
|
||||
"napcat_server": {
|
||||
"mode": ConfigField(type=str, default="reverse", description="连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端)", choices=["reverse", "forward"]),
|
||||
"mode": ConfigField(
|
||||
type=str,
|
||||
default="reverse",
|
||||
description="连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端)",
|
||||
choices=["reverse", "forward"],
|
||||
),
|
||||
"host": ConfigField(type=str, default="localhost", description="主机地址"),
|
||||
"port": ConfigField(type=int, default=8095, description="端口号"),
|
||||
"url": ConfigField(type=str, default="", description="正向连接时的完整WebSocket URL,如 ws://localhost:8080/ws (仅在forward模式下使用)"),
|
||||
"access_token": ConfigField(type=str, default="", description="WebSocket 连接的访问令牌,用于身份验证(可选)"),
|
||||
"url": ConfigField(
|
||||
type=str,
|
||||
default="",
|
||||
description="正向连接时的完整WebSocket URL,如 ws://localhost:8080/ws (仅在forward模式下使用)",
|
||||
),
|
||||
"access_token": ConfigField(
|
||||
type=str, default="", description="WebSocket 连接的访问令牌,用于身份验证(可选)"
|
||||
),
|
||||
"heartbeat_interval": ConfigField(type=int, default=30, description="心跳间隔时间(按秒计)"),
|
||||
},
|
||||
"maibot_server": {
|
||||
"host": ConfigField(type=str, default="localhost", description="麦麦在.env文件中设置的主机地址,即HOST字段"),
|
||||
"host": ConfigField(
|
||||
type=str, default="localhost", description="麦麦在.env文件中设置的主机地址,即HOST字段"
|
||||
),
|
||||
"port": ConfigField(type=int, default=8000, description="麦麦在.env文件中设置的端口,即PORT字段"),
|
||||
"platform_name": ConfigField(type=str, default="qq", description="平台名称,用于消息路由"),
|
||||
},
|
||||
"voice": {
|
||||
"use_tts": ConfigField(type=bool, default=False, description="是否使用tts语音(请确保你配置了tts并有对应的adapter)"),
|
||||
"use_tts": ConfigField(
|
||||
type=bool, default=False, description="是否使用tts语音(请确保你配置了tts并有对应的adapter)"
|
||||
),
|
||||
},
|
||||
"slicing": {
|
||||
"max_frame_size": ConfigField(type=int, default=64, description="WebSocket帧的最大大小,单位为字节,默认64KB"),
|
||||
"max_frame_size": ConfigField(
|
||||
type=int, default=64, description="WebSocket帧的最大大小,单位为字节,默认64KB"
|
||||
),
|
||||
"delay_ms": ConfigField(type=int, default=10, description="切片发送间隔时间,单位为毫秒"),
|
||||
},
|
||||
"debug": {
|
||||
"level": ConfigField(type=str, default="INFO", description="日志等级(DEBUG, INFO, WARNING, ERROR, CRITICAL)", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]),
|
||||
"level": ConfigField(
|
||||
type=str,
|
||||
default="INFO",
|
||||
description="日志等级(DEBUG, INFO, WARNING, ERROR, CRITICAL)",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||
),
|
||||
},
|
||||
"features": {
|
||||
# 权限设置
|
||||
"group_list_type": ConfigField(type=str, default="blacklist", description="群聊列表类型:whitelist(白名单)或 blacklist(黑名单)", choices=["whitelist", "blacklist"]),
|
||||
"group_list_type": ConfigField(
|
||||
type=str,
|
||||
default="blacklist",
|
||||
description="群聊列表类型:whitelist(白名单)或 blacklist(黑名单)",
|
||||
choices=["whitelist", "blacklist"],
|
||||
),
|
||||
"group_list": ConfigField(type=list, default=[], description="群聊ID列表"),
|
||||
"private_list_type": ConfigField(type=str, default="blacklist", description="私聊列表类型:whitelist(白名单)或 blacklist(黑名单)", choices=["whitelist", "blacklist"]),
|
||||
"private_list_type": ConfigField(
|
||||
type=str,
|
||||
default="blacklist",
|
||||
description="私聊列表类型:whitelist(白名单)或 blacklist(黑名单)",
|
||||
choices=["whitelist", "blacklist"],
|
||||
),
|
||||
"private_list": ConfigField(type=list, default=[], description="用户ID列表"),
|
||||
"ban_user_id": ConfigField(type=list, default=[], description="全局禁止用户ID列表,这些用户无法在任何地方使用机器人"),
|
||||
"ban_user_id": ConfigField(
|
||||
type=list, default=[], description="全局禁止用户ID列表,这些用户无法在任何地方使用机器人"
|
||||
),
|
||||
"ban_qq_bot": ConfigField(type=bool, default=False, description="是否屏蔽QQ官方机器人消息"),
|
||||
|
||||
# 聊天功能设置
|
||||
"enable_poke": ConfigField(type=bool, default=True, description="是否启用戳一戳功能"),
|
||||
"ignore_non_self_poke": ConfigField(type=bool, default=False, description="是否无视不是针对自己的戳一戳"),
|
||||
"poke_debounce_seconds": ConfigField(type=int, default=3, description="戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略"),
|
||||
"poke_debounce_seconds": ConfigField(
|
||||
type=int, default=3, description="戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略"
|
||||
),
|
||||
"enable_reply_at": ConfigField(type=bool, default=True, description="是否启用引用回复时艾特用户的功能"),
|
||||
"reply_at_rate": ConfigField(type=float, default=0.5, description="引用回复时艾特用户的几率 (0.0 ~ 1.0)"),
|
||||
"enable_emoji_like": ConfigField(type=bool, default=True, description="是否启用群聊表情回复功能"),
|
||||
|
||||
# 视频处理设置
|
||||
"enable_video_analysis": ConfigField(type=bool, default=True, description="是否启用视频识别功能"),
|
||||
"max_video_size_mb": ConfigField(type=int, default=100, description="视频文件最大大小限制(MB)"),
|
||||
"download_timeout": ConfigField(type=int, default=60, description="视频下载超时时间(秒)"),
|
||||
"supported_formats": ConfigField(type=list, default=["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"], description="支持的视频格式"),
|
||||
|
||||
"supported_formats": ConfigField(
|
||||
type=list, default=["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"], description="支持的视频格式"
|
||||
),
|
||||
# 消息缓冲设置
|
||||
"enable_message_buffer": ConfigField(type=bool, default=True, description="是否启用消息缓冲合并功能"),
|
||||
"message_buffer_enable_group": ConfigField(type=bool, default=True, description="是否启用群聊消息缓冲合并"),
|
||||
"message_buffer_enable_private": ConfigField(type=bool, default=True, description="是否启用私聊消息缓冲合并"),
|
||||
"message_buffer_interval": ConfigField(type=float, default=3.0, description="消息合并间隔时间(秒),在此时间内的连续消息将被合并"),
|
||||
"message_buffer_initial_delay": ConfigField(type=float, default=0.5, description="消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并"),
|
||||
"message_buffer_max_components": ConfigField(type=int, default=50, description="单个会话最大缓冲消息组件数量,超过此数量将强制合并"),
|
||||
"message_buffer_block_prefixes": ConfigField(type=list, default=["/", "!", "!", ".", "。", "#", "%"], description="消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲"),
|
||||
}
|
||||
"message_buffer_enable_private": ConfigField(
|
||||
type=bool, default=True, description="是否启用私聊消息缓冲合并"
|
||||
),
|
||||
"message_buffer_interval": ConfigField(
|
||||
type=float, default=3.0, description="消息合并间隔时间(秒),在此时间内的连续消息将被合并"
|
||||
),
|
||||
"message_buffer_initial_delay": ConfigField(
|
||||
type=float, default=0.5, description="消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并"
|
||||
),
|
||||
"message_buffer_max_components": ConfigField(
|
||||
type=int, default=50, description="单个会话最大缓冲消息组件数量,超过此数量将强制合并"
|
||||
),
|
||||
"message_buffer_block_prefixes": ConfigField(
|
||||
type=list,
|
||||
default=["/", "!", "!", ".", "。", "#", "%"],
|
||||
description="消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲",
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
# 配置节描述
|
||||
@@ -374,7 +421,7 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
"voice": "发送语音设置",
|
||||
"slicing": "WebSocket消息切片设置",
|
||||
"debug": "调试设置",
|
||||
"features": "功能设置(权限控制、聊天功能、视频处理、消息缓冲等)"
|
||||
"features": "功能设置(权限控制、聊天功能、视频处理、消息缓冲等)",
|
||||
}
|
||||
|
||||
def register_events(self):
|
||||
@@ -409,6 +456,7 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
chunker.set_plugin_config(self.config)
|
||||
# 设置response_pool的插件配置
|
||||
from .src.response_pool import set_plugin_config as set_response_pool_config
|
||||
|
||||
set_response_pool_config(self.config)
|
||||
# 设置send_handler的插件配置
|
||||
send_handler.set_plugin_config(self.config)
|
||||
@@ -418,4 +466,4 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
notice_handler.set_plugin_config(self.config)
|
||||
# 设置meta_event_handler的插件配置
|
||||
meta_event_handler.set_plugin_config(self.config)
|
||||
# 设置其他handler的插件配置(现在由component_registry在注册时自动设置)
|
||||
# 设置其他handler的插件配置(现在由component_registry在注册时自动设置)
|
||||
|
||||
@@ -102,7 +102,9 @@ class SimpleMessageBuffer:
|
||||
return True
|
||||
|
||||
# 检查屏蔽前缀
|
||||
block_prefixes = tuple(config_api.get_plugin_config(self.plugin_config, "features.message_buffer_block_prefixes", []))
|
||||
block_prefixes = tuple(
|
||||
config_api.get_plugin_config(self.plugin_config, "features.message_buffer_block_prefixes", [])
|
||||
)
|
||||
|
||||
text = text.strip()
|
||||
if text.startswith(block_prefixes):
|
||||
@@ -134,9 +136,13 @@ class SimpleMessageBuffer:
|
||||
|
||||
# 检查是否启用对应类型的缓冲
|
||||
message_type = event_data.get("message_type", "")
|
||||
if message_type == "group" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_group", False):
|
||||
if message_type == "group" and not config_api.get_plugin_config(
|
||||
self.plugin_config, "features.message_buffer_enable_group", False
|
||||
):
|
||||
return False
|
||||
elif message_type == "private" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_private", False):
|
||||
elif message_type == "private" and not config_api.get_plugin_config(
|
||||
self.plugin_config, "features.message_buffer_enable_private", False
|
||||
):
|
||||
return False
|
||||
|
||||
# 提取文本
|
||||
@@ -158,7 +164,9 @@ class SimpleMessageBuffer:
|
||||
session = self.buffer_pool[session_id]
|
||||
|
||||
# 检查是否超过最大组件数量
|
||||
if len(session.messages) >= config_api.get_plugin_config(self.plugin_config, "features.message_buffer_max_components", 5):
|
||||
if len(session.messages) >= config_api.get_plugin_config(
|
||||
self.plugin_config, "features.message_buffer_max_components", 5
|
||||
):
|
||||
logger.debug(f"会话 {session_id} 消息数量达到上限,强制合并")
|
||||
asyncio.create_task(self._force_merge_session(session_id))
|
||||
self.buffer_pool[session_id] = BufferedSession(session_id=session_id, original_event=original_event)
|
||||
|
||||
@@ -14,7 +14,7 @@ def create_router(plugin_config: dict):
|
||||
platform_name = config_api.get_plugin_config(plugin_config, "maibot_server.platform_name", "qq")
|
||||
host = config_api.get_plugin_config(plugin_config, "maibot_server.host", "localhost")
|
||||
port = config_api.get_plugin_config(plugin_config, "maibot_server.port", 8000)
|
||||
|
||||
|
||||
route_config = RouteConfig(
|
||||
route_config={
|
||||
platform_name: TargetConfig(
|
||||
@@ -32,7 +32,7 @@ async def mmc_start_com(plugin_config: dict = None):
|
||||
logger.info("正在连接MaiBot")
|
||||
if plugin_config:
|
||||
create_router(plugin_config)
|
||||
|
||||
|
||||
if router:
|
||||
router.register_class_handler(send_handler.handle_message)
|
||||
await router.run()
|
||||
|
||||
@@ -32,7 +32,7 @@ class NoticeType: # 通知事件
|
||||
group_recall = "group_recall" # 群聊消息撤回
|
||||
notify = "notify"
|
||||
group_ban = "group_ban" # 群禁言
|
||||
group_msg_emoji_like = "group_msg_emoji_like" # 群聊表情回复
|
||||
group_msg_emoji_like = "group_msg_emoji_like" # 群聊表情回复
|
||||
|
||||
class Notify:
|
||||
poke = "poke" # 戳一戳
|
||||
|
||||
@@ -100,7 +100,7 @@ class MessageHandler:
|
||||
# 检查群聊黑白名单
|
||||
group_list_type = config_api.get_plugin_config(self.plugin_config, "features.group_list_type", "blacklist")
|
||||
group_list = config_api.get_plugin_config(self.plugin_config, "features.group_list", [])
|
||||
|
||||
|
||||
if group_list_type == "whitelist":
|
||||
if group_id not in group_list:
|
||||
logger.warning("群聊不在白名单中,消息被丢弃")
|
||||
@@ -111,9 +111,11 @@ class MessageHandler:
|
||||
return False
|
||||
else:
|
||||
# 检查私聊黑白名单
|
||||
private_list_type = config_api.get_plugin_config(self.plugin_config, "features.private_list_type", "blacklist")
|
||||
private_list_type = config_api.get_plugin_config(
|
||||
self.plugin_config, "features.private_list_type", "blacklist"
|
||||
)
|
||||
private_list = config_api.get_plugin_config(self.plugin_config, "features.private_list", [])
|
||||
|
||||
|
||||
if private_list_type == "whitelist":
|
||||
if user_id not in private_list:
|
||||
logger.warning("私聊不在白名单中,消息被丢弃")
|
||||
@@ -156,21 +158,23 @@ class MessageHandler:
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
"""
|
||||
|
||||
|
||||
# 添加原始消息调试日志,特别关注message字段
|
||||
logger.debug(f"收到原始消息: message_type={raw_message.get('message_type')}, message_id={raw_message.get('message_id')}")
|
||||
logger.debug(
|
||||
f"收到原始消息: message_type={raw_message.get('message_type')}, message_id={raw_message.get('message_id')}"
|
||||
)
|
||||
logger.debug(f"原始消息内容: {raw_message.get('message', [])}")
|
||||
|
||||
|
||||
# 检查是否包含@或video消息段
|
||||
message_segments = raw_message.get('message', [])
|
||||
message_segments = raw_message.get("message", [])
|
||||
if message_segments:
|
||||
for i, seg in enumerate(message_segments):
|
||||
seg_type = seg.get('type')
|
||||
if seg_type in ['at', 'video']:
|
||||
seg_type = seg.get("type")
|
||||
if seg_type in ["at", "video"]:
|
||||
logger.info(f"检测到 {seg_type.upper()} 消息段 [{i}]: {seg}")
|
||||
elif seg_type not in ['text', 'face', 'image']:
|
||||
elif seg_type not in ["text", "face", "image"]:
|
||||
logger.warning(f"检测到特殊消息段 [{i}]: type={seg_type}, data={seg.get('data', {})}")
|
||||
|
||||
|
||||
message_type: str = raw_message.get("message_type")
|
||||
message_id: int = raw_message.get("message_id")
|
||||
# message_time: int = raw_message.get("time")
|
||||
@@ -308,9 +312,13 @@ class MessageHandler:
|
||||
message_type = raw_message.get("message_type")
|
||||
should_use_buffer = False
|
||||
|
||||
if message_type == "group" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_group", True):
|
||||
if message_type == "group" and config_api.get_plugin_config(
|
||||
self.plugin_config, "features.message_buffer_enable_group", True
|
||||
):
|
||||
should_use_buffer = True
|
||||
elif message_type == "private" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_private", True):
|
||||
elif message_type == "private" and config_api.get_plugin_config(
|
||||
self.plugin_config, "features.message_buffer_enable_private", True
|
||||
):
|
||||
should_use_buffer = True
|
||||
|
||||
if should_use_buffer:
|
||||
@@ -368,10 +376,10 @@ class MessageHandler:
|
||||
for sub_message in real_message:
|
||||
sub_message: dict
|
||||
sub_message_type = sub_message.get("type")
|
||||
|
||||
|
||||
# 添加详细的消息类型调试信息
|
||||
logger.debug(f"处理消息段: type={sub_message_type}, data={sub_message.get('data', {})}")
|
||||
|
||||
|
||||
# 特别关注 at 和 video 消息的识别
|
||||
if sub_message_type == "at":
|
||||
logger.debug(f"检测到@消息: {sub_message}")
|
||||
@@ -379,7 +387,7 @@ class MessageHandler:
|
||||
logger.debug(f"检测到VIDEO消息: {sub_message}")
|
||||
elif sub_message_type not in ["text", "face", "image", "record"]:
|
||||
logger.warning(f"检测到特殊消息类型: {sub_message_type}, 完整消息: {sub_message}")
|
||||
|
||||
|
||||
match sub_message_type:
|
||||
case RealMessageType.text:
|
||||
ret_seg = await self.handle_text_message(sub_message)
|
||||
|
||||
@@ -33,6 +33,7 @@ class MessageSending:
|
||||
try:
|
||||
# 重新导入router
|
||||
from ..mmc_com_layer import router
|
||||
|
||||
self.maibot_router = router
|
||||
if self.maibot_router is not None:
|
||||
logger.info("MaiBot router重连成功")
|
||||
@@ -73,14 +74,14 @@ class MessageSending:
|
||||
|
||||
# 获取对应的客户端并发送切片
|
||||
platform = message_base.message_info.platform
|
||||
|
||||
|
||||
# 再次检查router状态(防止运行时被重置)
|
||||
if self.maibot_router is None or not hasattr(self.maibot_router, 'clients'):
|
||||
if self.maibot_router is None or not hasattr(self.maibot_router, "clients"):
|
||||
logger.warning("MaiBot router连接已断开,尝试重新连接")
|
||||
if not await self._attempt_reconnect():
|
||||
logger.error("MaiBot router重连失败,切片发送中止")
|
||||
return False
|
||||
|
||||
|
||||
if platform not in self.maibot_router.clients:
|
||||
logger.error(f"平台 {platform} 未连接")
|
||||
return False
|
||||
|
||||
@@ -22,7 +22,9 @@ class MetaEventHandler:
|
||||
"""设置插件配置"""
|
||||
self.plugin_config = plugin_config
|
||||
# 更新interval值
|
||||
self.interval = config_api.get_plugin_config(self.plugin_config, "napcat_server.heartbeat_interval", 5000) / 1000
|
||||
self.interval = (
|
||||
config_api.get_plugin_config(self.plugin_config, "napcat_server.heartbeat_interval", 5000) / 1000
|
||||
)
|
||||
|
||||
async def handle_meta_event(self, message: dict) -> None:
|
||||
event_type = message.get("meta_event_type")
|
||||
|
||||
@@ -116,9 +116,9 @@ class NoticeHandler:
|
||||
sub_type = raw_message.get("sub_type")
|
||||
match sub_type:
|
||||
case NoticeType.Notify.poke:
|
||||
if config_api.get_plugin_config(self.plugin_config, "features.enable_poke", True) and await message_handler.check_allow_to_chat(
|
||||
user_id, group_id, False, False
|
||||
):
|
||||
if config_api.get_plugin_config(
|
||||
self.plugin_config, "features.enable_poke", True
|
||||
) and await message_handler.check_allow_to_chat(user_id, group_id, False, False):
|
||||
logger.debug("处理戳一戳消息")
|
||||
handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id)
|
||||
else:
|
||||
@@ -127,14 +127,18 @@ class NoticeHandler:
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from ...event_types import NapcatEvent
|
||||
|
||||
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME)
|
||||
await event_manager.trigger_event(
|
||||
NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME
|
||||
)
|
||||
case _:
|
||||
logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}")
|
||||
case NoticeType.group_msg_emoji_like:
|
||||
case NoticeType.group_msg_emoji_like:
|
||||
# 该事件转移到 handle_group_emoji_like_notify函数内触发
|
||||
if config_api.get_plugin_config(self.plugin_config, "features.enable_emoji_like", True):
|
||||
logger.debug("处理群聊表情回复")
|
||||
handled_message, user_info = await self.handle_group_emoji_like_notify(raw_message,group_id,user_id)
|
||||
handled_message, user_info = await self.handle_group_emoji_like_notify(
|
||||
raw_message, group_id, user_id
|
||||
)
|
||||
else:
|
||||
logger.warning("群聊表情回复被禁用,取消群聊表情回复处理")
|
||||
case NoticeType.group_ban:
|
||||
@@ -294,7 +298,7 @@ class NoticeHandler:
|
||||
async def handle_group_emoji_like_notify(self, raw_message: dict, group_id: int, user_id: int):
|
||||
if not group_id:
|
||||
logger.error("群ID不能为空,无法处理群聊表情回复通知")
|
||||
return None, None
|
||||
return None, None
|
||||
|
||||
user_qq_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id)
|
||||
if user_qq_info:
|
||||
@@ -304,37 +308,42 @@ class NoticeHandler:
|
||||
user_name = "QQ用户"
|
||||
user_cardname = "QQ用户"
|
||||
logger.debug("无法获取表情回复对方的用户昵称")
|
||||
|
||||
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from ...event_types import NapcatEvent
|
||||
|
||||
target_message = await event_manager.trigger_event(NapcatEvent.MESSAGE.GET_MSG,message_id=raw_message.get("message_id",""))
|
||||
target_message_text = target_message.get_message_result().get("data",{}).get("raw_message","")
|
||||
target_message = await event_manager.trigger_event(
|
||||
NapcatEvent.MESSAGE.GET_MSG, message_id=raw_message.get("message_id", "")
|
||||
)
|
||||
target_message_text = target_message.get_message_result().get("data", {}).get("raw_message", "")
|
||||
if not target_message:
|
||||
logger.error("未找到对应消息")
|
||||
return None, None
|
||||
if len(target_message_text) > 15:
|
||||
target_message_text = target_message_text[:15] + "..."
|
||||
|
||||
|
||||
user_info: UserInfo = UserInfo(
|
||||
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
|
||||
user_id=user_id,
|
||||
user_nickname=user_name,
|
||||
user_cardname=user_cardname,
|
||||
)
|
||||
|
||||
|
||||
like_emoji_id = raw_message.get("likes")[0].get("emoji_id")
|
||||
await event_manager.trigger_event(
|
||||
NapcatEvent.ON_RECEIVED.EMOJI_LIEK,
|
||||
permission_group=PLUGIN_NAME,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
message_id=raw_message.get("message_id",""),
|
||||
emoji_id=like_emoji_id
|
||||
)
|
||||
seg_data = Seg(type="text",data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id,"")}回复了你的消息[{target_message_text}]")
|
||||
NapcatEvent.ON_RECEIVED.EMOJI_LIEK,
|
||||
permission_group=PLUGIN_NAME,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
message_id=raw_message.get("message_id", ""),
|
||||
emoji_id=like_emoji_id,
|
||||
)
|
||||
seg_data = Seg(
|
||||
type="text",
|
||||
data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]",
|
||||
)
|
||||
return seg_data, user_info
|
||||
|
||||
|
||||
async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]:
|
||||
if not group_id:
|
||||
logger.error("群ID不能为空,无法处理禁言通知")
|
||||
|
||||
@@ -45,12 +45,12 @@ async def check_timeout_response() -> None:
|
||||
while True:
|
||||
cleaned_message_count: int = 0
|
||||
now_time = time.time()
|
||||
|
||||
|
||||
# 获取心跳间隔配置
|
||||
heartbeat_interval = 30 # 默认值
|
||||
if plugin_config:
|
||||
heartbeat_interval = config_api.get_plugin_config(plugin_config, "napcat_server.heartbeat_interval", 30)
|
||||
|
||||
|
||||
for echo_id, response_time in list(response_time_dict.items()):
|
||||
if now_time - response_time > heartbeat_interval:
|
||||
cleaned_message_count += 1
|
||||
|
||||
@@ -297,9 +297,9 @@ class SendHandler:
|
||||
|
||||
try:
|
||||
# 检查是否为缓冲消息ID(格式:buffered-{original_id}-{timestamp})
|
||||
if id.startswith('buffered-'):
|
||||
if id.startswith("buffered-"):
|
||||
# 从缓冲消息ID中提取原始消息ID
|
||||
original_id = id.split('-')[1]
|
||||
original_id = id.split("-")[1]
|
||||
msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(original_id)})
|
||||
else:
|
||||
msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(id)})
|
||||
@@ -363,7 +363,7 @@ class SendHandler:
|
||||
use_tts = False
|
||||
if self.plugin_config:
|
||||
use_tts = config_api.get_plugin_config(self.plugin_config, "voice.use_tts", False)
|
||||
|
||||
|
||||
if not use_tts:
|
||||
logger.warning("未启用语音消息处理")
|
||||
return {}
|
||||
|
||||
@@ -18,7 +18,9 @@ class WebSocketManager:
|
||||
self.max_reconnect_attempts = 10 # 最大重连次数
|
||||
self.plugin_config = None
|
||||
|
||||
async def start_connection(self, message_handler: Callable[[Server.ServerConnection], Any], plugin_config: dict) -> None:
|
||||
async def start_connection(
|
||||
self, message_handler: Callable[[Server.ServerConnection], Any], plugin_config: dict
|
||||
) -> None:
|
||||
"""根据配置启动 WebSocket 连接"""
|
||||
self.plugin_config = plugin_config
|
||||
mode = config_api.get_plugin_config(plugin_config, "napcat_server.mode")
|
||||
@@ -72,9 +74,7 @@ class WebSocketManager:
|
||||
# 如果配置了访问令牌,添加到请求头
|
||||
access_token = config_api.get_plugin_config(self.plugin_config, "napcat_server.access_token")
|
||||
if access_token:
|
||||
connect_kwargs["additional_headers"] = {
|
||||
"Authorization": f"Bearer {access_token}"
|
||||
}
|
||||
connect_kwargs["additional_headers"] = {"Authorization": f"Bearer {access_token}"}
|
||||
logger.info("已添加访问令牌到连接请求头")
|
||||
|
||||
async with Server.connect(url, **connect_kwargs) as websocket:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Base search engine interface
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Any
|
||||
|
||||
@@ -9,20 +10,20 @@ class BaseSearchEngine(ABC):
|
||||
"""
|
||||
搜索引擎基类
|
||||
"""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
执行搜索
|
||||
|
||||
|
||||
Args:
|
||||
args: 搜索参数,包含 query、num_results、time_range 等
|
||||
|
||||
|
||||
Returns:
|
||||
搜索结果列表,每个结果包含 title、url、snippet、provider 字段
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def is_available(self) -> bool:
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Bing search engine implementation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import random
|
||||
@@ -58,21 +59,21 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
"""
|
||||
Bing搜索引擎实现
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.session = requests.Session()
|
||||
self.session.headers = HEADERS
|
||||
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""检查Bing搜索引擎是否可用"""
|
||||
return True # Bing是免费搜索引擎,总是可用
|
||||
|
||||
|
||||
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""执行Bing搜索"""
|
||||
query = args["query"]
|
||||
num_results = args.get("num_results", 3)
|
||||
time_range = args.get("time_range", "any")
|
||||
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
func = functools.partial(self._search_sync, query, num_results, time_range)
|
||||
@@ -81,17 +82,17 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
except Exception as e:
|
||||
logger.error(f"Bing 搜索失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _search_sync(self, keyword: str, num_results: int, time_range: str) -> List[Dict[str, Any]]:
|
||||
"""同步执行Bing搜索"""
|
||||
if not keyword:
|
||||
return []
|
||||
|
||||
list_result = []
|
||||
|
||||
|
||||
# 构建搜索URL
|
||||
search_url = bing_search_url + keyword
|
||||
|
||||
|
||||
# 如果指定了时间范围,添加时间过滤参数
|
||||
if time_range == "week":
|
||||
search_url += "&qft=+filterui:date-range-7"
|
||||
@@ -181,34 +182,29 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
# 尝试提取搜索结果
|
||||
# 方法1: 查找标准的搜索结果容器
|
||||
results = root.select("ol#b_results li.b_algo")
|
||||
|
||||
|
||||
if results:
|
||||
for _rank, result in enumerate(results, 1):
|
||||
# 提取标题和链接
|
||||
title_link = result.select_one("h2 a")
|
||||
if not title_link:
|
||||
continue
|
||||
|
||||
|
||||
title = title_link.get_text().strip()
|
||||
url = title_link.get("href", "")
|
||||
|
||||
|
||||
# 提取摘要
|
||||
abstract = ""
|
||||
abstract_elem = result.select_one("div.b_caption p")
|
||||
if abstract_elem:
|
||||
abstract = abstract_elem.get_text().strip()
|
||||
|
||||
|
||||
# 限制摘要长度
|
||||
if ABSTRACT_MAX_LENGTH and len(abstract) > ABSTRACT_MAX_LENGTH:
|
||||
abstract = abstract[:ABSTRACT_MAX_LENGTH] + "..."
|
||||
|
||||
list_data.append({
|
||||
"title": title,
|
||||
"url": url,
|
||||
"snippet": abstract,
|
||||
"provider": "Bing"
|
||||
})
|
||||
|
||||
|
||||
list_data.append({"title": title, "url": url, "snippet": abstract, "provider": "Bing"})
|
||||
|
||||
if len(list_data) >= 10: # 限制结果数量
|
||||
break
|
||||
|
||||
@@ -216,22 +212,34 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
if not list_data:
|
||||
# 查找所有可能的搜索结果链接
|
||||
all_links = root.find_all("a")
|
||||
|
||||
|
||||
for link in all_links:
|
||||
href = link.get("href", "")
|
||||
text = link.get_text().strip()
|
||||
|
||||
|
||||
# 过滤有效的搜索结果链接
|
||||
if (href and text and len(text) > 10
|
||||
if (
|
||||
href
|
||||
and text
|
||||
and len(text) > 10
|
||||
and not href.startswith("javascript:")
|
||||
and not href.startswith("#")
|
||||
and "http" in href
|
||||
and not any(x in href for x in [
|
||||
"bing.com/search", "bing.com/images", "bing.com/videos",
|
||||
"bing.com/maps", "bing.com/news", "login", "account",
|
||||
"microsoft", "javascript"
|
||||
])):
|
||||
|
||||
and not any(
|
||||
x in href
|
||||
for x in [
|
||||
"bing.com/search",
|
||||
"bing.com/images",
|
||||
"bing.com/videos",
|
||||
"bing.com/maps",
|
||||
"bing.com/news",
|
||||
"login",
|
||||
"account",
|
||||
"microsoft",
|
||||
"javascript",
|
||||
]
|
||||
)
|
||||
):
|
||||
# 尝试获取摘要
|
||||
abstract = ""
|
||||
parent = link.parent
|
||||
@@ -239,18 +247,13 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
full_text = parent.get_text().strip()
|
||||
if len(full_text) > len(text):
|
||||
abstract = full_text.replace(text, "", 1).strip()
|
||||
|
||||
|
||||
# 限制摘要长度
|
||||
if ABSTRACT_MAX_LENGTH and len(abstract) > ABSTRACT_MAX_LENGTH:
|
||||
abstract = abstract[:ABSTRACT_MAX_LENGTH] + "..."
|
||||
|
||||
list_data.append({
|
||||
"title": text,
|
||||
"url": href,
|
||||
"snippet": abstract,
|
||||
"provider": "Bing"
|
||||
})
|
||||
|
||||
|
||||
list_data.append({"title": text, "url": href, "snippet": abstract, "provider": "Bing"})
|
||||
|
||||
if len(list_data) >= 10:
|
||||
break
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
DuckDuckGo search engine implementation
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any
|
||||
from asyncddgs import aDDGS
|
||||
|
||||
@@ -14,27 +15,22 @@ class DDGSearchEngine(BaseSearchEngine):
|
||||
"""
|
||||
DuckDuckGo搜索引擎实现
|
||||
"""
|
||||
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""检查DuckDuckGo搜索引擎是否可用"""
|
||||
return True # DuckDuckGo不需要API密钥,总是可用
|
||||
|
||||
|
||||
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""执行DuckDuckGo搜索"""
|
||||
query = args["query"]
|
||||
num_results = args.get("num_results", 3)
|
||||
|
||||
|
||||
try:
|
||||
async with aDDGS() as ddgs:
|
||||
search_response = await ddgs.text(query, max_results=num_results)
|
||||
|
||||
|
||||
return [
|
||||
{
|
||||
"title": r.get("title"),
|
||||
"url": r.get("href"),
|
||||
"snippet": r.get("body"),
|
||||
"provider": "DuckDuckGo"
|
||||
}
|
||||
{"title": r.get("title"), "url": r.get("href"), "snippet": r.get("body"), "provider": "DuckDuckGo"}
|
||||
for r in search_response
|
||||
]
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Exa search engine implementation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from datetime import datetime, timedelta
|
||||
@@ -19,31 +20,27 @@ class ExaSearchEngine(BaseSearchEngine):
|
||||
"""
|
||||
Exa搜索引擎实现
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self._initialize_clients()
|
||||
|
||||
|
||||
def _initialize_clients(self):
|
||||
"""初始化Exa客户端"""
|
||||
# 从主配置文件读取API密钥
|
||||
exa_api_keys = config_api.get_global_config("web_search.exa_api_keys", None)
|
||||
|
||||
|
||||
# 创建API密钥管理器
|
||||
self.api_manager = create_api_key_manager_from_config(
|
||||
exa_api_keys,
|
||||
lambda key: Exa(api_key=key),
|
||||
"Exa"
|
||||
)
|
||||
|
||||
self.api_manager = create_api_key_manager_from_config(exa_api_keys, lambda key: Exa(api_key=key), "Exa")
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""检查Exa搜索引擎是否可用"""
|
||||
return self.api_manager.is_available()
|
||||
|
||||
|
||||
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""执行Exa搜索"""
|
||||
if not self.is_available():
|
||||
return []
|
||||
|
||||
|
||||
query = args["query"]
|
||||
num_results = args.get("num_results", 3)
|
||||
time_range = args.get("time_range", "any")
|
||||
@@ -52,7 +49,7 @@ class ExaSearchEngine(BaseSearchEngine):
|
||||
if time_range != "any":
|
||||
today = datetime.now()
|
||||
start_date = today - timedelta(days=7 if time_range == "week" else 30)
|
||||
exa_args["start_published_date"] = start_date.strftime('%Y-%m-%d')
|
||||
exa_args["start_published_date"] = start_date.strftime("%Y-%m-%d")
|
||||
|
||||
try:
|
||||
# 使用API密钥管理器获取下一个客户端
|
||||
@@ -60,17 +57,17 @@ class ExaSearchEngine(BaseSearchEngine):
|
||||
if not exa_client:
|
||||
logger.error("无法获取Exa客户端")
|
||||
return []
|
||||
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
func = functools.partial(exa_client.search_and_contents, query, **exa_args)
|
||||
search_response = await loop.run_in_executor(None, func)
|
||||
|
||||
|
||||
return [
|
||||
{
|
||||
"title": res.title,
|
||||
"url": res.url,
|
||||
"snippet": " ".join(getattr(res, 'highlights', [])) or (getattr(res, 'text', '')[:250] + '...'),
|
||||
"provider": "Exa"
|
||||
"snippet": " ".join(getattr(res, "highlights", [])) or (getattr(res, "text", "")[:250] + "..."),
|
||||
"provider": "Exa",
|
||||
}
|
||||
for res in search_response.results
|
||||
]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Tavily search engine implementation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from typing import Dict, List, Any
|
||||
@@ -18,31 +19,29 @@ class TavilySearchEngine(BaseSearchEngine):
|
||||
"""
|
||||
Tavily搜索引擎实现
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self._initialize_clients()
|
||||
|
||||
|
||||
def _initialize_clients(self):
|
||||
"""初始化Tavily客户端"""
|
||||
# 从主配置文件读取API密钥
|
||||
tavily_api_keys = config_api.get_global_config("web_search.tavily_api_keys", None)
|
||||
|
||||
|
||||
# 创建API密钥管理器
|
||||
self.api_manager = create_api_key_manager_from_config(
|
||||
tavily_api_keys,
|
||||
lambda key: TavilyClient(api_key=key),
|
||||
"Tavily"
|
||||
tavily_api_keys, lambda key: TavilyClient(api_key=key), "Tavily"
|
||||
)
|
||||
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""检查Tavily搜索引擎是否可用"""
|
||||
return self.api_manager.is_available()
|
||||
|
||||
|
||||
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""执行Tavily搜索"""
|
||||
if not self.is_available():
|
||||
return []
|
||||
|
||||
|
||||
query = args["query"]
|
||||
num_results = args.get("num_results", 3)
|
||||
time_range = args.get("time_range", "any")
|
||||
@@ -53,38 +52,40 @@ class TavilySearchEngine(BaseSearchEngine):
|
||||
if not tavily_client:
|
||||
logger.error("无法获取Tavily客户端")
|
||||
return []
|
||||
|
||||
|
||||
# 构建Tavily搜索参数
|
||||
search_params = {
|
||||
"query": query,
|
||||
"max_results": num_results,
|
||||
"search_depth": "basic",
|
||||
"include_answer": False,
|
||||
"include_raw_content": False
|
||||
"include_raw_content": False,
|
||||
}
|
||||
|
||||
|
||||
# 根据时间范围调整搜索参数
|
||||
if time_range == "week":
|
||||
search_params["days"] = 7
|
||||
elif time_range == "month":
|
||||
search_params["days"] = 30
|
||||
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
func = functools.partial(tavily_client.search, **search_params)
|
||||
search_response = await loop.run_in_executor(None, func)
|
||||
|
||||
|
||||
results = []
|
||||
if search_response and "results" in search_response:
|
||||
for res in search_response["results"]:
|
||||
results.append({
|
||||
"title": res.get("title", "无标题"),
|
||||
"url": res.get("url", ""),
|
||||
"snippet": res.get("content", "")[:300] + "..." if res.get("content") else "无摘要",
|
||||
"provider": "Tavily"
|
||||
})
|
||||
|
||||
results.append(
|
||||
{
|
||||
"title": res.get("title", "无标题"),
|
||||
"url": res.get("url", ""),
|
||||
"snippet": res.get("content", "")[:300] + "..." if res.get("content") else "无摘要",
|
||||
"provider": "Tavily",
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tavily 搜索失败: {e}")
|
||||
return []
|
||||
|
||||
@@ -3,15 +3,10 @@ Web Search Tool Plugin
|
||||
|
||||
一个功能强大的网络搜索和URL解析插件,支持多种搜索引擎和解析策略。
|
||||
"""
|
||||
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
from src.plugin_system import (
|
||||
BasePlugin,
|
||||
register_plugin,
|
||||
ComponentInfo,
|
||||
ConfigField,
|
||||
PythonDependency
|
||||
)
|
||||
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo, ConfigField, PythonDependency
|
||||
from src.plugin_system.apis import config_api
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -25,7 +20,7 @@ logger = get_logger("web_search_plugin")
|
||||
class WEBSEARCHPLUGIN(BasePlugin):
|
||||
"""
|
||||
网络搜索工具插件
|
||||
|
||||
|
||||
提供网络搜索和URL解析功能,支持多种搜索引擎:
|
||||
- Exa (需要API密钥)
|
||||
- Tavily (需要API密钥)
|
||||
@@ -37,11 +32,11 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
plugin_name: str = "web_search_tool" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: List[str] = [] # 插件依赖列表
|
||||
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""初始化插件,立即加载所有搜索引擎"""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
# 立即初始化所有搜索引擎,触发API密钥管理器的日志输出
|
||||
logger.info("🚀 正在初始化所有搜索引擎...")
|
||||
try:
|
||||
@@ -49,65 +44,58 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
from .engines.tavily_engine import TavilySearchEngine
|
||||
from .engines.ddg_engine import DDGSearchEngine
|
||||
from .engines.bing_engine import BingSearchEngine
|
||||
|
||||
|
||||
# 实例化所有搜索引擎,这会触发API密钥管理器的初始化
|
||||
exa_engine = ExaSearchEngine()
|
||||
tavily_engine = TavilySearchEngine()
|
||||
ddg_engine = DDGSearchEngine()
|
||||
bing_engine = BingSearchEngine()
|
||||
|
||||
|
||||
# 报告每个引擎的状态
|
||||
engines_status = {
|
||||
"Exa": exa_engine.is_available(),
|
||||
"Tavily": tavily_engine.is_available(),
|
||||
"DuckDuckGo": ddg_engine.is_available(),
|
||||
"Bing": bing_engine.is_available()
|
||||
"Bing": bing_engine.is_available(),
|
||||
}
|
||||
|
||||
|
||||
available_engines = [name for name, available in engines_status.items() if available]
|
||||
unavailable_engines = [name for name, available in engines_status.items() if not available]
|
||||
|
||||
|
||||
if available_engines:
|
||||
logger.info(f"✅ 可用搜索引擎: {', '.join(available_engines)}")
|
||||
if unavailable_engines:
|
||||
logger.info(f"❌ 不可用搜索引擎: {', '.join(unavailable_engines)}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 搜索引擎初始化失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
# Python包依赖列表
|
||||
python_dependencies: List[PythonDependency] = [
|
||||
PythonDependency(
|
||||
package_name="asyncddgs",
|
||||
description="异步DuckDuckGo搜索库",
|
||||
optional=False
|
||||
),
|
||||
PythonDependency(package_name="asyncddgs", description="异步DuckDuckGo搜索库", optional=False),
|
||||
PythonDependency(
|
||||
package_name="exa_py",
|
||||
description="Exa搜索API客户端库",
|
||||
optional=True # 如果没有API密钥,这个是可选的
|
||||
optional=True, # 如果没有API密钥,这个是可选的
|
||||
),
|
||||
PythonDependency(
|
||||
package_name="tavily",
|
||||
install_name="tavily-python", # 安装时使用这个名称
|
||||
description="Tavily搜索API客户端库",
|
||||
optional=True # 如果没有API密钥,这个是可选的
|
||||
optional=True, # 如果没有API密钥,这个是可选的
|
||||
),
|
||||
PythonDependency(
|
||||
package_name="httpx",
|
||||
version=">=0.20.0",
|
||||
install_name="httpx[socks]", # 安装时使用这个名称(包含可选依赖)
|
||||
description="支持SOCKS代理的HTTP客户端库",
|
||||
optional=False
|
||||
)
|
||||
optional=False,
|
||||
),
|
||||
]
|
||||
config_file_name: str = "config.toml" # 配置文件名
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {
|
||||
"plugin": "插件基本信息",
|
||||
"proxy": "链接本地解析代理配置"
|
||||
}
|
||||
config_section_descriptions = {"plugin": "插件基本信息", "proxy": "链接本地解析代理配置"}
|
||||
|
||||
# 配置Schema定义
|
||||
# 注意:EXA配置和组件设置已迁移到主配置文件(bot_config.toml)的[exa]和[web_search]部分
|
||||
@@ -119,42 +107,32 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
},
|
||||
"proxy": {
|
||||
"http_proxy": ConfigField(
|
||||
type=str,
|
||||
default=None,
|
||||
description="HTTP代理地址,格式如: http://proxy.example.com:8080"
|
||||
type=str, default=None, description="HTTP代理地址,格式如: http://proxy.example.com:8080"
|
||||
),
|
||||
"https_proxy": ConfigField(
|
||||
type=str,
|
||||
default=None,
|
||||
description="HTTPS代理地址,格式如: http://proxy.example.com:8080"
|
||||
type=str, default=None, description="HTTPS代理地址,格式如: http://proxy.example.com:8080"
|
||||
),
|
||||
"socks5_proxy": ConfigField(
|
||||
type=str,
|
||||
default=None,
|
||||
description="SOCKS5代理地址,格式如: socks5://proxy.example.com:1080"
|
||||
type=str, default=None, description="SOCKS5代理地址,格式如: socks5://proxy.example.com:1080"
|
||||
),
|
||||
"enable_proxy": ConfigField(
|
||||
type=bool,
|
||||
default=False,
|
||||
description="是否启用代理"
|
||||
)
|
||||
"enable_proxy": ConfigField(type=bool, default=False, description="是否启用代理"),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
"""
|
||||
获取插件组件列表
|
||||
|
||||
|
||||
Returns:
|
||||
组件信息和类型的元组列表
|
||||
"""
|
||||
enable_tool = []
|
||||
|
||||
|
||||
# 从主配置文件读取组件启用配置
|
||||
if config_api.get_global_config("web_search.enable_web_search_tool", True):
|
||||
enable_tool.append((WebSurfingTool.get_tool_info(), WebSurfingTool))
|
||||
|
||||
|
||||
if config_api.get_global_config("web_search.enable_url_tool", True):
|
||||
enable_tool.append((URLParserTool.get_tool_info(), URLParserTool))
|
||||
|
||||
|
||||
return enable_tool
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
URL parser tool implementation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from typing import Any, Dict
|
||||
@@ -24,17 +25,18 @@ class URLParserTool(BaseTool):
|
||||
"""
|
||||
一个用于解析和总结一个或多个网页URL内容的工具。
|
||||
"""
|
||||
|
||||
name: str = "parse_url"
|
||||
description: str = "当需要理解一个或多个特定网页链接的内容时,使用此工具。例如:'这些网页讲了什么?[https://example.com, https://example2.com]' 或 '帮我总结一下这些文章'"
|
||||
available_for_llm: bool = True
|
||||
parameters = [
|
||||
("urls", ToolParamType.STRING, "要理解的网站", True, None),
|
||||
]
|
||||
|
||||
|
||||
def __init__(self, plugin_config=None):
|
||||
super().__init__(plugin_config)
|
||||
self._initialize_exa_clients()
|
||||
|
||||
|
||||
def _initialize_exa_clients(self):
|
||||
"""初始化Exa客户端"""
|
||||
# 优先从主配置文件读取,如果没有则从插件配置文件读取
|
||||
@@ -42,12 +44,10 @@ class URLParserTool(BaseTool):
|
||||
if exa_api_keys is None:
|
||||
# 从插件配置文件读取
|
||||
exa_api_keys = self.get_config("exa.api_keys", [])
|
||||
|
||||
|
||||
# 创建API密钥管理器
|
||||
self.api_manager = create_api_key_manager_from_config(
|
||||
exa_api_keys,
|
||||
lambda key: Exa(api_key=key),
|
||||
"Exa URL Parser"
|
||||
exa_api_keys, lambda key: Exa(api_key=key), "Exa URL Parser"
|
||||
)
|
||||
|
||||
async def _local_parse_and_summarize(self, url: str) -> Dict[str, Any]:
|
||||
@@ -58,12 +58,12 @@ class URLParserTool(BaseTool):
|
||||
# 读取代理配置
|
||||
enable_proxy = self.get_config("proxy.enable_proxy", False)
|
||||
proxies = None
|
||||
|
||||
|
||||
if enable_proxy:
|
||||
socks5_proxy = self.get_config("proxy.socks5_proxy", None)
|
||||
http_proxy = self.get_config("proxy.http_proxy", None)
|
||||
https_proxy = self.get_config("proxy.https_proxy", None)
|
||||
|
||||
|
||||
# 优先使用SOCKS5代理(全协议代理)
|
||||
if socks5_proxy:
|
||||
proxies = socks5_proxy
|
||||
@@ -75,17 +75,17 @@ class URLParserTool(BaseTool):
|
||||
if https_proxy:
|
||||
proxies["https://"] = https_proxy
|
||||
logger.info(f"使用HTTP/HTTPS代理配置: {proxies}")
|
||||
|
||||
|
||||
client_kwargs = {"timeout": 15.0, "follow_redirects": True}
|
||||
if proxies:
|
||||
client_kwargs["proxies"] = proxies
|
||||
|
||||
|
||||
async with httpx.AsyncClient(**client_kwargs) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
|
||||
|
||||
title = soup.title.string if soup.title else "无标题"
|
||||
for script in soup(["script", "style"]):
|
||||
script.extract()
|
||||
@@ -104,12 +104,12 @@ class URLParserTool(BaseTool):
|
||||
return {"error": "未配置LLM模型"}
|
||||
|
||||
success, summary, reasoning, model_name = await llm_api.generate_with_model(
|
||||
prompt=summary_prompt,
|
||||
model_config=model_config,
|
||||
request_type="story.generate",
|
||||
temperature=0.3,
|
||||
max_tokens=1000
|
||||
)
|
||||
prompt=summary_prompt,
|
||||
model_config=model_config,
|
||||
request_type="story.generate",
|
||||
temperature=0.3,
|
||||
max_tokens=1000,
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.info(f"生成摘要失败: {summary}")
|
||||
@@ -117,12 +117,7 @@ class URLParserTool(BaseTool):
|
||||
|
||||
logger.info(f"成功生成摘要内容:'{summary}'")
|
||||
|
||||
return {
|
||||
"title": title,
|
||||
"url": url,
|
||||
"snippet": summary,
|
||||
"source": "local"
|
||||
}
|
||||
return {"title": title, "url": url, "snippet": summary, "source": "local"}
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.warning(f"本地解析URL '{url}' 失败 (HTTP {e.response.status_code})")
|
||||
@@ -137,6 +132,7 @@ class URLParserTool(BaseTool):
|
||||
"""
|
||||
# 获取当前文件路径用于缓存键
|
||||
import os
|
||||
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
|
||||
# 检查缓存
|
||||
@@ -144,7 +140,7 @@ class URLParserTool(BaseTool):
|
||||
if cached_result:
|
||||
logger.info(f"缓存命中: {self.name} -> {function_args}")
|
||||
return cached_result
|
||||
|
||||
|
||||
urls_input = function_args.get("urls")
|
||||
if not urls_input:
|
||||
return {"error": "URL列表不能为空。"}
|
||||
@@ -158,14 +154,14 @@ class URLParserTool(BaseTool):
|
||||
valid_urls = validate_urls(urls)
|
||||
if not valid_urls:
|
||||
return {"error": "未找到有效的URL。"}
|
||||
|
||||
|
||||
urls = valid_urls
|
||||
logger.info(f"准备解析 {len(urls)} 个URL: {urls}")
|
||||
|
||||
successful_results = []
|
||||
error_messages = []
|
||||
urls_to_retry_locally = []
|
||||
|
||||
|
||||
# 步骤 1: 尝试使用 Exa API 进行解析
|
||||
contents_response = None
|
||||
if self.api_manager.is_available():
|
||||
@@ -182,41 +178,45 @@ class URLParserTool(BaseTool):
|
||||
contents_response = await loop.run_in_executor(None, func)
|
||||
except Exception as e:
|
||||
logger.error(f"执行 Exa URL解析时发生严重异常: {e}", exc_info=True)
|
||||
contents_response = None # 确保异常后为None
|
||||
contents_response = None # 确保异常后为None
|
||||
|
||||
# 步骤 2: 处理Exa的响应
|
||||
if contents_response and hasattr(contents_response, 'statuses'):
|
||||
results_map = {res.url: res for res in contents_response.results} if hasattr(contents_response, 'results') else {}
|
||||
if contents_response and hasattr(contents_response, "statuses"):
|
||||
results_map = (
|
||||
{res.url: res for res in contents_response.results} if hasattr(contents_response, "results") else {}
|
||||
)
|
||||
if contents_response.statuses:
|
||||
for status in contents_response.statuses:
|
||||
if status.status == 'success':
|
||||
if status.status == "success":
|
||||
res = results_map.get(status.id)
|
||||
if res:
|
||||
summary = getattr(res, 'summary', '')
|
||||
highlights = " ".join(getattr(res, 'highlights', []))
|
||||
text_snippet = (getattr(res, 'text', '')[:300] + '...') if getattr(res, 'text', '') else ''
|
||||
snippet = summary or highlights or text_snippet or '无摘要'
|
||||
|
||||
successful_results.append({
|
||||
"title": getattr(res, 'title', '无标题'),
|
||||
"url": getattr(res, 'url', status.id),
|
||||
"snippet": snippet,
|
||||
"source": "exa"
|
||||
})
|
||||
summary = getattr(res, "summary", "")
|
||||
highlights = " ".join(getattr(res, "highlights", []))
|
||||
text_snippet = (getattr(res, "text", "")[:300] + "...") if getattr(res, "text", "") else ""
|
||||
snippet = summary or highlights or text_snippet or "无摘要"
|
||||
|
||||
successful_results.append(
|
||||
{
|
||||
"title": getattr(res, "title", "无标题"),
|
||||
"url": getattr(res, "url", status.id),
|
||||
"snippet": snippet,
|
||||
"source": "exa",
|
||||
}
|
||||
)
|
||||
else:
|
||||
error_tag = getattr(status, 'error', '未知错误')
|
||||
error_tag = getattr(status, "error", "未知错误")
|
||||
logger.warning(f"Exa解析URL '{status.id}' 失败: {error_tag}。准备本地重试。")
|
||||
urls_to_retry_locally.append(status.id)
|
||||
else:
|
||||
# 如果Exa未配置、API调用失败或返回无效响应,则所有URL都进入本地重试
|
||||
urls_to_retry_locally.extend(url for url in urls if url not in [res['url'] for res in successful_results])
|
||||
urls_to_retry_locally.extend(url for url in urls if url not in [res["url"] for res in successful_results])
|
||||
|
||||
# 步骤 3: 对失败的URL进行本地解析
|
||||
if urls_to_retry_locally:
|
||||
logger.info(f"开始本地解析以下URL: {urls_to_retry_locally}")
|
||||
local_tasks = [self._local_parse_and_summarize(url) for url in urls_to_retry_locally]
|
||||
local_results = await asyncio.gather(*local_tasks)
|
||||
|
||||
|
||||
for i, res in enumerate(local_results):
|
||||
url = urls_to_retry_locally[i]
|
||||
if "error" in res:
|
||||
@@ -228,13 +228,9 @@ class URLParserTool(BaseTool):
|
||||
return {"error": "无法从所有给定的URL获取内容。", "details": error_messages}
|
||||
|
||||
formatted_content = format_url_parse_results(successful_results)
|
||||
|
||||
result = {
|
||||
"type": "url_parse_result",
|
||||
"content": formatted_content,
|
||||
"errors": error_messages
|
||||
}
|
||||
|
||||
|
||||
result = {"type": "url_parse_result", "content": formatted_content, "errors": error_messages}
|
||||
|
||||
# 保存到缓存
|
||||
if "error" not in result:
|
||||
await tool_cache.set(self.name, function_args, current_file_path, result)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Web search tool implementation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
|
||||
@@ -22,14 +23,23 @@ class WebSurfingTool(BaseTool):
|
||||
"""
|
||||
网络搜索工具
|
||||
"""
|
||||
|
||||
name: str = "web_search"
|
||||
description: str = "用于执行网络搜索。当用户明确要求搜索,或者需要获取关于公司、产品、事件的最新信息、新闻或动态时,必须使用此工具"
|
||||
description: str = (
|
||||
"用于执行网络搜索。当用户明确要求搜索,或者需要获取关于公司、产品、事件的最新信息、新闻或动态时,必须使用此工具"
|
||||
)
|
||||
available_for_llm: bool = True
|
||||
parameters = [
|
||||
("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None),
|
||||
("num_results", ToolParamType.INTEGER, "期望每个搜索引擎返回的搜索结果数量,默认为5。", False, None),
|
||||
("time_range", ToolParamType.STRING, "指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'。", False, ["any", "week", "month"])
|
||||
] # type: ignore
|
||||
(
|
||||
"time_range",
|
||||
ToolParamType.STRING,
|
||||
"指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'。",
|
||||
False,
|
||||
["any", "week", "month"],
|
||||
),
|
||||
] # type: ignore
|
||||
|
||||
def __init__(self, plugin_config=None):
|
||||
super().__init__(plugin_config)
|
||||
@@ -38,7 +48,7 @@ class WebSurfingTool(BaseTool):
|
||||
"exa": ExaSearchEngine(),
|
||||
"tavily": TavilySearchEngine(),
|
||||
"ddg": DDGSearchEngine(),
|
||||
"bing": BingSearchEngine()
|
||||
"bing": BingSearchEngine(),
|
||||
}
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@@ -48,6 +58,7 @@ class WebSurfingTool(BaseTool):
|
||||
|
||||
# 获取当前文件路径用于缓存键
|
||||
import os
|
||||
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
|
||||
# 检查缓存
|
||||
@@ -59,7 +70,7 @@ class WebSurfingTool(BaseTool):
|
||||
# 读取搜索配置
|
||||
enabled_engines = config_api.get_global_config("web_search.enabled_engines", ["ddg"])
|
||||
search_strategy = config_api.get_global_config("web_search.search_strategy", "single")
|
||||
|
||||
|
||||
logger.info(f"开始搜索,策略: {search_strategy}, 启用引擎: {enabled_engines}, 参数: '{function_args}'")
|
||||
|
||||
# 根据策略执行搜索
|
||||
@@ -69,17 +80,19 @@ class WebSurfingTool(BaseTool):
|
||||
result = await self._execute_fallback_search(function_args, enabled_engines)
|
||||
else: # single
|
||||
result = await self._execute_single_search(function_args, enabled_engines)
|
||||
|
||||
|
||||
# 保存到缓存
|
||||
if "error" not in result:
|
||||
await tool_cache.set(self.name, function_args, current_file_path, result, semantic_query=query)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
async def _execute_parallel_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]:
|
||||
async def _execute_parallel_search(
|
||||
self, function_args: Dict[str, Any], enabled_engines: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""并行搜索策略:同时使用所有启用的搜索引擎"""
|
||||
search_tasks = []
|
||||
|
||||
|
||||
for engine_name in enabled_engines:
|
||||
engine = self.engines.get(engine_name)
|
||||
if engine and engine.is_available():
|
||||
@@ -92,7 +105,7 @@ class WebSurfingTool(BaseTool):
|
||||
|
||||
try:
|
||||
search_results_lists = await asyncio.gather(*search_tasks, return_exceptions=True)
|
||||
|
||||
|
||||
all_results = []
|
||||
for result in search_results_lists:
|
||||
if isinstance(result, list):
|
||||
@@ -103,7 +116,7 @@ class WebSurfingTool(BaseTool):
|
||||
# 去重并格式化
|
||||
unique_results = deduplicate_results(all_results)
|
||||
formatted_content = format_search_results(unique_results)
|
||||
|
||||
|
||||
return {
|
||||
"type": "web_search_result",
|
||||
"content": formatted_content,
|
||||
@@ -113,30 +126,32 @@ class WebSurfingTool(BaseTool):
|
||||
logger.error(f"执行并行网络搜索时发生异常: {e}", exc_info=True)
|
||||
return {"error": f"执行网络搜索时发生严重错误: {str(e)}"}
|
||||
|
||||
async def _execute_fallback_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]:
|
||||
async def _execute_fallback_search(
|
||||
self, function_args: Dict[str, Any], enabled_engines: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""回退搜索策略:按顺序尝试搜索引擎,失败则尝试下一个"""
|
||||
for engine_name in enabled_engines:
|
||||
engine = self.engines.get(engine_name)
|
||||
if not engine or not engine.is_available():
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
custom_args = function_args.copy()
|
||||
custom_args["num_results"] = custom_args.get("num_results", 5)
|
||||
|
||||
|
||||
results = await engine.search(custom_args)
|
||||
|
||||
|
||||
if results: # 如果有结果,直接返回
|
||||
formatted_content = format_search_results(results)
|
||||
return {
|
||||
"type": "web_search_result",
|
||||
"content": formatted_content,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"{engine_name} 搜索失败,尝试下一个引擎: {e}")
|
||||
continue
|
||||
|
||||
|
||||
return {"error": "所有搜索引擎都失败了。"}
|
||||
|
||||
async def _execute_single_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]:
|
||||
@@ -145,20 +160,20 @@ class WebSurfingTool(BaseTool):
|
||||
engine = self.engines.get(engine_name)
|
||||
if not engine or not engine.is_available():
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
custom_args = function_args.copy()
|
||||
custom_args["num_results"] = custom_args.get("num_results", 5)
|
||||
|
||||
|
||||
results = await engine.search(custom_args)
|
||||
formatted_content = format_search_results(results)
|
||||
return {
|
||||
"type": "web_search_result",
|
||||
"content": formatted_content,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{engine_name} 搜索失败: {e}")
|
||||
return {"error": f"{engine_name} 搜索失败: {str(e)}"}
|
||||
|
||||
|
||||
return {"error": "没有可用的搜索引擎。"}
|
||||
|
||||
@@ -1,24 +1,25 @@
|
||||
"""
|
||||
API密钥管理器,提供轮询机制
|
||||
"""
|
||||
|
||||
import itertools
|
||||
from typing import List, Optional, TypeVar, Generic, Callable
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("api_key_manager")
|
||||
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class APIKeyManager(Generic[T]):
|
||||
"""
|
||||
API密钥管理器,支持轮询机制
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, api_keys: List[str], client_factory: Callable[[str], T], service_name: str = "Unknown"):
|
||||
"""
|
||||
初始化API密钥管理器
|
||||
|
||||
|
||||
Args:
|
||||
api_keys: API密钥列表
|
||||
client_factory: 客户端工厂函数,接受API密钥参数并返回客户端实例
|
||||
@@ -27,14 +28,14 @@ class APIKeyManager(Generic[T]):
|
||||
self.service_name = service_name
|
||||
self.clients: List[T] = []
|
||||
self.client_cycle: Optional[itertools.cycle] = None
|
||||
|
||||
|
||||
if api_keys:
|
||||
# 过滤有效的API密钥,排除None、空字符串、"None"字符串等
|
||||
valid_keys = []
|
||||
for key in api_keys:
|
||||
if isinstance(key, str) and key.strip() and key.strip().lower() not in ("none", "null", ""):
|
||||
valid_keys.append(key.strip())
|
||||
|
||||
|
||||
if valid_keys:
|
||||
try:
|
||||
self.clients = [client_factory(key) for key in valid_keys]
|
||||
@@ -48,35 +49,33 @@ class APIKeyManager(Generic[T]):
|
||||
logger.warning(f"⚠️ {service_name} API Keys 配置无效(包含None或空值),{service_name} 功能将不可用")
|
||||
else:
|
||||
logger.warning(f"⚠️ {service_name} API Keys 未配置,{service_name} 功能将不可用")
|
||||
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""检查是否有可用的客户端"""
|
||||
return bool(self.clients and self.client_cycle)
|
||||
|
||||
|
||||
def get_next_client(self) -> Optional[T]:
|
||||
"""获取下一个客户端(轮询)"""
|
||||
if not self.is_available():
|
||||
return None
|
||||
return next(self.client_cycle)
|
||||
|
||||
|
||||
def get_client_count(self) -> int:
|
||||
"""获取可用客户端数量"""
|
||||
return len(self.clients)
|
||||
|
||||
|
||||
def create_api_key_manager_from_config(
|
||||
config_keys: Optional[List[str]],
|
||||
client_factory: Callable[[str], T],
|
||||
service_name: str
|
||||
config_keys: Optional[List[str]], client_factory: Callable[[str], T], service_name: str
|
||||
) -> APIKeyManager[T]:
|
||||
"""
|
||||
从配置创建API密钥管理器的便捷函数
|
||||
|
||||
|
||||
Args:
|
||||
config_keys: 从配置读取的API密钥列表
|
||||
client_factory: 客户端工厂函数
|
||||
service_name: 服务名称
|
||||
|
||||
|
||||
Returns:
|
||||
API密钥管理器实例
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Formatters for web search results
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any
|
||||
|
||||
|
||||
@@ -13,15 +14,15 @@ def format_search_results(results: List[Dict[str, Any]]) -> str:
|
||||
|
||||
formatted_string = "根据网络搜索结果:\n\n"
|
||||
for i, res in enumerate(results, 1):
|
||||
title = res.get("title", '无标题')
|
||||
url = res.get("url", '#')
|
||||
snippet = res.get("snippet", '无摘要')
|
||||
title = res.get("title", "无标题")
|
||||
url = res.get("url", "#")
|
||||
snippet = res.get("snippet", "无摘要")
|
||||
provider = res.get("provider", "未知来源")
|
||||
|
||||
|
||||
formatted_string += f"{i}. **{title}** (来自: {provider})\n"
|
||||
formatted_string += f" - 摘要: {snippet}\n"
|
||||
formatted_string += f" - 来源: {url}\n\n"
|
||||
|
||||
|
||||
return formatted_string
|
||||
|
||||
|
||||
@@ -31,10 +32,10 @@ def format_url_parse_results(results: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
formatted_parts = []
|
||||
for res in results:
|
||||
title = res.get('title', '无标题')
|
||||
url = res.get('url', '#')
|
||||
snippet = res.get('snippet', '无摘要')
|
||||
source = res.get('source', '未知')
|
||||
title = res.get("title", "无标题")
|
||||
url = res.get("url", "#")
|
||||
snippet = res.get("snippet", "无摘要")
|
||||
source = res.get("source", "未知")
|
||||
|
||||
formatted_string = f"**{title}**\n"
|
||||
formatted_string += f"**内容摘要**:\n{snippet}\n"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
URL processing utilities
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
@@ -12,11 +13,11 @@ def parse_urls_from_input(urls_input) -> List[str]:
|
||||
if isinstance(urls_input, str):
|
||||
# 如果是字符串,尝试解析为URL列表
|
||||
# 提取所有HTTP/HTTPS URL
|
||||
url_pattern = r'https?://[^\s\],]+'
|
||||
url_pattern = r"https?://[^\s\],]+"
|
||||
urls = re.findall(url_pattern, urls_input)
|
||||
if not urls:
|
||||
# 如果没有找到标准URL,将整个字符串作为单个URL
|
||||
if urls_input.strip().startswith(('http://', 'https://')):
|
||||
if urls_input.strip().startswith(("http://", "https://")):
|
||||
urls = [urls_input.strip()]
|
||||
else:
|
||||
return []
|
||||
@@ -24,7 +25,7 @@ def parse_urls_from_input(urls_input) -> List[str]:
|
||||
urls = [url.strip() for url in urls_input if isinstance(url, str) and url.strip()]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
@@ -34,6 +35,6 @@ def validate_urls(urls: List[str]) -> List[str]:
|
||||
"""
|
||||
valid_urls = []
|
||||
for url in urls:
|
||||
if url.startswith(('http://', 'https://')):
|
||||
if url.startswith(("http://", "https://")):
|
||||
valid_urls.append(url)
|
||||
return valid_urls
|
||||
|
||||
@@ -21,8 +21,18 @@ logger = get_logger(__name__)
|
||||
|
||||
# ============================ AsyncTask ============================
|
||||
|
||||
|
||||
class ReminderTask(AsyncTask):
|
||||
def __init__(self, delay: float, stream_id: str, is_group: bool, target_user_id: str, target_user_name: str, event_details: str, creator_name: str):
|
||||
def __init__(
|
||||
self,
|
||||
delay: float,
|
||||
stream_id: str,
|
||||
is_group: bool,
|
||||
target_user_id: str,
|
||||
target_user_name: str,
|
||||
event_details: str,
|
||||
creator_name: str,
|
||||
):
|
||||
super().__init__(task_name=f"ReminderTask_{target_user_id}_{datetime.now().timestamp()}")
|
||||
self.delay = delay
|
||||
self.stream_id = stream_id
|
||||
@@ -37,22 +47,22 @@ class ReminderTask(AsyncTask):
|
||||
if self.delay > 0:
|
||||
logger.info(f"等待 {self.delay:.2f} 秒后执行提醒...")
|
||||
await asyncio.sleep(self.delay)
|
||||
|
||||
|
||||
logger.info(f"执行提醒任务: 给 {self.target_user_name} 发送关于 '{self.event_details}' 的提醒")
|
||||
|
||||
reminder_text = f"叮咚!这是 {self.creator_name} 让我准时提醒你的事情:\n\n{self.event_details}"
|
||||
|
||||
if self.is_group:
|
||||
# 在群聊中,构造 @ 消息段并发送
|
||||
group_id = self.stream_id.split('_')[-1] if '_' in self.stream_id else self.stream_id
|
||||
group_id = self.stream_id.split("_")[-1] if "_" in self.stream_id else self.stream_id
|
||||
message_payload = [
|
||||
{"type": "at", "data": {"qq": self.target_user_id}},
|
||||
{"type": "text", "data": {"text": f" {reminder_text}"}}
|
||||
{"type": "text", "data": {"text": f" {reminder_text}"}},
|
||||
]
|
||||
await send_api.adapter_command_to_stream(
|
||||
action="send_group_msg",
|
||||
params={"group_id": group_id, "message": message_payload},
|
||||
stream_id=self.stream_id
|
||||
stream_id=self.stream_id,
|
||||
)
|
||||
else:
|
||||
# 在私聊中,直接发送文本
|
||||
@@ -66,6 +76,7 @@ class ReminderTask(AsyncTask):
|
||||
|
||||
# =============================== Actions ===============================
|
||||
|
||||
|
||||
class RemindAction(BaseAction):
|
||||
"""一个能从对话中智能识别并设置定时提醒的动作。"""
|
||||
|
||||
@@ -95,12 +106,12 @@ class RemindAction(BaseAction):
|
||||
action_parameters = {
|
||||
"user_name": "需要被提醒的人的称呼或名字,如果没有明确指定给某人,则默认为'自己'",
|
||||
"remind_time": "描述提醒时间的自然语言字符串,例如'十分钟后'或'明天下午3点'",
|
||||
"event_details": "需要提醒的具体事件内容"
|
||||
"event_details": "需要提醒的具体事件内容",
|
||||
}
|
||||
action_require = [
|
||||
"当用户请求在未来的某个时间点提醒他/她或别人某件事时使用",
|
||||
"适用于包含明确时间信息和事件描述的对话",
|
||||
"例如:'10分钟后提醒我收快递'、'明天早上九点喊一下李四参加晨会'"
|
||||
"例如:'10分钟后提醒我收快递'、'明天早上九点喊一下李四参加晨会'",
|
||||
]
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
@@ -110,7 +121,15 @@ class RemindAction(BaseAction):
|
||||
event_details = self.action_data.get("event_details")
|
||||
|
||||
if not all([user_name, remind_time_str, event_details]):
|
||||
missing_params = [p for p, v in {"user_name": user_name, "remind_time": remind_time_str, "event_details": event_details}.items() if not v]
|
||||
missing_params = [
|
||||
p
|
||||
for p, v in {
|
||||
"user_name": user_name,
|
||||
"remind_time": remind_time_str,
|
||||
"event_details": event_details,
|
||||
}.items()
|
||||
if not v
|
||||
]
|
||||
error_msg = f"缺少必要的提醒参数: {', '.join(missing_params)}"
|
||||
logger.warning(f"[ReminderPlugin] LLM未能提取完整参数: {error_msg}")
|
||||
return False, error_msg
|
||||
@@ -135,9 +154,9 @@ class RemindAction(BaseAction):
|
||||
person_manager = get_person_info_manager()
|
||||
user_id_to_remind = None
|
||||
user_name_to_remind = ""
|
||||
|
||||
|
||||
assert isinstance(user_name, str)
|
||||
|
||||
|
||||
if user_name.strip() in ["自己", "我", "me"]:
|
||||
user_id_to_remind = self.user_id
|
||||
user_name_to_remind = self.user_nickname
|
||||
@@ -154,7 +173,7 @@ class RemindAction(BaseAction):
|
||||
try:
|
||||
assert user_id_to_remind is not None
|
||||
assert event_details is not None
|
||||
|
||||
|
||||
reminder_task = ReminderTask(
|
||||
delay=delay_seconds,
|
||||
stream_id=self.chat_id,
|
||||
@@ -162,14 +181,14 @@ class RemindAction(BaseAction):
|
||||
target_user_id=str(user_id_to_remind),
|
||||
target_user_name=str(user_name_to_remind),
|
||||
event_details=str(event_details),
|
||||
creator_name=str(self.user_nickname)
|
||||
creator_name=str(self.user_nickname),
|
||||
)
|
||||
await async_task_manager.add_task(reminder_task)
|
||||
|
||||
|
||||
# 4. 发送确认消息
|
||||
confirm_message = f"好的,我记下了。\n将在 {target_time.strftime('%Y-%m-%d %H:%M:%S')} 提醒 {user_name_to_remind}:\n{event_details}"
|
||||
await self.send_text(confirm_message)
|
||||
|
||||
|
||||
return True, "提醒设置成功"
|
||||
except Exception as e:
|
||||
logger.error(f"[ReminderPlugin] 创建提醒任务时出错: {e}", exc_info=True)
|
||||
@@ -179,6 +198,7 @@ class RemindAction(BaseAction):
|
||||
|
||||
# =============================== Plugin ===============================
|
||||
|
||||
|
||||
@register_plugin
|
||||
class ReminderPlugin(BasePlugin):
|
||||
"""一个能从对话中智能识别并设置定时提醒的插件。"""
|
||||
@@ -193,6 +213,4 @@ class ReminderPlugin(BasePlugin):
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ActionInfo, Type[BaseAction]]]:
|
||||
"""注册插件的所有功能组件。"""
|
||||
return [
|
||||
(RemindAction.get_action_info(), RemindAction)
|
||||
]
|
||||
return [(RemindAction.get_action_info(), RemindAction)]
|
||||
|
||||
@@ -290,4 +290,4 @@ def has_active_plans(month: str) -> bool:
|
||||
return count > 0
|
||||
except Exception as e:
|
||||
logger.error(f"检查 {month} 的有效月度计划时发生错误: {e}")
|
||||
return False
|
||||
return False
|
||||
|
||||
@@ -221,4 +221,4 @@ class MonthlyPlanLLMGenerator:
|
||||
return plans
|
||||
except Exception as e:
|
||||
logger.error(f"解析月度计划响应时发生错误: {e}")
|
||||
return []
|
||||
return []
|
||||
|
||||
@@ -102,4 +102,4 @@ class PlanManager:
|
||||
|
||||
def get_plans_for_schedule(self, month: str, max_count: int) -> List:
|
||||
avoid_days = global_config.planning_system.avoid_repetition_days
|
||||
return get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days)
|
||||
return get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days)
|
||||
|
||||
@@ -96,4 +96,4 @@ class ScheduleData(BaseModel):
|
||||
covered[i] = True
|
||||
|
||||
# 检查是否所有分钟都被覆盖
|
||||
return all(covered)
|
||||
return all(covered)
|
||||
|
||||
Reference in New Issue
Block a user