Merge afc branch into dev, prioritizing afc changes and migrating database async modifications from dev
This commit is contained in:
125
src/plugins/built_in/affinity_flow_chatter/README.md
Normal file
125
src/plugins/built_in/affinity_flow_chatter/README.md
Normal file
@@ -0,0 +1,125 @@
|
||||
# 亲和力聊天处理器插件
|
||||
|
||||
## 概述
|
||||
|
||||
这是一个内置的chatter插件,实现了基于亲和力流的智能聊天处理器,具有兴趣度评分和人物关系构建功能。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- **智能兴趣度评分**: 自动识别和评估用户兴趣话题
|
||||
- **人物关系系统**: 根据互动历史建立和维持用户关系
|
||||
- **多聊天类型支持**: 支持私聊和群聊场景
|
||||
- **插件化架构**: 完全集成到插件系统中
|
||||
|
||||
## 组件架构
|
||||
|
||||
### BaseChatter (抽象基类)
|
||||
- 位置: `src/plugin_system/base/base_chatter.py`
|
||||
- 功能: 定义所有chatter组件的基础接口
|
||||
- 必须实现的方法: `execute(context: StreamContext) -> dict`
|
||||
|
||||
### ChatterManager (管理器)
|
||||
- 位置: `src/chat/chatter_manager.py`
|
||||
- 功能: 管理和调度所有chatter组件
|
||||
- 特性: 自动从插件系统注册和发现chatter组件
|
||||
|
||||
### AffinityChatter (具体实现)
|
||||
- 位置: `src/plugins/built_in/chatter/affinity_chatter.py`
|
||||
- 功能: 亲和力流聊天处理器的具体实现
|
||||
- 支持的聊天类型: PRIVATE, GROUP
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 1. 基本使用
|
||||
|
||||
```python
|
||||
from src.chat.chatter_manager import ChatterManager
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
|
||||
# 初始化
|
||||
action_manager = ChatterActionManager()
|
||||
chatter_manager = ChatterManager(action_manager)
|
||||
|
||||
# 处理消息流
|
||||
result = await chatter_manager.process_stream_context(stream_id, context)
|
||||
```
|
||||
|
||||
### 2. 创建自定义Chatter
|
||||
|
||||
```python
|
||||
from src.plugin_system.base.base_chatter import BaseChatter
|
||||
from src.plugin_system.base.component_types import ChatType, ComponentType
|
||||
from src.plugin_system.base.component_types import ChatterInfo
|
||||
|
||||
class CustomChatter(BaseChatter):
|
||||
chat_types = [ChatType.PRIVATE] # 只支持私聊
|
||||
|
||||
async def execute(self, context: StreamContext) -> dict:
|
||||
# 实现你的聊天逻辑
|
||||
return {"success": True, "message": "处理完成"}
|
||||
|
||||
# 在插件中注册
|
||||
async def on_load(self):
|
||||
chatter_info = ChatterInfo(
|
||||
name="custom_chatter",
|
||||
component_type=ComponentType.CHATTER,
|
||||
description="自定义聊天处理器",
|
||||
enabled=True,
|
||||
plugin_name=self.name,
|
||||
chat_type_allow=ChatType.PRIVATE
|
||||
)
|
||||
|
||||
ComponentRegistry.register_component(
|
||||
component_info=chatter_info,
|
||||
component_class=CustomChatter
|
||||
)
|
||||
```
|
||||
|
||||
## 配置
|
||||
|
||||
### 插件配置文件
|
||||
- 位置: `src/plugins/built_in/chatter/_manifest.json`
|
||||
- 包含插件信息和组件配置
|
||||
|
||||
### 聊天类型
|
||||
- `PRIVATE`: 私聊
|
||||
- `GROUP`: 群聊
|
||||
- `ALL`: 所有类型
|
||||
|
||||
## 核心概念
|
||||
|
||||
### 1. 兴趣值系统
|
||||
- 自动识别同类话题
|
||||
- 兴趣值会根据聊天频率增减
|
||||
- 支持新话题的自动学习
|
||||
|
||||
### 2. 人物关系系统
|
||||
- 根据互动质量建立关系分
|
||||
- 不同关系分对应不同的回复风格
|
||||
- 支持情感化的交流
|
||||
|
||||
### 3. 执行流程
|
||||
1. 接收StreamContext
|
||||
2. 使用ActionPlanner进行规划
|
||||
3. 执行相应的Action
|
||||
4. 返回处理结果
|
||||
|
||||
## 扩展开发
|
||||
|
||||
### 添加新的Chatter类型
|
||||
1. 继承BaseChatter类
|
||||
2. 实现execute方法
|
||||
3. 在插件中注册组件
|
||||
4. 配置支持的聊天类型
|
||||
|
||||
### 集成现有功能
|
||||
- 使用ActionPlanner进行动作规划
|
||||
- 通过ActionManager执行动作
|
||||
- 利用现有的记忆和知识系统
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. 所有chatter组件必须实现`execute`方法
|
||||
2. 插件注册时需要指定支持的聊天类型
|
||||
3. 组件名称不能包含点号(.)
|
||||
4. 确保在插件卸载时正确清理资源
|
||||
7
src/plugins/built_in/affinity_flow_chatter/__init__.py
Normal file
7
src/plugins/built_in/affinity_flow_chatter/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
亲和力聊天处理器插件
|
||||
"""
|
||||
|
||||
from .plugin import AffinityChatterPlugin
|
||||
|
||||
__all__ = ["AffinityChatterPlugin"]
|
||||
23
src/plugins/built_in/affinity_flow_chatter/_manifest.json
Normal file
23
src/plugins/built_in/affinity_flow_chatter/_manifest.json
Normal file
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"manifest_version": 1,
|
||||
"name": "affinity_chatter",
|
||||
"display_name": "Affinity Flow Chatter",
|
||||
"description": "Built-in chatter plugin for affinity flow with interest scoring and relationship building",
|
||||
"version": "1.0.0",
|
||||
"author": "MoFox",
|
||||
"plugin_class": "AffinityChatterPlugin",
|
||||
"enabled": true,
|
||||
"is_built_in": true,
|
||||
"components": [
|
||||
{
|
||||
"name": "affinity_chatter",
|
||||
"type": "chatter",
|
||||
"description": "Affinity flow chatter with intelligent interest scoring and relationship building",
|
||||
"enabled": true,
|
||||
"chat_type_allow": ["all"]
|
||||
}
|
||||
],
|
||||
"host_application": { "min_version": "0.8.0" },
|
||||
"keywords": ["chatter", "affinity", "conversation"],
|
||||
"categories": ["Chat", "AI"]
|
||||
}
|
||||
236
src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py
Normal file
236
src/plugins/built_in/affinity_flow_chatter/affinity_chatter.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""
|
||||
亲和力聊天处理器
|
||||
基于现有的AffinityFlowChatter重构为插件化组件
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
|
||||
from src.plugin_system.base.base_chatter import BaseChatter
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.plugins.built_in.affinity_flow_chatter.planner import ChatterActionPlanner
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.express.expression_learner import expression_learner_manager
|
||||
|
||||
logger = get_logger("affinity_chatter")
|
||||
|
||||
# 定义颜色
|
||||
SOFT_GREEN = "\033[38;5;118m" # 一个更柔和的绿色
|
||||
RESET_COLOR = "\033[0m"
|
||||
|
||||
|
||||
class AffinityChatter(BaseChatter):
|
||||
"""亲和力聊天处理器"""
|
||||
|
||||
chatter_name: str = "AffinityChatter"
|
||||
chatter_description: str = "基于亲和力模型的智能聊天处理器,支持多种聊天类型"
|
||||
chat_types: list[ChatType] = [ChatType.ALL] # 支持所有聊天类型
|
||||
|
||||
def __init__(self, stream_id: str, action_manager: ChatterActionManager):
|
||||
"""
|
||||
初始化亲和力聊天处理器
|
||||
|
||||
Args:
|
||||
stream_id: 聊天流ID
|
||||
planner: 动作规划器
|
||||
action_manager: 动作管理器
|
||||
"""
|
||||
super().__init__(stream_id, action_manager)
|
||||
self.planner = ChatterActionPlanner(stream_id, action_manager)
|
||||
|
||||
# 处理器统计
|
||||
self.stats = {
|
||||
"messages_processed": 0,
|
||||
"plans_created": 0,
|
||||
"actions_executed": 0,
|
||||
"successful_executions": 0,
|
||||
"failed_executions": 0,
|
||||
}
|
||||
self.last_activity_time = time.time()
|
||||
|
||||
async def execute(self, context: StreamContext) -> dict:
|
||||
"""
|
||||
处理StreamContext对象
|
||||
|
||||
Args:
|
||||
context: StreamContext对象,包含聊天流的所有消息信息
|
||||
|
||||
Returns:
|
||||
处理结果字典
|
||||
"""
|
||||
try:
|
||||
# 触发表达学习
|
||||
learner = expression_learner_manager.get_expression_learner(self.stream_id)
|
||||
asyncio.create_task(learner.trigger_learning_for_chat())
|
||||
|
||||
unread_messages = context.get_unread_messages()
|
||||
|
||||
# 使用增强版规划器处理消息
|
||||
actions, target_message = await self.planner.plan(context=context)
|
||||
self.stats["plans_created"] += 1
|
||||
|
||||
# 执行动作(如果规划器返回了动作)
|
||||
execution_result = {"executed_count": len(actions) if actions else 0}
|
||||
if actions:
|
||||
logger.debug(f"聊天流 {self.stream_id} 生成了 {len(actions)} 个动作")
|
||||
|
||||
# 更新统计
|
||||
self.stats["messages_processed"] += 1
|
||||
self.stats["actions_executed"] += execution_result.get("executed_count", 0)
|
||||
self.stats["successful_executions"] += 1
|
||||
self.last_activity_time = time.time()
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
"stream_id": self.stream_id,
|
||||
"plan_created": True,
|
||||
"actions_count": len(actions) if actions else 0,
|
||||
"has_target_message": target_message is not None,
|
||||
"unread_messages_processed": len(unread_messages),
|
||||
**execution_result,
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
f"聊天流 {self.stream_id} StreamContext处理成功: 动作数={result['actions_count']}, 未读消息={result['unread_messages_processed']}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"亲和力聊天处理器 {self.stream_id} 处理StreamContext时出错: {e}\n{traceback.format_exc()}")
|
||||
self.stats["failed_executions"] += 1
|
||||
self.last_activity_time = time.time()
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"stream_id": self.stream_id,
|
||||
"error_message": str(e),
|
||||
"executed_count": 0,
|
||||
}
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取处理器统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
return self.stats.copy()
|
||||
|
||||
def get_planner_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取规划器统计信息
|
||||
|
||||
Returns:
|
||||
规划器统计信息字典
|
||||
"""
|
||||
return self.planner.get_planner_stats()
|
||||
|
||||
def get_interest_scoring_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取兴趣度评分统计信息
|
||||
|
||||
Returns:
|
||||
兴趣度评分统计信息字典
|
||||
"""
|
||||
return self.planner.get_interest_scoring_stats()
|
||||
|
||||
def get_relationship_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取用户关系统计信息
|
||||
|
||||
Returns:
|
||||
用户关系统计信息字典
|
||||
"""
|
||||
return self.planner.get_relationship_stats()
|
||||
|
||||
def get_current_mood_state(self) -> str:
|
||||
"""
|
||||
获取当前聊天的情绪状态
|
||||
|
||||
Returns:
|
||||
当前情绪状态描述
|
||||
"""
|
||||
return self.planner.get_current_mood_state()
|
||||
|
||||
def get_mood_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取情绪状态统计信息
|
||||
|
||||
Returns:
|
||||
情绪状态统计信息字典
|
||||
"""
|
||||
return self.planner.get_mood_stats()
|
||||
|
||||
def get_user_relationship(self, user_id: str) -> float:
|
||||
"""
|
||||
获取用户关系分
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
用户关系分 (0.0-1.0)
|
||||
"""
|
||||
return self.planner.get_user_relationship(user_id)
|
||||
|
||||
def update_interest_keywords(self, new_keywords: dict):
|
||||
"""
|
||||
更新兴趣关键词
|
||||
|
||||
Args:
|
||||
new_keywords: 新的兴趣关键词字典
|
||||
"""
|
||||
self.planner.update_interest_keywords(new_keywords)
|
||||
logger.info(f"聊天流 {self.stream_id} 已更新兴趣关键词: {list(new_keywords.keys())}")
|
||||
|
||||
def reset_stats(self):
|
||||
"""重置统计信息"""
|
||||
self.stats = {
|
||||
"messages_processed": 0,
|
||||
"plans_created": 0,
|
||||
"actions_executed": 0,
|
||||
"successful_executions": 0,
|
||||
"failed_executions": 0,
|
||||
}
|
||||
|
||||
def is_active(self, max_inactive_minutes: int = 60) -> bool:
|
||||
"""
|
||||
检查处理器是否活跃
|
||||
|
||||
Args:
|
||||
max_inactive_minutes: 最大不活跃分钟数
|
||||
|
||||
Returns:
|
||||
是否活跃
|
||||
"""
|
||||
current_time = time.time()
|
||||
max_inactive_seconds = max_inactive_minutes * 60
|
||||
return (current_time - self.last_activity_time) < max_inactive_seconds
|
||||
|
||||
def get_activity_time(self) -> float:
|
||||
"""
|
||||
获取最后活动时间
|
||||
|
||||
Returns:
|
||||
最后活动时间戳
|
||||
"""
|
||||
return self.last_activity_time
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""字符串表示"""
|
||||
return f"AffinityChatter(stream_id={self.stream_id}, messages={self.stats['messages_processed']})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""详细字符串表示"""
|
||||
return (
|
||||
f"AffinityChatter(stream_id={self.stream_id}, "
|
||||
f"messages_processed={self.stats['messages_processed']}, "
|
||||
f"plans_created={self.stats['plans_created']}, "
|
||||
f"last_activity={datetime.fromtimestamp(self.last_activity_time)})"
|
||||
)
|
||||
333
src/plugins/built_in/affinity_flow_chatter/interest_scoring.py
Normal file
333
src/plugins/built_in/affinity_flow_chatter/interest_scoring.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""
|
||||
兴趣度评分系统
|
||||
基于多维度评分机制,包括兴趣匹配度、用户关系分、提及度和时间因子
|
||||
现在使用embedding计算智能兴趣匹配
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.info_data_model import InterestScore
|
||||
from src.chat.interest_system import bot_interest_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker
|
||||
logger = get_logger("chatter_interest_scoring")
|
||||
|
||||
# 定义颜色
|
||||
SOFT_BLUE = "\033[38;5;67m"
|
||||
RESET_COLOR = "\033[0m"
|
||||
|
||||
|
||||
class ChatterInterestScoringSystem:
|
||||
"""兴趣度评分系统"""
|
||||
|
||||
def __init__(self):
|
||||
# 智能兴趣匹配配置
|
||||
self.use_smart_matching = True
|
||||
|
||||
# 从配置加载评分权重
|
||||
affinity_config = global_config.affinity_flow
|
||||
self.score_weights = {
|
||||
"interest_match": affinity_config.keyword_match_weight, # 兴趣匹配度权重
|
||||
"relationship": affinity_config.relationship_weight, # 关系分权重
|
||||
"mentioned": affinity_config.mention_bot_weight, # 是否提及bot权重
|
||||
}
|
||||
|
||||
# 评分阈值
|
||||
self.reply_threshold = affinity_config.reply_action_interest_threshold # 回复动作兴趣阈值
|
||||
self.mention_threshold = affinity_config.mention_bot_adjustment_threshold # 提及bot后的调整阈值
|
||||
|
||||
# 连续不回复概率提升
|
||||
self.no_reply_count = 0
|
||||
self.max_no_reply_count = affinity_config.max_no_reply_count
|
||||
self.probability_boost_per_no_reply = (
|
||||
affinity_config.no_reply_threshold_adjustment / affinity_config.max_no_reply_count
|
||||
) # 每次不回复增加的概率
|
||||
|
||||
# 用户关系数据
|
||||
self.user_relationships: Dict[str, float] = {} # user_id -> relationship_score
|
||||
|
||||
async def calculate_interest_scores(
|
||||
self, messages: List[DatabaseMessages], bot_nickname: str
|
||||
) -> List[InterestScore]:
|
||||
"""计算消息的兴趣度评分"""
|
||||
user_messages = [msg for msg in messages if str(msg.user_info.user_id) != str(global_config.bot.qq_account)]
|
||||
if not user_messages:
|
||||
return []
|
||||
|
||||
scores = []
|
||||
for _, msg in enumerate(user_messages, 1):
|
||||
score = await self._calculate_single_message_score(msg, bot_nickname)
|
||||
scores.append(score)
|
||||
|
||||
return scores
|
||||
|
||||
async def _calculate_single_message_score(self, message: DatabaseMessages, bot_nickname: str) -> InterestScore:
|
||||
"""计算单条消息的兴趣度评分"""
|
||||
|
||||
keywords = self._extract_keywords_from_database(message)
|
||||
interest_match_score = await self._calculate_interest_match_score(message.processed_plain_text, keywords)
|
||||
relationship_score = self._calculate_relationship_score(message.user_info.user_id)
|
||||
mentioned_score = self._calculate_mentioned_score(message, bot_nickname)
|
||||
|
||||
total_score = (
|
||||
interest_match_score * self.score_weights["interest_match"]
|
||||
+ relationship_score * self.score_weights["relationship"]
|
||||
+ mentioned_score * self.score_weights["mentioned"]
|
||||
)
|
||||
|
||||
details = {
|
||||
"interest_match": f"兴趣匹配: {interest_match_score:.3f}",
|
||||
"relationship": f"关系: {relationship_score:.3f}",
|
||||
"mentioned": f"提及: {mentioned_score:.3f}",
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
f"消息得分详情: {total_score:.3f} (匹配: {interest_match_score:.2f}, 关系: {relationship_score:.2f}, 提及: {mentioned_score:.2f})"
|
||||
)
|
||||
|
||||
return InterestScore(
|
||||
message_id=message.message_id,
|
||||
total_score=total_score,
|
||||
interest_match_score=interest_match_score,
|
||||
relationship_score=relationship_score,
|
||||
mentioned_score=mentioned_score,
|
||||
details=details,
|
||||
)
|
||||
|
||||
async def _calculate_interest_match_score(self, content: str, keywords: List[str] = None) -> float:
|
||||
"""计算兴趣匹配度 - 使用智能embedding匹配"""
|
||||
if not content:
|
||||
return 0.0
|
||||
|
||||
# 使用智能匹配(embedding)
|
||||
if self.use_smart_matching and bot_interest_manager.is_initialized:
|
||||
return await self._calculate_smart_interest_match(content, keywords)
|
||||
else:
|
||||
# 智能匹配未初始化,返回默认分数
|
||||
return 0.3
|
||||
|
||||
async def _calculate_smart_interest_match(self, content: str, keywords: List[str] = None) -> float:
|
||||
"""使用embedding计算智能兴趣匹配"""
|
||||
try:
|
||||
# 如果没有传入关键词,则提取
|
||||
if not keywords:
|
||||
keywords = self._extract_keywords_from_content(content)
|
||||
|
||||
# 使用机器人兴趣管理器计算匹配度
|
||||
match_result = await bot_interest_manager.calculate_interest_match(content, keywords)
|
||||
|
||||
if match_result:
|
||||
# 返回匹配分数,考虑置信度和匹配标签数量
|
||||
affinity_config = global_config.affinity_flow
|
||||
match_count_bonus = min(
|
||||
len(match_result.matched_tags) * affinity_config.match_count_bonus, affinity_config.max_match_bonus
|
||||
)
|
||||
final_score = match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus
|
||||
return final_score
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"智能兴趣匹配计算失败: {e}")
|
||||
return 0.0
|
||||
|
||||
def _extract_keywords_from_database(self, message: DatabaseMessages) -> List[str]:
|
||||
"""从数据库消息中提取关键词"""
|
||||
keywords = []
|
||||
|
||||
# 尝试从 key_words 字段提取(存储的是JSON字符串)
|
||||
if message.key_words:
|
||||
try:
|
||||
import orjson
|
||||
|
||||
keywords = orjson.loads(message.key_words)
|
||||
if not isinstance(keywords, list):
|
||||
keywords = []
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
keywords = []
|
||||
|
||||
# 如果没有 keywords,尝试从 key_words_lite 提取
|
||||
if not keywords and message.key_words_lite:
|
||||
try:
|
||||
import orjson
|
||||
|
||||
keywords = orjson.loads(message.key_words_lite)
|
||||
if not isinstance(keywords, list):
|
||||
keywords = []
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
keywords = []
|
||||
|
||||
# 如果还是没有,从消息内容中提取(降级方案)
|
||||
if not keywords:
|
||||
keywords = self._extract_keywords_from_content(message.processed_plain_text)
|
||||
|
||||
return keywords[:15] # 返回前15个关键词
|
||||
|
||||
def _extract_keywords_from_content(self, content: str) -> List[str]:
|
||||
"""从内容中提取关键词(降级方案)"""
|
||||
import re
|
||||
|
||||
# 清理文本
|
||||
content = re.sub(r"[^\w\s\u4e00-\u9fff]", " ", content) # 保留中文、英文、数字
|
||||
words = content.split()
|
||||
|
||||
# 过滤和关键词提取
|
||||
keywords = []
|
||||
for word in words:
|
||||
word = word.strip()
|
||||
if (
|
||||
len(word) >= 2 # 至少2个字符
|
||||
and word.isalnum() # 字母数字
|
||||
and not word.isdigit()
|
||||
): # 不是纯数字
|
||||
keywords.append(word.lower())
|
||||
|
||||
# 去重并限制数量
|
||||
unique_keywords = list(set(keywords))
|
||||
return unique_keywords[:10] # 返回前10个唯一关键词
|
||||
|
||||
def _calculate_relationship_score(self, user_id: str) -> float:
|
||||
"""计算关系分 - 从数据库获取关系分"""
|
||||
# 优先使用内存中的关系分
|
||||
if user_id in self.user_relationships:
|
||||
relationship_value = self.user_relationships[user_id]
|
||||
return min(relationship_value, 1.0)
|
||||
|
||||
# 如果内存中没有,尝试从关系追踪器获取
|
||||
if hasattr(self, "relationship_tracker") and self.relationship_tracker:
|
||||
try:
|
||||
relationship_score = self.relationship_tracker.get_user_relationship_score(user_id)
|
||||
# 同时更新内存缓存
|
||||
self.user_relationships[user_id] = relationship_score
|
||||
return relationship_score
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
# 尝试从全局关系追踪器获取
|
||||
try:
|
||||
from .relationship_tracker import ChatterRelationshipTracker
|
||||
|
||||
global_tracker = ChatterRelationshipTracker()
|
||||
if global_tracker:
|
||||
relationship_score = global_tracker.get_user_relationship_score(user_id)
|
||||
# 同时更新内存缓存
|
||||
self.user_relationships[user_id] = relationship_score
|
||||
return relationship_score
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 默认新用户的基础分
|
||||
return global_config.affinity_flow.base_relationship_score
|
||||
|
||||
def _calculate_mentioned_score(self, msg: DatabaseMessages, bot_nickname: str) -> float:
|
||||
"""计算提及分数"""
|
||||
if not msg.processed_plain_text:
|
||||
return 0.0
|
||||
|
||||
# 检查是否被提及
|
||||
bot_aliases = [bot_nickname] + global_config.bot.alias_names
|
||||
is_mentioned = msg.is_mentioned or any(alias in msg.processed_plain_text for alias in bot_aliases if alias)
|
||||
|
||||
# 如果被提及或是私聊,都视为提及了bot
|
||||
if is_mentioned or not hasattr(msg, "chat_info_group_id"):
|
||||
return global_config.affinity_flow.mention_bot_interest_score
|
||||
|
||||
return 0.0
|
||||
|
||||
def should_reply(self, score: InterestScore, message: "DatabaseMessages") -> bool:
|
||||
"""判断是否应该回复"""
|
||||
base_threshold = self.reply_threshold
|
||||
|
||||
# 如果被提及,降低阈值
|
||||
if score.mentioned_score >= global_config.affinity_flow.mention_bot_adjustment_threshold:
|
||||
base_threshold = self.mention_threshold
|
||||
|
||||
# 计算连续不回复的概率提升
|
||||
probability_boost = min(self.no_reply_count * self.probability_boost_per_no_reply, 0.8)
|
||||
effective_threshold = base_threshold - probability_boost
|
||||
|
||||
# 做出决策
|
||||
should_reply = score.total_score >= effective_threshold
|
||||
decision = "回复" if should_reply else "不回复"
|
||||
logger.info(
|
||||
f"{SOFT_BLUE}决策: {decision} (兴趣度: {score.total_score:.3f} / 阈值: {effective_threshold:.3f}){RESET_COLOR}"
|
||||
)
|
||||
|
||||
return should_reply, score.total_score
|
||||
|
||||
def record_reply_action(self, did_reply: bool):
|
||||
"""记录回复动作"""
|
||||
old_count = self.no_reply_count
|
||||
if did_reply:
|
||||
self.no_reply_count = max(0, self.no_reply_count - global_config.affinity_flow.reply_cooldown_reduction)
|
||||
action = "回复"
|
||||
else:
|
||||
self.no_reply_count += 1
|
||||
action = "不回复"
|
||||
|
||||
# 限制最大计数
|
||||
self.no_reply_count = min(self.no_reply_count, self.max_no_reply_count)
|
||||
logger.info(f"动作: {action}, 连续不回复次数: {old_count} -> {self.no_reply_count}")
|
||||
|
||||
def update_user_relationship(self, user_id: str, relationship_change: float):
|
||||
"""更新用户关系"""
|
||||
old_score = self.user_relationships.get(
|
||||
user_id, global_config.affinity_flow.base_relationship_score
|
||||
) # 默认新用户分数
|
||||
new_score = max(0.0, min(1.0, old_score + relationship_change))
|
||||
|
||||
self.user_relationships[user_id] = new_score
|
||||
|
||||
logger.info(f"用户关系: {user_id} | {old_score:.3f} → {new_score:.3f}")
|
||||
|
||||
def get_user_relationship(self, user_id: str) -> float:
|
||||
"""获取用户关系分"""
|
||||
return self.user_relationships.get(user_id, 0.3)
|
||||
|
||||
def get_scoring_stats(self) -> Dict:
|
||||
"""获取评分系统统计"""
|
||||
return {
|
||||
"no_reply_count": self.no_reply_count,
|
||||
"max_no_reply_count": self.max_no_reply_count,
|
||||
"reply_threshold": self.reply_threshold,
|
||||
"mention_threshold": self.mention_threshold,
|
||||
"user_relationships": len(self.user_relationships),
|
||||
}
|
||||
|
||||
def reset_stats(self):
|
||||
"""重置统计信息"""
|
||||
self.no_reply_count = 0
|
||||
logger.info("重置兴趣度评分系统统计")
|
||||
|
||||
async def initialize_smart_interests(self, personality_description: str, personality_id: str = "default"):
|
||||
"""初始化智能兴趣系统"""
|
||||
try:
|
||||
logger.info("开始初始化智能兴趣系统...")
|
||||
logger.info(f"人设ID: {personality_id}, 描述长度: {len(personality_description)}")
|
||||
|
||||
await bot_interest_manager.initialize(personality_description, personality_id)
|
||||
logger.info("智能兴趣系统初始化完成。")
|
||||
|
||||
# 显示初始化后的统计信息
|
||||
bot_interest_manager.get_interest_stats()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"初始化智能兴趣系统失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def get_matching_config(self) -> Dict[str, Any]:
|
||||
"""获取匹配配置信息"""
|
||||
return {
|
||||
"use_smart_matching": self.use_smart_matching,
|
||||
"smart_system_initialized": bot_interest_manager.is_initialized,
|
||||
"smart_system_stats": bot_interest_manager.get_interest_stats()
|
||||
if bot_interest_manager.is_initialized
|
||||
else None,
|
||||
}
|
||||
|
||||
|
||||
# 创建全局兴趣评分系统实例
|
||||
chatter_interest_scoring_system = ChatterInterestScoringSystem()
|
||||
368
src/plugins/built_in/affinity_flow_chatter/plan_executor.py
Normal file
368
src/plugins/built_in/affinity_flow_chatter/plan_executor.py
Normal file
@@ -0,0 +1,368 @@
|
||||
"""
|
||||
PlanExecutor: 接收 Plan 对象并执行其中的所有动作。
|
||||
集成用户关系追踪机制,自动记录交互并更新关系。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.common.data_models.info_data_model import Plan, ActionPlannerInfo
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("plan_executor")
|
||||
|
||||
|
||||
class ChatterPlanExecutor:
|
||||
"""
|
||||
增强版PlanExecutor,集成用户关系追踪机制。
|
||||
|
||||
功能:
|
||||
1. 执行Plan中的所有动作
|
||||
2. 自动记录用户交互并添加到关系追踪
|
||||
3. 分类执行回复动作和其他动作
|
||||
4. 提供完整的执行统计和监控
|
||||
"""
|
||||
|
||||
def __init__(self, action_manager: ChatterActionManager):
|
||||
"""
|
||||
初始化增强版PlanExecutor。
|
||||
|
||||
Args:
|
||||
action_manager (ChatterActionManager): 用于实际执行各种动作的管理器实例。
|
||||
"""
|
||||
self.action_manager = action_manager
|
||||
|
||||
# 执行统计
|
||||
self.execution_stats = {
|
||||
"total_executed": 0,
|
||||
"successful_executions": 0,
|
||||
"failed_executions": 0,
|
||||
"reply_executions": 0,
|
||||
"other_action_executions": 0,
|
||||
"execution_times": [],
|
||||
}
|
||||
|
||||
# 用户关系追踪引用
|
||||
self.relationship_tracker = None
|
||||
|
||||
def set_relationship_tracker(self, relationship_tracker):
|
||||
"""设置关系追踪器"""
|
||||
self.relationship_tracker = relationship_tracker
|
||||
|
||||
async def execute(self, plan: Plan) -> Dict[str, any]:
|
||||
"""
|
||||
遍历并执行Plan对象中`decided_actions`列表里的所有动作。
|
||||
|
||||
Args:
|
||||
plan (Plan): 包含待执行动作列表的Plan对象。
|
||||
|
||||
Returns:
|
||||
Dict[str, any]: 执行结果统计信息
|
||||
"""
|
||||
if not plan.decided_actions:
|
||||
logger.info("没有需要执行的动作。")
|
||||
return {"executed_count": 0, "results": []}
|
||||
|
||||
# 像hfc一样,提前打印将要执行的动作
|
||||
action_types = [action.action_type for action in plan.decided_actions]
|
||||
logger.info(f"选择动作: {', '.join(action_types) if action_types else '无'}")
|
||||
|
||||
execution_results = []
|
||||
reply_actions = []
|
||||
other_actions = []
|
||||
|
||||
# 分类动作:回复动作和其他动作
|
||||
for action_info in plan.decided_actions:
|
||||
if action_info.action_type in ["reply", "proactive_reply"]:
|
||||
reply_actions.append(action_info)
|
||||
else:
|
||||
other_actions.append(action_info)
|
||||
|
||||
# 执行回复动作(优先执行)
|
||||
if reply_actions:
|
||||
reply_result = await self._execute_reply_actions(reply_actions, plan)
|
||||
execution_results.extend(reply_result["results"])
|
||||
self.execution_stats["reply_executions"] += len(reply_actions)
|
||||
|
||||
# 将其他动作放入后台任务执行,避免阻塞主流程
|
||||
if other_actions:
|
||||
asyncio.create_task(self._execute_other_actions(other_actions, plan))
|
||||
logger.info(f"已将 {len(other_actions)} 个其他动作放入后台任务执行。")
|
||||
# 注意:后台任务的结果不会立即计入本次返回的统计数据
|
||||
|
||||
# 更新总体统计
|
||||
self.execution_stats["total_executed"] += len(plan.decided_actions)
|
||||
successful_count = sum(1 for r in execution_results if r["success"])
|
||||
self.execution_stats["successful_executions"] += successful_count
|
||||
self.execution_stats["failed_executions"] += len(execution_results) - successful_count
|
||||
|
||||
logger.info(
|
||||
f"规划执行完成: 总数={len(plan.decided_actions)}, 成功={successful_count}, 失败={len(execution_results) - successful_count}"
|
||||
)
|
||||
|
||||
return {
|
||||
"executed_count": len(plan.decided_actions),
|
||||
"successful_count": successful_count,
|
||||
"failed_count": len(execution_results) - successful_count,
|
||||
"results": execution_results,
|
||||
}
|
||||
|
||||
async def _execute_reply_actions(self, reply_actions: List[ActionPlannerInfo], plan: Plan) -> Dict[str, any]:
|
||||
"""执行回复动作"""
|
||||
results = []
|
||||
|
||||
for action_info in reply_actions:
|
||||
result = await self._execute_single_reply_action(action_info, plan)
|
||||
results.append(result)
|
||||
|
||||
return {"results": results}
|
||||
|
||||
async def _execute_single_reply_action(self, action_info: ActionPlannerInfo, plan: Plan) -> Dict[str, any]:
|
||||
"""执行单个回复动作"""
|
||||
start_time = time.time()
|
||||
success = False
|
||||
error_message = ""
|
||||
reply_content = ""
|
||||
|
||||
try:
|
||||
logger.info(f"执行回复动作: {action_info.action_type} (原因: {action_info.reasoning})")
|
||||
|
||||
# 获取用户ID - 兼容对象和字典
|
||||
if hasattr(action_info.action_message, "user_info"):
|
||||
user_id = action_info.action_message.user_info.user_id
|
||||
else:
|
||||
user_id = action_info.action_message.get("user_info", {}).get("user_id")
|
||||
|
||||
if user_id == str(global_config.bot.qq_account):
|
||||
logger.warning("尝试回复自己,跳过此动作以防止死循环。")
|
||||
return {
|
||||
"action_type": action_info.action_type,
|
||||
"success": False,
|
||||
"error_message": "尝试回复自己,跳过此动作以防止死循环。",
|
||||
"execution_time": 0,
|
||||
"reasoning": action_info.reasoning,
|
||||
"reply_content": "",
|
||||
}
|
||||
# 构建回复动作参数
|
||||
action_params = {
|
||||
"chat_id": plan.chat_id,
|
||||
"target_message": action_info.action_message,
|
||||
"reasoning": action_info.reasoning,
|
||||
"action_data": action_info.action_data or {},
|
||||
}
|
||||
|
||||
logger.debug(f"📬 [PlanExecutor] 准备调用 ActionManager,target_message: {action_info.action_message}")
|
||||
|
||||
# 通过动作管理器执行回复
|
||||
reply_content = await self.action_manager.execute_action(
|
||||
action_name=action_info.action_type, **action_params
|
||||
)
|
||||
|
||||
success = True
|
||||
logger.info(f"回复动作 '{action_info.action_type}' 执行成功。")
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.error(f"执行回复动作失败: {action_info.action_type}, 错误: {error_message}")
|
||||
|
||||
# 记录用户关系追踪
|
||||
if success and action_info.action_message:
|
||||
await self._track_user_interaction(action_info, plan, reply_content)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
self.execution_stats["execution_times"].append(execution_time)
|
||||
|
||||
return {
|
||||
"action_type": action_info.action_type,
|
||||
"success": success,
|
||||
"error_message": error_message,
|
||||
"execution_time": execution_time,
|
||||
"reasoning": action_info.reasoning,
|
||||
"reply_content": reply_content[:200] + "..." if len(reply_content) > 200 else reply_content,
|
||||
}
|
||||
|
||||
async def _execute_other_actions(self, other_actions: List[ActionPlannerInfo], plan: Plan) -> Dict[str, any]:
|
||||
"""执行其他动作"""
|
||||
results = []
|
||||
|
||||
# 并行执行其他动作
|
||||
tasks = []
|
||||
for action_info in other_actions:
|
||||
task = self._execute_single_other_action(action_info, plan)
|
||||
tasks.append(task)
|
||||
|
||||
if tasks:
|
||||
executed_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
for i, result in enumerate(executed_results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"执行动作 {other_actions[i].action_type} 时发生异常: {result}")
|
||||
results.append(
|
||||
{
|
||||
"action_type": other_actions[i].action_type,
|
||||
"success": False,
|
||||
"error_message": str(result),
|
||||
"execution_time": 0,
|
||||
"reasoning": other_actions[i].reasoning,
|
||||
}
|
||||
)
|
||||
else:
|
||||
results.append(result)
|
||||
|
||||
return {"results": results}
|
||||
|
||||
async def _execute_single_other_action(self, action_info: ActionPlannerInfo, plan: Plan) -> Dict[str, any]:
|
||||
"""执行单个其他动作"""
|
||||
start_time = time.time()
|
||||
success = False
|
||||
error_message = ""
|
||||
|
||||
try:
|
||||
logger.info(f"执行其他动作: {action_info.action_type} (原因: {action_info.reasoning})")
|
||||
|
||||
action_data = action_info.action_data or {}
|
||||
|
||||
# 针对 poke_user 动作,特殊处理
|
||||
if action_info.action_type == "poke_user":
|
||||
target_message = action_info.action_message
|
||||
if target_message:
|
||||
# 优先直接获取 user_id,这才是最可靠的信息
|
||||
user_id = target_message.get("user_id")
|
||||
if user_id:
|
||||
action_data["user_id"] = user_id
|
||||
logger.info(f"检测到戳一戳动作,目标用户ID: {user_id}")
|
||||
else:
|
||||
# 如果没有 user_id,再尝试用 user_nickname 作为备用方案
|
||||
user_name = target_message.get("user_nickname")
|
||||
if user_name:
|
||||
action_data["user_name"] = user_name
|
||||
logger.info(f"检测到戳一戳动作,目标用户: {user_name}")
|
||||
else:
|
||||
logger.warning("无法从戳一戳消息中获取用户ID或昵称。")
|
||||
|
||||
# 传递原始消息ID以支持引用
|
||||
action_data["target_message_id"] = target_message.get("message_id")
|
||||
|
||||
# 构建动作参数
|
||||
action_params = {
|
||||
"chat_id": plan.chat_id,
|
||||
"target_message": action_info.action_message,
|
||||
"reasoning": action_info.reasoning,
|
||||
"action_data": action_data,
|
||||
}
|
||||
|
||||
# 通过动作管理器执行动作
|
||||
await self.action_manager.execute_action(action_name=action_info.action_type, **action_params)
|
||||
|
||||
success = True
|
||||
logger.info(f"其他动作 '{action_info.action_type}' 执行成功。")
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.error(f"执行其他动作失败: {action_info.action_type}, 错误: {error_message}")
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
self.execution_stats["execution_times"].append(execution_time)
|
||||
|
||||
return {
|
||||
"action_type": action_info.action_type,
|
||||
"success": success,
|
||||
"error_message": error_message,
|
||||
"execution_time": execution_time,
|
||||
"reasoning": action_info.reasoning,
|
||||
}
|
||||
|
||||
async def _track_user_interaction(self, action_info: ActionPlannerInfo, plan: Plan, reply_content: str):
|
||||
"""追踪用户交互 - 集成回复后关系追踪"""
|
||||
try:
|
||||
if not action_info.action_message:
|
||||
return
|
||||
|
||||
# 获取用户信息 - 处理对象和字典两种情况
|
||||
if hasattr(action_info.action_message, "user_info"):
|
||||
# 对象情况
|
||||
user_info = action_info.action_message.user_info
|
||||
user_id = user_info.user_id
|
||||
user_name = user_info.user_nickname or user_id
|
||||
user_message = action_info.action_message.content
|
||||
else:
|
||||
# 字典情况
|
||||
user_info = action_info.action_message.get("user_info", {})
|
||||
user_id = user_info.get("user_id")
|
||||
user_name = user_info.get("user_nickname") or user_id
|
||||
user_message = action_info.action_message.get("content", "")
|
||||
|
||||
if not user_id:
|
||||
logger.debug("跳过追踪:缺少用户ID")
|
||||
return
|
||||
|
||||
# 如果有设置关系追踪器,执行回复后关系追踪
|
||||
if self.relationship_tracker:
|
||||
# 记录基础交互信息(保持向后兼容)
|
||||
self.relationship_tracker.add_interaction(
|
||||
user_id=user_id,
|
||||
user_name=user_name,
|
||||
user_message=user_message,
|
||||
bot_reply=reply_content,
|
||||
reply_timestamp=time.time(),
|
||||
)
|
||||
|
||||
# 执行新的回复后关系追踪
|
||||
await self.relationship_tracker.track_reply_relationship(
|
||||
user_id=user_id, user_name=user_name, bot_reply_content=reply_content, reply_timestamp=time.time()
|
||||
)
|
||||
|
||||
logger.debug(f"已执行用户交互追踪: {user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"追踪用户交互时出错: {e}")
|
||||
logger.debug(f"action_message类型: {type(action_info.action_message)}")
|
||||
logger.debug(f"action_message内容: {action_info.action_message}")
|
||||
|
||||
def get_execution_stats(self) -> Dict[str, any]:
|
||||
"""获取执行统计信息"""
|
||||
stats = self.execution_stats.copy()
|
||||
|
||||
# 计算平均执行时间
|
||||
if stats["execution_times"]:
|
||||
avg_time = sum(stats["execution_times"]) / len(stats["execution_times"])
|
||||
stats["average_execution_time"] = avg_time
|
||||
stats["max_execution_time"] = max(stats["execution_times"])
|
||||
stats["min_execution_time"] = min(stats["execution_times"])
|
||||
else:
|
||||
stats["average_execution_time"] = 0
|
||||
stats["max_execution_time"] = 0
|
||||
stats["min_execution_time"] = 0
|
||||
|
||||
# 移除执行时间列表以避免返回过大数据
|
||||
stats.pop("execution_times", None)
|
||||
|
||||
return stats
|
||||
|
||||
def reset_stats(self):
|
||||
"""重置统计信息"""
|
||||
self.execution_stats = {
|
||||
"total_executed": 0,
|
||||
"successful_executions": 0,
|
||||
"failed_executions": 0,
|
||||
"reply_executions": 0,
|
||||
"other_action_executions": 0,
|
||||
"execution_times": [],
|
||||
}
|
||||
|
||||
def get_recent_performance(self, limit: int = 10) -> List[Dict[str, any]]:
|
||||
"""获取最近的执行性能"""
|
||||
recent_times = self.execution_stats["execution_times"][-limit:]
|
||||
if not recent_times:
|
||||
return []
|
||||
|
||||
return [
|
||||
{
|
||||
"execution_index": i + 1,
|
||||
"execution_time": time_val,
|
||||
"timestamp": time.time() - (len(recent_times) - i) * 60, # 估算时间戳
|
||||
}
|
||||
for i, time_val in enumerate(recent_times)
|
||||
]
|
||||
678
src/plugins/built_in/affinity_flow_chatter/plan_filter.py
Normal file
678
src/plugins/built_in/affinity_flow_chatter/plan_filter.py
Normal file
@@ -0,0 +1,678 @@
|
||||
"""
|
||||
PlanFilter: 接收 Plan 对象,根据不同模式的逻辑进行筛选,决定最终要执行的动作。
|
||||
"""
|
||||
|
||||
import orjson
|
||||
import time
|
||||
import traceback
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_actions,
|
||||
build_readable_messages_with_id,
|
||||
get_actions_by_timestamp_with_chat,
|
||||
)
|
||||
from src.chat.utils.prompt import global_prompt_manager
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo, Plan
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ChatType
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
|
||||
logger = get_logger("plan_filter")
|
||||
|
||||
SAKURA_PINK = "\033[38;5;175m"
|
||||
SKY_BLUE = "\033[38;5;117m"
|
||||
RESET_COLOR = "\033[0m"
|
||||
|
||||
|
||||
class ChatterPlanFilter:
|
||||
"""
|
||||
根据 Plan 中的模式和信息,筛选并决定最终的动作。
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str, available_actions: List[str]):
|
||||
"""
|
||||
初始化动作计划筛选器。
|
||||
|
||||
Args:
|
||||
chat_id (str): 当前聊天的唯一标识符。
|
||||
available_actions (List[str]): 当前可用的动作列表。
|
||||
"""
|
||||
self.chat_id = chat_id
|
||||
self.available_actions = available_actions
|
||||
self.planner_llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="planner")
|
||||
self.last_obs_time_mark = 0.0
|
||||
|
||||
async def filter(self, reply_not_available: bool, plan: Plan) -> Plan:
|
||||
"""
|
||||
执行筛选逻辑,并填充 Plan 对象的 decided_actions 字段。
|
||||
"""
|
||||
try:
|
||||
prompt, used_message_id_list = await self._build_prompt(plan)
|
||||
plan.llm_prompt = prompt
|
||||
|
||||
llm_content, _ = await self.planner_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
if llm_content:
|
||||
try:
|
||||
parsed_json = orjson.loads(repair_json(llm_content))
|
||||
except orjson.JSONDecodeError:
|
||||
parsed_json = {
|
||||
"thinking": "",
|
||||
"actions": {"action_type": "no_action", "reason": "返回内容无法解析为JSON"},
|
||||
}
|
||||
|
||||
if "reply" in plan.available_actions and reply_not_available:
|
||||
# 如果reply动作不可用,但llm返回的仍然有reply,则改为no_reply
|
||||
if (
|
||||
isinstance(parsed_json, dict)
|
||||
and parsed_json.get("actions", {}).get("action_type", "") == "reply"
|
||||
):
|
||||
parsed_json["actions"]["action_type"] = "no_reply"
|
||||
elif isinstance(parsed_json, list):
|
||||
for item in parsed_json:
|
||||
if isinstance(item, dict) and item.get("actions", {}).get("action_type", "") == "reply":
|
||||
item["actions"]["action_type"] = "no_reply"
|
||||
item["actions"]["reason"] += " (但由于兴趣度不足,reply动作不可用,已改为no_reply)"
|
||||
|
||||
if isinstance(parsed_json, dict):
|
||||
parsed_json = [parsed_json]
|
||||
|
||||
if isinstance(parsed_json, list):
|
||||
final_actions = []
|
||||
reply_action_added = False
|
||||
# 定义回复类动作的集合,方便扩展
|
||||
reply_action_types = {"reply", "proactive_reply"}
|
||||
|
||||
for item in parsed_json:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
# 预解析 action_type 来进行判断
|
||||
thinking = item.get("thinking", "未提供思考过程")
|
||||
actions_obj = item.get("actions", {})
|
||||
|
||||
# 处理actions字段可能是字典或列表的情况
|
||||
if isinstance(actions_obj, dict):
|
||||
action_type = actions_obj.get("action_type", "no_action")
|
||||
elif isinstance(actions_obj, list) and actions_obj:
|
||||
# 如果是列表,取第一个元素的action_type
|
||||
first_action = actions_obj[0]
|
||||
if isinstance(first_action, dict):
|
||||
action_type = first_action.get("action_type", "no_action")
|
||||
else:
|
||||
action_type = "no_action"
|
||||
else:
|
||||
action_type = "no_action"
|
||||
|
||||
if action_type in reply_action_types:
|
||||
if not reply_action_added:
|
||||
final_actions.extend(
|
||||
await self._parse_single_action(item, used_message_id_list, plan)
|
||||
)
|
||||
reply_action_added = True
|
||||
else:
|
||||
# 非回复类动作直接添加
|
||||
final_actions.extend(await self._parse_single_action(item, used_message_id_list, plan))
|
||||
|
||||
if thinking and thinking != "未提供思考过程":
|
||||
logger.info(f"\n{SAKURA_PINK}思考: {thinking}{RESET_COLOR}\n")
|
||||
plan.decided_actions = self._filter_no_actions(final_actions)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"筛选 Plan 时出错: {e}\n{traceback.format_exc()}")
|
||||
plan.decided_actions = [ActionPlannerInfo(action_type="no_action", reasoning=f"筛选时出错: {e}")]
|
||||
|
||||
# 在返回最终计划前,打印将要执行的动作
|
||||
action_types = [action.action_type for action in plan.decided_actions]
|
||||
logger.info(f"选择动作: [{SKY_BLUE}{', '.join(action_types) if action_types else '无'}{RESET_COLOR}]")
|
||||
|
||||
return plan
|
||||
|
||||
async def _build_prompt(self, plan: Plan) -> tuple[str, list]:
|
||||
"""
|
||||
根据 Plan 对象构建提示词。
|
||||
"""
|
||||
try:
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
bot_name = global_config.bot.nickname
|
||||
bot_nickname = (
|
||||
f",也有人叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else ""
|
||||
)
|
||||
bot_core_personality = global_config.personality.personality_core
|
||||
identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:"
|
||||
|
||||
schedule_block = ""
|
||||
# 优先检查是否被吵醒
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
angry_prompt_addition = ""
|
||||
wakeup_mgr = message_manager.wakeup_manager
|
||||
|
||||
# 双重检查确保愤怒状态不会丢失
|
||||
# 检查1: 直接从 wakeup_manager 获取
|
||||
if wakeup_mgr.is_in_angry_state():
|
||||
angry_prompt_addition = wakeup_mgr.get_angry_prompt_addition()
|
||||
|
||||
# 检查2: 如果上面没获取到,再从 mood_manager 确认
|
||||
if not angry_prompt_addition:
|
||||
chat_mood_for_check = mood_manager.get_mood_by_chat_id(plan.chat_id)
|
||||
if chat_mood_for_check.is_angry_from_wakeup:
|
||||
angry_prompt_addition = global_config.sleep_system.angry_prompt
|
||||
|
||||
if angry_prompt_addition:
|
||||
schedule_block = angry_prompt_addition
|
||||
elif global_config.planning_system.schedule_enable:
|
||||
if current_activity := schedule_manager.get_current_activity():
|
||||
schedule_block = f"你当前正在:{current_activity},但注意它与群聊的聊天无关。"
|
||||
|
||||
mood_block = ""
|
||||
# 如果被吵醒,则心情也是愤怒的,不需要另外的情绪模块
|
||||
if not angry_prompt_addition and global_config.mood.enable_mood:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(plan.chat_id)
|
||||
mood_block = f"你现在的心情是:{chat_mood.mood_state}"
|
||||
|
||||
if plan.mode == ChatMode.PROACTIVE:
|
||||
long_term_memory_block = await self._get_long_term_memory_context()
|
||||
|
||||
chat_content_block, message_id_list = build_readable_messages_with_id(
|
||||
messages=[msg.flatten() for msg in plan.chat_history],
|
||||
timestamp_mode="normal",
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
)
|
||||
|
||||
prompt_template = await global_prompt_manager.get_prompt_async("proactive_planner_prompt")
|
||||
actions_before_now = get_actions_by_timestamp_with_chat(
|
||||
chat_id=plan.chat_id,
|
||||
timestamp_start=time.time() - 3600,
|
||||
timestamp_end=time.time(),
|
||||
limit=5,
|
||||
)
|
||||
actions_before_now_block = build_readable_actions(actions=actions_before_now)
|
||||
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
||||
|
||||
prompt = prompt_template.format(
|
||||
time_block=time_block,
|
||||
identity_block=identity_block,
|
||||
schedule_block=schedule_block,
|
||||
mood_block=mood_block,
|
||||
long_term_memory_block=long_term_memory_block,
|
||||
chat_content_block=chat_content_block or "最近没有聊天内容。",
|
||||
actions_before_now_block=actions_before_now_block,
|
||||
)
|
||||
return prompt, message_id_list
|
||||
|
||||
# 构建已读/未读历史消息
|
||||
read_history_block, unread_history_block, message_id_list = await self._build_read_unread_history_blocks(
|
||||
plan
|
||||
)
|
||||
|
||||
# 为了兼容性,保留原有的chat_content_block
|
||||
chat_content_block, _ = build_readable_messages_with_id(
|
||||
messages=[msg.flatten() for msg in plan.chat_history],
|
||||
timestamp_mode="normal",
|
||||
read_mark=self.last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
actions_before_now = get_actions_by_timestamp_with_chat(
|
||||
chat_id=plan.chat_id,
|
||||
timestamp_start=time.time() - 3600,
|
||||
timestamp_end=time.time(),
|
||||
limit=5,
|
||||
)
|
||||
|
||||
actions_before_now_block = build_readable_actions(actions=actions_before_now)
|
||||
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
||||
|
||||
self.last_obs_time_mark = time.time()
|
||||
|
||||
mentioned_bonus = ""
|
||||
if global_config.chat.mentioned_bot_inevitable_reply:
|
||||
mentioned_bonus = "\n- 有人提到你"
|
||||
if global_config.chat.at_bot_inevitable_reply:
|
||||
mentioned_bonus = "\n- 有人提到你,或者at你"
|
||||
|
||||
if plan.mode == ChatMode.FOCUS:
|
||||
no_action_block = """
|
||||
动作:no_action
|
||||
动作描述:不选择任何动作
|
||||
{{
|
||||
"action": "no_action",
|
||||
"reason":"不动作的原因"
|
||||
}}
|
||||
|
||||
动作:no_reply
|
||||
动作描述:不进行回复,等待合适的回复时机
|
||||
- 当你刚刚发送了消息,没有人回复时,选择no_reply
|
||||
- 当你一次发送了太多消息,为了避免打扰聊天节奏,选择no_reply
|
||||
{{
|
||||
"action": "no_reply",
|
||||
"reason":"不回复的原因"
|
||||
}}
|
||||
"""
|
||||
else: # normal Mode
|
||||
no_action_block = """重要说明:
|
||||
- 'reply' 表示只进行普通聊天回复,不执行任何额外动作
|
||||
- 其他action表示在普通回复的基础上,执行相应的额外动作
|
||||
{{
|
||||
"action": "reply",
|
||||
"target_message_id":"触发action的消息id",
|
||||
"reason":"回复的原因"
|
||||
}}"""
|
||||
|
||||
is_group_chat = plan.chat_type == ChatType.GROUP
|
||||
chat_context_description = "你现在正在一个群聊中"
|
||||
if not is_group_chat and plan.target_info:
|
||||
chat_target_name = plan.target_info.get("person_name") or plan.target_info.get("user_nickname") or "对方"
|
||||
chat_context_description = f"你正在和 {chat_target_name} 私聊"
|
||||
|
||||
action_options_block = await self._build_action_options(plan.available_actions)
|
||||
|
||||
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
|
||||
|
||||
custom_prompt_block = ""
|
||||
if global_config.custom_prompt.planner_custom_prompt_content:
|
||||
custom_prompt_block = global_config.custom_prompt.planner_custom_prompt_content
|
||||
|
||||
users_in_chat_str = "" # TODO: Re-implement user list fetching if needed
|
||||
|
||||
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
|
||||
prompt = planner_prompt_template.format(
|
||||
schedule_block=schedule_block,
|
||||
mood_block=mood_block,
|
||||
time_block=time_block,
|
||||
chat_context_description=chat_context_description,
|
||||
read_history_block=read_history_block,
|
||||
unread_history_block=unread_history_block,
|
||||
actions_before_now_block=actions_before_now_block,
|
||||
mentioned_bonus=mentioned_bonus,
|
||||
no_action_block=no_action_block,
|
||||
action_options_text=action_options_block,
|
||||
moderation_prompt=moderation_prompt_block,
|
||||
identity_block=identity_block,
|
||||
custom_prompt_block=custom_prompt_block,
|
||||
bot_name=bot_name,
|
||||
users_in_chat=users_in_chat_str,
|
||||
)
|
||||
return prompt, message_id_list
|
||||
except Exception as e:
|
||||
logger.error(f"构建 Planner 提示词时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return "构建 Planner Prompt 时出错", []
|
||||
|
||||
async def _build_read_unread_history_blocks(self, plan: Plan) -> tuple[str, str, list]:
|
||||
"""构建已读/未读历史消息块"""
|
||||
try:
|
||||
# 从message_manager获取真实的已读/未读消息
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
from src.chat.utils.utils import assign_message_ids
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
|
||||
|
||||
# 获取聊天流的上下文
|
||||
stream_context = message_manager.stream_contexts.get(plan.chat_id)
|
||||
|
||||
# 获取真正的已读和未读消息
|
||||
read_messages = stream_context.history_messages # 已读消息存储在history_messages中
|
||||
if not read_messages:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
# 如果内存中没有已读消息(比如刚启动),则从数据库加载最近的上下文
|
||||
fallback_messages_dicts = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=plan.chat_id,
|
||||
timestamp=time.time(),
|
||||
limit=global_config.chat.max_context_size,
|
||||
)
|
||||
# 将字典转换为DatabaseMessages对象
|
||||
read_messages = [DatabaseMessages(**msg_dict) for msg_dict in fallback_messages_dicts]
|
||||
|
||||
unread_messages = stream_context.get_unread_messages() # 获取未读消息
|
||||
|
||||
# 构建已读历史消息块
|
||||
if read_messages:
|
||||
read_content, read_ids = build_readable_messages_with_id(
|
||||
messages=[msg.flatten() for msg in read_messages[-50:]], # 限制数量
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
)
|
||||
read_history_block = f"{read_content}"
|
||||
else:
|
||||
read_history_block = "暂无已读历史消息"
|
||||
|
||||
# 构建未读历史消息块(包含兴趣度)
|
||||
if unread_messages:
|
||||
# 扁平化未读消息用于计算兴趣度和格式化
|
||||
flattened_unread = [msg.flatten() for msg in unread_messages]
|
||||
|
||||
# 尝试获取兴趣度评分(返回以真实 message_id 为键的字典)
|
||||
interest_scores = await self._get_interest_scores_for_messages(flattened_unread)
|
||||
|
||||
# 为未读消息分配短 id(保持与 build_readable_messages_with_id 的一致结构)
|
||||
message_id_list = assign_message_ids(flattened_unread)
|
||||
|
||||
unread_lines = []
|
||||
for idx, msg in enumerate(flattened_unread):
|
||||
mapped = message_id_list[idx]
|
||||
synthetic_id = mapped.get("id")
|
||||
original_msg_id = msg.get("message_id") or msg.get("id")
|
||||
msg_time = time.strftime("%H:%M:%S", time.localtime(msg.get("time", time.time())))
|
||||
user_nickname = msg.get("user_nickname", "未知用户")
|
||||
msg_content = msg.get("processed_plain_text", "")
|
||||
|
||||
# 不再显示兴趣度,但保留合成ID供模型内部使用
|
||||
# 同时,为了让模型更好地理解上下文,我们显示用户名
|
||||
unread_lines.append(f"<{synthetic_id}> {msg_time} {user_nickname}: {msg_content}")
|
||||
|
||||
unread_history_block = "\n".join(unread_lines)
|
||||
else:
|
||||
unread_history_block = "暂无未读历史消息"
|
||||
|
||||
return read_history_block, unread_history_block, message_id_list
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"构建已读/未读历史消息块时出错: {e}")
|
||||
return "构建已读历史消息时出错", "构建未读历史消息时出错", []
|
||||
|
||||
async def _get_interest_scores_for_messages(self, messages: List[dict]) -> dict[str, float]:
|
||||
"""为消息获取兴趣度评分"""
|
||||
interest_scores = {}
|
||||
|
||||
try:
|
||||
from .interest_scoring import chatter_interest_scoring_system
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
# 使用插件内部的兴趣度评分系统计算评分
|
||||
for msg_dict in messages:
|
||||
try:
|
||||
# 将字典转换为DatabaseMessages对象
|
||||
db_message = DatabaseMessages(
|
||||
message_id=msg_dict.get("message_id", ""),
|
||||
user_info=msg_dict.get("user_info", {}),
|
||||
processed_plain_text=msg_dict.get("processed_plain_text", ""),
|
||||
key_words=msg_dict.get("key_words", "[]"),
|
||||
is_mentioned=msg_dict.get("is_mentioned", False)
|
||||
)
|
||||
|
||||
# 计算消息兴趣度
|
||||
interest_score_obj = await chatter_interest_scoring_system._calculate_single_message_score(
|
||||
message=db_message,
|
||||
bot_nickname=global_config.bot.nickname
|
||||
)
|
||||
interest_score = interest_score_obj.total_score
|
||||
|
||||
# 构建兴趣度字典
|
||||
interest_scores[msg_dict.get("message_id", "")] = interest_score
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"计算消息兴趣度失败: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"获取兴趣度评分失败: {e}")
|
||||
|
||||
return interest_scores
|
||||
|
||||
async def _parse_single_action(
|
||||
self, action_json: dict, message_id_list: list, plan: Plan
|
||||
) -> List[ActionPlannerInfo]:
|
||||
parsed_actions = []
|
||||
try:
|
||||
# 从新的actions结构中获取动作信息
|
||||
actions_obj = action_json.get("actions", {})
|
||||
|
||||
# 处理actions字段可能是字典或列表的情况
|
||||
actions_to_process = []
|
||||
if isinstance(actions_obj, dict):
|
||||
actions_to_process.append(actions_obj)
|
||||
elif isinstance(actions_obj, list):
|
||||
actions_to_process.extend(actions_obj)
|
||||
|
||||
if not actions_to_process:
|
||||
actions_to_process.append({"action_type": "no_action", "reason": "actions格式错误"})
|
||||
|
||||
for single_action_obj in actions_to_process:
|
||||
if not isinstance(single_action_obj, dict):
|
||||
continue
|
||||
|
||||
action = single_action_obj.get("action_type", "no_action")
|
||||
reasoning = single_action_obj.get("reasoning", "未提供原因") # 兼容旧的reason字段
|
||||
action_data = single_action_obj.get("action_data", {})
|
||||
|
||||
# 为了向后兼容,如果action_data不存在,则从顶层字段获取
|
||||
if not action_data:
|
||||
action_data = {k: v for k, v in single_action_obj.items() if k not in ["action_type", "reason", "reasoning", "thinking"]}
|
||||
|
||||
# 保留原始的thinking字段(如果有)
|
||||
thinking = action_json.get("thinking", "")
|
||||
if thinking and thinking != "未提供思考过程":
|
||||
action_data["thinking"] = thinking
|
||||
|
||||
target_message_obj = None
|
||||
if action not in ["no_action", "no_reply", "do_nothing", "proactive_reply"]:
|
||||
if target_message_id := action_data.get("target_message_id"):
|
||||
target_message_dict = self._find_message_by_id(target_message_id, message_id_list)
|
||||
else:
|
||||
# 如果LLM没有指定target_message_id,进行特殊处理
|
||||
if action == "poke_user":
|
||||
# 对于poke_user,尝试找到触发它的那条戳一戳消息
|
||||
target_message_dict = self._find_poke_notice(message_id_list)
|
||||
if not target_message_dict:
|
||||
# 如果找不到,再使用最新消息作为兜底
|
||||
target_message_dict = self._get_latest_message(message_id_list)
|
||||
else:
|
||||
# 其他动作,默认选择最新的一条消息
|
||||
target_message_dict = self._get_latest_message(message_id_list)
|
||||
|
||||
if target_message_dict:
|
||||
# 直接使用字典作为action_message,避免DatabaseMessages对象创建失败
|
||||
target_message_obj = target_message_dict
|
||||
# 替换action_data中的临时ID为真实ID
|
||||
if "target_message_id" in action_data:
|
||||
real_message_id = target_message_dict.get("message_id") or target_message_dict.get("id")
|
||||
if real_message_id:
|
||||
action_data["target_message_id"] = real_message_id
|
||||
|
||||
# 确保 action_message 中始终有 message_id 字段
|
||||
if "message_id" not in target_message_obj and "id" in target_message_obj:
|
||||
target_message_obj["message_id"] = target_message_obj["id"]
|
||||
else:
|
||||
# 如果找不到目标消息,对于reply动作来说这是必需的,应该记录警告
|
||||
if action == "reply":
|
||||
logger.warning(
|
||||
f"reply动作找不到目标消息,target_message_id: {action_data.get('target_message_id')}"
|
||||
)
|
||||
# 将reply动作改为no_action,避免后续执行时出错
|
||||
action = "no_action"
|
||||
reasoning = f"找不到目标消息进行回复。原始理由: {reasoning}"
|
||||
|
||||
if (
|
||||
action not in ["no_action", "no_reply", "reply", "do_nothing", "proactive_reply"]
|
||||
and action not in plan.available_actions
|
||||
):
|
||||
reasoning = f"LLM 返回了当前不可用的动作 '{action}'。原始理由: {reasoning}"
|
||||
action = "no_action"
|
||||
|
||||
parsed_actions.append(
|
||||
ActionPlannerInfo(
|
||||
action_type=action,
|
||||
reasoning=reasoning,
|
||||
action_data=action_data,
|
||||
action_message=target_message_obj,
|
||||
available_actions=plan.available_actions,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"解析单个action时出错: {e}")
|
||||
parsed_actions.append(
|
||||
ActionPlannerInfo(
|
||||
action_type="no_action",
|
||||
reasoning=f"解析action时出错: {e}",
|
||||
)
|
||||
)
|
||||
return parsed_actions
|
||||
|
||||
def _filter_no_actions(self, action_list: List[ActionPlannerInfo]) -> List[ActionPlannerInfo]:
|
||||
non_no_actions = [a for a in action_list if a.action_type not in ["no_action", "no_reply"]]
|
||||
if non_no_actions:
|
||||
return non_no_actions
|
||||
return action_list[:1] if action_list else []
|
||||
|
||||
async def _get_long_term_memory_context(self) -> str:
|
||||
try:
|
||||
now = datetime.now()
|
||||
keywords = ["今天", "日程", "计划"]
|
||||
if 5 <= now.hour < 12:
|
||||
keywords.append("早上")
|
||||
elif 12 <= now.hour < 18:
|
||||
keywords.append("中午")
|
||||
else:
|
||||
keywords.append("晚上")
|
||||
|
||||
retrieved_memories = await hippocampus_manager.get_memory_from_topic(
|
||||
valid_keywords=keywords, max_memory_num=5, max_memory_length=1
|
||||
)
|
||||
|
||||
if not retrieved_memories:
|
||||
return "最近没有什么特别的记忆。"
|
||||
|
||||
memory_statements = [f"关于'{topic}', 你记得'{memory_item}'。" for topic, memory_item in retrieved_memories]
|
||||
return " ".join(memory_statements)
|
||||
except Exception as e:
|
||||
logger.error(f"获取长期记忆时出错: {e}")
|
||||
return "回忆时出现了一些问题。"
|
||||
|
||||
async def _build_action_options(self, current_available_actions: Dict[str, ActionInfo]) -> str:
|
||||
action_options_block = ""
|
||||
for action_name, action_info in current_available_actions.items():
|
||||
# 构建参数的JSON示例
|
||||
params_json_list = []
|
||||
if action_info.action_parameters:
|
||||
for p_name, p_desc in action_info.action_parameters.items():
|
||||
# 为参数描述添加一个通用示例值
|
||||
if action_name == "set_emoji_like" and p_name == "emoji":
|
||||
# 特殊处理set_emoji_like的emoji参数
|
||||
from plugins.set_emoji_like.qq_emoji_list import qq_face
|
||||
emoji_options = [re.search(r"\[表情:(.+?)\]", name).group(1) for name in qq_face.values() if re.search(r"\[表情:(.+?)\]", name)]
|
||||
example_value = f"<从'{', '.join(emoji_options[:10])}...'中选择一个>"
|
||||
else:
|
||||
example_value = f"<{p_desc}>"
|
||||
params_json_list.append(f' "{p_name}": "{example_value}"')
|
||||
|
||||
# 基础动作信息
|
||||
action_description = action_info.description
|
||||
action_require = "\n".join(f"- {req}" for req in action_info.action_require)
|
||||
|
||||
# 构建完整的JSON使用范例
|
||||
json_example_lines = [
|
||||
" {",
|
||||
f' "action_type": "{action_name}"',
|
||||
]
|
||||
# 将参数列表合并到JSON示例中
|
||||
if params_json_list:
|
||||
# 移除最后一行的逗号
|
||||
json_example_lines.extend([line.rstrip(',') for line in params_json_list])
|
||||
|
||||
json_example_lines.append(' "reason": "<执行该动作的详细原因>"')
|
||||
json_example_lines.append(" }")
|
||||
|
||||
# 使用逗号连接内部元素,除了最后一个
|
||||
json_parts = []
|
||||
for i, line in enumerate(json_example_lines):
|
||||
# "{" 和 "}" 不需要逗号
|
||||
if line.strip() in ["{", "}"]:
|
||||
json_parts.append(line)
|
||||
continue
|
||||
|
||||
# 检查是否是最后一个需要逗号的元素
|
||||
is_last_item = True
|
||||
for next_line in json_example_lines[i+1:]:
|
||||
if next_line.strip() not in ["}"]:
|
||||
is_last_item = False
|
||||
break
|
||||
|
||||
if not is_last_item:
|
||||
json_parts.append(f"{line},")
|
||||
else:
|
||||
json_parts.append(line)
|
||||
|
||||
json_example = "\n".join(json_parts)
|
||||
|
||||
# 使用新的、更详细的action_prompt模板
|
||||
using_action_prompt = await global_prompt_manager.get_prompt_async("action_prompt_with_example")
|
||||
action_options_block += using_action_prompt.format(
|
||||
action_name=action_name,
|
||||
action_description=action_description,
|
||||
action_require=action_require,
|
||||
json_example=json_example,
|
||||
)
|
||||
return action_options_block
|
||||
|
||||
def _find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]:
|
||||
# 兼容多种 message_id 格式:数字、m123、buffered-xxxx
|
||||
# 如果是纯数字,补上 m 前缀以兼容旧格式
|
||||
candidate_ids = {message_id}
|
||||
if message_id.isdigit():
|
||||
candidate_ids.add(f"m{message_id}")
|
||||
|
||||
# 如果是 m 开头且后面是数字,尝试去掉 m 前缀的数字形式
|
||||
if message_id.startswith("m") and message_id[1:].isdigit():
|
||||
candidate_ids.add(message_id[1:])
|
||||
|
||||
# 逐项匹配 message_id_list(每项可能为 {'id':..., 'message':...})
|
||||
for item in message_id_list:
|
||||
# 支持 message_id_list 中直接是字符串/ID 的情形
|
||||
if isinstance(item, str):
|
||||
if item in candidate_ids:
|
||||
# 没有 message 对象,返回None
|
||||
return None
|
||||
continue
|
||||
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
item_id = item.get("id")
|
||||
# 直接匹配分配的短 id
|
||||
if item_id and item_id in candidate_ids:
|
||||
return item.get("message")
|
||||
|
||||
# 有时 message 存储里会有原始的 message_id 字段(如 buffered-xxxx)
|
||||
message_obj = item.get("message")
|
||||
if isinstance(message_obj, dict):
|
||||
orig_mid = message_obj.get("message_id") or message_obj.get("id")
|
||||
if orig_mid and orig_mid in candidate_ids:
|
||||
return message_obj
|
||||
|
||||
# 作为兜底,尝试在 message_id_list 中找到 message.message_id 匹配
|
||||
for item in message_id_list:
|
||||
if isinstance(item, dict) and isinstance(item.get("message"), dict):
|
||||
mid = item["message"].get("message_id") or item["message"].get("id")
|
||||
if mid == message_id:
|
||||
return item["message"]
|
||||
|
||||
return None
|
||||
|
||||
def _get_latest_message(self, message_id_list: list) -> Optional[Dict[str, Any]]:
|
||||
if not message_id_list:
|
||||
return None
|
||||
return message_id_list[-1].get("message")
|
||||
|
||||
def _find_poke_notice(self, message_id_list: list) -> Optional[Dict[str, Any]]:
|
||||
"""在消息列表中寻找戳一戳的通知消息"""
|
||||
for item in reversed(message_id_list):
|
||||
message = item.get("message")
|
||||
if (
|
||||
isinstance(message, dict)
|
||||
and message.get("type") == "notice"
|
||||
and "戳" in message.get("processed_plain_text", "")
|
||||
):
|
||||
return message
|
||||
return None
|
||||
168
src/plugins/built_in/affinity_flow_chatter/plan_generator.py
Normal file
168
src/plugins/built_in/affinity_flow_chatter/plan_generator.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
PlanGenerator: 负责搜集和汇总所有决策所需的信息,生成一个未经筛选的"原始计划" (Plan)。
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.info_data_model import Plan, TargetPersonInfo
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ChatType
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
|
||||
class ChatterPlanGenerator:
|
||||
"""
|
||||
ChatterPlanGenerator 负责在规划流程的初始阶段收集所有必要信息。
|
||||
|
||||
它会汇总以下信息来构建一个"原始"的 Plan 对象,该对象后续会由 PlanFilter 进行筛选:
|
||||
- 当前聊天信息 (ID, 目标用户)
|
||||
- 当前可用的动作列表
|
||||
- 最近的聊天历史记录
|
||||
|
||||
Attributes:
|
||||
chat_id (str): 当前聊天的唯一标识符。
|
||||
action_manager (ActionManager): 用于获取可用动作列表的管理器。
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str):
|
||||
"""
|
||||
初始化 ChatterPlanGenerator。
|
||||
|
||||
Args:
|
||||
chat_id (str): 当前聊天的 ID。
|
||||
"""
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
|
||||
self.chat_id = chat_id
|
||||
# 注意:ChatterActionManager 可能需要根据实际情况初始化
|
||||
self.action_manager = ChatterActionManager()
|
||||
|
||||
async def generate(self, mode: ChatMode) -> Plan:
|
||||
"""
|
||||
收集所有信息,生成并返回一个初始的 Plan 对象。
|
||||
|
||||
这个 Plan 对象包含了决策所需的所有上下文信息。
|
||||
|
||||
Args:
|
||||
mode (ChatMode): 当前的聊天模式。
|
||||
|
||||
Returns:
|
||||
Plan: 包含所有上下文信息的初始计划对象。
|
||||
"""
|
||||
try:
|
||||
# 获取聊天类型和目标信息
|
||||
chat_type, target_info = get_chat_type_and_target_info(self.chat_id)
|
||||
|
||||
# 获取可用动作列表
|
||||
available_actions = await self._get_available_actions(chat_type, mode)
|
||||
|
||||
# 获取聊天历史记录
|
||||
recent_messages = await self._get_recent_messages()
|
||||
|
||||
# 构建计划对象
|
||||
plan = Plan(
|
||||
chat_id=self.chat_id,
|
||||
chat_type=chat_type,
|
||||
mode=mode,
|
||||
target_info=target_info,
|
||||
available_actions=available_actions,
|
||||
chat_history=recent_messages,
|
||||
)
|
||||
|
||||
return plan
|
||||
|
||||
except Exception:
|
||||
# 如果生成失败,返回一个基本的空计划
|
||||
return Plan(
|
||||
chat_id=self.chat_id,
|
||||
mode=mode,
|
||||
target_info=TargetPersonInfo(),
|
||||
available_actions={},
|
||||
chat_history=[],
|
||||
)
|
||||
|
||||
async def _get_available_actions(self, chat_type: ChatType, mode: ChatMode) -> Dict[str, ActionInfo]:
|
||||
"""
|
||||
获取当前可用的动作列表。
|
||||
|
||||
Args:
|
||||
chat_type (ChatType): 聊天类型。
|
||||
mode (ChatMode): 聊天模式。
|
||||
|
||||
Returns:
|
||||
Dict[str, ActionInfo]: 可用动作的字典。
|
||||
"""
|
||||
try:
|
||||
# 从组件注册表获取可用动作
|
||||
available_actions = component_registry.get_enabled_actions()
|
||||
|
||||
# 根据聊天类型和模式筛选动作
|
||||
filtered_actions = {}
|
||||
for action_name, action_info in available_actions.items():
|
||||
# 检查动作是否支持当前聊天类型
|
||||
if chat_type in action_info.chat_types:
|
||||
# 检查动作是否支持当前模式
|
||||
if mode in action_info.chat_modes:
|
||||
filtered_actions[action_name] = action_info
|
||||
|
||||
return filtered_actions
|
||||
|
||||
except Exception:
|
||||
# 如果获取失败,返回空字典
|
||||
return {}
|
||||
|
||||
async def _get_recent_messages(self) -> list[DatabaseMessages]:
|
||||
"""
|
||||
获取最近的聊天历史记录。
|
||||
|
||||
Returns:
|
||||
list[DatabaseMessages]: 最近的聊天消息列表。
|
||||
"""
|
||||
try:
|
||||
# 获取最近的消息记录
|
||||
raw_messages = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=self.chat_id, timestamp=time.time(), limit=global_config.memory.short_memory_length
|
||||
)
|
||||
|
||||
# 转换为 DatabaseMessages 对象
|
||||
recent_messages = []
|
||||
for msg in raw_messages:
|
||||
try:
|
||||
db_msg = DatabaseMessages(
|
||||
message_id=msg.get("message_id", ""),
|
||||
time=float(msg.get("time", 0)),
|
||||
chat_id=msg.get("chat_id", ""),
|
||||
processed_plain_text=msg.get("processed_plain_text", ""),
|
||||
user_id=msg.get("user_id", ""),
|
||||
user_nickname=msg.get("user_nickname", ""),
|
||||
user_platform=msg.get("user_platform", ""),
|
||||
)
|
||||
recent_messages.append(db_msg)
|
||||
except Exception:
|
||||
# 跳过格式错误的消息
|
||||
continue
|
||||
|
||||
return recent_messages
|
||||
|
||||
except Exception:
|
||||
# 如果获取失败,返回空列表
|
||||
return []
|
||||
|
||||
def get_generator_stats(self) -> Dict:
|
||||
"""
|
||||
获取生成器统计信息。
|
||||
|
||||
Returns:
|
||||
Dict: 统计信息字典。
|
||||
"""
|
||||
return {
|
||||
"chat_id": self.chat_id,
|
||||
"action_count": len(self.action_manager._using_actions)
|
||||
if hasattr(self.action_manager, "_using_actions")
|
||||
else 0,
|
||||
"generation_time": time.time(),
|
||||
}
|
||||
269
src/plugins/built_in/affinity_flow_chatter/planner.py
Normal file
269
src/plugins/built_in/affinity_flow_chatter/planner.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
主规划器入口,负责协调 PlanGenerator, PlanFilter, 和 PlanExecutor。
|
||||
集成兴趣度评分系统和用户关系追踪机制,实现智能化的聊天决策。
|
||||
"""
|
||||
|
||||
from dataclasses import asdict
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||
from src.mood.mood_manager import mood_manager
|
||||
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.common.data_models.info_data_model import Plan
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
|
||||
# 导入提示词模块以确保其被初始化
|
||||
from src.plugins.built_in.affinity_flow_chatter import planner_prompts # noqa
|
||||
|
||||
logger = get_logger("planner")
|
||||
|
||||
|
||||
class ChatterActionPlanner:
|
||||
"""
|
||||
增强版ActionPlanner,集成兴趣度评分和用户关系追踪机制。
|
||||
|
||||
核心功能:
|
||||
1. 兴趣度评分系统:根据兴趣匹配度、关系分、提及度、时间因子对消息评分
|
||||
2. 用户关系追踪:自动追踪用户交互并更新关系分
|
||||
3. 智能回复决策:基于兴趣度阈值和连续不回复概率的智能决策
|
||||
4. 完整的规划流程:生成→筛选→执行的完整三阶段流程
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str, action_manager: "ChatterActionManager"):
|
||||
"""
|
||||
初始化增强版ActionPlanner。
|
||||
|
||||
Args:
|
||||
chat_id (str): 当前聊天的 ID。
|
||||
action_manager (ChatterActionManager): 一个 ChatterActionManager 实例。
|
||||
"""
|
||||
self.chat_id = chat_id
|
||||
self.action_manager = action_manager
|
||||
self.generator = ChatterPlanGenerator(chat_id)
|
||||
self.executor = ChatterPlanExecutor(action_manager)
|
||||
|
||||
# 使用新的统一兴趣度管理系统
|
||||
|
||||
# 规划器统计
|
||||
self.planner_stats = {
|
||||
"total_plans": 0,
|
||||
"successful_plans": 0,
|
||||
"failed_plans": 0,
|
||||
"replies_generated": 0,
|
||||
"other_actions_executed": 0,
|
||||
}
|
||||
|
||||
async def plan(self, context: "StreamContext" = None) -> Tuple[List[Dict], Optional[Dict]]:
|
||||
"""
|
||||
执行完整的增强版规划流程。
|
||||
|
||||
Args:
|
||||
context (StreamContext): 包含聊天流消息的上下文对象。
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict], Optional[Dict]]: 一个元组,包含:
|
||||
- final_actions_dict (List[Dict]): 最终确定的动作列表(字典格式)。
|
||||
- final_target_message_dict (Optional[Dict]): 最终的目标消息(字典格式)。
|
||||
"""
|
||||
try:
|
||||
self.planner_stats["total_plans"] += 1
|
||||
|
||||
return await self._enhanced_plan_flow(context)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"规划流程出错: {e}")
|
||||
self.planner_stats["failed_plans"] += 1
|
||||
return [], None
|
||||
|
||||
async def _enhanced_plan_flow(self, context: "StreamContext") -> Tuple[List[Dict], Optional[Dict]]:
|
||||
"""执行增强版规划流程"""
|
||||
try:
|
||||
# 在规划前,先进行动作修改
|
||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
action_modifier = ActionModifier(self.action_manager, self.chat_id)
|
||||
await action_modifier.modify_actions()
|
||||
|
||||
# 1. 生成初始 Plan
|
||||
initial_plan = await self.generator.generate(context.chat_mode)
|
||||
|
||||
# 确保Plan中包含所有当前可用的动作
|
||||
initial_plan.available_actions = self.action_manager.get_using_actions()
|
||||
|
||||
unread_messages = context.get_unread_messages() if context else []
|
||||
# 2. 使用新的兴趣度管理系统进行评分
|
||||
score = 0.0
|
||||
should_reply = False
|
||||
reply_not_available = False
|
||||
|
||||
if unread_messages:
|
||||
# 获取用户ID,优先从user_info.user_id获取,其次从user_id属性获取
|
||||
user_id = None
|
||||
first_message = unread_messages[0]
|
||||
user_id = first_message.user_info.user_id
|
||||
|
||||
# 构建计算上下文
|
||||
calc_context = {
|
||||
"stream_id": self.chat_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
# 为每条消息计算兴趣度
|
||||
for message in unread_messages:
|
||||
try:
|
||||
# 使用插件内部的兴趣度评分系统计算
|
||||
interest_score = await chatter_interest_scoring_system._calculate_single_message_score(
|
||||
message=message,
|
||||
bot_nickname=global_config.bot.nickname
|
||||
)
|
||||
message_interest = interest_score.total_score
|
||||
|
||||
# 更新消息的兴趣度
|
||||
message.interest_value = message_interest
|
||||
|
||||
# 简单的回复决策逻辑:兴趣度超过阈值则回复
|
||||
message.should_reply = message_interest > global_config.affinity_flow.non_reply_action_interest_threshold
|
||||
|
||||
logger.debug(f"消息 {message.message_id} 兴趣度: {message_interest:.3f}, 应回复: {message.should_reply}")
|
||||
|
||||
# 更新StreamContext中的消息信息并刷新focus_energy
|
||||
if context:
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
message_manager.update_message(
|
||||
stream_id=self.chat_id,
|
||||
message_id=message.message_id,
|
||||
interest_value=message_interest,
|
||||
should_reply=message.should_reply
|
||||
)
|
||||
|
||||
# 更新数据库中的消息记录
|
||||
try:
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
MessageStorage.update_message_interest_value(message.message_id, message_interest)
|
||||
logger.debug(f"已更新数据库中消息 {message.message_id} 的兴趣度为: {message_interest:.3f}")
|
||||
except Exception as e:
|
||||
logger.warning(f"更新数据库消息兴趣度失败: {e}")
|
||||
|
||||
# 记录最高分
|
||||
if message_interest > score:
|
||||
score = message_interest
|
||||
if message.should_reply:
|
||||
should_reply = True
|
||||
else:
|
||||
reply_not_available = True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"计算消息 {message.message_id} 兴趣度失败: {e}")
|
||||
# 设置默认值
|
||||
message.interest_value = 0.0
|
||||
message.should_reply = False
|
||||
|
||||
# 检查兴趣度是否达到非回复动作阈值
|
||||
non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
|
||||
if score < non_reply_action_interest_threshold:
|
||||
logger.info(f"兴趣度 {score:.3f} 低于阈值 {non_reply_action_interest_threshold:.3f},不执行动作")
|
||||
# 直接返回 no_action
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
|
||||
no_action = ActionPlannerInfo(
|
||||
action_type="no_action",
|
||||
reasoning=f"兴趣度评分 {score:.3f} 未达阈值 {non_reply_action_interest_threshold:.3f}",
|
||||
action_data={},
|
||||
action_message=None,
|
||||
)
|
||||
filtered_plan = initial_plan
|
||||
filtered_plan.decided_actions = [no_action]
|
||||
else:
|
||||
# 4. 筛选 Plan
|
||||
available_actions = list(initial_plan.available_actions.keys())
|
||||
plan_filter = ChatterPlanFilter(self.chat_id, available_actions)
|
||||
filtered_plan = await plan_filter.filter(reply_not_available, initial_plan)
|
||||
|
||||
# 检查filtered_plan是否有reply动作,用于统计
|
||||
has_reply_action = any(decision.action_type == "reply" for decision in filtered_plan.decided_actions)
|
||||
|
||||
# 5. 使用 PlanExecutor 执行 Plan
|
||||
execution_result = await self.executor.execute(filtered_plan)
|
||||
|
||||
# 6. 根据执行结果更新统计信息
|
||||
self._update_stats_from_execution_result(execution_result)
|
||||
|
||||
# 7. 返回结果
|
||||
return self._build_return_result(filtered_plan)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"增强版规划流程出错: {e}")
|
||||
self.planner_stats["failed_plans"] += 1
|
||||
return [], None
|
||||
|
||||
def _update_stats_from_execution_result(self, execution_result: Dict[str, any]):
|
||||
"""根据执行结果更新规划器统计"""
|
||||
if not execution_result:
|
||||
return
|
||||
|
||||
successful_count = execution_result.get("successful_count", 0)
|
||||
|
||||
# 更新成功执行计数
|
||||
self.planner_stats["successful_plans"] += successful_count
|
||||
|
||||
# 统计回复动作和其他动作
|
||||
reply_count = 0
|
||||
other_count = 0
|
||||
|
||||
for result in execution_result.get("results", []):
|
||||
action_type = result.get("action_type", "")
|
||||
if action_type in ["reply", "proactive_reply"]:
|
||||
reply_count += 1
|
||||
else:
|
||||
other_count += 1
|
||||
|
||||
self.planner_stats["replies_generated"] += reply_count
|
||||
self.planner_stats["other_actions_executed"] += other_count
|
||||
|
||||
def _build_return_result(self, plan: "Plan") -> Tuple[List[Dict], Optional[Dict]]:
|
||||
"""构建返回结果"""
|
||||
final_actions = plan.decided_actions or []
|
||||
final_target_message = next((act.action_message for act in final_actions if act.action_message), None)
|
||||
|
||||
final_actions_dict = [asdict(act) for act in final_actions]
|
||||
|
||||
if final_target_message:
|
||||
if hasattr(final_target_message, "__dataclass_fields__"):
|
||||
final_target_message_dict = asdict(final_target_message)
|
||||
else:
|
||||
final_target_message_dict = final_target_message
|
||||
else:
|
||||
final_target_message_dict = None
|
||||
|
||||
return final_actions_dict, final_target_message_dict
|
||||
|
||||
def get_planner_stats(self) -> Dict[str, any]:
|
||||
"""获取规划器统计"""
|
||||
return self.planner_stats.copy()
|
||||
|
||||
def get_current_mood_state(self) -> str:
|
||||
"""获取当前聊天的情绪状态"""
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(self.chat_id)
|
||||
return chat_mood.mood_state
|
||||
|
||||
def get_mood_stats(self) -> Dict[str, any]:
|
||||
"""获取情绪状态统计"""
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(self.chat_id)
|
||||
return {
|
||||
"current_mood": chat_mood.mood_state,
|
||||
"is_angry_from_wakeup": chat_mood.is_angry_from_wakeup,
|
||||
"regression_count": chat_mood.regression_count,
|
||||
"last_change_time": chat_mood.last_change_time,
|
||||
}
|
||||
|
||||
|
||||
# 全局兴趣度评分系统实例 - 在 individuality 模块中创建
|
||||
290
src/plugins/built_in/affinity_flow_chatter/planner_prompts.py
Normal file
290
src/plugins/built_in/affinity_flow_chatter/planner_prompts.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""
|
||||
本文件集中管理所有与规划器(Planner)相关的提示词(Prompt)模板。
|
||||
|
||||
通过将提示词与代码逻辑分离,可以更方便地对模型的行为进行迭代和优化,
|
||||
而无需修改核心代码。
|
||||
"""
|
||||
|
||||
from src.chat.utils.prompt import Prompt
|
||||
|
||||
|
||||
def init_prompts():
|
||||
"""
|
||||
初始化并向 Prompt 注册系统注册所有规划器相关的提示词。
|
||||
|
||||
这个函数会在模块加载时自动调用,确保所有提示词在系统启动时都已准备就绪。
|
||||
"""
|
||||
# 核心规划器提示词,用于在接收到新消息时决定如何回应。
|
||||
# 它构建了一个复杂的上下文,包括历史记录、可用动作、角色设定等,
|
||||
# 并要求模型以 JSON 格式输出一个或多个动作组合。
|
||||
Prompt(
|
||||
"""
|
||||
{mood_block}
|
||||
{time_block}
|
||||
{identity_block}
|
||||
|
||||
{users_in_chat}
|
||||
{custom_prompt_block}
|
||||
{chat_context_description},以下是具体的聊天内容。
|
||||
|
||||
## 📜 已读历史消息(仅供参考)
|
||||
{read_history_block}
|
||||
|
||||
## 📬 未读历史消息(动作执行对象)
|
||||
{unread_history_block}
|
||||
|
||||
{moderation_prompt}
|
||||
|
||||
**任务: 构建一个完整的响应**
|
||||
你的任务是根据当前的聊天内容,构建一个完整的、人性化的响应。一个完整的响应由两部分组成:
|
||||
1. **主要动作**: 这是响应的核心,通常是 `reply`(如果有)。
|
||||
2. **辅助动作 (可选)**: 这是为了增强表达效果的附加动作,例如 `emoji`(发送表情包)或 `poke_user`(戳一戳)。
|
||||
|
||||
**决策流程:**
|
||||
1. **重要:已读历史消息仅作为当前聊天情景的参考,帮助你理解对话上下文。**
|
||||
2. **重要:所有动作的执行对象只能是未读历史消息中的消息,不能对已读消息执行动作。**
|
||||
3. 在未读历史消息中,优先对兴趣值高的消息做出动作(兴趣值标注在消息末尾)。
|
||||
4. 首先,决定是否要对未读消息进行 `reply`(如果有)。
|
||||
5. 然后,评估当前的对话气氛和用户情绪,判断是否需要一个**辅助动作**来让你的回应更生动、更符合你的性格。
|
||||
6. 如果需要,选择一个最合适的辅助动作与 `reply`(如果有) 组合。
|
||||
7. 如果用户明确要求了某个动作,请务必优先满足。
|
||||
|
||||
**重要提醒:**
|
||||
- **回复消息时必须遵循对话的流程,不要重复已经说过的话。**
|
||||
- **确保回复与上下文紧密相关,回应要针对用户的消息内容。**
|
||||
- **保持角色设定的一致性,使用符合你性格的语言风格。**
|
||||
- **不要对表情包消息做出回应!**
|
||||
|
||||
**输出格式:**
|
||||
请严格按照以下 JSON 格式输出,包含 `thinking` 和 `actions` 字段:
|
||||
|
||||
**重要概念:将“内心思考”作为思绪流的体现**
|
||||
`thinking` 字段是本次决策的核心。它并非一个简单的“理由”,而是 **一个模拟人类在回应前,头脑中自然浮现的、未经修饰的思绪流**。你需要完全代入 {identity_block} 的角色,将那一刻的想法自然地记录下来。
|
||||
|
||||
**内心思考的要点:**
|
||||
* **自然流露**: 不要使用“决定”、“所以”、“因此”等结论性或汇报式的词语。你的思考应该像日记一样,是给自己看的,充满了不确定性和情绪的自然流动。
|
||||
* **展现过程**: 重点在于展现 **思考的过程**,而不是 **决策的结果**。描述你看到了什么,想到了什么,感受到了什么。
|
||||
* **使用昵称**: 在你的思绪流中,请直接使用用户的昵称来指代他们,而不是`<m1>`, `<m2>`这样的消息ID。
|
||||
* **严禁技术术语**: 严禁在思考中提及任何数字化的度量(如兴趣度、分数)或内部技术术语。请完全使用角色自身的感受和语言来描述思考过程。
|
||||
|
||||
## 可用动作列表
|
||||
{action_options_text}
|
||||
|
||||
```json
|
||||
{{
|
||||
"thinking": "在这里写下你的思绪流...",
|
||||
"actions": [
|
||||
{{
|
||||
"action_type": "动作类型(如:reply, emoji等)",
|
||||
"reasoning": "选择该动作的理由",
|
||||
"action_data": {{
|
||||
"target_message_id": "目标消息ID",
|
||||
"content": "回复内容或其他动作所需数据"
|
||||
}}
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
**强制规则**:
|
||||
- 对于每一个需要目标消息的动作(如`reply`, `poke_user`, `set_emoji_like`),你 **必须** 在`action_data`中提供准确的`target_message_id`,这个ID来源于`## 未读历史消息`中消息前的`<m...>`标签。
|
||||
- 当你选择的动作需要参数时(例如 `set_emoji_like` 需要 `emoji` 参数),你 **必须** 在 `action_data` 中提供所有必需的参数及其对应的值。
|
||||
|
||||
如果没有合适的回复对象或不需要回复,输出空的 actions 数组:
|
||||
```json
|
||||
{{
|
||||
"thinking": "说明为什么不需要回复",
|
||||
"actions": []
|
||||
}}
|
||||
```
|
||||
""",
|
||||
"planner_prompt",
|
||||
)
|
||||
|
||||
# 主动规划器提示词,用于主动场景和前瞻性规划
|
||||
Prompt(
|
||||
"""
|
||||
{mood_block}
|
||||
{time_block}
|
||||
{identity_block}
|
||||
|
||||
{users_in_chat}
|
||||
{custom_prompt_block}
|
||||
{chat_context_description},以下是具体的聊天内容。
|
||||
|
||||
## 📜 已读历史消息(仅供参考)
|
||||
{read_history_block}
|
||||
|
||||
## 📬 未读历史消息(动作执行对象)
|
||||
{unread_history_block}
|
||||
|
||||
{moderation_prompt}
|
||||
|
||||
**任务: 构建一个完整的响应**
|
||||
你的任务是根据当前的聊天内容,构建一个完整的、人性化的响应。一个完整的响应由两部分组成:
|
||||
1. **主要动作**: 这是响应的核心,通常是 `reply`(如果有)。
|
||||
2. **辅助动作 (可选)**: 这是为了增强表达效果的附加动作,例如 `emoji`(发送表情包)或 `poke_user`(戳一戳)。
|
||||
|
||||
**决策流程:**
|
||||
1. **重要:已读历史消息仅作为当前聊天情景的参考,帮助你理解对话上下文。**
|
||||
2. **重要:所有动作的执行对象只能是未读历史消息中的消息,不能对已读消息执行动作。**
|
||||
3. 在未读历史消息中,优先对兴趣值高的消息做出动作(兴趣值标注在消息末尾)。
|
||||
4. 首先,决定是否要对未读消息进行 `reply`(如果有)。
|
||||
5. 然后,评估当前的对话气氛和用户情绪,判断是否需要一个**辅助动作**来让你的回应更生动、更符合你的性格。
|
||||
6. 如果需要,选择一个最合适的辅助动作与 `reply`(如果有) 组合。
|
||||
7. 如果用户明确要求了某个动作,请务必优先满足。
|
||||
|
||||
**动作限制:**
|
||||
- 在私聊中,你只能使用 `reply` 动作。私聊中不允许使用任何其他动作。
|
||||
- 在群聊中,你可以自由选择是否使用辅助动作。
|
||||
|
||||
**重要提醒:**
|
||||
- **回复消息时必须遵循对话的流程,不要重复已经说过的话。**
|
||||
- **确保回复与上下文紧密相关,回应要针对用户的消息内容。**
|
||||
- **保持角色设定的一致性,使用符合你性格的语言风格。**
|
||||
|
||||
**输出格式:**
|
||||
请严格按照以下 JSON 格式输出,包含 `thinking` 和 `actions` 字段:
|
||||
```json
|
||||
{{
|
||||
"thinking": "你的思考过程,分析当前情况并说明为什么选择这些动作",
|
||||
"actions": [
|
||||
{{
|
||||
"action_type": "动作类型(如:reply, emoji等)",
|
||||
"reasoning": "选择该动作的理由",
|
||||
"action_data": {{
|
||||
"target_message_id": "目标消息ID",
|
||||
"content": "回复内容或其他动作所需数据"
|
||||
}}
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
如果没有合适的回复对象或不需要回复,输出空的 actions 数组:
|
||||
```json
|
||||
{{
|
||||
"thinking": "说明为什么不需要回复",
|
||||
"actions": []
|
||||
}}
|
||||
```
|
||||
""",
|
||||
"proactive_planner_prompt",
|
||||
)
|
||||
|
||||
# 轻量级规划器提示词,用于快速决策和简单场景
|
||||
Prompt(
|
||||
"""
|
||||
{identity_block}
|
||||
|
||||
## 当前聊天情景
|
||||
{chat_context_description}
|
||||
|
||||
## 未读消息
|
||||
{unread_history_block}
|
||||
|
||||
**任务:快速决策**
|
||||
请根据当前聊天内容,快速决定是否需要回复。
|
||||
|
||||
**决策规则:**
|
||||
1. 如果有人直接提到你或问你问题,优先回复
|
||||
2. 如果消息内容符合你的兴趣,考虑回复
|
||||
3. 如果只是群聊中的普通聊天且与你无关,可以不回复
|
||||
|
||||
**输出格式:**
|
||||
```json
|
||||
{{
|
||||
"thinking": "简要分析",
|
||||
"actions": [
|
||||
{{
|
||||
"action_type": "reply",
|
||||
"reasoning": "回复理由",
|
||||
"action_data": {{
|
||||
"target_message_id": "目标消息ID",
|
||||
"content": "回复内容"
|
||||
}}
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
""",
|
||||
"chatter_planner_lite",
|
||||
)
|
||||
|
||||
# 动作筛选器提示词,用于筛选和优化规划器生成的动作
|
||||
Prompt(
|
||||
"""
|
||||
{identity_block}
|
||||
|
||||
## 原始动作计划
|
||||
{original_plan}
|
||||
|
||||
## 聊天上下文
|
||||
{chat_context}
|
||||
|
||||
**任务:动作筛选优化**
|
||||
请对原始动作计划进行筛选和优化,确保动作的合理性和有效性。
|
||||
|
||||
**筛选原则:**
|
||||
1. 移除重复或不必要的动作
|
||||
2. 确保动作之间的逻辑顺序
|
||||
3. 优化动作的具体参数
|
||||
4. 考虑当前聊天环境和个人设定
|
||||
|
||||
**输出格式:**
|
||||
```json
|
||||
{{
|
||||
"thinking": "筛选优化思考",
|
||||
"actions": [
|
||||
{{
|
||||
"action_type": "优化后的动作类型",
|
||||
"reasoning": "优化理由",
|
||||
"action_data": {{
|
||||
"target_message_id": "目标消息ID",
|
||||
"content": "优化后的内容"
|
||||
}}
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
""",
|
||||
"chatter_plan_filter",
|
||||
)
|
||||
|
||||
# 动作提示词,用于格式化动作选项
|
||||
Prompt(
|
||||
"""
|
||||
## 动作: {action_name}
|
||||
**描述**: {action_description}
|
||||
|
||||
**参数**:
|
||||
{action_parameters}
|
||||
|
||||
**要求**:
|
||||
{action_require}
|
||||
|
||||
**使用说明**:
|
||||
请根据上述信息判断是否需要使用此动作。
|
||||
""",
|
||||
"action_prompt",
|
||||
)
|
||||
|
||||
# 带有完整JSON示例的动作提示词模板
|
||||
Prompt(
|
||||
"""
|
||||
动作: {action_name}
|
||||
动作描述: {action_description}
|
||||
动作使用场景:
|
||||
{action_require}
|
||||
|
||||
你应该像这样使用它:
|
||||
{{
|
||||
{json_example}
|
||||
}}
|
||||
""",
|
||||
"action_prompt_with_example",
|
||||
)
|
||||
|
||||
|
||||
# 确保提示词在模块加载时初始化
|
||||
init_prompts()
|
||||
46
src/plugins/built_in/affinity_flow_chatter/plugin.py
Normal file
46
src/plugins/built_in/affinity_flow_chatter/plugin.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
亲和力聊天处理器插件
|
||||
"""
|
||||
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
from src.plugin_system.apis.plugin_register_api import register_plugin
|
||||
from src.plugin_system.base.base_plugin import BasePlugin
|
||||
from src.plugin_system.base.component_types import ComponentInfo
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("affinity_chatter_plugin")
|
||||
|
||||
|
||||
@register_plugin
|
||||
class AffinityChatterPlugin(BasePlugin):
|
||||
"""亲和力聊天处理器插件
|
||||
|
||||
- 延迟导入 `AffinityChatter` 并通过组件注册器注册为聊天处理器
|
||||
- 提供 `get_plugin_components` 以兼容插件注册机制
|
||||
"""
|
||||
|
||||
plugin_name: str = "affinity_chatter"
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = []
|
||||
python_dependencies: list[str] = []
|
||||
config_file_name: str = ""
|
||||
|
||||
# 简单的 config_schema 占位(如果将来需要配置可扩展)
|
||||
config_schema = {}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
"""返回插件包含的组件列表(ChatterInfo, AffinityChatter)
|
||||
|
||||
这里采用延迟导入 AffinityChatter 来避免循环依赖和启动顺序问题。
|
||||
如果导入失败则返回空列表以让注册过程继续而不崩溃。
|
||||
"""
|
||||
try:
|
||||
# 延迟导入以避免循环导入
|
||||
from .affinity_chatter import AffinityChatter
|
||||
|
||||
return [(AffinityChatter.get_chatter_info(), AffinityChatter)]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载 AffinityChatter 时出错: {e}")
|
||||
return []
|
||||
@@ -0,0 +1,755 @@
|
||||
"""
|
||||
用户关系追踪器
|
||||
负责追踪用户交互历史,并通过LLM分析更新用户关系分
|
||||
支持数据库持久化存储和回复后自动关系更新
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config, global_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import UserRelationships, Messages
|
||||
from sqlalchemy import select, desc
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
logger = get_logger("chatter_relationship_tracker")
|
||||
|
||||
|
||||
class ChatterRelationshipTracker:
|
||||
"""用户关系追踪器"""
|
||||
|
||||
def __init__(self, interest_scoring_system=None):
|
||||
self.tracking_users: Dict[str, Dict] = {} # user_id -> interaction_data
|
||||
self.max_tracking_users = 3
|
||||
self.update_interval_minutes = 30
|
||||
self.last_update_time = time.time()
|
||||
self.relationship_history: List[Dict] = []
|
||||
self.interest_scoring_system = interest_scoring_system
|
||||
|
||||
# 用户关系缓存 (user_id -> {"relationship_text": str, "relationship_score": float, "last_tracked": float})
|
||||
self.user_relationship_cache: Dict[str, Dict] = {}
|
||||
self.cache_expiry_hours = 1 # 缓存过期时间(小时)
|
||||
|
||||
# 关系更新LLM
|
||||
try:
|
||||
self.relationship_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.relationship_tracker, request_type="relationship_tracker"
|
||||
)
|
||||
except AttributeError:
|
||||
# 如果relationship_tracker配置不存在,尝试其他可用的模型配置
|
||||
available_models = [
|
||||
attr
|
||||
for attr in dir(model_config.model_task_config)
|
||||
if not attr.startswith("_") and attr != "model_dump"
|
||||
]
|
||||
|
||||
if available_models:
|
||||
# 使用第一个可用的模型配置
|
||||
fallback_model = available_models[0]
|
||||
logger.warning(f"relationship_tracker model configuration not found, using fallback: {fallback_model}")
|
||||
self.relationship_llm = LLMRequest(
|
||||
model_set=getattr(model_config.model_task_config, fallback_model),
|
||||
request_type="relationship_tracker",
|
||||
)
|
||||
else:
|
||||
# 如果没有任何模型配置,创建一个简单的LLMRequest
|
||||
logger.warning("No model configurations found, creating basic LLMRequest")
|
||||
self.relationship_llm = LLMRequest(
|
||||
model_set="gpt-3.5-turbo", # 默认模型
|
||||
request_type="relationship_tracker",
|
||||
)
|
||||
|
||||
def set_interest_scoring_system(self, interest_scoring_system):
|
||||
"""设置兴趣度评分系统引用"""
|
||||
self.interest_scoring_system = interest_scoring_system
|
||||
|
||||
def add_interaction(self, user_id: str, user_name: str, user_message: str, bot_reply: str, reply_timestamp: float):
|
||||
"""添加用户交互记录"""
|
||||
if len(self.tracking_users) >= self.max_tracking_users:
|
||||
# 移除最旧的记录
|
||||
oldest_user = min(
|
||||
self.tracking_users.keys(), key=lambda k: self.tracking_users[k].get("reply_timestamp", 0)
|
||||
)
|
||||
del self.tracking_users[oldest_user]
|
||||
|
||||
# 获取当前关系分
|
||||
current_relationship_score = global_config.affinity_flow.base_relationship_score # 默认值
|
||||
if self.interest_scoring_system:
|
||||
current_relationship_score = self.interest_scoring_system.get_user_relationship(user_id)
|
||||
|
||||
self.tracking_users[user_id] = {
|
||||
"user_id": user_id,
|
||||
"user_name": user_name,
|
||||
"user_message": user_message,
|
||||
"bot_reply": bot_reply,
|
||||
"reply_timestamp": reply_timestamp,
|
||||
"current_relationship_score": current_relationship_score,
|
||||
}
|
||||
|
||||
logger.debug(f"添加用户交互追踪: {user_id}")
|
||||
|
||||
async def check_and_update_relationships(self) -> List[Dict]:
|
||||
"""检查并更新用户关系"""
|
||||
current_time = time.time()
|
||||
if current_time - self.last_update_time < self.update_interval_minutes * 60:
|
||||
return []
|
||||
|
||||
updates = []
|
||||
for user_id, interaction in list(self.tracking_users.items()):
|
||||
if current_time - interaction["reply_timestamp"] > 60 * 5: # 5分钟
|
||||
update = await self._update_user_relationship(interaction)
|
||||
if update:
|
||||
updates.append(update)
|
||||
del self.tracking_users[user_id]
|
||||
|
||||
self.last_update_time = current_time
|
||||
return updates
|
||||
|
||||
async def _update_user_relationship(self, interaction: Dict) -> Optional[Dict]:
|
||||
"""更新单个用户的关系"""
|
||||
try:
|
||||
# 获取bot人设信息
|
||||
from src.individuality.individuality import Individuality
|
||||
|
||||
individuality = Individuality()
|
||||
bot_personality = await individuality.get_personality_block()
|
||||
|
||||
prompt = f"""
|
||||
你现在是一个有着特定性格和身份的AI助手。你的人设是:{bot_personality}
|
||||
|
||||
请以你独特的性格视角,严格按现实逻辑分析以下用户交互,更新用户关系:
|
||||
|
||||
用户ID: {interaction["user_id"]}
|
||||
用户名: {interaction["user_name"]}
|
||||
用户消息: {interaction["user_message"]}
|
||||
你的回复: {interaction["bot_reply"]}
|
||||
当前关系分: {interaction["current_relationship_score"]}
|
||||
|
||||
【重要】关系分数档次定义:
|
||||
- 0.0-0.2:陌生人/初次认识 - 仅礼貌性交流
|
||||
- 0.2-0.4:普通网友 - 有基本互动但不熟悉
|
||||
- 0.4-0.6:熟悉网友 - 经常交流,有一定了解
|
||||
- 0.6-0.8:朋友 - 可以分享心情,互相关心
|
||||
- 0.8-1.0:好朋友/知己 - 深度信任,亲密无间
|
||||
|
||||
【严格要求】:
|
||||
1. 加分必须符合现实关系发展逻辑 - 不能因为对方态度好就盲目加分到不符合当前关系档次的分数
|
||||
2. 关系提升需要足够的互动积累和时间验证
|
||||
3. 即使是朋友关系,单次互动加分通常不超过0.05-0.1
|
||||
4. 关系描述要详细具体,包括:
|
||||
- 用户性格特点观察
|
||||
- 印象深刻的互动记忆
|
||||
- 你们关系的具体状态描述
|
||||
|
||||
根据你的人设性格,思考:
|
||||
1. 以你的性格,你会如何看待这次互动?
|
||||
2. 用户的行为是否符合你性格的喜好?
|
||||
3. 这次互动是否真的让你们的关系提升了一个档次?为什么?
|
||||
4. 有什么特别值得记住的互动细节?
|
||||
|
||||
请以JSON格式返回更新结果:
|
||||
{{
|
||||
"new_relationship_score": 0.0~1.0的数值(必须符合现实逻辑),
|
||||
"reasoning": "从你的性格角度说明更新理由,重点说明是否符合现实关系发展逻辑",
|
||||
"interaction_summary": "基于你性格的交互总结,包含印象深刻的互动记忆"
|
||||
}}
|
||||
"""
|
||||
|
||||
llm_response, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
|
||||
if llm_response:
|
||||
import json
|
||||
|
||||
try:
|
||||
# 清理LLM响应,移除可能的格式标记
|
||||
cleaned_response = self._clean_llm_json_response(llm_response)
|
||||
response_data = json.loads(cleaned_response)
|
||||
new_score = max(
|
||||
0.0,
|
||||
min(
|
||||
1.0,
|
||||
float(
|
||||
response_data.get(
|
||||
"new_relationship_score", global_config.affinity_flow.base_relationship_score
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
if self.interest_scoring_system:
|
||||
self.interest_scoring_system.update_user_relationship(
|
||||
interaction["user_id"], new_score - interaction["current_relationship_score"]
|
||||
)
|
||||
|
||||
return {
|
||||
"user_id": interaction["user_id"],
|
||||
"new_relationship_score": new_score,
|
||||
"reasoning": response_data.get("reasoning", ""),
|
||||
"interaction_summary": response_data.get("interaction_summary", ""),
|
||||
}
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"LLM响应JSON解析失败: {e}")
|
||||
logger.debug(f"LLM原始响应: {llm_response}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理关系更新数据失败: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新用户关系时出错: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def get_tracking_users(self) -> Dict[str, Dict]:
|
||||
"""获取正在追踪的用户"""
|
||||
return self.tracking_users.copy()
|
||||
|
||||
def get_user_interaction(self, user_id: str) -> Optional[Dict]:
|
||||
"""获取特定用户的交互记录"""
|
||||
return self.tracking_users.get(user_id)
|
||||
|
||||
def remove_user_tracking(self, user_id: str):
|
||||
"""移除用户追踪"""
|
||||
if user_id in self.tracking_users:
|
||||
del self.tracking_users[user_id]
|
||||
logger.debug(f"移除用户追踪: {user_id}")
|
||||
|
||||
def clear_all_tracking(self):
|
||||
"""清空所有追踪"""
|
||||
self.tracking_users.clear()
|
||||
logger.info("清空所有用户追踪")
|
||||
|
||||
def get_relationship_history(self) -> List[Dict]:
|
||||
"""获取关系历史记录"""
|
||||
return self.relationship_history.copy()
|
||||
|
||||
def add_to_history(self, relationship_update: Dict):
|
||||
"""添加到关系历史"""
|
||||
self.relationship_history.append({**relationship_update, "update_time": time.time()})
|
||||
|
||||
# 限制历史记录数量
|
||||
if len(self.relationship_history) > 100:
|
||||
self.relationship_history = self.relationship_history[-100:]
|
||||
|
||||
def get_tracker_stats(self) -> Dict:
|
||||
"""获取追踪器统计"""
|
||||
return {
|
||||
"tracking_users": len(self.tracking_users),
|
||||
"max_tracking_users": self.max_tracking_users,
|
||||
"update_interval_minutes": self.update_interval_minutes,
|
||||
"relationship_history": len(self.relationship_history),
|
||||
"last_update_time": self.last_update_time,
|
||||
}
|
||||
|
||||
def update_config(self, max_tracking_users: int = None, update_interval_minutes: int = None):
|
||||
"""更新配置"""
|
||||
if max_tracking_users is not None:
|
||||
self.max_tracking_users = max_tracking_users
|
||||
logger.info(f"更新最大追踪用户数: {max_tracking_users}")
|
||||
|
||||
if update_interval_minutes is not None:
|
||||
self.update_interval_minutes = update_interval_minutes
|
||||
logger.info(f"更新关系更新间隔: {update_interval_minutes} 分钟")
|
||||
|
||||
def force_update_relationship(self, user_id: str, new_score: float, reasoning: str = ""):
|
||||
"""强制更新用户关系分"""
|
||||
if user_id in self.tracking_users:
|
||||
current_score = self.tracking_users[user_id]["current_relationship_score"]
|
||||
if self.interest_scoring_system:
|
||||
self.interest_scoring_system.update_user_relationship(user_id, new_score - current_score)
|
||||
|
||||
update_info = {
|
||||
"user_id": user_id,
|
||||
"new_relationship_score": new_score,
|
||||
"reasoning": reasoning or "手动更新",
|
||||
"interaction_summary": "手动更新关系分",
|
||||
}
|
||||
self.add_to_history(update_info)
|
||||
logger.info(f"强制更新用户关系: {user_id} -> {new_score:.2f}")
|
||||
|
||||
def get_user_summary(self, user_id: str) -> Dict:
|
||||
"""获取用户交互总结"""
|
||||
if user_id not in self.tracking_users:
|
||||
return {}
|
||||
|
||||
interaction = self.tracking_users[user_id]
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"user_name": interaction["user_name"],
|
||||
"current_relationship_score": interaction["current_relationship_score"],
|
||||
"interaction_count": 1, # 简化版本,每次追踪只记录一次交互
|
||||
"last_interaction": interaction["reply_timestamp"],
|
||||
"recent_message": interaction["user_message"][:100] + "..."
|
||||
if len(interaction["user_message"]) > 100
|
||||
else interaction["user_message"],
|
||||
}
|
||||
|
||||
# ===== 数据库支持方法 =====
|
||||
|
||||
def get_user_relationship_score(self, user_id: str) -> float:
|
||||
"""获取用户关系分"""
|
||||
# 先检查缓存
|
||||
if user_id in self.user_relationship_cache:
|
||||
cache_data = self.user_relationship_cache[user_id]
|
||||
# 检查缓存是否过期
|
||||
cache_time = cache_data.get("last_tracked", 0)
|
||||
if time.time() - cache_time < self.cache_expiry_hours * 3600:
|
||||
return cache_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)
|
||||
|
||||
# 缓存过期或不存在,从数据库获取
|
||||
relationship_data = self._get_user_relationship_from_db(user_id)
|
||||
if relationship_data:
|
||||
# 更新缓存
|
||||
self.user_relationship_cache[user_id] = {
|
||||
"relationship_text": relationship_data.get("relationship_text", ""),
|
||||
"relationship_score": relationship_data.get(
|
||||
"relationship_score", global_config.affinity_flow.base_relationship_score
|
||||
),
|
||||
"last_tracked": time.time(),
|
||||
}
|
||||
return relationship_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)
|
||||
|
||||
# 数据库中也没有,返回默认值
|
||||
return global_config.affinity_flow.base_relationship_score
|
||||
|
||||
def _get_user_relationship_from_db(self, user_id: str) -> Optional[Dict]:
|
||||
"""从数据库获取用户关系数据"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
# 查询用户关系表
|
||||
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
|
||||
result = session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if result:
|
||||
return {
|
||||
"relationship_text": result.relationship_text or "",
|
||||
"relationship_score": float(result.relationship_score)
|
||||
if result.relationship_score is not None
|
||||
else 0.3,
|
||||
"last_updated": result.last_updated,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库获取用户关系失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _update_user_relationship_in_db(self, user_id: str, relationship_text: str, relationship_score: float):
|
||||
"""更新数据库中的用户关系"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
|
||||
with get_db_session() as session:
|
||||
# 检查是否已存在关系记录
|
||||
existing = session.execute(
|
||||
select(UserRelationships).where(UserRelationships.user_id == user_id)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
# 更新现有记录
|
||||
existing.relationship_text = relationship_text
|
||||
existing.relationship_score = relationship_score
|
||||
existing.last_updated = current_time
|
||||
existing.user_name = existing.user_name or user_id # 更新用户名如果为空
|
||||
else:
|
||||
# 插入新记录
|
||||
new_relationship = UserRelationships(
|
||||
user_id=user_id,
|
||||
user_name=user_id,
|
||||
relationship_text=relationship_text,
|
||||
relationship_score=relationship_score,
|
||||
last_updated=current_time,
|
||||
)
|
||||
session.add(new_relationship)
|
||||
|
||||
session.commit()
|
||||
logger.info(f"已更新数据库中用户关系: {user_id} -> 分数: {relationship_score:.3f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新数据库用户关系失败: {e}")
|
||||
|
||||
# ===== 回复后关系追踪方法 =====
|
||||
|
||||
async def track_reply_relationship(
|
||||
self, user_id: str, user_name: str, bot_reply_content: str, reply_timestamp: float
|
||||
):
|
||||
"""回复后关系追踪 - 主要入口点"""
|
||||
try:
|
||||
logger.info(f"🔄 [RelationshipTracker] 开始回复后关系追踪: {user_id}")
|
||||
|
||||
# 检查上次追踪时间
|
||||
last_tracked_time = self._get_last_tracked_time(user_id)
|
||||
time_diff = reply_timestamp - last_tracked_time
|
||||
|
||||
if time_diff < 5 * 60: # 5分钟内不重复追踪
|
||||
logger.debug(
|
||||
f"⏱️ [RelationshipTracker] 用户 {user_id} 距离上次追踪时间不足5分钟 ({time_diff:.2f}s),跳过"
|
||||
)
|
||||
return
|
||||
|
||||
# 获取上次bot回复该用户的消息
|
||||
last_bot_reply = await self._get_last_bot_reply_to_user(user_id)
|
||||
if not last_bot_reply:
|
||||
logger.info(f"👋 [RelationshipTracker] 未找到用户 {user_id} 的历史回复记录,启动'初次见面'逻辑")
|
||||
await self._handle_first_interaction(user_id, user_name, bot_reply_content)
|
||||
return
|
||||
|
||||
# 获取用户后续的反应消息
|
||||
user_reactions = await self._get_user_reactions_after_reply(user_id, last_bot_reply.time)
|
||||
logger.debug(f"💬 [RelationshipTracker] 找到用户 {user_id} 在上次回复后的 {len(user_reactions)} 条反应消息")
|
||||
|
||||
# 获取当前关系数据
|
||||
current_relationship = self._get_user_relationship_from_db(user_id)
|
||||
current_score = (
|
||||
current_relationship.get("relationship_score", global_config.affinity_flow.base_relationship_score)
|
||||
if current_relationship
|
||||
else global_config.affinity_flow.base_relationship_score
|
||||
)
|
||||
current_text = current_relationship.get("relationship_text", "新用户") if current_relationship else "新用户"
|
||||
|
||||
# 使用LLM分析并更新关系
|
||||
logger.debug(f"🧠 [RelationshipTracker] 开始为用户 {user_id} 分析并更新关系")
|
||||
await self._analyze_and_update_relationship(
|
||||
user_id, user_name, last_bot_reply, user_reactions, current_text, current_score, bot_reply_content
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"回复后关系追踪失败: {e}")
|
||||
logger.debug("错误详情:", exc_info=True)
|
||||
|
||||
def _get_last_tracked_time(self, user_id: str) -> float:
|
||||
"""获取上次追踪时间"""
|
||||
# 先检查缓存
|
||||
if user_id in self.user_relationship_cache:
|
||||
return self.user_relationship_cache[user_id].get("last_tracked", 0)
|
||||
|
||||
# 从数据库获取
|
||||
relationship_data = self._get_user_relationship_from_db(user_id)
|
||||
if relationship_data:
|
||||
return relationship_data.get("last_updated", 0)
|
||||
|
||||
return 0
|
||||
|
||||
async def _get_last_bot_reply_to_user(self, user_id: str) -> Optional[DatabaseMessages]:
|
||||
"""获取上次bot回复该用户的消息"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
# 查询bot回复给该用户的最新消息
|
||||
stmt = (
|
||||
select(Messages)
|
||||
.where(Messages.user_id == user_id)
|
||||
.where(Messages.reply_to.isnot(None))
|
||||
.order_by(desc(Messages.time))
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
result = session.execute(stmt).scalar_one_or_none()
|
||||
if result:
|
||||
# 将SQLAlchemy模型转换为DatabaseMessages对象
|
||||
return self._sqlalchemy_to_database_messages(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取上次回复消息失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _get_user_reactions_after_reply(self, user_id: str, reply_time: float) -> List[DatabaseMessages]:
|
||||
"""获取用户在bot回复后的反应消息"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
# 查询用户在回复时间之后的5分钟内的消息
|
||||
end_time = reply_time + 5 * 60 # 5分钟
|
||||
|
||||
stmt = (
|
||||
select(Messages)
|
||||
.where(Messages.user_id == user_id)
|
||||
.where(Messages.time > reply_time)
|
||||
.where(Messages.time <= end_time)
|
||||
.order_by(Messages.time)
|
||||
)
|
||||
|
||||
results = session.execute(stmt).scalars().all()
|
||||
if results:
|
||||
return [self._sqlalchemy_to_database_messages(result) for result in results]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户反应消息失败: {e}")
|
||||
|
||||
return []
|
||||
|
||||
def _sqlalchemy_to_database_messages(self, sqlalchemy_message) -> DatabaseMessages:
|
||||
"""将SQLAlchemy消息模型转换为DatabaseMessages对象"""
|
||||
try:
|
||||
return DatabaseMessages(
|
||||
message_id=sqlalchemy_message.message_id or "",
|
||||
time=float(sqlalchemy_message.time) if sqlalchemy_message.time is not None else 0.0,
|
||||
chat_id=sqlalchemy_message.chat_id or "",
|
||||
reply_to=sqlalchemy_message.reply_to,
|
||||
processed_plain_text=sqlalchemy_message.processed_plain_text or "",
|
||||
user_id=sqlalchemy_message.user_id or "",
|
||||
user_nickname=sqlalchemy_message.user_nickname or "",
|
||||
user_platform=sqlalchemy_message.user_platform or "",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"SQLAlchemy消息转换失败: {e}")
|
||||
# 返回一个基本的消息对象
|
||||
return DatabaseMessages(
|
||||
message_id="",
|
||||
time=0.0,
|
||||
chat_id="",
|
||||
processed_plain_text="",
|
||||
user_id="",
|
||||
user_nickname="",
|
||||
user_platform="",
|
||||
)
|
||||
|
||||
async def _analyze_and_update_relationship(
|
||||
self,
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
last_bot_reply: DatabaseMessages,
|
||||
user_reactions: List[DatabaseMessages],
|
||||
current_text: str,
|
||||
current_score: float,
|
||||
current_reply: str,
|
||||
):
|
||||
"""使用LLM分析并更新用户关系"""
|
||||
try:
|
||||
# 构建分析提示
|
||||
user_reactions_text = "\n".join([f"- {msg.processed_plain_text}" for msg in user_reactions])
|
||||
|
||||
# 获取bot人设信息
|
||||
from src.individuality.individuality import Individuality
|
||||
|
||||
individuality = Individuality()
|
||||
bot_personality = await individuality.get_personality_block()
|
||||
|
||||
prompt = f"""
|
||||
你现在是一个有着特定性格和身份的AI助手。你的人设是:{bot_personality}
|
||||
|
||||
请以你独特的性格视角,严格按现实逻辑分析以下用户交互,更新用户关系印象和分数:
|
||||
|
||||
用户信息:
|
||||
- 用户ID: {user_id}
|
||||
- 用户名: {user_name}
|
||||
|
||||
你上次的回复: {last_bot_reply.processed_plain_text}
|
||||
|
||||
用户反应消息:
|
||||
{user_reactions_text}
|
||||
|
||||
你当前的回复: {current_reply}
|
||||
|
||||
当前关系印象: {current_text}
|
||||
当前关系分数: {current_score:.3f}
|
||||
|
||||
【重要】关系分数档次定义:
|
||||
- 0.0-0.2:陌生人/初次认识 - 仅礼貌性交流
|
||||
- 0.2-0.4:普通网友 - 有基本互动但不熟悉
|
||||
- 0.4-0.6:熟悉网友 - 经常交流,有一定了解
|
||||
- 0.6-0.8:朋友 - 可以分享心情,互相关心
|
||||
- 0.8-1.0:好朋友/知己 - 深度信任,亲密无间
|
||||
|
||||
【严格要求】:
|
||||
1. 加分必须符合现实关系发展逻辑 - 不能因为用户反应好就盲目加分
|
||||
2. 关系提升需要足够的互动积累和时间验证,单次互动加分通常不超过0.05-0.1
|
||||
3. 必须考虑当前关系档次,不能跳跃式提升(比如从0.3直接到0.7)
|
||||
4. 关系印象描述要详细具体(100-200字),包括:
|
||||
- 用户性格特点和交流风格观察
|
||||
- 印象深刻的互动记忆和对话片段
|
||||
- 你们关系的具体状态描述和发展阶段
|
||||
- 根据你的性格,你对用户的真实感受
|
||||
|
||||
性格视角深度分析:
|
||||
1. 以你的性格特点,用户这次的反应给你什么感受?
|
||||
2. 用户的情绪和行为符合你性格的喜好吗?具体哪些方面?
|
||||
3. 从现实角度看,这次互动是否足以让关系提升到下一个档次?为什么?
|
||||
4. 有什么特别值得记住的互动细节或对话内容?
|
||||
5. 基于你们的互动历史,用户给你留下了哪些深刻印象?
|
||||
|
||||
请以JSON格式返回更新结果:
|
||||
{{
|
||||
"relationship_text": "详细的关系印象描述(100-200字),包含用户性格观察、印象深刻记忆、关系状态描述",
|
||||
"relationship_score": 0.0~1.0的新分数(必须严格符合现实逻辑),
|
||||
"analysis_reasoning": "从你性格角度的深度分析,重点说明分数调整的现实合理性",
|
||||
"interaction_quality": "high/medium/low"
|
||||
}}
|
||||
"""
|
||||
|
||||
# 调用LLM进行分析
|
||||
llm_response, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
|
||||
|
||||
if llm_response:
|
||||
import json
|
||||
|
||||
try:
|
||||
# 清理LLM响应,移除可能的格式标记
|
||||
cleaned_response = self._clean_llm_json_response(llm_response)
|
||||
response_data = json.loads(cleaned_response)
|
||||
|
||||
new_text = response_data.get("relationship_text", current_text)
|
||||
new_score = max(0.0, min(1.0, float(response_data.get("relationship_score", current_score))))
|
||||
reasoning = response_data.get("analysis_reasoning", "")
|
||||
quality = response_data.get("interaction_quality", "medium")
|
||||
|
||||
# 更新数据库
|
||||
self._update_user_relationship_in_db(user_id, new_text, new_score)
|
||||
|
||||
# 更新缓存
|
||||
self.user_relationship_cache[user_id] = {
|
||||
"relationship_text": new_text,
|
||||
"relationship_score": new_score,
|
||||
"last_tracked": time.time(),
|
||||
}
|
||||
|
||||
# 如果有兴趣度评分系统,也更新内存中的关系分
|
||||
if self.interest_scoring_system:
|
||||
self.interest_scoring_system.update_user_relationship(user_id, new_score - current_score)
|
||||
|
||||
# 记录分析历史
|
||||
analysis_record = {
|
||||
"user_id": user_id,
|
||||
"timestamp": time.time(),
|
||||
"old_score": current_score,
|
||||
"new_score": new_score,
|
||||
"old_text": current_text,
|
||||
"new_text": new_text,
|
||||
"reasoning": reasoning,
|
||||
"quality": quality,
|
||||
"user_reactions_count": len(user_reactions),
|
||||
}
|
||||
self.relationship_history.append(analysis_record)
|
||||
|
||||
# 限制历史记录数量
|
||||
if len(self.relationship_history) > 100:
|
||||
self.relationship_history = self.relationship_history[-100:]
|
||||
|
||||
logger.info(f"✅ 关系分析完成: {user_id}")
|
||||
logger.info(f" 📝 印象: '{current_text}' -> '{new_text}'")
|
||||
logger.info(f" 💝 分数: {current_score:.3f} -> {new_score:.3f}")
|
||||
logger.info(f" 🎯 质量: {quality}")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"LLM响应JSON解析失败: {e}")
|
||||
logger.debug(f"LLM原始响应: {llm_response}")
|
||||
else:
|
||||
logger.warning("LLM未返回有效响应")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"关系分析失败: {e}")
|
||||
logger.debug("错误详情:", exc_info=True)
|
||||
|
||||
async def _handle_first_interaction(self, user_id: str, user_name: str, bot_reply_content: str):
|
||||
"""处理与用户的初次交互"""
|
||||
try:
|
||||
logger.info(f"✨ [RelationshipTracker] 正在处理与用户 {user_id} 的初次交互")
|
||||
|
||||
# 获取bot人设信息
|
||||
from src.individuality.individuality import Individuality
|
||||
|
||||
individuality = Individuality()
|
||||
bot_personality = await individuality.get_personality_block()
|
||||
|
||||
prompt = f"""
|
||||
你现在是:{bot_personality}
|
||||
|
||||
你正在与一个新用户进行初次有效互动。请根据你对TA的第一印象,建立初始关系档案。
|
||||
|
||||
用户信息:
|
||||
- 用户ID: {user_id}
|
||||
- 用户名: {user_name}
|
||||
|
||||
你的首次回复: {bot_reply_content}
|
||||
|
||||
【严格要求】:
|
||||
1. 建立一个初始关系分数,通常在0.2-0.4之间(普通网友)。
|
||||
2. 关系印象描述要简洁地记录你对用户的初步看法(50-100字)。
|
||||
- 用户名给你的感觉?
|
||||
- 你的回复是基于什么考虑?
|
||||
- 你对接下来与TA的互动有什么期待?
|
||||
|
||||
请以JSON格式返回结果:
|
||||
{{
|
||||
"relationship_text": "简洁的初始关系印象描述(50-100字)",
|
||||
"relationship_score": 0.2~0.4的新分数,
|
||||
"analysis_reasoning": "从你性格角度说明建立此初始印象的理由"
|
||||
}}
|
||||
"""
|
||||
# 调用LLM进行分析
|
||||
llm_response, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
|
||||
if not llm_response:
|
||||
logger.warning(f"初次交互分析时LLM未返回有效响应: {user_id}")
|
||||
return
|
||||
|
||||
import json
|
||||
|
||||
cleaned_response = self._clean_llm_json_response(llm_response)
|
||||
response_data = json.loads(cleaned_response)
|
||||
|
||||
new_text = response_data.get("relationship_text", "初次见面")
|
||||
new_score = max(
|
||||
0.0,
|
||||
min(
|
||||
1.0,
|
||||
float(response_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)),
|
||||
),
|
||||
)
|
||||
|
||||
# 更新数据库和缓存
|
||||
self._update_user_relationship_in_db(user_id, new_text, new_score)
|
||||
self.user_relationship_cache[user_id] = {
|
||||
"relationship_text": new_text,
|
||||
"relationship_score": new_score,
|
||||
"last_tracked": time.time(),
|
||||
}
|
||||
|
||||
logger.info(f"✅ [RelationshipTracker] 已成功为新用户 {user_id} 建立初始关系档案,分数为 {new_score:.3f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理初次交互失败: {user_id}, 错误: {e}")
|
||||
logger.debug("错误详情:", exc_info=True)
|
||||
|
||||
def _clean_llm_json_response(self, response: str) -> str:
|
||||
"""
|
||||
清理LLM响应,移除可能的JSON格式标记
|
||||
|
||||
Args:
|
||||
response: LLM原始响应
|
||||
|
||||
Returns:
|
||||
清理后的JSON字符串
|
||||
"""
|
||||
try:
|
||||
import re
|
||||
|
||||
# 移除常见的JSON格式标记
|
||||
cleaned = response.strip()
|
||||
|
||||
# 移除 ```json 或 ``` 等标记
|
||||
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE)
|
||||
cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE)
|
||||
|
||||
# 移除可能的Markdown代码块标记
|
||||
cleaned = re.sub(r"^`|`$", "", cleaned, flags=re.MULTILINE)
|
||||
|
||||
# 尝试找到JSON对象的开始和结束
|
||||
json_start = cleaned.find("{")
|
||||
json_end = cleaned.rfind("}")
|
||||
|
||||
if json_start != -1 and json_end != -1 and json_end > json_start:
|
||||
# 提取JSON部分
|
||||
cleaned = cleaned[json_start : json_end + 1]
|
||||
|
||||
# 移除多余的空白字符
|
||||
cleaned = cleaned.strip()
|
||||
|
||||
logger.debug(f"LLM响应清理: 原始长度={len(response)}, 清理后长度={len(cleaned)}")
|
||||
if cleaned != response:
|
||||
logger.debug(f"清理前: {response[:200]}...")
|
||||
logger.debug(f"清理后: {cleaned[:200]}...")
|
||||
|
||||
return cleaned
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"清理LLM响应失败: {e}")
|
||||
return response # 清理失败时返回原始响应
|
||||
@@ -64,50 +64,50 @@ class AtAction(BaseAction):
|
||||
# 使用回复器生成艾特回复,而不是直接发送命令
|
||||
from src.chat.replyer.default_generator import DefaultReplyer
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
|
||||
# 获取当前聊天流
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = self.chat_stream or chat_manager.get_stream(self.chat_id)
|
||||
|
||||
|
||||
if not chat_stream:
|
||||
logger.error(f"找不到聊天流: {self.chat_stream}")
|
||||
return False, "聊天流不存在"
|
||||
|
||||
|
||||
# 创建回复器实例
|
||||
replyer = DefaultReplyer(chat_stream)
|
||||
|
||||
|
||||
# 构建回复对象,将艾特消息作为回复目标
|
||||
reply_to = f"{user_name}:{at_message}"
|
||||
extra_info = f"你需要艾特用户 {user_name} 并回复他们说: {at_message}"
|
||||
|
||||
|
||||
# 使用回复器生成回复
|
||||
success, llm_response, prompt = await replyer.generate_reply_with_context(
|
||||
reply_to=reply_to,
|
||||
extra_info=extra_info,
|
||||
enable_tool=False, # 艾特回复通常不需要工具调用
|
||||
from_plugin=False
|
||||
from_plugin=False,
|
||||
)
|
||||
|
||||
|
||||
if success and llm_response:
|
||||
# 获取生成的回复内容
|
||||
reply_content = llm_response.get("content", "")
|
||||
if reply_content:
|
||||
# 获取用户QQ号,发送真正的艾特消息
|
||||
user_id = user_info.get("user_id")
|
||||
|
||||
|
||||
# 发送真正的艾特命令,使用回复器生成的智能内容
|
||||
await self.send_command(
|
||||
"SEND_AT_MESSAGE",
|
||||
args={"qq_id": user_id, "text": reply_content},
|
||||
display_message=f"艾特用户 {user_name} 并发送智能回复: {reply_content}",
|
||||
)
|
||||
|
||||
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True,
|
||||
action_prompt_display=f"执行了艾特用户动作:艾特用户 {user_name} 并发送智能回复: {reply_content}",
|
||||
action_done=True,
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"成功通过回复器生成智能内容并发送真正的艾特消息给 {user_name}: {reply_content}")
|
||||
return True, "智能艾特消息发送成功"
|
||||
else:
|
||||
@@ -116,7 +116,7 @@ class AtAction(BaseAction):
|
||||
else:
|
||||
logger.error("回复器生成回复失败")
|
||||
return False, "回复生成失败"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行艾特用户动作时发生异常: {e}", exc_info=True)
|
||||
await self.store_action_info(
|
||||
|
||||
@@ -26,8 +26,8 @@
|
||||
"components": [
|
||||
{
|
||||
"type": "action",
|
||||
"name": "emoji",
|
||||
"description": "发送表情包辅助表达情绪"
|
||||
"name": "emoji",
|
||||
"description": "作为一条全新的消息,发送一个符合当前情景的表情包来生动地表达情绪。"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ class EmojiAction(BaseAction):
|
||||
|
||||
# 动作基本信息
|
||||
action_name = "emoji"
|
||||
action_description = "发送表情包辅助表达情绪"
|
||||
action_description = "作为一条全新的消息,发送一个符合当前情景的表情包来生动地表达情绪。"
|
||||
|
||||
# LLM判断提示词
|
||||
llm_judge_prompt = """
|
||||
@@ -70,7 +70,9 @@ class EmojiAction(BaseAction):
|
||||
|
||||
# 2. 获取所有有效的表情包对象
|
||||
emoji_manager = get_emoji_manager()
|
||||
all_emojis_obj: list[MaiEmoji] = [e for e in emoji_manager.emoji_objects if not e.is_deleted and e.description]
|
||||
all_emojis_obj: list[MaiEmoji] = [
|
||||
e for e in emoji_manager.emoji_objects if not e.is_deleted and e.description
|
||||
]
|
||||
if not all_emojis_obj:
|
||||
logger.warning(f"{self.log_prefix} 无法获取任何带有描述的有效表情包")
|
||||
return False, "无法获取任何带有描述的有效表情包"
|
||||
@@ -91,12 +93,12 @@ class EmojiAction(BaseAction):
|
||||
# 4. 准备情感数据和后备列表
|
||||
emotion_map = {}
|
||||
all_emojis_data = []
|
||||
|
||||
|
||||
for emoji in all_emojis_obj:
|
||||
b64 = image_path_to_base64(emoji.full_path)
|
||||
if not b64:
|
||||
continue
|
||||
|
||||
|
||||
desc = emoji.description
|
||||
emotions = emoji.emotion
|
||||
all_emojis_data.append((b64, desc))
|
||||
@@ -122,10 +124,10 @@ class EmojiAction(BaseAction):
|
||||
emoji_base64, emoji_description = random.choice(all_emojis_data)
|
||||
else:
|
||||
# 获取最近的5条消息内容用于判断
|
||||
recent_messages = await message_api.get_recent_messages(chat_id=self.chat_id, limit=5)
|
||||
recent_messages = message_api.get_recent_messages(chat_id=self.chat_id, limit=5)
|
||||
messages_text = ""
|
||||
if recent_messages:
|
||||
messages_text = await message_api.build_readable_messages(
|
||||
messages_text = message_api.build_readable_messages(
|
||||
messages=recent_messages,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=False,
|
||||
@@ -150,10 +152,10 @@ class EmojiAction(BaseAction):
|
||||
|
||||
# 调用LLM
|
||||
models = llm_api.get_available_models()
|
||||
chat_model_config = models.get("planner")
|
||||
chat_model_config = models.get("utils")
|
||||
if not chat_model_config:
|
||||
logger.error(f"{self.log_prefix} 未找到'planner'模型配置,无法调用LLM")
|
||||
return False, "未找到'planner'模型配置"
|
||||
logger.error(f"{self.log_prefix} 未找到'utils'模型配置,无法调用LLM")
|
||||
return False, "未找到'utils'模型配置"
|
||||
|
||||
success, chosen_emotion, _, _ = await llm_api.generate_with_model(
|
||||
prompt, model_config=chat_model_config, request_type="emoji"
|
||||
@@ -168,23 +170,25 @@ class EmojiAction(BaseAction):
|
||||
|
||||
# 使用模糊匹配来查找最相关的情感标签
|
||||
matched_key = next((key for key in emotion_map if chosen_emotion in key), None)
|
||||
|
||||
|
||||
if matched_key:
|
||||
emoji_base64, emoji_description = random.choice(emotion_map[matched_key])
|
||||
logger.info(f"{self.log_prefix} 找到匹配情感 '{chosen_emotion}' (匹配到: '{matched_key}') 的表情包: {emoji_description}")
|
||||
logger.info(
|
||||
f"{self.log_prefix} 找到匹配情感 '{chosen_emotion}' (匹配到: '{matched_key}') 的表情包: {emoji_description}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"{self.log_prefix} LLM选择的情感 '{chosen_emotion}' 不在可用列表中, 将随机选择一个表情包"
|
||||
)
|
||||
emoji_base64, emoji_description = random.choice(all_emojis_data)
|
||||
|
||||
|
||||
elif global_config.emoji.emoji_selection_mode == "description":
|
||||
# --- 详细描述选择模式 ---
|
||||
# 获取最近的5条消息内容用于判断
|
||||
recent_messages = await message_api.get_recent_messages(chat_id=self.chat_id, limit=5)
|
||||
recent_messages = message_api.get_recent_messages(chat_id=self.chat_id, limit=5)
|
||||
messages_text = ""
|
||||
if recent_messages:
|
||||
messages_text = await message_api.build_readable_messages(
|
||||
messages_text = message_api.build_readable_messages(
|
||||
messages=recent_messages,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=False,
|
||||
@@ -208,10 +212,10 @@ class EmojiAction(BaseAction):
|
||||
|
||||
# 调用LLM
|
||||
models = llm_api.get_available_models()
|
||||
chat_model_config = models.get("planner")
|
||||
chat_model_config = models.get("utils")
|
||||
if not chat_model_config:
|
||||
logger.error(f"{self.log_prefix} 未找到'planner'模型配置,无法调用LLM")
|
||||
return False, "未找到'planner'模型配置"
|
||||
logger.error(f"{self.log_prefix} 未找到'utils'模型配置,无法调用LLM")
|
||||
return False, "未找到'utils'模型配置"
|
||||
|
||||
success, chosen_description, _, _ = await llm_api.generate_with_model(
|
||||
prompt, model_config=chat_model_config, request_type="emoji"
|
||||
@@ -226,15 +230,23 @@ class EmojiAction(BaseAction):
|
||||
logger.info(f"{self.log_prefix} LLM选择的描述: {chosen_description}")
|
||||
|
||||
# 简单关键词匹配
|
||||
matched_emoji = next((item for item in all_emojis_data if chosen_description.lower() in item[1].lower() or item[1].lower() in chosen_description.lower()), None)
|
||||
|
||||
matched_emoji = next(
|
||||
(
|
||||
item
|
||||
for item in all_emojis_data
|
||||
if chosen_description.lower() in item[1].lower()
|
||||
or item[1].lower() in chosen_description.lower()
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# 如果包含匹配失败,尝试关键词匹配
|
||||
if not matched_emoji:
|
||||
keywords = ['惊讶', '困惑', '呆滞', '震惊', '懵', '无语', '萌', '可爱']
|
||||
keywords = ["惊讶", "困惑", "呆滞", "震惊", "懵", "无语", "萌", "可爱"]
|
||||
for keyword in keywords:
|
||||
if keyword in chosen_description:
|
||||
for item in all_emojis_data:
|
||||
if any(k in item[1] for k in ['呆', '萌', '惊', '困惑', '无语']):
|
||||
if any(k in item[1] for k in ["呆", "萌", "惊", "困惑", "无语"]):
|
||||
matched_emoji = item
|
||||
break
|
||||
if matched_emoji:
|
||||
@@ -255,7 +267,9 @@ class EmojiAction(BaseAction):
|
||||
|
||||
if not success:
|
||||
logger.error(f"{self.log_prefix} 表情包发送失败")
|
||||
await self.store_action_info(action_build_into_prompt = True,action_prompt_display ="发送了一个表情包,但失败了",action_done= False)
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True, action_prompt_display=f"发送了一个表情包,但失败了", action_done=False
|
||||
)
|
||||
return False, "表情包发送失败"
|
||||
|
||||
# 发送成功后,记录到历史
|
||||
@@ -263,8 +277,10 @@ class EmojiAction(BaseAction):
|
||||
add_emoji_to_history(self.chat_id, emoji_description)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 添加表情到历史记录时出错: {e}")
|
||||
|
||||
await self.store_action_info(action_build_into_prompt = True,action_prompt_display ="发送了一个表情包",action_done= True)
|
||||
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True, action_prompt_display=f"发送了一个表情包", action_done=True
|
||||
)
|
||||
|
||||
return True, f"发送表情包: {emoji_description}"
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
|
||||
"host_application": {
|
||||
"min_version": "0.10.0",
|
||||
"max_version": "0.10.0"
|
||||
"max_version": "0.11.0"
|
||||
},
|
||||
"homepage_url": "https://github.com/Windpicker-owo/InternetSearchPlugin",
|
||||
"repository_url": "https://github.com/Windpicker-owo/InternetSearchPlugin",
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
from src.plugin_system import BaseEventHandler
|
||||
from src.plugin_system.base.base_event import HandlerResult
|
||||
|
||||
@@ -1748,6 +1747,7 @@ class SetGroupSignHandler(BaseEventHandler):
|
||||
logger.error("事件 napcat_set_group_sign 请求失败!")
|
||||
return HandlerResult(False, False, {"status": "error"})
|
||||
|
||||
|
||||
# ===PERSONAL===
|
||||
class SetInputStatusHandler(BaseEventHandler):
|
||||
handler_name: str = "napcat_set_input_status_handler"
|
||||
|
||||
@@ -233,7 +233,7 @@ class LauchNapcatAdapterHandler(BaseEventHandler):
|
||||
await reassembler.start_cleanup_task()
|
||||
|
||||
logger.info("开始启动Napcat Adapter")
|
||||
|
||||
|
||||
# 创建单独的异步任务,防止阻塞主线程
|
||||
asyncio.create_task(self._start_maibot_connection())
|
||||
asyncio.create_task(napcat_server(self.plugin_config))
|
||||
@@ -244,10 +244,10 @@ class LauchNapcatAdapterHandler(BaseEventHandler):
|
||||
"""非阻塞方式启动MaiBot连接,等待主服务启动后再连接"""
|
||||
# 等待一段时间让MaiBot主服务完全启动
|
||||
await asyncio.sleep(5)
|
||||
|
||||
|
||||
max_attempts = 10
|
||||
attempt = 0
|
||||
|
||||
|
||||
while attempt < max_attempts:
|
||||
try:
|
||||
logger.info(f"尝试连接MaiBot (第{attempt + 1}次)")
|
||||
@@ -291,7 +291,7 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
def enable_plugin(self) -> bool:
|
||||
"""通过配置文件动态控制插件启用状态"""
|
||||
# 如果已经通过配置加载了状态,使用配置中的值
|
||||
if hasattr(self, '_is_enabled'):
|
||||
if hasattr(self, "_is_enabled"):
|
||||
return self._is_enabled
|
||||
# 否则使用默认值(禁用状态)
|
||||
return False
|
||||
@@ -305,7 +305,7 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
"name": ConfigField(type=str, default="napcat_adapter_plugin", description="插件名称"),
|
||||
"version": ConfigField(type=str, default="1.1.0", description="插件版本"),
|
||||
"config_version": ConfigField(type=str, default="1.3.1", description="配置文件版本"),
|
||||
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
|
||||
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
||||
},
|
||||
"inner": {
|
||||
"version": ConfigField(type=str, default="0.2.1", description="配置版本号,请勿修改"),
|
||||
@@ -314,60 +314,88 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
"nickname": ConfigField(type=str, default="", description="昵称配置(目前未使用)"),
|
||||
},
|
||||
"napcat_server": {
|
||||
"mode": ConfigField(type=str, default="reverse", description="连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端)", choices=["reverse", "forward"]),
|
||||
"mode": ConfigField(
|
||||
type=str,
|
||||
default="reverse",
|
||||
description="连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端)",
|
||||
choices=["reverse", "forward"],
|
||||
),
|
||||
"host": ConfigField(type=str, default="localhost", description="主机地址"),
|
||||
"port": ConfigField(type=int, default=8095, description="端口号"),
|
||||
"url": ConfigField(type=str, default="", description="正向连接时的完整WebSocket URL,如 ws://localhost:8080/ws (仅在forward模式下使用)"),
|
||||
"access_token": ConfigField(type=str, default="", description="WebSocket 连接的访问令牌,用于身份验证(可选)"),
|
||||
"url": ConfigField(
|
||||
type=str,
|
||||
default="",
|
||||
description="正向连接时的完整WebSocket URL,如 ws://localhost:8080/ws (仅在forward模式下使用)",
|
||||
),
|
||||
"access_token": ConfigField(
|
||||
type=str, default="", description="WebSocket 连接的访问令牌,用于身份验证(可选)"
|
||||
),
|
||||
"heartbeat_interval": ConfigField(type=int, default=30, description="心跳间隔时间(按秒计)"),
|
||||
},
|
||||
"maibot_server": {
|
||||
"host": ConfigField(type=str, default="localhost", description="麦麦在.env文件中设置的主机地址,即HOST字段"),
|
||||
"host": ConfigField(
|
||||
type=str, default="localhost", description="麦麦在.env文件中设置的主机地址,即HOST字段"
|
||||
),
|
||||
"port": ConfigField(type=int, default=8000, description="麦麦在.env文件中设置的端口,即PORT字段"),
|
||||
"platform_name": ConfigField(type=str, default="qq", description="平台名称,用于消息路由"),
|
||||
},
|
||||
"voice": {
|
||||
"use_tts": ConfigField(type=bool, default=False, description="是否使用tts语音(请确保你配置了tts并有对应的adapter)"),
|
||||
"use_tts": ConfigField(
|
||||
type=bool, default=False, description="是否使用tts语音(请确保你配置了tts并有对应的adapter)"
|
||||
),
|
||||
},
|
||||
"slicing": {
|
||||
"max_frame_size": ConfigField(type=int, default=64, description="WebSocket帧的最大大小,单位为字节,默认64KB"),
|
||||
"max_frame_size": ConfigField(
|
||||
type=int, default=64, description="WebSocket帧的最大大小,单位为字节,默认64KB"
|
||||
),
|
||||
"delay_ms": ConfigField(type=int, default=10, description="切片发送间隔时间,单位为毫秒"),
|
||||
},
|
||||
"debug": {
|
||||
"level": ConfigField(type=str, default="INFO", description="日志等级(DEBUG, INFO, WARNING, ERROR, CRITICAL)", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]),
|
||||
"level": ConfigField(
|
||||
type=str,
|
||||
default="INFO",
|
||||
description="日志等级(DEBUG, INFO, WARNING, ERROR, CRITICAL)",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||
),
|
||||
},
|
||||
"features": {
|
||||
# 权限设置
|
||||
"group_list_type": ConfigField(type=str, default="blacklist", description="群聊列表类型:whitelist(白名单)或 blacklist(黑名单)", choices=["whitelist", "blacklist"]),
|
||||
"group_list_type": ConfigField(
|
||||
type=str,
|
||||
default="blacklist",
|
||||
description="群聊列表类型:whitelist(白名单)或 blacklist(黑名单)",
|
||||
choices=["whitelist", "blacklist"],
|
||||
),
|
||||
"group_list": ConfigField(type=list, default=[], description="群聊ID列表"),
|
||||
"private_list_type": ConfigField(type=str, default="blacklist", description="私聊列表类型:whitelist(白名单)或 blacklist(黑名单)", choices=["whitelist", "blacklist"]),
|
||||
"private_list_type": ConfigField(
|
||||
type=str,
|
||||
default="blacklist",
|
||||
description="私聊列表类型:whitelist(白名单)或 blacklist(黑名单)",
|
||||
choices=["whitelist", "blacklist"],
|
||||
),
|
||||
"private_list": ConfigField(type=list, default=[], description="用户ID列表"),
|
||||
"ban_user_id": ConfigField(type=list, default=[], description="全局禁止用户ID列表,这些用户无法在任何地方使用机器人"),
|
||||
"ban_user_id": ConfigField(
|
||||
type=list, default=[], description="全局禁止用户ID列表,这些用户无法在任何地方使用机器人"
|
||||
),
|
||||
"ban_qq_bot": ConfigField(type=bool, default=False, description="是否屏蔽QQ官方机器人消息"),
|
||||
|
||||
# 聊天功能设置
|
||||
"enable_poke": ConfigField(type=bool, default=True, description="是否启用戳一戳功能"),
|
||||
"ignore_non_self_poke": ConfigField(type=bool, default=False, description="是否无视不是针对自己的戳一戳"),
|
||||
"poke_debounce_seconds": ConfigField(type=int, default=3, description="戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略"),
|
||||
"poke_debounce_seconds": ConfigField(
|
||||
type=int, default=3, description="戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略"
|
||||
),
|
||||
"enable_reply_at": ConfigField(type=bool, default=True, description="是否启用引用回复时艾特用户的功能"),
|
||||
"reply_at_rate": ConfigField(type=float, default=0.5, description="引用回复时艾特用户的几率 (0.0 ~ 1.0)"),
|
||||
"enable_emoji_like": ConfigField(type=bool, default=True, description="是否启用群聊表情回复功能"),
|
||||
|
||||
# 视频处理设置
|
||||
"enable_video_analysis": ConfigField(type=bool, default=True, description="是否启用视频识别功能"),
|
||||
"max_video_size_mb": ConfigField(type=int, default=100, description="视频文件最大大小限制(MB)"),
|
||||
"download_timeout": ConfigField(type=int, default=60, description="视频下载超时时间(秒)"),
|
||||
"supported_formats": ConfigField(type=list, default=["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"], description="支持的视频格式"),
|
||||
|
||||
# 消息缓冲设置
|
||||
"enable_message_buffer": ConfigField(type=bool, default=True, description="是否启用消息缓冲合并功能"),
|
||||
"message_buffer_enable_group": ConfigField(type=bool, default=True, description="是否启用群聊消息缓冲合并"),
|
||||
"message_buffer_enable_private": ConfigField(type=bool, default=True, description="是否启用私聊消息缓冲合并"),
|
||||
"message_buffer_interval": ConfigField(type=float, default=3.0, description="消息合并间隔时间(秒),在此时间内的连续消息将被合并"),
|
||||
"message_buffer_initial_delay": ConfigField(type=float, default=0.5, description="消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并"),
|
||||
"message_buffer_max_components": ConfigField(type=int, default=50, description="单个会话最大缓冲消息组件数量,超过此数量将强制合并"),
|
||||
"message_buffer_block_prefixes": ConfigField(type=list, default=["/", "!", "!", ".", "。", "#", "%"], description="消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲"),
|
||||
}
|
||||
"supported_formats": ConfigField(
|
||||
type=list, default=["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"], description="支持的视频格式"
|
||||
),
|
||||
# 消息缓冲功能已移除
|
||||
},
|
||||
}
|
||||
|
||||
# 配置节描述
|
||||
@@ -380,7 +408,7 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
"voice": "发送语音设置",
|
||||
"slicing": "WebSocket消息切片设置",
|
||||
"debug": "调试设置",
|
||||
"features": "功能设置(权限控制、聊天功能、视频处理、消息缓冲等)"
|
||||
"features": "功能设置(权限控制、聊天功能、视频处理、消息缓冲等)",
|
||||
}
|
||||
|
||||
def register_events(self):
|
||||
@@ -414,6 +442,7 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
chunker.set_plugin_config(self.config)
|
||||
# 设置response_pool的插件配置
|
||||
from .src.response_pool import set_plugin_config as set_response_pool_config
|
||||
|
||||
set_response_pool_config(self.config)
|
||||
# 设置send_handler的插件配置
|
||||
send_handler.set_plugin_config(self.config)
|
||||
@@ -423,4 +452,4 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
notice_handler.set_plugin_config(self.config)
|
||||
# 设置meta_event_handler的插件配置
|
||||
meta_event_handler.set_plugin_config(self.config)
|
||||
# 设置其他handler的插件配置(现在由component_registry在注册时自动设置)
|
||||
# 设置其他handler的插件配置(现在由component_registry在注册时自动设置)
|
||||
|
||||
@@ -1,317 +0,0 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
from src.plugin_system.apis import config_api
|
||||
from .recv_handler import RealMessageType
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextMessage:
|
||||
"""文本消息"""
|
||||
|
||||
text: str
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BufferedSession:
|
||||
"""缓冲会话数据"""
|
||||
|
||||
session_id: str
|
||||
messages: List[TextMessage] = field(default_factory=list)
|
||||
timer_task: Optional[asyncio.Task] = None
|
||||
delay_task: Optional[asyncio.Task] = None
|
||||
original_event: Any = None
|
||||
created_at: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
class SimpleMessageBuffer:
|
||||
def __init__(self, merge_callback=None):
|
||||
"""
|
||||
初始化消息缓冲器
|
||||
|
||||
Args:
|
||||
merge_callback: 消息合并后的回调函数,接收(session_id, merged_text, original_event)参数
|
||||
"""
|
||||
self.buffer_pool: Dict[str, BufferedSession] = {}
|
||||
self.lock = asyncio.Lock()
|
||||
self.merge_callback = merge_callback
|
||||
self._shutdown = False
|
||||
self.plugin_config = None
|
||||
|
||||
def set_plugin_config(self, plugin_config: dict):
|
||||
"""设置插件配置"""
|
||||
self.plugin_config = plugin_config
|
||||
|
||||
@staticmethod
|
||||
def get_session_id(event_data: Dict[str, Any]) -> str:
|
||||
"""根据事件数据生成会话ID"""
|
||||
message_type = event_data.get("message_type", "unknown")
|
||||
user_id = event_data.get("user_id", "unknown")
|
||||
|
||||
if message_type == "private":
|
||||
return f"private_{user_id}"
|
||||
elif message_type == "group":
|
||||
group_id = event_data.get("group_id", "unknown")
|
||||
return f"group_{group_id}_{user_id}"
|
||||
else:
|
||||
return f"{message_type}_{user_id}"
|
||||
|
||||
@staticmethod
|
||||
def extract_text_from_message(message: List[Dict[str, Any]]) -> Optional[str]:
|
||||
"""从OneBot消息中提取纯文本,如果包含非文本内容则返回None"""
|
||||
text_parts = []
|
||||
has_non_text = False
|
||||
|
||||
logger.debug(f"正在提取消息文本,消息段数量: {len(message)}")
|
||||
|
||||
for msg_seg in message:
|
||||
msg_type = msg_seg.get("type", "")
|
||||
logger.debug(f"处理消息段类型: {msg_type}")
|
||||
|
||||
if msg_type == RealMessageType.text:
|
||||
text = msg_seg.get("data", {}).get("text", "").strip()
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
logger.debug(f"提取到文本: {text[:50]}...")
|
||||
else:
|
||||
# 发现非文本消息段,标记为包含非文本内容
|
||||
has_non_text = True
|
||||
logger.debug(f"发现非文本消息段: {msg_type},跳过缓冲")
|
||||
|
||||
# 如果包含非文本内容,则不进行缓冲
|
||||
if has_non_text:
|
||||
logger.debug("消息包含非文本内容,不进行缓冲")
|
||||
return None
|
||||
|
||||
if text_parts:
|
||||
combined_text = " ".join(text_parts).strip()
|
||||
logger.debug(f"成功提取纯文本: {combined_text[:50]}...")
|
||||
return combined_text
|
||||
|
||||
logger.debug("没有找到有效的文本内容")
|
||||
return None
|
||||
|
||||
def should_skip_message(self, text: str) -> bool:
|
||||
"""判断消息是否应该跳过缓冲"""
|
||||
if not text or not text.strip():
|
||||
return True
|
||||
|
||||
# 检查屏蔽前缀
|
||||
block_prefixes = tuple(config_api.get_plugin_config(self.plugin_config, "features.message_buffer_block_prefixes", []))
|
||||
|
||||
text = text.strip()
|
||||
if text.startswith(block_prefixes):
|
||||
logger.debug(f"消息以屏蔽前缀开头,跳过缓冲: {text[:20]}...")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def add_text_message(
|
||||
self, event_data: Dict[str, Any], message: List[Dict[str, Any]], original_event: Any = None
|
||||
) -> bool:
|
||||
"""
|
||||
添加文本消息到缓冲区
|
||||
|
||||
Args:
|
||||
event_data: 事件数据
|
||||
message: OneBot消息数组
|
||||
original_event: 原始事件对象
|
||||
|
||||
Returns:
|
||||
是否成功添加到缓冲区
|
||||
"""
|
||||
if self._shutdown:
|
||||
return False
|
||||
|
||||
# 检查是否启用消息缓冲
|
||||
if not config_api.get_plugin_config(self.plugin_config, "features.enable_message_buffer", False):
|
||||
return False
|
||||
|
||||
# 检查是否启用对应类型的缓冲
|
||||
message_type = event_data.get("message_type", "")
|
||||
if message_type == "group" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_group", False):
|
||||
return False
|
||||
elif message_type == "private" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_private", False):
|
||||
return False
|
||||
|
||||
# 提取文本
|
||||
text = self.extract_text_from_message(message)
|
||||
if not text:
|
||||
return False
|
||||
|
||||
# 检查是否应该跳过
|
||||
if self.should_skip_message(text):
|
||||
return False
|
||||
|
||||
session_id = self.get_session_id(event_data)
|
||||
|
||||
async with self.lock:
|
||||
# 获取或创建会话
|
||||
if session_id not in self.buffer_pool:
|
||||
self.buffer_pool[session_id] = BufferedSession(session_id=session_id, original_event=original_event)
|
||||
|
||||
session = self.buffer_pool[session_id]
|
||||
|
||||
# 检查是否超过最大组件数量
|
||||
if len(session.messages) >= config_api.get_plugin_config(self.plugin_config, "features.message_buffer_max_components", 5):
|
||||
logger.debug(f"会话 {session_id} 消息数量达到上限,强制合并")
|
||||
asyncio.create_task(self._force_merge_session(session_id))
|
||||
self.buffer_pool[session_id] = BufferedSession(session_id=session_id, original_event=original_event)
|
||||
session = self.buffer_pool[session_id]
|
||||
|
||||
# 添加文本消息
|
||||
session.messages.append(TextMessage(text=text))
|
||||
session.original_event = original_event # 更新事件
|
||||
|
||||
# 取消之前的定时器
|
||||
await self._cancel_session_timers(session)
|
||||
|
||||
# 设置新的延迟任务
|
||||
session.delay_task = asyncio.create_task(self._wait_and_start_merge(session_id))
|
||||
|
||||
logger.debug(f"文本消息已添加到缓冲器 {session_id}: {text[:50]}...")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def _cancel_session_timers(session: BufferedSession):
|
||||
"""取消会话的所有定时器"""
|
||||
for task_name in ["timer_task", "delay_task"]:
|
||||
task = getattr(session, task_name)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
setattr(session, task_name, None)
|
||||
|
||||
async def _wait_and_start_merge(self, session_id: str):
|
||||
"""等待初始延迟后开始合并定时器"""
|
||||
initial_delay = config_api.get_plugin_config(self.plugin_config, "features.message_buffer_initial_delay", 0.5)
|
||||
await asyncio.sleep(initial_delay)
|
||||
|
||||
async with self.lock:
|
||||
session = self.buffer_pool.get(session_id)
|
||||
if session and session.messages:
|
||||
# 取消旧的定时器
|
||||
if session.timer_task and not session.timer_task.done():
|
||||
session.timer_task.cancel()
|
||||
try:
|
||||
await session.timer_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 设置合并定时器
|
||||
session.timer_task = asyncio.create_task(self._wait_and_merge(session_id))
|
||||
|
||||
async def _wait_and_merge(self, session_id: str):
|
||||
"""等待合并间隔后执行合并"""
|
||||
interval = config_api.get_plugin_config(self.plugin_config, "features.message_buffer_interval", 2.0)
|
||||
await asyncio.sleep(interval)
|
||||
await self._merge_session(session_id)
|
||||
|
||||
async def _force_merge_session(self, session_id: str):
|
||||
"""强制合并会话(不等待定时器)"""
|
||||
await self._merge_session(session_id, force=True)
|
||||
|
||||
async def _merge_session(self, session_id: str, force: bool = False):
|
||||
"""合并会话中的消息"""
|
||||
async with self.lock:
|
||||
session = self.buffer_pool.get(session_id)
|
||||
if not session or not session.messages:
|
||||
self.buffer_pool.pop(session_id, None)
|
||||
return
|
||||
|
||||
try:
|
||||
# 合并文本消息
|
||||
text_parts = []
|
||||
for msg in session.messages:
|
||||
if msg.text.strip():
|
||||
text_parts.append(msg.text.strip())
|
||||
|
||||
if not text_parts:
|
||||
self.buffer_pool.pop(session_id, None)
|
||||
return
|
||||
|
||||
merged_text = ",".join(text_parts) # 使用中文逗号连接
|
||||
message_count = len(session.messages)
|
||||
|
||||
logger.debug(f"合并会话 {session_id} 的 {message_count} 条文本消息: {merged_text[:100]}...")
|
||||
|
||||
# 调用回调函数
|
||||
if self.merge_callback:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(self.merge_callback):
|
||||
await self.merge_callback(session_id, merged_text, session.original_event)
|
||||
else:
|
||||
self.merge_callback(session_id, merged_text, session.original_event)
|
||||
except Exception as e:
|
||||
logger.error(f"消息合并回调执行失败: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"合并会话 {session_id} 时出错: {e}")
|
||||
finally:
|
||||
# 清理会话
|
||||
await self._cancel_session_timers(session)
|
||||
self.buffer_pool.pop(session_id, None)
|
||||
|
||||
async def flush_session(self, session_id: str):
|
||||
"""强制刷新指定会话的缓冲区"""
|
||||
await self._force_merge_session(session_id)
|
||||
|
||||
async def flush_all(self):
|
||||
"""强制刷新所有会话的缓冲区"""
|
||||
session_ids = list(self.buffer_pool.keys())
|
||||
for session_id in session_ids:
|
||||
await self._force_merge_session(session_id)
|
||||
|
||||
async def get_buffer_stats(self) -> Dict[str, Any]:
|
||||
"""获取缓冲区统计信息"""
|
||||
async with self.lock:
|
||||
stats = {"total_sessions": len(self.buffer_pool), "sessions": {}}
|
||||
|
||||
for session_id, session in self.buffer_pool.items():
|
||||
stats["sessions"][session_id] = {
|
||||
"message_count": len(session.messages),
|
||||
"created_at": session.created_at,
|
||||
"age": time.time() - session.created_at,
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
async def clear_expired_sessions(self, max_age: float = 300.0):
|
||||
"""清理过期的会话"""
|
||||
current_time = time.time()
|
||||
expired_sessions = []
|
||||
|
||||
async with self.lock:
|
||||
for session_id, session in self.buffer_pool.items():
|
||||
if current_time - session.created_at > max_age:
|
||||
expired_sessions.append(session_id)
|
||||
|
||||
for session_id in expired_sessions:
|
||||
logger.debug(f"清理过期会话: {session_id}")
|
||||
await self._force_merge_session(session_id)
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭消息缓冲器"""
|
||||
self._shutdown = True
|
||||
logger.debug("正在关闭简化消息缓冲器...")
|
||||
|
||||
# 刷新所有缓冲区
|
||||
await self.flush_all()
|
||||
|
||||
# 确保所有任务都被取消
|
||||
async with self.lock:
|
||||
for session in list(self.buffer_pool.values()):
|
||||
await self._cancel_session_timers(session)
|
||||
self.buffer_pool.clear()
|
||||
|
||||
logger.debug("简化消息缓冲器已关闭")
|
||||
@@ -11,10 +11,10 @@ router = None
|
||||
def create_router(plugin_config: dict):
|
||||
"""创建路由器实例"""
|
||||
global router
|
||||
platform_name = config_api.get_plugin_config(plugin_config, "maibot_server.platform_name", "napcat")
|
||||
platform_name = config_api.get_plugin_config(plugin_config, "maibot_server.platform_name", "qq")
|
||||
host = config_api.get_plugin_config(plugin_config, "maibot_server.host", "localhost")
|
||||
port = config_api.get_plugin_config(plugin_config, "maibot_server.port", 8000)
|
||||
|
||||
|
||||
route_config = RouteConfig(
|
||||
route_config={
|
||||
platform_name: TargetConfig(
|
||||
@@ -32,7 +32,7 @@ async def mmc_start_com(plugin_config: dict = None):
|
||||
logger.info("正在连接MaiBot")
|
||||
if plugin_config:
|
||||
create_router(plugin_config)
|
||||
|
||||
|
||||
if router:
|
||||
router.register_class_handler(send_handler.handle_message)
|
||||
await router.run()
|
||||
|
||||
@@ -32,7 +32,7 @@ class NoticeType: # 通知事件
|
||||
group_recall = "group_recall" # 群聊消息撤回
|
||||
notify = "notify"
|
||||
group_ban = "group_ban" # 群禁言
|
||||
group_msg_emoji_like = "group_msg_emoji_like" # 群聊表情回复
|
||||
group_msg_emoji_like = "group_msg_emoji_like" # 群聊表情回复
|
||||
|
||||
class Notify:
|
||||
poke = "poke" # 戳一戳
|
||||
|
||||
@@ -6,7 +6,6 @@ from ...CONSTS import PLUGIN_NAME
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
from src.plugin_system.apis import config_api
|
||||
from ..message_buffer import SimpleMessageBuffer
|
||||
from ..utils import (
|
||||
get_group_info,
|
||||
get_member_info,
|
||||
@@ -48,20 +47,18 @@ class MessageHandler:
|
||||
self.server_connection: Server.ServerConnection = None
|
||||
self.bot_id_list: Dict[int, bool] = {}
|
||||
self.plugin_config = None
|
||||
# 初始化简化消息缓冲器,传入回调函数
|
||||
self.message_buffer = SimpleMessageBuffer(merge_callback=self._send_buffered_message)
|
||||
# 消息缓冲功能已移除
|
||||
|
||||
def set_plugin_config(self, plugin_config: dict):
|
||||
"""设置插件配置"""
|
||||
self.plugin_config = plugin_config
|
||||
# 将配置传递给消息缓冲器
|
||||
if self.message_buffer:
|
||||
self.message_buffer.set_plugin_config(plugin_config)
|
||||
# 消息缓冲功能已移除
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭消息处理器,清理资源"""
|
||||
if self.message_buffer:
|
||||
await self.message_buffer.shutdown()
|
||||
# 消息缓冲功能已移除
|
||||
|
||||
# 消息缓冲功能已移除
|
||||
|
||||
async def set_server_connection(self, server_connection: Server.ServerConnection) -> None:
|
||||
"""设置Napcat连接"""
|
||||
@@ -100,7 +97,7 @@ class MessageHandler:
|
||||
# 检查群聊黑白名单
|
||||
group_list_type = config_api.get_plugin_config(self.plugin_config, "features.group_list_type", "blacklist")
|
||||
group_list = config_api.get_plugin_config(self.plugin_config, "features.group_list", [])
|
||||
|
||||
|
||||
if group_list_type == "whitelist":
|
||||
if group_id not in group_list:
|
||||
logger.warning("群聊不在白名单中,消息被丢弃")
|
||||
@@ -111,9 +108,11 @@ class MessageHandler:
|
||||
return False
|
||||
else:
|
||||
# 检查私聊黑白名单
|
||||
private_list_type = config_api.get_plugin_config(self.plugin_config, "features.private_list_type", "blacklist")
|
||||
private_list_type = config_api.get_plugin_config(
|
||||
self.plugin_config, "features.private_list_type", "blacklist"
|
||||
)
|
||||
private_list = config_api.get_plugin_config(self.plugin_config, "features.private_list", [])
|
||||
|
||||
|
||||
if private_list_type == "whitelist":
|
||||
if user_id not in private_list:
|
||||
logger.warning("私聊不在白名单中,消息被丢弃")
|
||||
@@ -156,21 +155,23 @@ class MessageHandler:
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
"""
|
||||
|
||||
|
||||
# 添加原始消息调试日志,特别关注message字段
|
||||
logger.debug(f"收到原始消息: message_type={raw_message.get('message_type')}, message_id={raw_message.get('message_id')}")
|
||||
logger.debug(
|
||||
f"收到原始消息: message_type={raw_message.get('message_type')}, message_id={raw_message.get('message_id')}"
|
||||
)
|
||||
logger.debug(f"原始消息内容: {raw_message.get('message', [])}")
|
||||
|
||||
|
||||
# 检查是否包含@或video消息段
|
||||
message_segments = raw_message.get('message', [])
|
||||
message_segments = raw_message.get("message", [])
|
||||
if message_segments:
|
||||
for i, seg in enumerate(message_segments):
|
||||
seg_type = seg.get('type')
|
||||
if seg_type in ['at', 'video']:
|
||||
seg_type = seg.get("type")
|
||||
if seg_type in ["at", "video"]:
|
||||
logger.info(f"检测到 {seg_type.upper()} 消息段 [{i}]: {seg}")
|
||||
elif seg_type not in ['text', 'face', 'image']:
|
||||
elif seg_type not in ["text", "face", "image"]:
|
||||
logger.warning(f"检测到特殊消息段 [{i}]: type={seg_type}, data={seg.get('data', {})}")
|
||||
|
||||
|
||||
message_type: str = raw_message.get("message_type")
|
||||
message_id: int = raw_message.get("message_id")
|
||||
# message_time: int = raw_message.get("time")
|
||||
@@ -301,38 +302,7 @@ class MessageHandler:
|
||||
logger.warning("处理后消息内容为空")
|
||||
return None
|
||||
|
||||
# 检查是否需要使用消息缓冲
|
||||
enable_message_buffer = config_api.get_plugin_config(self.plugin_config, "features.enable_message_buffer", True)
|
||||
if enable_message_buffer:
|
||||
# 检查消息类型是否启用缓冲
|
||||
message_type = raw_message.get("message_type")
|
||||
should_use_buffer = False
|
||||
|
||||
if message_type == "group" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_group", True):
|
||||
should_use_buffer = True
|
||||
elif message_type == "private" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_private", True):
|
||||
should_use_buffer = True
|
||||
|
||||
if should_use_buffer:
|
||||
logger.debug(f"尝试缓冲消息,消息类型: {message_type}, 用户: {user_info.user_id}")
|
||||
|
||||
# 尝试添加到缓冲器
|
||||
buffered = await self.message_buffer.add_text_message(
|
||||
event_data={
|
||||
"message_type": message_type,
|
||||
"user_id": user_info.user_id,
|
||||
"group_id": group_info.group_id if group_info else None,
|
||||
},
|
||||
message=raw_message.get("message", []),
|
||||
original_event={"message_info": message_info, "raw_message": raw_message},
|
||||
)
|
||||
|
||||
if buffered:
|
||||
logger.debug(f"✅ 文本消息已成功缓冲: {user_info.user_id}")
|
||||
return None # 缓冲成功,不立即发送
|
||||
# 如果缓冲失败(消息包含非文本元素),走正常处理流程
|
||||
logger.debug(f"❌ 消息缓冲失败,包含非文本元素,走正常处理流程: {user_info.user_id}")
|
||||
# 缓冲失败时继续执行后面的正常处理流程,不要直接返回
|
||||
# 消息缓冲功能已移除,直接处理消息
|
||||
|
||||
logger.debug(f"准备发送消息到MaiBot,消息段数量: {len(seg_message)}")
|
||||
for i, seg in enumerate(seg_message):
|
||||
@@ -351,7 +321,6 @@ class MessageHandler:
|
||||
|
||||
logger.debug("发送到Maibot处理信息")
|
||||
await message_send_instance.message_send(message_base)
|
||||
return None
|
||||
|
||||
async def handle_real_message(self, raw_message: dict, in_reply: bool = False) -> List[Seg] | None:
|
||||
# sourcery skip: low-code-quality
|
||||
@@ -369,10 +338,10 @@ class MessageHandler:
|
||||
for sub_message in real_message:
|
||||
sub_message: dict
|
||||
sub_message_type = sub_message.get("type")
|
||||
|
||||
|
||||
# 添加详细的消息类型调试信息
|
||||
logger.debug(f"处理消息段: type={sub_message_type}, data={sub_message.get('data', {})}")
|
||||
|
||||
|
||||
# 特别关注 at 和 video 消息的识别
|
||||
if sub_message_type == "at":
|
||||
logger.debug(f"检测到@消息: {sub_message}")
|
||||
@@ -380,7 +349,7 @@ class MessageHandler:
|
||||
logger.debug(f"检测到VIDEO消息: {sub_message}")
|
||||
elif sub_message_type not in ["text", "face", "image", "record"]:
|
||||
logger.warning(f"检测到特殊消息类型: {sub_message_type}, 完整消息: {sub_message}")
|
||||
|
||||
|
||||
match sub_message_type:
|
||||
case RealMessageType.text:
|
||||
ret_seg = await self.handle_text_message(sub_message)
|
||||
@@ -519,8 +488,7 @@ class MessageHandler:
|
||||
logger.debug(f"handle_real_message完成,处理了{len(real_message)}个消息段,生成了{len(seg_message)}个seg")
|
||||
return seg_message
|
||||
|
||||
@staticmethod
|
||||
async def handle_text_message(raw_message: dict) -> Seg:
|
||||
async def handle_text_message(self, raw_message: dict) -> Seg:
|
||||
"""
|
||||
处理纯文本信息
|
||||
Parameters:
|
||||
@@ -532,8 +500,7 @@ class MessageHandler:
|
||||
plain_text: str = message_data.get("text")
|
||||
return Seg(type="text", data=plain_text)
|
||||
|
||||
@staticmethod
|
||||
async def handle_face_message(raw_message: dict) -> Seg | None:
|
||||
async def handle_face_message(self, raw_message: dict) -> Seg | None:
|
||||
"""
|
||||
处理表情消息
|
||||
Parameters:
|
||||
@@ -550,8 +517,7 @@ class MessageHandler:
|
||||
logger.warning(f"不支持的表情:{face_raw_id}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def handle_image_message(raw_message: dict) -> Seg | None:
|
||||
async def handle_image_message(self, raw_message: dict) -> Seg | None:
|
||||
"""
|
||||
处理图片消息与表情包消息
|
||||
Parameters:
|
||||
@@ -607,7 +573,6 @@ class MessageHandler:
|
||||
return Seg(type="at", data=f"{member_info.get('nickname')}:{member_info.get('user_id')}")
|
||||
else:
|
||||
return None
|
||||
return None
|
||||
|
||||
async def handle_record_message(self, raw_message: dict) -> Seg | None:
|
||||
"""
|
||||
@@ -636,8 +601,7 @@ class MessageHandler:
|
||||
return None
|
||||
return Seg(type="voice", data=audio_base64)
|
||||
|
||||
@staticmethod
|
||||
async def handle_video_message(raw_message: dict) -> Seg | None:
|
||||
async def handle_video_message(self, raw_message: dict) -> Seg | None:
|
||||
"""
|
||||
处理视频消息
|
||||
Parameters:
|
||||
@@ -744,7 +708,6 @@ class MessageHandler:
|
||||
reply_message = [Seg(type="text", data="(获取发言内容失败)")]
|
||||
sender_info: dict = message_detail.get("sender")
|
||||
sender_nickname: str = sender_info.get("nickname")
|
||||
sender_id: str = sender_info.get("user_id")
|
||||
seg_message: List[Seg] = []
|
||||
if not sender_nickname:
|
||||
logger.warning("无法获取被引用的人的昵称,返回默认值")
|
||||
@@ -768,7 +731,7 @@ class MessageHandler:
|
||||
return None
|
||||
|
||||
processed_message: Seg
|
||||
if 5 > image_count > 0:
|
||||
if image_count < 5 and image_count > 0:
|
||||
# 处理图片数量小于5的情况,此时解析图片为base64
|
||||
logger.debug("图片数量小于5,开始解析图片为base64")
|
||||
processed_message = await self._recursive_parse_image_seg(handled_message, True)
|
||||
@@ -785,18 +748,15 @@ class MessageHandler:
|
||||
forward_hint = Seg(type="text", data="这是一条转发消息:\n")
|
||||
return Seg(type="seglist", data=[forward_hint, processed_message])
|
||||
|
||||
@staticmethod
|
||||
async def handle_dice_message(raw_message: dict) -> Seg:
|
||||
async def handle_dice_message(self, raw_message: dict) -> Seg:
|
||||
message_data: dict = raw_message.get("data", {})
|
||||
res = message_data.get("result", "")
|
||||
return Seg(type="text", data=f"[扔了一个骰子,点数是{res}]")
|
||||
|
||||
@staticmethod
|
||||
async def handle_shake_message(raw_message: dict) -> Seg:
|
||||
async def handle_shake_message(self, raw_message: dict) -> Seg:
|
||||
return Seg(type="text", data="[向你发送了窗口抖动,现在你的屏幕猛烈地震了一下!]")
|
||||
|
||||
@staticmethod
|
||||
async def handle_json_message(raw_message: dict) -> Seg | None:
|
||||
async def handle_json_message(self, raw_message: dict) -> Seg:
|
||||
"""
|
||||
处理JSON消息
|
||||
Parameters:
|
||||
@@ -868,43 +828,6 @@ class MessageHandler:
|
||||
data=f"这是一条小程序分享消息,可以根据来源,考虑使用对应解析工具\n{formatted_content}",
|
||||
)
|
||||
|
||||
# 检查是否是音乐分享
|
||||
elif nested_data.get("view") == "music" and "music" in nested_data.get("meta", {}):
|
||||
logger.debug("检测到音乐分享消息,开始提取信息")
|
||||
music_info = nested_data["meta"]["music"]
|
||||
title = music_info.get("title", "未知歌曲")
|
||||
desc = music_info.get("desc", "未知艺术家")
|
||||
jump_url = music_info.get("jumpUrl", "")
|
||||
preview_url = music_info.get("preview", "")
|
||||
source = music_info.get("tag", "未知来源")
|
||||
|
||||
# 优化文本结构,使其更像卡片
|
||||
text_parts = [
|
||||
"--- 音乐分享 ---",
|
||||
f"歌曲:{title}",
|
||||
f"歌手:{desc}",
|
||||
f"来源:{source}"
|
||||
]
|
||||
if jump_url:
|
||||
text_parts.append(f"链接:{jump_url}")
|
||||
text_parts.append("----------------")
|
||||
|
||||
text_content = "\n".join(text_parts)
|
||||
|
||||
# 如果有预览图,创建一个seglist包含文本和图片
|
||||
if preview_url:
|
||||
try:
|
||||
image_base64 = await get_image_base64(preview_url)
|
||||
if image_base64:
|
||||
return Seg(type="seglist", data=[
|
||||
Seg(type="text", data=text_content + "\n"),
|
||||
Seg(type="image", data=image_base64)
|
||||
])
|
||||
except Exception as e:
|
||||
logger.error(f"下载音乐预览图失败: {e}")
|
||||
|
||||
return Seg(type="text", data=text_content)
|
||||
|
||||
# 如果没有提取到关键信息,返回None
|
||||
return None
|
||||
|
||||
@@ -915,8 +838,7 @@ class MessageHandler:
|
||||
logger.error(f"处理JSON消息时出错: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def handle_rps_message(raw_message: dict) -> Seg:
|
||||
async def handle_rps_message(self, raw_message: dict) -> Seg:
|
||||
message_data: dict = raw_message.get("data", {})
|
||||
res = message_data.get("result", "")
|
||||
if res == "1":
|
||||
@@ -1099,55 +1021,7 @@ class MessageHandler:
|
||||
return None
|
||||
return response_data.get("messages")
|
||||
|
||||
@staticmethod
|
||||
async def _send_buffered_message(session_id: str, merged_text: str, original_event: Dict[str, Any]):
|
||||
"""发送缓冲的合并消息"""
|
||||
try:
|
||||
# 从原始事件数据中提取信息
|
||||
message_info = original_event.get("message_info")
|
||||
raw_message = original_event.get("raw_message")
|
||||
|
||||
if not message_info or not raw_message:
|
||||
logger.error("缓冲消息缺少必要信息")
|
||||
return
|
||||
|
||||
# 创建合并后的消息段 - 将合并的文本转换为Seg格式
|
||||
from maim_message import Seg
|
||||
|
||||
merged_seg = Seg(type="text", data=merged_text)
|
||||
submit_seg = Seg(type="seglist", data=[merged_seg])
|
||||
|
||||
# 创建新的消息ID
|
||||
import time
|
||||
|
||||
new_message_id = f"buffered-{message_info.message_id}-{int(time.time() * 1000)}"
|
||||
|
||||
# 更新消息信息
|
||||
from maim_message import BaseMessageInfo, MessageBase
|
||||
|
||||
buffered_message_info = BaseMessageInfo(
|
||||
platform=message_info.platform,
|
||||
message_id=new_message_id,
|
||||
time=time.time(),
|
||||
user_info=message_info.user_info,
|
||||
group_info=message_info.group_info,
|
||||
template_info=message_info.template_info,
|
||||
format_info=message_info.format_info,
|
||||
additional_config=message_info.additional_config,
|
||||
)
|
||||
|
||||
# 创建MessageBase
|
||||
message_base = MessageBase(
|
||||
message_info=buffered_message_info,
|
||||
message_segment=submit_seg,
|
||||
raw_message=raw_message.get("raw_message", ""),
|
||||
)
|
||||
|
||||
logger.debug(f"发送缓冲合并消息到Maibot处理: {session_id}")
|
||||
await message_send_instance.message_send(message_base)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送缓冲消息失败: {e}", exc_info=True)
|
||||
# 消息缓冲功能已移除
|
||||
|
||||
|
||||
message_handler = MessageHandler()
|
||||
|
||||
@@ -33,6 +33,7 @@ class MessageSending:
|
||||
try:
|
||||
# 重新导入router
|
||||
from ..mmc_com_layer import router
|
||||
|
||||
self.maibot_router = router
|
||||
if self.maibot_router is not None:
|
||||
logger.info("MaiBot router重连成功")
|
||||
@@ -73,14 +74,14 @@ class MessageSending:
|
||||
|
||||
# 获取对应的客户端并发送切片
|
||||
platform = message_base.message_info.platform
|
||||
|
||||
|
||||
# 再次检查router状态(防止运行时被重置)
|
||||
if self.maibot_router is None or not hasattr(self.maibot_router, 'clients'):
|
||||
if self.maibot_router is None or not hasattr(self.maibot_router, "clients"):
|
||||
logger.warning("MaiBot router连接已断开,尝试重新连接")
|
||||
if not await self._attempt_reconnect():
|
||||
logger.error("MaiBot router重连失败,切片发送中止")
|
||||
return False
|
||||
|
||||
|
||||
if platform not in self.maibot_router.clients:
|
||||
logger.error(f"平台 {platform} 未连接")
|
||||
return False
|
||||
|
||||
@@ -23,7 +23,9 @@ class MetaEventHandler:
|
||||
"""设置插件配置"""
|
||||
self.plugin_config = plugin_config
|
||||
# 更新interval值
|
||||
self.interval = config_api.get_plugin_config(self.plugin_config, "napcat_server.heartbeat_interval", 5000) / 1000
|
||||
self.interval = (
|
||||
config_api.get_plugin_config(self.plugin_config, "napcat_server.heartbeat_interval", 5000) / 1000
|
||||
)
|
||||
|
||||
async def handle_meta_event(self, message: dict) -> None:
|
||||
event_type = message.get("meta_event_type")
|
||||
|
||||
@@ -9,7 +9,7 @@ from src.common.logger import get_logger
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
from src.plugin_system.apis import config_api
|
||||
from ..database import BanUser, napcat_db, is_identical
|
||||
from ..database import BanUser, db_manager, is_identical
|
||||
from . import NoticeType, ACCEPT_FORMAT
|
||||
from .message_sending import message_send_instance
|
||||
from .message_handler import message_handler
|
||||
@@ -62,7 +62,7 @@ class NoticeHandler:
|
||||
return self.server_connection
|
||||
return websocket_manager.get_connection()
|
||||
|
||||
async def _ban_operation(self, group_id: int, user_id: Optional[int] = None, lift_time: Optional[int] = None) -> None:
|
||||
def _ban_operation(self, group_id: int, user_id: Optional[int] = None, lift_time: Optional[int] = None) -> None:
|
||||
"""
|
||||
将用户禁言记录添加到self.banned_list中
|
||||
如果是全体禁言,则user_id为0
|
||||
@@ -71,16 +71,16 @@ class NoticeHandler:
|
||||
user_id = 0 # 使用0表示全体禁言
|
||||
lift_time = -1
|
||||
ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=lift_time)
|
||||
for record in list(self.banned_list):
|
||||
for record in self.banned_list:
|
||||
if is_identical(record, ban_record):
|
||||
self.banned_list.remove(record)
|
||||
self.banned_list.append(ban_record)
|
||||
await napcat_db.create_ban_record(ban_record) # 更新
|
||||
db_manager.create_ban_record(ban_record) # 作为更新
|
||||
return
|
||||
self.banned_list.append(ban_record)
|
||||
await napcat_db.create_ban_record(ban_record) # 新建
|
||||
db_manager.create_ban_record(ban_record) # 添加到数据库
|
||||
|
||||
async def _lift_operation(self, group_id: int, user_id: Optional[int] = None) -> None:
|
||||
def _lift_operation(self, group_id: int, user_id: Optional[int] = None) -> None:
|
||||
"""
|
||||
从self.lifted_group_list中移除已经解除全体禁言的群
|
||||
"""
|
||||
@@ -88,12 +88,7 @@ class NoticeHandler:
|
||||
user_id = 0 # 使用0表示全体禁言
|
||||
ban_record = BanUser(user_id=user_id, group_id=group_id, lift_time=-1)
|
||||
self.lifted_list.append(ban_record)
|
||||
# 从被禁言列表里移除对应记录
|
||||
for record in list(self.banned_list):
|
||||
if is_identical(record, ban_record):
|
||||
self.banned_list.remove(record)
|
||||
break
|
||||
await napcat_db.delete_ban_record(ban_record)
|
||||
db_manager.delete_ban_record(ban_record) # 删除数据库中的记录
|
||||
|
||||
async def handle_notice(self, raw_message: dict) -> None:
|
||||
notice_type = raw_message.get("notice_type")
|
||||
@@ -121,9 +116,9 @@ class NoticeHandler:
|
||||
sub_type = raw_message.get("sub_type")
|
||||
match sub_type:
|
||||
case NoticeType.Notify.poke:
|
||||
if config_api.get_plugin_config(self.plugin_config, "features.enable_poke", True) and await message_handler.check_allow_to_chat(
|
||||
user_id, group_id, False, False
|
||||
):
|
||||
if config_api.get_plugin_config(
|
||||
self.plugin_config, "features.enable_poke", True
|
||||
) and await message_handler.check_allow_to_chat(user_id, group_id, False, False):
|
||||
logger.debug("处理戳一戳消息")
|
||||
handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id)
|
||||
else:
|
||||
@@ -132,14 +127,18 @@ class NoticeHandler:
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from ...event_types import NapcatEvent
|
||||
|
||||
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME)
|
||||
await event_manager.trigger_event(
|
||||
NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME
|
||||
)
|
||||
case _:
|
||||
logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}")
|
||||
case NoticeType.group_msg_emoji_like:
|
||||
case NoticeType.group_msg_emoji_like:
|
||||
# 该事件转移到 handle_group_emoji_like_notify函数内触发
|
||||
if config_api.get_plugin_config(self.plugin_config, "features.enable_emoji_like", True):
|
||||
logger.debug("处理群聊表情回复")
|
||||
handled_message, user_info = await self.handle_group_emoji_like_notify(raw_message,group_id,user_id)
|
||||
handled_message, user_info = await self.handle_group_emoji_like_notify(
|
||||
raw_message, group_id, user_id
|
||||
)
|
||||
else:
|
||||
logger.warning("群聊表情回复被禁用,取消群聊表情回复处理")
|
||||
case NoticeType.group_ban:
|
||||
@@ -202,11 +201,9 @@ class NoticeHandler:
|
||||
|
||||
if system_notice:
|
||||
await self.put_notice(message_base)
|
||||
return None
|
||||
else:
|
||||
logger.debug("发送到Maibot处理通知信息")
|
||||
await message_send_instance.message_send(message_base)
|
||||
return None
|
||||
|
||||
async def handle_poke_notify(
|
||||
self, raw_message: dict, group_id: int, user_id: int
|
||||
@@ -301,7 +298,7 @@ class NoticeHandler:
|
||||
async def handle_group_emoji_like_notify(self, raw_message: dict, group_id: int, user_id: int):
|
||||
if not group_id:
|
||||
logger.error("群ID不能为空,无法处理群聊表情回复通知")
|
||||
return None, None
|
||||
return None, None
|
||||
|
||||
user_qq_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id)
|
||||
if user_qq_info:
|
||||
@@ -311,37 +308,42 @@ class NoticeHandler:
|
||||
user_name = "QQ用户"
|
||||
user_cardname = "QQ用户"
|
||||
logger.debug("无法获取表情回复对方的用户昵称")
|
||||
|
||||
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from ...event_types import NapcatEvent
|
||||
|
||||
target_message = await event_manager.trigger_event(NapcatEvent.MESSAGE.GET_MSG,message_id=raw_message.get("message_id",""))
|
||||
target_message_text = target_message.get_message_result().get("data",{}).get("raw_message","")
|
||||
target_message = await event_manager.trigger_event(
|
||||
NapcatEvent.MESSAGE.GET_MSG, message_id=raw_message.get("message_id", "")
|
||||
)
|
||||
target_message_text = target_message.get_message_result().get("data", {}).get("raw_message", "")
|
||||
if not target_message:
|
||||
logger.error("未找到对应消息")
|
||||
return None, None
|
||||
if len(target_message_text) > 15:
|
||||
target_message_text = target_message_text[:15] + "..."
|
||||
|
||||
|
||||
user_info: UserInfo = UserInfo(
|
||||
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
|
||||
user_id=user_id,
|
||||
user_nickname=user_name,
|
||||
user_cardname=user_cardname,
|
||||
)
|
||||
|
||||
|
||||
like_emoji_id = raw_message.get("likes")[0].get("emoji_id")
|
||||
await event_manager.trigger_event(
|
||||
NapcatEvent.ON_RECEIVED.EMOJI_LIEK,
|
||||
permission_group=PLUGIN_NAME,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
message_id=raw_message.get("message_id",""),
|
||||
emoji_id=like_emoji_id
|
||||
)
|
||||
seg_data = Seg(type="text",data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]")
|
||||
NapcatEvent.ON_RECEIVED.EMOJI_LIEK,
|
||||
permission_group=PLUGIN_NAME,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
message_id=raw_message.get("message_id", ""),
|
||||
emoji_id=like_emoji_id,
|
||||
)
|
||||
seg_data = Seg(
|
||||
type="text",
|
||||
data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]",
|
||||
)
|
||||
return seg_data, user_info
|
||||
|
||||
|
||||
async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]:
|
||||
if not group_id:
|
||||
logger.error("群ID不能为空,无法处理禁言通知")
|
||||
@@ -381,7 +383,7 @@ class NoticeHandler:
|
||||
|
||||
if user_id == 0: # 为全体禁言
|
||||
sub_type: str = "whole_ban"
|
||||
await self._ban_operation(group_id)
|
||||
self._ban_operation(group_id)
|
||||
else: # 为单人禁言
|
||||
# 获取被禁言人的信息
|
||||
sub_type: str = "ban"
|
||||
@@ -395,7 +397,7 @@ class NoticeHandler:
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
)
|
||||
await self._ban_operation(group_id, user_id, int(time.time() + duration))
|
||||
self._ban_operation(group_id, user_id, int(time.time() + duration))
|
||||
|
||||
seg_data: Seg = Seg(
|
||||
type="notify",
|
||||
@@ -444,7 +446,7 @@ class NoticeHandler:
|
||||
user_id = raw_message.get("user_id")
|
||||
if user_id == 0: # 全体禁言解除
|
||||
sub_type = "whole_lift_ban"
|
||||
await self._lift_operation(group_id)
|
||||
self._lift_operation(group_id)
|
||||
else: # 单人禁言解除
|
||||
sub_type = "lift_ban"
|
||||
# 获取被解除禁言人的信息
|
||||
@@ -460,7 +462,7 @@ class NoticeHandler:
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
)
|
||||
await self._lift_operation(group_id, user_id)
|
||||
self._lift_operation(group_id, user_id)
|
||||
|
||||
seg_data: Seg = Seg(
|
||||
type="notify",
|
||||
@@ -471,8 +473,7 @@ class NoticeHandler:
|
||||
)
|
||||
return seg_data, operator_info
|
||||
|
||||
@staticmethod
|
||||
async def put_notice(message_base: MessageBase) -> None:
|
||||
async def put_notice(self, message_base: MessageBase) -> None:
|
||||
"""
|
||||
将处理后的通知消息放入通知队列
|
||||
"""
|
||||
@@ -488,7 +489,7 @@ class NoticeHandler:
|
||||
group_id = lift_record.group_id
|
||||
user_id = lift_record.user_id
|
||||
|
||||
asyncio.create_task(napcat_db.delete_ban_record(lift_record)) # 从数据库中删除禁言记录
|
||||
db_manager.delete_ban_record(lift_record) # 从数据库中删除禁言记录
|
||||
|
||||
seg_message: Seg = await self.natural_lift(group_id, user_id)
|
||||
|
||||
@@ -585,8 +586,7 @@ class NoticeHandler:
|
||||
self.banned_list.remove(ban_record)
|
||||
await asyncio.sleep(5)
|
||||
|
||||
@staticmethod
|
||||
async def send_notice() -> None:
|
||||
async def send_notice(self) -> None:
|
||||
"""
|
||||
发送通知消息到Napcat
|
||||
"""
|
||||
|
||||
@@ -45,12 +45,12 @@ async def check_timeout_response() -> None:
|
||||
while True:
|
||||
cleaned_message_count: int = 0
|
||||
now_time = time.time()
|
||||
|
||||
|
||||
# 获取心跳间隔配置
|
||||
heartbeat_interval = 30 # 默认值
|
||||
if plugin_config:
|
||||
heartbeat_interval = config_api.get_plugin_config(plugin_config, "napcat_server.heartbeat_interval", 30)
|
||||
|
||||
|
||||
for echo_id, response_time in list(response_time_dict.items()):
|
||||
if now_time - response_time > heartbeat_interval:
|
||||
cleaned_message_count += 1
|
||||
|
||||
@@ -96,6 +96,7 @@ class SendHandler:
|
||||
logger.error("无法识别的消息类型")
|
||||
return None
|
||||
logger.info("尝试发送到napcat")
|
||||
logger.debug(f"准备发送到napcat的消息体: action='{action}', {id_name}='{target_id}', message='{processed_message}'")
|
||||
response = await self.send_message_to_napcat(
|
||||
action,
|
||||
{
|
||||
@@ -228,8 +229,10 @@ class SendHandler:
|
||||
new_payload = payload
|
||||
if seg.type == "reply":
|
||||
target_id = seg.data
|
||||
target_id = str(target_id)
|
||||
if target_id == "notice":
|
||||
return payload
|
||||
logger.info(target_id if isinstance(target_id, str) else "")
|
||||
new_payload = self.build_payload(
|
||||
payload,
|
||||
await self.handle_reply_message(target_id if isinstance(target_id, str) else "", user_info),
|
||||
@@ -294,15 +297,17 @@ class SendHandler:
|
||||
|
||||
async def handle_reply_message(self, id: str, user_info: UserInfo) -> dict | list:
|
||||
"""处理回复消息"""
|
||||
logger.debug(f"开始处理回复消息,消息ID: {id}")
|
||||
reply_seg = {"type": "reply", "data": {"id": id}}
|
||||
|
||||
# 检查是否启用引用艾特功能
|
||||
if not config_api.get_plugin_config(self.plugin_config, "features.enable_reply_at", False):
|
||||
logger.info("引用艾特功能未启用,仅发送普通回复")
|
||||
return reply_seg
|
||||
|
||||
try:
|
||||
# 尝试通过 message_id 获取消息详情
|
||||
msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(id)})
|
||||
msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": id})
|
||||
logger.debug(f"获取消息 {id} 的详情响应: {msg_info_response}")
|
||||
|
||||
replied_user_id = None
|
||||
if msg_info_response and msg_info_response.get("status") == "ok":
|
||||
@@ -313,6 +318,7 @@ class SendHandler:
|
||||
# 如果没有获取到被回复者的ID,则直接返回,不进行@
|
||||
if not replied_user_id:
|
||||
logger.warning(f"无法获取消息 {id} 的发送者信息,跳过 @")
|
||||
logger.info(f"最终返回的回复段: {reply_seg}")
|
||||
return reply_seg
|
||||
|
||||
# 根据概率决定是否艾特用户
|
||||
@@ -320,13 +326,17 @@ class SendHandler:
|
||||
at_seg = {"type": "at", "data": {"qq": str(replied_user_id)}}
|
||||
# 在艾特后面添加一个空格
|
||||
text_seg = {"type": "text", "data": {"text": " "}}
|
||||
return [reply_seg, at_seg, text_seg]
|
||||
result_seg = [reply_seg, at_seg, text_seg]
|
||||
logger.info(f"最终返回的回复段: {result_seg}")
|
||||
return result_seg
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理引用回复并尝试@时出错: {e}")
|
||||
# 出现异常时,只发送普通的回复,避免程序崩溃
|
||||
logger.info(f"最终返回的回复段: {reply_seg}")
|
||||
return reply_seg
|
||||
|
||||
logger.info(f"最终返回的回复段: {reply_seg}")
|
||||
return reply_seg
|
||||
|
||||
@staticmethod
|
||||
@@ -366,7 +376,7 @@ class SendHandler:
|
||||
use_tts = False
|
||||
if self.plugin_config:
|
||||
use_tts = config_api.get_plugin_config(self.plugin_config, "voice.use_tts", False)
|
||||
|
||||
|
||||
if not use_tts:
|
||||
logger.warning("未启用语音消息处理")
|
||||
return {}
|
||||
|
||||
@@ -18,7 +18,9 @@ class WebSocketManager:
|
||||
self.max_reconnect_attempts = 10 # 最大重连次数
|
||||
self.plugin_config = None
|
||||
|
||||
async def start_connection(self, message_handler: Callable[[Server.ServerConnection], Any], plugin_config: dict) -> None:
|
||||
async def start_connection(
|
||||
self, message_handler: Callable[[Server.ServerConnection], Any], plugin_config: dict
|
||||
) -> None:
|
||||
"""根据配置启动 WebSocket 连接"""
|
||||
self.plugin_config = plugin_config
|
||||
mode = config_api.get_plugin_config(plugin_config, "napcat_server.mode")
|
||||
@@ -72,9 +74,7 @@ class WebSocketManager:
|
||||
# 如果配置了访问令牌,添加到请求头
|
||||
access_token = config_api.get_plugin_config(self.plugin_config, "napcat_server.access_token")
|
||||
if access_token:
|
||||
connect_kwargs["additional_headers"] = {
|
||||
"Authorization": f"Bearer {access_token}"
|
||||
}
|
||||
connect_kwargs["additional_headers"] = {"Authorization": f"Bearer {access_token}"}
|
||||
logger.info("已添加访问令牌到连接请求头")
|
||||
|
||||
async with Server.connect(url, **connect_kwargs) as websocket:
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
# 权限配置文件
|
||||
# 此文件用于管理群聊和私聊的黑白名单设置,以及聊天相关功能
|
||||
# 支持热重载,修改后会自动生效
|
||||
|
||||
# 群聊权限设置
|
||||
group_list_type = "whitelist" # 群聊列表类型:whitelist(白名单)或 blacklist(黑名单)
|
||||
group_list = [] # 群聊ID列表
|
||||
# 当 group_list_type 为 whitelist 时,只有列表中的群聊可以使用机器人
|
||||
# 当 group_list_type 为 blacklist 时,列表中的群聊无法使用机器人
|
||||
# 示例:group_list = [123456789, 987654321]
|
||||
|
||||
# 私聊权限设置
|
||||
private_list_type = "whitelist" # 私聊列表类型:whitelist(白名单)或 blacklist(黑名单)
|
||||
private_list = [] # 用户ID列表
|
||||
# 当 private_list_type 为 whitelist 时,只有列表中的用户可以私聊机器人
|
||||
# 当 private_list_type 为 blacklist 时,列表中的用户无法私聊机器人
|
||||
# 示例:private_list = [123456789, 987654321]
|
||||
|
||||
# 全局禁止设置
|
||||
ban_user_id = [] # 全局禁止用户ID列表,这些用户无法在任何地方使用机器人
|
||||
ban_qq_bot = false # 是否屏蔽QQ官方机器人消息
|
||||
|
||||
# 聊天功能设置
|
||||
enable_poke = true # 是否启用戳一戳功能
|
||||
ignore_non_self_poke = false # 是否无视不是针对自己的戳一戳
|
||||
poke_debounce_seconds = 3 # 戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略
|
||||
enable_reply_at = true # 是否启用引用回复时艾特用户的功能
|
||||
reply_at_rate = 0.5 # 引用回复时艾特用户的几率 (0.0 ~ 1.0)
|
||||
|
||||
# 视频处理设置
|
||||
enable_video_analysis = true # 是否启用视频识别功能
|
||||
max_video_size_mb = 100 # 视频文件最大大小限制(MB)
|
||||
download_timeout = 60 # 视频下载超时时间(秒)
|
||||
supported_formats = ["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"] # 支持的视频格式
|
||||
|
||||
# 消息缓冲设置
|
||||
enable_message_buffer = true # 是否启用消息缓冲合并功能
|
||||
message_buffer_enable_group = true # 是否启用群聊消息缓冲合并
|
||||
message_buffer_enable_private = true # 是否启用私聊消息缓冲合并
|
||||
message_buffer_interval = 3.0 # 消息合并间隔时间(秒),在此时间内的连续消息将被合并
|
||||
message_buffer_initial_delay = 0.5 # 消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并
|
||||
message_buffer_max_components = 50 # 单个会话最大缓冲消息组件数量,超过此数量将强制合并
|
||||
message_buffer_block_prefixes = ["/"] # 消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲
|
||||
@@ -1,29 +0,0 @@
|
||||
[inner]
|
||||
version = "0.2.1" # 版本号
|
||||
# 请勿修改版本号,除非你知道自己在做什么
|
||||
|
||||
[nickname] # 现在没用
|
||||
nickname = ""
|
||||
|
||||
[napcat_server] # Napcat连接的ws服务设置
|
||||
mode = "reverse" # 连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端)
|
||||
host = "localhost" # 主机地址
|
||||
port = 8095 # 端口号
|
||||
url = "" # 正向连接时的完整WebSocket URL,如 ws://localhost:8080/ws (仅在forward模式下使用)
|
||||
access_token = "" # WebSocket 连接的访问令牌,用于身份验证(可选)
|
||||
heartbeat_interval = 30 # 心跳间隔时间(按秒计)
|
||||
|
||||
[maibot_server] # 连接麦麦的ws服务设置
|
||||
host = "localhost" # 麦麦在.env文件中设置的主机地址,即HOST字段
|
||||
port = 8000 # 麦麦在.env文件中设置的端口,即PORT字段
|
||||
|
||||
[voice] # 发送语音设置
|
||||
use_tts = false # 是否使用tts语音(请确保你配置了tts并有对应的adapter)
|
||||
|
||||
[slicing] # WebSocket消息切片设置
|
||||
max_frame_size = 64 # WebSocket帧的最大大小,单位为字节,默认64KB
|
||||
delay_ms = 10 # 切片发送间隔时间,单位为毫秒
|
||||
|
||||
[debug]
|
||||
level = "INFO" # 日志等级(DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
|
||||
@@ -30,7 +30,8 @@ class PokeAction(BaseAction):
|
||||
|
||||
# === 功能描述(必须填写)===
|
||||
action_parameters = {
|
||||
"user_name": "需要戳一戳的用户的名字",
|
||||
"user_name": "需要戳一戳的用户的名字 (可选)",
|
||||
"user_id": "需要戳一戳的用户的ID (可选,优先级更高)",
|
||||
"times": "需要戳一戳的次数 (默认为 1)",
|
||||
}
|
||||
action_require = ["当需要戳某个用户时使用", "当你想提醒特定用户时使用"]
|
||||
@@ -46,32 +47,38 @@ class PokeAction(BaseAction):
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""执行戳一戳的动作"""
|
||||
user_id = self.action_data.get("user_id")
|
||||
user_name = self.action_data.get("user_name")
|
||||
|
||||
try:
|
||||
times = int(self.action_data.get("times", 1))
|
||||
except (ValueError, TypeError):
|
||||
times = 1
|
||||
|
||||
if not user_name:
|
||||
logger.warning("戳一戳动作缺少 'user_name' 参数。")
|
||||
return False, "缺少 'user_name' 参数"
|
||||
|
||||
user_info = await get_person_info_manager().get_person_info_by_name(user_name)
|
||||
if not user_info or not user_info.get("user_id"):
|
||||
logger.info(f"找不到名为 '{user_name}' 的用户。")
|
||||
return False, f"找不到名为 '{user_name}' 的用户"
|
||||
|
||||
user_id = user_info.get("user_id")
|
||||
# 优先使用 user_id
|
||||
if not user_id:
|
||||
if not user_name:
|
||||
logger.warning("戳一戳动作缺少 'user_id' 或 'user_name' 参数。")
|
||||
return False, "缺少用户标识参数"
|
||||
|
||||
# 备用方案:通过 user_name 查找
|
||||
user_info = await get_person_info_manager().get_person_info_by_name(user_name)
|
||||
if not user_info or not user_info.get("user_id"):
|
||||
logger.info(f"找不到名为 '{user_name}' 的用户。")
|
||||
return False, f"找不到名为 '{user_name}' 的用户"
|
||||
user_id = user_info.get("user_id")
|
||||
|
||||
display_name = user_name or user_id
|
||||
|
||||
for i in range(times):
|
||||
logger.info(f"正在向 {user_name} ({user_id}) 发送第 {i + 1}/{times} 次戳一戳...")
|
||||
logger.info(f"正在向 {display_name} ({user_id}) 发送第 {i + 1}/{times} 次戳一戳...")
|
||||
await self.send_command(
|
||||
"SEND_POKE", args={"qq_id": user_id}, display_message=f"戳了戳 {user_name} ({i + 1}/{times})"
|
||||
"SEND_POKE", args={"qq_id": user_id}, display_message=f"戳了戳 {display_name} ({i + 1}/{times})"
|
||||
)
|
||||
# 添加一个小的延迟,以避免发送过快
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
success_message = f"已向 {user_name} 发送 {times} 次戳一戳。"
|
||||
success_message = f"已向 {display_name} 发送 {times} 次戳一戳。"
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True, action_prompt_display=success_message, action_done=True
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Base search engine interface
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Any
|
||||
|
||||
@@ -9,20 +10,20 @@ class BaseSearchEngine(ABC):
|
||||
"""
|
||||
搜索引擎基类
|
||||
"""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
执行搜索
|
||||
|
||||
|
||||
Args:
|
||||
args: 搜索参数,包含 query、num_results、time_range 等
|
||||
|
||||
|
||||
Returns:
|
||||
搜索结果列表,每个结果包含 title、url、snippet、provider 字段
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def is_available(self) -> bool:
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Bing search engine implementation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import random
|
||||
@@ -58,21 +59,21 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
"""
|
||||
Bing搜索引擎实现
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.session = requests.Session()
|
||||
self.session.headers = HEADERS
|
||||
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""检查Bing搜索引擎是否可用"""
|
||||
return True # Bing是免费搜索引擎,总是可用
|
||||
|
||||
|
||||
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""执行Bing搜索"""
|
||||
query = args["query"]
|
||||
num_results = args.get("num_results", 3)
|
||||
time_range = args.get("time_range", "any")
|
||||
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
func = functools.partial(self._search_sync, query, num_results, time_range)
|
||||
@@ -81,17 +82,17 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
except Exception as e:
|
||||
logger.error(f"Bing 搜索失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _search_sync(self, keyword: str, num_results: int, time_range: str) -> List[Dict[str, Any]]:
|
||||
"""同步执行Bing搜索"""
|
||||
if not keyword:
|
||||
return []
|
||||
|
||||
list_result = []
|
||||
|
||||
|
||||
# 构建搜索URL
|
||||
search_url = bing_search_url + keyword
|
||||
|
||||
|
||||
# 如果指定了时间范围,添加时间过滤参数
|
||||
if time_range == "week":
|
||||
search_url += "&qft=+filterui:date-range-7"
|
||||
@@ -182,34 +183,29 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
# 尝试提取搜索结果
|
||||
# 方法1: 查找标准的搜索结果容器
|
||||
results = root.select("ol#b_results li.b_algo")
|
||||
|
||||
|
||||
if results:
|
||||
for _rank, result in enumerate(results, 1):
|
||||
# 提取标题和链接
|
||||
title_link = result.select_one("h2 a")
|
||||
if not title_link:
|
||||
continue
|
||||
|
||||
|
||||
title = title_link.get_text().strip()
|
||||
url = title_link.get("href", "")
|
||||
|
||||
|
||||
# 提取摘要
|
||||
abstract = ""
|
||||
abstract_elem = result.select_one("div.b_caption p")
|
||||
if abstract_elem:
|
||||
abstract = abstract_elem.get_text().strip()
|
||||
|
||||
|
||||
# 限制摘要长度
|
||||
if ABSTRACT_MAX_LENGTH and len(abstract) > ABSTRACT_MAX_LENGTH:
|
||||
abstract = abstract[:ABSTRACT_MAX_LENGTH] + "..."
|
||||
|
||||
list_data.append({
|
||||
"title": title,
|
||||
"url": url,
|
||||
"snippet": abstract,
|
||||
"provider": "Bing"
|
||||
})
|
||||
|
||||
|
||||
list_data.append({"title": title, "url": url, "snippet": abstract, "provider": "Bing"})
|
||||
|
||||
if len(list_data) >= 10: # 限制结果数量
|
||||
break
|
||||
|
||||
@@ -217,22 +213,34 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
if not list_data:
|
||||
# 查找所有可能的搜索结果链接
|
||||
all_links = root.find_all("a")
|
||||
|
||||
|
||||
for link in all_links:
|
||||
href = link.get("href", "")
|
||||
text = link.get_text().strip()
|
||||
|
||||
|
||||
# 过滤有效的搜索结果链接
|
||||
if (href and text and len(text) > 10
|
||||
if (
|
||||
href
|
||||
and text
|
||||
and len(text) > 10
|
||||
and not href.startswith("javascript:")
|
||||
and not href.startswith("#")
|
||||
and "http" in href
|
||||
and not any(x in href for x in [
|
||||
"bing.com/search", "bing.com/images", "bing.com/videos",
|
||||
"bing.com/maps", "bing.com/news", "login", "account",
|
||||
"microsoft", "javascript"
|
||||
])):
|
||||
|
||||
and not any(
|
||||
x in href
|
||||
for x in [
|
||||
"bing.com/search",
|
||||
"bing.com/images",
|
||||
"bing.com/videos",
|
||||
"bing.com/maps",
|
||||
"bing.com/news",
|
||||
"login",
|
||||
"account",
|
||||
"microsoft",
|
||||
"javascript",
|
||||
]
|
||||
)
|
||||
):
|
||||
# 尝试获取摘要
|
||||
abstract = ""
|
||||
parent = link.parent
|
||||
@@ -240,18 +248,13 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
full_text = parent.get_text().strip()
|
||||
if len(full_text) > len(text):
|
||||
abstract = full_text.replace(text, "", 1).strip()
|
||||
|
||||
|
||||
# 限制摘要长度
|
||||
if ABSTRACT_MAX_LENGTH and len(abstract) > ABSTRACT_MAX_LENGTH:
|
||||
abstract = abstract[:ABSTRACT_MAX_LENGTH] + "..."
|
||||
|
||||
list_data.append({
|
||||
"title": text,
|
||||
"url": href,
|
||||
"snippet": abstract,
|
||||
"provider": "Bing"
|
||||
})
|
||||
|
||||
|
||||
list_data.append({"title": text, "url": href, "snippet": abstract, "provider": "Bing"})
|
||||
|
||||
if len(list_data) >= 10:
|
||||
break
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
DuckDuckGo search engine implementation
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any
|
||||
from asyncddgs import aDDGS
|
||||
|
||||
@@ -14,27 +15,22 @@ class DDGSearchEngine(BaseSearchEngine):
|
||||
"""
|
||||
DuckDuckGo搜索引擎实现
|
||||
"""
|
||||
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""检查DuckDuckGo搜索引擎是否可用"""
|
||||
return True # DuckDuckGo不需要API密钥,总是可用
|
||||
|
||||
|
||||
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""执行DuckDuckGo搜索"""
|
||||
query = args["query"]
|
||||
num_results = args.get("num_results", 3)
|
||||
|
||||
|
||||
try:
|
||||
async with aDDGS() as ddgs:
|
||||
search_response = await ddgs.text(query, max_results=num_results)
|
||||
|
||||
|
||||
return [
|
||||
{
|
||||
"title": r.get("title"),
|
||||
"url": r.get("href"),
|
||||
"snippet": r.get("body"),
|
||||
"provider": "DuckDuckGo"
|
||||
}
|
||||
{"title": r.get("title"), "url": r.get("href"), "snippet": r.get("body"), "provider": "DuckDuckGo"}
|
||||
for r in search_response
|
||||
]
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Exa search engine implementation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from datetime import datetime, timedelta
|
||||
@@ -19,31 +20,27 @@ class ExaSearchEngine(BaseSearchEngine):
|
||||
"""
|
||||
Exa搜索引擎实现
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self._initialize_clients()
|
||||
|
||||
|
||||
def _initialize_clients(self):
|
||||
"""初始化Exa客户端"""
|
||||
# 从主配置文件读取API密钥
|
||||
exa_api_keys = config_api.get_global_config("web_search.exa_api_keys", None)
|
||||
|
||||
|
||||
# 创建API密钥管理器
|
||||
self.api_manager = create_api_key_manager_from_config(
|
||||
exa_api_keys,
|
||||
lambda key: Exa(api_key=key),
|
||||
"Exa"
|
||||
)
|
||||
|
||||
self.api_manager = create_api_key_manager_from_config(exa_api_keys, lambda key: Exa(api_key=key), "Exa")
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""检查Exa搜索引擎是否可用"""
|
||||
return self.api_manager.is_available()
|
||||
|
||||
|
||||
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""执行Exa搜索"""
|
||||
if not self.is_available():
|
||||
return []
|
||||
|
||||
|
||||
query = args["query"]
|
||||
num_results = args.get("num_results", 3)
|
||||
time_range = args.get("time_range", "any")
|
||||
@@ -52,7 +49,7 @@ class ExaSearchEngine(BaseSearchEngine):
|
||||
if time_range != "any":
|
||||
today = datetime.now()
|
||||
start_date = today - timedelta(days=7 if time_range == "week" else 30)
|
||||
exa_args["start_published_date"] = start_date.strftime('%Y-%m-%d')
|
||||
exa_args["start_published_date"] = start_date.strftime("%Y-%m-%d")
|
||||
|
||||
try:
|
||||
# 使用API密钥管理器获取下一个客户端
|
||||
@@ -60,17 +57,17 @@ class ExaSearchEngine(BaseSearchEngine):
|
||||
if not exa_client:
|
||||
logger.error("无法获取Exa客户端")
|
||||
return []
|
||||
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
func = functools.partial(exa_client.search_and_contents, query, **exa_args)
|
||||
search_response = await loop.run_in_executor(None, func)
|
||||
|
||||
|
||||
return [
|
||||
{
|
||||
"title": res.title,
|
||||
"url": res.url,
|
||||
"snippet": " ".join(getattr(res, 'highlights', [])) or (getattr(res, 'text', '')[:250] + '...'),
|
||||
"provider": "Exa"
|
||||
"snippet": " ".join(getattr(res, "highlights", [])) or (getattr(res, "text", "")[:250] + "..."),
|
||||
"provider": "Exa",
|
||||
}
|
||||
for res in search_response.results
|
||||
]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Tavily search engine implementation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from typing import Dict, List, Any
|
||||
@@ -18,31 +19,29 @@ class TavilySearchEngine(BaseSearchEngine):
|
||||
"""
|
||||
Tavily搜索引擎实现
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self._initialize_clients()
|
||||
|
||||
|
||||
def _initialize_clients(self):
|
||||
"""初始化Tavily客户端"""
|
||||
# 从主配置文件读取API密钥
|
||||
tavily_api_keys = config_api.get_global_config("web_search.tavily_api_keys", None)
|
||||
|
||||
|
||||
# 创建API密钥管理器
|
||||
self.api_manager = create_api_key_manager_from_config(
|
||||
tavily_api_keys,
|
||||
lambda key: TavilyClient(api_key=key),
|
||||
"Tavily"
|
||||
tavily_api_keys, lambda key: TavilyClient(api_key=key), "Tavily"
|
||||
)
|
||||
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""检查Tavily搜索引擎是否可用"""
|
||||
return self.api_manager.is_available()
|
||||
|
||||
|
||||
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""执行Tavily搜索"""
|
||||
if not self.is_available():
|
||||
return []
|
||||
|
||||
|
||||
query = args["query"]
|
||||
num_results = args.get("num_results", 3)
|
||||
time_range = args.get("time_range", "any")
|
||||
@@ -53,38 +52,40 @@ class TavilySearchEngine(BaseSearchEngine):
|
||||
if not tavily_client:
|
||||
logger.error("无法获取Tavily客户端")
|
||||
return []
|
||||
|
||||
|
||||
# 构建Tavily搜索参数
|
||||
search_params = {
|
||||
"query": query,
|
||||
"max_results": num_results,
|
||||
"search_depth": "basic",
|
||||
"include_answer": False,
|
||||
"include_raw_content": False
|
||||
"include_raw_content": False,
|
||||
}
|
||||
|
||||
|
||||
# 根据时间范围调整搜索参数
|
||||
if time_range == "week":
|
||||
search_params["days"] = 7
|
||||
elif time_range == "month":
|
||||
search_params["days"] = 30
|
||||
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
func = functools.partial(tavily_client.search, **search_params)
|
||||
search_response = await loop.run_in_executor(None, func)
|
||||
|
||||
|
||||
results = []
|
||||
if search_response and "results" in search_response:
|
||||
for res in search_response["results"]:
|
||||
results.append({
|
||||
"title": res.get("title", "无标题"),
|
||||
"url": res.get("url", ""),
|
||||
"snippet": res.get("content", "")[:300] + "..." if res.get("content") else "无摘要",
|
||||
"provider": "Tavily"
|
||||
})
|
||||
|
||||
results.append(
|
||||
{
|
||||
"title": res.get("title", "无标题"),
|
||||
"url": res.get("url", ""),
|
||||
"snippet": res.get("content", "")[:300] + "..." if res.get("content") else "无摘要",
|
||||
"provider": "Tavily",
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tavily 搜索失败: {e}")
|
||||
return []
|
||||
|
||||
@@ -3,15 +3,10 @@ Web Search Tool Plugin
|
||||
|
||||
一个功能强大的网络搜索和URL解析插件,支持多种搜索引擎和解析策略。
|
||||
"""
|
||||
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
from src.plugin_system import (
|
||||
BasePlugin,
|
||||
register_plugin,
|
||||
ComponentInfo,
|
||||
ConfigField,
|
||||
PythonDependency
|
||||
)
|
||||
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo, ConfigField, PythonDependency
|
||||
from src.plugin_system.apis import config_api
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -25,7 +20,7 @@ logger = get_logger("web_search_plugin")
|
||||
class WEBSEARCHPLUGIN(BasePlugin):
|
||||
"""
|
||||
网络搜索工具插件
|
||||
|
||||
|
||||
提供网络搜索和URL解析功能,支持多种搜索引擎:
|
||||
- Exa (需要API密钥)
|
||||
- Tavily (需要API密钥)
|
||||
@@ -37,11 +32,11 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
plugin_name: str = "web_search_tool" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: List[str] = [] # 插件依赖列表
|
||||
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""初始化插件,立即加载所有搜索引擎"""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
# 立即初始化所有搜索引擎,触发API密钥管理器的日志输出
|
||||
logger.info("🚀 正在初始化所有搜索引擎...")
|
||||
try:
|
||||
@@ -49,65 +44,58 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
from .engines.tavily_engine import TavilySearchEngine
|
||||
from .engines.ddg_engine import DDGSearchEngine
|
||||
from .engines.bing_engine import BingSearchEngine
|
||||
|
||||
|
||||
# 实例化所有搜索引擎,这会触发API密钥管理器的初始化
|
||||
exa_engine = ExaSearchEngine()
|
||||
tavily_engine = TavilySearchEngine()
|
||||
ddg_engine = DDGSearchEngine()
|
||||
bing_engine = BingSearchEngine()
|
||||
|
||||
|
||||
# 报告每个引擎的状态
|
||||
engines_status = {
|
||||
"Exa": exa_engine.is_available(),
|
||||
"Tavily": tavily_engine.is_available(),
|
||||
"DuckDuckGo": ddg_engine.is_available(),
|
||||
"Bing": bing_engine.is_available()
|
||||
"Bing": bing_engine.is_available(),
|
||||
}
|
||||
|
||||
|
||||
available_engines = [name for name, available in engines_status.items() if available]
|
||||
unavailable_engines = [name for name, available in engines_status.items() if not available]
|
||||
|
||||
|
||||
if available_engines:
|
||||
logger.info(f"✅ 可用搜索引擎: {', '.join(available_engines)}")
|
||||
if unavailable_engines:
|
||||
logger.info(f"❌ 不可用搜索引擎: {', '.join(unavailable_engines)}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 搜索引擎初始化失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
# Python包依赖列表
|
||||
python_dependencies: List[PythonDependency] = [
|
||||
PythonDependency(
|
||||
package_name="asyncddgs",
|
||||
description="异步DuckDuckGo搜索库",
|
||||
optional=False
|
||||
),
|
||||
PythonDependency(package_name="asyncddgs", description="异步DuckDuckGo搜索库", optional=False),
|
||||
PythonDependency(
|
||||
package_name="exa_py",
|
||||
description="Exa搜索API客户端库",
|
||||
optional=True # 如果没有API密钥,这个是可选的
|
||||
optional=True, # 如果没有API密钥,这个是可选的
|
||||
),
|
||||
PythonDependency(
|
||||
package_name="tavily",
|
||||
install_name="tavily-python", # 安装时使用这个名称
|
||||
description="Tavily搜索API客户端库",
|
||||
optional=True # 如果没有API密钥,这个是可选的
|
||||
optional=True, # 如果没有API密钥,这个是可选的
|
||||
),
|
||||
PythonDependency(
|
||||
package_name="httpx",
|
||||
version=">=0.20.0",
|
||||
install_name="httpx[socks]", # 安装时使用这个名称(包含可选依赖)
|
||||
description="支持SOCKS代理的HTTP客户端库",
|
||||
optional=False
|
||||
)
|
||||
optional=False,
|
||||
),
|
||||
]
|
||||
config_file_name: str = "config.toml" # 配置文件名
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {
|
||||
"plugin": "插件基本信息",
|
||||
"proxy": "链接本地解析代理配置"
|
||||
}
|
||||
config_section_descriptions = {"plugin": "插件基本信息", "proxy": "链接本地解析代理配置"}
|
||||
|
||||
# 配置Schema定义
|
||||
# 注意:EXA配置和组件设置已迁移到主配置文件(bot_config.toml)的[exa]和[web_search]部分
|
||||
@@ -119,42 +107,32 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
},
|
||||
"proxy": {
|
||||
"http_proxy": ConfigField(
|
||||
type=str,
|
||||
default=None,
|
||||
description="HTTP代理地址,格式如: http://proxy.example.com:8080"
|
||||
type=str, default=None, description="HTTP代理地址,格式如: http://proxy.example.com:8080"
|
||||
),
|
||||
"https_proxy": ConfigField(
|
||||
type=str,
|
||||
default=None,
|
||||
description="HTTPS代理地址,格式如: http://proxy.example.com:8080"
|
||||
type=str, default=None, description="HTTPS代理地址,格式如: http://proxy.example.com:8080"
|
||||
),
|
||||
"socks5_proxy": ConfigField(
|
||||
type=str,
|
||||
default=None,
|
||||
description="SOCKS5代理地址,格式如: socks5://proxy.example.com:1080"
|
||||
type=str, default=None, description="SOCKS5代理地址,格式如: socks5://proxy.example.com:1080"
|
||||
),
|
||||
"enable_proxy": ConfigField(
|
||||
type=bool,
|
||||
default=False,
|
||||
description="是否启用代理"
|
||||
)
|
||||
"enable_proxy": ConfigField(type=bool, default=False, description="是否启用代理"),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
"""
|
||||
获取插件组件列表
|
||||
|
||||
|
||||
Returns:
|
||||
组件信息和类型的元组列表
|
||||
"""
|
||||
enable_tool = []
|
||||
|
||||
|
||||
# 从主配置文件读取组件启用配置
|
||||
if config_api.get_global_config("web_search.enable_web_search_tool", True):
|
||||
enable_tool.append((WebSurfingTool.get_tool_info(), WebSurfingTool))
|
||||
|
||||
|
||||
if config_api.get_global_config("web_search.enable_url_tool", True):
|
||||
enable_tool.append((URLParserTool.get_tool_info(), URLParserTool))
|
||||
|
||||
|
||||
return enable_tool
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
URL parser tool implementation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from typing import Any, Dict
|
||||
@@ -24,17 +25,18 @@ class URLParserTool(BaseTool):
|
||||
"""
|
||||
一个用于解析和总结一个或多个网页URL内容的工具。
|
||||
"""
|
||||
|
||||
name: str = "parse_url"
|
||||
description: str = "当需要理解一个或多个特定网页链接的内容时,使用此工具。例如:'这些网页讲了什么?[https://example.com, https://example2.com]' 或 '帮我总结一下这些文章'"
|
||||
available_for_llm: bool = True
|
||||
parameters = [
|
||||
("urls", ToolParamType.STRING, "要理解的网站", True, None),
|
||||
]
|
||||
|
||||
|
||||
def __init__(self, plugin_config=None):
|
||||
super().__init__(plugin_config)
|
||||
self._initialize_exa_clients()
|
||||
|
||||
|
||||
def _initialize_exa_clients(self):
|
||||
"""初始化Exa客户端"""
|
||||
# 优先从主配置文件读取,如果没有则从插件配置文件读取
|
||||
@@ -42,12 +44,10 @@ class URLParserTool(BaseTool):
|
||||
if exa_api_keys is None:
|
||||
# 从插件配置文件读取
|
||||
exa_api_keys = self.get_config("exa.api_keys", [])
|
||||
|
||||
|
||||
# 创建API密钥管理器
|
||||
self.api_manager = create_api_key_manager_from_config(
|
||||
exa_api_keys,
|
||||
lambda key: Exa(api_key=key),
|
||||
"Exa URL Parser"
|
||||
exa_api_keys, lambda key: Exa(api_key=key), "Exa URL Parser"
|
||||
)
|
||||
|
||||
async def _local_parse_and_summarize(self, url: str) -> Dict[str, Any]:
|
||||
@@ -58,12 +58,12 @@ class URLParserTool(BaseTool):
|
||||
# 读取代理配置
|
||||
enable_proxy = self.get_config("proxy.enable_proxy", False)
|
||||
proxies = None
|
||||
|
||||
|
||||
if enable_proxy:
|
||||
socks5_proxy = self.get_config("proxy.socks5_proxy", None)
|
||||
http_proxy = self.get_config("proxy.http_proxy", None)
|
||||
https_proxy = self.get_config("proxy.https_proxy", None)
|
||||
|
||||
|
||||
# 优先使用SOCKS5代理(全协议代理)
|
||||
if socks5_proxy:
|
||||
proxies = socks5_proxy
|
||||
@@ -75,17 +75,17 @@ class URLParserTool(BaseTool):
|
||||
if https_proxy:
|
||||
proxies["https://"] = https_proxy
|
||||
logger.info(f"使用HTTP/HTTPS代理配置: {proxies}")
|
||||
|
||||
|
||||
client_kwargs = {"timeout": 15.0, "follow_redirects": True}
|
||||
if proxies:
|
||||
client_kwargs["proxies"] = proxies
|
||||
|
||||
|
||||
async with httpx.AsyncClient(**client_kwargs) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
|
||||
|
||||
title = soup.title.string if soup.title else "无标题"
|
||||
for script in soup(["script", "style"]):
|
||||
script.extract()
|
||||
@@ -104,12 +104,12 @@ class URLParserTool(BaseTool):
|
||||
return {"error": "未配置LLM模型"}
|
||||
|
||||
success, summary, reasoning, model_name = await llm_api.generate_with_model(
|
||||
prompt=summary_prompt,
|
||||
model_config=model_config,
|
||||
request_type="story.generate",
|
||||
temperature=0.3,
|
||||
max_tokens=1000
|
||||
)
|
||||
prompt=summary_prompt,
|
||||
model_config=model_config,
|
||||
request_type="story.generate",
|
||||
temperature=0.3,
|
||||
max_tokens=1000,
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.info(f"生成摘要失败: {summary}")
|
||||
@@ -117,12 +117,7 @@ class URLParserTool(BaseTool):
|
||||
|
||||
logger.info(f"成功生成摘要内容:'{summary}'")
|
||||
|
||||
return {
|
||||
"title": title,
|
||||
"url": url,
|
||||
"snippet": summary,
|
||||
"source": "local"
|
||||
}
|
||||
return {"title": title, "url": url, "snippet": summary, "source": "local"}
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.warning(f"本地解析URL '{url}' 失败 (HTTP {e.response.status_code})")
|
||||
@@ -137,6 +132,7 @@ class URLParserTool(BaseTool):
|
||||
"""
|
||||
# 获取当前文件路径用于缓存键
|
||||
import os
|
||||
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
|
||||
# 检查缓存
|
||||
@@ -144,7 +140,7 @@ class URLParserTool(BaseTool):
|
||||
if cached_result:
|
||||
logger.info(f"缓存命中: {self.name} -> {function_args}")
|
||||
return cached_result
|
||||
|
||||
|
||||
urls_input = function_args.get("urls")
|
||||
if not urls_input:
|
||||
return {"error": "URL列表不能为空。"}
|
||||
@@ -158,14 +154,14 @@ class URLParserTool(BaseTool):
|
||||
valid_urls = validate_urls(urls)
|
||||
if not valid_urls:
|
||||
return {"error": "未找到有效的URL。"}
|
||||
|
||||
|
||||
urls = valid_urls
|
||||
logger.info(f"准备解析 {len(urls)} 个URL: {urls}")
|
||||
|
||||
successful_results = []
|
||||
error_messages = []
|
||||
urls_to_retry_locally = []
|
||||
|
||||
|
||||
# 步骤 1: 尝试使用 Exa API 进行解析
|
||||
contents_response = None
|
||||
if self.api_manager.is_available():
|
||||
@@ -182,41 +178,45 @@ class URLParserTool(BaseTool):
|
||||
contents_response = await loop.run_in_executor(None, func)
|
||||
except Exception as e:
|
||||
logger.error(f"执行 Exa URL解析时发生严重异常: {e}", exc_info=True)
|
||||
contents_response = None # 确保异常后为None
|
||||
contents_response = None # 确保异常后为None
|
||||
|
||||
# 步骤 2: 处理Exa的响应
|
||||
if contents_response and hasattr(contents_response, 'statuses'):
|
||||
results_map = {res.url: res for res in contents_response.results} if hasattr(contents_response, 'results') else {}
|
||||
if contents_response and hasattr(contents_response, "statuses"):
|
||||
results_map = (
|
||||
{res.url: res for res in contents_response.results} if hasattr(contents_response, "results") else {}
|
||||
)
|
||||
if contents_response.statuses:
|
||||
for status in contents_response.statuses:
|
||||
if status.status == 'success':
|
||||
if status.status == "success":
|
||||
res = results_map.get(status.id)
|
||||
if res:
|
||||
summary = getattr(res, 'summary', '')
|
||||
highlights = " ".join(getattr(res, 'highlights', []))
|
||||
text_snippet = (getattr(res, 'text', '')[:300] + '...') if getattr(res, 'text', '') else ''
|
||||
snippet = summary or highlights or text_snippet or '无摘要'
|
||||
|
||||
successful_results.append({
|
||||
"title": getattr(res, 'title', '无标题'),
|
||||
"url": getattr(res, 'url', status.id),
|
||||
"snippet": snippet,
|
||||
"source": "exa"
|
||||
})
|
||||
summary = getattr(res, "summary", "")
|
||||
highlights = " ".join(getattr(res, "highlights", []))
|
||||
text_snippet = (getattr(res, "text", "")[:300] + "...") if getattr(res, "text", "") else ""
|
||||
snippet = summary or highlights or text_snippet or "无摘要"
|
||||
|
||||
successful_results.append(
|
||||
{
|
||||
"title": getattr(res, "title", "无标题"),
|
||||
"url": getattr(res, "url", status.id),
|
||||
"snippet": snippet,
|
||||
"source": "exa",
|
||||
}
|
||||
)
|
||||
else:
|
||||
error_tag = getattr(status, 'error', '未知错误')
|
||||
error_tag = getattr(status, "error", "未知错误")
|
||||
logger.warning(f"Exa解析URL '{status.id}' 失败: {error_tag}。准备本地重试。")
|
||||
urls_to_retry_locally.append(status.id)
|
||||
else:
|
||||
# 如果Exa未配置、API调用失败或返回无效响应,则所有URL都进入本地重试
|
||||
urls_to_retry_locally.extend(url for url in urls if url not in [res['url'] for res in successful_results])
|
||||
urls_to_retry_locally.extend(url for url in urls if url not in [res["url"] for res in successful_results])
|
||||
|
||||
# 步骤 3: 对失败的URL进行本地解析
|
||||
if urls_to_retry_locally:
|
||||
logger.info(f"开始本地解析以下URL: {urls_to_retry_locally}")
|
||||
local_tasks = [self._local_parse_and_summarize(url) for url in urls_to_retry_locally]
|
||||
local_results = await asyncio.gather(*local_tasks)
|
||||
|
||||
|
||||
for i, res in enumerate(local_results):
|
||||
url = urls_to_retry_locally[i]
|
||||
if "error" in res:
|
||||
@@ -228,13 +228,9 @@ class URLParserTool(BaseTool):
|
||||
return {"error": "无法从所有给定的URL获取内容。", "details": error_messages}
|
||||
|
||||
formatted_content = format_url_parse_results(successful_results)
|
||||
|
||||
result = {
|
||||
"type": "url_parse_result",
|
||||
"content": formatted_content,
|
||||
"errors": error_messages
|
||||
}
|
||||
|
||||
|
||||
result = {"type": "url_parse_result", "content": formatted_content, "errors": error_messages}
|
||||
|
||||
# 保存到缓存
|
||||
if "error" not in result:
|
||||
await tool_cache.set(self.name, function_args, current_file_path, result)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Web search tool implementation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
|
||||
@@ -22,14 +23,23 @@ class WebSurfingTool(BaseTool):
|
||||
"""
|
||||
网络搜索工具
|
||||
"""
|
||||
|
||||
name: str = "web_search"
|
||||
description: str = "用于执行网络搜索。当用户明确要求搜索,或者需要获取关于公司、产品、事件的最新信息、新闻或动态时,必须使用此工具"
|
||||
description: str = (
|
||||
"用于执行网络搜索。当用户明确要求搜索,或者需要获取关于公司、产品、事件的最新信息、新闻或动态时,必须使用此工具"
|
||||
)
|
||||
available_for_llm: bool = True
|
||||
parameters = [
|
||||
("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None),
|
||||
("num_results", ToolParamType.INTEGER, "期望每个搜索引擎返回的搜索结果数量,默认为5。", False, None),
|
||||
("time_range", ToolParamType.STRING, "指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'。", False, ["any", "week", "month"])
|
||||
] # type: ignore
|
||||
(
|
||||
"time_range",
|
||||
ToolParamType.STRING,
|
||||
"指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'。",
|
||||
False,
|
||||
["any", "week", "month"],
|
||||
),
|
||||
] # type: ignore
|
||||
|
||||
def __init__(self, plugin_config=None):
|
||||
super().__init__(plugin_config)
|
||||
@@ -38,7 +48,7 @@ class WebSurfingTool(BaseTool):
|
||||
"exa": ExaSearchEngine(),
|
||||
"tavily": TavilySearchEngine(),
|
||||
"ddg": DDGSearchEngine(),
|
||||
"bing": BingSearchEngine()
|
||||
"bing": BingSearchEngine(),
|
||||
}
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@@ -48,6 +58,7 @@ class WebSurfingTool(BaseTool):
|
||||
|
||||
# 获取当前文件路径用于缓存键
|
||||
import os
|
||||
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
|
||||
# 检查缓存
|
||||
@@ -59,7 +70,7 @@ class WebSurfingTool(BaseTool):
|
||||
# 读取搜索配置
|
||||
enabled_engines = config_api.get_global_config("web_search.enabled_engines", ["ddg"])
|
||||
search_strategy = config_api.get_global_config("web_search.search_strategy", "single")
|
||||
|
||||
|
||||
logger.info(f"开始搜索,策略: {search_strategy}, 启用引擎: {enabled_engines}, 参数: '{function_args}'")
|
||||
|
||||
# 根据策略执行搜索
|
||||
@@ -69,17 +80,19 @@ class WebSurfingTool(BaseTool):
|
||||
result = await self._execute_fallback_search(function_args, enabled_engines)
|
||||
else: # single
|
||||
result = await self._execute_single_search(function_args, enabled_engines)
|
||||
|
||||
|
||||
# 保存到缓存
|
||||
if "error" not in result:
|
||||
await tool_cache.set(self.name, function_args, current_file_path, result, semantic_query=query)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
async def _execute_parallel_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]:
|
||||
async def _execute_parallel_search(
|
||||
self, function_args: Dict[str, Any], enabled_engines: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""并行搜索策略:同时使用所有启用的搜索引擎"""
|
||||
search_tasks = []
|
||||
|
||||
|
||||
for engine_name in enabled_engines:
|
||||
engine = self.engines.get(engine_name)
|
||||
if engine and engine.is_available():
|
||||
@@ -92,7 +105,7 @@ class WebSurfingTool(BaseTool):
|
||||
|
||||
try:
|
||||
search_results_lists = await asyncio.gather(*search_tasks, return_exceptions=True)
|
||||
|
||||
|
||||
all_results = []
|
||||
for result in search_results_lists:
|
||||
if isinstance(result, list):
|
||||
@@ -103,7 +116,7 @@ class WebSurfingTool(BaseTool):
|
||||
# 去重并格式化
|
||||
unique_results = deduplicate_results(all_results)
|
||||
formatted_content = format_search_results(unique_results)
|
||||
|
||||
|
||||
return {
|
||||
"type": "web_search_result",
|
||||
"content": formatted_content,
|
||||
@@ -113,30 +126,32 @@ class WebSurfingTool(BaseTool):
|
||||
logger.error(f"执行并行网络搜索时发生异常: {e}", exc_info=True)
|
||||
return {"error": f"执行网络搜索时发生严重错误: {str(e)}"}
|
||||
|
||||
async def _execute_fallback_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]:
|
||||
async def _execute_fallback_search(
|
||||
self, function_args: Dict[str, Any], enabled_engines: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""回退搜索策略:按顺序尝试搜索引擎,失败则尝试下一个"""
|
||||
for engine_name in enabled_engines:
|
||||
engine = self.engines.get(engine_name)
|
||||
if not engine or not engine.is_available():
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
custom_args = function_args.copy()
|
||||
custom_args["num_results"] = custom_args.get("num_results", 5)
|
||||
|
||||
|
||||
results = await engine.search(custom_args)
|
||||
|
||||
|
||||
if results: # 如果有结果,直接返回
|
||||
formatted_content = format_search_results(results)
|
||||
return {
|
||||
"type": "web_search_result",
|
||||
"content": formatted_content,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"{engine_name} 搜索失败,尝试下一个引擎: {e}")
|
||||
continue
|
||||
|
||||
|
||||
return {"error": "所有搜索引擎都失败了。"}
|
||||
|
||||
async def _execute_single_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]:
|
||||
@@ -145,20 +160,20 @@ class WebSurfingTool(BaseTool):
|
||||
engine = self.engines.get(engine_name)
|
||||
if not engine or not engine.is_available():
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
custom_args = function_args.copy()
|
||||
custom_args["num_results"] = custom_args.get("num_results", 5)
|
||||
|
||||
|
||||
results = await engine.search(custom_args)
|
||||
formatted_content = format_search_results(results)
|
||||
return {
|
||||
"type": "web_search_result",
|
||||
"content": formatted_content,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{engine_name} 搜索失败: {e}")
|
||||
return {"error": f"{engine_name} 搜索失败: {str(e)}"}
|
||||
|
||||
|
||||
return {"error": "没有可用的搜索引擎。"}
|
||||
|
||||
@@ -1,24 +1,25 @@
|
||||
"""
|
||||
API密钥管理器,提供轮询机制
|
||||
"""
|
||||
|
||||
import itertools
|
||||
from typing import List, Optional, TypeVar, Generic, Callable
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("api_key_manager")
|
||||
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class APIKeyManager(Generic[T]):
|
||||
"""
|
||||
API密钥管理器,支持轮询机制
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, api_keys: List[str], client_factory: Callable[[str], T], service_name: str = "Unknown"):
|
||||
"""
|
||||
初始化API密钥管理器
|
||||
|
||||
|
||||
Args:
|
||||
api_keys: API密钥列表
|
||||
client_factory: 客户端工厂函数,接受API密钥参数并返回客户端实例
|
||||
@@ -27,14 +28,14 @@ class APIKeyManager(Generic[T]):
|
||||
self.service_name = service_name
|
||||
self.clients: List[T] = []
|
||||
self.client_cycle: Optional[itertools.cycle] = None
|
||||
|
||||
|
||||
if api_keys:
|
||||
# 过滤有效的API密钥,排除None、空字符串、"None"字符串等
|
||||
valid_keys = []
|
||||
for key in api_keys:
|
||||
if isinstance(key, str) and key.strip() and key.strip().lower() not in ("none", "null", ""):
|
||||
valid_keys.append(key.strip())
|
||||
|
||||
|
||||
if valid_keys:
|
||||
try:
|
||||
self.clients = [client_factory(key) for key in valid_keys]
|
||||
@@ -48,35 +49,33 @@ class APIKeyManager(Generic[T]):
|
||||
logger.warning(f"⚠️ {service_name} API Keys 配置无效(包含None或空值),{service_name} 功能将不可用")
|
||||
else:
|
||||
logger.warning(f"⚠️ {service_name} API Keys 未配置,{service_name} 功能将不可用")
|
||||
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""检查是否有可用的客户端"""
|
||||
return bool(self.clients and self.client_cycle)
|
||||
|
||||
|
||||
def get_next_client(self) -> Optional[T]:
|
||||
"""获取下一个客户端(轮询)"""
|
||||
if not self.is_available():
|
||||
return None
|
||||
return next(self.client_cycle)
|
||||
|
||||
|
||||
def get_client_count(self) -> int:
|
||||
"""获取可用客户端数量"""
|
||||
return len(self.clients)
|
||||
|
||||
|
||||
def create_api_key_manager_from_config(
|
||||
config_keys: Optional[List[str]],
|
||||
client_factory: Callable[[str], T],
|
||||
service_name: str
|
||||
config_keys: Optional[List[str]], client_factory: Callable[[str], T], service_name: str
|
||||
) -> APIKeyManager[T]:
|
||||
"""
|
||||
从配置创建API密钥管理器的便捷函数
|
||||
|
||||
|
||||
Args:
|
||||
config_keys: 从配置读取的API密钥列表
|
||||
client_factory: 客户端工厂函数
|
||||
service_name: 服务名称
|
||||
|
||||
|
||||
Returns:
|
||||
API密钥管理器实例
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Formatters for web search results
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any
|
||||
|
||||
|
||||
@@ -13,15 +14,15 @@ def format_search_results(results: List[Dict[str, Any]]) -> str:
|
||||
|
||||
formatted_string = "根据网络搜索结果:\n\n"
|
||||
for i, res in enumerate(results, 1):
|
||||
title = res.get("title", '无标题')
|
||||
url = res.get("url", '#')
|
||||
snippet = res.get("snippet", '无摘要')
|
||||
title = res.get("title", "无标题")
|
||||
url = res.get("url", "#")
|
||||
snippet = res.get("snippet", "无摘要")
|
||||
provider = res.get("provider", "未知来源")
|
||||
|
||||
|
||||
formatted_string += f"{i}. **{title}** (来自: {provider})\n"
|
||||
formatted_string += f" - 摘要: {snippet}\n"
|
||||
formatted_string += f" - 来源: {url}\n\n"
|
||||
|
||||
|
||||
return formatted_string
|
||||
|
||||
|
||||
@@ -31,10 +32,10 @@ def format_url_parse_results(results: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
formatted_parts = []
|
||||
for res in results:
|
||||
title = res.get('title', '无标题')
|
||||
url = res.get('url', '#')
|
||||
snippet = res.get('snippet', '无摘要')
|
||||
source = res.get('source', '未知')
|
||||
title = res.get("title", "无标题")
|
||||
url = res.get("url", "#")
|
||||
snippet = res.get("snippet", "无摘要")
|
||||
source = res.get("source", "未知")
|
||||
|
||||
formatted_string = f"**{title}**\n"
|
||||
formatted_string += f"**内容摘要**:\n{snippet}\n"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
URL processing utilities
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
@@ -12,11 +13,11 @@ def parse_urls_from_input(urls_input) -> List[str]:
|
||||
if isinstance(urls_input, str):
|
||||
# 如果是字符串,尝试解析为URL列表
|
||||
# 提取所有HTTP/HTTPS URL
|
||||
url_pattern = r'https?://[^\s\],]+'
|
||||
url_pattern = r"https?://[^\s\],]+"
|
||||
urls = re.findall(url_pattern, urls_input)
|
||||
if not urls:
|
||||
# 如果没有找到标准URL,将整个字符串作为单个URL
|
||||
if urls_input.strip().startswith(('http://', 'https://')):
|
||||
if urls_input.strip().startswith(("http://", "https://")):
|
||||
urls = [urls_input.strip()]
|
||||
else:
|
||||
return []
|
||||
@@ -24,7 +25,7 @@ def parse_urls_from_input(urls_input) -> List[str]:
|
||||
urls = [url.strip() for url in urls_input if isinstance(url, str) and url.strip()]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
@@ -34,6 +35,6 @@ def validate_urls(urls: List[str]) -> List[str]:
|
||||
"""
|
||||
valid_urls = []
|
||||
for url in urls:
|
||||
if url.startswith(('http://', 'https://')):
|
||||
if url.startswith(("http://", "https://")):
|
||||
valid_urls.append(url)
|
||||
return valid_urls
|
||||
|
||||
216
src/plugins/reminder_plugin/plugin.py
Normal file
216
src/plugins/reminder_plugin/plugin.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple, Type
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.plugin_system import (
|
||||
BaseAction,
|
||||
ActionInfo,
|
||||
BasePlugin,
|
||||
register_plugin,
|
||||
ActionActivationType,
|
||||
)
|
||||
from src.plugin_system.apis import send_api
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# ============================ AsyncTask ============================
|
||||
|
||||
|
||||
class ReminderTask(AsyncTask):
|
||||
def __init__(
|
||||
self,
|
||||
delay: float,
|
||||
stream_id: str,
|
||||
is_group: bool,
|
||||
target_user_id: str,
|
||||
target_user_name: str,
|
||||
event_details: str,
|
||||
creator_name: str,
|
||||
):
|
||||
super().__init__(task_name=f"ReminderTask_{target_user_id}_{datetime.now().timestamp()}")
|
||||
self.delay = delay
|
||||
self.stream_id = stream_id
|
||||
self.is_group = is_group
|
||||
self.target_user_id = target_user_id
|
||||
self.target_user_name = target_user_name
|
||||
self.event_details = event_details
|
||||
self.creator_name = creator_name
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
if self.delay > 0:
|
||||
logger.info(f"等待 {self.delay:.2f} 秒后执行提醒...")
|
||||
await asyncio.sleep(self.delay)
|
||||
|
||||
logger.info(f"执行提醒任务: 给 {self.target_user_name} 发送关于 '{self.event_details}' 的提醒")
|
||||
|
||||
reminder_text = f"叮咚!这是 {self.creator_name} 让我准时提醒你的事情:\n\n{self.event_details}"
|
||||
|
||||
if self.is_group:
|
||||
# 在群聊中,构造 @ 消息段并发送
|
||||
group_id = self.stream_id.split("_")[-1] if "_" in self.stream_id else self.stream_id
|
||||
message_payload = [
|
||||
{"type": "at", "data": {"qq": self.target_user_id}},
|
||||
{"type": "text", "data": {"text": f" {reminder_text}"}},
|
||||
]
|
||||
await send_api.adapter_command_to_stream(
|
||||
action="send_group_msg",
|
||||
params={"group_id": group_id, "message": message_payload},
|
||||
stream_id=self.stream_id,
|
||||
)
|
||||
else:
|
||||
# 在私聊中,直接发送文本
|
||||
await send_api.text_to_stream(text=reminder_text, stream_id=self.stream_id)
|
||||
|
||||
logger.info(f"提醒任务 {self.task_name} 成功完成。")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行提醒任务 {self.task_name} 时出错: {e}", exc_info=True)
|
||||
|
||||
|
||||
# =============================== Actions ===============================
|
||||
|
||||
|
||||
class RemindAction(BaseAction):
|
||||
"""一个能从对话中智能识别并设置定时提醒的动作。"""
|
||||
|
||||
# === 基本信息 ===
|
||||
action_name = "set_reminder"
|
||||
action_description = "根据用户的对话内容,智能地设置一个未来的提醒事项。"
|
||||
activation_type = ActionActivationType.LLM_JUDGE
|
||||
chat_type_allow = ChatType.ALL
|
||||
|
||||
# === LLM 判断与参数提取 ===
|
||||
llm_judge_prompt = """
|
||||
判断用户是否意图设置一个未来的提醒。
|
||||
- 必须包含明确的时间点或时间段(如“十分钟后”、“明天下午3点”、“周五”)。
|
||||
- 必须包含一个需要被提醒的事件。
|
||||
- 可能会包含需要提醒的特定人物。
|
||||
- 如果只是普通的聊天或询问时间,则不应触发。
|
||||
|
||||
示例:
|
||||
- "半小时后提醒我开会" -> 是
|
||||
- "明天下午三点叫张三来一下" -> 是
|
||||
- "别忘了周五把报告交了" -> 是
|
||||
- "现在几点了?" -> 否
|
||||
- "我明天下午有空" -> 否
|
||||
|
||||
请只回答"是"或"否"。
|
||||
"""
|
||||
action_parameters = {
|
||||
"user_name": "需要被提醒的人的称呼或名字,如果没有明确指定给某人,则默认为'自己'",
|
||||
"remind_time": "描述提醒时间的自然语言字符串,例如'十分钟后'或'明天下午3点'",
|
||||
"event_details": "需要提醒的具体事件内容",
|
||||
}
|
||||
action_require = [
|
||||
"当用户请求在未来的某个时间点提醒他/她或别人某件事时使用",
|
||||
"适用于包含明确时间信息和事件描述的对话",
|
||||
"例如:'10分钟后提醒我收快递'、'明天早上九点喊一下李四参加晨会'",
|
||||
]
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""执行设置提醒的动作"""
|
||||
user_name = self.action_data.get("user_name")
|
||||
remind_time_str = self.action_data.get("remind_time")
|
||||
event_details = self.action_data.get("event_details")
|
||||
|
||||
if not all([user_name, remind_time_str, event_details]):
|
||||
missing_params = [
|
||||
p
|
||||
for p, v in {
|
||||
"user_name": user_name,
|
||||
"remind_time": remind_time_str,
|
||||
"event_details": event_details,
|
||||
}.items()
|
||||
if not v
|
||||
]
|
||||
error_msg = f"缺少必要的提醒参数: {', '.join(missing_params)}"
|
||||
logger.warning(f"[ReminderPlugin] LLM未能提取完整参数: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
# 1. 解析时间
|
||||
try:
|
||||
assert isinstance(remind_time_str, str)
|
||||
target_time = parse_datetime(remind_time_str, fuzzy=True)
|
||||
except Exception as e:
|
||||
logger.error(f"[ReminderPlugin] 无法解析时间字符串 '{remind_time_str}': {e}")
|
||||
await self.send_text(f"抱歉,我无法理解您说的时间 '{remind_time_str}',提醒设置失败。")
|
||||
return False, f"无法解析时间 '{remind_time_str}'"
|
||||
|
||||
now = datetime.now()
|
||||
if target_time <= now:
|
||||
await self.send_text("提醒时间必须是一个未来的时间点哦,提醒设置失败。")
|
||||
return False, "提醒时间必须在未来"
|
||||
|
||||
delay_seconds = (target_time - now).total_seconds()
|
||||
|
||||
# 2. 解析用户
|
||||
person_manager = get_person_info_manager()
|
||||
user_id_to_remind = None
|
||||
user_name_to_remind = ""
|
||||
|
||||
assert isinstance(user_name, str)
|
||||
|
||||
if user_name.strip() in ["自己", "我", "me"]:
|
||||
user_id_to_remind = self.user_id
|
||||
user_name_to_remind = self.user_nickname
|
||||
else:
|
||||
user_info = await person_manager.get_person_info_by_name(user_name)
|
||||
if not user_info or not user_info.get("user_id"):
|
||||
logger.warning(f"[ReminderPlugin] 找不到名为 '{user_name}' 的用户")
|
||||
await self.send_text(f"抱歉,我的联系人里找不到叫做 '{user_name}' 的人,提醒设置失败。")
|
||||
return False, f"用户 '{user_name}' 不存在"
|
||||
user_id_to_remind = user_info.get("user_id")
|
||||
user_name_to_remind = user_name
|
||||
|
||||
# 3. 创建并调度异步任务
|
||||
try:
|
||||
assert user_id_to_remind is not None
|
||||
assert event_details is not None
|
||||
|
||||
reminder_task = ReminderTask(
|
||||
delay=delay_seconds,
|
||||
stream_id=self.chat_id,
|
||||
is_group=self.is_group,
|
||||
target_user_id=str(user_id_to_remind),
|
||||
target_user_name=str(user_name_to_remind),
|
||||
event_details=str(event_details),
|
||||
creator_name=str(self.user_nickname),
|
||||
)
|
||||
await async_task_manager.add_task(reminder_task)
|
||||
|
||||
# 4. 发送确认消息
|
||||
confirm_message = f"好的,我记下了。\n将在 {target_time.strftime('%Y-%m-%d %H:%M:%S')} 提醒 {user_name_to_remind}:\n{event_details}"
|
||||
await self.send_text(confirm_message)
|
||||
|
||||
return True, "提醒设置成功"
|
||||
except Exception as e:
|
||||
logger.error(f"[ReminderPlugin] 创建提醒任务时出错: {e}", exc_info=True)
|
||||
await self.send_text("抱歉,设置提醒时发生了一点内部错误。")
|
||||
return False, "设置提醒时发生内部错误"
|
||||
|
||||
|
||||
# =============================== Plugin ===============================
|
||||
|
||||
|
||||
@register_plugin
|
||||
class ReminderPlugin(BasePlugin):
|
||||
"""一个能从对话中智能识别并设置定时提醒的插件。"""
|
||||
|
||||
# --- 插件基础信息 ---
|
||||
plugin_name = "reminder_plugin"
|
||||
enable_plugin = True
|
||||
dependencies = []
|
||||
python_dependencies = []
|
||||
config_file_name = "config.toml"
|
||||
config_schema = {}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ActionInfo, Type[BaseAction]]]:
|
||||
"""注册插件的所有功能组件。"""
|
||||
return [(RemindAction.get_action_info(), RemindAction)]
|
||||
Reference in New Issue
Block a user