ruff,私聊视为提及了bot

This commit is contained in:
Windpicker-owo
2025-09-20 22:34:22 +08:00
parent 006f9130b9
commit 444f1ca315
76 changed files with 1066 additions and 882 deletions

11
bot.py
View File

@@ -34,6 +34,7 @@ script_dir = os.path.dirname(os.path.abspath(__file__))
os.chdir(script_dir)
logger.info(f"已设置工作目录为: {script_dir}")
# 检查并创建.env文件
def ensure_env_file():
"""确保.env文件存在如果不存在则从模板创建"""
@@ -44,6 +45,7 @@ def ensure_env_file():
if template_env.exists():
logger.info("未找到.env文件正在从模板创建...")
import shutil
shutil.copy(template_env, env_file)
logger.info("已从template/template.env创建.env文件")
logger.warning("请编辑.env文件将EULA_CONFIRMED设置为true并配置其他必要参数")
@@ -51,6 +53,7 @@ def ensure_env_file():
logger.error("未找到.env文件和template.env模板文件")
sys.exit(1)
# 确保环境文件存在
ensure_env_file()
@@ -130,9 +133,9 @@ async def graceful_shutdown():
def check_eula():
"""检查EULA和隐私条款确认状态 - 环境变量版类似Minecraft"""
# 检查环境变量中的EULA确认
eula_confirmed = os.getenv('EULA_CONFIRMED', '').lower()
eula_confirmed = os.getenv("EULA_CONFIRMED", "").lower()
if eula_confirmed == 'true':
if eula_confirmed == "true":
logger.info("EULA已通过环境变量确认")
return
@@ -148,8 +151,8 @@ def check_eula():
try:
load_dotenv(override=True) # 重新加载.env文件
eula_confirmed = os.getenv('EULA_CONFIRMED', '').lower()
if eula_confirmed == 'true':
eula_confirmed = os.getenv("EULA_CONFIRMED", "").lower()
if eula_confirmed == "true":
confirm_logger.info("EULA确认成功感谢您的同意")
return

View File

@@ -20,16 +20,17 @@ files_to_update = [
"src/mais4u/mais4u_chat/s4u_mood_manager.py",
"src/plugin_system/core/tool_use.py",
"src/chat/memory_system/memory_activator.py",
"src/chat/utils/smart_prompt.py"
"src/chat/utils/smart_prompt.py",
]
def update_prompt_imports(file_path):
"""更新文件中的Prompt导入"""
if not os.path.exists(file_path):
print(f"文件不存在: {file_path}")
return False
with open(file_path, 'r', encoding='utf-8') as f:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
# 替换导入语句
@@ -38,7 +39,7 @@ def update_prompt_imports(file_path):
if old_import in content:
new_content = content.replace(old_import, new_import)
with open(file_path, 'w', encoding='utf-8') as f:
with open(file_path, "w", encoding="utf-8") as f:
f.write(new_content)
print(f"已更新: {file_path}")
return True
@@ -46,6 +47,7 @@ def update_prompt_imports(file_path):
print(f"无需更新: {file_path}")
return False
def main():
"""主函数"""
updated_count = 0
@@ -55,5 +57,6 @@ def main():
print(f"\n更新完成!共更新了 {updated_count} 个文件")
if __name__ == "__main__":
main()

View File

@@ -5,4 +5,4 @@
from src.chat.affinity_flow.afc_manager import afc_manager
__all__ = ['afc_manager', 'AFCManager', 'AffinityFlowChatter']
__all__ = ["afc_manager", "AFCManager", "AffinityFlowChatter"]

View File

@@ -2,6 +2,7 @@
亲和力聊天处理流管理器
管理不同聊天流的亲和力聊天处理流,统一获取新消息并分发到对应的亲和力聊天处理流
"""
import time
import traceback
from typing import Dict, Optional, List
@@ -20,7 +21,7 @@ class AFCManager:
def __init__(self):
self.affinity_flow_chatters: Dict[str, "AffinityFlowChatter"] = {}
'''所有聊天流的亲和力聊天处理流stream_id -> affinity_flow_chatter'''
"""所有聊天流的亲和力聊天处理流stream_id -> affinity_flow_chatter"""
# 动作管理器
self.action_manager = ActionManager()
@@ -40,11 +41,7 @@ class AFCManager:
# 创建增强版规划器
planner = ActionPlanner(stream_id, self.action_manager)
chatter = AffinityFlowChatter(
stream_id=stream_id,
planner=planner,
action_manager=self.action_manager
)
chatter = AffinityFlowChatter(stream_id=stream_id, planner=planner, action_manager=self.action_manager)
self.affinity_flow_chatters[stream_id] = chatter
logger.info(f"创建新的亲和力聊天处理器: {stream_id}")
@@ -74,7 +71,6 @@ class AFCManager:
"executed_count": 0,
}
def get_chatter_stats(self, stream_id: str) -> Optional[Dict[str, any]]:
"""获取聊天处理器统计"""
if stream_id in self.affinity_flow_chatters:
@@ -131,4 +127,5 @@ class AFCManager:
self.affinity_flow_chatters[stream_id].update_interest_keywords(new_keywords)
logger.info(f"已更新聊天流 {stream_id} 的兴趣关键词: {list(new_keywords.keys())}")
afc_manager = AFCManager()

View File

@@ -2,6 +2,7 @@
亲和力聊天处理器
单个聊天流的处理器,负责处理特定聊天流的完整交互流程
"""
import time
import traceback
from datetime import datetime
@@ -57,10 +58,7 @@ class AffinityFlowChatter:
unread_messages = context.get_unread_messages()
# 使用增强版规划器处理消息
actions, target_message = await self.planner.plan(
mode=ChatMode.FOCUS,
context=context
)
actions, target_message = await self.planner.plan(mode=ChatMode.FOCUS, context=context)
self.stats["plans_created"] += 1
# 执行动作(如果规划器返回了动作)
@@ -84,7 +82,9 @@ class AffinityFlowChatter:
**execution_result,
}
logger.info(f"聊天流 {self.stream_id} StreamContext处理成功: 动作数={result['actions_count']}, 未读消息={result['unread_messages_processed']}")
logger.info(
f"聊天流 {self.stream_id} StreamContext处理成功: 动作数={result['actions_count']}, 未读消息={result['unread_messages_processed']}"
)
return result
@@ -197,7 +197,9 @@ class AffinityFlowChatter:
def __repr__(self) -> str:
"""详细字符串表示"""
return (f"AffinityFlowChatter(stream_id={self.stream_id}, "
return (
f"AffinityFlowChatter(stream_id={self.stream_id}, "
f"messages_processed={self.stats['messages_processed']}, "
f"plans_created={self.stats['plans_created']}, "
f"last_activity={datetime.fromtimestamp(self.last_activity_time)})")
f"last_activity={datetime.fromtimestamp(self.last_activity_time)})"
)

View File

