Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox_Bot into dev
This commit is contained in:
@@ -7,15 +7,15 @@ import asyncio
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from src.chat.express.expression_learner import expression_learner_manager
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_chatter import BaseChatter
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.plugins.built_in.affinity_flow_chatter.planner import ChatterActionPlanner
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.express.expression_learner import expression_learner_manager
|
||||
|
||||
logger = get_logger("affinity_chatter")
|
||||
|
||||
@@ -113,7 +113,7 @@ class AffinityChatter(BaseChatter):
|
||||
"executed_count": 0,
|
||||
}
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
获取处理器统计信息
|
||||
|
||||
@@ -122,7 +122,7 @@ class AffinityChatter(BaseChatter):
|
||||
"""
|
||||
return self.stats.copy()
|
||||
|
||||
def get_planner_stats(self) -> Dict[str, Any]:
|
||||
def get_planner_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
获取规划器统计信息
|
||||
|
||||
@@ -131,7 +131,7 @@ class AffinityChatter(BaseChatter):
|
||||
"""
|
||||
return self.planner.get_planner_stats()
|
||||
|
||||
def get_interest_scoring_stats(self) -> Dict[str, Any]:
|
||||
def get_interest_scoring_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
获取兴趣度评分统计信息
|
||||
|
||||
@@ -140,7 +140,7 @@ class AffinityChatter(BaseChatter):
|
||||
"""
|
||||
return self.planner.get_interest_scoring_stats()
|
||||
|
||||
def get_relationship_stats(self) -> Dict[str, Any]:
|
||||
def get_relationship_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
获取用户关系统计信息
|
||||
|
||||
@@ -158,7 +158,7 @@ class AffinityChatter(BaseChatter):
|
||||
"""
|
||||
return self.planner.get_current_mood_state()
|
||||
|
||||
def get_mood_stats(self) -> Dict[str, Any]:
|
||||
def get_mood_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
获取情绪状态统计信息
|
||||
|
||||
|
||||
@@ -5,14 +5,14 @@
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from typing import Dict, List, Any
|
||||
from typing import Any
|
||||
|
||||
from src.chat.interest_system import bot_interest_manager
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.info_data_model import InterestScore
|
||||
from src.chat.interest_system import bot_interest_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker
|
||||
|
||||
logger = get_logger("chatter_interest_scoring")
|
||||
|
||||
# 定义颜色
|
||||
@@ -45,13 +45,13 @@ class ChatterInterestScoringSystem:
|
||||
self.probability_boost_per_no_reply = (
|
||||
affinity_config.no_reply_threshold_adjustment / affinity_config.max_no_reply_count
|
||||
) # 每次不回复增加的概率
|
||||
|
||||
|
||||
# 用户关系数据
|
||||
self.user_relationships: Dict[str, float] = {} # user_id -> relationship_score
|
||||
self.user_relationships: dict[str, float] = {} # user_id -> relationship_score
|
||||
|
||||
async def calculate_interest_scores(
|
||||
self, messages: List[DatabaseMessages], bot_nickname: str
|
||||
) -> List[InterestScore]:
|
||||
self, messages: list[DatabaseMessages], bot_nickname: str
|
||||
) -> list[InterestScore]:
|
||||
"""计算消息的兴趣度评分"""
|
||||
user_messages = [msg for msg in messages if str(msg.user_info.user_id) != str(global_config.bot.qq_account)]
|
||||
if not user_messages:
|
||||
@@ -97,7 +97,7 @@ class ChatterInterestScoringSystem:
|
||||
details=details,
|
||||
)
|
||||
|
||||
async def _calculate_interest_match_score(self, content: str, keywords: List[str] = None) -> float:
|
||||
async def _calculate_interest_match_score(self, content: str, keywords: list[str] = None) -> float:
|
||||
"""计算兴趣匹配度 - 使用智能embedding匹配"""
|
||||
if not content:
|
||||
return 0.0
|
||||
@@ -109,7 +109,7 @@ class ChatterInterestScoringSystem:
|
||||
# 智能匹配未初始化,返回默认分数
|
||||
return 0.3
|
||||
|
||||
async def _calculate_smart_interest_match(self, content: str, keywords: List[str] = None) -> float:
|
||||
async def _calculate_smart_interest_match(self, content: str, keywords: list[str] = None) -> float:
|
||||
"""使用embedding计算智能兴趣匹配"""
|
||||
try:
|
||||
# 如果没有传入关键词,则提取
|
||||
@@ -134,7 +134,7 @@ class ChatterInterestScoringSystem:
|
||||
logger.error(f"智能兴趣匹配计算失败: {e}")
|
||||
return 0.0
|
||||
|
||||
def _extract_keywords_from_database(self, message: DatabaseMessages) -> List[str]:
|
||||
def _extract_keywords_from_database(self, message: DatabaseMessages) -> list[str]:
|
||||
"""从数据库消息中提取关键词"""
|
||||
keywords = []
|
||||
|
||||
@@ -166,7 +166,7 @@ class ChatterInterestScoringSystem:
|
||||
|
||||
return keywords[:15] # 返回前15个关键词
|
||||
|
||||
def _extract_keywords_from_content(self, content: str) -> List[str]:
|
||||
def _extract_keywords_from_content(self, content: str) -> list[str]:
|
||||
"""从内容中提取关键词(降级方案)"""
|
||||
import re
|
||||
|
||||
@@ -287,7 +287,7 @@ class ChatterInterestScoringSystem:
|
||||
"""获取用户关系分"""
|
||||
return self.user_relationships.get(user_id, 0.3)
|
||||
|
||||
def get_scoring_stats(self) -> Dict:
|
||||
def get_scoring_stats(self) -> dict:
|
||||
"""获取评分系统统计"""
|
||||
return {
|
||||
"no_reply_count": self.no_reply_count,
|
||||
@@ -318,7 +318,7 @@ class ChatterInterestScoringSystem:
|
||||
logger.error(f"初始化智能兴趣系统失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def get_matching_config(self) -> Dict[str, Any]:
|
||||
def get_matching_config(self) -> dict[str, Any]:
|
||||
"""获取匹配配置信息"""
|
||||
return {
|
||||
"use_smart_matching": self.use_smart_matching,
|
||||
|
||||
@@ -5,12 +5,11 @@ PlanExecutor: 接收 Plan 对象并执行其中的所有动作。
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.common.data_models.info_data_model import Plan, ActionPlannerInfo
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo, Plan
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("plan_executor")
|
||||
|
||||
@@ -52,7 +51,7 @@ class ChatterPlanExecutor:
|
||||
"""设置关系追踪器"""
|
||||
self.relationship_tracker = relationship_tracker
|
||||
|
||||
async def execute(self, plan: Plan) -> Dict[str, any]:
|
||||
async def execute(self, plan: Plan) -> dict[str, any]:
|
||||
"""
|
||||
遍历并执行Plan对象中`decided_actions`列表里的所有动作。
|
||||
|
||||
@@ -65,7 +64,7 @@ class ChatterPlanExecutor:
|
||||
if not plan.decided_actions:
|
||||
logger.info("没有需要执行的动作。")
|
||||
return {"executed_count": 0, "results": []}
|
||||
|
||||
|
||||
# 像hfc一样,提前打印将要执行的动作
|
||||
action_types = [action.action_type for action in plan.decided_actions]
|
||||
logger.info(f"选择动作: {', '.join(action_types) if action_types else '无'}")
|
||||
@@ -110,7 +109,7 @@ class ChatterPlanExecutor:
|
||||
"results": execution_results,
|
||||
}
|
||||
|
||||
async def _execute_reply_actions(self, reply_actions: List[ActionPlannerInfo], plan: Plan) -> Dict[str, any]:
|
||||
async def _execute_reply_actions(self, reply_actions: list[ActionPlannerInfo], plan: Plan) -> dict[str, any]:
|
||||
"""串行执行所有回复动作,增加去重逻辑,避免对同一消息多次回复"""
|
||||
results = []
|
||||
|
||||
@@ -150,17 +149,19 @@ class ChatterPlanExecutor:
|
||||
for i, action_info in enumerate(unique_actions):
|
||||
is_last_action = i == total_actions - 1
|
||||
if total_actions > 1:
|
||||
logger.info(f"[多重回复] 正在执行第 {i+1}/{total_actions} 个回复...")
|
||||
logger.info(f"[多重回复] 正在执行第 {i + 1}/{total_actions} 个回复...")
|
||||
|
||||
# 传递 clear_unread 参数
|
||||
result = await self._execute_single_reply_action(action_info, plan, clear_unread=is_last_action)
|
||||
results.append(result)
|
||||
|
||||
if total_actions > 1:
|
||||
logger.info(f"[多重回复] 所有回复任务执行完毕。")
|
||||
logger.info("[多重回复] 所有回复任务执行完毕。")
|
||||
return {"results": results}
|
||||
|
||||
async def _execute_single_reply_action(self, action_info: ActionPlannerInfo, plan: Plan, clear_unread: bool = True) -> Dict[str, any]:
|
||||
async def _execute_single_reply_action(
|
||||
self, action_info: ActionPlannerInfo, plan: Plan, clear_unread: bool = True
|
||||
) -> dict[str, any]:
|
||||
"""执行单个回复动作"""
|
||||
start_time = time.time()
|
||||
success = False
|
||||
@@ -201,7 +202,7 @@ class ChatterPlanExecutor:
|
||||
execution_result = await self.action_manager.execute_action(
|
||||
action_name=action_info.action_type, **action_params
|
||||
)
|
||||
|
||||
|
||||
# 从返回结果中提取真正的回复文本
|
||||
if isinstance(execution_result, dict):
|
||||
reply_content = execution_result.get("reply_text", "")
|
||||
@@ -233,10 +234,12 @@ class ChatterPlanExecutor:
|
||||
"error_message": error_message,
|
||||
"execution_time": execution_time,
|
||||
"reasoning": action_info.reasoning,
|
||||
"reply_content": reply_content[:200] + "..." if reply_content and len(reply_content) > 200 else reply_content,
|
||||
"reply_content": reply_content[:200] + "..."
|
||||
if reply_content and len(reply_content) > 200
|
||||
else reply_content,
|
||||
}
|
||||
|
||||
async def _execute_other_actions(self, other_actions: List[ActionPlannerInfo], plan: Plan) -> Dict[str, any]:
|
||||
async def _execute_other_actions(self, other_actions: list[ActionPlannerInfo], plan: Plan) -> dict[str, any]:
|
||||
"""执行其他动作"""
|
||||
results = []
|
||||
|
||||
@@ -265,7 +268,7 @@ class ChatterPlanExecutor:
|
||||
|
||||
return {"results": results}
|
||||
|
||||
async def _execute_single_other_action(self, action_info: ActionPlannerInfo, plan: Plan) -> Dict[str, any]:
|
||||
async def _execute_single_other_action(self, action_info: ActionPlannerInfo, plan: Plan) -> dict[str, any]:
|
||||
"""执行单个其他动作"""
|
||||
start_time = time.time()
|
||||
success = False
|
||||
@@ -374,7 +377,7 @@ class ChatterPlanExecutor:
|
||||
logger.debug(f"action_message类型: {type(action_info.action_message)}")
|
||||
logger.debug(f"action_message内容: {action_info.action_message}")
|
||||
|
||||
def get_execution_stats(self) -> Dict[str, any]:
|
||||
def get_execution_stats(self) -> dict[str, any]:
|
||||
"""获取执行统计信息"""
|
||||
stats = self.execution_stats.copy()
|
||||
|
||||
@@ -405,7 +408,7 @@ class ChatterPlanExecutor:
|
||||
"execution_times": [],
|
||||
}
|
||||
|
||||
def get_recent_performance(self, limit: int = 10) -> List[Dict[str, any]]:
|
||||
def get_recent_performance(self, limit: int = 10) -> list[dict[str, any]]:
|
||||
"""获取最近的执行性能"""
|
||||
recent_times = self.execution_stats["execution_times"][-limit:]
|
||||
if not recent_times:
|
||||
|
||||
@@ -2,13 +2,13 @@
|
||||
PlanFilter: 接收 Plan 对象,根据不同模式的逻辑进行筛选,决定最终要执行的动作。
|
||||
"""
|
||||
|
||||
import orjson
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
|
||||
# 旧的Hippocampus系统已被移除,现在使用增强记忆系统
|
||||
@@ -39,7 +39,7 @@ class ChatterPlanFilter:
|
||||
根据 Plan 中的模式和信息,筛选并决定最终的动作。
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str, available_actions: List[str]):
|
||||
def __init__(self, chat_id: str, available_actions: list[str]):
|
||||
"""
|
||||
初始化动作计划筛选器。
|
||||
|
||||
@@ -100,7 +100,7 @@ class ChatterPlanFilter:
|
||||
# 预解析 action_type 来进行判断
|
||||
thinking = item.get("thinking", "未提供思考过程")
|
||||
actions_obj = item.get("actions", {})
|
||||
|
||||
|
||||
# 处理actions字段可能是字典或列表的情况
|
||||
if isinstance(actions_obj, dict):
|
||||
action_type = actions_obj.get("action_type", "no_action")
|
||||
@@ -116,14 +116,12 @@ class ChatterPlanFilter:
|
||||
|
||||
if action_type in reply_action_types:
|
||||
if not reply_action_added:
|
||||
final_actions.extend(
|
||||
await self._parse_single_action(item, used_message_id_list, plan)
|
||||
)
|
||||
final_actions.extend(await self._parse_single_action(item, used_message_id_list, plan))
|
||||
reply_action_added = True
|
||||
else:
|
||||
# 非回复类动作直接添加
|
||||
final_actions.extend(await self._parse_single_action(item, used_message_id_list, plan))
|
||||
|
||||
|
||||
if thinking and thinking != "未提供思考过程":
|
||||
logger.info(f"\n{SAKURA_PINK}思考: {thinking}{RESET_COLOR}\n")
|
||||
plan.decided_actions = self._filter_no_actions(final_actions)
|
||||
@@ -154,6 +152,7 @@ class ChatterPlanFilter:
|
||||
schedule_block = ""
|
||||
# 优先检查是否被吵醒
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
|
||||
angry_prompt_addition = ""
|
||||
wakeup_mgr = message_manager.wakeup_manager
|
||||
|
||||
@@ -161,7 +160,7 @@ class ChatterPlanFilter:
|
||||
# 检查1: 直接从 wakeup_manager 获取
|
||||
if wakeup_mgr.is_in_angry_state():
|
||||
angry_prompt_addition = wakeup_mgr.get_angry_prompt_addition()
|
||||
|
||||
|
||||
# 检查2: 如果上面没获取到,再从 mood_manager 确认
|
||||
if not angry_prompt_addition:
|
||||
chat_mood_for_check = mood_manager.get_mood_by_chat_id(plan.chat_id)
|
||||
@@ -274,7 +273,9 @@ class ChatterPlanFilter:
|
||||
is_group_chat = plan.chat_type == ChatType.GROUP
|
||||
chat_context_description = "你现在正在一个群聊中"
|
||||
if not is_group_chat and plan.target_info:
|
||||
chat_target_name = plan.target_info.get("person_name") or plan.target_info.get("user_nickname") or "对方"
|
||||
chat_target_name = (
|
||||
plan.target_info.get("person_name") or plan.target_info.get("user_nickname") or "对方"
|
||||
)
|
||||
chat_context_description = f"你正在和 {chat_target_name} 私聊"
|
||||
|
||||
action_options_block = await self._build_action_options(plan.available_actions)
|
||||
@@ -315,12 +316,12 @@ class ChatterPlanFilter:
|
||||
"""构建已读/未读历史消息块"""
|
||||
try:
|
||||
# 从message_manager获取真实的已读/未读消息
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
from src.chat.utils.utils import assign_message_ids
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
|
||||
from src.chat.utils.utils import assign_message_ids
|
||||
|
||||
# 获取聊天流的上下文
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(plan.chat_id)
|
||||
if not chat_stream:
|
||||
@@ -333,6 +334,7 @@ class ChatterPlanFilter:
|
||||
read_messages = stream_context.context.history_messages # 已读消息存储在history_messages中
|
||||
if not read_messages:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
# 如果内存中没有已读消息(比如刚启动),则从数据库加载最近的上下文
|
||||
fallback_messages_dicts = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=plan.chat_id,
|
||||
@@ -390,14 +392,15 @@ class ChatterPlanFilter:
|
||||
logger.error(f"构建已读/未读历史消息块时出错: {e}")
|
||||
return "构建已读历史消息时出错", "构建未读历史消息时出错", []
|
||||
|
||||
async def _get_interest_scores_for_messages(self, messages: List[dict]) -> dict[str, float]:
|
||||
async def _get_interest_scores_for_messages(self, messages: list[dict]) -> dict[str, float]:
|
||||
"""为消息获取兴趣度评分"""
|
||||
interest_scores = {}
|
||||
|
||||
try:
|
||||
from .interest_scoring import chatter_interest_scoring_system
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
from .interest_scoring import chatter_interest_scoring_system
|
||||
|
||||
# 使用插件内部的兴趣度评分系统计算评分
|
||||
for msg_dict in messages:
|
||||
try:
|
||||
@@ -414,7 +417,7 @@ class ChatterPlanFilter:
|
||||
processed_plain_text=msg_dict.get("processed_plain_text", ""),
|
||||
key_words=msg_dict.get("key_words", "[]"),
|
||||
is_mentioned=msg_dict.get("is_mentioned", False),
|
||||
**{"user_info": user_info_dict} # 通过kwargs传入user_info
|
||||
**{"user_info": user_info_dict}, # 通过kwargs传入user_info
|
||||
)
|
||||
else:
|
||||
# 如果没有user_info字段,使用平铺的字段(flatten()方法返回的格式)
|
||||
@@ -425,13 +428,12 @@ class ChatterPlanFilter:
|
||||
user_platform=msg_dict.get("user_platform", ""),
|
||||
processed_plain_text=msg_dict.get("processed_plain_text", ""),
|
||||
key_words=msg_dict.get("key_words", "[]"),
|
||||
is_mentioned=msg_dict.get("is_mentioned", False)
|
||||
is_mentioned=msg_dict.get("is_mentioned", False),
|
||||
)
|
||||
|
||||
# 计算消息兴趣度
|
||||
interest_score_obj = await chatter_interest_scoring_system._calculate_single_message_score(
|
||||
message=db_message,
|
||||
bot_nickname=global_config.bot.nickname
|
||||
message=db_message, bot_nickname=global_config.bot.nickname
|
||||
)
|
||||
interest_score = interest_score_obj.total_score
|
||||
|
||||
@@ -449,12 +451,12 @@ class ChatterPlanFilter:
|
||||
|
||||
async def _parse_single_action(
|
||||
self, action_json: dict, message_id_list: list, plan: Plan
|
||||
) -> List[ActionPlannerInfo]:
|
||||
) -> list[ActionPlannerInfo]:
|
||||
parsed_actions = []
|
||||
try:
|
||||
# 从新的actions结构中获取动作信息
|
||||
actions_obj = action_json.get("actions", {})
|
||||
|
||||
|
||||
# 处理actions字段可能是字典或列表的情况
|
||||
actions_to_process = []
|
||||
if isinstance(actions_obj, dict):
|
||||
@@ -463,19 +465,23 @@ class ChatterPlanFilter:
|
||||
actions_to_process.extend(actions_obj)
|
||||
|
||||
if not actions_to_process:
|
||||
actions_to_process.append({"action_type": "no_action", "reason": "actions格式错误"})
|
||||
actions_to_process.append({"action_type": "no_action", "reason": "actions格式错误"})
|
||||
|
||||
for single_action_obj in actions_to_process:
|
||||
if not isinstance(single_action_obj, dict):
|
||||
continue
|
||||
|
||||
action = single_action_obj.get("action_type", "no_action")
|
||||
reasoning = single_action_obj.get("reasoning", "未提供原因") # 兼容旧的reason字段
|
||||
reasoning = single_action_obj.get("reasoning", "未提供原因") # 兼容旧的reason字段
|
||||
action_data = single_action_obj.get("action_data", {})
|
||||
|
||||
|
||||
# 为了向后兼容,如果action_data不存在,则从顶层字段获取
|
||||
if not action_data:
|
||||
action_data = {k: v for k, v in single_action_obj.items() if k not in ["action_type", "reason", "reasoning", "thinking"]}
|
||||
action_data = {
|
||||
k: v
|
||||
for k, v in single_action_obj.items()
|
||||
if k not in ["action_type", "reason", "reasoning", "thinking"]
|
||||
}
|
||||
|
||||
# 保留原始的thinking字段(如果有)
|
||||
thinking = action_json.get("thinking", "")
|
||||
@@ -501,7 +507,9 @@ class ChatterPlanFilter:
|
||||
# reply动作必须有目标消息,使用最新消息作为兜底
|
||||
target_message_dict = self._get_latest_message(message_id_list)
|
||||
if target_message_dict:
|
||||
logger.info(f"[{action}] 使用最新消息作为目标: {target_message_dict.get('message_id')}")
|
||||
logger.info(
|
||||
f"[{action}] 使用最新消息作为目标: {target_message_dict.get('message_id')}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"[{action}] 无法找到任何目标消息,降级为no_action")
|
||||
action = "no_action"
|
||||
@@ -509,15 +517,21 @@ class ChatterPlanFilter:
|
||||
|
||||
elif action in ["poke_user", "set_emoji_like"]:
|
||||
# 这些动作可以尝试其他策略
|
||||
target_message_dict = self._find_poke_notice(message_id_list) or self._get_latest_message(message_id_list)
|
||||
target_message_dict = self._find_poke_notice(
|
||||
message_id_list
|
||||
) or self._get_latest_message(message_id_list)
|
||||
if target_message_dict:
|
||||
logger.info(f"[{action}] 使用替代消息作为目标: {target_message_dict.get('message_id')}")
|
||||
logger.info(
|
||||
f"[{action}] 使用替代消息作为目标: {target_message_dict.get('message_id')}"
|
||||
)
|
||||
|
||||
else:
|
||||
# 其他动作使用最新消息或跳过
|
||||
target_message_dict = self._get_latest_message(message_id_list)
|
||||
if target_message_dict:
|
||||
logger.info(f"[{action}] 使用最新消息作为目标: {target_message_dict.get('message_id')}")
|
||||
logger.info(
|
||||
f"[{action}] 使用最新消息作为目标: {target_message_dict.get('message_id')}"
|
||||
)
|
||||
else:
|
||||
# 如果LLM没有指定target_message_id,进行特殊处理
|
||||
if action == "poke_user":
|
||||
@@ -586,7 +600,7 @@ class ChatterPlanFilter:
|
||||
)
|
||||
return parsed_actions
|
||||
|
||||
def _filter_no_actions(self, action_list: List[ActionPlannerInfo]) -> List[ActionPlannerInfo]:
|
||||
def _filter_no_actions(self, action_list: list[ActionPlannerInfo]) -> list[ActionPlannerInfo]:
|
||||
non_no_actions = [a for a in action_list if a.action_type not in ["no_action", "no_reply"]]
|
||||
if non_no_actions:
|
||||
return non_no_actions
|
||||
@@ -614,7 +628,7 @@ class ChatterPlanFilter:
|
||||
query_text=query,
|
||||
user_id="system", # 系统查询
|
||||
scope_id="system",
|
||||
limit=5
|
||||
limit=5,
|
||||
)
|
||||
|
||||
if not enhanced_memories:
|
||||
@@ -627,7 +641,9 @@ class ChatterPlanFilter:
|
||||
memory_type = memory_chunk.memory_type.value if memory_chunk.memory_type else "unknown"
|
||||
retrieved_memories.append((memory_type, content))
|
||||
|
||||
memory_statements = [f"关于'{topic}', 你记得'{memory_item}'。" for topic, memory_item in retrieved_memories]
|
||||
memory_statements = [
|
||||
f"关于'{topic}', 你记得'{memory_item}'。" for topic, memory_item in retrieved_memories
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"增强记忆系统检索失败,使用默认回复: {e}")
|
||||
@@ -637,7 +653,7 @@ class ChatterPlanFilter:
|
||||
logger.error(f"获取长期记忆时出错: {e}")
|
||||
return "回忆时出现了一些问题。"
|
||||
|
||||
async def _build_action_options(self, current_available_actions: Dict[str, ActionInfo]) -> str:
|
||||
async def _build_action_options(self, current_available_actions: dict[str, ActionInfo]) -> str:
|
||||
action_options_block = ""
|
||||
for action_name, action_info in current_available_actions.items():
|
||||
# 构建参数的JSON示例
|
||||
@@ -648,12 +664,17 @@ class ChatterPlanFilter:
|
||||
if action_name == "set_emoji_like" and p_name == "emoji":
|
||||
# 特殊处理set_emoji_like的emoji参数
|
||||
from src.plugins.built_in.social_toolkit_plugin.qq_emoji_list import qq_face
|
||||
emoji_options = [re.search(r"\[表情:(.+?)\]", name).group(1) for name in qq_face.values() if re.search(r"\[表情:(.+?)\]", name)]
|
||||
|
||||
emoji_options = [
|
||||
re.search(r"\[表情:(.+?)\]", name).group(1)
|
||||
for name in qq_face.values()
|
||||
if re.search(r"\[表情:(.+?)\]", name)
|
||||
]
|
||||
example_value = f"<从'{', '.join(emoji_options[:10])}...'中选择一个>"
|
||||
else:
|
||||
example_value = f"<{p_desc}>"
|
||||
params_json_list.append(f' "{p_name}": "{example_value}"')
|
||||
|
||||
|
||||
# 基础动作信息
|
||||
action_description = action_info.description
|
||||
action_require = "\n".join(f"- {req}" for req in action_info.action_require)
|
||||
@@ -666,11 +687,11 @@ class ChatterPlanFilter:
|
||||
# 将参数列表合并到JSON示例中
|
||||
if params_json_list:
|
||||
# 移除最后一行的逗号
|
||||
json_example_lines.extend([line.rstrip(',') for line in params_json_list])
|
||||
json_example_lines.extend([line.rstrip(",") for line in params_json_list])
|
||||
|
||||
json_example_lines.append(' "reason": "<执行该动作的详细原因>"')
|
||||
json_example_lines.append(" }")
|
||||
|
||||
|
||||
# 使用逗号连接内部元素,除了最后一个
|
||||
json_parts = []
|
||||
for i, line in enumerate(json_example_lines):
|
||||
@@ -678,14 +699,14 @@ class ChatterPlanFilter:
|
||||
if line.strip() in ["{", "}"]:
|
||||
json_parts.append(line)
|
||||
continue
|
||||
|
||||
|
||||
# 检查是否是最后一个需要逗号的元素
|
||||
is_last_item = True
|
||||
for next_line in json_example_lines[i+1:]:
|
||||
for next_line in json_example_lines[i + 1 :]:
|
||||
if next_line.strip() not in ["}"]:
|
||||
is_last_item = False
|
||||
break
|
||||
|
||||
|
||||
if not is_last_item:
|
||||
json_parts.append(f"{line},")
|
||||
else:
|
||||
@@ -703,7 +724,7 @@ class ChatterPlanFilter:
|
||||
)
|
||||
return action_options_block
|
||||
|
||||
def _find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]:
|
||||
def _find_message_by_id(self, message_id: str, message_id_list: list) -> dict[str, Any] | None:
|
||||
"""
|
||||
增强的消息查找函数,支持多种格式和模糊匹配
|
||||
兼容大模型可能返回的各种格式变体
|
||||
@@ -713,7 +734,7 @@ class ChatterPlanFilter:
|
||||
|
||||
# 1. 标准化处理:去除可能的格式干扰
|
||||
original_id = str(message_id).strip()
|
||||
normalized_id = original_id.strip('<>"\'').strip()
|
||||
normalized_id = original_id.strip("<>\"'").strip()
|
||||
|
||||
if not normalized_id:
|
||||
return None
|
||||
@@ -731,12 +752,13 @@ class ChatterPlanFilter:
|
||||
|
||||
# 处理包含在文本中的ID格式 (如 "消息m123" -> 提取 m123)
|
||||
import re
|
||||
|
||||
# 尝试提取各种格式的ID
|
||||
id_patterns = [
|
||||
r'm\d+', # m123格式
|
||||
r'\d+', # 纯数字格式
|
||||
r'buffered-[a-f0-9-]+', # buffered-xxxx格式
|
||||
r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', # UUID格式
|
||||
r"m\d+", # m123格式
|
||||
r"\d+", # 纯数字格式
|
||||
r"buffered-[a-f0-9-]+", # buffered-xxxx格式
|
||||
r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", # UUID格式
|
||||
]
|
||||
|
||||
for pattern in id_patterns:
|
||||
@@ -771,12 +793,12 @@ class ChatterPlanFilter:
|
||||
# 4. 尝试模糊匹配(数字部分匹配)
|
||||
for candidate in candidate_ids:
|
||||
# 提取数字部分进行模糊匹配
|
||||
number_part = re.sub(r'[^0-9]', '', candidate)
|
||||
number_part = re.sub(r"[^0-9]", "", candidate)
|
||||
if number_part:
|
||||
for item in message_id_list:
|
||||
if isinstance(item, dict):
|
||||
item_id = item.get("id", "")
|
||||
item_number = re.sub(r'[^0-9]', '', item_id)
|
||||
item_number = re.sub(r"[^0-9]", "", item_id)
|
||||
|
||||
# 数字部分匹配
|
||||
if item_number == number_part:
|
||||
@@ -787,7 +809,7 @@ class ChatterPlanFilter:
|
||||
message_obj = item.get("message")
|
||||
if isinstance(message_obj, dict):
|
||||
orig_mid = message_obj.get("message_id") or message_obj.get("id")
|
||||
orig_number = re.sub(r'[^0-9]', '', str(orig_mid)) if orig_mid else ""
|
||||
orig_number = re.sub(r"[^0-9]", "", str(orig_mid)) if orig_mid else ""
|
||||
if orig_number == number_part:
|
||||
logger.debug(f"模糊匹配成功(消息对象): {candidate} -> {orig_mid}")
|
||||
return message_obj
|
||||
@@ -807,12 +829,12 @@ class ChatterPlanFilter:
|
||||
logger.warning(f"未找到任何匹配的消息: {original_id} (候选: {candidate_ids})")
|
||||
return None
|
||||
|
||||
def _get_latest_message(self, message_id_list: list) -> Optional[Dict[str, Any]]:
|
||||
def _get_latest_message(self, message_id_list: list) -> dict[str, Any] | None:
|
||||
if not message_id_list:
|
||||
return None
|
||||
return message_id_list[-1].get("message")
|
||||
|
||||
def _find_poke_notice(self, message_id_list: list) -> Optional[Dict[str, Any]]:
|
||||
def _find_poke_notice(self, message_id_list: list) -> dict[str, Any] | None:
|
||||
"""在消息列表中寻找戳一戳的通知消息"""
|
||||
for item in reversed(message_id_list):
|
||||
message = item.get("message")
|
||||
|
||||
@@ -3,7 +3,6 @@ PlanGenerator: 负责搜集和汇总所有决策所需的信息,生成一个
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
@@ -85,7 +84,7 @@ class ChatterPlanGenerator:
|
||||
chat_history=[],
|
||||
)
|
||||
|
||||
async def _get_available_actions(self, chat_type: ChatType, mode: ChatMode) -> Dict[str, ActionInfo]:
|
||||
async def _get_available_actions(self, chat_type: ChatType, mode: ChatMode) -> dict[str, ActionInfo]:
|
||||
"""
|
||||
获取当前可用的动作列表。
|
||||
|
||||
@@ -152,7 +151,7 @@ class ChatterPlanGenerator:
|
||||
# 如果获取失败,返回空列表
|
||||
return []
|
||||
|
||||
def get_generator_stats(self) -> Dict:
|
||||
def get_generator_stats(self) -> dict:
|
||||
"""
|
||||
获取生成器统计信息。
|
||||
|
||||
|
||||
@@ -4,23 +4,20 @@
|
||||
"""
|
||||
|
||||
from dataclasses import asdict
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||
from src.mood.mood_manager import mood_manager
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.common.data_models.info_data_model import Plan
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.common.data_models.info_data_model import Plan
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
|
||||
# 导入提示词模块以确保其被初始化
|
||||
from src.plugins.built_in.affinity_flow_chatter import planner_prompts # noqa
|
||||
@@ -63,7 +60,7 @@ class ChatterActionPlanner:
|
||||
"other_actions_executed": 0,
|
||||
}
|
||||
|
||||
async def plan(self, context: "StreamContext" = None) -> Tuple[List[Dict], Optional[Dict]]:
|
||||
async def plan(self, context: "StreamContext" = None) -> tuple[list[dict], dict | None]:
|
||||
"""
|
||||
执行完整的增强版规划流程。
|
||||
|
||||
@@ -85,26 +82,27 @@ class ChatterActionPlanner:
|
||||
self.planner_stats["failed_plans"] += 1
|
||||
return [], None
|
||||
|
||||
async def _enhanced_plan_flow(self, context: "StreamContext") -> Tuple[List[Dict], Optional[Dict]]:
|
||||
async def _enhanced_plan_flow(self, context: "StreamContext") -> tuple[list[dict], dict | None]:
|
||||
"""执行增强版规划流程"""
|
||||
try:
|
||||
# 在规划前,先进行动作修改
|
||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
|
||||
action_modifier = ActionModifier(self.action_manager, self.chat_id)
|
||||
await action_modifier.modify_actions()
|
||||
|
||||
|
||||
# 1. 生成初始 Plan
|
||||
initial_plan = await self.generator.generate(context.chat_mode)
|
||||
|
||||
# 确保Plan中包含所有当前可用的动作
|
||||
initial_plan.available_actions = self.action_manager.get_using_actions()
|
||||
|
||||
|
||||
unread_messages = context.get_unread_messages() if context else []
|
||||
# 2. 使用新的兴趣度管理系统进行评分
|
||||
score = 0.0
|
||||
should_reply = False
|
||||
reply_not_available = False
|
||||
interest_updates: List[Dict[str, Any]] = []
|
||||
interest_updates: list[dict[str, Any]] = []
|
||||
|
||||
if unread_messages:
|
||||
# 为每条消息计算兴趣度,并延迟提交数据库更新
|
||||
@@ -117,7 +115,9 @@ class ChatterActionPlanner:
|
||||
message_interest = interest_score.total_score
|
||||
|
||||
message.interest_value = message_interest
|
||||
message.should_reply = message_interest > global_config.affinity_flow.non_reply_action_interest_threshold
|
||||
message.should_reply = (
|
||||
message_interest > global_config.affinity_flow.non_reply_action_interest_threshold
|
||||
)
|
||||
|
||||
interest_updates.append(
|
||||
{
|
||||
@@ -191,7 +191,7 @@ class ChatterActionPlanner:
|
||||
self.planner_stats["failed_plans"] += 1
|
||||
return [], None
|
||||
|
||||
async def _commit_interest_updates(self, updates: List[Dict[str, Any]]) -> None:
|
||||
async def _commit_interest_updates(self, updates: list[dict[str, Any]]) -> None:
|
||||
"""统一更新消息兴趣度,减少数据库写入次数"""
|
||||
if not updates:
|
||||
return
|
||||
@@ -218,7 +218,7 @@ class ChatterActionPlanner:
|
||||
except Exception as e:
|
||||
logger.warning(f"批量更新数据库兴趣度失败: {e}")
|
||||
|
||||
def _update_stats_from_execution_result(self, execution_result: Dict[str, any]):
|
||||
def _update_stats_from_execution_result(self, execution_result: dict[str, any]):
|
||||
"""根据执行结果更新规划器统计"""
|
||||
if not execution_result:
|
||||
return
|
||||
@@ -242,7 +242,7 @@ class ChatterActionPlanner:
|
||||
self.planner_stats["replies_generated"] += reply_count
|
||||
self.planner_stats["other_actions_executed"] += other_count
|
||||
|
||||
def _build_return_result(self, plan: "Plan") -> Tuple[List[Dict], Optional[Dict]]:
|
||||
def _build_return_result(self, plan: "Plan") -> tuple[list[dict], dict | None]:
|
||||
"""构建返回结果"""
|
||||
final_actions = plan.decided_actions or []
|
||||
final_target_message = next((act.action_message for act in final_actions if act.action_message), None)
|
||||
@@ -259,7 +259,7 @@ class ChatterActionPlanner:
|
||||
|
||||
return final_actions_dict, final_target_message_dict
|
||||
|
||||
def get_planner_stats(self) -> Dict[str, any]:
|
||||
def get_planner_stats(self) -> dict[str, any]:
|
||||
"""获取规划器统计"""
|
||||
return self.planner_stats.copy()
|
||||
|
||||
@@ -268,7 +268,7 @@ class ChatterActionPlanner:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(self.chat_id)
|
||||
return chat_mood.mood_state
|
||||
|
||||
def get_mood_stats(self) -> Dict[str, any]:
|
||||
def get_mood_stats(self) -> dict[str, any]:
|
||||
"""获取情绪状态统计"""
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(self.chat_id)
|
||||
return {
|
||||
|
||||
@@ -2,12 +2,10 @@
|
||||
亲和力聊天处理器插件
|
||||
"""
|
||||
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis.plugin_register_api import register_plugin
|
||||
from src.plugin_system.base.base_plugin import BasePlugin
|
||||
from src.plugin_system.base.component_types import ComponentInfo
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("affinity_chatter_plugin")
|
||||
|
||||
@@ -29,7 +27,7 @@ class AffinityChatterPlugin(BasePlugin):
|
||||
# 简单的 config_schema 占位(如果将来需要配置可扩展)
|
||||
config_schema = {}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
|
||||
"""返回插件包含的组件列表(ChatterInfo, AffinityChatter)
|
||||
|
||||
这里采用延迟导入 AffinityChatter 来避免循环依赖和启动顺序问题。
|
||||
|
||||
@@ -5,15 +5,15 @@
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config, global_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import UserRelationships, Messages
|
||||
from sqlalchemy import select, desc
|
||||
from sqlalchemy import desc, select
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import Messages, UserRelationships
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger("chatter_relationship_tracker")
|
||||
|
||||
@@ -22,15 +22,15 @@ class ChatterRelationshipTracker:
|
||||
"""用户关系追踪器"""
|
||||
|
||||
def __init__(self, interest_scoring_system=None):
|
||||
self.tracking_users: Dict[str, Dict] = {} # user_id -> interaction_data
|
||||
self.tracking_users: dict[str, dict] = {} # user_id -> interaction_data
|
||||
self.max_tracking_users = 3
|
||||
self.update_interval_minutes = 30
|
||||
self.last_update_time = time.time()
|
||||
self.relationship_history: List[Dict] = []
|
||||
self.relationship_history: list[dict] = []
|
||||
self.interest_scoring_system = interest_scoring_system
|
||||
|
||||
# 用户关系缓存 (user_id -> {"relationship_text": str, "relationship_score": float, "last_tracked": float})
|
||||
self.user_relationship_cache: Dict[str, Dict] = {}
|
||||
self.user_relationship_cache: dict[str, dict] = {}
|
||||
self.cache_expiry_hours = 1 # 缓存过期时间(小时)
|
||||
|
||||
# 关系更新LLM
|
||||
@@ -91,7 +91,7 @@ class ChatterRelationshipTracker:
|
||||
|
||||
logger.debug(f"添加用户交互追踪: {user_id}")
|
||||
|
||||
async def check_and_update_relationships(self) -> List[Dict]:
|
||||
async def check_and_update_relationships(self) -> list[dict]:
|
||||
"""检查并更新用户关系"""
|
||||
current_time = time.time()
|
||||
if current_time - self.last_update_time < self.update_interval_minutes * 60:
|
||||
@@ -108,7 +108,7 @@ class ChatterRelationshipTracker:
|
||||
self.last_update_time = current_time
|
||||
return updates
|
||||
|
||||
async def _update_user_relationship(self, interaction: Dict) -> Optional[Dict]:
|
||||
async def _update_user_relationship(self, interaction: dict) -> dict | None:
|
||||
"""更新单个用户的关系"""
|
||||
try:
|
||||
# 获取bot人设信息
|
||||
@@ -201,11 +201,11 @@ class ChatterRelationshipTracker:
|
||||
|
||||
return None
|
||||
|
||||
def get_tracking_users(self) -> Dict[str, Dict]:
|
||||
def get_tracking_users(self) -> dict[str, dict]:
|
||||
"""获取正在追踪的用户"""
|
||||
return self.tracking_users.copy()
|
||||
|
||||
def get_user_interaction(self, user_id: str) -> Optional[Dict]:
|
||||
def get_user_interaction(self, user_id: str) -> dict | None:
|
||||
"""获取特定用户的交互记录"""
|
||||
return self.tracking_users.get(user_id)
|
||||
|
||||
@@ -220,11 +220,11 @@ class ChatterRelationshipTracker:
|
||||
self.tracking_users.clear()
|
||||
logger.info("清空所有用户追踪")
|
||||
|
||||
def get_relationship_history(self) -> List[Dict]:
|
||||
def get_relationship_history(self) -> list[dict]:
|
||||
"""获取关系历史记录"""
|
||||
return self.relationship_history.copy()
|
||||
|
||||
def add_to_history(self, relationship_update: Dict):
|
||||
def add_to_history(self, relationship_update: dict):
|
||||
"""添加到关系历史"""
|
||||
self.relationship_history.append({**relationship_update, "update_time": time.time()})
|
||||
|
||||
@@ -232,7 +232,7 @@ class ChatterRelationshipTracker:
|
||||
if len(self.relationship_history) > 100:
|
||||
self.relationship_history = self.relationship_history[-100:]
|
||||
|
||||
def get_tracker_stats(self) -> Dict:
|
||||
def get_tracker_stats(self) -> dict:
|
||||
"""获取追踪器统计"""
|
||||
return {
|
||||
"tracking_users": len(self.tracking_users),
|
||||
@@ -268,7 +268,7 @@ class ChatterRelationshipTracker:
|
||||
self.add_to_history(update_info)
|
||||
logger.info(f"强制更新用户关系: {user_id} -> {new_score:.2f}")
|
||||
|
||||
def get_user_summary(self, user_id: str) -> Dict:
|
||||
def get_user_summary(self, user_id: str) -> dict:
|
||||
"""获取用户交互总结"""
|
||||
if user_id not in self.tracking_users:
|
||||
return {}
|
||||
@@ -313,7 +313,7 @@ class ChatterRelationshipTracker:
|
||||
# 数据库中也没有,返回默认值
|
||||
return global_config.affinity_flow.base_relationship_score
|
||||
|
||||
async def _get_user_relationship_from_db(self, user_id: str) -> Optional[Dict]:
|
||||
async def _get_user_relationship_from_db(self, user_id: str) -> dict | None:
|
||||
"""从数据库获取用户关系数据"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
@@ -431,7 +431,7 @@ class ChatterRelationshipTracker:
|
||||
|
||||
return 0
|
||||
|
||||
async def _get_last_bot_reply_to_user(self, user_id: str) -> Optional[DatabaseMessages]:
|
||||
async def _get_last_bot_reply_to_user(self, user_id: str) -> DatabaseMessages | None:
|
||||
"""获取上次bot回复该用户的消息"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
@@ -455,7 +455,7 @@ class ChatterRelationshipTracker:
|
||||
|
||||
return None
|
||||
|
||||
async def _get_user_reactions_after_reply(self, user_id: str, reply_time: float) -> List[DatabaseMessages]:
|
||||
async def _get_user_reactions_after_reply(self, user_id: str, reply_time: float) -> list[DatabaseMessages]:
|
||||
"""获取用户在bot回复后的反应消息"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
@@ -511,7 +511,7 @@ class ChatterRelationshipTracker:
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
last_bot_reply: DatabaseMessages,
|
||||
user_reactions: List[DatabaseMessages],
|
||||
user_reactions: list[DatabaseMessages],
|
||||
current_text: str,
|
||||
current_score: float,
|
||||
current_reply: str,
|
||||
@@ -596,7 +596,7 @@ class ChatterRelationshipTracker:
|
||||
quality = response_data.get("interaction_quality", "medium")
|
||||
|
||||
# 更新数据库
|
||||
await self._update_user_relationship_in_db(user_id, new_text, new_score)
|
||||
await self._update_user_relationship_in_db(user_id, new_text, new_score)
|
||||
|
||||
# 更新缓存
|
||||
self.user_relationship_cache[user_id] = {
|
||||
|
||||
@@ -8,9 +8,9 @@
|
||||
- 测试功能
|
||||
"""
|
||||
|
||||
from src.plugin_system.base import BaseCommand
|
||||
from src.chat.antipromptinjector import get_anti_injector
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base import BaseCommand
|
||||
|
||||
logger = get_logger("anti_injector.commands")
|
||||
|
||||
@@ -56,5 +56,5 @@ class AntiInjectorStatusCommand(BaseCommand):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取反注入系统状态失败: {e}")
|
||||
await self.send_text(f"获取状态失败: {str(e)}")
|
||||
return False, f"获取状态失败: {str(e)}", True
|
||||
await self.send_text(f"获取状态失败: {e!s}")
|
||||
return False, f"获取状态失败: {e!s}", True
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
import random
|
||||
from typing import Tuple
|
||||
|
||||
# 导入新插件系统
|
||||
from src.plugin_system import BaseAction, ActionActivationType, ChatMode
|
||||
from src.chat.emoji_system.emoji_history import add_emoji_to_history, get_recent_emojis
|
||||
from src.chat.emoji_system.emoji_manager import MaiEmoji, get_emoji_manager
|
||||
from src.chat.utils.utils_image import image_path_to_base64
|
||||
|
||||
# 导入依赖的系统组件
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
# 导入新插件系统
|
||||
from src.plugin_system import ActionActivationType, BaseAction, ChatMode
|
||||
|
||||
# 导入API模块 - 标准Python包方式
|
||||
from src.plugin_system.apis import llm_api, message_api
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager, MaiEmoji
|
||||
from src.chat.utils.utils_image import image_path_to_base64
|
||||
from src.config.config import global_config
|
||||
from src.chat.emoji_system.emoji_history import get_recent_emojis, add_emoji_to_history
|
||||
|
||||
|
||||
logger = get_logger("emoji")
|
||||
|
||||
@@ -59,7 +58,7 @@ class EmojiAction(BaseAction):
|
||||
# 关联类型
|
||||
associated_types = ["emoji"]
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""执行表情动作"""
|
||||
logger.info(f"{self.log_prefix} 决定发送表情")
|
||||
|
||||
@@ -268,7 +267,7 @@ class EmojiAction(BaseAction):
|
||||
if not success:
|
||||
logger.error(f"{self.log_prefix} 表情包发送失败")
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True, action_prompt_display=f"发送了一个表情包,但失败了", action_done=False
|
||||
action_build_into_prompt=True, action_prompt_display="发送了一个表情包,但失败了", action_done=False
|
||||
)
|
||||
return False, "表情包发送失败"
|
||||
|
||||
@@ -279,11 +278,11 @@ class EmojiAction(BaseAction):
|
||||
logger.error(f"{self.log_prefix} 添加表情到历史记录时出错: {e}")
|
||||
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True, action_prompt_display=f"发送了一个表情包", action_done=True
|
||||
action_build_into_prompt=True, action_prompt_display="发送了一个表情包", action_done=True
|
||||
)
|
||||
|
||||
return True, f"发送表情包: {emoji_description}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 表情动作执行失败: {e}", exc_info=True)
|
||||
return False, f"表情发送失败: {str(e)}"
|
||||
return False, f"表情发送失败: {e!s}"
|
||||
|
||||
@@ -5,19 +5,16 @@
|
||||
这是系统的内置插件,提供基础的聊天交互功能
|
||||
"""
|
||||
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
# 导入新插件系统
|
||||
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
|
||||
|
||||
# 导入依赖的系统组件
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# 导入新插件系统
|
||||
from src.plugin_system import BasePlugin, ComponentInfo, register_plugin
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from src.plugins.built_in.core_actions.anti_injector_manager import AntiInjectorStatusCommand
|
||||
|
||||
# 导入API模块 - 标准Python包方式
|
||||
from src.plugins.built_in.core_actions.emoji import EmojiAction
|
||||
from src.plugins.built_in.core_actions.anti_injector_manager import AntiInjectorStatusCommand
|
||||
|
||||
logger = get_logger("core_actions")
|
||||
|
||||
@@ -62,7 +59,7 @@ class CoreActionsPlugin(BasePlugin):
|
||||
},
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
|
||||
"""返回插件包含的组件列表"""
|
||||
|
||||
# --- 根据配置注册组件 ---
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
||||
from src.plugin_system import BaseTool, ToolParamType
|
||||
|
||||
logger = get_logger("lpmm_get_knowledge_tool")
|
||||
@@ -19,7 +19,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
|
||||
]
|
||||
available_for_llm = global_config.lpmm_knowledge.enable
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行知识库搜索
|
||||
|
||||
Args:
|
||||
@@ -47,16 +47,16 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
|
||||
knowledge_parts = []
|
||||
for i, item in enumerate(knowledge_info["knowledge_items"]):
|
||||
knowledge_parts.append(f"- {item.get('content', 'N/A')}")
|
||||
|
||||
|
||||
knowledge_text = "\n".join(knowledge_parts)
|
||||
summary = knowledge_info.get('summary', '无总结')
|
||||
summary = knowledge_info.get("summary", "无总结")
|
||||
content = f"关于 '{query}', 你知道以下信息:\n{knowledge_text}\n\n总结: {summary}"
|
||||
else:
|
||||
content = f"关于 '{query}',你的知识库里好像没有相关的信息呢"
|
||||
return {"type": "lpmm_knowledge", "id": query, "content": content}
|
||||
except Exception as e:
|
||||
# 捕获异常并记录错误
|
||||
logger.error(f"知识库搜索工具执行失败: {str(e)}")
|
||||
logger.error(f"知识库搜索工具执行失败: {e!s}")
|
||||
# 在其他异常情况下,确保 id 仍然是 query (如果它被定义了)
|
||||
query_id = query if "query" in locals() else "unknown_query"
|
||||
return {"type": "info", "id": query_id, "content": f"lpmm知识库搜索失败,炸了: {str(e)}"}
|
||||
return {"type": "info", "id": query_id, "content": f"lpmm知识库搜索失败,炸了: {e!s}"}
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
让框架能够发现并加载子目录中的组件。
|
||||
"""
|
||||
|
||||
from .plugin import MaiZoneRefactoredPlugin as MaiZoneRefactoredPlugin
|
||||
from .actions.send_feed_action import SendFeedAction as SendFeedAction
|
||||
from .actions.read_feed_action import ReadFeedAction as ReadFeedAction
|
||||
from .actions.send_feed_action import SendFeedAction as SendFeedAction
|
||||
from .commands.send_feed_command import SendFeedCommand as SendFeedCommand
|
||||
from .plugin import MaiZoneRefactoredPlugin as MaiZoneRefactoredPlugin
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
阅读说说动作组件
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BaseAction, ActionActivationType, ChatMode
|
||||
from src.plugin_system import ActionActivationType, BaseAction, ChatMode
|
||||
from src.plugin_system.apis import generator_api
|
||||
from src.plugin_system.apis.permission_api import permission_api
|
||||
|
||||
from ..services.manager import get_qzone_service
|
||||
|
||||
logger = get_logger("MaiZone.ReadFeedAction")
|
||||
@@ -41,7 +39,7 @@ class ReadFeedAction(BaseAction):
|
||||
# 使用权限API检查用户是否有阅读说说的权限
|
||||
return await permission_api.check_permission(platform, user_id, "plugin.maizone.read_feed")
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""
|
||||
执行动作的核心逻辑。
|
||||
"""
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
发送说说动作组件
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BaseAction, ActionActivationType, ChatMode
|
||||
from src.plugin_system import ActionActivationType, BaseAction, ChatMode
|
||||
from src.plugin_system.apis import generator_api
|
||||
from src.plugin_system.apis.permission_api import permission_api
|
||||
|
||||
from ..services.manager import get_qzone_service
|
||||
|
||||
logger = get_logger("MaiZone.SendFeedAction")
|
||||
@@ -41,7 +39,7 @@ class SendFeedAction(BaseAction):
|
||||
# 使用权限API检查用户是否有发送说说的权限
|
||||
return await permission_api.check_permission(platform, user_id, "plugin.maizone.send_feed")
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""
|
||||
执行动作的核心逻辑。
|
||||
"""
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
发送说说命令 await self.send_text(f"收到!正在为你生成关于"{topic or '随机'}"的说说,请稍候...【热重载测试成功】")件
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.plus_command import PlusCommand
|
||||
from src.plugin_system.base.command_args import CommandArgs
|
||||
from src.plugin_system.base.plus_command import PlusCommand
|
||||
from src.plugin_system.utils.permission_decorators import require_permission
|
||||
from ..services.manager import get_qzone_service, get_config_getter
|
||||
|
||||
from ..services.manager import get_config_getter, get_qzone_service
|
||||
|
||||
logger = get_logger("MaiZone.SendFeedCommand")
|
||||
|
||||
@@ -28,7 +26,7 @@ class SendFeedCommand(PlusCommand):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@require_permission("plugin.maizone.send_feed")
|
||||
async def execute(self, args: CommandArgs) -> Tuple[bool, str, bool]:
|
||||
async def execute(self, args: CommandArgs) -> tuple[bool, str, bool]:
|
||||
"""
|
||||
执行命令的核心逻辑。
|
||||
"""
|
||||
|
||||
@@ -1,28 +1,26 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
MaiZone(麦麦空间)- 重构版
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BasePlugin, ComponentInfo, register_plugin
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from src.plugin_system.apis.permission_api import permission_api
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
|
||||
from .actions.read_feed_action import ReadFeedAction
|
||||
from .actions.send_feed_action import SendFeedAction
|
||||
from .commands.send_feed_command import SendFeedCommand
|
||||
from .services.content_service import ContentService
|
||||
from .services.image_service import ImageService
|
||||
from .services.qzone_service import QZoneService
|
||||
from .services.scheduler_service import SchedulerService
|
||||
from .services.monitor_service import MonitorService
|
||||
from .services.cookie_service import CookieService
|
||||
from .services.reply_tracker_service import ReplyTrackerService
|
||||
from .services.image_service import ImageService
|
||||
from .services.manager import register_service
|
||||
from .services.monitor_service import MonitorService
|
||||
from .services.qzone_service import QZoneService
|
||||
from .services.reply_tracker_service import ReplyTrackerService
|
||||
from .services.scheduler_service import SchedulerService
|
||||
|
||||
logger = get_logger("MaiZone.Plugin")
|
||||
|
||||
@@ -35,8 +33,8 @@ class MaiZoneRefactoredPlugin(BasePlugin):
|
||||
plugin_description: str = "重构版的MaiZone插件"
|
||||
config_file_name: str = "config.toml"
|
||||
enable_plugin: bool = True
|
||||
dependencies: List[str] = []
|
||||
python_dependencies: List[str] = []
|
||||
dependencies: list[str] = []
|
||||
python_dependencies: list[str] = []
|
||||
|
||||
config_schema: dict = {
|
||||
"plugin": {"enable": ConfigField(type=bool, default=True, description="是否启用插件")},
|
||||
@@ -87,6 +85,7 @@ class MaiZoneRefactoredPlugin(BasePlugin):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def on_plugin_loaded(self):
|
||||
"""插件加载完成后的回调,初始化服务并启动后台任务"""
|
||||
# --- 注册权限节点 ---
|
||||
@@ -124,7 +123,7 @@ class MaiZoneRefactoredPlugin(BasePlugin):
|
||||
asyncio.create_task(monitor_service.start())
|
||||
logger.info("MaiZone后台监控和定时任务已启动。")
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
|
||||
return [
|
||||
(SendFeedAction.get_action_info(), SendFeedAction),
|
||||
(ReadFeedAction.get_action_info(), ReadFeedAction),
|
||||
|
||||
@@ -1,23 +1,23 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
内容服务模块
|
||||
负责生成所有与QQ空间相关的文本内容,例如说说、评论等。
|
||||
"""
|
||||
|
||||
from typing import Callable, Optional
|
||||
import datetime
|
||||
|
||||
import base64
|
||||
import aiohttp
|
||||
from src.common.logger import get_logger
|
||||
import imghdr
|
||||
import asyncio
|
||||
from src.plugin_system.apis import llm_api, config_api, generator_api
|
||||
from src.plugin_system.apis.cross_context_api import get_chat_history_by_group_name
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
import base64
|
||||
import datetime
|
||||
import imghdr
|
||||
from collections.abc import Callable
|
||||
|
||||
import aiohttp
|
||||
from maim_message import UserInfo
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.api_ada_configs import TaskConfig
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.plugin_system.apis import config_api, generator_api, llm_api
|
||||
from src.plugin_system.apis.cross_context_api import get_chat_history_by_group_name
|
||||
|
||||
# 导入旧的工具函数,我们稍后会考虑是否也需要重构它
|
||||
from ..utils.history_utils import get_send_history
|
||||
@@ -38,7 +38,7 @@ class ContentService:
|
||||
"""
|
||||
self.get_config = get_config
|
||||
|
||||
async def generate_story(self, topic: str, context: Optional[str] = None) -> str:
|
||||
async def generate_story(self, topic: str, context: str | None = None) -> str:
|
||||
"""
|
||||
根据指定主题和可选的上下文生成一条QQ空间说说。
|
||||
|
||||
@@ -231,7 +231,7 @@ class ContentService:
|
||||
return ""
|
||||
return ""
|
||||
|
||||
async def _describe_image(self, image_url: str) -> Optional[str]:
|
||||
async def _describe_image(self, image_url: str) -> str | None:
|
||||
"""
|
||||
使用LLM识别图片内容。
|
||||
"""
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Cookie服务模块
|
||||
负责从多种来源获取、缓存和管理QZone的Cookie。
|
||||
"""
|
||||
|
||||
import orjson
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Dict
|
||||
|
||||
import aiohttp
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
@@ -29,28 +29,28 @@ class CookieService:
|
||||
"""获取指定QQ账号的cookie文件路径"""
|
||||
return self.cookie_dir / f"cookies-{qq_account}.json"
|
||||
|
||||
def _save_cookies_to_file(self, qq_account: str, cookies: Dict[str, str]):
|
||||
def _save_cookies_to_file(self, qq_account: str, cookies: dict[str, str]):
|
||||
"""将Cookie保存到本地文件"""
|
||||
cookie_file_path = self._get_cookie_file_path(qq_account)
|
||||
try:
|
||||
with open(cookie_file_path, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(cookies, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
logger.info(f"Cookie已成功缓存至: {cookie_file_path}")
|
||||
except IOError as e:
|
||||
except OSError as e:
|
||||
logger.error(f"无法写入Cookie文件 {cookie_file_path}: {e}")
|
||||
|
||||
def _load_cookies_from_file(self, qq_account: str) -> Optional[Dict[str, str]]:
|
||||
def _load_cookies_from_file(self, qq_account: str) -> dict[str, str] | None:
|
||||
"""从本地文件加载Cookie"""
|
||||
cookie_file_path = self._get_cookie_file_path(qq_account)
|
||||
if cookie_file_path.exists():
|
||||
try:
|
||||
with open(cookie_file_path, "r", encoding="utf-8") as f:
|
||||
with open(cookie_file_path, encoding="utf-8") as f:
|
||||
return orjson.loads(f.read())
|
||||
except (IOError, orjson.JSONDecodeError) as e:
|
||||
except (OSError, orjson.JSONDecodeError) as e:
|
||||
logger.error(f"无法读取或解析Cookie文件 {cookie_file_path}: {e}")
|
||||
return None
|
||||
|
||||
async def _get_cookies_from_adapter(self, stream_id: Optional[str]) -> Optional[Dict[str, str]]:
|
||||
async def _get_cookies_from_adapter(self, stream_id: str | None) -> dict[str, str] | None:
|
||||
"""通过Adapter API获取Cookie"""
|
||||
try:
|
||||
params = {"domain": "user.qzone.qq.com"}
|
||||
@@ -73,7 +73,7 @@ class CookieService:
|
||||
logger.error(f"通过Adapter获取Cookie时发生异常: {e}")
|
||||
return None
|
||||
|
||||
async def _get_cookies_from_http(self) -> Optional[Dict[str, str]]:
|
||||
async def _get_cookies_from_http(self) -> dict[str, str] | None:
|
||||
"""通过备用HTTP端点获取Cookie"""
|
||||
host = self.get_config("cookie.http_fallback_host", "172.20.130.55")
|
||||
port = self.get_config("cookie.http_fallback_port", "9999")
|
||||
@@ -110,7 +110,7 @@ class CookieService:
|
||||
logger.error(f"通过HTTP备用地址 {http_url} 获取Cookie失败: {e}")
|
||||
return None
|
||||
|
||||
async def get_cookies(self, qq_account: str, stream_id: Optional[str]) -> Optional[Dict[str, str]]:
|
||||
async def get_cookies(self, qq_account: str, stream_id: str | None) -> dict[str, str] | None:
|
||||
"""
|
||||
获取Cookie,按以下顺序尝试:
|
||||
1. HTTP备用端点 (更稳定)
|
||||
@@ -140,5 +140,7 @@ class CookieService:
|
||||
self._save_cookies_to_file(qq_account, cookies)
|
||||
return cookies
|
||||
|
||||
logger.error(f"为 {qq_account} 获取Cookie的所有方法均失败。请确保Napcat HTTP服务或Adapter连接至少有一个正常工作,或存在有效的本地Cookie文件。")
|
||||
logger.error(
|
||||
f"为 {qq_account} 获取Cookie的所有方法均失败。请确保Napcat HTTP服务或Adapter连接至少有一个正常工作,或存在有效的本地Cookie文件。"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
图片服务模块
|
||||
负责处理所有与图片相关的任务,特别是AI生成图片。
|
||||
"""
|
||||
|
||||
import base64
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import aiohttp
|
||||
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
服务管理器/定位器
|
||||
这是一个独立的模块,用于注册和获取插件内的全局服务实例,以避免循环导入。
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Callable
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from .qzone_service import QZoneService
|
||||
|
||||
# --- 全局服务注册表 ---
|
||||
_services: Dict[str, Any] = {}
|
||||
_services: dict[str, Any] = {}
|
||||
|
||||
|
||||
def register_service(name: str, instance: Any):
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
好友动态监控服务
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import Callable
|
||||
from collections.abc import Callable
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .qzone_service import QZoneService
|
||||
|
||||
logger = get_logger("MaiZone.MonitorService")
|
||||
|
||||
@@ -1,32 +1,33 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
QQ空间服务模块
|
||||
封装了所有与QQ空间API的直接交互,是插件的核心业务逻辑层。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import orjson
|
||||
import base64
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import base64
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Dict, Any, List, Tuple
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import bs4
|
||||
import json5
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api, person_api
|
||||
import orjson
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api, person_api
|
||||
|
||||
from .content_service import ContentService
|
||||
from .image_service import ImageService
|
||||
from .cookie_service import CookieService
|
||||
from .image_service import ImageService
|
||||
from .reply_tracker_service import ReplyTrackerService
|
||||
|
||||
logger = get_logger("MaiZone.QZoneService")
|
||||
@@ -64,7 +65,7 @@ class QZoneService:
|
||||
|
||||
# --- Public Methods (High-Level Business Logic) ---
|
||||
|
||||
async def send_feed(self, topic: str, stream_id: Optional[str]) -> Dict[str, Any]:
|
||||
async def send_feed(self, topic: str, stream_id: str | None) -> dict[str, Any]:
|
||||
"""发送一条说说"""
|
||||
# --- 获取互通组上下文 ---
|
||||
context = await self._get_intercom_context(stream_id) if stream_id else None
|
||||
@@ -92,7 +93,7 @@ class QZoneService:
|
||||
logger.error(f"发布说说时发生异常: {e}", exc_info=True)
|
||||
return {"success": False, "message": f"发布说说异常: {e}"}
|
||||
|
||||
async def send_feed_from_activity(self, activity: str) -> Dict[str, Any]:
|
||||
async def send_feed_from_activity(self, activity: str) -> dict[str, Any]:
|
||||
"""根据日程活动发送一条说说"""
|
||||
story = await self.content_service.generate_story_from_activity(activity)
|
||||
if not story:
|
||||
@@ -118,7 +119,7 @@ class QZoneService:
|
||||
logger.error(f"根据活动发布说说时发生异常: {e}", exc_info=True)
|
||||
return {"success": False, "message": f"发布说说异常: {e}"}
|
||||
|
||||
async def read_and_process_feeds(self, target_name: str, stream_id: Optional[str]) -> Dict[str, Any]:
|
||||
async def read_and_process_feeds(self, target_name: str, stream_id: str | None) -> dict[str, Any]:
|
||||
"""读取并处理指定好友的说说"""
|
||||
target_person_id = await person_api.get_person_id_by_name(target_name)
|
||||
if not target_person_id:
|
||||
@@ -147,7 +148,7 @@ class QZoneService:
|
||||
logger.error(f"读取和处理说说时发生异常: {e}", exc_info=True)
|
||||
return {"success": False, "message": f"处理说说异常: {e}"}
|
||||
|
||||
async def monitor_feeds(self, stream_id: Optional[str] = None):
|
||||
async def monitor_feeds(self, stream_id: str | None = None):
|
||||
"""监控并处理所有好友的动态,包括回复自己说说的评论"""
|
||||
logger.info("开始执行好友动态监控...")
|
||||
qq_account = config_api.get_global_config("bot.qq_account", "")
|
||||
@@ -189,7 +190,7 @@ class QZoneService:
|
||||
|
||||
# --- Internal Helper Methods ---
|
||||
|
||||
async def _get_intercom_context(self, stream_id: str) -> Optional[str]:
|
||||
async def _get_intercom_context(self, stream_id: str) -> str | None:
|
||||
"""
|
||||
根据 stream_id 查找其所属的互通组,并构建该组的聊天上下文。
|
||||
|
||||
@@ -247,7 +248,7 @@ class QZoneService:
|
||||
logger.debug(f"Stream ID '{stream_id}' 未在任何互通组中找到。")
|
||||
return None
|
||||
|
||||
async def _reply_to_own_feed_comments(self, feed: Dict, api_client: Dict):
|
||||
async def _reply_to_own_feed_comments(self, feed: dict, api_client: dict):
|
||||
"""处理对自己说说的评论并进行回复"""
|
||||
qq_account = config_api.get_global_config("bot.qq_account", "")
|
||||
comments = feed.get("comments", [])
|
||||
@@ -290,9 +291,7 @@ class QZoneService:
|
||||
comment_content = comment.get("content", "")
|
||||
|
||||
try:
|
||||
reply_content = await self.content_service.generate_comment_reply(
|
||||
content, comment_content, nickname
|
||||
)
|
||||
reply_content = await self.content_service.generate_comment_reply(content, comment_content, nickname)
|
||||
if reply_content:
|
||||
success = await api_client["reply"](fid, qq_account, nickname, reply_content, comment_tid)
|
||||
if success:
|
||||
@@ -311,7 +310,7 @@ class QZoneService:
|
||||
if comment_key in self.processing_comments:
|
||||
self.processing_comments.remove(comment_key)
|
||||
|
||||
async def _validate_and_cleanup_reply_records(self, fid: str, my_replies: List[Dict]):
|
||||
async def _validate_and_cleanup_reply_records(self, fid: str, my_replies: list[dict]):
|
||||
"""验证并清理已删除的回复记录"""
|
||||
# 获取当前记录中该说说的所有已回复评论ID
|
||||
recorded_replied_comments = self.reply_tracker.get_replied_comments(fid)
|
||||
@@ -335,7 +334,7 @@ class QZoneService:
|
||||
self.reply_tracker.remove_reply_record(fid, comment_tid)
|
||||
logger.debug(f"已清理删除的回复记录: feed_id={fid}, comment_id={comment_tid}")
|
||||
|
||||
async def _process_single_feed(self, feed: Dict, api_client: Dict, target_qq: str, target_name: str):
|
||||
async def _process_single_feed(self, feed: dict, api_client: dict, target_qq: str, target_name: str):
|
||||
"""处理单条说说,决定是否评论和点赞"""
|
||||
content = feed.get("content", "")
|
||||
fid = feed.get("tid", "")
|
||||
@@ -373,7 +372,7 @@ class QZoneService:
|
||||
if random.random() <= self.get_config("read.like_possibility", 1.0):
|
||||
await api_client["like"](target_qq, fid)
|
||||
|
||||
def _load_local_images(self, image_dir: str) -> List[bytes]:
|
||||
def _load_local_images(self, image_dir: str) -> list[bytes]:
|
||||
"""随机加载本地图片(不删除文件)"""
|
||||
images = []
|
||||
if not image_dir or not os.path.exists(image_dir):
|
||||
@@ -434,7 +433,7 @@ class QZoneService:
|
||||
hash_val += (hash_val << 5) + ord(char)
|
||||
return str(hash_val & 2147483647)
|
||||
|
||||
async def _renew_and_load_cookies(self, qq_account: str, stream_id: Optional[str]) -> Optional[Dict[str, str]]:
|
||||
async def _renew_and_load_cookies(self, qq_account: str, stream_id: str | None) -> dict[str, str] | None:
|
||||
cookie_dir = Path(__file__).resolve().parent.parent / "cookies"
|
||||
cookie_dir.mkdir(exist_ok=True)
|
||||
cookie_file_path = cookie_dir / f"cookies-{qq_account}.json"
|
||||
@@ -482,7 +481,7 @@ class QZoneService:
|
||||
logger.error("所有获取Cookie的方式均失败。")
|
||||
return None
|
||||
|
||||
async def _fetch_cookies_http(self, host: str, port: int, napcat_token: str) -> Optional[Dict]:
|
||||
async def _fetch_cookies_http(self, host: str, port: int, napcat_token: str) -> dict | None:
|
||||
"""通过HTTP服务器获取Cookie"""
|
||||
# 从配置中读取主机和端口,如果未提供则使用传入的参数
|
||||
final_host = self.get_config("cookie.http_fallback_host", host)
|
||||
@@ -517,22 +516,24 @@ class QZoneService:
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(f"无法连接到Napcat服务(尝试 {attempt + 1}/{max_retries}): {url},错误: {str(e)}")
|
||||
logger.warning(f"无法连接到Napcat服务(尝试 {attempt + 1}/{max_retries}): {url},错误: {e!s}")
|
||||
await asyncio.sleep(retry_delay)
|
||||
retry_delay *= 2
|
||||
continue
|
||||
logger.error(f"无法连接到Napcat服务(最终尝试): {url},错误: {str(e)}")
|
||||
logger.error(f"无法连接到Napcat服务(最终尝试): {url},错误: {e!s}")
|
||||
raise RuntimeError(f"无法连接到Napcat服务: {url}") from e
|
||||
except Exception as e:
|
||||
logger.error(f"获取cookie异常: {str(e)}")
|
||||
logger.error(f"获取cookie异常: {e!s}")
|
||||
raise
|
||||
|
||||
raise RuntimeError(f"无法连接到Napcat服务: 超过最大重试次数({max_retries})")
|
||||
|
||||
async def _get_api_client(self, qq_account: str, stream_id: Optional[str]) -> Optional[Dict]:
|
||||
async def _get_api_client(self, qq_account: str, stream_id: str | None) -> dict | None:
|
||||
cookies = await self.cookie_service.get_cookies(qq_account, stream_id)
|
||||
if not cookies:
|
||||
logger.error("获取API客户端失败:未能获取到Cookie。请检查Napcat连接是否正常,或是否存在有效的本地Cookie文件。")
|
||||
logger.error(
|
||||
"获取API客户端失败:未能获取到Cookie。请检查Napcat连接是否正常,或是否存在有效的本地Cookie文件。"
|
||||
)
|
||||
return None
|
||||
|
||||
p_skey = cookies.get("p_skey") or cookies.get("p_skey".upper())
|
||||
@@ -559,7 +560,7 @@ class QZoneService:
|
||||
response.raise_for_status()
|
||||
return await response.text()
|
||||
|
||||
async def _publish(content: str, images: List[bytes]) -> Tuple[bool, str]:
|
||||
async def _publish(content: str, images: list[bytes]) -> tuple[bool, str]:
|
||||
"""发布说说"""
|
||||
try:
|
||||
post_data = {
|
||||
@@ -660,7 +661,7 @@ class QZoneService:
|
||||
|
||||
return picbo, richval
|
||||
|
||||
async def _upload_image(image_bytes: bytes, index: int) -> Optional[Dict[str, str]]:
|
||||
async def _upload_image(image_bytes: bytes, index: int) -> dict[str, str] | None:
|
||||
"""上传图片到QQ空间(完全按照原版实现)"""
|
||||
try:
|
||||
upload_url = "https://up.qzone.qq.com/cgi-bin/upload/cgi_upload_image"
|
||||
@@ -726,7 +727,8 @@ class QZoneService:
|
||||
return {"pic_bo": picbo, "richval": richval}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"从上传结果中提取图片参数失败: {e}, 上传结果: {upload_result}", exc_info=True
|
||||
f"从上传结果中提取图片参数失败: {e}, 上传结果: {upload_result}",
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
else:
|
||||
@@ -744,7 +746,7 @@ class QZoneService:
|
||||
logger.error(f"上传图片 {index + 1} 异常: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def _list_feeds(t_qq: str, num: int) -> List[Dict]:
|
||||
async def _list_feeds(t_qq: str, num: int) -> list[dict]:
|
||||
"""获取指定用户说说列表 (统一接口)"""
|
||||
try:
|
||||
# 统一使用 format=json 获取完整评论
|
||||
@@ -764,7 +766,9 @@ class QZoneService:
|
||||
json_data = orjson.loads(res_text)
|
||||
|
||||
if json_data.get("code") != 0:
|
||||
logger.warning(f"获取说说列表API返回错误: code={json_data.get('code')}, message={json_data.get('message')}")
|
||||
logger.warning(
|
||||
f"获取说说列表API返回错误: code={json_data.get('code')}, message={json_data.get('message')}"
|
||||
)
|
||||
return []
|
||||
|
||||
feeds_list = []
|
||||
@@ -797,7 +801,7 @@ class QZoneService:
|
||||
for c in commentlist:
|
||||
if not isinstance(c, dict):
|
||||
continue
|
||||
|
||||
|
||||
# 添加主评论
|
||||
comments.append(
|
||||
{
|
||||
@@ -822,7 +826,7 @@ class QZoneService:
|
||||
"parent_tid": c.get("tid"), # 父ID是主评论的ID
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
feeds_list.append(
|
||||
{
|
||||
"tid": msg.get("tid", ""),
|
||||
@@ -917,7 +921,7 @@ class QZoneService:
|
||||
logger.error(f"回复评论异常: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def _monitor_list_feeds(num: int) -> List[Dict]:
|
||||
async def _monitor_list_feeds(num: int) -> list[dict]:
|
||||
"""监控好友动态"""
|
||||
try:
|
||||
params = {
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
评论回复跟踪服务
|
||||
负责记录和管理已回复过的评论ID,避免重复回复
|
||||
@@ -7,7 +6,8 @@
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Set, Dict, Any, Union
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("MaiZone.ReplyTrackerService")
|
||||
@@ -27,7 +27,7 @@ class ReplyTrackerService:
|
||||
|
||||
# 内存中的已回复评论记录
|
||||
# 格式: {feed_id: {comment_id: timestamp, ...}, ...}
|
||||
self.replied_comments: Dict[str, Dict[str, float]] = {}
|
||||
self.replied_comments: dict[str, dict[str, float]] = {}
|
||||
|
||||
# 数据清理配置
|
||||
self.max_record_days = 30 # 保留30天的记录
|
||||
@@ -64,7 +64,7 @@ class ReplyTrackerService:
|
||||
try:
|
||||
if self.reply_record_file.exists():
|
||||
try:
|
||||
with open(self.reply_record_file, "r", encoding="utf-8") as f:
|
||||
with open(self.reply_record_file, encoding="utf-8") as f:
|
||||
file_content = f.read().strip()
|
||||
if not file_content: # 文件为空
|
||||
logger.warning("回复记录文件为空,将创建新的记录")
|
||||
@@ -173,7 +173,7 @@ class ReplyTrackerService:
|
||||
if total_removed > 0:
|
||||
logger.info(f"清理了 {total_removed} 条超过{self.max_record_days}天的过期回复记录")
|
||||
|
||||
def has_replied(self, feed_id: str, comment_id: Union[str, int]) -> bool:
|
||||
def has_replied(self, feed_id: str, comment_id: str | int) -> bool:
|
||||
"""
|
||||
检查是否已经回复过指定的评论
|
||||
|
||||
@@ -190,7 +190,7 @@ class ReplyTrackerService:
|
||||
comment_id_str = str(comment_id)
|
||||
return feed_id in self.replied_comments and comment_id_str in self.replied_comments[feed_id]
|
||||
|
||||
def mark_as_replied(self, feed_id: str, comment_id: Union[str, int]):
|
||||
def mark_as_replied(self, feed_id: str, comment_id: str | int):
|
||||
"""
|
||||
标记指定评论为已回复
|
||||
|
||||
@@ -219,7 +219,7 @@ class ReplyTrackerService:
|
||||
else:
|
||||
logger.error(f"标记评论时数据验证失败: feed_id={feed_id}, comment_id={comment_id}")
|
||||
|
||||
def get_replied_comments(self, feed_id: str) -> Set[str]:
|
||||
def get_replied_comments(self, feed_id: str) -> set[str]:
|
||||
"""
|
||||
获取指定说说下所有已回复的评论ID
|
||||
|
||||
@@ -234,7 +234,7 @@ class ReplyTrackerService:
|
||||
return {str(comment_id) for comment_id in self.replied_comments[feed_id].keys()}
|
||||
return set()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
获取回复记录统计信息
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
定时任务服务
|
||||
根据日程表定时发送说说。
|
||||
@@ -8,13 +7,14 @@ import asyncio
|
||||
import datetime
|
||||
import random
|
||||
import traceback
|
||||
from typing import Callable
|
||||
from collections.abc import Callable
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import MaiZoneScheduleStatus
|
||||
from src.common.logger import get_logger
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from sqlalchemy import select
|
||||
from src.common.database.sqlalchemy_models import MaiZoneScheduleStatus
|
||||
|
||||
from .qzone_service import QZoneService
|
||||
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
历史记录工具模块
|
||||
提供用于获取QQ空间发送历史的功能。
|
||||
"""
|
||||
|
||||
import orjson
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
import requests
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("MaiZone.HistoryUtils")
|
||||
@@ -26,11 +26,11 @@ class _CookieManager:
|
||||
return str(cookie_dir / f"cookies-{uin}.json")
|
||||
|
||||
@staticmethod
|
||||
def load_cookies(qq_account: str) -> Optional[Dict[str, str]]:
|
||||
def load_cookies(qq_account: str) -> dict[str, str] | None:
|
||||
cookie_file = _CookieManager.get_cookie_file_path(qq_account)
|
||||
if os.path.exists(cookie_file):
|
||||
try:
|
||||
with open(cookie_file, "r", encoding="utf-8") as f:
|
||||
with open(cookie_file, encoding="utf-8") as f:
|
||||
return orjson.loads(f.read())
|
||||
except Exception as e:
|
||||
logger.error(f"加载Cookie文件失败: {e}")
|
||||
@@ -42,7 +42,7 @@ class _SimpleQZoneAPI:
|
||||
|
||||
LIST_URL = "https://user.qzone.qq.com/proxy/domain/taotao.qq.com/cgi-bin/emotion_cgi_msglist_v6"
|
||||
|
||||
def __init__(self, cookies_dict: Optional[Dict[str, str]] = None):
|
||||
def __init__(self, cookies_dict: dict[str, str] | None = None):
|
||||
self.cookies = cookies_dict or {}
|
||||
self.gtk2 = ""
|
||||
p_skey = self.cookies.get("p_skey") or self.cookies.get("p_skey".upper())
|
||||
@@ -55,7 +55,7 @@ class _SimpleQZoneAPI:
|
||||
hash_val += (hash_val << 5) + ord(char)
|
||||
return str(hash_val & 2147483647)
|
||||
|
||||
def get_feed_list(self, target_qq: str, num: int) -> List[Dict[str, Any]]:
|
||||
def get_feed_list(self, target_qq: str, num: int) -> list[dict[str, Any]]:
|
||||
try:
|
||||
params = {
|
||||
"g_tk": self.gtk2,
|
||||
|
||||
@@ -835,7 +835,7 @@ class MessageHandler:
|
||||
if music:
|
||||
tag = music.get("tag", "未知来源")
|
||||
logger.debug(f"检测到【{tag}】音乐分享消息 (music view),开始提取信息")
|
||||
|
||||
|
||||
title = music.get("title", "未知歌曲")
|
||||
desc = music.get("desc", "未知艺术家")
|
||||
jump_url = music.get("jumpUrl", "")
|
||||
@@ -853,7 +853,7 @@ class MessageHandler:
|
||||
artist = parts[1]
|
||||
else:
|
||||
artist = desc
|
||||
|
||||
|
||||
formatted_content = (
|
||||
f"这是一张来自【{tag}】的音乐分享卡片:\n"
|
||||
f"歌曲: {song_title}\n"
|
||||
@@ -870,12 +870,12 @@ class MessageHandler:
|
||||
if news and "网易云音乐" in news.get("tag", ""):
|
||||
tag = news.get("tag")
|
||||
logger.debug(f"检测到【{tag}】音乐分享消息 (news view),开始提取信息")
|
||||
|
||||
|
||||
title = news.get("title", "未知歌曲")
|
||||
desc = news.get("desc", "未知艺术家")
|
||||
jump_url = news.get("jumpUrl", "")
|
||||
preview_url = news.get("preview", "")
|
||||
|
||||
|
||||
formatted_content = (
|
||||
f"这是一张来自【{tag}】的音乐分享卡片:\n"
|
||||
f"标题: {title}\n"
|
||||
|
||||
@@ -3,7 +3,6 @@ import time
|
||||
import random
|
||||
import websockets as Server
|
||||
import uuid
|
||||
import asyncio
|
||||
from maim_message import (
|
||||
UserInfo,
|
||||
GroupInfo,
|
||||
@@ -96,7 +95,9 @@ class SendHandler:
|
||||
logger.error("无法识别的消息类型")
|
||||
return
|
||||
logger.info("尝试发送到napcat")
|
||||
logger.debug(f"准备发送到napcat的消息体: action='{action}', {id_name}='{target_id}', message='{processed_message}'")
|
||||
logger.debug(
|
||||
f"准备发送到napcat的消息体: action='{action}', {id_name}='{target_id}', message='{processed_message}'"
|
||||
)
|
||||
response = await self.send_message_to_napcat(
|
||||
action,
|
||||
{
|
||||
|
||||
@@ -6,7 +6,6 @@ import urllib3
|
||||
import ssl
|
||||
import io
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from asyncio import Lock
|
||||
|
||||
@@ -75,7 +74,7 @@ async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> d
|
||||
except Exception as e:
|
||||
logger.error(f"获取群信息失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
data = socket_response.get("data")
|
||||
if data:
|
||||
await set_to_cache(cache_key, data)
|
||||
@@ -114,7 +113,7 @@ async def get_member_info(websocket: Server.ServerConnection, group_id: int, use
|
||||
cached_data = await get_from_cache(cache_key)
|
||||
if cached_data:
|
||||
return cached_data
|
||||
|
||||
|
||||
logger.debug(f"获取群成员信息中 (无缓存): group={group_id}, user={user_id}")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps(
|
||||
@@ -133,7 +132,7 @@ async def get_member_info(websocket: Server.ServerConnection, group_id: int, use
|
||||
except Exception as e:
|
||||
logger.error(f"获取成员信息失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
data = socket_response.get("data")
|
||||
if data:
|
||||
await set_to_cache(cache_key, data)
|
||||
@@ -203,7 +202,7 @@ async def get_self_info(websocket: Server.ServerConnection) -> dict | None:
|
||||
except Exception as e:
|
||||
logger.error(f"获取自身信息失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
data = response.get("data")
|
||||
if data:
|
||||
await set_to_cache(cache_key, data)
|
||||
|
||||
@@ -6,19 +6,17 @@
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
from src.plugin_system.apis.logging_api import get_logger
|
||||
from src.plugin_system.apis.permission_api import permission_api
|
||||
from src.plugin_system.apis.plugin_register_api import register_plugin
|
||||
from src.plugin_system.base.base_plugin import BasePlugin
|
||||
from src.plugin_system.base.plus_command import PlusCommand
|
||||
from src.plugin_system.base.command_args import CommandArgs
|
||||
from src.plugin_system.apis.permission_api import permission_api
|
||||
from src.plugin_system.apis.logging_api import get_logger
|
||||
from src.plugin_system.base.component_types import PlusCommandInfo, ChatType
|
||||
from src.plugin_system.base.component_types import ChatType, PlusCommandInfo
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from src.plugin_system.base.plus_command import PlusCommand
|
||||
from src.plugin_system.utils.permission_decorators import require_permission
|
||||
|
||||
|
||||
logger = get_logger("Permission")
|
||||
|
||||
|
||||
@@ -44,7 +42,7 @@ class PermissionCommand(PlusCommand):
|
||||
"plugin.permission.view", "权限查看:可以查看权限节点和用户权限信息", "permission_manager", True
|
||||
)
|
||||
|
||||
async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]:
|
||||
async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]:
|
||||
"""执行权限管理命令"""
|
||||
if args.is_empty:
|
||||
await self._show_help()
|
||||
@@ -114,7 +112,7 @@ class PermissionCommand(PlusCommand):
|
||||
await self.send_text(help_text)
|
||||
|
||||
@staticmethod
|
||||
def _parse_user_mention(mention: str) -> Optional[str]:
|
||||
def _parse_user_mention(mention: str) -> str | None:
|
||||
"""解析用户提及,提取QQ号
|
||||
|
||||
支持的格式:
|
||||
@@ -134,7 +132,7 @@ class PermissionCommand(PlusCommand):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_user_from_args(args: CommandArgs, index: int = 0) -> Optional[str]:
|
||||
def parse_user_from_args(args: CommandArgs, index: int = 0) -> str | None:
|
||||
"""从CommandArgs中解析用户ID
|
||||
|
||||
Args:
|
||||
@@ -166,7 +164,7 @@ class PermissionCommand(PlusCommand):
|
||||
return None
|
||||
|
||||
@require_permission("plugin.permission.manage", "❌ 你没有权限管理的权限")
|
||||
async def _grant_permission(self, chat_stream, args: List[str]):
|
||||
async def _grant_permission(self, chat_stream, args: list[str]):
|
||||
"""授权用户权限"""
|
||||
if len(args) < 2:
|
||||
await self.send_text("❌ 用法: /permission grant <@用户|QQ号> <权限节点>")
|
||||
@@ -189,7 +187,7 @@ class PermissionCommand(PlusCommand):
|
||||
await self.send_text("❌ 授权失败,请检查权限节点是否存在")
|
||||
|
||||
@require_permission("plugin.permission.manage", "❌ 你没有权限管理的权限")
|
||||
async def _revoke_permission(self, chat_stream, args: List[str]):
|
||||
async def _revoke_permission(self, chat_stream, args: list[str]):
|
||||
"""撤销用户权限"""
|
||||
if len(args) < 2:
|
||||
await self.send_text("❌ 用法: /permission revoke <@用户|QQ号> <权限节点>")
|
||||
@@ -212,7 +210,7 @@ class PermissionCommand(PlusCommand):
|
||||
await self.send_text("❌ 撤销失败,请检查权限节点是否存在")
|
||||
|
||||
@require_permission("plugin.permission.view", "❌ 你没有查看权限的权限")
|
||||
async def _list_permissions(self, chat_stream, args: List[str]):
|
||||
async def _list_permissions(self, chat_stream, args: list[str]):
|
||||
"""列出用户权限"""
|
||||
target_user_id = None
|
||||
|
||||
@@ -244,7 +242,7 @@ class PermissionCommand(PlusCommand):
|
||||
await self.send_text(response)
|
||||
|
||||
@require_permission("plugin.permission.view", "❌ 你没有查看权限的权限")
|
||||
async def _check_permission(self, chat_stream, args: List[str]):
|
||||
async def _check_permission(self, chat_stream, args: list[str]):
|
||||
"""检查用户权限"""
|
||||
if len(args) < 2:
|
||||
await self.send_text("❌ 用法: /permission check <@用户|QQ号> <权限节点>")
|
||||
@@ -273,7 +271,7 @@ class PermissionCommand(PlusCommand):
|
||||
await self.send_text(response)
|
||||
|
||||
@require_permission("plugin.permission.view", "❌ 你没有查看权限的权限")
|
||||
async def _list_nodes(self, chat_stream, args: List[str]):
|
||||
async def _list_nodes(self, chat_stream, args: list[str]):
|
||||
"""列出权限节点"""
|
||||
plugin_name = args[0] if args else None
|
||||
|
||||
@@ -388,6 +386,6 @@ class PermissionManagerPlugin(BasePlugin):
|
||||
}
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[PlusCommandInfo, Type[PlusCommand]]]:
|
||||
def get_plugin_components(self) -> list[tuple[PlusCommandInfo, type[PlusCommand]]]:
|
||||
"""返回插件的PlusCommand组件"""
|
||||
return [(PermissionCommand.get_plus_command_info(), PermissionCommand)]
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
import asyncio
|
||||
|
||||
from typing import List, Tuple, Type
|
||||
from src.plugin_system import (
|
||||
BasePlugin,
|
||||
ConfigField,
|
||||
register_plugin,
|
||||
plugin_manage_api,
|
||||
component_manage_api,
|
||||
ComponentInfo,
|
||||
ComponentType,
|
||||
ConfigField,
|
||||
component_manage_api,
|
||||
plugin_manage_api,
|
||||
register_plugin,
|
||||
)
|
||||
from src.plugin_system.base.plus_command import PlusCommand
|
||||
from src.plugin_system.base.command_args import CommandArgs
|
||||
from src.plugin_system.base.component_types import PlusCommandInfo, ChatType
|
||||
from src.plugin_system.apis.permission_api import permission_api
|
||||
from src.plugin_system.base.command_args import CommandArgs
|
||||
from src.plugin_system.base.component_types import ChatType, PlusCommandInfo
|
||||
from src.plugin_system.base.plus_command import PlusCommand
|
||||
from src.plugin_system.utils.permission_decorators import require_permission
|
||||
|
||||
|
||||
@@ -31,7 +30,7 @@ class ManagementCommand(PlusCommand):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@require_permission("plugin.management.admin", "❌ 你没有插件管理的权限")
|
||||
async def execute(self, args: CommandArgs) -> Tuple[bool, str, bool]:
|
||||
async def execute(self, args: CommandArgs) -> tuple[bool, str, bool]:
|
||||
"""执行插件管理命令"""
|
||||
if args.is_empty:
|
||||
await self._show_help("all")
|
||||
@@ -51,7 +50,7 @@ class ManagementCommand(PlusCommand):
|
||||
await self.send_text(f"❌ 未知的子命令: {subcommand}\n使用 /pm help 查看帮助")
|
||||
return True, "未知子命令", True
|
||||
|
||||
async def _handle_plugin_commands(self, args: List[str]) -> Tuple[bool, str, bool]:
|
||||
async def _handle_plugin_commands(self, args: list[str]) -> tuple[bool, str, bool]:
|
||||
"""处理插件相关命令"""
|
||||
if not args:
|
||||
await self._show_help("plugin")
|
||||
@@ -83,7 +82,7 @@ class ManagementCommand(PlusCommand):
|
||||
|
||||
return True, "插件命令执行完成", True
|
||||
|
||||
async def _handle_component_commands(self, args: List[str]) -> Tuple[bool, str, bool]:
|
||||
async def _handle_component_commands(self, args: list[str]) -> tuple[bool, str, bool]:
|
||||
"""处理组件相关命令"""
|
||||
if not args:
|
||||
await self._show_help("component")
|
||||
@@ -258,9 +257,8 @@ class ManagementCommand(PlusCommand):
|
||||
else:
|
||||
await self.send_text(f"❌ 插件强制重载失败: `{plugin_name}`")
|
||||
except Exception as e:
|
||||
await self.send_text(f"❌ 强制重载过程中发生错误: {str(e)}")
|
||||
await self.send_text(f"❌ 强制重载过程中发生错误: {e!s}")
|
||||
|
||||
|
||||
async def _add_dir(self, dir_path: str):
|
||||
"""添加插件目录"""
|
||||
await self.send_text(f"📁 正在添加插件目录: `{dir_path}`")
|
||||
@@ -272,17 +270,17 @@ class ManagementCommand(PlusCommand):
|
||||
await self.send_text(f"❌ 插件目录添加失败: `{dir_path}`")
|
||||
|
||||
@staticmethod
|
||||
def _fetch_all_registered_components() -> List[ComponentInfo]:
|
||||
def _fetch_all_registered_components() -> list[ComponentInfo]:
|
||||
all_plugin_info = component_manage_api.get_all_plugin_info()
|
||||
if not all_plugin_info:
|
||||
return []
|
||||
|
||||
components_info: List[ComponentInfo] = []
|
||||
components_info: list[ComponentInfo] = []
|
||||
for plugin_info in all_plugin_info.values():
|
||||
components_info.extend(plugin_info.components)
|
||||
return components_info
|
||||
|
||||
def _fetch_locally_disabled_components(self) -> List[str]:
|
||||
def _fetch_locally_disabled_components(self) -> list[str]:
|
||||
"""获取本地禁用的组件列表"""
|
||||
stream_id = self.message.chat_stream.stream_id
|
||||
locally_disabled_components_actions = component_manage_api.get_locally_disabled_components(
|
||||
@@ -501,16 +499,16 @@ class PluginManagementPlugin(BasePlugin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# 注册权限节点
|
||||
|
||||
|
||||
async def on_plugin_loaded(self):
|
||||
await permission_api.register_permission_node(
|
||||
"plugin.management.admin",
|
||||
"插件管理:可以管理插件和组件的加载、卸载、启用、禁用等操作",
|
||||
"plugin_management",
|
||||
False,
|
||||
"plugin.management.admin",
|
||||
"插件管理:可以管理插件和组件的加载、卸载、启用、禁用等操作",
|
||||
"plugin_management",
|
||||
False,
|
||||
)
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[PlusCommandInfo, Type[PlusCommand]]]:
|
||||
def get_plugin_components(self) -> list[tuple[PlusCommandInfo, type[PlusCommand]]]:
|
||||
"""返回插件的PlusCommand组件"""
|
||||
components = []
|
||||
if self.get_config("plugin.enabled", True):
|
||||
|
||||
@@ -1,27 +1,22 @@
|
||||
from typing import List, Tuple, Union, Type, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.official_configs import AffinityFlowConfig
|
||||
from src.plugin_system.base.base_plugin import BasePlugin
|
||||
from src.plugin_system import (
|
||||
BaseEventHandler,
|
||||
BasePlugin,
|
||||
ConfigField,
|
||||
register_plugin,
|
||||
plugin_manage_api,
|
||||
component_manage_api,
|
||||
ComponentInfo,
|
||||
ComponentType,
|
||||
EventHandlerInfo,
|
||||
EventType,
|
||||
BaseEventHandler,
|
||||
register_plugin,
|
||||
)
|
||||
from src.plugin_system.base.base_plugin import BasePlugin
|
||||
|
||||
from .proacive_thinker_event import ProactiveThinkerEventHandler
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@register_plugin
|
||||
class ProactiveThinkerPlugin(BasePlugin):
|
||||
"""一个主动思考的插件,但现在还只是个空壳子"""
|
||||
|
||||
plugin_name: str = "proactive_thinker"
|
||||
enable_plugin: bool = False
|
||||
dependencies: list[str] = []
|
||||
@@ -33,13 +28,13 @@ class ProactiveThinkerPlugin(BasePlugin):
|
||||
"config_version": ConfigField(type=str, default="1.1.0", description="配置文件版本"),
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[EventHandlerInfo, Type[BaseEventHandler]]]:
|
||||
def get_plugin_components(self) -> list[tuple[EventHandlerInfo, type[BaseEventHandler]]]:
|
||||
"""返回插件的EventHandler组件"""
|
||||
components: List[Tuple[EventHandlerInfo, Type[BaseEventHandler]]] = [
|
||||
components: list[tuple[EventHandlerInfo, type[BaseEventHandler]]] = [
|
||||
(ProactiveThinkerEventHandler.get_handler_info(), ProactiveThinkerEventHandler)
|
||||
]
|
||||
return components
|
||||
|
||||
|
||||
@@ -2,17 +2,17 @@ import asyncio
|
||||
import random
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import List, Union, Type, Optional
|
||||
|
||||
from maim_message import UserInfo
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.manager.async_task_manager import async_task_manager, AsyncTask
|
||||
from src.plugin_system import EventType, BaseEventHandler
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.plugin_system import BaseEventHandler, EventType
|
||||
from src.plugin_system.apis import chat_api, person_api
|
||||
from src.plugin_system.base.base_event import HandlerResult
|
||||
|
||||
from .proactive_thinker_executor import ProactiveThinkerExecutor
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -69,7 +69,7 @@ class ColdStartTask(AsyncTask):
|
||||
|
||||
# 创建 UserInfo 对象,这是创建聊天流的必要信息
|
||||
user_info = UserInfo(platform=platform, user_id=str(user_id), user_nickname=user_nickname)
|
||||
|
||||
|
||||
# 【关键步骤】主动创建聊天流。
|
||||
# 创建后,该用户就进入了机器人的“好友列表”,后续将由 ProactiveThinkingTask 接管
|
||||
stream = await self.chat_manager.get_or_create_stream(platform, user_info)
|
||||
@@ -175,10 +175,12 @@ class ProactiveThinkingTask(AsyncTask):
|
||||
# 2. 【核心逻辑】检查聊天冷却时间是否足够长
|
||||
time_since_last_active = time.time() - stream.last_active_time
|
||||
if time_since_last_active > next_interval:
|
||||
logger.info(f"【日常唤醒】聊天流 {stream.stream_id} 已冷却 {time_since_last_active:.2f} 秒,触发主动对话。")
|
||||
|
||||
logger.info(
|
||||
f"【日常唤醒】聊天流 {stream.stream_id} 已冷却 {time_since_last_active:.2f} 秒,触发主动对话。"
|
||||
)
|
||||
|
||||
await self.executor.execute(stream_id=stream.stream_id, start_mode="wake_up")
|
||||
|
||||
|
||||
# 【关键步骤】在触发后,立刻更新活跃时间并保存。
|
||||
# 这可以防止在同一个检查周期内,对同一个目标因为意外的延迟而发送多条消息。
|
||||
stream.update_active_time()
|
||||
@@ -197,7 +199,7 @@ class ProactiveThinkerEventHandler(BaseEventHandler):
|
||||
|
||||
handler_name: str = "proactive_thinker_on_start"
|
||||
handler_description: str = "主动思考插件的启动事件处理器"
|
||||
init_subscribe: List[Union[EventType, str]] = [EventType.ON_START]
|
||||
init_subscribe: list[EventType | str] = [EventType.ON_START]
|
||||
|
||||
async def execute(self, kwargs: dict | None) -> "HandlerResult":
|
||||
"""在机器人启动时执行,根据配置决定是否启动后台任务。"""
|
||||
|
||||
@@ -1,11 +1,21 @@
|
||||
import orjson
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import chat_api, person_api, schedule_api, send_api, llm_api, message_api, generator_api, database_api
|
||||
from src.config.config import global_config, model_config
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.plugin_system.apis import (
|
||||
chat_api,
|
||||
database_api,
|
||||
generator_api,
|
||||
llm_api,
|
||||
message_api,
|
||||
person_api,
|
||||
schedule_api,
|
||||
send_api,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -39,17 +49,16 @@ class ProactiveThinkerExecutor:
|
||||
# 2. 决策阶段
|
||||
decision_result = await self._make_decision(context, start_mode)
|
||||
|
||||
|
||||
if not decision_result or not decision_result.get("should_reply"):
|
||||
reason = decision_result.get("reason", "未提供") if decision_result else "决策过程返回None"
|
||||
logger.info(f"决策结果为:不回复。原因: {reason}")
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self._get_stream_from_id(stream_id),
|
||||
action_name="proactive_decision",
|
||||
action_prompt_display=f"主动思考决定不回复,原因: {reason}",
|
||||
action_done = True,
|
||||
action_data=decision_result
|
||||
)
|
||||
chat_stream=self._get_stream_from_id(stream_id),
|
||||
action_name="proactive_decision",
|
||||
action_prompt_display=f"主动思考决定不回复,原因: {reason}",
|
||||
action_done=True,
|
||||
action_data=decision_result,
|
||||
)
|
||||
return
|
||||
|
||||
# 3. 规划与执行阶段
|
||||
@@ -59,15 +68,17 @@ class ProactiveThinkerExecutor:
|
||||
chat_stream=self._get_stream_from_id(stream_id),
|
||||
action_name="proactive_decision",
|
||||
action_prompt_display=f"主动思考决定回复,原因: {reason},话题:{topic}",
|
||||
action_done = True,
|
||||
action_data=decision_result
|
||||
action_done=True,
|
||||
action_data=decision_result,
|
||||
)
|
||||
logger.info(f"决策结果为:回复。话题: {topic}")
|
||||
|
||||
|
||||
plan_prompt = self._build_plan_prompt(context, start_mode, topic, reason)
|
||||
|
||||
is_success, response, _, _ = await llm_api.generate_with_model(prompt=plan_prompt, model_config=model_config.model_task_config.utils)
|
||||
|
||||
|
||||
is_success, response, _, _ = await llm_api.generate_with_model(
|
||||
prompt=plan_prompt, model_config=model_config.model_task_config.utils
|
||||
)
|
||||
|
||||
if is_success and response:
|
||||
stream = self._get_stream_from_id(stream_id)
|
||||
if stream:
|
||||
@@ -91,7 +102,7 @@ class ProactiveThinkerExecutor:
|
||||
logger.error(f"解析 stream_id ({stream_id}) 或获取 stream 失败: {e}")
|
||||
return None
|
||||
|
||||
async def _gather_context(self, stream_id: str) -> Optional[Dict[str, Any]]:
|
||||
async def _gather_context(self, stream_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
收集构建提示词所需的所有上下文信息
|
||||
"""
|
||||
@@ -104,33 +115,41 @@ class ProactiveThinkerExecutor:
|
||||
if not user_info or not user_info.platform or not user_info.user_id:
|
||||
logger.warning(f"Stream {stream_id} 的 user_info 不完整")
|
||||
return None
|
||||
|
||||
|
||||
person_id = person_api.get_person_id(user_info.platform, int(user_info.user_id))
|
||||
person_info_manager = get_person_info_manager()
|
||||
|
||||
# 获取日程
|
||||
schedules = await schedule_api.ScheduleAPI.get_today_schedule()
|
||||
schedule_context = "\n".join([f"- {s['title']} ({s['start_time']}-{s['end_time']})" for s in schedules]) if schedules else "今天没有日程安排。"
|
||||
schedule_context = (
|
||||
"\n".join([f"- {s['title']} ({s['start_time']}-{s['end_time']})" for s in schedules])
|
||||
if schedules
|
||||
else "今天没有日程安排。"
|
||||
)
|
||||
|
||||
# 获取关系信息
|
||||
short_impression = await person_info_manager.get_value(person_id, "short_impression") or "无"
|
||||
impression = await person_info_manager.get_value(person_id, "impression") or "无"
|
||||
attitude = await person_info_manager.get_value(person_id, "attitude") or 50
|
||||
|
||||
|
||||
# 获取最近聊天记录
|
||||
recent_messages = await message_api.get_recent_messages(stream_id, limit=10)
|
||||
recent_chat_history = await message_api.build_readable_messages_to_str(recent_messages) if recent_messages else "无"
|
||||
|
||||
recent_chat_history = (
|
||||
await message_api.build_readable_messages_to_str(recent_messages) if recent_messages else "无"
|
||||
)
|
||||
|
||||
# 获取最近的动作历史
|
||||
action_history = await database_api.db_query(
|
||||
database_api.MODEL_MAPPING["ActionRecords"],
|
||||
filters={"chat_id": stream_id, "action_name": "proactive_decision"},
|
||||
limit=3,
|
||||
order_by=["-time"]
|
||||
order_by=["-time"],
|
||||
)
|
||||
action_history_context = "无"
|
||||
if isinstance(action_history, list):
|
||||
action_history_context = "\n".join([f"- {a['action_data']}" for a in action_history if isinstance(a, dict)]) or "无"
|
||||
action_history_context = (
|
||||
"\n".join([f"- {a['action_data']}" for a in action_history if isinstance(a, dict)]) or "无"
|
||||
)
|
||||
|
||||
return {
|
||||
"person_id": person_id,
|
||||
@@ -138,47 +157,43 @@ class ProactiveThinkerExecutor:
|
||||
"schedule_context": schedule_context,
|
||||
"recent_chat_history": recent_chat_history,
|
||||
"action_history_context": action_history_context,
|
||||
"relationship": {
|
||||
"short_impression": short_impression,
|
||||
"impression": impression,
|
||||
"attitude": attitude
|
||||
},
|
||||
"relationship": {"short_impression": short_impression, "impression": impression, "attitude": attitude},
|
||||
"persona": {
|
||||
"core": global_config.personality.personality_core,
|
||||
"side": global_config.personality.personality_side,
|
||||
"identity": global_config.personality.identity,
|
||||
},
|
||||
"current_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
"current_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}
|
||||
|
||||
async def _make_decision(self, context: Dict[str, Any], start_mode: str) -> Optional[Dict[str, Any]]:
|
||||
async def _make_decision(self, context: dict[str, Any], start_mode: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
决策模块:判断是否应该主动发起对话,以及聊什么话题
|
||||
"""
|
||||
persona = context['persona']
|
||||
user_info = context['user_info']
|
||||
relationship = context['relationship']
|
||||
persona = context["persona"]
|
||||
user_info = context["user_info"]
|
||||
relationship = context["relationship"]
|
||||
|
||||
prompt = f"""
|
||||
# 角色
|
||||
你的名字是{global_config.bot.nickname},你的人设如下:
|
||||
- 核心人设: {persona['core']}
|
||||
- 侧面人设: {persona['side']}
|
||||
- 身份: {persona['identity']}
|
||||
- 核心人设: {persona["core"]}
|
||||
- 侧面人设: {persona["side"]}
|
||||
- 身份: {persona["identity"]}
|
||||
|
||||
# 任务
|
||||
现在是 {context['current_time']},你需要根据当前的情境,决定是否要主动向用户 '{user_info.user_nickname}' 发起对话。
|
||||
现在是 {context["current_time"]},你需要根据当前的情境,决定是否要主动向用户 '{user_info.user_nickname}' 发起对话。
|
||||
|
||||
# 情境分析
|
||||
1. **启动模式**: {start_mode} ({'初次见面/很久未见' if start_mode == 'cold_start' else '日常唤醒'})
|
||||
1. **启动模式**: {start_mode} ({"初次见面/很久未见" if start_mode == "cold_start" else "日常唤醒"})
|
||||
2. **你的日程**:
|
||||
{context['schedule_context']}
|
||||
{context["schedule_context"]}
|
||||
3. **你和Ta的关系**:
|
||||
- 简短印象: {relationship['short_impression']}
|
||||
- 详细印象: {relationship['impression']}
|
||||
- 好感度: {relationship['attitude']}/100
|
||||
- 简短印象: {relationship["short_impression"]}
|
||||
- 详细印象: {relationship["impression"]}
|
||||
- 好感度: {relationship["attitude"]}/100
|
||||
4. **最近的聊天摘要**:
|
||||
{context['recent_chat_history']}
|
||||
{context["recent_chat_history"]}
|
||||
|
||||
# 决策指令
|
||||
请综合以上所有信息,做出决策。你的决策需要以JSON格式输出,包含以下字段:
|
||||
@@ -204,9 +219,11 @@ class ProactiveThinkerExecutor:
|
||||
|
||||
请输出你的决策:
|
||||
"""
|
||||
|
||||
is_success, response, _, _ = await llm_api.generate_with_model(prompt=prompt, model_config=model_config.model_task_config.utils)
|
||||
|
||||
|
||||
is_success, response, _, _ = await llm_api.generate_with_model(
|
||||
prompt=prompt, model_config=model_config.model_task_config.utils
|
||||
)
|
||||
|
||||
if not is_success:
|
||||
return {"should_reply": False, "reason": "决策模型生成失败"}
|
||||
|
||||
@@ -218,21 +235,21 @@ class ProactiveThinkerExecutor:
|
||||
logger.error(f"决策LLM返回的JSON格式无效: {response}")
|
||||
return {"should_reply": False, "reason": "决策模型返回格式错误"}
|
||||
|
||||
def _build_plan_prompt(self, context: Dict[str, Any], start_mode: str, topic: str, reason: str) -> str:
|
||||
def _build_plan_prompt(self, context: dict[str, Any], start_mode: str, topic: str, reason: str) -> str:
|
||||
"""
|
||||
根据启动模式和决策话题,构建最终的规划提示词
|
||||
"""
|
||||
persona = context['persona']
|
||||
user_info = context['user_info']
|
||||
relationship = context['relationship']
|
||||
persona = context["persona"]
|
||||
user_info = context["user_info"]
|
||||
relationship = context["relationship"]
|
||||
|
||||
if start_mode == "cold_start":
|
||||
prompt = f"""
|
||||
# 角色
|
||||
你的名字是{global_config.bot.nickname},你的人设如下:
|
||||
- 核心人设: {persona['core']}
|
||||
- 侧面人设: {persona['side']}
|
||||
- 身份: {persona['identity']}
|
||||
- 核心人设: {persona["core"]}
|
||||
- 侧面人设: {persona["side"]}
|
||||
- 身份: {persona["identity"]}
|
||||
|
||||
# 任务
|
||||
你需要主动向一个新朋友 '{user_info.user_nickname}' 发起对话。这是你们的第一次交流,或者很久没聊了。
|
||||
@@ -240,9 +257,9 @@ class ProactiveThinkerExecutor:
|
||||
# 决策上下文
|
||||
- **决策理由**: {reason}
|
||||
- **你和Ta的关系**:
|
||||
- 简短印象: {relationship['short_impression']}
|
||||
- 详细印象: {relationship['impression']}
|
||||
- 好感度: {relationship['attitude']}/100
|
||||
- 简短印象: {relationship["short_impression"]}
|
||||
- 详细印象: {relationship["impression"]}
|
||||
- 好感度: {relationship["attitude"]}/100
|
||||
|
||||
# 对话指引
|
||||
- 你的目标是“破冰”,让对话自然地开始。
|
||||
@@ -254,26 +271,26 @@ class ProactiveThinkerExecutor:
|
||||
prompt = f"""
|
||||
# 角色
|
||||
你的名字是{global_config.bot.nickname},你的人设如下:
|
||||
- 核心人设: {persona['core']}
|
||||
- 侧面人设: {persona['side']}
|
||||
- 身份: {persona['identity']}
|
||||
- 核心人设: {persona["core"]}
|
||||
- 侧面人设: {persona["side"]}
|
||||
- 身份: {persona["identity"]}
|
||||
|
||||
# 任务
|
||||
现在是 {context['current_time']},你需要主动向你的朋友 '{user_info.user_nickname}' 发起对话。
|
||||
现在是 {context["current_time"]},你需要主动向你的朋友 '{user_info.user_nickname}' 发起对话。
|
||||
|
||||
# 决策上下文
|
||||
- **决策理由**: {reason}
|
||||
|
||||
# 情境分析
|
||||
1. **你的日程**:
|
||||
{context['schedule_context']}
|
||||
{context["schedule_context"]}
|
||||
2. **你和Ta的关系**:
|
||||
- 详细印象: {relationship['impression']}
|
||||
- 好感度: {relationship['attitude']}/100
|
||||
- 详细印象: {relationship["impression"]}
|
||||
- 好感度: {relationship["attitude"]}/100
|
||||
3. **最近的聊天摘要**:
|
||||
{context['recent_chat_history']}
|
||||
{context["recent_chat_history"]}
|
||||
4. **你最近的相关动作**:
|
||||
{context['action_history_context']}
|
||||
{context["action_history_context"]}
|
||||
|
||||
# 对话指引
|
||||
- 你决定和Ta聊聊关于“{topic}”的话题。
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
import re
|
||||
from typing import List, Tuple, Type, Optional
|
||||
|
||||
from src.plugin_system import (
|
||||
BasePlugin,
|
||||
register_plugin,
|
||||
BaseAction,
|
||||
ComponentInfo,
|
||||
ActionActivationType,
|
||||
ConfigField,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from .qq_emoji_list import qq_face
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.plugin_system.apis import send_api, llm_api, generator_api
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
from typing import Optional
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
import asyncio
|
||||
import datetime
|
||||
import re
|
||||
from typing import ClassVar
|
||||
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.common.logger import get_logger
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.plugin_system import (
|
||||
ActionActivationType,
|
||||
BaseAction,
|
||||
BasePlugin,
|
||||
ComponentInfo,
|
||||
ConfigField,
|
||||
register_plugin,
|
||||
)
|
||||
from src.plugin_system.apis import generator_api, llm_api, send_api
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
|
||||
from .qq_emoji_list import qq_face
|
||||
|
||||
logger = get_logger("set_emoji_like_plugin")
|
||||
|
||||
@@ -32,7 +32,7 @@ class ReminderTask(AsyncTask):
|
||||
self,
|
||||
delay: float,
|
||||
stream_id: str,
|
||||
group_id: Optional[str],
|
||||
group_id: str | None,
|
||||
is_group: bool,
|
||||
target_user_id: str,
|
||||
target_user_name: str,
|
||||
@@ -164,7 +164,7 @@ class PokeAction(BaseAction):
|
||||
"""
|
||||
associated_types = ["text"]
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""执行戳一戳的动作"""
|
||||
user_id = self.action_data.get("user_id")
|
||||
user_name = self.action_data.get("user_name")
|
||||
@@ -244,7 +244,7 @@ class SetEmojiLikeAction(BaseAction):
|
||||
if match:
|
||||
emoji_options.append(match.group(1))
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""执行设置表情回应的动作"""
|
||||
message_id = None
|
||||
set_like = self.action_data.get("set", True)
|
||||
@@ -362,7 +362,7 @@ class RemindAction(BaseAction):
|
||||
"例如:'10分钟后提醒我收快递'、'明天早上九点喊一下李四参加晨会'",
|
||||
]
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""执行设置提醒的动作"""
|
||||
user_name = self.action_data.get("user_name")
|
||||
remind_time_str = self.action_data.get("remind_time")
|
||||
@@ -388,14 +388,14 @@ class RemindAction(BaseAction):
|
||||
# 优先尝试直接解析
|
||||
try:
|
||||
target_time = parse_datetime(remind_time_str, fuzzy=True)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
# 如果直接解析失败,调用 LLM 进行转换
|
||||
logger.info(f"[ReminderPlugin] 直接解析时间 '{remind_time_str}' 失败,尝试使用 LLM 进行转换...")
|
||||
|
||||
# 获取所有可用的模型配置
|
||||
available_models = llm_api.get_available_models()
|
||||
if "utils_small" not in available_models:
|
||||
raise ValueError("未找到 'utils_small' 模型配置,无法解析时间")
|
||||
raise ValueError("未找到 'utils_small' 模型配置,无法解析时间") from e
|
||||
|
||||
# 明确使用 'planner' 模型
|
||||
model_to_use = available_models["utils_small"]
|
||||
@@ -421,7 +421,7 @@ class RemindAction(BaseAction):
|
||||
)
|
||||
|
||||
if not success or not response:
|
||||
raise ValueError(f"LLM未能返回有效的时间字符串: {response}")
|
||||
raise ValueError(f"LLM未能返回有效的时间字符串: {response}") from e
|
||||
|
||||
converted_time_str = response.strip()
|
||||
logger.info(f"[ReminderPlugin] LLM 转换结果: '{converted_time_str}'")
|
||||
@@ -535,15 +535,15 @@ class SetEmojiLikePlugin(BasePlugin):
|
||||
# 插件基本信息
|
||||
plugin_name: str = "social_toolkit_plugin" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: List[str] = [] # 插件依赖列表
|
||||
python_dependencies: List[str] = [] # Python包依赖列表,现在使用内置API
|
||||
dependencies: list[str] = [] # 插件依赖列表
|
||||
python_dependencies: list[str] = [] # Python包依赖列表,现在使用内置API
|
||||
config_file_name: str = "config.toml" # 配置文件名
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {"plugin": "插件基本信息", "components": "插件组件"}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema: dict = {
|
||||
config_schema: ClassVar[dict ]= {
|
||||
"plugin": {
|
||||
"name": ConfigField(type=str, default="set_emoji_like", description="插件名称"),
|
||||
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
|
||||
@@ -557,7 +557,7 @@ class SetEmojiLikePlugin(BasePlugin):
|
||||
},
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
|
||||
enable_components = []
|
||||
if self.get_config("components.action_set_emoji_like"):
|
||||
enable_components.append((SetEmojiLikeAction.get_action_info(), SetEmojiLikeAction))
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis.plugin_register_api import register_plugin
|
||||
from src.plugin_system.base.base_action import ActionActivationType, BaseAction, ChatMode
|
||||
from src.plugin_system.base.base_plugin import BasePlugin
|
||||
from src.plugin_system.base.component_types import ComponentInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_action import BaseAction, ActionActivationType, ChatMode
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from typing import Tuple, List, Type
|
||||
|
||||
logger = get_logger("tts")
|
||||
|
||||
@@ -44,7 +43,7 @@ class TTSAction(BaseAction):
|
||||
# 关联类型
|
||||
associated_types = ["tts_text"]
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""处理TTS文本转语音动作"""
|
||||
logger.info(f"{self.log_prefix} 执行TTS动作: {self.reasoning}")
|
||||
|
||||
@@ -140,7 +139,7 @@ class TTSPlugin(BasePlugin):
|
||||
},
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
|
||||
"""返回插件包含的组件列表"""
|
||||
|
||||
# 从配置获取组件启用状态
|
||||
|
||||
@@ -3,7 +3,7 @@ Base search engine interface
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Any
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseSearchEngine(ABC):
|
||||
@@ -12,7 +12,7 @@ class BaseSearchEngine(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
执行搜索
|
||||
|
||||
|
||||
@@ -6,11 +6,13 @@ import asyncio
|
||||
import functools
|
||||
import random
|
||||
import traceback
|
||||
from typing import Dict, List, Any
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .base import BaseSearchEngine
|
||||
|
||||
logger = get_logger("bing_engine")
|
||||
@@ -68,7 +70,7 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
"""检查Bing搜索引擎是否可用"""
|
||||
return True # Bing是免费搜索引擎,总是可用
|
||||
|
||||
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""执行Bing搜索"""
|
||||
query = args["query"]
|
||||
num_results = args.get("num_results", 3)
|
||||
@@ -83,7 +85,7 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
logger.error(f"Bing 搜索失败: {e}")
|
||||
return []
|
||||
|
||||
def _search_sync(self, keyword: str, num_results: int, time_range: str) -> List[Dict[str, Any]]:
|
||||
def _search_sync(self, keyword: str, num_results: int, time_range: str) -> list[dict[str, Any]]:
|
||||
"""同步执行Bing搜索"""
|
||||
if not keyword:
|
||||
return []
|
||||
@@ -113,7 +115,7 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
return list_result[:num_results] if len(list_result) > num_results else list_result
|
||||
|
||||
@staticmethod
|
||||
def _parse_html(url: str) -> List[Dict[str, Any]]:
|
||||
def _parse_html(url: str) -> list[dict[str, Any]]:
|
||||
"""解析处理结果"""
|
||||
try:
|
||||
logger.debug(f"访问Bing搜索URL: {url}")
|
||||
@@ -141,11 +143,11 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
try:
|
||||
res = session.get(url=url, timeout=(3.05, 6), verify=True, allow_redirects=True)
|
||||
except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e:
|
||||
logger.warning(f"第一次请求超时,正在重试: {str(e)}")
|
||||
logger.warning(f"第一次请求超时,正在重试: {e!s}")
|
||||
try:
|
||||
res = session.get(url=url, timeout=(5, 10), verify=False)
|
||||
except Exception as e2:
|
||||
logger.error(f"第二次请求也失败: {str(e2)}")
|
||||
logger.error(f"第二次请求也失败: {e2!s}")
|
||||
return []
|
||||
|
||||
res.encoding = "utf-8"
|
||||
@@ -175,7 +177,7 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
try:
|
||||
root = BeautifulSoup(res.text, "html.parser")
|
||||
except Exception as e:
|
||||
logger.error(f"HTML解析失败: {str(e)}")
|
||||
logger.error(f"HTML解析失败: {e!s}")
|
||||
return []
|
||||
|
||||
list_data = []
|
||||
@@ -262,6 +264,6 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
return list_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析Bing页面时出错: {str(e)}")
|
||||
logger.error(f"解析Bing页面时出错: {e!s}")
|
||||
logger.debug(traceback.format_exc())
|
||||
return []
|
||||
|
||||
@@ -2,10 +2,12 @@
|
||||
DuckDuckGo search engine implementation
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any
|
||||
from typing import Any
|
||||
|
||||
from asyncddgs import aDDGS
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .base import BaseSearchEngine
|
||||
|
||||
logger = get_logger("ddg_engine")
|
||||
@@ -20,7 +22,7 @@ class DDGSearchEngine(BaseSearchEngine):
|
||||
"""检查DuckDuckGo搜索引擎是否可用"""
|
||||
return True # DuckDuckGo不需要API密钥,总是可用
|
||||
|
||||
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""执行DuckDuckGo搜索"""
|
||||
query = args["query"]
|
||||
num_results = args.get("num_results", 3)
|
||||
|
||||
@@ -5,13 +5,15 @@ Exa search engine implementation
|
||||
import asyncio
|
||||
import functools
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any
|
||||
from typing import Any
|
||||
|
||||
from exa_py import Exa
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api
|
||||
from .base import BaseSearchEngine
|
||||
|
||||
from ..utils.api_key_manager import create_api_key_manager_from_config
|
||||
from .base import BaseSearchEngine
|
||||
|
||||
logger = get_logger("exa_engine")
|
||||
|
||||
@@ -36,7 +38,7 @@ class ExaSearchEngine(BaseSearchEngine):
|
||||
"""检查Exa搜索引擎是否可用"""
|
||||
return self.api_manager.is_available()
|
||||
|
||||
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""执行Exa搜索"""
|
||||
if not self.is_available():
|
||||
return []
|
||||
|
||||
@@ -4,13 +4,15 @@ Tavily search engine implementation
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from typing import Dict, List, Any
|
||||
from typing import Any
|
||||
|
||||
from tavily import TavilyClient
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api
|
||||
from .base import BaseSearchEngine
|
||||
|
||||
from ..utils.api_key_manager import create_api_key_manager_from_config
|
||||
from .base import BaseSearchEngine
|
||||
|
||||
logger = get_logger("tavily_engine")
|
||||
|
||||
@@ -37,7 +39,7 @@ class TavilySearchEngine(BaseSearchEngine):
|
||||
"""检查Tavily搜索引擎是否可用"""
|
||||
return self.api_manager.is_available()
|
||||
|
||||
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""执行Tavily搜索"""
|
||||
if not self.is_available():
|
||||
return []
|
||||
|
||||
@@ -4,14 +4,12 @@ Web Search Tool Plugin
|
||||
一个功能强大的网络搜索和URL解析插件,支持多种搜索引擎和解析策略。
|
||||
"""
|
||||
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo, ConfigField, PythonDependency
|
||||
from src.plugin_system.apis import config_api
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, PythonDependency, register_plugin
|
||||
from src.plugin_system.apis import config_api
|
||||
|
||||
from .tools.web_search import WebSurfingTool
|
||||
from .tools.url_parser import URLParserTool
|
||||
from .tools.web_search import WebSurfingTool
|
||||
|
||||
logger = get_logger("web_search_plugin")
|
||||
|
||||
@@ -31,7 +29,7 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
# 插件基本信息
|
||||
plugin_name: str = "web_search_tool" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: List[str] = [] # 插件依赖列表
|
||||
dependencies: list[str] = [] # 插件依赖列表
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""初始化插件,立即加载所有搜索引擎"""
|
||||
@@ -40,10 +38,10 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
# 立即初始化所有搜索引擎,触发API密钥管理器的日志输出
|
||||
logger.info("🚀 正在初始化所有搜索引擎...")
|
||||
try:
|
||||
from .engines.bing_engine import BingSearchEngine
|
||||
from .engines.ddg_engine import DDGSearchEngine
|
||||
from .engines.exa_engine import ExaSearchEngine
|
||||
from .engines.tavily_engine import TavilySearchEngine
|
||||
from .engines.ddg_engine import DDGSearchEngine
|
||||
from .engines.bing_engine import BingSearchEngine
|
||||
|
||||
# 实例化所有搜索引擎,这会触发API密钥管理器的初始化
|
||||
exa_engine = ExaSearchEngine()
|
||||
@@ -71,7 +69,7 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
logger.error(f"❌ 搜索引擎初始化失败: {e}", exc_info=True)
|
||||
|
||||
# Python包依赖列表
|
||||
python_dependencies: List[PythonDependency] = [
|
||||
python_dependencies: list[PythonDependency] = [
|
||||
PythonDependency(package_name="asyncddgs", description="异步DuckDuckGo搜索库", optional=False),
|
||||
PythonDependency(
|
||||
package_name="exa_py",
|
||||
@@ -119,7 +117,7 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
},
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
|
||||
"""
|
||||
获取插件组件列表
|
||||
|
||||
|
||||
@@ -4,19 +4,20 @@ URL parser tool implementation
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from typing import Any, Dict
|
||||
from exa_py import Exa
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from bs4 import BeautifulSoup
|
||||
from exa_py import Exa
|
||||
|
||||
from src.common.cache_manager import tool_cache
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BaseTool, ToolParamType, llm_api
|
||||
from src.plugin_system.apis import config_api
|
||||
from src.common.cache_manager import tool_cache
|
||||
|
||||
from ..utils.api_key_manager import create_api_key_manager_from_config
|
||||
from ..utils.formatters import format_url_parse_results
|
||||
from ..utils.url_utils import parse_urls_from_input, validate_urls
|
||||
from ..utils.api_key_manager import create_api_key_manager_from_config
|
||||
|
||||
logger = get_logger("url_parser_tool")
|
||||
|
||||
@@ -50,7 +51,7 @@ class URLParserTool(BaseTool):
|
||||
exa_api_keys, lambda key: Exa(api_key=key), "Exa URL Parser"
|
||||
)
|
||||
|
||||
async def _local_parse_and_summarize(self, url: str) -> Dict[str, Any]:
|
||||
async def _local_parse_and_summarize(self, url: str) -> dict[str, Any]:
|
||||
"""
|
||||
使用本地库(httpx, BeautifulSoup)解析URL,并调用LLM进行总结。
|
||||
"""
|
||||
@@ -124,9 +125,9 @@ class URLParserTool(BaseTool):
|
||||
return {"error": f"请求失败,状态码: {e.response.status_code}"}
|
||||
except Exception as e:
|
||||
logger.error(f"本地解析或总结URL '{url}' 时发生未知异常: {e}", exc_info=True)
|
||||
return {"error": f"发生未知错误: {str(e)}"}
|
||||
return {"error": f"发生未知错误: {e!s}"}
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
执行URL内容提取和总结。优先使用Exa,失败后尝试本地解析。
|
||||
"""
|
||||
|
||||
@@ -3,18 +3,18 @@ Web search tool implementation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from src.common.cache_manager import tool_cache
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BaseTool, ToolParamType
|
||||
from src.plugin_system.apis import config_api
|
||||
from src.common.cache_manager import tool_cache
|
||||
|
||||
from ..engines.bing_engine import BingSearchEngine
|
||||
from ..engines.ddg_engine import DDGSearchEngine
|
||||
from ..engines.exa_engine import ExaSearchEngine
|
||||
from ..engines.tavily_engine import TavilySearchEngine
|
||||
from ..engines.ddg_engine import DDGSearchEngine
|
||||
from ..engines.bing_engine import BingSearchEngine
|
||||
from ..utils.formatters import format_search_results, deduplicate_results
|
||||
from ..utils.formatters import deduplicate_results, format_search_results
|
||||
|
||||
logger = get_logger("web_search_tool")
|
||||
|
||||
@@ -51,7 +51,7 @@ class WebSurfingTool(BaseTool):
|
||||
"bing": BingSearchEngine(),
|
||||
}
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
query = function_args.get("query")
|
||||
if not query:
|
||||
return {"error": "搜索查询不能为空。"}
|
||||
@@ -88,8 +88,8 @@ class WebSurfingTool(BaseTool):
|
||||
return result
|
||||
|
||||
async def _execute_parallel_search(
|
||||
self, function_args: Dict[str, Any], enabled_engines: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
self, function_args: dict[str, Any], enabled_engines: list[str]
|
||||
) -> dict[str, Any]:
|
||||
"""并行搜索策略:同时使用所有启用的搜索引擎"""
|
||||
search_tasks = []
|
||||
|
||||
@@ -124,11 +124,11 @@ class WebSurfingTool(BaseTool):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行并行网络搜索时发生异常: {e}", exc_info=True)
|
||||
return {"error": f"执行网络搜索时发生严重错误: {str(e)}"}
|
||||
return {"error": f"执行网络搜索时发生严重错误: {e!s}"}
|
||||
|
||||
async def _execute_fallback_search(
|
||||
self, function_args: Dict[str, Any], enabled_engines: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
self, function_args: dict[str, Any], enabled_engines: list[str]
|
||||
) -> dict[str, Any]:
|
||||
"""回退搜索策略:按顺序尝试搜索引擎,失败则尝试下一个"""
|
||||
for engine_name in enabled_engines:
|
||||
engine = self.engines.get(engine_name)
|
||||
@@ -154,7 +154,7 @@ class WebSurfingTool(BaseTool):
|
||||
|
||||
return {"error": "所有搜索引擎都失败了。"}
|
||||
|
||||
async def _execute_single_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]:
|
||||
async def _execute_single_search(self, function_args: dict[str, Any], enabled_engines: list[str]) -> dict[str, Any]:
|
||||
"""单一搜索策略:只使用第一个可用的搜索引擎"""
|
||||
for engine_name in enabled_engines:
|
||||
engine = self.engines.get(engine_name)
|
||||
@@ -174,6 +174,6 @@ class WebSurfingTool(BaseTool):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{engine_name} 搜索失败: {e}")
|
||||
return {"error": f"{engine_name} 搜索失败: {str(e)}"}
|
||||
return {"error": f"{engine_name} 搜索失败: {e!s}"}
|
||||
|
||||
return {"error": "没有可用的搜索引擎。"}
|
||||
|
||||
@@ -3,7 +3,9 @@ API密钥管理器,提供轮询机制
|
||||
"""
|
||||
|
||||
import itertools
|
||||
from typing import List, Optional, TypeVar, Generic, Callable
|
||||
from collections.abc import Callable
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("api_key_manager")
|
||||
@@ -16,7 +18,7 @@ class APIKeyManager(Generic[T]):
|
||||
API密钥管理器,支持轮询机制
|
||||
"""
|
||||
|
||||
def __init__(self, api_keys: List[str], client_factory: Callable[[str], T], service_name: str = "Unknown"):
|
||||
def __init__(self, api_keys: list[str], client_factory: Callable[[str], T], service_name: str = "Unknown"):
|
||||
"""
|
||||
初始化API密钥管理器
|
||||
|
||||
@@ -26,8 +28,8 @@ class APIKeyManager(Generic[T]):
|
||||
service_name: 服务名称,用于日志记录
|
||||
"""
|
||||
self.service_name = service_name
|
||||
self.clients: List[T] = []
|
||||
self.client_cycle: Optional[itertools.cycle] = None
|
||||
self.clients: list[T] = []
|
||||
self.client_cycle: itertools.cycle | None = None
|
||||
|
||||
if api_keys:
|
||||
# 过滤有效的API密钥,排除None、空字符串、"None"字符串等
|
||||
@@ -54,7 +56,7 @@ class APIKeyManager(Generic[T]):
|
||||
"""检查是否有可用的客户端"""
|
||||
return bool(self.clients and self.client_cycle)
|
||||
|
||||
def get_next_client(self) -> Optional[T]:
|
||||
def get_next_client(self) -> T | None:
|
||||
"""获取下一个客户端(轮询)"""
|
||||
if not self.is_available():
|
||||
return None
|
||||
@@ -66,7 +68,7 @@ class APIKeyManager(Generic[T]):
|
||||
|
||||
|
||||
def create_api_key_manager_from_config(
|
||||
config_keys: Optional[List[str]], client_factory: Callable[[str], T], service_name: str
|
||||
config_keys: list[str] | None, client_factory: Callable[[str], T], service_name: str
|
||||
) -> APIKeyManager[T]:
|
||||
"""
|
||||
从配置创建API密钥管理器的便捷函数
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
Formatters for web search results
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any
|
||||
from typing import Any
|
||||
|
||||
|
||||
def format_search_results(results: List[Dict[str, Any]]) -> str:
|
||||
def format_search_results(results: list[dict[str, Any]]) -> str:
|
||||
"""
|
||||
格式化搜索结果为字符串
|
||||
"""
|
||||
@@ -26,7 +26,7 @@ def format_search_results(results: List[Dict[str, Any]]) -> str:
|
||||
return formatted_string
|
||||
|
||||
|
||||
def format_url_parse_results(results: List[Dict[str, Any]]) -> str:
|
||||
def format_url_parse_results(results: list[dict[str, Any]]) -> str:
|
||||
"""
|
||||
将成功解析的URL结果列表格式化为一段简洁的文本。
|
||||
"""
|
||||
@@ -45,7 +45,7 @@ def format_url_parse_results(results: List[Dict[str, Any]]) -> str:
|
||||
return "\n---\n".join(formatted_parts)
|
||||
|
||||
|
||||
def deduplicate_results(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
def deduplicate_results(results: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
根据URL去重搜索结果
|
||||
"""
|
||||
|
||||
@@ -3,10 +3,9 @@ URL processing utilities
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
|
||||
def parse_urls_from_input(urls_input) -> List[str]:
|
||||
def parse_urls_from_input(urls_input) -> list[str]:
|
||||
"""
|
||||
从输入中解析URL列表
|
||||
"""
|
||||
@@ -29,7 +28,7 @@ def parse_urls_from_input(urls_input) -> List[str]:
|
||||
return urls
|
||||
|
||||
|
||||
def validate_urls(urls: List[str]) -> List[str]:
|
||||
def validate_urls(urls: list[str]) -> list[str]:
|
||||
"""
|
||||
验证URL格式,返回有效的URL列表
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user