re-style: 格式化代码
This commit is contained in:
committed by
Windpicker-owo
parent
00ba07e0e1
commit
a79253c714
@@ -7,15 +7,15 @@ import asyncio
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from src.chat.express.expression_learner import expression_learner_manager
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_chatter import BaseChatter
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.plugins.built_in.affinity_flow_chatter.planner import ChatterActionPlanner
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.express.expression_learner import expression_learner_manager
|
||||
|
||||
logger = get_logger("affinity_chatter")
|
||||
|
||||
@@ -113,7 +113,7 @@ class AffinityChatter(BaseChatter):
|
||||
"executed_count": 0,
|
||||
}
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
获取处理器统计信息
|
||||
|
||||
@@ -122,7 +122,7 @@ class AffinityChatter(BaseChatter):
|
||||
"""
|
||||
return self.stats.copy()
|
||||
|
||||
def get_planner_stats(self) -> Dict[str, Any]:
|
||||
def get_planner_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
获取规划器统计信息
|
||||
|
||||
@@ -131,7 +131,7 @@ class AffinityChatter(BaseChatter):
|
||||
"""
|
||||
return self.planner.get_planner_stats()
|
||||
|
||||
def get_interest_scoring_stats(self) -> Dict[str, Any]:
|
||||
def get_interest_scoring_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
获取兴趣度评分统计信息
|
||||
|
||||
@@ -140,7 +140,7 @@ class AffinityChatter(BaseChatter):
|
||||
"""
|
||||
return self.planner.get_interest_scoring_stats()
|
||||
|
||||
def get_relationship_stats(self) -> Dict[str, Any]:
|
||||
def get_relationship_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
获取用户关系统计信息
|
||||
|
||||
@@ -158,7 +158,7 @@ class AffinityChatter(BaseChatter):
|
||||
"""
|
||||
return self.planner.get_current_mood_state()
|
||||
|
||||
def get_mood_stats(self) -> Dict[str, Any]:
|
||||
def get_mood_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
获取情绪状态统计信息
|
||||
|
||||
|
||||
@@ -5,11 +5,11 @@
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from typing import Dict, List, Any
|
||||
from typing import Any
|
||||
|
||||
from src.chat.interest_system import bot_interest_manager
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.info_data_model import InterestScore
|
||||
from src.chat.interest_system import bot_interest_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
@@ -47,11 +47,11 @@ class ChatterInterestScoringSystem:
|
||||
) # 每次不回复增加的概率
|
||||
|
||||
# 用户关系数据
|
||||
self.user_relationships: Dict[str, float] = {} # user_id -> relationship_score
|
||||
self.user_relationships: dict[str, float] = {} # user_id -> relationship_score
|
||||
|
||||
async def calculate_interest_scores(
|
||||
self, messages: List[DatabaseMessages], bot_nickname: str
|
||||
) -> List[InterestScore]:
|
||||
self, messages: list[DatabaseMessages], bot_nickname: str
|
||||
) -> list[InterestScore]:
|
||||
"""计算消息的兴趣度评分"""
|
||||
user_messages = [msg for msg in messages if str(msg.user_info.user_id) != str(global_config.bot.qq_account)]
|
||||
if not user_messages:
|
||||
@@ -97,7 +97,7 @@ class ChatterInterestScoringSystem:
|
||||
details=details,
|
||||
)
|
||||
|
||||
async def _calculate_interest_match_score(self, content: str, keywords: List[str] = None) -> float:
|
||||
async def _calculate_interest_match_score(self, content: str, keywords: list[str] = None) -> float:
|
||||
"""计算兴趣匹配度 - 使用智能embedding匹配"""
|
||||
if not content:
|
||||
return 0.0
|
||||
@@ -109,7 +109,7 @@ class ChatterInterestScoringSystem:
|
||||
# 智能匹配未初始化,返回默认分数
|
||||
return 0.3
|
||||
|
||||
async def _calculate_smart_interest_match(self, content: str, keywords: List[str] = None) -> float:
|
||||
async def _calculate_smart_interest_match(self, content: str, keywords: list[str] = None) -> float:
|
||||
"""使用embedding计算智能兴趣匹配"""
|
||||
try:
|
||||
# 如果没有传入关键词,则提取
|
||||
@@ -134,7 +134,7 @@ class ChatterInterestScoringSystem:
|
||||
logger.error(f"智能兴趣匹配计算失败: {e}")
|
||||
return 0.0
|
||||
|
||||
def _extract_keywords_from_database(self, message: DatabaseMessages) -> List[str]:
|
||||
def _extract_keywords_from_database(self, message: DatabaseMessages) -> list[str]:
|
||||
"""从数据库消息中提取关键词"""
|
||||
keywords = []
|
||||
|
||||
@@ -166,7 +166,7 @@ class ChatterInterestScoringSystem:
|
||||
|
||||
return keywords[:15] # 返回前15个关键词
|
||||
|
||||
def _extract_keywords_from_content(self, content: str) -> List[str]:
|
||||
def _extract_keywords_from_content(self, content: str) -> list[str]:
|
||||
"""从内容中提取关键词(降级方案)"""
|
||||
import re
|
||||
|
||||
@@ -287,7 +287,7 @@ class ChatterInterestScoringSystem:
|
||||
"""获取用户关系分"""
|
||||
return self.user_relationships.get(user_id, 0.3)
|
||||
|
||||
def get_scoring_stats(self) -> Dict:
|
||||
def get_scoring_stats(self) -> dict:
|
||||
"""获取评分系统统计"""
|
||||
return {
|
||||
"no_reply_count": self.no_reply_count,
|
||||
@@ -318,7 +318,7 @@ class ChatterInterestScoringSystem:
|
||||
logger.error(f"初始化智能兴趣系统失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
def get_matching_config(self) -> Dict[str, Any]:
|
||||
def get_matching_config(self) -> dict[str, Any]:
|
||||
"""获取匹配配置信息"""
|
||||
return {
|
||||
"use_smart_matching": self.use_smart_matching,
|
||||
|
||||
@@ -5,12 +5,11 @@ PlanExecutor: 接收 Plan 对象并执行其中的所有动作。
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.common.data_models.info_data_model import Plan, ActionPlannerInfo
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo, Plan
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("plan_executor")
|
||||
|
||||
@@ -52,7 +51,7 @@ class ChatterPlanExecutor:
|
||||
"""设置关系追踪器"""
|
||||
self.relationship_tracker = relationship_tracker
|
||||
|
||||
async def execute(self, plan: Plan) -> Dict[str, any]:
|
||||
async def execute(self, plan: Plan) -> dict[str, any]:
|
||||
"""
|
||||
遍历并执行Plan对象中`decided_actions`列表里的所有动作。
|
||||
|
||||
@@ -110,7 +109,7 @@ class ChatterPlanExecutor:
|
||||
"results": execution_results,
|
||||
}
|
||||
|
||||
async def _execute_reply_actions(self, reply_actions: List[ActionPlannerInfo], plan: Plan) -> Dict[str, any]:
|
||||
async def _execute_reply_actions(self, reply_actions: list[ActionPlannerInfo], plan: Plan) -> dict[str, any]:
|
||||
"""串行执行所有回复动作,增加去重逻辑,避免对同一消息多次回复"""
|
||||
results = []
|
||||
|
||||
@@ -162,7 +161,7 @@ class ChatterPlanExecutor:
|
||||
|
||||
async def _execute_single_reply_action(
|
||||
self, action_info: ActionPlannerInfo, plan: Plan, clear_unread: bool = True
|
||||
) -> Dict[str, any]:
|
||||
) -> dict[str, any]:
|
||||
"""执行单个回复动作"""
|
||||
start_time = time.time()
|
||||
success = False
|
||||
@@ -240,7 +239,7 @@ class ChatterPlanExecutor:
|
||||
else reply_content,
|
||||
}
|
||||
|
||||
async def _execute_other_actions(self, other_actions: List[ActionPlannerInfo], plan: Plan) -> Dict[str, any]:
|
||||
async def _execute_other_actions(self, other_actions: list[ActionPlannerInfo], plan: Plan) -> dict[str, any]:
|
||||
"""执行其他动作"""
|
||||
results = []
|
||||
|
||||
@@ -269,7 +268,7 @@ class ChatterPlanExecutor:
|
||||
|
||||
return {"results": results}
|
||||
|
||||
async def _execute_single_other_action(self, action_info: ActionPlannerInfo, plan: Plan) -> Dict[str, any]:
|
||||
async def _execute_single_other_action(self, action_info: ActionPlannerInfo, plan: Plan) -> dict[str, any]:
|
||||
"""执行单个其他动作"""
|
||||
start_time = time.time()
|
||||
success = False
|
||||
@@ -378,7 +377,7 @@ class ChatterPlanExecutor:
|
||||
logger.debug(f"action_message类型: {type(action_info.action_message)}")
|
||||
logger.debug(f"action_message内容: {action_info.action_message}")
|
||||
|
||||
def get_execution_stats(self) -> Dict[str, any]:
|
||||
def get_execution_stats(self) -> dict[str, any]:
|
||||
"""获取执行统计信息"""
|
||||
stats = self.execution_stats.copy()
|
||||
|
||||
@@ -409,7 +408,7 @@ class ChatterPlanExecutor:
|
||||
"execution_times": [],
|
||||
}
|
||||
|
||||
def get_recent_performance(self, limit: int = 10) -> List[Dict[str, any]]:
|
||||
def get_recent_performance(self, limit: int = 10) -> list[dict[str, any]]:
|
||||
"""获取最近的执行性能"""
|
||||
recent_times = self.execution_stats["execution_times"][-limit:]
|
||||
if not recent_times:
|
||||
|
||||
@@ -2,13 +2,13 @@
|
||||
PlanFilter: 接收 Plan 对象,根据不同模式的逻辑进行筛选,决定最终要执行的动作。
|
||||
"""
|
||||
|
||||
import orjson
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
|
||||
# 旧的Hippocampus系统已被移除,现在使用增强记忆系统
|
||||
@@ -39,7 +39,7 @@ class ChatterPlanFilter:
|
||||
根据 Plan 中的模式和信息,筛选并决定最终的动作。
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str, available_actions: List[str]):
|
||||
def __init__(self, chat_id: str, available_actions: list[str]):
|
||||
"""
|
||||
初始化动作计划筛选器。
|
||||
|
||||
@@ -316,8 +316,8 @@ class ChatterPlanFilter:
|
||||
"""构建已读/未读历史消息块"""
|
||||
try:
|
||||
# 从message_manager获取真实的已读/未读消息
|
||||
from src.chat.utils.utils import assign_message_ids
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
|
||||
from src.chat.utils.utils import assign_message_ids
|
||||
|
||||
# 获取聊天流的上下文
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
@@ -392,14 +392,15 @@ class ChatterPlanFilter:
|
||||
logger.error(f"构建已读/未读历史消息块时出错: {e}")
|
||||
return "构建已读历史消息时出错", "构建未读历史消息时出错", []
|
||||
|
||||
async def _get_interest_scores_for_messages(self, messages: List[dict]) -> dict[str, float]:
|
||||
async def _get_interest_scores_for_messages(self, messages: list[dict]) -> dict[str, float]:
|
||||
"""为消息获取兴趣度评分"""
|
||||
interest_scores = {}
|
||||
|
||||
try:
|
||||
from .interest_scoring import chatter_interest_scoring_system
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
from .interest_scoring import chatter_interest_scoring_system
|
||||
|
||||
# 使用插件内部的兴趣度评分系统计算评分
|
||||
for msg_dict in messages:
|
||||
try:
|
||||
@@ -450,7 +451,7 @@ class ChatterPlanFilter:
|
||||
|
||||
async def _parse_single_action(
|
||||
self, action_json: dict, message_id_list: list, plan: Plan
|
||||
) -> List[ActionPlannerInfo]:
|
||||
) -> list[ActionPlannerInfo]:
|
||||
parsed_actions = []
|
||||
try:
|
||||
# 从新的actions结构中获取动作信息
|
||||
@@ -599,7 +600,7 @@ class ChatterPlanFilter:
|
||||
)
|
||||
return parsed_actions
|
||||
|
||||
def _filter_no_actions(self, action_list: List[ActionPlannerInfo]) -> List[ActionPlannerInfo]:
|
||||
def _filter_no_actions(self, action_list: list[ActionPlannerInfo]) -> list[ActionPlannerInfo]:
|
||||
non_no_actions = [a for a in action_list if a.action_type not in ["no_action", "no_reply"]]
|
||||
if non_no_actions:
|
||||
return non_no_actions
|
||||
@@ -653,8 +654,7 @@ class ChatterPlanFilter:
|
||||
logger.error(f"获取长期记忆时出错: {e}")
|
||||
return "回忆时出现了一些问题。"
|
||||
|
||||
@staticmethod
|
||||
async def _build_action_options(current_available_actions: Dict[str, ActionInfo]) -> str:
|
||||
async def _build_action_options(self, current_available_actions: dict[str, ActionInfo]) -> str:
|
||||
action_options_block = ""
|
||||
for action_name, action_info in current_available_actions.items():
|
||||
# 构建参数的JSON示例
|
||||
@@ -725,7 +725,7 @@ class ChatterPlanFilter:
|
||||
)
|
||||
return action_options_block
|
||||
|
||||
def _find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]:
|
||||
def _find_message_by_id(self, message_id: str, message_id_list: list) -> dict[str, Any] | None:
|
||||
"""
|
||||
增强的消息查找函数,支持多种格式和模糊匹配
|
||||
兼容大模型可能返回的各种格式变体
|
||||
@@ -830,13 +830,12 @@ class ChatterPlanFilter:
|
||||
logger.warning(f"未找到任何匹配的消息: {original_id} (候选: {candidate_ids})")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_latest_message(message_id_list: list) -> Optional[Dict[str, Any]]:
|
||||
def _get_latest_message(self, message_id_list: list) -> dict[str, Any] | None:
|
||||
if not message_id_list:
|
||||
return None
|
||||
return message_id_list[-1].get("message")
|
||||
|
||||
def _find_poke_notice(self, message_id_list: list) -> Optional[Dict[str, Any]]:
|
||||
def _find_poke_notice(self, message_id_list: list) -> dict[str, Any] | None:
|
||||
"""在消息列表中寻找戳一戳的通知消息"""
|
||||
for item in reversed(message_id_list):
|
||||
message = item.get("message")
|
||||
|
||||
@@ -3,7 +3,6 @@ PlanGenerator: 负责搜集和汇总所有决策所需的信息,生成一个
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
@@ -85,7 +84,7 @@ class ChatterPlanGenerator:
|
||||
chat_history=[],
|
||||
)
|
||||
|
||||
async def _get_available_actions(self, chat_type: ChatType, mode: ChatMode) -> Dict[str, ActionInfo]:
|
||||
async def _get_available_actions(self, chat_type: ChatType, mode: ChatMode) -> dict[str, ActionInfo]:
|
||||
"""
|
||||
获取当前可用的动作列表。
|
||||
|
||||
@@ -152,7 +151,7 @@ class ChatterPlanGenerator:
|
||||
# 如果获取失败,返回空列表
|
||||
return []
|
||||
|
||||
def get_generator_stats(self) -> Dict:
|
||||
def get_generator_stats(self) -> dict:
|
||||
"""
|
||||
获取生成器统计信息。
|
||||
|
||||
|
||||
@@ -4,22 +4,20 @@
|
||||
"""
|
||||
|
||||
from dataclasses import asdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||
from src.mood.mood_manager import mood_manager
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.common.data_models.info_data_model import Plan
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.common.data_models.info_data_model import Plan
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
|
||||
# 导入提示词模块以确保其被初始化
|
||||
from src.plugins.built_in.affinity_flow_chatter import planner_prompts # noqa
|
||||
@@ -62,7 +60,7 @@ class ChatterActionPlanner:
|
||||
"other_actions_executed": 0,
|
||||
}
|
||||
|
||||
async def plan(self, context: "StreamContext" = None) -> Tuple[List[Dict], Optional[Dict]]:
|
||||
async def plan(self, context: "StreamContext" = None) -> tuple[list[dict], dict | None]:
|
||||
"""
|
||||
执行完整的增强版规划流程。
|
||||
|
||||
@@ -84,7 +82,7 @@ class ChatterActionPlanner:
|
||||
self.planner_stats["failed_plans"] += 1
|
||||
return [], None
|
||||
|
||||
async def _enhanced_plan_flow(self, context: "StreamContext") -> Tuple[List[Dict], Optional[Dict]]:
|
||||
async def _enhanced_plan_flow(self, context: "StreamContext") -> tuple[list[dict], dict | None]:
|
||||
"""执行增强版规划流程"""
|
||||
try:
|
||||
# 在规划前,先进行动作修改
|
||||
@@ -104,7 +102,7 @@ class ChatterActionPlanner:
|
||||
score = 0.0
|
||||
should_reply = False
|
||||
reply_not_available = False
|
||||
interest_updates: List[Dict[str, Any]] = []
|
||||
interest_updates: list[dict[str, Any]] = []
|
||||
|
||||
if unread_messages:
|
||||
# 为每条消息计算兴趣度,并延迟提交数据库更新
|
||||
@@ -193,7 +191,7 @@ class ChatterActionPlanner:
|
||||
self.planner_stats["failed_plans"] += 1
|
||||
return [], None
|
||||
|
||||
async def _commit_interest_updates(self, updates: List[Dict[str, Any]]) -> None:
|
||||
async def _commit_interest_updates(self, updates: list[dict[str, Any]]) -> None:
|
||||
"""统一更新消息兴趣度,减少数据库写入次数"""
|
||||
if not updates:
|
||||
return
|
||||
@@ -220,7 +218,7 @@ class ChatterActionPlanner:
|
||||
except Exception as e:
|
||||
logger.warning(f"批量更新数据库兴趣度失败: {e}")
|
||||
|
||||
def _update_stats_from_execution_result(self, execution_result: Dict[str, any]):
|
||||
def _update_stats_from_execution_result(self, execution_result: dict[str, any]):
|
||||
"""根据执行结果更新规划器统计"""
|
||||
if not execution_result:
|
||||
return
|
||||
@@ -244,7 +242,7 @@ class ChatterActionPlanner:
|
||||
self.planner_stats["replies_generated"] += reply_count
|
||||
self.planner_stats["other_actions_executed"] += other_count
|
||||
|
||||
def _build_return_result(self, plan: "Plan") -> Tuple[List[Dict], Optional[Dict]]:
|
||||
def _build_return_result(self, plan: "Plan") -> tuple[list[dict], dict | None]:
|
||||
"""构建返回结果"""
|
||||
final_actions = plan.decided_actions or []
|
||||
final_target_message = next((act.action_message for act in final_actions if act.action_message), None)
|
||||
@@ -261,7 +259,7 @@ class ChatterActionPlanner:
|
||||
|
||||
return final_actions_dict, final_target_message_dict
|
||||
|
||||
def get_planner_stats(self) -> Dict[str, any]:
|
||||
def get_planner_stats(self) -> dict[str, any]:
|
||||
"""获取规划器统计"""
|
||||
return self.planner_stats.copy()
|
||||
|
||||
@@ -270,7 +268,7 @@ class ChatterActionPlanner:
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(self.chat_id)
|
||||
return chat_mood.mood_state
|
||||
|
||||
def get_mood_stats(self) -> Dict[str, any]:
|
||||
def get_mood_stats(self) -> dict[str, any]:
|
||||
"""获取情绪状态统计"""
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(self.chat_id)
|
||||
return {
|
||||
|
||||
@@ -2,12 +2,10 @@
|
||||
亲和力聊天处理器插件
|
||||
"""
|
||||
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis.plugin_register_api import register_plugin
|
||||
from src.plugin_system.base.base_plugin import BasePlugin
|
||||
from src.plugin_system.base.component_types import ComponentInfo
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("affinity_chatter_plugin")
|
||||
|
||||
@@ -29,7 +27,7 @@ class AffinityChatterPlugin(BasePlugin):
|
||||
# 简单的 config_schema 占位(如果将来需要配置可扩展)
|
||||
config_schema = {}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
|
||||
"""返回插件包含的组件列表(ChatterInfo, AffinityChatter)
|
||||
|
||||
这里采用延迟导入 AffinityChatter 来避免循环依赖和启动顺序问题。
|
||||
|
||||
@@ -5,15 +5,15 @@
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config, global_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import UserRelationships, Messages
|
||||
from sqlalchemy import select, desc
|
||||
from sqlalchemy import desc, select
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import Messages, UserRelationships
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger("chatter_relationship_tracker")
|
||||
|
||||
@@ -22,15 +22,15 @@ class ChatterRelationshipTracker:
|
||||
"""用户关系追踪器"""
|
||||
|
||||
def __init__(self, interest_scoring_system=None):
|
||||
self.tracking_users: Dict[str, Dict] = {} # user_id -> interaction_data
|
||||
self.tracking_users: dict[str, dict] = {} # user_id -> interaction_data
|
||||
self.max_tracking_users = 3
|
||||
self.update_interval_minutes = 30
|
||||
self.last_update_time = time.time()
|
||||
self.relationship_history: List[Dict] = []
|
||||
self.relationship_history: list[dict] = []
|
||||
self.interest_scoring_system = interest_scoring_system
|
||||
|
||||
# 用户关系缓存 (user_id -> {"relationship_text": str, "relationship_score": float, "last_tracked": float})
|
||||
self.user_relationship_cache: Dict[str, Dict] = {}
|
||||
self.user_relationship_cache: dict[str, dict] = {}
|
||||
self.cache_expiry_hours = 1 # 缓存过期时间(小时)
|
||||
|
||||
# 关系更新LLM
|
||||
@@ -91,7 +91,7 @@ class ChatterRelationshipTracker:
|
||||
|
||||
logger.debug(f"添加用户交互追踪: {user_id}")
|
||||
|
||||
async def check_and_update_relationships(self) -> List[Dict]:
|
||||
async def check_and_update_relationships(self) -> list[dict]:
|
||||
"""检查并更新用户关系"""
|
||||
current_time = time.time()
|
||||
if current_time - self.last_update_time < self.update_interval_minutes * 60:
|
||||
@@ -108,7 +108,7 @@ class ChatterRelationshipTracker:
|
||||
self.last_update_time = current_time
|
||||
return updates
|
||||
|
||||
async def _update_user_relationship(self, interaction: Dict) -> Optional[Dict]:
|
||||
async def _update_user_relationship(self, interaction: dict) -> dict | None:
|
||||
"""更新单个用户的关系"""
|
||||
try:
|
||||
# 获取bot人设信息
|
||||
@@ -201,11 +201,11 @@ class ChatterRelationshipTracker:
|
||||
|
||||
return None
|
||||
|
||||
def get_tracking_users(self) -> Dict[str, Dict]:
|
||||
def get_tracking_users(self) -> dict[str, dict]:
|
||||
"""获取正在追踪的用户"""
|
||||
return self.tracking_users.copy()
|
||||
|
||||
def get_user_interaction(self, user_id: str) -> Optional[Dict]:
|
||||
def get_user_interaction(self, user_id: str) -> dict | None:
|
||||
"""获取特定用户的交互记录"""
|
||||
return self.tracking_users.get(user_id)
|
||||
|
||||
@@ -220,11 +220,11 @@ class ChatterRelationshipTracker:
|
||||
self.tracking_users.clear()
|
||||
logger.info("清空所有用户追踪")
|
||||
|
||||
def get_relationship_history(self) -> List[Dict]:
|
||||
def get_relationship_history(self) -> list[dict]:
|
||||
"""获取关系历史记录"""
|
||||
return self.relationship_history.copy()
|
||||
|
||||
def add_to_history(self, relationship_update: Dict):
|
||||
def add_to_history(self, relationship_update: dict):
|
||||
"""添加到关系历史"""
|
||||
self.relationship_history.append({**relationship_update, "update_time": time.time()})
|
||||
|
||||
@@ -232,7 +232,7 @@ class ChatterRelationshipTracker:
|
||||
if len(self.relationship_history) > 100:
|
||||
self.relationship_history = self.relationship_history[-100:]
|
||||
|
||||
def get_tracker_stats(self) -> Dict:
|
||||
def get_tracker_stats(self) -> dict:
|
||||
"""获取追踪器统计"""
|
||||
return {
|
||||
"tracking_users": len(self.tracking_users),
|
||||
@@ -268,7 +268,7 @@ class ChatterRelationshipTracker:
|
||||
self.add_to_history(update_info)
|
||||
logger.info(f"强制更新用户关系: {user_id} -> {new_score:.2f}")
|
||||
|
||||
def get_user_summary(self, user_id: str) -> Dict:
|
||||
def get_user_summary(self, user_id: str) -> dict:
|
||||
"""获取用户交互总结"""
|
||||
if user_id not in self.tracking_users:
|
||||
return {}
|
||||
@@ -313,7 +313,7 @@ class ChatterRelationshipTracker:
|
||||
# 数据库中也没有,返回默认值
|
||||
return global_config.affinity_flow.base_relationship_score
|
||||
|
||||
async def _get_user_relationship_from_db(self, user_id: str) -> Optional[Dict]:
|
||||
async def _get_user_relationship_from_db(self, user_id: str) -> dict | None:
|
||||
"""从数据库获取用户关系数据"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
@@ -431,7 +431,7 @@ class ChatterRelationshipTracker:
|
||||
|
||||
return 0
|
||||
|
||||
async def _get_last_bot_reply_to_user(self, user_id: str) -> Optional[DatabaseMessages]:
|
||||
async def _get_last_bot_reply_to_user(self, user_id: str) -> DatabaseMessages | None:
|
||||
"""获取上次bot回复该用户的消息"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
@@ -455,7 +455,7 @@ class ChatterRelationshipTracker:
|
||||
|
||||
return None
|
||||
|
||||
async def _get_user_reactions_after_reply(self, user_id: str, reply_time: float) -> List[DatabaseMessages]:
|
||||
async def _get_user_reactions_after_reply(self, user_id: str, reply_time: float) -> list[DatabaseMessages]:
|
||||
"""获取用户在bot回复后的反应消息"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
@@ -511,7 +511,7 @@ class ChatterRelationshipTracker:
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
last_bot_reply: DatabaseMessages,
|
||||
user_reactions: List[DatabaseMessages],
|
||||
user_reactions: list[DatabaseMessages],
|
||||
current_text: str,
|
||||
current_score: float,
|
||||
current_reply: str,
|
||||
|
||||
@@ -8,9 +8,9 @@
|
||||
- 测试功能
|
||||
"""
|
||||
|
||||
from src.plugin_system.base import BaseCommand
|
||||
from src.chat.antipromptinjector import get_anti_injector
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base import BaseCommand
|
||||
|
||||
logger = get_logger("anti_injector.commands")
|
||||
|
||||
@@ -56,5 +56,5 @@ class AntiInjectorStatusCommand(BaseCommand):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取反注入系统状态失败: {e}")
|
||||
await self.send_text(f"获取状态失败: {str(e)}")
|
||||
return False, f"获取状态失败: {str(e)}", True
|
||||
await self.send_text(f"获取状态失败: {e!s}")
|
||||
return False, f"获取状态失败: {e!s}", True
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
import random
|
||||
from typing import Tuple
|
||||
|
||||
# 导入新插件系统
|
||||
from src.plugin_system import BaseAction, ActionActivationType, ChatMode
|
||||
from src.chat.emoji_system.emoji_history import add_emoji_to_history, get_recent_emojis
|
||||
from src.chat.emoji_system.emoji_manager import MaiEmoji, get_emoji_manager
|
||||
from src.chat.utils.utils_image import image_path_to_base64
|
||||
|
||||
# 导入依赖的系统组件
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
# 导入新插件系统
|
||||
from src.plugin_system import ActionActivationType, BaseAction, ChatMode
|
||||
|
||||
# 导入API模块 - 标准Python包方式
|
||||
from src.plugin_system.apis import llm_api, message_api
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager, MaiEmoji
|
||||
from src.chat.utils.utils_image import image_path_to_base64
|
||||
from src.config.config import global_config
|
||||
from src.chat.emoji_system.emoji_history import get_recent_emojis, add_emoji_to_history
|
||||
|
||||
|
||||
logger = get_logger("emoji")
|
||||
|
||||
@@ -59,7 +58,7 @@ class EmojiAction(BaseAction):
|
||||
# 关联类型
|
||||
associated_types = ["emoji"]
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""执行表情动作"""
|
||||
logger.info(f"{self.log_prefix} 决定发送表情")
|
||||
|
||||
@@ -286,4 +285,4 @@ class EmojiAction(BaseAction):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 表情动作执行失败: {e}", exc_info=True)
|
||||
return False, f"表情发送失败: {str(e)}"
|
||||
return False, f"表情发送失败: {e!s}"
|
||||
|
||||
@@ -5,19 +5,16 @@
|
||||
这是系统的内置插件,提供基础的聊天交互功能
|
||||
"""
|
||||
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
# 导入新插件系统
|
||||
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
|
||||
|
||||
# 导入依赖的系统组件
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# 导入新插件系统
|
||||
from src.plugin_system import BasePlugin, ComponentInfo, register_plugin
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from src.plugins.built_in.core_actions.anti_injector_manager import AntiInjectorStatusCommand
|
||||
|
||||
# 导入API模块 - 标准Python包方式
|
||||
from src.plugins.built_in.core_actions.emoji import EmojiAction
|
||||
from src.plugins.built_in.core_actions.anti_injector_manager import AntiInjectorStatusCommand
|
||||
|
||||
logger = get_logger("core_actions")
|
||||
|
||||
@@ -62,7 +59,7 @@ class CoreActionsPlugin(BasePlugin):
|
||||
},
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
|
||||
"""返回插件包含的组件列表"""
|
||||
|
||||
# --- 根据配置注册组件 ---
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
||||
from src.plugin_system import BaseTool, ToolParamType
|
||||
|
||||
logger = get_logger("lpmm_get_knowledge_tool")
|
||||
@@ -19,7 +19,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
|
||||
]
|
||||
available_for_llm = global_config.lpmm_knowledge.enable
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行知识库搜索
|
||||
|
||||
Args:
|
||||
@@ -56,7 +56,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
|
||||
return {"type": "lpmm_knowledge", "id": query, "content": content}
|
||||
except Exception as e:
|
||||
# 捕获异常并记录错误
|
||||
logger.error(f"知识库搜索工具执行失败: {str(e)}")
|
||||
logger.error(f"知识库搜索工具执行失败: {e!s}")
|
||||
# 在其他异常情况下,确保 id 仍然是 query (如果它被定义了)
|
||||
query_id = query if "query" in locals() else "unknown_query"
|
||||
return {"type": "info", "id": query_id, "content": f"lpmm知识库搜索失败,炸了: {str(e)}"}
|
||||
return {"type": "info", "id": query_id, "content": f"lpmm知识库搜索失败,炸了: {e!s}"}
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
让框架能够发现并加载子目录中的组件。
|
||||
"""
|
||||
|
||||
from .plugin import MaiZoneRefactoredPlugin as MaiZoneRefactoredPlugin
|
||||
from .actions.send_feed_action import SendFeedAction as SendFeedAction
|
||||
from .actions.read_feed_action import ReadFeedAction as ReadFeedAction
|
||||
from .actions.send_feed_action import SendFeedAction as SendFeedAction
|
||||
from .commands.send_feed_command import SendFeedCommand as SendFeedCommand
|
||||
from .plugin import MaiZoneRefactoredPlugin as MaiZoneRefactoredPlugin
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
阅读说说动作组件
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BaseAction, ActionActivationType, ChatMode
|
||||
from src.plugin_system import ActionActivationType, BaseAction, ChatMode
|
||||
from src.plugin_system.apis import generator_api
|
||||
from src.plugin_system.apis.permission_api import permission_api
|
||||
|
||||
from ..services.manager import get_qzone_service
|
||||
|
||||
logger = get_logger("MaiZone.ReadFeedAction")
|
||||
@@ -41,7 +39,7 @@ class ReadFeedAction(BaseAction):
|
||||
# 使用权限API检查用户是否有阅读说说的权限
|
||||
return await permission_api.check_permission(platform, user_id, "plugin.maizone.read_feed")
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""
|
||||
执行动作的核心逻辑。
|
||||
"""
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
发送说说动作组件
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BaseAction, ActionActivationType, ChatMode
|
||||
from src.plugin_system import ActionActivationType, BaseAction, ChatMode
|
||||
from src.plugin_system.apis import generator_api
|
||||
from src.plugin_system.apis.permission_api import permission_api
|
||||
|
||||
from ..services.manager import get_qzone_service
|
||||
|
||||
logger = get_logger("MaiZone.SendFeedAction")
|
||||
@@ -41,7 +39,7 @@ class SendFeedAction(BaseAction):
|
||||
# 使用权限API检查用户是否有发送说说的权限
|
||||
return await permission_api.check_permission(platform, user_id, "plugin.maizone.send_feed")
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""
|
||||
执行动作的核心逻辑。
|
||||
"""
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
发送说说命令 await self.send_text(f"收到!正在为你生成关于"{topic or '随机'}"的说说,请稍候...【热重载测试成功】")件
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.plus_command import PlusCommand
|
||||
from src.plugin_system.base.command_args import CommandArgs
|
||||
from src.plugin_system.base.plus_command import PlusCommand
|
||||
from src.plugin_system.utils.permission_decorators import require_permission
|
||||
from ..services.manager import get_qzone_service, get_config_getter
|
||||
|
||||
from ..services.manager import get_config_getter, get_qzone_service
|
||||
|
||||
logger = get_logger("MaiZone.SendFeedCommand")
|
||||
|
||||
@@ -28,7 +26,7 @@ class SendFeedCommand(PlusCommand):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@require_permission("plugin.maizone.send_feed")
|
||||
async def execute(self, args: CommandArgs) -> Tuple[bool, str, bool]:
|
||||
async def execute(self, args: CommandArgs) -> tuple[bool, str, bool]:
|
||||
"""
|
||||
执行命令的核心逻辑。
|
||||
"""
|
||||
|
||||
@@ -1,28 +1,26 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
MaiZone(麦麦空间)- 重构版
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BasePlugin, ComponentInfo, register_plugin
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from src.plugin_system.apis.permission_api import permission_api
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
|
||||
from .actions.read_feed_action import ReadFeedAction
|
||||
from .actions.send_feed_action import SendFeedAction
|
||||
from .commands.send_feed_command import SendFeedCommand
|
||||
from .services.content_service import ContentService
|
||||
from .services.image_service import ImageService
|
||||
from .services.qzone_service import QZoneService
|
||||
from .services.scheduler_service import SchedulerService
|
||||
from .services.monitor_service import MonitorService
|
||||
from .services.cookie_service import CookieService
|
||||
from .services.reply_tracker_service import ReplyTrackerService
|
||||
from .services.image_service import ImageService
|
||||
from .services.manager import register_service
|
||||
from .services.monitor_service import MonitorService
|
||||
from .services.qzone_service import QZoneService
|
||||
from .services.reply_tracker_service import ReplyTrackerService
|
||||
from .services.scheduler_service import SchedulerService
|
||||
|
||||
logger = get_logger("MaiZone.Plugin")
|
||||
|
||||
@@ -35,8 +33,8 @@ class MaiZoneRefactoredPlugin(BasePlugin):
|
||||
plugin_description: str = "重构版的MaiZone插件"
|
||||
config_file_name: str = "config.toml"
|
||||
enable_plugin: bool = True
|
||||
dependencies: List[str] = []
|
||||
python_dependencies: List[str] = []
|
||||
dependencies: list[str] = []
|
||||
python_dependencies: list[str] = []
|
||||
|
||||
config_schema: dict = {
|
||||
"plugin": {"enable": ConfigField(type=bool, default=True, description="是否启用插件")},
|
||||
@@ -125,7 +123,7 @@ class MaiZoneRefactoredPlugin(BasePlugin):
|
||||
asyncio.create_task(monitor_service.start())
|
||||
logger.info("MaiZone后台监控和定时任务已启动。")
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
|
||||
return [
|
||||
(SendFeedAction.get_action_info(), SendFeedAction),
|
||||
(ReadFeedAction.get_action_info(), ReadFeedAction),
|
||||
|
||||
@@ -1,23 +1,23 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
内容服务模块
|
||||
负责生成所有与QQ空间相关的文本内容,例如说说、评论等。
|
||||
"""
|
||||
|
||||
from typing import Callable, Optional
|
||||
import datetime
|
||||
|
||||
import base64
|
||||
import aiohttp
|
||||
from src.common.logger import get_logger
|
||||
import imghdr
|
||||
import asyncio
|
||||
from src.plugin_system.apis import llm_api, config_api, generator_api
|
||||
from src.plugin_system.apis.cross_context_api import get_chat_history_by_group_name
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
import base64
|
||||
import datetime
|
||||
import imghdr
|
||||
from collections.abc import Callable
|
||||
|
||||
import aiohttp
|
||||
from maim_message import UserInfo
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.api_ada_configs import TaskConfig
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.plugin_system.apis import config_api, generator_api, llm_api
|
||||
from src.plugin_system.apis.cross_context_api import get_chat_history_by_group_name
|
||||
|
||||
# 导入旧的工具函数,我们稍后会考虑是否也需要重构它
|
||||
from ..utils.history_utils import get_send_history
|
||||
@@ -38,7 +38,7 @@ class ContentService:
|
||||
"""
|
||||
self.get_config = get_config
|
||||
|
||||
async def generate_story(self, topic: str, context: Optional[str] = None) -> str:
|
||||
async def generate_story(self, topic: str, context: str | None = None) -> str:
|
||||
"""
|
||||
根据指定主题和可选的上下文生成一条QQ空间说说。
|
||||
|
||||
@@ -231,7 +231,7 @@ class ContentService:
|
||||
return ""
|
||||
return ""
|
||||
|
||||
async def _describe_image(self, image_url: str) -> Optional[str]:
|
||||
async def _describe_image(self, image_url: str) -> str | None:
|
||||
"""
|
||||
使用LLM识别图片内容。
|
||||
"""
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Cookie服务模块
|
||||
负责从多种来源获取、缓存和管理QZone的Cookie。
|
||||
"""
|
||||
|
||||
import orjson
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Dict
|
||||
|
||||
import aiohttp
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
@@ -29,28 +29,28 @@ class CookieService:
|
||||
"""获取指定QQ账号的cookie文件路径"""
|
||||
return self.cookie_dir / f"cookies-{qq_account}.json"
|
||||
|
||||
def _save_cookies_to_file(self, qq_account: str, cookies: Dict[str, str]):
|
||||
def _save_cookies_to_file(self, qq_account: str, cookies: dict[str, str]):
|
||||
"""将Cookie保存到本地文件"""
|
||||
cookie_file_path = self._get_cookie_file_path(qq_account)
|
||||
try:
|
||||
with open(cookie_file_path, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(cookies, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
logger.info(f"Cookie已成功缓存至: {cookie_file_path}")
|
||||
except IOError as e:
|
||||
except OSError as e:
|
||||
logger.error(f"无法写入Cookie文件 {cookie_file_path}: {e}")
|
||||
|
||||
def _load_cookies_from_file(self, qq_account: str) -> Optional[Dict[str, str]]:
|
||||
def _load_cookies_from_file(self, qq_account: str) -> dict[str, str] | None:
|
||||
"""从本地文件加载Cookie"""
|
||||
cookie_file_path = self._get_cookie_file_path(qq_account)
|
||||
if cookie_file_path.exists():
|
||||
try:
|
||||
with open(cookie_file_path, "r", encoding="utf-8") as f:
|
||||
with open(cookie_file_path, encoding="utf-8") as f:
|
||||
return orjson.loads(f.read())
|
||||
except (IOError, orjson.JSONDecodeError) as e:
|
||||
except (OSError, orjson.JSONDecodeError) as e:
|
||||
logger.error(f"无法读取或解析Cookie文件 {cookie_file_path}: {e}")
|
||||
return None
|
||||
|
||||
async def _get_cookies_from_adapter(self, stream_id: Optional[str]) -> Optional[Dict[str, str]]:
|
||||
async def _get_cookies_from_adapter(self, stream_id: str | None) -> dict[str, str] | None:
|
||||
"""通过Adapter API获取Cookie"""
|
||||
try:
|
||||
params = {"domain": "user.qzone.qq.com"}
|
||||
@@ -73,7 +73,7 @@ class CookieService:
|
||||
logger.error(f"通过Adapter获取Cookie时发生异常: {e}")
|
||||
return None
|
||||
|
||||
async def _get_cookies_from_http(self) -> Optional[Dict[str, str]]:
|
||||
async def _get_cookies_from_http(self) -> dict[str, str] | None:
|
||||
"""通过备用HTTP端点获取Cookie"""
|
||||
host = self.get_config("cookie.http_fallback_host", "172.20.130.55")
|
||||
port = self.get_config("cookie.http_fallback_port", "9999")
|
||||
@@ -110,7 +110,7 @@ class CookieService:
|
||||
logger.error(f"通过HTTP备用地址 {http_url} 获取Cookie失败: {e}")
|
||||
return None
|
||||
|
||||
async def get_cookies(self, qq_account: str, stream_id: Optional[str]) -> Optional[Dict[str, str]]:
|
||||
async def get_cookies(self, qq_account: str, stream_id: str | None) -> dict[str, str] | None:
|
||||
"""
|
||||
获取Cookie,按以下顺序尝试:
|
||||
1. HTTP备用端点 (更稳定)
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
图片服务模块
|
||||
负责处理所有与图片相关的任务,特别是AI生成图片。
|
||||
"""
|
||||
|
||||
import base64
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import aiohttp
|
||||
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
服务管理器/定位器
|
||||
这是一个独立的模块,用于注册和获取插件内的全局服务实例,以避免循环导入。
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Callable
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from .qzone_service import QZoneService
|
||||
|
||||
# --- 全局服务注册表 ---
|
||||
_services: Dict[str, Any] = {}
|
||||
_services: dict[str, Any] = {}
|
||||
|
||||
|
||||
def register_service(name: str, instance: Any):
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
好友动态监控服务
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import Callable
|
||||
from collections.abc import Callable
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .qzone_service import QZoneService
|
||||
|
||||
logger = get_logger("MaiZone.MonitorService")
|
||||
|
||||
@@ -1,32 +1,33 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
QQ空间服务模块
|
||||
封装了所有与QQ空间API的直接交互,是插件的核心业务逻辑层。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import orjson
|
||||
import base64
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import base64
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Dict, Any, List, Tuple
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import bs4
|
||||
import json5
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api, person_api
|
||||
import orjson
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api, person_api
|
||||
|
||||
from .content_service import ContentService
|
||||
from .image_service import ImageService
|
||||
from .cookie_service import CookieService
|
||||
from .image_service import ImageService
|
||||
from .reply_tracker_service import ReplyTrackerService
|
||||
|
||||
logger = get_logger("MaiZone.QZoneService")
|
||||
@@ -64,7 +65,7 @@ class QZoneService:
|
||||
|
||||
# --- Public Methods (High-Level Business Logic) ---
|
||||
|
||||
async def send_feed(self, topic: str, stream_id: Optional[str]) -> Dict[str, Any]:
|
||||
async def send_feed(self, topic: str, stream_id: str | None) -> dict[str, Any]:
|
||||
"""发送一条说说"""
|
||||
# --- 获取互通组上下文 ---
|
||||
context = await self._get_intercom_context(stream_id) if stream_id else None
|
||||
@@ -92,7 +93,7 @@ class QZoneService:
|
||||
logger.error(f"发布说说时发生异常: {e}", exc_info=True)
|
||||
return {"success": False, "message": f"发布说说异常: {e}"}
|
||||
|
||||
async def send_feed_from_activity(self, activity: str) -> Dict[str, Any]:
|
||||
async def send_feed_from_activity(self, activity: str) -> dict[str, Any]:
|
||||
"""根据日程活动发送一条说说"""
|
||||
story = await self.content_service.generate_story_from_activity(activity)
|
||||
if not story:
|
||||
@@ -118,7 +119,7 @@ class QZoneService:
|
||||
logger.error(f"根据活动发布说说时发生异常: {e}", exc_info=True)
|
||||
return {"success": False, "message": f"发布说说异常: {e}"}
|
||||
|
||||
async def read_and_process_feeds(self, target_name: str, stream_id: Optional[str]) -> Dict[str, Any]:
|
||||
async def read_and_process_feeds(self, target_name: str, stream_id: str | None) -> dict[str, Any]:
|
||||
"""读取并处理指定好友的说说"""
|
||||
target_person_id = await person_api.get_person_id_by_name(target_name)
|
||||
if not target_person_id:
|
||||
@@ -147,7 +148,7 @@ class QZoneService:
|
||||
logger.error(f"读取和处理说说时发生异常: {e}", exc_info=True)
|
||||
return {"success": False, "message": f"处理说说异常: {e}"}
|
||||
|
||||
async def monitor_feeds(self, stream_id: Optional[str] = None):
|
||||
async def monitor_feeds(self, stream_id: str | None = None):
|
||||
"""监控并处理所有好友的动态,包括回复自己说说的评论"""
|
||||
logger.info("开始执行好友动态监控...")
|
||||
qq_account = config_api.get_global_config("bot.qq_account", "")
|
||||
@@ -189,7 +190,7 @@ class QZoneService:
|
||||
|
||||
# --- Internal Helper Methods ---
|
||||
|
||||
async def _get_intercom_context(self, stream_id: str) -> Optional[str]:
|
||||
async def _get_intercom_context(self, stream_id: str) -> str | None:
|
||||
"""
|
||||
根据 stream_id 查找其所属的互通组,并构建该组的聊天上下文。
|
||||
|
||||
@@ -247,7 +248,7 @@ class QZoneService:
|
||||
logger.debug(f"Stream ID '{stream_id}' 未在任何互通组中找到。")
|
||||
return None
|
||||
|
||||
async def _reply_to_own_feed_comments(self, feed: Dict, api_client: Dict):
|
||||
async def _reply_to_own_feed_comments(self, feed: dict, api_client: dict):
|
||||
"""处理对自己说说的评论并进行回复"""
|
||||
qq_account = config_api.get_global_config("bot.qq_account", "")
|
||||
comments = feed.get("comments", [])
|
||||
@@ -309,7 +310,7 @@ class QZoneService:
|
||||
if comment_key in self.processing_comments:
|
||||
self.processing_comments.remove(comment_key)
|
||||
|
||||
async def _validate_and_cleanup_reply_records(self, fid: str, my_replies: List[Dict]):
|
||||
async def _validate_and_cleanup_reply_records(self, fid: str, my_replies: list[dict]):
|
||||
"""验证并清理已删除的回复记录"""
|
||||
# 获取当前记录中该说说的所有已回复评论ID
|
||||
recorded_replied_comments = self.reply_tracker.get_replied_comments(fid)
|
||||
@@ -333,7 +334,7 @@ class QZoneService:
|
||||
self.reply_tracker.remove_reply_record(fid, comment_tid)
|
||||
logger.debug(f"已清理删除的回复记录: feed_id={fid}, comment_id={comment_tid}")
|
||||
|
||||
async def _process_single_feed(self, feed: Dict, api_client: Dict, target_qq: str, target_name: str):
|
||||
async def _process_single_feed(self, feed: dict, api_client: dict, target_qq: str, target_name: str):
|
||||
"""处理单条说说,决定是否评论和点赞"""
|
||||
content = feed.get("content", "")
|
||||
fid = feed.get("tid", "")
|
||||
@@ -371,7 +372,7 @@ class QZoneService:
|
||||
if random.random() <= self.get_config("read.like_possibility", 1.0):
|
||||
await api_client["like"](target_qq, fid)
|
||||
|
||||
def _load_local_images(self, image_dir: str) -> List[bytes]:
|
||||
def _load_local_images(self, image_dir: str) -> list[bytes]:
|
||||
"""随机加载本地图片(不删除文件)"""
|
||||
images = []
|
||||
if not image_dir or not os.path.exists(image_dir):
|
||||
@@ -432,7 +433,7 @@ class QZoneService:
|
||||
hash_val += (hash_val << 5) + ord(char)
|
||||
return str(hash_val & 2147483647)
|
||||
|
||||
async def _renew_and_load_cookies(self, qq_account: str, stream_id: Optional[str]) -> Optional[Dict[str, str]]:
|
||||
async def _renew_and_load_cookies(self, qq_account: str, stream_id: str | None) -> dict[str, str] | None:
|
||||
cookie_dir = Path(__file__).resolve().parent.parent / "cookies"
|
||||
cookie_dir.mkdir(exist_ok=True)
|
||||
cookie_file_path = cookie_dir / f"cookies-{qq_account}.json"
|
||||
@@ -480,7 +481,7 @@ class QZoneService:
|
||||
logger.error("所有获取Cookie的方式均失败。")
|
||||
return None
|
||||
|
||||
async def _fetch_cookies_http(self, host: str, port: int, napcat_token: str) -> Optional[Dict]:
|
||||
async def _fetch_cookies_http(self, host: str, port: int, napcat_token: str) -> dict | None:
|
||||
"""通过HTTP服务器获取Cookie"""
|
||||
# 从配置中读取主机和端口,如果未提供则使用传入的参数
|
||||
final_host = self.get_config("cookie.http_fallback_host", host)
|
||||
@@ -515,19 +516,19 @@ class QZoneService:
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(f"无法连接到Napcat服务(尝试 {attempt + 1}/{max_retries}): {url},错误: {str(e)}")
|
||||
logger.warning(f"无法连接到Napcat服务(尝试 {attempt + 1}/{max_retries}): {url},错误: {e!s}")
|
||||
await asyncio.sleep(retry_delay)
|
||||
retry_delay *= 2
|
||||
continue
|
||||
logger.error(f"无法连接到Napcat服务(最终尝试): {url},错误: {str(e)}")
|
||||
logger.error(f"无法连接到Napcat服务(最终尝试): {url},错误: {e!s}")
|
||||
raise RuntimeError(f"无法连接到Napcat服务: {url}") from e
|
||||
except Exception as e:
|
||||
logger.error(f"获取cookie异常: {str(e)}")
|
||||
logger.error(f"获取cookie异常: {e!s}")
|
||||
raise
|
||||
|
||||
raise RuntimeError(f"无法连接到Napcat服务: 超过最大重试次数({max_retries})")
|
||||
|
||||
async def _get_api_client(self, qq_account: str, stream_id: Optional[str]) -> Optional[Dict]:
|
||||
async def _get_api_client(self, qq_account: str, stream_id: str | None) -> dict | None:
|
||||
cookies = await self.cookie_service.get_cookies(qq_account, stream_id)
|
||||
if not cookies:
|
||||
logger.error(
|
||||
@@ -559,7 +560,7 @@ class QZoneService:
|
||||
response.raise_for_status()
|
||||
return await response.text()
|
||||
|
||||
async def _publish(content: str, images: List[bytes]) -> Tuple[bool, str]:
|
||||
async def _publish(content: str, images: list[bytes]) -> tuple[bool, str]:
|
||||
"""发布说说"""
|
||||
try:
|
||||
post_data = {
|
||||
@@ -660,7 +661,7 @@ class QZoneService:
|
||||
|
||||
return picbo, richval
|
||||
|
||||
async def _upload_image(image_bytes: bytes, index: int) -> Optional[Dict[str, str]]:
|
||||
async def _upload_image(image_bytes: bytes, index: int) -> dict[str, str] | None:
|
||||
"""上传图片到QQ空间(完全按照原版实现)"""
|
||||
try:
|
||||
upload_url = "https://up.qzone.qq.com/cgi-bin/upload/cgi_upload_image"
|
||||
@@ -745,7 +746,7 @@ class QZoneService:
|
||||
logger.error(f"上传图片 {index + 1} 异常: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def _list_feeds(t_qq: str, num: int) -> List[Dict]:
|
||||
async def _list_feeds(t_qq: str, num: int) -> list[dict]:
|
||||
"""获取指定用户说说列表 (统一接口)"""
|
||||
try:
|
||||
# 统一使用 format=json 获取完整评论
|
||||
@@ -920,7 +921,7 @@ class QZoneService:
|
||||
logger.error(f"回复评论异常: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def _monitor_list_feeds(num: int) -> List[Dict]:
|
||||
async def _monitor_list_feeds(num: int) -> list[dict]:
|
||||
"""监控好友动态"""
|
||||
try:
|
||||
params = {
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
评论回复跟踪服务
|
||||
负责记录和管理已回复过的评论ID,避免重复回复
|
||||
@@ -7,7 +6,8 @@
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Set, Dict, Any, Union
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("MaiZone.ReplyTrackerService")
|
||||
@@ -27,7 +27,7 @@ class ReplyTrackerService:
|
||||
|
||||
# 内存中的已回复评论记录
|
||||
# 格式: {feed_id: {comment_id: timestamp, ...}, ...}
|
||||
self.replied_comments: Dict[str, Dict[str, float]] = {}
|
||||
self.replied_comments: dict[str, dict[str, float]] = {}
|
||||
|
||||
# 数据清理配置
|
||||
self.max_record_days = 30 # 保留30天的记录
|
||||
@@ -64,7 +64,7 @@ class ReplyTrackerService:
|
||||
try:
|
||||
if self.reply_record_file.exists():
|
||||
try:
|
||||
with open(self.reply_record_file, "r", encoding="utf-8") as f:
|
||||
with open(self.reply_record_file, encoding="utf-8") as f:
|
||||
file_content = f.read().strip()
|
||||
if not file_content: # 文件为空
|
||||
logger.warning("回复记录文件为空,将创建新的记录")
|
||||
@@ -173,7 +173,7 @@ class ReplyTrackerService:
|
||||
if total_removed > 0:
|
||||
logger.info(f"清理了 {total_removed} 条超过{self.max_record_days}天的过期回复记录")
|
||||
|
||||
def has_replied(self, feed_id: str, comment_id: Union[str, int]) -> bool:
|
||||
def has_replied(self, feed_id: str, comment_id: str | int) -> bool:
|
||||
"""
|
||||
检查是否已经回复过指定的评论
|
||||
|
||||
@@ -190,7 +190,7 @@ class ReplyTrackerService:
|
||||
comment_id_str = str(comment_id)
|
||||
return feed_id in self.replied_comments and comment_id_str in self.replied_comments[feed_id]
|
||||
|
||||
def mark_as_replied(self, feed_id: str, comment_id: Union[str, int]):
|
||||
def mark_as_replied(self, feed_id: str, comment_id: str | int):
|
||||
"""
|
||||
标记指定评论为已回复
|
||||
|
||||
@@ -219,7 +219,7 @@ class ReplyTrackerService:
|
||||
else:
|
||||
logger.error(f"标记评论时数据验证失败: feed_id={feed_id}, comment_id={comment_id}")
|
||||
|
||||
def get_replied_comments(self, feed_id: str) -> Set[str]:
|
||||
def get_replied_comments(self, feed_id: str) -> set[str]:
|
||||
"""
|
||||
获取指定说说下所有已回复的评论ID
|
||||
|
||||
@@ -234,7 +234,7 @@ class ReplyTrackerService:
|
||||
return {str(comment_id) for comment_id in self.replied_comments[feed_id].keys()}
|
||||
return set()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
获取回复记录统计信息
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
定时任务服务
|
||||
根据日程表定时发送说说。
|
||||
@@ -8,13 +7,14 @@ import asyncio
|
||||
import datetime
|
||||
import random
|
||||
import traceback
|
||||
from typing import Callable
|
||||
from collections.abc import Callable
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import MaiZoneScheduleStatus
|
||||
from src.common.logger import get_logger
|
||||
from src.schedule.schedule_manager import schedule_manager
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from sqlalchemy import select
|
||||
from src.common.database.sqlalchemy_models import MaiZoneScheduleStatus
|
||||
|
||||
from .qzone_service import QZoneService
|
||||
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
历史记录工具模块
|
||||
提供用于获取QQ空间发送历史的功能。
|
||||
"""
|
||||
|
||||
import orjson
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
import requests
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("MaiZone.HistoryUtils")
|
||||
@@ -26,11 +26,11 @@ class _CookieManager:
|
||||
return str(cookie_dir / f"cookies-{uin}.json")
|
||||
|
||||
@staticmethod
|
||||
def load_cookies(qq_account: str) -> Optional[Dict[str, str]]:
|
||||
def load_cookies(qq_account: str) -> dict[str, str] | None:
|
||||
cookie_file = _CookieManager.get_cookie_file_path(qq_account)
|
||||
if os.path.exists(cookie_file):
|
||||
try:
|
||||
with open(cookie_file, "r", encoding="utf-8") as f:
|
||||
with open(cookie_file, encoding="utf-8") as f:
|
||||
return orjson.loads(f.read())
|
||||
except Exception as e:
|
||||
logger.error(f"加载Cookie文件失败: {e}")
|
||||
@@ -42,7 +42,7 @@ class _SimpleQZoneAPI:
|
||||
|
||||
LIST_URL = "https://user.qzone.qq.com/proxy/domain/taotao.qq.com/cgi-bin/emotion_cgi_msglist_v6"
|
||||
|
||||
def __init__(self, cookies_dict: Optional[Dict[str, str]] = None):
|
||||
def __init__(self, cookies_dict: dict[str, str] | None = None):
|
||||
self.cookies = cookies_dict or {}
|
||||
self.gtk2 = ""
|
||||
p_skey = self.cookies.get("p_skey") or self.cookies.get("p_skey".upper())
|
||||
@@ -55,7 +55,7 @@ class _SimpleQZoneAPI:
|
||||
hash_val += (hash_val << 5) + ord(char)
|
||||
return str(hash_val & 2147483647)
|
||||
|
||||
def get_feed_list(self, target_qq: str, num: int) -> List[Dict[str, Any]]:
|
||||
def get_feed_list(self, target_qq: str, num: int) -> list[dict[str, Any]]:
|
||||
try:
|
||||
params = {
|
||||
"g_tk": self.gtk2,
|
||||
|
||||
@@ -6,19 +6,17 @@
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
from src.plugin_system.apis.logging_api import get_logger
|
||||
from src.plugin_system.apis.permission_api import permission_api
|
||||
from src.plugin_system.apis.plugin_register_api import register_plugin
|
||||
from src.plugin_system.base.base_plugin import BasePlugin
|
||||
from src.plugin_system.base.plus_command import PlusCommand
|
||||
from src.plugin_system.base.command_args import CommandArgs
|
||||
from src.plugin_system.apis.permission_api import permission_api
|
||||
from src.plugin_system.apis.logging_api import get_logger
|
||||
from src.plugin_system.base.component_types import PlusCommandInfo, ChatType
|
||||
from src.plugin_system.base.component_types import ChatType, PlusCommandInfo
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from src.plugin_system.base.plus_command import PlusCommand
|
||||
from src.plugin_system.utils.permission_decorators import require_permission
|
||||
|
||||
|
||||
logger = get_logger("Permission")
|
||||
|
||||
|
||||
@@ -44,7 +42,7 @@ class PermissionCommand(PlusCommand):
|
||||
"plugin.permission.view", "权限查看:可以查看权限节点和用户权限信息", "permission_manager", True
|
||||
)
|
||||
|
||||
async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]:
|
||||
async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]:
|
||||
"""执行权限管理命令"""
|
||||
if args.is_empty:
|
||||
await self._show_help()
|
||||
@@ -114,7 +112,7 @@ class PermissionCommand(PlusCommand):
|
||||
await self.send_text(help_text)
|
||||
|
||||
@staticmethod
|
||||
def _parse_user_mention(mention: str) -> Optional[str]:
|
||||
def _parse_user_mention(mention: str) -> str | None:
|
||||
"""解析用户提及,提取QQ号
|
||||
|
||||
支持的格式:
|
||||
@@ -134,7 +132,7 @@ class PermissionCommand(PlusCommand):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_user_from_args(args: CommandArgs, index: int = 0) -> Optional[str]:
|
||||
def parse_user_from_args(args: CommandArgs, index: int = 0) -> str | None:
|
||||
"""从CommandArgs中解析用户ID
|
||||
|
||||
Args:
|
||||
@@ -166,7 +164,7 @@ class PermissionCommand(PlusCommand):
|
||||
return None
|
||||
|
||||
@require_permission("plugin.permission.manage", "❌ 你没有权限管理的权限")
|
||||
async def _grant_permission(self, chat_stream, args: List[str]):
|
||||
async def _grant_permission(self, chat_stream, args: list[str]):
|
||||
"""授权用户权限"""
|
||||
if len(args) < 2:
|
||||
await self.send_text("❌ 用法: /permission grant <@用户|QQ号> <权限节点>")
|
||||
@@ -189,7 +187,7 @@ class PermissionCommand(PlusCommand):
|
||||
await self.send_text("❌ 授权失败,请检查权限节点是否存在")
|
||||
|
||||
@require_permission("plugin.permission.manage", "❌ 你没有权限管理的权限")
|
||||
async def _revoke_permission(self, chat_stream, args: List[str]):
|
||||
async def _revoke_permission(self, chat_stream, args: list[str]):
|
||||
"""撤销用户权限"""
|
||||
if len(args) < 2:
|
||||
await self.send_text("❌ 用法: /permission revoke <@用户|QQ号> <权限节点>")
|
||||
@@ -212,7 +210,7 @@ class PermissionCommand(PlusCommand):
|
||||
await self.send_text("❌ 撤销失败,请检查权限节点是否存在")
|
||||
|
||||
@require_permission("plugin.permission.view", "❌ 你没有查看权限的权限")
|
||||
async def _list_permissions(self, chat_stream, args: List[str]):
|
||||
async def _list_permissions(self, chat_stream, args: list[str]):
|
||||
"""列出用户权限"""
|
||||
target_user_id = None
|
||||
|
||||
@@ -244,7 +242,7 @@ class PermissionCommand(PlusCommand):
|
||||
await self.send_text(response)
|
||||
|
||||
@require_permission("plugin.permission.view", "❌ 你没有查看权限的权限")
|
||||
async def _check_permission(self, chat_stream, args: List[str]):
|
||||
async def _check_permission(self, chat_stream, args: list[str]):
|
||||
"""检查用户权限"""
|
||||
if len(args) < 2:
|
||||
await self.send_text("❌ 用法: /permission check <@用户|QQ号> <权限节点>")
|
||||
@@ -273,7 +271,7 @@ class PermissionCommand(PlusCommand):
|
||||
await self.send_text(response)
|
||||
|
||||
@require_permission("plugin.permission.view", "❌ 你没有查看权限的权限")
|
||||
async def _list_nodes(self, chat_stream, args: List[str]):
|
||||
async def _list_nodes(self, chat_stream, args: list[str]):
|
||||
"""列出权限节点"""
|
||||
plugin_name = args[0] if args else None
|
||||
|
||||
@@ -388,6 +386,6 @@ class PermissionManagerPlugin(BasePlugin):
|
||||
}
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[PlusCommandInfo, Type[PlusCommand]]]:
|
||||
def get_plugin_components(self) -> list[tuple[PlusCommandInfo, type[PlusCommand]]]:
|
||||
"""返回插件的PlusCommand组件"""
|
||||
return [(PermissionCommand.get_plus_command_info(), PermissionCommand)]
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
import asyncio
|
||||
|
||||
from typing import List, Tuple, Type
|
||||
from src.plugin_system import (
|
||||
BasePlugin,
|
||||
ConfigField,
|
||||
register_plugin,
|
||||
plugin_manage_api,
|
||||
component_manage_api,
|
||||
ComponentInfo,
|
||||
ComponentType,
|
||||
ConfigField,
|
||||
component_manage_api,
|
||||
plugin_manage_api,
|
||||
register_plugin,
|
||||
)
|
||||
from src.plugin_system.base.plus_command import PlusCommand
|
||||
from src.plugin_system.base.command_args import CommandArgs
|
||||
from src.plugin_system.base.component_types import PlusCommandInfo, ChatType
|
||||
from src.plugin_system.apis.permission_api import permission_api
|
||||
from src.plugin_system.base.command_args import CommandArgs
|
||||
from src.plugin_system.base.component_types import ChatType, PlusCommandInfo
|
||||
from src.plugin_system.base.plus_command import PlusCommand
|
||||
from src.plugin_system.utils.permission_decorators import require_permission
|
||||
|
||||
|
||||
@@ -31,7 +30,7 @@ class ManagementCommand(PlusCommand):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@require_permission("plugin.management.admin", "❌ 你没有插件管理的权限")
|
||||
async def execute(self, args: CommandArgs) -> Tuple[bool, str, bool]:
|
||||
async def execute(self, args: CommandArgs) -> tuple[bool, str, bool]:
|
||||
"""执行插件管理命令"""
|
||||
if args.is_empty:
|
||||
await self._show_help("all")
|
||||
@@ -51,7 +50,7 @@ class ManagementCommand(PlusCommand):
|
||||
await self.send_text(f"❌ 未知的子命令: {subcommand}\n使用 /pm help 查看帮助")
|
||||
return True, "未知子命令", True
|
||||
|
||||
async def _handle_plugin_commands(self, args: List[str]) -> Tuple[bool, str, bool]:
|
||||
async def _handle_plugin_commands(self, args: list[str]) -> tuple[bool, str, bool]:
|
||||
"""处理插件相关命令"""
|
||||
if not args:
|
||||
await self._show_help("plugin")
|
||||
@@ -83,7 +82,7 @@ class ManagementCommand(PlusCommand):
|
||||
|
||||
return True, "插件命令执行完成", True
|
||||
|
||||
async def _handle_component_commands(self, args: List[str]) -> Tuple[bool, str, bool]:
|
||||
async def _handle_component_commands(self, args: list[str]) -> tuple[bool, str, bool]:
|
||||
"""处理组件相关命令"""
|
||||
if not args:
|
||||
await self._show_help("component")
|
||||
@@ -258,7 +257,7 @@ class ManagementCommand(PlusCommand):
|
||||
else:
|
||||
await self.send_text(f"❌ 插件强制重载失败: `{plugin_name}`")
|
||||
except Exception as e:
|
||||
await self.send_text(f"❌ 强制重载过程中发生错误: {str(e)}")
|
||||
await self.send_text(f"❌ 强制重载过程中发生错误: {e!s}")
|
||||
|
||||
async def _add_dir(self, dir_path: str):
|
||||
"""添加插件目录"""
|
||||
@@ -271,17 +270,17 @@ class ManagementCommand(PlusCommand):
|
||||
await self.send_text(f"❌ 插件目录添加失败: `{dir_path}`")
|
||||
|
||||
@staticmethod
|
||||
def _fetch_all_registered_components() -> List[ComponentInfo]:
|
||||
def _fetch_all_registered_components() -> list[ComponentInfo]:
|
||||
all_plugin_info = component_manage_api.get_all_plugin_info()
|
||||
if not all_plugin_info:
|
||||
return []
|
||||
|
||||
components_info: List[ComponentInfo] = []
|
||||
components_info: list[ComponentInfo] = []
|
||||
for plugin_info in all_plugin_info.values():
|
||||
components_info.extend(plugin_info.components)
|
||||
return components_info
|
||||
|
||||
def _fetch_locally_disabled_components(self) -> List[str]:
|
||||
def _fetch_locally_disabled_components(self) -> list[str]:
|
||||
"""获取本地禁用的组件列表"""
|
||||
stream_id = self.message.chat_stream.stream_id
|
||||
locally_disabled_components_actions = component_manage_api.get_locally_disabled_components(
|
||||
@@ -509,7 +508,7 @@ class PluginManagementPlugin(BasePlugin):
|
||||
False,
|
||||
)
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[PlusCommandInfo, Type[PlusCommand]]]:
|
||||
def get_plugin_components(self) -> list[tuple[PlusCommandInfo, type[PlusCommand]]]:
|
||||
"""返回插件的PlusCommand组件"""
|
||||
components = []
|
||||
if self.get_config("plugin.enabled", True):
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_plugin import BasePlugin
|
||||
from src.plugin_system import (
|
||||
BaseEventHandler,
|
||||
BasePlugin,
|
||||
ConfigField,
|
||||
register_plugin,
|
||||
EventHandlerInfo,
|
||||
BaseEventHandler,
|
||||
register_plugin,
|
||||
)
|
||||
from src.plugin_system.base.base_plugin import BasePlugin
|
||||
|
||||
from .proacive_thinker_event import ProactiveThinkerEventHandler
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -33,9 +32,9 @@ class ProactiveThinkerPlugin(BasePlugin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[EventHandlerInfo, Type[BaseEventHandler]]]:
|
||||
def get_plugin_components(self) -> list[tuple[EventHandlerInfo, type[BaseEventHandler]]]:
|
||||
"""返回插件的EventHandler组件"""
|
||||
components: List[Tuple[EventHandlerInfo, Type[BaseEventHandler]]] = [
|
||||
components: list[tuple[EventHandlerInfo, type[BaseEventHandler]]] = [
|
||||
(ProactiveThinkerEventHandler.get_handler_info(), ProactiveThinkerEventHandler)
|
||||
]
|
||||
return components
|
||||
|
||||
@@ -2,17 +2,17 @@ import asyncio
|
||||
import random
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import List, Union
|
||||
|
||||
from maim_message import UserInfo
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.manager.async_task_manager import async_task_manager, AsyncTask
|
||||
from src.plugin_system import EventType, BaseEventHandler
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.plugin_system import BaseEventHandler, EventType
|
||||
from src.plugin_system.apis import chat_api, person_api
|
||||
from src.plugin_system.base.base_event import HandlerResult
|
||||
|
||||
from .proactive_thinker_executor import ProactiveThinkerExecutor
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -199,7 +199,7 @@ class ProactiveThinkerEventHandler(BaseEventHandler):
|
||||
|
||||
handler_name: str = "proactive_thinker_on_start"
|
||||
handler_description: str = "主动思考插件的启动事件处理器"
|
||||
init_subscribe: List[Union[EventType, str]] = [EventType.ON_START]
|
||||
init_subscribe: list[EventType | str] = [EventType.ON_START]
|
||||
|
||||
async def execute(self, kwargs: dict | None) -> "HandlerResult":
|
||||
"""在机器人启动时执行,根据配置决定是否启动后台任务。"""
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
import orjson
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.plugin_system.apis import (
|
||||
chat_api,
|
||||
database_api,
|
||||
generator_api,
|
||||
llm_api,
|
||||
message_api,
|
||||
person_api,
|
||||
schedule_api,
|
||||
send_api,
|
||||
llm_api,
|
||||
message_api,
|
||||
generator_api,
|
||||
database_api,
|
||||
)
|
||||
from src.config.config import global_config, model_config
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -101,7 +102,7 @@ class ProactiveThinkerExecutor:
|
||||
logger.error(f"解析 stream_id ({stream_id}) 或获取 stream 失败: {e}")
|
||||
return None
|
||||
|
||||
async def _gather_context(self, stream_id: str) -> Optional[Dict[str, Any]]:
|
||||
async def _gather_context(self, stream_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
收集构建提示词所需的所有上下文信息
|
||||
"""
|
||||
@@ -165,7 +166,7 @@ class ProactiveThinkerExecutor:
|
||||
"current_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}
|
||||
|
||||
async def _make_decision(self, context: Dict[str, Any], start_mode: str) -> Optional[Dict[str, Any]]:
|
||||
async def _make_decision(self, context: dict[str, Any], start_mode: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
决策模块:判断是否应该主动发起对话,以及聊什么话题
|
||||
"""
|
||||
@@ -234,7 +235,7 @@ class ProactiveThinkerExecutor:
|
||||
logger.error(f"决策LLM返回的JSON格式无效: {response}")
|
||||
return {"should_reply": False, "reason": "决策模型返回格式错误"}
|
||||
|
||||
def _build_plan_prompt(self, context: Dict[str, Any], start_mode: str, topic: str, reason: str) -> str:
|
||||
def _build_plan_prompt(self, context: dict[str, Any], start_mode: str, topic: str, reason: str) -> str:
|
||||
"""
|
||||
根据启动模式和决策话题,构建最终的规划提示词
|
||||
"""
|
||||
|
||||
@@ -1,24 +1,25 @@
|
||||
import re
|
||||
from typing import List, Tuple, Type, Optional
|
||||
|
||||
from src.plugin_system import (
|
||||
BasePlugin,
|
||||
register_plugin,
|
||||
BaseAction,
|
||||
ComponentInfo,
|
||||
ActionActivationType,
|
||||
ConfigField,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from .qq_emoji_list import qq_face
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.plugin_system.apis import send_api, llm_api, generator_api
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
import asyncio
|
||||
import datetime
|
||||
import re
|
||||
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.common.logger import get_logger
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.plugin_system import (
|
||||
ActionActivationType,
|
||||
BaseAction,
|
||||
BasePlugin,
|
||||
ComponentInfo,
|
||||
ConfigField,
|
||||
register_plugin,
|
||||
)
|
||||
from src.plugin_system.apis import generator_api, llm_api, send_api
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
|
||||
from .qq_emoji_list import qq_face
|
||||
|
||||
logger = get_logger("set_emoji_like_plugin")
|
||||
|
||||
@@ -30,7 +31,7 @@ class ReminderTask(AsyncTask):
|
||||
self,
|
||||
delay: float,
|
||||
stream_id: str,
|
||||
group_id: Optional[str],
|
||||
group_id: str | None,
|
||||
is_group: bool,
|
||||
target_user_id: str,
|
||||
target_user_name: str,
|
||||
@@ -162,7 +163,7 @@ class PokeAction(BaseAction):
|
||||
"""
|
||||
associated_types = ["text"]
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""执行戳一戳的动作"""
|
||||
user_id = self.action_data.get("user_id")
|
||||
user_name = self.action_data.get("user_name")
|
||||
@@ -242,7 +243,7 @@ class SetEmojiLikeAction(BaseAction):
|
||||
if match:
|
||||
emoji_options.append(match.group(1))
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""执行设置表情回应的动作"""
|
||||
message_id = None
|
||||
set_like = self.action_data.get("set", True)
|
||||
@@ -360,7 +361,7 @@ class RemindAction(BaseAction):
|
||||
"例如:'10分钟后提醒我收快递'、'明天早上九点喊一下李四参加晨会'",
|
||||
]
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""执行设置提醒的动作"""
|
||||
user_name = self.action_data.get("user_name")
|
||||
remind_time_str = self.action_data.get("remind_time")
|
||||
@@ -386,14 +387,14 @@ class RemindAction(BaseAction):
|
||||
# 优先尝试直接解析
|
||||
try:
|
||||
target_time = parse_datetime(remind_time_str, fuzzy=True)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
# 如果直接解析失败,调用 LLM 进行转换
|
||||
logger.info(f"[ReminderPlugin] 直接解析时间 '{remind_time_str}' 失败,尝试使用 LLM 进行转换...")
|
||||
|
||||
# 获取所有可用的模型配置
|
||||
available_models = llm_api.get_available_models()
|
||||
if "utils_small" not in available_models:
|
||||
raise ValueError("未找到 'utils_small' 模型配置,无法解析时间")
|
||||
raise ValueError("未找到 'utils_small' 模型配置,无法解析时间") from e
|
||||
|
||||
# 明确使用 'planner' 模型
|
||||
model_to_use = available_models["utils_small"]
|
||||
@@ -419,7 +420,7 @@ class RemindAction(BaseAction):
|
||||
)
|
||||
|
||||
if not success or not response:
|
||||
raise ValueError(f"LLM未能返回有效的时间字符串: {response}")
|
||||
raise ValueError(f"LLM未能返回有效的时间字符串: {response}") from e
|
||||
|
||||
converted_time_str = response.strip()
|
||||
logger.info(f"[ReminderPlugin] LLM 转换结果: '{converted_time_str}'")
|
||||
@@ -533,8 +534,8 @@ class SetEmojiLikePlugin(BasePlugin):
|
||||
# 插件基本信息
|
||||
plugin_name: str = "social_toolkit_plugin" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: List[str] = [] # 插件依赖列表
|
||||
python_dependencies: List[str] = [] # Python包依赖列表,现在使用内置API
|
||||
dependencies: list[str] = [] # 插件依赖列表
|
||||
python_dependencies: list[str] = [] # Python包依赖列表,现在使用内置API
|
||||
config_file_name: str = "config.toml" # 配置文件名
|
||||
|
||||
# 配置节描述
|
||||
@@ -555,7 +556,7 @@ class SetEmojiLikePlugin(BasePlugin):
|
||||
},
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
|
||||
enable_components = []
|
||||
if self.get_config("components.action_set_emoji_like"):
|
||||
enable_components.append((SetEmojiLikeAction.get_action_info(), SetEmojiLikeAction))
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis.plugin_register_api import register_plugin
|
||||
from src.plugin_system.base.base_action import ActionActivationType, BaseAction, ChatMode
|
||||
from src.plugin_system.base.base_plugin import BasePlugin
|
||||
from src.plugin_system.base.component_types import ComponentInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_action import BaseAction, ActionActivationType, ChatMode
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from typing import Tuple, List, Type
|
||||
|
||||
logger = get_logger("tts")
|
||||
|
||||
@@ -44,7 +43,7 @@ class TTSAction(BaseAction):
|
||||
# 关联类型
|
||||
associated_types = ["tts_text"]
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""处理TTS文本转语音动作"""
|
||||
logger.info(f"{self.log_prefix} 执行TTS动作: {self.reasoning}")
|
||||
|
||||
@@ -140,7 +139,7 @@ class TTSPlugin(BasePlugin):
|
||||
},
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
|
||||
"""返回插件包含的组件列表"""
|
||||
|
||||
# 从配置获取组件启用状态
|
||||
|
||||
@@ -3,7 +3,7 @@ Base search engine interface
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Any
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseSearchEngine(ABC):
|
||||
@@ -12,7 +12,7 @@ class BaseSearchEngine(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
执行搜索
|
||||
|
||||
|
||||
@@ -6,11 +6,13 @@ import asyncio
|
||||
import functools
|
||||
import random
|
||||
import traceback
|
||||
from typing import Dict, List, Any
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .base import BaseSearchEngine
|
||||
|
||||
logger = get_logger("bing_engine")
|
||||
@@ -68,7 +70,7 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
"""检查Bing搜索引擎是否可用"""
|
||||
return True # Bing是免费搜索引擎,总是可用
|
||||
|
||||
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""执行Bing搜索"""
|
||||
query = args["query"]
|
||||
num_results = args.get("num_results", 3)
|
||||
@@ -83,7 +85,7 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
logger.error(f"Bing 搜索失败: {e}")
|
||||
return []
|
||||
|
||||
def _search_sync(self, keyword: str, num_results: int, time_range: str) -> List[Dict[str, Any]]:
|
||||
def _search_sync(self, keyword: str, num_results: int, time_range: str) -> list[dict[str, Any]]:
|
||||
"""同步执行Bing搜索"""
|
||||
if not keyword:
|
||||
return []
|
||||
@@ -113,7 +115,7 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
return list_result[:num_results] if len(list_result) > num_results else list_result
|
||||
|
||||
@staticmethod
|
||||
def _parse_html(url: str) -> List[Dict[str, Any]]:
|
||||
def _parse_html(url: str) -> list[dict[str, Any]]:
|
||||
"""解析处理结果"""
|
||||
try:
|
||||
logger.debug(f"访问Bing搜索URL: {url}")
|
||||
@@ -141,11 +143,11 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
try:
|
||||
res = session.get(url=url, timeout=(3.05, 6), verify=True, allow_redirects=True)
|
||||
except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e:
|
||||
logger.warning(f"第一次请求超时,正在重试: {str(e)}")
|
||||
logger.warning(f"第一次请求超时,正在重试: {e!s}")
|
||||
try:
|
||||
res = session.get(url=url, timeout=(5, 10), verify=False)
|
||||
except Exception as e2:
|
||||
logger.error(f"第二次请求也失败: {str(e2)}")
|
||||
logger.error(f"第二次请求也失败: {e2!s}")
|
||||
return []
|
||||
|
||||
res.encoding = "utf-8"
|
||||
@@ -175,7 +177,7 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
try:
|
||||
root = BeautifulSoup(res.text, "html.parser")
|
||||
except Exception as e:
|
||||
logger.error(f"HTML解析失败: {str(e)}")
|
||||
logger.error(f"HTML解析失败: {e!s}")
|
||||
return []
|
||||
|
||||
list_data = []
|
||||
@@ -262,6 +264,6 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
return list_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析Bing页面时出错: {str(e)}")
|
||||
logger.error(f"解析Bing页面时出错: {e!s}")
|
||||
logger.debug(traceback.format_exc())
|
||||
return []
|
||||
|
||||
@@ -2,10 +2,12 @@
|
||||
DuckDuckGo search engine implementation
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any
|
||||
from typing import Any
|
||||
|
||||
from asyncddgs import aDDGS
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .base import BaseSearchEngine
|
||||
|
||||
logger = get_logger("ddg_engine")
|
||||
@@ -20,7 +22,7 @@ class DDGSearchEngine(BaseSearchEngine):
|
||||
"""检查DuckDuckGo搜索引擎是否可用"""
|
||||
return True # DuckDuckGo不需要API密钥,总是可用
|
||||
|
||||
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""执行DuckDuckGo搜索"""
|
||||
query = args["query"]
|
||||
num_results = args.get("num_results", 3)
|
||||
|
||||
@@ -5,13 +5,15 @@ Exa search engine implementation
|
||||
import asyncio
|
||||
import functools
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any
|
||||
from typing import Any
|
||||
|
||||
from exa_py import Exa
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api
|
||||
from .base import BaseSearchEngine
|
||||
|
||||
from ..utils.api_key_manager import create_api_key_manager_from_config
|
||||
from .base import BaseSearchEngine
|
||||
|
||||
logger = get_logger("exa_engine")
|
||||
|
||||
@@ -36,7 +38,7 @@ class ExaSearchEngine(BaseSearchEngine):
|
||||
"""检查Exa搜索引擎是否可用"""
|
||||
return self.api_manager.is_available()
|
||||
|
||||
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""执行Exa搜索"""
|
||||
if not self.is_available():
|
||||
return []
|
||||
|
||||
@@ -4,13 +4,15 @@ Tavily search engine implementation
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from typing import Dict, List, Any
|
||||
from typing import Any
|
||||
|
||||
from tavily import TavilyClient
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api
|
||||
from .base import BaseSearchEngine
|
||||
|
||||
from ..utils.api_key_manager import create_api_key_manager_from_config
|
||||
from .base import BaseSearchEngine
|
||||
|
||||
logger = get_logger("tavily_engine")
|
||||
|
||||
@@ -37,7 +39,7 @@ class TavilySearchEngine(BaseSearchEngine):
|
||||
"""检查Tavily搜索引擎是否可用"""
|
||||
return self.api_manager.is_available()
|
||||
|
||||
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""执行Tavily搜索"""
|
||||
if not self.is_available():
|
||||
return []
|
||||
|
||||
@@ -4,14 +4,12 @@ Web Search Tool Plugin
|
||||
一个功能强大的网络搜索和URL解析插件,支持多种搜索引擎和解析策略。
|
||||
"""
|
||||
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo, ConfigField, PythonDependency
|
||||
from src.plugin_system.apis import config_api
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, PythonDependency, register_plugin
|
||||
from src.plugin_system.apis import config_api
|
||||
|
||||
from .tools.web_search import WebSurfingTool
|
||||
from .tools.url_parser import URLParserTool
|
||||
from .tools.web_search import WebSurfingTool
|
||||
|
||||
logger = get_logger("web_search_plugin")
|
||||
|
||||
@@ -31,7 +29,7 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
# 插件基本信息
|
||||
plugin_name: str = "web_search_tool" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: List[str] = [] # 插件依赖列表
|
||||
dependencies: list[str] = [] # 插件依赖列表
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""初始化插件,立即加载所有搜索引擎"""
|
||||
@@ -40,10 +38,10 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
# 立即初始化所有搜索引擎,触发API密钥管理器的日志输出
|
||||
logger.info("🚀 正在初始化所有搜索引擎...")
|
||||
try:
|
||||
from .engines.bing_engine import BingSearchEngine
|
||||
from .engines.ddg_engine import DDGSearchEngine
|
||||
from .engines.exa_engine import ExaSearchEngine
|
||||
from .engines.tavily_engine import TavilySearchEngine
|
||||
from .engines.ddg_engine import DDGSearchEngine
|
||||
from .engines.bing_engine import BingSearchEngine
|
||||
|
||||
# 实例化所有搜索引擎,这会触发API密钥管理器的初始化
|
||||
exa_engine = ExaSearchEngine()
|
||||
@@ -71,7 +69,7 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
logger.error(f"❌ 搜索引擎初始化失败: {e}", exc_info=True)
|
||||
|
||||
# Python包依赖列表
|
||||
python_dependencies: List[PythonDependency] = [
|
||||
python_dependencies: list[PythonDependency] = [
|
||||
PythonDependency(package_name="asyncddgs", description="异步DuckDuckGo搜索库", optional=False),
|
||||
PythonDependency(
|
||||
package_name="exa_py",
|
||||
@@ -119,7 +117,7 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
},
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
|
||||
"""
|
||||
获取插件组件列表
|
||||
|
||||
|
||||
@@ -4,19 +4,20 @@ URL parser tool implementation
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from typing import Any, Dict
|
||||
from exa_py import Exa
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from bs4 import BeautifulSoup
|
||||
from exa_py import Exa
|
||||
|
||||
from src.common.cache_manager import tool_cache
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BaseTool, ToolParamType, llm_api
|
||||
from src.plugin_system.apis import config_api
|
||||
from src.common.cache_manager import tool_cache
|
||||
|
||||
from ..utils.api_key_manager import create_api_key_manager_from_config
|
||||
from ..utils.formatters import format_url_parse_results
|
||||
from ..utils.url_utils import parse_urls_from_input, validate_urls
|
||||
from ..utils.api_key_manager import create_api_key_manager_from_config
|
||||
|
||||
logger = get_logger("url_parser_tool")
|
||||
|
||||
@@ -50,7 +51,7 @@ class URLParserTool(BaseTool):
|
||||
exa_api_keys, lambda key: Exa(api_key=key), "Exa URL Parser"
|
||||
)
|
||||
|
||||
async def _local_parse_and_summarize(self, url: str) -> Dict[str, Any]:
|
||||
async def _local_parse_and_summarize(self, url: str) -> dict[str, Any]:
|
||||
"""
|
||||
使用本地库(httpx, BeautifulSoup)解析URL,并调用LLM进行总结。
|
||||
"""
|
||||
@@ -124,9 +125,9 @@ class URLParserTool(BaseTool):
|
||||
return {"error": f"请求失败,状态码: {e.response.status_code}"}
|
||||
except Exception as e:
|
||||
logger.error(f"本地解析或总结URL '{url}' 时发生未知异常: {e}", exc_info=True)
|
||||
return {"error": f"发生未知错误: {str(e)}"}
|
||||
return {"error": f"发生未知错误: {e!s}"}
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
执行URL内容提取和总结。优先使用Exa,失败后尝试本地解析。
|
||||
"""
|
||||
|
||||
@@ -3,18 +3,18 @@ Web search tool implementation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from src.common.cache_manager import tool_cache
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BaseTool, ToolParamType
|
||||
from src.plugin_system.apis import config_api
|
||||
from src.common.cache_manager import tool_cache
|
||||
|
||||
from ..engines.bing_engine import BingSearchEngine
|
||||
from ..engines.ddg_engine import DDGSearchEngine
|
||||
from ..engines.exa_engine import ExaSearchEngine
|
||||
from ..engines.tavily_engine import TavilySearchEngine
|
||||
from ..engines.ddg_engine import DDGSearchEngine
|
||||
from ..engines.bing_engine import BingSearchEngine
|
||||
from ..utils.formatters import format_search_results, deduplicate_results
|
||||
from ..utils.formatters import deduplicate_results, format_search_results
|
||||
|
||||
logger = get_logger("web_search_tool")
|
||||
|
||||
@@ -51,7 +51,7 @@ class WebSurfingTool(BaseTool):
|
||||
"bing": BingSearchEngine(),
|
||||
}
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
query = function_args.get("query")
|
||||
if not query:
|
||||
return {"error": "搜索查询不能为空。"}
|
||||
@@ -88,8 +88,8 @@ class WebSurfingTool(BaseTool):
|
||||
return result
|
||||
|
||||
async def _execute_parallel_search(
|
||||
self, function_args: Dict[str, Any], enabled_engines: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
self, function_args: dict[str, Any], enabled_engines: list[str]
|
||||
) -> dict[str, Any]:
|
||||
"""并行搜索策略:同时使用所有启用的搜索引擎"""
|
||||
search_tasks = []
|
||||
|
||||
@@ -124,11 +124,11 @@ class WebSurfingTool(BaseTool):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行并行网络搜索时发生异常: {e}", exc_info=True)
|
||||
return {"error": f"执行网络搜索时发生严重错误: {str(e)}"}
|
||||
return {"error": f"执行网络搜索时发生严重错误: {e!s}"}
|
||||
|
||||
async def _execute_fallback_search(
|
||||
self, function_args: Dict[str, Any], enabled_engines: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
self, function_args: dict[str, Any], enabled_engines: list[str]
|
||||
) -> dict[str, Any]:
|
||||
"""回退搜索策略:按顺序尝试搜索引擎,失败则尝试下一个"""
|
||||
for engine_name in enabled_engines:
|
||||
engine = self.engines.get(engine_name)
|
||||
@@ -154,7 +154,7 @@ class WebSurfingTool(BaseTool):
|
||||
|
||||
return {"error": "所有搜索引擎都失败了。"}
|
||||
|
||||
async def _execute_single_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]:
|
||||
async def _execute_single_search(self, function_args: dict[str, Any], enabled_engines: list[str]) -> dict[str, Any]:
|
||||
"""单一搜索策略:只使用第一个可用的搜索引擎"""
|
||||
for engine_name in enabled_engines:
|
||||
engine = self.engines.get(engine_name)
|
||||
@@ -174,6 +174,6 @@ class WebSurfingTool(BaseTool):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{engine_name} 搜索失败: {e}")
|
||||
return {"error": f"{engine_name} 搜索失败: {str(e)}"}
|
||||
return {"error": f"{engine_name} 搜索失败: {e!s}"}
|
||||
|
||||
return {"error": "没有可用的搜索引擎。"}
|
||||
|
||||
@@ -3,7 +3,9 @@ API密钥管理器,提供轮询机制
|
||||
"""
|
||||
|
||||
import itertools
|
||||
from typing import List, Optional, TypeVar, Generic, Callable
|
||||
from collections.abc import Callable
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("api_key_manager")
|
||||
@@ -16,7 +18,7 @@ class APIKeyManager(Generic[T]):
|
||||
API密钥管理器,支持轮询机制
|
||||
"""
|
||||
|
||||
def __init__(self, api_keys: List[str], client_factory: Callable[[str], T], service_name: str = "Unknown"):
|
||||
def __init__(self, api_keys: list[str], client_factory: Callable[[str], T], service_name: str = "Unknown"):
|
||||
"""
|
||||
初始化API密钥管理器
|
||||
|
||||
@@ -26,8 +28,8 @@ class APIKeyManager(Generic[T]):
|
||||
service_name: 服务名称,用于日志记录
|
||||
"""
|
||||
self.service_name = service_name
|
||||
self.clients: List[T] = []
|
||||
self.client_cycle: Optional[itertools.cycle] = None
|
||||
self.clients: list[T] = []
|
||||
self.client_cycle: itertools.cycle | None = None
|
||||
|
||||
if api_keys:
|
||||
# 过滤有效的API密钥,排除None、空字符串、"None"字符串等
|
||||
@@ -54,7 +56,7 @@ class APIKeyManager(Generic[T]):
|
||||
"""检查是否有可用的客户端"""
|
||||
return bool(self.clients and self.client_cycle)
|
||||
|
||||
def get_next_client(self) -> Optional[T]:
|
||||
def get_next_client(self) -> T | None:
|
||||
"""获取下一个客户端(轮询)"""
|
||||
if not self.is_available():
|
||||
return None
|
||||
@@ -66,7 +68,7 @@ class APIKeyManager(Generic[T]):
|
||||
|
||||
|
||||
def create_api_key_manager_from_config(
|
||||
config_keys: Optional[List[str]], client_factory: Callable[[str], T], service_name: str
|
||||
config_keys: list[str] | None, client_factory: Callable[[str], T], service_name: str
|
||||
) -> APIKeyManager[T]:
|
||||
"""
|
||||
从配置创建API密钥管理器的便捷函数
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
Formatters for web search results
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any
|
||||
from typing import Any
|
||||
|
||||
|
||||
def format_search_results(results: List[Dict[str, Any]]) -> str:
|
||||
def format_search_results(results: list[dict[str, Any]]) -> str:
|
||||
"""
|
||||
格式化搜索结果为字符串
|
||||
"""
|
||||
@@ -26,7 +26,7 @@ def format_search_results(results: List[Dict[str, Any]]) -> str:
|
||||
return formatted_string
|
||||
|
||||
|
||||
def format_url_parse_results(results: List[Dict[str, Any]]) -> str:
|
||||
def format_url_parse_results(results: list[dict[str, Any]]) -> str:
|
||||
"""
|
||||
将成功解析的URL结果列表格式化为一段简洁的文本。
|
||||
"""
|
||||
@@ -45,7 +45,7 @@ def format_url_parse_results(results: List[Dict[str, Any]]) -> str:
|
||||
return "\n---\n".join(formatted_parts)
|
||||
|
||||
|
||||
def deduplicate_results(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
def deduplicate_results(results: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
根据URL去重搜索结果
|
||||
"""
|
||||
|
||||
@@ -3,10 +3,9 @@ URL processing utilities
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
|
||||
def parse_urls_from_input(urls_input) -> List[str]:
|
||||
def parse_urls_from_input(urls_input) -> list[str]:
|
||||
"""
|
||||
从输入中解析URL列表
|
||||
"""
|
||||
@@ -29,7 +28,7 @@ def parse_urls_from_input(urls_input) -> List[str]:
|
||||
return urls
|
||||
|
||||
|
||||
def validate_urls(urls: List[str]) -> List[str]:
|
||||
def validate_urls(urls: list[str]) -> list[str]:
|
||||
"""
|
||||
验证URL格式,返回有效的URL列表
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user