@@ -38,7 +38,9 @@ class InterestScoringSystem:
# 连续不回复概率提升
self.no_reply_count = 0
self.max_no_reply_count = affinity_config.max_no_reply_count
self.probability_boost_per_no_reply = affinity_config.no_reply_threshold_adjustment / affinity_config.max_no_reply_count # 每次不回复增加的概率
self.probability_boost_per_no_reply = (
affinity_config.no_reply_threshold_adjustment / affinity_config.max_no_reply_count
) # 每次不回复增加的概率
# 用户关系数据
self.user_relationships: Dict[str, float] = {} # user_id -> relationship_score
@@ -153,7 +155,9 @@ class InterestScoringSystem:
# 返回匹配分数,考虑置信度和匹配标签数量
affinity_config = global_config.affinity_flow
match_count_bonus = min(len(match_result.matched_tags) * affinity_config.match_count_bonus, affinity_config.max_match_bonus)
match_count_bonus = min(
len(match_result.matched_tags) * affinity_config.match_count_bonus, affinity_config.max_match_bonus
)
final_score = match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus
logger.debug(
f"⚖️ 最终分数计算: 总分({match_result.overall_score:.3f}) × 1.3 × 置信度({match_result.confidence:.3f}) + 标签数量奖励({match_count_bonus:.3f}) = {final_score:.3f}"
@@ -263,7 +267,17 @@ class InterestScoringSystem:
if not msg.processed_plain_text:
return 0.0
if msg.is_mentioned or (bot_nickname and bot_nickname in msg.processed_plain_text):
# 检查是否被提及
is_mentioned = msg.is_mentioned or (bot_nickname and bot_nickname in msg.processed_plain_text)
# 检查是否为私聊group_info为None表示私聊
is_private_chat = msg.group_info is None
# 如果被提及或是私聊都视为提及了bot
if is_mentioned or is_private_chat:
logger.debug(f"🔍 提及检测 - 被提及: {is_mentioned}, 私聊: {is_private_chat}")
if is_private_chat and not is_mentioned:
logger.debug("💬 私聊消息自动视为提及bot")
return global_config.affinity_flow.mention_bot_interest_score
return 0.0
@@ -282,7 +296,9 @@ class InterestScoringSystem:
logger.debug(f"📋 基础阈值: {base_threshold:.3f}")
# 如果被提及,降低阈值
if score.mentioned_score >= global_config.affinity_flow.mention_bot_interest_score * 0.5: # 使用提及bot兴趣分的一半作为判断阈值
if (
score.mentioned_score >= global_config.affinity_flow.mention_bot_interest_score * 0.5
): # 使用提及bot兴趣分的一半作为判断阈值
base_threshold = self.mention_threshold
logger.debug(f"📣 消息提及了机器人,使用降低阈值: {base_threshold:.3f}")
@@ -325,7 +341,9 @@ class InterestScoringSystem:
def update_user_relationship(self, user_id: str, relationship_change: float):
"""更新用户关系"""
old_score = self.user_relationships.get(user_id, global_config.affinity_flow.base_relationship_score) # 默认新用户分数
old_score = self.user_relationships.get(
user_id, global_config.affinity_flow.base_relationship_score
) # 默认新用户分数
new_score = max(0.0, min(1.0, old_score + relationship_change))
self.user_relationships[user_id] = new_score

View File

@@ -116,6 +116,7 @@ class UserRelationshipTracker:
try:
# 获取bot人设信息
from src.individuality.individuality import Individuality
individuality = Individuality()
bot_personality = await individuality.get_personality_block()
@@ -168,7 +169,17 @@ class UserRelationshipTracker:
# 清理LLM响应移除可能的格式标记
cleaned_response = self._clean_llm_json_response(llm_response)
response_data = json.loads(cleaned_response)
new_score = max(0.0, min(1.0, float(response_data.get("new_relationship_score", global_config.affinity_flow.base_relationship_score))))
new_score = max(
0.0,
min(
1.0,
float(
response_data.get(
"new_relationship_score", global_config.affinity_flow.base_relationship_score
)
),
),
)
if self.interest_scoring_system:
self.interest_scoring_system.update_user_relationship(
@@ -295,7 +306,9 @@ class UserRelationshipTracker:
# 更新缓存
self.user_relationship_cache[user_id] = {
"relationship_text": relationship_data.get("relationship_text", ""),
"relationship_score": relationship_data.get("relationship_score", global_config.affinity_flow.base_relationship_score),
"relationship_score": relationship_data.get(
"relationship_score", global_config.affinity_flow.base_relationship_score
),
"last_tracked": time.time(),
}
return relationship_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)
@@ -386,7 +399,11 @@ class UserRelationshipTracker:
# 获取当前关系数据
current_relationship = self._get_user_relationship_from_db(user_id)
current_score = current_relationship.get("relationship_score", global_config.affinity_flow.base_relationship_score) if current_relationship else global_config.affinity_flow.base_relationship_score
current_score = (
current_relationship.get("relationship_score", global_config.affinity_flow.base_relationship_score)
if current_relationship
else global_config.affinity_flow.base_relationship_score
)
current_text = current_relationship.get("relationship_text", "新用户") if current_relationship else "新用户"
# 使用LLM分析并更新关系
@@ -501,6 +518,7 @@ class UserRelationshipTracker:
# 获取bot人设信息
from src.individuality.individuality import Individuality
individuality = Individuality()
bot_personality = await individuality.get_personality_block()

View File

@@ -2,6 +2,7 @@
"""
表情包发送历史记录模块
"""
import os
from typing import List, Dict
from collections import deque

View File

@@ -524,9 +524,7 @@ class EmojiManager:
self.record_usage(selected_emoji.hash)
_time_end = time.time()
logger.info(
f"找到匹配描述的表情包: {selected_emoji.description}, 耗时: {(_time_end - _time_start):.2f}s"
)
logger.info(f"找到匹配描述的表情包: {selected_emoji.description}, 耗时: {(_time_end - _time_start):.2f}s")
# 8. 返回选中的表情包信息
return selected_emoji.full_path, f"[表情包:{selected_emoji.description}]", text_emotion
@@ -627,8 +625,9 @@ class EmojiManager:
# 无论steal_emoji是否开启都检查emoji文件夹以支持手动注册
# 只有在需要腾出空间或填充表情库时,才真正执行注册
if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or \
(self.emoji_num < self.emoji_num_max):
if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or (
self.emoji_num < self.emoji_num_max
):
try:
# 获取目录下所有图片文件
files_to_process = [
@@ -931,16 +930,21 @@ class EmojiManager:
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() if Image.open(io.BytesIO(image_bytes)).format else "jpeg"
image_format = (
Image.open(io.BytesIO(image_bytes)).format.lower()
if Image.open(io.BytesIO(image_bytes)).format
else "jpeg"
)
# 2. 检查数据库中是否已存在该表情包的描述,实现复用
existing_description = None
try:
with get_db_session() as session:
existing_image = session.query(Images).filter(
(Images.emoji_hash == image_hash) & (Images.type == "emoji")
).one_or_none()
existing_image = (
session.query(Images)
.filter((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
.one_or_none()
)
if existing_image and existing_image.description:
existing_description = existing_image.description
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")

View File

@@ -14,6 +14,7 @@ Chat Frequency Analyzer
- MIN_CHATS_FOR_PEAK: 在一个窗口内需要多少次聊天才能被认为是高峰时段。
- MIN_GAP_BETWEEN_PEAKS_HOURS: 两个独立高峰时段之间的最小间隔(小时)。
"""
import time as time_module
from datetime import datetime, timedelta, time
from typing import List, Tuple, Optional
@@ -71,7 +72,9 @@ class ChatFrequencyAnalyzer:
current_window_end = datetimes[i]
# 合并重叠或相邻的高峰时段
if peak_windows and current_window_start - peak_windows[-1][1] < timedelta(hours=MIN_GAP_BETWEEN_PEAKS_HOURS):
if peak_windows and current_window_start - peak_windows[-1][1] < timedelta(
hours=MIN_GAP_BETWEEN_PEAKS_HOURS
):
# 扩展上一个窗口的结束时间
peak_windows[-1] = (peak_windows[-1][0], current_window_end)
else:

View File

@@ -14,6 +14,7 @@ Frequency-Based Proactive Trigger
- TRIGGER_CHECK_INTERVAL_SECONDS: 触发器检查的周期(秒)。
- COOLDOWN_HOURS: 在同一个高峰时段内触发一次后的冷却时间(小时)。
"""
import asyncio
import time
from datetime import datetime
@@ -21,6 +22,7 @@ from typing import Dict, Optional
from src.common.logger import get_logger
from src.chat.affinity_flow.afc_manager import afc_manager
# TODO: 需要重新实现主动思考和睡眠管理功能
from .analyzer import chat_frequency_analyzer
@@ -74,7 +76,6 @@ class FrequencyBasedTrigger:
# 4. 检查当前是否是该用户的高峰聊天时间
if chat_frequency_analyzer.is_in_peak_time(chat_id, now):
# 5. 检查用户当前是否已有活跃的处理任务
# 亲和力流系统不直接提供循环状态,通过检查最后活动时间来判断是否忙碌
chatter = afc_manager.get_or_create_chatter(chat_id)

View File

@@ -4,14 +4,12 @@
"""
from .bot_interest_manager import BotInterestManager, bot_interest_manager
from src.common.data_models.bot_interest_data_model import (
BotInterestTag, BotPersonalityInterests, InterestMatchResult
)
from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
__all__ = [
"BotInterestManager",
"bot_interest_manager",
"BotInterestTag",
"BotPersonalityInterests",
"InterestMatchResult"
"InterestMatchResult",
]

View File

@@ -2,6 +2,7 @@
机器人兴趣标签管理系统
基于人设生成兴趣标签并使用embedding计算匹配度
"""
import orjson
import traceback
from typing import List, Dict, Optional, Any
@@ -10,9 +11,7 @@ import numpy as np
from src.common.logger import get_logger
from src.config.config import global_config
from src.common.data_models.bot_interest_data_model import (
BotPersonalityInterests, BotInterestTag, InterestMatchResult
)
from src.common.data_models.bot_interest_data_model import BotPersonalityInterests, BotInterestTag, InterestMatchResult
logger = get_logger("bot_interest_manager")
@@ -87,7 +86,7 @@ class BotInterestManager:
logger.debug("✅ 成功导入embedding相关模块")
# 检查embedding配置是否存在
if not hasattr(model_config.model_task_config, 'embedding'):
if not hasattr(model_config.model_task_config, "embedding"):
raise RuntimeError("❌ 未找到embedding模型配置")
logger.info("📋 找到embedding模型配置")
@@ -101,7 +100,7 @@ class BotInterestManager:
logger.info(f"🔗 客户端类型: {type(self.embedding_request).__name__}")
# 获取第一个embedding模型的ModelInfo
if hasattr(self.embedding_config, 'model_list') and self.embedding_config.model_list:
if hasattr(self.embedding_config, "model_list") and self.embedding_config.model_list:
first_model_name = self.embedding_config.model_list[0]
logger.info(f"🎯 使用embedding模型: {first_model_name}")
else:
@@ -127,7 +126,9 @@ class BotInterestManager:
# 生成新的兴趣标签
logger.info("🆕 数据库中未找到兴趣标签,开始生成新的...")
logger.info("🤖 正在调用LLM生成个性化兴趣标签...")
generated_interests = await self._generate_interests_from_personality(personality_description, personality_id)
generated_interests = await self._generate_interests_from_personality(
personality_description, personality_id
)
if generated_interests:
self.current_interests = generated_interests
@@ -140,14 +141,16 @@ class BotInterestManager:
else:
raise RuntimeError("❌ 兴趣标签生成失败")
async def _generate_interests_from_personality(self, personality_description: str, personality_id: str) -> Optional[BotPersonalityInterests]:
async def _generate_interests_from_personality(
self, personality_description: str, personality_id: str
) -> Optional[BotPersonalityInterests]:
"""根据人设生成兴趣标签"""
try:
logger.info("🎨 开始根据人设生成兴趣标签...")
logger.info(f"📝 人设长度: {len(personality_description)} 字符")
# 检查embedding客户端是否可用
if not hasattr(self, 'embedding_request'):
if not hasattr(self, "embedding_request"):
raise RuntimeError("❌ Embedding客户端未初始化无法生成兴趣标签")
# 构建提示词
@@ -190,8 +193,7 @@ class BotInterestManager:
interests_data = orjson.loads(response)
bot_interests = BotPersonalityInterests(
personality_id=personality_id,
personality_description=personality_description
personality_id=personality_id, personality_description=personality_description
)
# 解析生成的兴趣标签
@@ -202,10 +204,7 @@ class BotInterestManager:
tag_name = tag_data.get("name", f"标签_{i}")
weight = tag_data.get("weight", 0.5)
tag = BotInterestTag(
tag_name=tag_name,
weight=weight
)
tag = BotInterestTag(tag_name=tag_name, weight=weight)
bot_interests.interest_tags.append(tag)
logger.debug(f" 🏷️ {tag_name} (权重: {weight:.2f})")
@@ -225,7 +224,6 @@ class BotInterestManager:
traceback.print_exc()
raise
async def _call_llm_for_interest_generation(self, prompt: str) -> Optional[str]:
"""调用LLM生成兴趣标签"""
try:
@@ -252,12 +250,14 @@ class BotInterestManager:
model_config=replyer_config,
request_type="interest_generation",
temperature=0.7,
max_tokens=2000
max_tokens=2000,
)
if success and response:
logger.info(f"✅ LLM响应成功模型: {model_name}, 响应长度: {len(response)} 字符")
logger.debug(f"📄 LLM响应内容: {response[:200]}..." if len(response) > 200 else f"📄 LLM响应内容: {response}")
logger.debug(
f"📄 LLM响应内容: {response[:200]}..." if len(response) > 200 else f"📄 LLM响应内容: {response}"
)
if reasoning_content:
logger.debug(f"🧠 推理内容: {reasoning_content[:100]}...")
@@ -279,14 +279,14 @@ class BotInterestManager:
import re
# 移除 ```json 和 ``` 标记
cleaned = re.sub(r'```json\s*', '', response)
cleaned = re.sub(r'\s*```', '', cleaned)
cleaned = re.sub(r"```json\s*", "", response)
cleaned = re.sub(r"\s*```", "", cleaned)
# 移除可能的多余空格和换行
cleaned = cleaned.strip()
# 尝试提取JSON对象如果响应中有其他文本
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
json_match = re.search(r"\{.*\}", cleaned, re.DOTALL)
if json_match:
cleaned = json_match.group(0)
@@ -295,7 +295,7 @@ class BotInterestManager:
async def _generate_embeddings_for_tags(self, interests: BotPersonalityInterests):
"""为所有兴趣标签生成embedding"""
if not hasattr(self, 'embedding_request'):
if not hasattr(self, "embedding_request"):
raise RuntimeError("❌ Embedding客户端未初始化无法生成embedding")
total_tags = len(interests.interest_tags)
@@ -342,7 +342,7 @@ class BotInterestManager:
async def _get_embedding(self, text: str) -> List[float]:
"""获取文本的embedding向量"""
if not hasattr(self, 'embedding_request'):
if not hasattr(self, "embedding_request"):
raise RuntimeError("❌ Embedding请求客户端未初始化")
# 检查缓存
@@ -376,7 +376,9 @@ class BotInterestManager:
logger.debug(f"✅ 消息embedding生成成功维度: {len(embedding)}")
return embedding
async def _calculate_similarity_scores(self, result: InterestMatchResult, message_embedding: List[float], keywords: List[str]):
async def _calculate_similarity_scores(
self, result: InterestMatchResult, message_embedding: List[float], keywords: List[str]
):
"""计算消息与兴趣标签的相似度分数"""
try:
if not self.current_interests:
@@ -397,7 +399,9 @@ class BotInterestManager:
# 设置相似度阈值为0.3
if similarity > 0.3:
result.add_match(tag.tag_name, weighted_score, keywords)
logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 加权分数={weighted_score:.3f}")
logger.debug(
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 加权分数={weighted_score:.3f}"
)
except Exception as e:
logger.error(f"❌ 计算相似度分数失败: {e}")
@@ -455,7 +459,9 @@ class BotInterestManager:
match_count += 1
high_similarity_count += 1
result.add_match(tag.tag_name, enhanced_score, [tag.tag_name])
logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [高匹配]")
logger.debug(
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [高匹配]"
)
elif similarity > medium_threshold:
# 中相似度:中等加成
@@ -463,7 +469,9 @@ class BotInterestManager:
match_count += 1
medium_similarity_count += 1
result.add_match(tag.tag_name, enhanced_score, [tag.tag_name])
logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [中匹配]")
logger.debug(
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [中匹配]"
)
elif similarity > low_threshold:
# 低相似度:轻微加成
@@ -471,7 +479,9 @@ class BotInterestManager:
match_count += 1
low_similarity_count += 1
result.add_match(tag.tag_name, enhanced_score, [tag.tag_name])
logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [低匹配]")
logger.debug(
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [低匹配]"
)
logger.info(f"📈 匹配统计: {match_count}/{len(active_tags)} 个标签超过阈值")
logger.info(f"🔥 高相似度匹配(>{high_threshold}): {high_similarity_count}")
@@ -488,7 +498,9 @@ class BotInterestManager:
original_score = result.match_scores[tag_name]
bonus = keyword_bonus[tag_name]
result.match_scores[tag_name] = original_score + bonus
logger.debug(f" 🏷️ '{tag_name}': 原始分数={original_score:.3f}, 奖励={bonus:.3f}, 最终分数={result.match_scores[tag_name]:.3f}")
logger.debug(
f" 🏷️ '{tag_name}': 原始分数={original_score:.3f}, 奖励={bonus:.3f}, 最终分数={result.match_scores[tag_name]:.3f}"
)
# 计算总体分数
result.calculate_overall_score()
@@ -499,10 +511,11 @@ class BotInterestManager:
result.top_tag = top_tag_name
logger.info(f"🏆 最佳匹配标签: '{top_tag_name}' (分数: {result.match_scores[top_tag_name]:.3f})")
logger.info(f"📊 最终结果: 总分={result.overall_score:.3f}, 置信度={result.confidence:.3f}, 匹配标签数={len(result.matched_tags)}")
logger.info(
f"📊 最终结果: 总分={result.overall_score:.3f}, 置信度={result.confidence:.3f}, 匹配标签数={len(result.matched_tags)}"
)
return result
def _calculate_keyword_match_bonus(self, keywords: List[str], matched_tags: List[str]) -> Dict[str, float]:
"""计算关键词直接匹配奖励"""
if not keywords or not matched_tags:
@@ -522,17 +535,25 @@ class BotInterestManager:
# 完全匹配
if keyword_lower == tag_name_lower:
bonus += affinity_config.high_match_interest_threshold * 0.6 # 使用高匹配阈值的60%作为完全匹配奖励
logger.debug(f" 🎯 关键词完全匹配: '{keyword}' == '{tag_name}' (+{affinity_config.high_match_interest_threshold * 0.6:.3f})")
logger.debug(
f" 🎯 关键词完全匹配: '{keyword}' == '{tag_name}' (+{affinity_config.high_match_interest_threshold * 0.6:.3f})"
)
# 包含匹配
elif keyword_lower in tag_name_lower or tag_name_lower in keyword_lower:
bonus += affinity_config.medium_match_interest_threshold * 0.3 # 使用中匹配阈值的30%作为包含匹配奖励
logger.debug(f" 🎯 关键词包含匹配: '{keyword}''{tag_name}' (+{affinity_config.medium_match_interest_threshold * 0.3:.3f})")
bonus += (
affinity_config.medium_match_interest_threshold * 0.3
) # 使用中匹配阈值的30%作为包含匹配奖励
logger.debug(
f" 🎯 关键词包含匹配: '{keyword}''{tag_name}' (+{affinity_config.medium_match_interest_threshold * 0.3:.3f})"
)
# 部分匹配(编辑距离)
elif self._calculate_partial_match(keyword_lower, tag_name_lower):
bonus += affinity_config.low_match_interest_threshold * 0.4 # 使用低匹配阈值的40%作为部分匹配奖励
logger.debug(f" 🎯 关键词部分匹配: '{keyword}''{tag_name}' (+{affinity_config.low_match_interest_threshold * 0.4:.3f})")
logger.debug(
f" 🎯 关键词部分匹配: '{keyword}''{tag_name}' (+{affinity_config.low_match_interest_threshold * 0.4:.3f})"
)
if bonus > 0:
bonus_dict[tag_name] = min(bonus, affinity_config.max_match_bonus) # 使用配置的最大奖励限制
@@ -608,12 +629,12 @@ class BotInterestManager:
with get_db_session() as session:
# 查询最新的兴趣标签配置
db_interests = session.query(DBBotPersonalityInterests).filter(
DBBotPersonalityInterests.personality_id == personality_id
).order_by(
DBBotPersonalityInterests.version.desc(),
DBBotPersonalityInterests.last_updated.desc()
).first()
db_interests = (
session.query(DBBotPersonalityInterests)
.filter(DBBotPersonalityInterests.personality_id == personality_id)
.order_by(DBBotPersonalityInterests.version.desc(), DBBotPersonalityInterests.last_updated.desc())
.first()
)
if db_interests:
logger.info(f"✅ 找到数据库中的兴趣标签配置,版本: {db_interests.version}")
@@ -631,7 +652,7 @@ class BotInterestManager:
personality_description=db_interests.personality_description,
embedding_model=db_interests.embedding_model,
version=db_interests.version,
last_updated=db_interests.last_updated
last_updated=db_interests.last_updated,
)
# 解析兴趣标签
@@ -639,10 +660,14 @@ class BotInterestManager:
tag = BotInterestTag(
tag_name=tag_data.get("tag_name", ""),
weight=tag_data.get("weight", 0.5),
created_at=datetime.fromisoformat(tag_data.get("created_at", datetime.now().isoformat())),
updated_at=datetime.fromisoformat(tag_data.get("updated_at", datetime.now().isoformat())),
created_at=datetime.fromisoformat(
tag_data.get("created_at", datetime.now().isoformat())
),
updated_at=datetime.fromisoformat(
tag_data.get("updated_at", datetime.now().isoformat())
),
is_active=tag_data.get("is_active", True),
embedding=tag_data.get("embedding")
embedding=tag_data.get("embedding"),
)
interests.interest_tags.append(tag)
@@ -685,7 +710,7 @@ class BotInterestManager:
"created_at": tag.created_at.isoformat(),
"updated_at": tag.updated_at.isoformat(),
"is_active": tag.is_active,
"embedding": tag.embedding
"embedding": tag.embedding,
}
tags_data.append(tag_dict)
@@ -694,9 +719,11 @@ class BotInterestManager:
with get_db_session() as session:
# 检查是否已存在相同personality_id的记录
existing_record = session.query(DBBotPersonalityInterests).filter(
DBBotPersonalityInterests.personality_id == interests.personality_id
).first()
existing_record = (
session.query(DBBotPersonalityInterests)
.filter(DBBotPersonalityInterests.personality_id == interests.personality_id)
.first()
)
if existing_record:
# 更新现有记录
@@ -718,7 +745,7 @@ class BotInterestManager:
interest_tags=json_data,
embedding_model=interests.embedding_model,
version=interests.version,
last_updated=interests.last_updated
last_updated=interests.last_updated,
)
session.add(new_record)
session.commit()
@@ -728,9 +755,11 @@ class BotInterestManager:
# 验证保存是否成功
with get_db_session() as session:
saved_record = session.query(DBBotPersonalityInterests).filter(
DBBotPersonalityInterests.personality_id == interests.personality_id
).first()
saved_record = (
session.query(DBBotPersonalityInterests)
.filter(DBBotPersonalityInterests.personality_id == interests.personality_id)
.first()
)
session.commit()
if saved_record:
logger.info(f"✅ 验证成功数据库中存在personality_id为 {interests.personality_id} 的记录")
@@ -760,7 +789,7 @@ class BotInterestManager:
"total_tags": len(active_tags),
"embedding_model": self.current_interests.embedding_model,
"last_updated": self.current_interests.last_updated.isoformat(),
"cache_size": len(self.embedding_cache)
"cache_size": len(self.embedding_cache),
}
async def update_interest_tags(self, new_personality_description: str = None):
@@ -775,8 +804,7 @@ class BotInterestManager:
# 重新生成兴趣标签
new_interests = await self._generate_interests_from_personality(
self.current_interests.personality_description,
self.current_interests.personality_id
self.current_interests.personality_description, self.current_interests.personality_id
)
if new_interests:

View File

@@ -4,13 +4,11 @@
"""
from .message_manager import MessageManager, message_manager
from src.common.data_models.message_manager_data_model import StreamContext, MessageStatus, MessageManagerStats, StreamStats
from src.common.data_models.message_manager_data_model import (
StreamContext,
MessageStatus,
MessageManagerStats,
StreamStats,
)
__all__ = [
"MessageManager",
"message_manager",
"StreamContext",
"MessageStatus",
"MessageManagerStats",
"StreamStats"
]
__all__ = ["MessageManager", "message_manager", "StreamContext", "MessageStatus", "MessageManagerStats", "StreamStats"]

View File

@@ -2,6 +2,7 @@
消息管理模块
管理每个聊天流的上下文信息,包含历史记录和未读消息,定期检查并处理新消息
"""
import asyncio
import time
import traceback
@@ -100,9 +101,7 @@ class MessageManager:
# 如果没有处理任务,创建一个
if not context.processing_task or context.processing_task.done():
context.processing_task = asyncio.create_task(
self._process_stream_messages(stream_id)
)
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
# 更新统计
self.stats.active_streams = active_streams
@@ -175,7 +174,7 @@ class MessageManager:
unread_count=len(context.get_unread_messages()),
history_count=len(context.history_messages),
last_check_time=context.last_check_time,
has_active_task=context.processing_task and not context.processing_task.done()
has_active_task=context.processing_task and not context.processing_task.done(),
)
def get_manager_stats(self) -> Dict[str, Any]:
@@ -186,7 +185,7 @@ class MessageManager:
"total_unread_messages": self.stats.total_unread_messages,
"total_processed_messages": self.stats.total_processed_messages,
"uptime": self.stats.uptime,
"start_time": self.stats.start_time
"start_time": self.stats.start_time,
}
def cleanup_inactive_streams(self, max_inactive_hours: int = 24):
@@ -196,8 +195,7 @@ class MessageManager:
inactive_streams = []
for stream_id, context in self.stream_contexts.items():
if (current_time - context.last_check_time > max_inactive_seconds and
not context.get_unread_messages()):
if current_time - context.last_check_time > max_inactive_seconds and not context.get_unread_messages():
inactive_streams.append(stream_id)
for stream_id in inactive_streams:

View File

@@ -17,6 +17,7 @@ from src.plugin_system.core import component_registry, event_manager, global_ann
from src.plugin_system.base import BaseCommand, EventType
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
from src.chat.utils.utils import is_mentioned_bot_in_message
# 导入反注入系统
from src.chat.antipromptinjector import initialize_anti_injector
@@ -511,7 +512,7 @@ class ChatBot:
chat_info_user_id=message.chat_stream.user_info.user_id,
chat_info_user_nickname=message.chat_stream.user_info.user_nickname,
chat_info_user_cardname=message.chat_stream.user_info.user_cardname,
chat_info_user_platform=message.chat_stream.user_info.platform
chat_info_user_platform=message.chat_stream.user_info.platform,
)
# 如果是群聊,添加群组信息

View File

@@ -84,7 +84,9 @@ class ChatStream:
self.saved = False
self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息
# 从配置文件中读取focus_value如果没有则使用默认值1.0
self.focus_energy = data.get("focus_energy", global_config.chat.focus_value) if data else global_config.chat.focus_value
self.focus_energy = (
data.get("focus_energy", global_config.chat.focus_value) if data else global_config.chat.focus_value
)
self.no_reply_consecutive = 0
self.breaking_accumulated_interest = 0.0

View File

@@ -168,7 +168,12 @@ class ActionManager:
if not chat_stream:
logger.error(f"{log_prefix} 无法找到chat_id对应的chat_stream: {chat_id}")
return {"action_type": action_name, "success": False, "reply_text": "", "error": "chat_stream not found"}
return {
"action_type": action_name,
"success": False,
"reply_text": "",
"error": "chat_stream not found",
}
if action_name == "no_action":
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}

View File

@@ -1,6 +1,7 @@
"""
PlanGenerator: 负责搜集和汇总所有决策所需的信息,生成一个未经筛选的“原始计划” (Plan)。
"""
import time
from typing import Dict
@@ -35,6 +36,7 @@ class PlanGenerator:
chat_id (str): 当前聊天的 ID。
"""
from src.chat.planner_actions.action_manager import ActionManager
self.chat_id = chat_id
# 注意ActionManager 可能需要根据实际情况初始化
self.action_manager = ActionManager()
@@ -65,7 +67,6 @@ class PlanGenerator:
)
chat_history = [DatabaseMessages(**msg) for msg in chat_history_raw]
plan = Plan(
chat_id=self.chat_id,
mode=mode,
@@ -99,16 +100,13 @@ class PlanGenerator:
name="reply",
component_type=ComponentType.ACTION,
description="系统级动作:选择回复消息的决策",
action_parameters={
"content": "回复的文本内容",
"reply_to_message_id": "要回复的消息ID"
},
action_parameters={"content": "回复的文本内容", "reply_to_message_id": "要回复的消息ID"},
action_require=[
"你想要闲聊或者随便附和",
"当用户提到你或艾特你时",
"当需要回答用户的问题时",
"当你想参与对话时",
"当用户分享有趣的内容时"
"当用户分享有趣的内容时",
],
activation_type=ActionActivationType.ALWAYS,
activation_keywords=[],

View File

@@ -109,9 +109,7 @@ class ActionPlanner:
self.planner_stats["failed_plans"] += 1
return [], None
async def _enhanced_plan_flow(
self, mode: ChatMode, context: StreamContext
) -> Tuple[List[Dict], Optional[Dict]]:
async def _enhanced_plan_flow(self, mode: ChatMode, context: StreamContext) -> Tuple[List[Dict], Optional[Dict]]:
"""执行增强版规划流程"""
try:
# 1. 生成初始 Plan
@@ -137,7 +135,9 @@ class ActionPlanner:
# 检查兴趣度是否达到非回复动作阈值
non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
if score < non_reply_action_interest_threshold:
logger.info(f"❌ 兴趣度不足非回复动作阈值: {score:.3f} < {non_reply_action_interest_threshold:.3f}直接返回no_action")
logger.info(
f"❌ 兴趣度不足非回复动作阈值: {score:.3f} < {non_reply_action_interest_threshold:.3f}直接返回no_action"
)
logger.info(f"📊 最低要求: {non_reply_action_interest_threshold:.3f}")
# 直接返回 no_action
from src.common.data_models.info_data_model import ActionPlannerInfo

View File

@@ -598,6 +598,7 @@ class DefaultReplyer:
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
"""解析回复目标消息 - 使用共享工具"""
from src.chat.utils.prompt import Prompt
if target_message is None:
logger.warning("target_message为None返回默认值")
return "未知用户", "(无消息内容)"
@@ -704,21 +705,23 @@ class DefaultReplyer:
unread_history_prompt = ""
if unread_messages:
# 尝试获取兴趣度评分
interest_scores = await self._get_interest_scores_for_messages([msg.flatten() for msg in unread_messages])
interest_scores = await self._get_interest_scores_for_messages(
[msg.flatten() for msg in unread_messages]
)
unread_lines = []
for msg in unread_messages:
msg_id = msg.message_id
msg_time = time.strftime('%H:%M:%S', time.localtime(msg.time))
msg_time = time.strftime("%H:%M:%S", time.localtime(msg.time))
msg_content = msg.processed_plain_text
# 使用与已读历史消息相同的方法获取用户名
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
# 获取用户信息
user_info = getattr(msg, 'user_info', {})
platform = getattr(user_info, 'platform', '') or getattr(msg, 'platform', '')
user_id = getattr(user_info, 'user_id', '') or getattr(msg, 'user_id', '')
user_info = getattr(msg, "user_info", {})
platform = getattr(user_info, "platform", "") or getattr(msg, "platform", "")
user_id = getattr(user_info, "user_id", "") or getattr(msg, "user_id", "")
# 获取用户名
if platform and user_id:
@@ -808,7 +811,7 @@ class DefaultReplyer:
unread_lines = []
for msg in unread_messages:
msg_id = msg.get("message_id", "")
msg_time = time.strftime('%H:%M:%S', time.localtime(msg.get("time", time.time())))
msg_time = time.strftime("%H:%M:%S", time.localtime(msg.get("time", time.time())))
msg_content = msg.get("processed_plain_text", "")
# 使用与已读历史消息相同的方法获取用户名
@@ -834,7 +837,9 @@ class DefaultReplyer:
unread_lines.append(f"{msg_time} {sender_name}: {msg_content}{interest_text}")
unread_history_prompt_str = "\n".join(unread_lines)
unread_history_prompt = f"这是未读历史消息,包含兴趣度评分,请优先对兴趣值高的消息做出动作:\n{unread_history_prompt_str}"
unread_history_prompt = (
f"这是未读历史消息,包含兴趣度评分,请优先对兴趣值高的消息做出动作:\n{unread_history_prompt_str}"
)
else:
unread_history_prompt = "暂无未读历史消息"
@@ -1052,6 +1057,7 @@ class DefaultReplyer:
target_user_info = await person_info_manager.get_person_info_by_name(sender)
from src.chat.utils.prompt import Prompt
# 并行执行六个构建任务
task_results = await asyncio.gather(
self._time_and_run_task(
@@ -1127,6 +1133,7 @@ class DefaultReplyer:
schedule_block = ""
if global_config.planning_system.schedule_enable:
from src.schedule.schedule_manager import schedule_manager
current_activity = schedule_manager.get_current_activity()
if current_activity:
schedule_block = f"你当前正在:{current_activity}"

View File

@@ -235,7 +235,7 @@ class Prompt:
template: str,
name: Optional[str] = None,
parameters: Optional[PromptParameters] = None,
should_register: bool = True
should_register: bool = True,
):
"""
初始化统一提示词
@@ -420,7 +420,8 @@ class Prompt:
await self._build_normal_chat_context(context_data)
# 补充基础信息
context_data.update({
context_data.update(
{
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt,
"extra_info_block": self.parameters.extra_info_block,
"time_block": self.parameters.time_block or f"当前时间:{time.strftime('%Y-%m-%d %H:%M:%S')}",
@@ -430,7 +431,8 @@ class Prompt:
"reply_target_block": self.parameters.reply_target_block,
"mood_state": self.parameters.mood_prompt,
"action_descriptions": self.parameters.action_descriptions,
})
}
)
total_time = time.time() - start_time
logger.debug(f"上下文构建完成,总耗时: {total_time:.2f}s")
@@ -446,7 +448,7 @@ class Prompt:
self.parameters.message_list_before_now_long,
self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "",
self.parameters.sender,
self.parameters.chat_id
self.parameters.chat_id,
)
context_data["read_history_prompt"] = read_history_prompt
@@ -476,8 +478,6 @@ class Prompt:
except Exception as e:
logger.error(f"构建S4U历史消息prompt失败: {e}")
async def _build_expression_habits(self) -> Dict[str, Any]:
"""构建表达习惯"""
if not global_config.expression.enable_expression:
@@ -491,10 +491,7 @@ class Prompt:
if self.parameters.message_list_before_now_long:
recent_messages = self.parameters.message_list_before_now_long[-10:]
chat_history = build_readable_messages(
recent_messages,
replace_bot_name=True,
timestamp_mode="normal",
truncate=True
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
)
# 创建表情选择器
@@ -505,7 +502,7 @@ class Prompt:
chat_history=chat_history,
current_message=self.parameters.target,
emotional_tone="neutral",
topic_type="general"
topic_type="general",
)
# 构建表达习惯块
@@ -535,17 +532,13 @@ class Prompt:
if self.parameters.message_list_before_now_long:
recent_messages = self.parameters.message_list_before_now_long[-20:]
chat_history = build_readable_messages(
recent_messages,
replace_bot_name=True,
timestamp_mode="normal",
truncate=True
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
)
# 激活长期记忆
memory_activator = MemoryActivator()
running_memories = await memory_activator.activate_memory_with_chat_history(
target_message=self.parameters.target,
chat_history_prompt=chat_history
target_message=self.parameters.target, chat_history_prompt=chat_history
)
# 获取即时记忆
@@ -593,10 +586,7 @@ class Prompt:
if self.parameters.message_list_before_now_long:
recent_messages = self.parameters.message_list_before_now_long[-15:]
chat_history = build_readable_messages(
recent_messages,
replace_bot_name=True,
timestamp_mode="normal",
truncate=True
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
)
# 创建工具执行器
@@ -607,7 +597,7 @@ class Prompt:
sender=self.parameters.sender,
target_message=self.parameters.target,
chat_history=chat_history,
return_details=False
return_details=False,
)
# 构建工具信息块
@@ -649,10 +639,7 @@ class Prompt:
# 搜索相关知识
knowledge_results = await qa_manager.get_knowledge(
question=question,
chat_id=self.parameters.chat_id,
max_results=5,
min_similarity=0.5
question=question, chat_id=self.parameters.chat_id, max_results=5, min_similarity=0.5
)
# 构建知识块
@@ -725,9 +712,11 @@ class Prompt:
"time_block": context_data.get("time_block", ""),
"reply_target_block": context_data.get("reply_target_block", ""),
"reply_style": global_config.personality.reply_style,
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
"safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""),
"safety_guidelines_block": self.parameters.safety_guidelines_block
or context_data.get("safety_guidelines_block", ""),
}
def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
@@ -749,9 +738,11 @@ class Prompt:
"reply_target_block": context_data.get("reply_target_block", ""),
"config_expression_style": global_config.personality.reply_style,
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
"safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""),
"safety_guidelines_block": self.parameters.safety_guidelines_block
or context_data.get("safety_guidelines_block", ""),
}
def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
@@ -769,9 +760,11 @@ class Prompt:
"reason": "",
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"reply_style": global_config.personality.reply_style,
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
"safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""),
"safety_guidelines_block": self.parameters.safety_guidelines_block
or context_data.get("safety_guidelines_block", ""),
}
def format(self, *args, **kwargs) -> str:
@@ -872,9 +865,7 @@ class Prompt:
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
@staticmethod
async def build_cross_context(
chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]
) -> str:
async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]) -> str:
"""
构建跨群聊上下文 - 统一实现
@@ -937,10 +928,7 @@ class Prompt:
# 工厂函数
def create_prompt(
template: str,
name: Optional[str] = None,
parameters: Optional[PromptParameters] = None,
**kwargs
template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
) -> Prompt:
"""快速创建Prompt实例的工厂函数"""
if parameters is None:
@@ -949,14 +937,10 @@ def create_prompt(
async def create_prompt_async(
template: str,
name: Optional[str] = None,
parameters: Optional[PromptParameters] = None,
**kwargs
template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
) -> Prompt:
"""异步创建Prompt实例"""
prompt = create_prompt(template, name, parameters, **kwargs)
if global_prompt_manager._context._current_context:
await global_prompt_manager._context.register_async(prompt)
return prompt

View File

@@ -6,6 +6,7 @@ class BaseDataModel:
def deepcopy(self):
return copy.deepcopy(self)
def temporarily_transform_class_to_dict(obj: Any) -> Any:
# sourcery skip: assign-if-exp, reintroduce-else
"""

View File

@@ -2,6 +2,7 @@
机器人兴趣标签数据模型
定义机器人的兴趣标签和相关的embedding数据结构
"""
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Any
from datetime import datetime
@@ -12,6 +13,7 @@ from . import BaseDataModel
@dataclass
class BotInterestTag(BaseDataModel):
"""机器人兴趣标签"""
tag_name: str
weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0)
embedding: Optional[List[float]] = None # 标签的embedding向量
@@ -27,7 +29,7 @@ class BotInterestTag(BaseDataModel):
"embedding": self.embedding,
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
"is_active": self.is_active
"is_active": self.is_active,
}
@classmethod
@@ -39,13 +41,14 @@ class BotInterestTag(BaseDataModel):
embedding=data.get("embedding"),
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(),
updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else datetime.now(),
is_active=data.get("is_active", True)
is_active=data.get("is_active", True),
)
@dataclass
class BotPersonalityInterests(BaseDataModel):
"""机器人人格化兴趣配置"""
personality_id: str
personality_description: str # 人设描述文本
interest_tags: List[BotInterestTag] = field(default_factory=list)
@@ -57,7 +60,6 @@ class BotPersonalityInterests(BaseDataModel):
"""获取活跃的兴趣标签"""
return [tag for tag in self.interest_tags if tag.is_active]
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
@@ -66,7 +68,7 @@ class BotPersonalityInterests(BaseDataModel):
"interest_tags": [tag.to_dict() for tag in self.interest_tags],
"embedding_model": self.embedding_model,
"last_updated": self.last_updated.isoformat(),
"version": self.version
"version": self.version,
}
@classmethod
@@ -78,13 +80,14 @@ class BotPersonalityInterests(BaseDataModel):
interest_tags=[BotInterestTag.from_dict(tag_data) for tag_data in data.get("interest_tags", [])],
embedding_model=data.get("embedding_model", "text-embedding-ada-002"),
last_updated=datetime.fromisoformat(data["last_updated"]) if data.get("last_updated") else datetime.now(),
version=data.get("version", 1)
version=data.get("version", 1),
)
@dataclass
class InterestMatchResult(BaseDataModel):
"""兴趣匹配结果"""
message_id: str
matched_tags: List[str] = field(default_factory=list)
match_scores: Dict[str, float] = field(default_factory=dict) # tag_name -> score
@@ -120,7 +123,9 @@ class InterestMatchResult(BaseDataModel):
# 计算置信度(基于匹配标签数量和分数分布)
if len(self.match_scores) > 0:
avg_score = self.overall_score
score_variance = sum((score - avg_score) ** 2 for score in self.match_scores.values()) / len(self.match_scores)
score_variance = sum((score - avg_score) ** 2 for score in self.match_scores.values()) / len(
self.match_scores
)
# 分数越集中,置信度越高
self.confidence = max(0.0, 1.0 - score_variance)
else:

View File

@@ -208,6 +208,7 @@ class DatabaseMessages(BaseDataModel):
"chat_info_user_cardname": self.chat_info.user_info.user_cardname,
}
@dataclass(init=False)
class DatabaseActionRecords(BaseDataModel):
def __init__(

View File

@@ -28,6 +28,7 @@ class ActionPlannerInfo(BaseDataModel):
@dataclass
class InterestScore(BaseDataModel):
"""兴趣度评分结果"""
message_id: str
total_score: float
interest_match_score: float
@@ -41,6 +42,7 @@ class Plan(BaseDataModel):
"""
统一规划数据模型
"""
chat_id: str
mode: "ChatMode"

View File

@@ -2,9 +2,11 @@ from dataclasses import dataclass
from typing import Optional, List, Tuple, TYPE_CHECKING, Any
from . import BaseDataModel
if TYPE_CHECKING:
from src.llm_models.payload_content.tool_option import ToolCall
@dataclass
class LLMGenerationDataModel(BaseDataModel):
content: Optional[str] = None

View File

@@ -2,6 +2,7 @@
消息管理模块数据模型
定义消息管理器使用的数据结构
"""
import asyncio
import time
from dataclasses import dataclass, field
@@ -16,6 +17,7 @@ if TYPE_CHECKING:
class MessageStatus(Enum):
"""消息状态枚举"""
UNREAD = "unread" # 未读消息
READ = "read" # 已读消息
PROCESSING = "processing" # 处理中
@@ -24,6 +26,7 @@ class MessageStatus(Enum):
@dataclass
class StreamContext(BaseDataModel):
"""聊天流上下文信息"""
stream_id: str
unread_messages: List["DatabaseMessages"] = field(default_factory=list)
history_messages: List["DatabaseMessages"] = field(default_factory=list)
@@ -59,6 +62,7 @@ class StreamContext(BaseDataModel):
@dataclass
class MessageManagerStats(BaseDataModel):
"""消息管理器统计信息"""
total_streams: int = 0
active_streams: int = 0
total_unread_messages: int = 0
@@ -74,6 +78,7 @@ class MessageManagerStats(BaseDataModel):
@dataclass
class StreamStats(BaseDataModel):
"""聊天流统计信息"""
stream_id: str
is_active: bool
unread_count: int

View File

@@ -31,7 +31,9 @@ class TelemetryHeartBeatTask(AsyncTask):
self.client_uuid: str | None = local_storage["mofox_uuid"] if "mofox_uuid" in local_storage else None # type: ignore
"""客户端UUID"""
self.private_key_pem: str | None = local_storage["mofox_private_key"] if "mofox_private_key" in local_storage else None # type: ignore
self.private_key_pem: str | None = (
local_storage["mofox_private_key"] if "mofox_private_key" in local_storage else None
) # type: ignore
"""客户端私钥"""
self.info_dict = self._get_sys_info()
@@ -75,10 +77,7 @@ class TelemetryHeartBeatTask(AsyncTask):
sign_data = f"{self.client_uuid}:{timestamp}:{json.dumps(request_body, separators=(',', ':'))}"
# 加载私钥
private_key = serialization.load_pem_private_key(
self.private_key_pem.encode('utf-8'),
password=None
)
private_key = serialization.load_pem_private_key(self.private_key_pem.encode("utf-8"), password=None)
# 确保是RSA私钥
if not isinstance(private_key, rsa.RSAPrivateKey):
@@ -86,16 +85,13 @@ class TelemetryHeartBeatTask(AsyncTask):
# 生成签名
signature = private_key.sign(
sign_data.encode('utf-8'),
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
sign_data.encode("utf-8"),
padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH),
hashes.SHA256(),
)
# Base64编码
signature_b64 = base64.b64encode(signature).decode('utf-8')
signature_b64 = base64.b64encode(signature).decode("utf-8")
return timestamp, signature_b64
@@ -113,10 +109,7 @@ class TelemetryHeartBeatTask(AsyncTask):
raise ValueError("私钥未初始化")
# 加载私钥
private_key = serialization.load_pem_private_key(
self.private_key_pem.encode('utf-8'),
password=None
)
private_key = serialization.load_pem_private_key(self.private_key_pem.encode("utf-8"), password=None)
# 确保是RSA私钥
if not isinstance(private_key, rsa.RSAPrivateKey):
@@ -125,14 +118,10 @@ class TelemetryHeartBeatTask(AsyncTask):
# 解密挑战数据
decrypted_bytes = private_key.decrypt(
base64.b64decode(challenge_b64),
padding.OAEP(
mgf=padding.MGF1(hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
padding.OAEP(mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None),
)
return decrypted_bytes.decode('utf-8')
return decrypted_bytes.decode("utf-8")
async def _req_uuid(self) -> bool:
"""
@@ -155,14 +144,12 @@ class TelemetryHeartBeatTask(AsyncTask):
if response.status != 200:
response_text = await response.text()
logger.error(
f"注册步骤1失败状态码: {response.status}, 响应内容: {response_text}"
)
logger.error(f"注册步骤1失败状态码: {response.status}, 响应内容: {response_text}")
raise aiohttp.ClientResponseError(
request_info=response.request_info,
history=response.history,
status=response.status,
message=f"Step1 failed: {response_text}"
message=f"Step1 failed: {response_text}",
)
step1_data = await response.json()
@@ -195,10 +182,7 @@ class TelemetryHeartBeatTask(AsyncTask):
# Step 2: 发送解密结果完成注册
async with session.post(
f"{TELEMETRY_SERVER_URL}/stat/reg_client_step2",
json={
"temp_uuid": temp_uuid,
"decrypted_uuid": decrypted_uuid
},
json={"temp_uuid": temp_uuid, "decrypted_uuid": decrypted_uuid},
timeout=aiohttp.ClientTimeout(total=5),
) as response:
logger.debug(f"Step2 Response status: {response.status}")
@@ -225,23 +209,19 @@ class TelemetryHeartBeatTask(AsyncTask):
raise ValueError(f"Step2失败: {response_text}")
else:
response_text = await response.text()
logger.error(
f"注册步骤2失败状态码: {response.status}, 响应内容: {response_text}"
)
logger.error(f"注册步骤2失败状态码: {response.status}, 响应内容: {response_text}")
raise aiohttp.ClientResponseError(
request_info=response.request_info,
history=response.history,
status=response.status,
message=f"Step2 failed: {response_text}"
message=f"Step2 failed: {response_text}",
)
except Exception as e:
import traceback
error_msg = str(e) or "未知错误"
logger.warning(
f"注册客户端出错,不过你还是可以正常使用墨狐: {type(e).__name__}: {error_msg}"
)
logger.warning(f"注册客户端出错,不过你还是可以正常使用墨狐: {type(e).__name__}: {error_msg}")
logger.debug(f"完整错误信息: {traceback.format_exc()}")
# 请求失败,重试次数+1
@@ -270,7 +250,7 @@ class TelemetryHeartBeatTask(AsyncTask):
"X-mofox-Signature": signature,
"X-mofox-Timestamp": timestamp,
"User-Agent": f"MofoxClient/{self.client_uuid[:8]}",
"Content-Type": "application/json"
"Content-Type": "application/json",
}
logger.debug(f"正在发送心跳到服务器: {self.server_url}")

View File

@@ -99,7 +99,6 @@ def get_global_server() -> Server:
"""获取全局服务器实例"""
global global_server
if global_server is None:
host = os.getenv("HOST", "127.0.0.1")
port_str = os.getenv("PORT", "8000")

View File

@@ -44,7 +44,7 @@ from src.config.official_configs import (
PermissionConfig,
CommandConfig,
PlanningSystemConfig,
AffinityFlowConfig
AffinityFlowConfig,
)
from .api_ada_configs import (
@@ -399,9 +399,7 @@ class Config(ValidatedConfigBase):
cross_context: CrossContextConfig = Field(
default_factory=lambda: CrossContextConfig(), description="跨群聊上下文共享配置"
)
affinity_flow: AffinityFlowConfig = Field(
default_factory=lambda: AffinityFlowConfig(), description="亲和流配置"
)
affinity_flow: AffinityFlowConfig = Field(default_factory=lambda: AffinityFlowConfig(), description="亲和流配置")
class APIAdapterConfig(ValidatedConfigBase):

View File

@@ -51,8 +51,12 @@ class PersonalityConfig(ValidatedConfigBase):
personality_core: str = Field(..., description="核心人格")
personality_side: str = Field(..., description="人格侧写")
identity: str = Field(default="", description="身份特征")
background_story: str = Field(default="", description="世界观背景故事这部分内容会作为背景知识LLM被指导不应主动复述")
safety_guidelines: List[str] = Field(default_factory=list, description="安全与互动底线Bot在任何情况下都必须遵守的原则")
background_story: str = Field(
default="", description="世界观背景故事这部分内容会作为背景知识LLM被指导不应主动复述"
)
safety_guidelines: List[str] = Field(
default_factory=list, description="安全与互动底线Bot在任何情况下都必须遵守的原则"
)
reply_style: str = Field(default="", description="表达风格")
prompt_mode: Literal["s4u", "normal"] = Field(default="s4u", description="Prompt模式")
compress_personality: bool = Field(default=True, description="是否压缩人格")
@@ -79,7 +83,8 @@ class ChatConfig(ValidatedConfigBase):
talk_frequency_adjust: list[list[str]] = Field(default_factory=lambda: [], description="聊天频率调整")
focus_value: float = Field(default=1.0, description="专注值")
focus_mode_quiet_groups: List[str] = Field(
default_factory=list, description='专注模式下需要保持安静的群组列表, 格式: ["platform:group_id1", "platform:group_id2"]'
default_factory=list,
description='专注模式下需要保持安静的群组列表, 格式: ["platform:group_id1", "platform:group_id2"]',
)
force_reply_private: bool = Field(default=False, description="强制回复私聊")
group_chat_mode: Literal["auto", "normal", "focus"] = Field(default="auto", description="群聊模式")
@@ -343,6 +348,7 @@ class ExpressionConfig(ValidatedConfigBase):
# 如果都没有匹配,返回默认值
return True, True, 1.0
class ToolConfig(ValidatedConfigBase):
"""工具配置类"""
@@ -477,7 +483,6 @@ class ExperimentalConfig(ValidatedConfigBase):
pfc_chatting: bool = Field(default=False, description="启用PFC聊天")
class MaimMessageConfig(ValidatedConfigBase):
"""maim_message配置类"""
@@ -602,8 +607,12 @@ class SleepSystemConfig(ValidatedConfigBase):
sleep_by_schedule: bool = Field(default=True, description="是否根据日程表进行睡觉")
fixed_sleep_time: str = Field(default="23:00", description="固定的睡觉时间")
fixed_wake_up_time: str = Field(default="07:00", description="固定的起床时间")
sleep_time_offset_minutes: int = Field(default=15, ge=0, le=60, description="睡觉时间随机偏移量范围(分钟),实际睡觉时间会在±该值范围内随机")
wake_up_time_offset_minutes: int = Field(default=15, ge=0, le=60, description="起床时间随机偏移量范围(分钟),实际起床时间会在±该值范围内随机")
sleep_time_offset_minutes: int = Field(
default=15, ge=0, le=60, description="睡觉时间随机偏移量范围(分钟),实际睡觉时间会在±该值范围内随机"
)
wake_up_time_offset_minutes: int = Field(
default=15, ge=0, le=60, description="起床时间随机偏移量范围(分钟),实际起床时间会在±该值范围内随机"
)
wakeup_threshold: float = Field(default=15.0, ge=1.0, description="唤醒阈值,达到此值时会被唤醒")
private_message_increment: float = Field(default=3.0, ge=0.1, description="私聊消息增加的唤醒度")
group_mention_increment: float = Field(default=2.0, ge=0.1, description="群聊艾特增加的唤醒度")
@@ -657,6 +666,8 @@ class CrossContextConfig(ValidatedConfigBase):
enable: bool = Field(default=False, description="是否启用跨群聊上下文共享功能")
groups: List[ContextGroup] = Field(default_factory=list, description="上下文共享组列表")
class CommandConfig(ValidatedConfigBase):
"""命令系统配置类"""

View File

@@ -88,8 +88,7 @@ class Individuality:
# 初始化智能兴趣系统
await interest_scoring_system.initialize_smart_interests(
personality_description=full_personality,
personality_id=self.bot_person_id
personality_description=full_personality, personality_id=self.bot_person_id
)
logger.info("智能兴趣系统初始化完成")

View File

@@ -130,6 +130,7 @@ class MainSystem:
# 停止消息重组器
from src.plugin_system.core.event_manager import event_manager
from src.plugin_system import EventType
asyncio.run(event_manager.trigger_event(EventType.ON_STOP, permission_group="SYSTEM"))
from src.utils.message_chunker import reassembler
@@ -250,6 +251,7 @@ MoFox_Bot(第三方修改版)
# 初始化回复后关系追踪系统
from src.chat.affinity_flow.relationship_integration import initialize_relationship_tracking
relationship_tracker = initialize_relationship_tracking()
if relationship_tracker:
logger.info("回复后关系追踪系统初始化成功")
@@ -273,6 +275,7 @@ MoFox_Bot(第三方修改版)
# 初始化LPMM知识库
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge
initialize_lpmm_knowledge()
logger.info("LPMM知识库初始化成功")
@@ -298,6 +301,7 @@ MoFox_Bot(第三方修改版)
# 启动消息管理器
from src.chat.message_manager import message_manager
await message_manager.start()
logger.info("消息管理器已启动")

View File

@@ -102,6 +102,7 @@ class PersonInfoManager:
return hashlib.md5(key.encode()).hexdigest()
qq_id = hashlib.md5(key.encode()).hexdigest()
# 对于 qq 平台,先检查该 person_id 是否已存在;如果存在直接返回
def _db_check_and_migrate_sync(p_id: str, raw_user_id: str):
try:

View File

@@ -123,7 +123,9 @@ class RelationshipFetcher:
all_points = current_points + forgotten_points
if all_points:
# 按权重和时效性综合排序
all_points.sort(key=lambda x: (float(x[1]) if len(x) > 1 else 0, float(x[2]) if len(x) > 2 else 0), reverse=True)
all_points.sort(
key=lambda x: (float(x[1]) if len(x) > 1 else 0, float(x[2]) if len(x) > 2 else 0), reverse=True
)
selected_points = all_points[:points_num]
points_text = "\n".join([f"- {point[0]}{point[2]}" for point in selected_points if len(point) > 2])
else:
@@ -139,7 +141,8 @@ class RelationshipFetcher:
# 2. 认识时间和频率
if know_since:
from datetime import datetime
know_time = datetime.fromtimestamp(know_since).strftime('%Y年%m月%d')
know_time = datetime.fromtimestamp(know_since).strftime("%Y年%m月%d")
relation_parts.append(f"你从{know_time}开始认识{person_name}")
if know_times > 0:
@@ -147,7 +150,8 @@ class RelationshipFetcher:
if last_know:
from datetime import datetime
last_time = datetime.fromtimestamp(last_know).strftime('%m月%d')
last_time = datetime.fromtimestamp(last_know).strftime("%m月%d")
relation_parts.append(f"最近一次交流是在{last_time}")
# 3. 态度和印象
@@ -173,7 +177,7 @@ class RelationshipFetcher:
relationships = await db_query(
UserRelationships,
filters=[UserRelationships.user_id == str(person_info_manager.get_value_sync(person_id, "user_id"))],
limit=1
limit=1,
)
if relationships:
@@ -189,7 +193,9 @@ class RelationshipFetcher:
# 构建最终的关系信息字符串
if relation_parts:
relation_info = f"关于{person_name},你知道以下信息:\n" + "\n".join([f"{part}" for part in relation_parts])
relation_info = f"关于{person_name},你知道以下信息:\n" + "\n".join(
[f"{part}" for part in relation_parts]
)
else:
relation_info = f"你对{person_name}了解不多,这是比较初步的交流。"

View File

@@ -93,7 +93,6 @@ class BaseAction(ABC):
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
self.chat_type_allow: ChatType = getattr(self.__class__, "chat_type_allow", ChatType.ALL)
# =============================================================================
# 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解)
# =============================================================================
@@ -398,6 +397,7 @@ class BaseAction(ABC):
try:
# 1. 从注册中心获取Action类
from src.plugin_system.core.component_registry import component_registry
action_class = component_registry.get_component_class(action_name, ComponentType.ACTION)
if not action_class:
logger.error(f"{log_prefix} 未找到Action: {action_name}")

View File

@@ -270,7 +270,9 @@ class ComponentRegistry:
# 使用EventManager进行事件处理器注册
from src.plugin_system.core.event_manager import event_manager
return event_manager.register_event_handler(handler_class,self.get_plugin_config(handler_info.plugin_name) or {})
return event_manager.register_event_handler(
handler_class, self.get_plugin_config(handler_info.plugin_name) or {}
)
# === 组件移除相关 ===
@@ -686,9 +688,10 @@ class ComponentRegistry:
# 如果插件实例不存在,尝试从配置文件读取
try:
import toml
config_path = Path("config") / "plugins" / plugin_name / "config.toml"
if config_path.exists():
with open(config_path, 'r', encoding='utf-8') as f:
with open(config_path, "r", encoding="utf-8") as f:
config_data = toml.load(f)
logger.debug(f"从配置文件读取插件 {plugin_name} 的配置")
return config_data

View File

@@ -145,7 +145,9 @@ class EventManager:
logger.info(f"事件 {event_name} 已禁用")
return True
def register_event_handler(self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None) -> bool:
def register_event_handler(
self, handler_class: Type[BaseEventHandler], plugin_config: Optional[dict] = None
) -> bool:
"""注册事件处理器
Args:
@@ -167,7 +169,7 @@ class EventManager:
# 创建事件处理器实例,传递插件配置
handler_instance = handler_class()
handler_instance.plugin_config = plugin_config
if plugin_config is not None and hasattr(handler_instance, 'set_plugin_config'):
if plugin_config is not None and hasattr(handler_instance, "set_plugin_config"):
handler_instance.set_plugin_config(plugin_config)
self._event_handlers[handler_name] = handler_instance

View File

@@ -199,9 +199,7 @@ class PluginManager:
self._show_plugin_components(plugin_name)
# 检查并调用 on_plugin_loaded 钩子(如果存在)
if hasattr(plugin_instance, "on_plugin_loaded") and callable(
plugin_instance.on_plugin_loaded
):
if hasattr(plugin_instance, "on_plugin_loaded") and callable(plugin_instance.on_plugin_loaded):
logger.debug(f"为插件 '{plugin_name}' 调用 on_plugin_loaded 钩子")
try:
# 使用 asyncio.create_task 确保它不会阻塞加载流程

View File

@@ -85,7 +85,7 @@ class AtAction(BaseAction):
reply_to=reply_to,
extra_info=extra_info,
enable_tool=False, # 艾特回复通常不需要工具调用
from_plugin=False
from_plugin=False,
)
if success and llm_response:

View File

@@ -70,7 +70,9 @@ class EmojiAction(BaseAction):
# 2. 获取所有有效的表情包对象
emoji_manager = get_emoji_manager()
all_emojis_obj: list[MaiEmoji] = [e for e in emoji_manager.emoji_objects if not e.is_deleted and e.description]
all_emojis_obj: list[MaiEmoji] = [
e for e in emoji_manager.emoji_objects if not e.is_deleted and e.description
]
if not all_emojis_obj:
logger.warning(f"{self.log_prefix} 无法获取任何带有描述的有效表情包")
return False, "无法获取任何带有描述的有效表情包"
@@ -171,7 +173,9 @@ class EmojiAction(BaseAction):
if matched_key:
emoji_base64, emoji_description = random.choice(emotion_map[matched_key])
logger.info(f"{self.log_prefix} 找到匹配情感 '{chosen_emotion}' (匹配到: '{matched_key}') 的表情包: {emoji_description}")
logger.info(
f"{self.log_prefix} 找到匹配情感 '{chosen_emotion}' (匹配到: '{matched_key}') 的表情包: {emoji_description}"
)
else:
logger.warning(
f"{self.log_prefix} LLM选择的情感 '{chosen_emotion}' 不在可用列表中, 将随机选择一个表情包"
@@ -226,15 +230,23 @@ class EmojiAction(BaseAction):
logger.info(f"{self.log_prefix} LLM选择的描述: {chosen_description}")
# 简单关键词匹配
matched_emoji = next((item for item in all_emojis_data if chosen_description.lower() in item[1].lower() or item[1].lower() in chosen_description.lower()), None)
matched_emoji = next(
(
item
for item in all_emojis_data
if chosen_description.lower() in item[1].lower()
or item[1].lower() in chosen_description.lower()
),
None,
)
# 如果包含匹配失败,尝试关键词匹配
if not matched_emoji:
keywords = ['惊讶', '困惑', '呆滞', '震惊', '', '无语', '', '可爱']
keywords = ["惊讶", "困惑", "呆滞", "震惊", "", "无语", "", "可爱"]
for keyword in keywords:
if keyword in chosen_description:
for item in all_emojis_data:
if any(k in item[1] for k in ['', '', '', '困惑', '无语']):
if any(k in item[1] for k in ["", "", "", "困惑", "无语"]):
matched_emoji = item
break
if matched_emoji:
@@ -255,7 +267,9 @@ class EmojiAction(BaseAction):
if not success:
logger.error(f"{self.log_prefix} 表情包发送失败")
await self.store_action_info(action_build_into_prompt = True,action_prompt_display =f"发送了一个表情包,但失败了",action_done= False)
await self.store_action_info(
action_build_into_prompt=True, action_prompt_display=f"发送了一个表情包,但失败了", action_done=False
)
return False, "表情包发送失败"
# 发送成功后,记录到历史
@@ -264,7 +278,9 @@ class EmojiAction(BaseAction):
except Exception as e:
logger.error(f"{self.log_prefix} 添加表情到历史记录时出错: {e}")
await self.store_action_info(action_build_into_prompt = True,action_prompt_display =f"发送了一个表情包",action_done= True)
await self.store_action_info(
action_build_into_prompt=True, action_prompt_display=f"发送了一个表情包", action_done=True
)
return True, f"发送表情包: {emoji_description}"

View File

@@ -1,4 +1,3 @@
from src.plugin_system import BaseEventHandler
from src.plugin_system.base.base_event import HandlerResult
@@ -1748,6 +1747,7 @@ class SetGroupSignHandler(BaseEventHandler):
logger.error("事件 napcat_set_group_sign 请求失败!")
return HandlerResult(False, False, {"status": "error"})
# ===PERSONAL===
class SetInputStatusHandler(BaseEventHandler):
handler_name: str = "napcat_set_input_status_handler"

View File

@@ -285,7 +285,7 @@ class NapcatAdapterPlugin(BasePlugin):
def enable_plugin(self) -> bool:
"""通过配置文件动态控制插件启用状态"""
# 如果已经通过配置加载了状态,使用配置中的值
if hasattr(self, '_is_enabled'):
if hasattr(self, "_is_enabled"):
return self._is_enabled
# 否则使用默认值(禁用状态)
return False
@@ -308,60 +308,107 @@ class NapcatAdapterPlugin(BasePlugin):
"nickname": ConfigField(type=str, default="", description="昵称配置(目前未使用)"),
},
"napcat_server": {
"mode": ConfigField(type=str, default="reverse", description="连接模式reverse=反向连接(作为服务器), forward=正向连接(作为客户端)", choices=["reverse", "forward"]),
"mode": ConfigField(
type=str,
default="reverse",
description="连接模式reverse=反向连接(作为服务器), forward=正向连接(作为客户端)",
choices=["reverse", "forward"],
),
"host": ConfigField(type=str, default="localhost", description="主机地址"),
"port": ConfigField(type=int, default=8095, description="端口号"),
"url": ConfigField(type=str, default="", description="正向连接时的完整WebSocket URL如 ws://localhost:8080/ws (仅在forward模式下使用)"),
"access_token": ConfigField(type=str, default="", description="WebSocket 连接的访问令牌,用于身份验证(可选)"),
"url": ConfigField(
type=str,
default="",
description="正向连接时的完整WebSocket URL如 ws://localhost:8080/ws (仅在forward模式下使用)",
),
"access_token": ConfigField(
type=str, default="", description="WebSocket 连接的访问令牌,用于身份验证(可选)"
),
"heartbeat_interval": ConfigField(type=int, default=30, description="心跳间隔时间(按秒计)"),
},
"maibot_server": {
"host": ConfigField(type=str, default="localhost", description="麦麦在.env文件中设置的主机地址即HOST字段"),
"host": ConfigField(
type=str, default="localhost", description="麦麦在.env文件中设置的主机地址即HOST字段"
),
"port": ConfigField(type=int, default=8000, description="麦麦在.env文件中设置的端口即PORT字段"),
"platform_name": ConfigField(type=str, default="qq", description="平台名称,用于消息路由"),
},
"voice": {
"use_tts": ConfigField(type=bool, default=False, description="是否使用tts语音请确保你配置了tts并有对应的adapter"),
"use_tts": ConfigField(
type=bool, default=False, description="是否使用tts语音请确保你配置了tts并有对应的adapter"
),
},
"slicing": {
"max_frame_size": ConfigField(type=int, default=64, description="WebSocket帧的最大大小单位为字节默认64KB"),
"max_frame_size": ConfigField(
type=int, default=64, description="WebSocket帧的最大大小单位为字节默认64KB"
),
"delay_ms": ConfigField(type=int, default=10, description="切片发送间隔时间,单位为毫秒"),
},
"debug": {
"level": ConfigField(type=str, default="INFO", description="日志等级DEBUG, INFO, WARNING, ERROR, CRITICAL", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]),
"level": ConfigField(
type=str,
default="INFO",
description="日志等级DEBUG, INFO, WARNING, ERROR, CRITICAL",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
),
},
"features": {
# 权限设置
"group_list_type": ConfigField(type=str, default="blacklist", description="群聊列表类型whitelist白名单或 blacklist黑名单", choices=["whitelist", "blacklist"]),
"group_list_type": ConfigField(
type=str,
default="blacklist",
description="群聊列表类型whitelist白名单或 blacklist黑名单",
choices=["whitelist", "blacklist"],
),
"group_list": ConfigField(type=list, default=[], description="群聊ID列表"),
"private_list_type": ConfigField(type=str, default="blacklist", description="私聊列表类型whitelist白名单或 blacklist黑名单", choices=["whitelist", "blacklist"]),
"private_list_type": ConfigField(
type=str,
default="blacklist",
description="私聊列表类型whitelist白名单或 blacklist黑名单",
choices=["whitelist", "blacklist"],
),
"private_list": ConfigField(type=list, default=[], description="用户ID列表"),
"ban_user_id": ConfigField(type=list, default=[], description="全局禁止用户ID列表这些用户无法在任何地方使用机器人"),
"ban_user_id": ConfigField(
type=list, default=[], description="全局禁止用户ID列表这些用户无法在任何地方使用机器人"
),
"ban_qq_bot": ConfigField(type=bool, default=False, description="是否屏蔽QQ官方机器人消息"),
# 聊天功能设置
"enable_poke": ConfigField(type=bool, default=True, description="是否启用戳一戳功能"),
"ignore_non_self_poke": ConfigField(type=bool, default=False, description="是否无视不是针对自己的戳一戳"),
"poke_debounce_seconds": ConfigField(type=int, default=3, description="戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略"),
"poke_debounce_seconds": ConfigField(
type=int, default=3, description="戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略"
),
"enable_reply_at": ConfigField(type=bool, default=True, description="是否启用引用回复时艾特用户的功能"),
"reply_at_rate": ConfigField(type=float, default=0.5, description="引用回复时艾特用户的几率 (0.0 ~ 1.0)"),
"enable_emoji_like": ConfigField(type=bool, default=True, description="是否启用群聊表情回复功能"),
# 视频处理设置
"enable_video_analysis": ConfigField(type=bool, default=True, description="是否启用视频识别功能"),
"max_video_size_mb": ConfigField(type=int, default=100, description="视频文件最大大小限制MB"),
"download_timeout": ConfigField(type=int, default=60, description="视频下载超时时间(秒)"),
"supported_formats": ConfigField(type=list, default=["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"], description="支持的视频格式"),
"supported_formats": ConfigField(
type=list, default=["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"], description="支持的视频格式"
),
# 消息缓冲设置
"enable_message_buffer": ConfigField(type=bool, default=True, description="是否启用消息缓冲合并功能"),
"message_buffer_enable_group": ConfigField(type=bool, default=True, description="是否启用群聊消息缓冲合并"),
"message_buffer_enable_private": ConfigField(type=bool, default=True, description="是否启用私聊消息缓冲合并"),
"message_buffer_interval": ConfigField(type=float, default=3.0, description="消息合并间隔时间(秒),在此时间内的连续消息将被合并"),
"message_buffer_initial_delay": ConfigField(type=float, default=0.5, description="消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并"),
"message_buffer_max_components": ConfigField(type=int, default=50, description="单个会话最大缓冲消息组件数量,超过此数量将强制合并"),
"message_buffer_block_prefixes": ConfigField(type=list, default=["/", "!", "", ".", "", "#", "%"], description="消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲"),
}
"message_buffer_enable_private": ConfigField(
type=bool, default=True, description="是否启用私聊消息缓冲合并"
),
"message_buffer_interval": ConfigField(
type=float, default=3.0, description="消息合并间隔时间(秒),在此时间内的连续消息将被合并"
),
"message_buffer_initial_delay": ConfigField(
type=float, default=0.5, description="消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并"
),
"message_buffer_max_components": ConfigField(
type=int, default=50, description="单个会话最大缓冲消息组件数量,超过此数量将强制合并"
),
"message_buffer_block_prefixes": ConfigField(
type=list,
default=["/", "!", "", ".", "", "#", "%"],
description="消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲",
),
},
}
# 配置节描述
@@ -374,7 +421,7 @@ class NapcatAdapterPlugin(BasePlugin):
"voice": "发送语音设置",
"slicing": "WebSocket消息切片设置",
"debug": "调试设置",
"features": "功能设置(权限控制、聊天功能、视频处理、消息缓冲等)"
"features": "功能设置(权限控制、聊天功能、视频处理、消息缓冲等)",
}
def register_events(self):
@@ -409,6 +456,7 @@ class NapcatAdapterPlugin(BasePlugin):
chunker.set_plugin_config(self.config)
# 设置response_pool的插件配置
from .src.response_pool import set_plugin_config as set_response_pool_config
set_response_pool_config(self.config)
# 设置send_handler的插件配置
send_handler.set_plugin_config(self.config)

View File

@@ -102,7 +102,9 @@ class SimpleMessageBuffer:
return True
# 检查屏蔽前缀
block_prefixes = tuple(config_api.get_plugin_config(self.plugin_config, "features.message_buffer_block_prefixes", []))
block_prefixes = tuple(
config_api.get_plugin_config(self.plugin_config, "features.message_buffer_block_prefixes", [])
)
text = text.strip()
if text.startswith(block_prefixes):
@@ -134,9 +136,13 @@ class SimpleMessageBuffer:
# 检查是否启用对应类型的缓冲
message_type = event_data.get("message_type", "")
if message_type == "group" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_group", False):
if message_type == "group" and not config_api.get_plugin_config(
self.plugin_config, "features.message_buffer_enable_group", False
):
return False
elif message_type == "private" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_private", False):
elif message_type == "private" and not config_api.get_plugin_config(
self.plugin_config, "features.message_buffer_enable_private", False
):
return False
# 提取文本
@@ -158,7 +164,9 @@ class SimpleMessageBuffer:
session = self.buffer_pool[session_id]
# 检查是否超过最大组件数量
if len(session.messages) >= config_api.get_plugin_config(self.plugin_config, "features.message_buffer_max_components", 5):
if len(session.messages) >= config_api.get_plugin_config(
self.plugin_config, "features.message_buffer_max_components", 5
):
logger.debug(f"会话 {session_id} 消息数量达到上限,强制合并")
asyncio.create_task(self._force_merge_session(session_id))
self.buffer_pool[session_id] = BufferedSession(session_id=session_id, original_event=original_event)

View File

@@ -111,7 +111,9 @@ class MessageHandler:
return False
else:
# 检查私聊黑白名单
private_list_type = config_api.get_plugin_config(self.plugin_config, "features.private_list_type", "blacklist")
private_list_type = config_api.get_plugin_config(
self.plugin_config, "features.private_list_type", "blacklist"
)
private_list = config_api.get_plugin_config(self.plugin_config, "features.private_list", [])
if private_list_type == "whitelist":
@@ -158,17 +160,19 @@ class MessageHandler:
"""
# 添加原始消息调试日志特别关注message字段
logger.debug(f"收到原始消息: message_type={raw_message.get('message_type')}, message_id={raw_message.get('message_id')}")
logger.debug(
f"收到原始消息: message_type={raw_message.get('message_type')}, message_id={raw_message.get('message_id')}"
)
logger.debug(f"原始消息内容: {raw_message.get('message', [])}")
# 检查是否包含@或video消息段
message_segments = raw_message.get('message', [])
message_segments = raw_message.get("message", [])
if message_segments:
for i, seg in enumerate(message_segments):
seg_type = seg.get('type')
if seg_type in ['at', 'video']:
seg_type = seg.get("type")
if seg_type in ["at", "video"]:
logger.info(f"检测到 {seg_type.upper()} 消息段 [{i}]: {seg}")
elif seg_type not in ['text', 'face', 'image']:
elif seg_type not in ["text", "face", "image"]:
logger.warning(f"检测到特殊消息段 [{i}]: type={seg_type}, data={seg.get('data', {})}")
message_type: str = raw_message.get("message_type")
@@ -308,9 +312,13 @@ class MessageHandler:
message_type = raw_message.get("message_type")
should_use_buffer = False
if message_type == "group" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_group", True):
if message_type == "group" and config_api.get_plugin_config(
self.plugin_config, "features.message_buffer_enable_group", True
):
should_use_buffer = True
elif message_type == "private" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_private", True):
elif message_type == "private" and config_api.get_plugin_config(
self.plugin_config, "features.message_buffer_enable_private", True
):
should_use_buffer = True
if should_use_buffer:

View File

@@ -33,6 +33,7 @@ class MessageSending:
try:
# 重新导入router
from ..mmc_com_layer import router
self.maibot_router = router
if self.maibot_router is not None:
logger.info("MaiBot router重连成功")
@@ -75,7 +76,7 @@ class MessageSending:
platform = message_base.message_info.platform
# 再次检查router状态防止运行时被重置
if self.maibot_router is None or not hasattr(self.maibot_router, 'clients'):
if self.maibot_router is None or not hasattr(self.maibot_router, "clients"):
logger.warning("MaiBot router连接已断开尝试重新连接")
if not await self._attempt_reconnect():
logger.error("MaiBot router重连失败切片发送中止")

View File

@@ -22,7 +22,9 @@ class MetaEventHandler:
"""设置插件配置"""
self.plugin_config = plugin_config
# 更新interval值
self.interval = config_api.get_plugin_config(self.plugin_config, "napcat_server.heartbeat_interval", 5000) / 1000
self.interval = (
config_api.get_plugin_config(self.plugin_config, "napcat_server.heartbeat_interval", 5000) / 1000
)
async def handle_meta_event(self, message: dict) -> None:
event_type = message.get("meta_event_type")

View File

@@ -116,9 +116,9 @@ class NoticeHandler:
sub_type = raw_message.get("sub_type")
match sub_type:
case NoticeType.Notify.poke:
if config_api.get_plugin_config(self.plugin_config, "features.enable_poke", True) and await message_handler.check_allow_to_chat(
user_id, group_id, False, False
):
if config_api.get_plugin_config(
self.plugin_config, "features.enable_poke", True
) and await message_handler.check_allow_to_chat(user_id, group_id, False, False):
logger.debug("处理戳一戳消息")
handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id)
else:
@@ -127,14 +127,18 @@ class NoticeHandler:
from src.plugin_system.core.event_manager import event_manager
from ...event_types import NapcatEvent
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME)
await event_manager.trigger_event(
NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME
)
case _:
logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}")
case NoticeType.group_msg_emoji_like:
# 该事件转移到 handle_group_emoji_like_notify函数内触发
if config_api.get_plugin_config(self.plugin_config, "features.enable_emoji_like", True):
logger.debug("处理群聊表情回复")
handled_message, user_info = await self.handle_group_emoji_like_notify(raw_message,group_id,user_id)
handled_message, user_info = await self.handle_group_emoji_like_notify(
raw_message, group_id, user_id
)
else:
logger.warning("群聊表情回复被禁用,取消群聊表情回复处理")
case NoticeType.group_ban:
@@ -308,7 +312,9 @@ class NoticeHandler:
from src.plugin_system.core.event_manager import event_manager
from ...event_types import NapcatEvent
target_message = await event_manager.trigger_event(NapcatEvent.MESSAGE.GET_MSG,message_id=raw_message.get("message_id",""))
target_message = await event_manager.trigger_event(
NapcatEvent.MESSAGE.GET_MSG, message_id=raw_message.get("message_id", "")
)
target_message_text = target_message.get_message_result().get("data", {}).get("raw_message", "")
if not target_message:
logger.error("未找到对应消息")
@@ -330,9 +336,12 @@ class NoticeHandler:
group_id=group_id,
user_id=user_id,
message_id=raw_message.get("message_id", ""),
emoji_id=like_emoji_id
emoji_id=like_emoji_id,
)
seg_data = Seg(
type="text",
data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]",
)
seg_data = Seg(type="text",data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id,"")}回复了你的消息[{target_message_text}]")
return seg_data, user_info
async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]:

View File

@@ -297,9 +297,9 @@ class SendHandler:
try:
# 检查是否为缓冲消息ID格式buffered-{original_id}-{timestamp}
if id.startswith('buffered-'):
if id.startswith("buffered-"):
# 从缓冲消息ID中提取原始消息ID
original_id = id.split('-')[1]
original_id = id.split("-")[1]
msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(original_id)})
else:
msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(id)})

View File

@@ -18,7 +18,9 @@ class WebSocketManager:
self.max_reconnect_attempts = 10 # 最大重连次数
self.plugin_config = None
async def start_connection(self, message_handler: Callable[[Server.ServerConnection], Any], plugin_config: dict) -> None:
async def start_connection(
self, message_handler: Callable[[Server.ServerConnection], Any], plugin_config: dict
) -> None:
"""根据配置启动 WebSocket 连接"""
self.plugin_config = plugin_config
mode = config_api.get_plugin_config(plugin_config, "napcat_server.mode")
@@ -72,9 +74,7 @@ class WebSocketManager:
# 如果配置了访问令牌,添加到请求头
access_token = config_api.get_plugin_config(self.plugin_config, "napcat_server.access_token")
if access_token:
connect_kwargs["additional_headers"] = {
"Authorization": f"Bearer {access_token}"
}
connect_kwargs["additional_headers"] = {"Authorization": f"Bearer {access_token}"}
logger.info("已添加访问令牌到连接请求头")
async with Server.connect(url, **connect_kwargs) as websocket:

View File

@@ -1,6 +1,7 @@
"""
Base search engine interface
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Any

View File

@@ -1,6 +1,7 @@
"""
Bing search engine implementation
"""
import asyncio
import functools
import random
@@ -202,12 +203,7 @@ class BingSearchEngine(BaseSearchEngine):
if ABSTRACT_MAX_LENGTH and len(abstract) > ABSTRACT_MAX_LENGTH:
abstract = abstract[:ABSTRACT_MAX_LENGTH] + "..."
list_data.append({
"title": title,
"url": url,
"snippet": abstract,
"provider": "Bing"
})
list_data.append({"title": title, "url": url, "snippet": abstract, "provider": "Bing"})
if len(list_data) >= 10: # 限制结果数量
break
@@ -222,16 +218,28 @@ class BingSearchEngine(BaseSearchEngine):
text = link.get_text().strip()
# 过滤有效的搜索结果链接
if (href and text and len(text) > 10
if (
href
and text
and len(text) > 10
and not href.startswith("javascript:")
and not href.startswith("#")
and "http" in href
and not any(x in href for x in [
"bing.com/search", "bing.com/images", "bing.com/videos",
"bing.com/maps", "bing.com/news", "login", "account",
"microsoft", "javascript"
])):
and not any(
x in href
for x in [
"bing.com/search",
"bing.com/images",
"bing.com/videos",
"bing.com/maps",
"bing.com/news",
"login",
"account",
"microsoft",
"javascript",
]
)
):
# 尝试获取摘要
abstract = ""
parent = link.parent
@@ -244,12 +252,7 @@ class BingSearchEngine(BaseSearchEngine):
if ABSTRACT_MAX_LENGTH and len(abstract) > ABSTRACT_MAX_LENGTH:
abstract = abstract[:ABSTRACT_MAX_LENGTH] + "..."
list_data.append({
"title": text,
"url": href,
"snippet": abstract,
"provider": "Bing"
})
list_data.append({"title": text, "url": href, "snippet": abstract, "provider": "Bing"})
if len(list_data) >= 10:
break

View File

@@ -1,6 +1,7 @@
"""
DuckDuckGo search engine implementation
"""
from typing import Dict, List, Any
from asyncddgs import aDDGS
@@ -29,12 +30,7 @@ class DDGSearchEngine(BaseSearchEngine):
search_response = await ddgs.text(query, max_results=num_results)
return [
{
"title": r.get("title"),
"url": r.get("href"),
"snippet": r.get("body"),
"provider": "DuckDuckGo"
}
{"title": r.get("title"), "url": r.get("href"), "snippet": r.get("body"), "provider": "DuckDuckGo"}
for r in search_response
]
except Exception as e:

View File

@@ -1,6 +1,7 @@
"""
Exa search engine implementation
"""
import asyncio
import functools
from datetime import datetime, timedelta
@@ -29,11 +30,7 @@ class ExaSearchEngine(BaseSearchEngine):
exa_api_keys = config_api.get_global_config("web_search.exa_api_keys", None)
# 创建API密钥管理器
self.api_manager = create_api_key_manager_from_config(
exa_api_keys,
lambda key: Exa(api_key=key),
"Exa"
)
self.api_manager = create_api_key_manager_from_config(exa_api_keys, lambda key: Exa(api_key=key), "Exa")
def is_available(self) -> bool:
"""检查Exa搜索引擎是否可用"""
@@ -52,7 +49,7 @@ class ExaSearchEngine(BaseSearchEngine):
if time_range != "any":
today = datetime.now()
start_date = today - timedelta(days=7 if time_range == "week" else 30)
exa_args["start_published_date"] = start_date.strftime('%Y-%m-%d')
exa_args["start_published_date"] = start_date.strftime("%Y-%m-%d")
try:
# 使用API密钥管理器获取下一个客户端
@@ -69,8 +66,8 @@ class ExaSearchEngine(BaseSearchEngine):
{
"title": res.title,
"url": res.url,
"snippet": " ".join(getattr(res, 'highlights', [])) or (getattr(res, 'text', '')[:250] + '...'),
"provider": "Exa"
"snippet": " ".join(getattr(res, "highlights", [])) or (getattr(res, "text", "")[:250] + "..."),
"provider": "Exa",
}
for res in search_response.results
]

View File

@@ -1,6 +1,7 @@
"""
Tavily search engine implementation
"""
import asyncio
import functools
from typing import Dict, List, Any
@@ -29,9 +30,7 @@ class TavilySearchEngine(BaseSearchEngine):
# 创建API密钥管理器
self.api_manager = create_api_key_manager_from_config(
tavily_api_keys,
lambda key: TavilyClient(api_key=key),
"Tavily"
tavily_api_keys, lambda key: TavilyClient(api_key=key), "Tavily"
)
def is_available(self) -> bool:
@@ -60,7 +59,7 @@ class TavilySearchEngine(BaseSearchEngine):
"max_results": num_results,
"search_depth": "basic",
"include_answer": False,
"include_raw_content": False
"include_raw_content": False,
}
# 根据时间范围调整搜索参数
@@ -76,12 +75,14 @@ class TavilySearchEngine(BaseSearchEngine):
results = []
if search_response and "results" in search_response:
for res in search_response["results"]:
results.append({
results.append(
{
"title": res.get("title", "无标题"),
"url": res.get("url", ""),
"snippet": res.get("content", "")[:300] + "..." if res.get("content") else "无摘要",
"provider": "Tavily"
})
"provider": "Tavily",
}
)
return results

View File

@@ -3,15 +3,10 @@ Web Search Tool Plugin
一个功能强大的网络搜索和URL解析插件支持多种搜索引擎和解析策略。
"""
from typing import List, Tuple, Type
from src.plugin_system import (
BasePlugin,
register_plugin,
ComponentInfo,
ConfigField,
PythonDependency
)
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo, ConfigField, PythonDependency
from src.plugin_system.apis import config_api
from src.common.logger import get_logger
@@ -61,7 +56,7 @@ class WEBSEARCHPLUGIN(BasePlugin):
"Exa": exa_engine.is_available(),
"Tavily": tavily_engine.is_available(),
"DuckDuckGo": ddg_engine.is_available(),
"Bing": bing_engine.is_available()
"Bing": bing_engine.is_available(),
}
available_engines = [name for name, available in engines_status.items() if available]
@@ -77,37 +72,30 @@ class WEBSEARCHPLUGIN(BasePlugin):
# Python包依赖列表
python_dependencies: List[PythonDependency] = [
PythonDependency(
package_name="asyncddgs",
description="异步DuckDuckGo搜索库",
optional=False
),
PythonDependency(package_name="asyncddgs", description="异步DuckDuckGo搜索库", optional=False),
PythonDependency(
package_name="exa_py",
description="Exa搜索API客户端库",
optional=True # 如果没有API密钥这个是可选的
optional=True, # 如果没有API密钥这个是可选的
),
PythonDependency(
package_name="tavily",
install_name="tavily-python", # 安装时使用这个名称
description="Tavily搜索API客户端库",
optional=True # 如果没有API密钥这个是可选的
optional=True, # 如果没有API密钥这个是可选的
),
PythonDependency(
package_name="httpx",
version=">=0.20.0",
install_name="httpx[socks]", # 安装时使用这个名称(包含可选依赖)
description="支持SOCKS代理的HTTP客户端库",
optional=False
)
optional=False,
),
]
config_file_name: str = "config.toml" # 配置文件名
# 配置节描述
config_section_descriptions = {
"plugin": "插件基本信息",
"proxy": "链接本地解析代理配置"
}
config_section_descriptions = {"plugin": "插件基本信息", "proxy": "链接本地解析代理配置"}
# 配置Schema定义
# 注意EXA配置和组件设置已迁移到主配置文件(bot_config.toml)的[exa]和[web_search]部分
@@ -119,25 +107,15 @@ class WEBSEARCHPLUGIN(BasePlugin):
},
"proxy": {
"http_proxy": ConfigField(
type=str,
default=None,
description="HTTP代理地址格式如: http://proxy.example.com:8080"
type=str, default=None, description="HTTP代理地址格式如: http://proxy.example.com:8080"
),
"https_proxy": ConfigField(
type=str,
default=None,
description="HTTPS代理地址格式如: http://proxy.example.com:8080"
type=str, default=None, description="HTTPS代理地址格式如: http://proxy.example.com:8080"
),
"socks5_proxy": ConfigField(
type=str,
default=None,
description="SOCKS5代理地址格式如: socks5://proxy.example.com:1080"
type=str, default=None, description="SOCKS5代理地址格式如: socks5://proxy.example.com:1080"
),
"enable_proxy": ConfigField(
type=bool,
default=False,
description="是否启用代理"
)
"enable_proxy": ConfigField(type=bool, default=False, description="是否启用代理"),
},
}

View File

@@ -1,6 +1,7 @@
"""
URL parser tool implementation
"""
import asyncio
import functools
from typing import Any, Dict
@@ -24,6 +25,7 @@ class URLParserTool(BaseTool):
"""
一个用于解析和总结一个或多个网页URL内容的工具。
"""
name: str = "parse_url"
description: str = "当需要理解一个或多个特定网页链接的内容时,使用此工具。例如:'这些网页讲了什么?[https://example.com, https://example2.com]''帮我总结一下这些文章'"
available_for_llm: bool = True
@@ -45,9 +47,7 @@ class URLParserTool(BaseTool):
# 创建API密钥管理器
self.api_manager = create_api_key_manager_from_config(
exa_api_keys,
lambda key: Exa(api_key=key),
"Exa URL Parser"
exa_api_keys, lambda key: Exa(api_key=key), "Exa URL Parser"
)
async def _local_parse_and_summarize(self, url: str) -> Dict[str, Any]:
@@ -108,7 +108,7 @@ class URLParserTool(BaseTool):
model_config=model_config,
request_type="story.generate",
temperature=0.3,
max_tokens=1000
max_tokens=1000,
)
if not success:
@@ -117,12 +117,7 @@ class URLParserTool(BaseTool):
logger.info(f"成功生成摘要内容:'{summary}'")
return {
"title": title,
"url": url,
"snippet": summary,
"source": "local"
}
return {"title": title, "url": url, "snippet": summary, "source": "local"}
except httpx.HTTPStatusError as e:
logger.warning(f"本地解析URL '{url}' 失败 (HTTP {e.response.status_code})")
@@ -137,6 +132,7 @@ class URLParserTool(BaseTool):
"""
# 获取当前文件路径用于缓存键
import os
current_file_path = os.path.abspath(__file__)
# 检查缓存
@@ -185,31 +181,35 @@ class URLParserTool(BaseTool):
contents_response = None # 确保异常后为None
# 步骤 2: 处理Exa的响应
if contents_response and hasattr(contents_response, 'statuses'):
results_map = {res.url: res for res in contents_response.results} if hasattr(contents_response, 'results') else {}
if contents_response and hasattr(contents_response, "statuses"):
results_map = (
{res.url: res for res in contents_response.results} if hasattr(contents_response, "results") else {}
)
if contents_response.statuses:
for status in contents_response.statuses:
if status.status == 'success':
if status.status == "success":
res = results_map.get(status.id)
if res:
summary = getattr(res, 'summary', '')
highlights = " ".join(getattr(res, 'highlights', []))
text_snippet = (getattr(res, 'text', '')[:300] + '...') if getattr(res, 'text', '') else ''
snippet = summary or highlights or text_snippet or '无摘要'
summary = getattr(res, "summary", "")
highlights = " ".join(getattr(res, "highlights", []))
text_snippet = (getattr(res, "text", "")[:300] + "...") if getattr(res, "text", "") else ""
snippet = summary or highlights or text_snippet or "无摘要"
successful_results.append({
"title": getattr(res, 'title', '无标题'),
"url": getattr(res, 'url', status.id),
successful_results.append(
{
"title": getattr(res, "title", "无标题"),
"url": getattr(res, "url", status.id),
"snippet": snippet,
"source": "exa"
})
"source": "exa",
}
)
else:
error_tag = getattr(status, 'error', '未知错误')
error_tag = getattr(status, "error", "未知错误")
logger.warning(f"Exa解析URL '{status.id}' 失败: {error_tag}。准备本地重试。")
urls_to_retry_locally.append(status.id)
else:
# 如果Exa未配置、API调用失败或返回无效响应则所有URL都进入本地重试
urls_to_retry_locally.extend(url for url in urls if url not in [res['url'] for res in successful_results])
urls_to_retry_locally.extend(url for url in urls if url not in [res["url"] for res in successful_results])
# 步骤 3: 对失败的URL进行本地解析
if urls_to_retry_locally:
@@ -229,11 +229,7 @@ class URLParserTool(BaseTool):
formatted_content = format_url_parse_results(successful_results)
result = {
"type": "url_parse_result",
"content": formatted_content,
"errors": error_messages
}
result = {"type": "url_parse_result", "content": formatted_content, "errors": error_messages}
# 保存到缓存
if "error" not in result:

View File

@@ -1,6 +1,7 @@
"""
Web search tool implementation
"""
import asyncio
from typing import Any, Dict, List
@@ -22,13 +23,22 @@ class WebSurfingTool(BaseTool):
"""
网络搜索工具
"""
name: str = "web_search"
description: str = "用于执行网络搜索。当用户明确要求搜索,或者需要获取关于公司、产品、事件的最新信息、新闻或动态时,必须使用此工具"
description: str = (
"用于执行网络搜索。当用户明确要求搜索,或者需要获取关于公司、产品、事件的最新信息、新闻或动态时,必须使用此工具"
)
available_for_llm: bool = True
parameters = [
("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None),
("num_results", ToolParamType.INTEGER, "期望每个搜索引擎返回的搜索结果数量默认为5。", False, None),
("time_range", ToolParamType.STRING, "指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'", False, ["any", "week", "month"])
(
"time_range",
ToolParamType.STRING,
"指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'",
False,
["any", "week", "month"],
),
] # type: ignore
def __init__(self, plugin_config=None):
@@ -38,7 +48,7 @@ class WebSurfingTool(BaseTool):
"exa": ExaSearchEngine(),
"tavily": TavilySearchEngine(),
"ddg": DDGSearchEngine(),
"bing": BingSearchEngine()
"bing": BingSearchEngine(),
}
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
@@ -48,6 +58,7 @@ class WebSurfingTool(BaseTool):
# 获取当前文件路径用于缓存键
import os
current_file_path = os.path.abspath(__file__)
# 检查缓存
@@ -76,7 +87,9 @@ class WebSurfingTool(BaseTool):
return result
async def _execute_parallel_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]:
async def _execute_parallel_search(
self, function_args: Dict[str, Any], enabled_engines: List[str]
) -> Dict[str, Any]:
"""并行搜索策略:同时使用所有启用的搜索引擎"""
search_tasks = []
@@ -113,7 +126,9 @@ class WebSurfingTool(BaseTool):
logger.error(f"执行并行网络搜索时发生异常: {e}", exc_info=True)
return {"error": f"执行网络搜索时发生严重错误: {str(e)}"}
async def _execute_fallback_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]:
async def _execute_fallback_search(
self, function_args: Dict[str, Any], enabled_engines: List[str]
) -> Dict[str, Any]:
"""回退搜索策略:按顺序尝试搜索引擎,失败则尝试下一个"""
for engine_name in enabled_engines:
engine = self.engines.get(engine_name)

View File

@@ -1,13 +1,14 @@
"""
API密钥管理器提供轮询机制
"""
import itertools
from typing import List, Optional, TypeVar, Generic, Callable
from src.common.logger import get_logger
logger = get_logger("api_key_manager")
T = TypeVar('T')
T = TypeVar("T")
class APIKeyManager(Generic[T]):
@@ -65,9 +66,7 @@ class APIKeyManager(Generic[T]):
def create_api_key_manager_from_config(
config_keys: Optional[List[str]],
client_factory: Callable[[str], T],
service_name: str
config_keys: Optional[List[str]], client_factory: Callable[[str], T], service_name: str
) -> APIKeyManager[T]:
"""
从配置创建API密钥管理器的便捷函数

View File

@@ -1,6 +1,7 @@
"""
Formatters for web search results
"""
from typing import List, Dict, Any
@@ -13,9 +14,9 @@ def format_search_results(results: List[Dict[str, Any]]) -> str:
formatted_string = "根据网络搜索结果:\n\n"
for i, res in enumerate(results, 1):
title = res.get("title", '无标题')
url = res.get("url", '#')
snippet = res.get("snippet", '无摘要')
title = res.get("title", "无标题")
url = res.get("url", "#")
snippet = res.get("snippet", "无摘要")
provider = res.get("provider", "未知来源")
formatted_string += f"{i}. **{title}** (来自: {provider})\n"
@@ -31,10 +32,10 @@ def format_url_parse_results(results: List[Dict[str, Any]]) -> str:
"""
formatted_parts = []
for res in results:
title = res.get('title', '无标题')
url = res.get('url', '#')
snippet = res.get('snippet', '无摘要')
source = res.get('source', '未知')
title = res.get("title", "无标题")
url = res.get("url", "#")
snippet = res.get("snippet", "无摘要")
source = res.get("source", "未知")
formatted_string = f"**{title}**\n"
formatted_string += f"**内容摘要**:\n{snippet}\n"

View File

@@ -1,6 +1,7 @@
"""
URL processing utilities
"""
import re
from typing import List
@@ -12,11 +13,11 @@ def parse_urls_from_input(urls_input) -> List[str]:
if isinstance(urls_input, str):
# 如果是字符串尝试解析为URL列表
# 提取所有HTTP/HTTPS URL
url_pattern = r'https?://[^\s\],]+'
url_pattern = r"https?://[^\s\],]+"
urls = re.findall(url_pattern, urls_input)
if not urls:
# 如果没有找到标准URL将整个字符串作为单个URL
if urls_input.strip().startswith(('http://', 'https://')):
if urls_input.strip().startswith(("http://", "https://")):
urls = [urls_input.strip()]
else:
return []
@@ -34,6 +35,6 @@ def validate_urls(urls: List[str]) -> List[str]:
"""
valid_urls = []
for url in urls:
if url.startswith(('http://', 'https://')):
if url.startswith(("http://", "https://")):
valid_urls.append(url)
return valid_urls

View File

@@ -21,8 +21,18 @@ logger = get_logger(__name__)
# ============================ AsyncTask ============================
class ReminderTask(AsyncTask):
def __init__(self, delay: float, stream_id: str, is_group: bool, target_user_id: str, target_user_name: str, event_details: str, creator_name: str):
def __init__(
self,
delay: float,
stream_id: str,
is_group: bool,
target_user_id: str,
target_user_name: str,
event_details: str,
creator_name: str,
):
super().__init__(task_name=f"ReminderTask_{target_user_id}_{datetime.now().timestamp()}")
self.delay = delay
self.stream_id = stream_id
@@ -44,15 +54,15 @@ class ReminderTask(AsyncTask):
if self.is_group:
# 在群聊中,构造 @ 消息段并发送
group_id = self.stream_id.split('_')[-1] if '_' in self.stream_id else self.stream_id
group_id = self.stream_id.split("_")[-1] if "_" in self.stream_id else self.stream_id
message_payload = [
{"type": "at", "data": {"qq": self.target_user_id}},
{"type": "text", "data": {"text": f" {reminder_text}"}}
{"type": "text", "data": {"text": f" {reminder_text}"}},
]
await send_api.adapter_command_to_stream(
action="send_group_msg",
params={"group_id": group_id, "message": message_payload},
stream_id=self.stream_id
stream_id=self.stream_id,
)
else:
# 在私聊中,直接发送文本
@@ -66,6 +76,7 @@ class ReminderTask(AsyncTask):
# =============================== Actions ===============================
class RemindAction(BaseAction):
"""一个能从对话中智能识别并设置定时提醒的动作。"""
@@ -95,12 +106,12 @@ class RemindAction(BaseAction):
action_parameters = {
"user_name": "需要被提醒的人的称呼或名字,如果没有明确指定给某人,则默认为'自己'",
"remind_time": "描述提醒时间的自然语言字符串,例如'十分钟后''明天下午3点'",
"event_details": "需要提醒的具体事件内容"
"event_details": "需要提醒的具体事件内容",
}
action_require = [
"当用户请求在未来的某个时间点提醒他/她或别人某件事时使用",
"适用于包含明确时间信息和事件描述的对话",
"例如:'10分钟后提醒我收快递''明天早上九点喊一下李四参加晨会'"
"例如:'10分钟后提醒我收快递''明天早上九点喊一下李四参加晨会'",
]
async def execute(self) -> Tuple[bool, str]:
@@ -110,7 +121,15 @@ class RemindAction(BaseAction):
event_details = self.action_data.get("event_details")
if not all([user_name, remind_time_str, event_details]):
missing_params = [p for p, v in {"user_name": user_name, "remind_time": remind_time_str, "event_details": event_details}.items() if not v]
missing_params = [
p
for p, v in {
"user_name": user_name,
"remind_time": remind_time_str,
"event_details": event_details,
}.items()
if not v
]
error_msg = f"缺少必要的提醒参数: {', '.join(missing_params)}"
logger.warning(f"[ReminderPlugin] LLM未能提取完整参数: {error_msg}")
return False, error_msg
@@ -162,7 +181,7 @@ class RemindAction(BaseAction):
target_user_id=str(user_id_to_remind),
target_user_name=str(user_name_to_remind),
event_details=str(event_details),
creator_name=str(self.user_nickname)
creator_name=str(self.user_nickname),
)
await async_task_manager.add_task(reminder_task)
@@ -179,6 +198,7 @@ class RemindAction(BaseAction):
# =============================== Plugin ===============================
@register_plugin
class ReminderPlugin(BasePlugin):
"""一个能从对话中智能识别并设置定时提醒的插件。"""
@@ -193,6 +213,4 @@ class ReminderPlugin(BasePlugin):
def get_plugin_components(self) -> List[Tuple[ActionInfo, Type[BaseAction]]]:
"""注册插件的所有功能组件。"""
return [
(RemindAction.get_action_info(), RemindAction)
]
return [(RemindAction.get_action_info(), RemindAction)]