re-style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 20:26:01 +08:00
parent ecb02cae31
commit 7923eafef3
263 changed files with 3103 additions and 3123 deletions

View File

@@ -7,15 +7,15 @@ import asyncio
import time
import traceback
from datetime import datetime
from typing import Dict, Any
from typing import Any
from src.chat.express.expression_learner import expression_learner_manager
from src.chat.planner_actions.action_manager import ChatterActionManager
from src.common.data_models.message_manager_data_model import StreamContext
from src.common.logger import get_logger
from src.plugin_system.base.base_chatter import BaseChatter
from src.plugin_system.base.component_types import ChatType
from src.common.data_models.message_manager_data_model import StreamContext
from src.plugins.built_in.affinity_flow_chatter.planner import ChatterActionPlanner
from src.chat.planner_actions.action_manager import ChatterActionManager
from src.common.logger import get_logger
from src.chat.express.expression_learner import expression_learner_manager
logger = get_logger("affinity_chatter")
@@ -113,7 +113,7 @@ class AffinityChatter(BaseChatter):
"executed_count": 0,
}
def get_stats(self) -> Dict[str, Any]:
def get_stats(self) -> dict[str, Any]:
"""
获取处理器统计信息
@@ -122,7 +122,7 @@ class AffinityChatter(BaseChatter):
"""
return self.stats.copy()
def get_planner_stats(self) -> Dict[str, Any]:
def get_planner_stats(self) -> dict[str, Any]:
"""
获取规划器统计信息
@@ -131,7 +131,7 @@ class AffinityChatter(BaseChatter):
"""
return self.planner.get_planner_stats()
def get_interest_scoring_stats(self) -> Dict[str, Any]:
def get_interest_scoring_stats(self) -> dict[str, Any]:
"""
获取兴趣度评分统计信息
@@ -140,7 +140,7 @@ class AffinityChatter(BaseChatter):
"""
return self.planner.get_interest_scoring_stats()
def get_relationship_stats(self) -> Dict[str, Any]:
def get_relationship_stats(self) -> dict[str, Any]:
"""
获取用户关系统计信息
@@ -158,7 +158,7 @@ class AffinityChatter(BaseChatter):
"""
return self.planner.get_current_mood_state()
def get_mood_stats(self) -> Dict[str, Any]:
def get_mood_stats(self) -> dict[str, Any]:
"""
获取情绪状态统计信息

View File

@@ -5,11 +5,11 @@
"""
import traceback
from typing import Dict, List, Any
from typing import Any
from src.chat.interest_system import bot_interest_manager
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.info_data_model import InterestScore
from src.chat.interest_system import bot_interest_manager
from src.common.logger import get_logger
from src.config.config import global_config
@@ -47,11 +47,11 @@ class ChatterInterestScoringSystem:
) # 每次不回复增加的概率
# 用户关系数据
self.user_relationships: Dict[str, float] = {} # user_id -> relationship_score
self.user_relationships: dict[str, float] = {} # user_id -> relationship_score
async def calculate_interest_scores(
self, messages: List[DatabaseMessages], bot_nickname: str
) -> List[InterestScore]:
self, messages: list[DatabaseMessages], bot_nickname: str
) -> list[InterestScore]:
"""计算消息的兴趣度评分"""
user_messages = [msg for msg in messages if str(msg.user_info.user_id) != str(global_config.bot.qq_account)]
if not user_messages:
@@ -97,7 +97,7 @@ class ChatterInterestScoringSystem:
details=details,
)
async def _calculate_interest_match_score(self, content: str, keywords: List[str] = None) -> float:
async def _calculate_interest_match_score(self, content: str, keywords: list[str] = None) -> float:
"""计算兴趣匹配度 - 使用智能embedding匹配"""
if not content:
return 0.0
@@ -109,7 +109,7 @@ class ChatterInterestScoringSystem:
# 智能匹配未初始化,返回默认分数
return 0.3
async def _calculate_smart_interest_match(self, content: str, keywords: List[str] = None) -> float:
async def _calculate_smart_interest_match(self, content: str, keywords: list[str] = None) -> float:
"""使用embedding计算智能兴趣匹配"""
try:
# 如果没有传入关键词,则提取
@@ -134,7 +134,7 @@ class ChatterInterestScoringSystem:
logger.error(f"智能兴趣匹配计算失败: {e}")
return 0.0
def _extract_keywords_from_database(self, message: DatabaseMessages) -> List[str]:
def _extract_keywords_from_database(self, message: DatabaseMessages) -> list[str]:
"""从数据库消息中提取关键词"""
keywords = []
@@ -166,7 +166,7 @@ class ChatterInterestScoringSystem:
return keywords[:15] # 返回前15个关键词
def _extract_keywords_from_content(self, content: str) -> List[str]:
def _extract_keywords_from_content(self, content: str) -> list[str]:
"""从内容中提取关键词(降级方案)"""
import re
@@ -287,7 +287,7 @@ class ChatterInterestScoringSystem:
"""获取用户关系分"""
return self.user_relationships.get(user_id, 0.3)
def get_scoring_stats(self) -> Dict:
def get_scoring_stats(self) -> dict:
"""获取评分系统统计"""
return {
"no_reply_count": self.no_reply_count,
@@ -318,7 +318,7 @@ class ChatterInterestScoringSystem:
logger.error(f"初始化智能兴趣系统失败: {e}")
traceback.print_exc()
def get_matching_config(self) -> Dict[str, Any]:
def get_matching_config(self) -> dict[str, Any]:
"""获取匹配配置信息"""
return {
"use_smart_matching": self.use_smart_matching,

View File

@@ -5,12 +5,11 @@ PlanExecutor: 接收 Plan 对象并执行其中的所有动作。
import asyncio
import time
from typing import Dict, List
from src.config.config import global_config
from src.chat.planner_actions.action_manager import ChatterActionManager
from src.common.data_models.info_data_model import Plan, ActionPlannerInfo
from src.common.data_models.info_data_model import ActionPlannerInfo, Plan
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("plan_executor")
@@ -52,7 +51,7 @@ class ChatterPlanExecutor:
"""设置关系追踪器"""
self.relationship_tracker = relationship_tracker
async def execute(self, plan: Plan) -> Dict[str, any]:
async def execute(self, plan: Plan) -> dict[str, any]:
"""
遍历并执行Plan对象中`decided_actions`列表里的所有动作。
@@ -110,7 +109,7 @@ class ChatterPlanExecutor:
"results": execution_results,
}
async def _execute_reply_actions(self, reply_actions: List[ActionPlannerInfo], plan: Plan) -> Dict[str, any]:
async def _execute_reply_actions(self, reply_actions: list[ActionPlannerInfo], plan: Plan) -> dict[str, any]:
"""串行执行所有回复动作,增加去重逻辑,避免对同一消息多次回复"""
results = []
@@ -162,7 +161,7 @@ class ChatterPlanExecutor:
async def _execute_single_reply_action(
self, action_info: ActionPlannerInfo, plan: Plan, clear_unread: bool = True
) -> Dict[str, any]:
) -> dict[str, any]:
"""执行单个回复动作"""
start_time = time.time()
success = False
@@ -240,7 +239,7 @@ class ChatterPlanExecutor:
else reply_content,
}
async def _execute_other_actions(self, other_actions: List[ActionPlannerInfo], plan: Plan) -> Dict[str, any]:
async def _execute_other_actions(self, other_actions: list[ActionPlannerInfo], plan: Plan) -> dict[str, any]:
"""执行其他动作"""
results = []
@@ -269,7 +268,7 @@ class ChatterPlanExecutor:
return {"results": results}
async def _execute_single_other_action(self, action_info: ActionPlannerInfo, plan: Plan) -> Dict[str, any]:
async def _execute_single_other_action(self, action_info: ActionPlannerInfo, plan: Plan) -> dict[str, any]:
"""执行单个其他动作"""
start_time = time.time()
success = False
@@ -378,7 +377,7 @@ class ChatterPlanExecutor:
logger.debug(f"action_message类型: {type(action_info.action_message)}")
logger.debug(f"action_message内容: {action_info.action_message}")
def get_execution_stats(self) -> Dict[str, any]:
def get_execution_stats(self) -> dict[str, any]:
"""获取执行统计信息"""
stats = self.execution_stats.copy()
@@ -409,7 +408,7 @@ class ChatterPlanExecutor:
"execution_times": [],
}
def get_recent_performance(self, limit: int = 10) -> List[Dict[str, any]]:
def get_recent_performance(self, limit: int = 10) -> list[dict[str, any]]:
"""获取最近的执行性能"""
recent_times = self.execution_stats["execution_times"][-limit:]
if not recent_times:

View File

@@ -2,13 +2,13 @@
PlanFilter: 接收 Plan 对象,根据不同模式的逻辑进行筛选,决定最终要执行的动作。
"""
import orjson
import re
import time
import traceback
import re
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Any
import orjson
from json_repair import repair_json
# 旧的Hippocampus系统已被移除现在使用增强记忆系统
@@ -39,7 +39,7 @@ class ChatterPlanFilter:
根据 Plan 中的模式和信息,筛选并决定最终的动作。
"""
def __init__(self, chat_id: str, available_actions: List[str]):
def __init__(self, chat_id: str, available_actions: list[str]):
"""
初始化动作计划筛选器。
@@ -316,8 +316,8 @@ class ChatterPlanFilter:
"""构建已读/未读历史消息块"""
try:
# 从message_manager获取真实的已读/未读消息
from src.chat.utils.utils import assign_message_ids
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
from src.chat.utils.utils import assign_message_ids
# 获取聊天流的上下文
from src.plugin_system.apis.chat_api import get_chat_manager
@@ -392,14 +392,15 @@ class ChatterPlanFilter:
logger.error(f"构建已读/未读历史消息块时出错: {e}")
return "构建已读历史消息时出错", "构建未读历史消息时出错", []
async def _get_interest_scores_for_messages(self, messages: List[dict]) -> dict[str, float]:
async def _get_interest_scores_for_messages(self, messages: list[dict]) -> dict[str, float]:
"""为消息获取兴趣度评分"""
interest_scores = {}
try:
from .interest_scoring import chatter_interest_scoring_system
from src.common.data_models.database_data_model import DatabaseMessages
from .interest_scoring import chatter_interest_scoring_system
# 使用插件内部的兴趣度评分系统计算评分
for msg_dict in messages:
try:
@@ -450,7 +451,7 @@ class ChatterPlanFilter:
async def _parse_single_action(
self, action_json: dict, message_id_list: list, plan: Plan
) -> List[ActionPlannerInfo]:
) -> list[ActionPlannerInfo]:
parsed_actions = []
try:
# 从新的actions结构中获取动作信息
@@ -599,7 +600,7 @@ class ChatterPlanFilter:
)
return parsed_actions
def _filter_no_actions(self, action_list: List[ActionPlannerInfo]) -> List[ActionPlannerInfo]:
def _filter_no_actions(self, action_list: list[ActionPlannerInfo]) -> list[ActionPlannerInfo]:
non_no_actions = [a for a in action_list if a.action_type not in ["no_action", "no_reply"]]
if non_no_actions:
return non_no_actions
@@ -652,7 +653,7 @@ class ChatterPlanFilter:
logger.error(f"获取长期记忆时出错: {e}")
return "回忆时出现了一些问题。"
async def _build_action_options(self, current_available_actions: Dict[str, ActionInfo]) -> str:
async def _build_action_options(self, current_available_actions: dict[str, ActionInfo]) -> str:
action_options_block = ""
for action_name, action_info in current_available_actions.items():
# 构建参数的JSON示例
@@ -723,7 +724,7 @@ class ChatterPlanFilter:
)
return action_options_block
def _find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]:
def _find_message_by_id(self, message_id: str, message_id_list: list) -> dict[str, Any] | None:
"""
增强的消息查找函数,支持多种格式和模糊匹配
兼容大模型可能返回的各种格式变体
@@ -828,12 +829,12 @@ class ChatterPlanFilter:
logger.warning(f"未找到任何匹配的消息: {original_id} (候选: {candidate_ids})")
return None
def _get_latest_message(self, message_id_list: list) -> Optional[Dict[str, Any]]:
def _get_latest_message(self, message_id_list: list) -> dict[str, Any] | None:
if not message_id_list:
return None
return message_id_list[-1].get("message")
def _find_poke_notice(self, message_id_list: list) -> Optional[Dict[str, Any]]:
def _find_poke_notice(self, message_id_list: list) -> dict[str, Any] | None:
"""在消息列表中寻找戳一戳的通知消息"""
for item in reversed(message_id_list):
message = item.get("message")

View File

@@ -3,7 +3,6 @@ PlanGenerator: 负责搜集和汇总所有决策所需的信息,生成一个
"""
import time
from typing import Dict
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
from src.chat.utils.utils import get_chat_type_and_target_info
@@ -85,7 +84,7 @@ class ChatterPlanGenerator:
chat_history=[],
)
async def _get_available_actions(self, chat_type: ChatType, mode: ChatMode) -> Dict[str, ActionInfo]:
async def _get_available_actions(self, chat_type: ChatType, mode: ChatMode) -> dict[str, ActionInfo]:
"""
获取当前可用的动作列表。
@@ -152,7 +151,7 @@ class ChatterPlanGenerator:
# 如果获取失败,返回空列表
return []
def get_generator_stats(self) -> Dict:
def get_generator_stats(self) -> dict:
"""
获取生成器统计信息。

View File

@@ -4,22 +4,20 @@
"""
from dataclasses import asdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor
from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter
from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
from src.mood.mood_manager import mood_manager
from typing import TYPE_CHECKING, Any
from src.common.logger import get_logger
from src.config.config import global_config
from src.mood.mood_manager import mood_manager
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor
from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter
from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator
if TYPE_CHECKING:
from src.common.data_models.message_manager_data_model import StreamContext
from src.common.data_models.info_data_model import Plan
from src.chat.planner_actions.action_manager import ChatterActionManager
from src.common.data_models.info_data_model import Plan
from src.common.data_models.message_manager_data_model import StreamContext
# 导入提示词模块以确保其被初始化
from src.plugins.built_in.affinity_flow_chatter import planner_prompts # noqa
@@ -62,7 +60,7 @@ class ChatterActionPlanner:
"other_actions_executed": 0,
}
async def plan(self, context: "StreamContext" = None) -> Tuple[List[Dict], Optional[Dict]]:
async def plan(self, context: "StreamContext" = None) -> tuple[list[dict], dict | None]:
"""
执行完整的增强版规划流程。
@@ -84,7 +82,7 @@ class ChatterActionPlanner:
self.planner_stats["failed_plans"] += 1
return [], None
async def _enhanced_plan_flow(self, context: "StreamContext") -> Tuple[List[Dict], Optional[Dict]]:
async def _enhanced_plan_flow(self, context: "StreamContext") -> tuple[list[dict], dict | None]:
"""执行增强版规划流程"""
try:
# 在规划前,先进行动作修改
@@ -104,7 +102,7 @@ class ChatterActionPlanner:
score = 0.0
should_reply = False
reply_not_available = False
interest_updates: List[Dict[str, Any]] = []
interest_updates: list[dict[str, Any]] = []
if unread_messages:
# 为每条消息计算兴趣度,并延迟提交数据库更新
@@ -193,7 +191,7 @@ class ChatterActionPlanner:
self.planner_stats["failed_plans"] += 1
return [], None
async def _commit_interest_updates(self, updates: List[Dict[str, Any]]) -> None:
async def _commit_interest_updates(self, updates: list[dict[str, Any]]) -> None:
"""统一更新消息兴趣度,减少数据库写入次数"""
if not updates:
return
@@ -220,7 +218,7 @@ class ChatterActionPlanner:
except Exception as e:
logger.warning(f"批量更新数据库兴趣度失败: {e}")
def _update_stats_from_execution_result(self, execution_result: Dict[str, any]):
def _update_stats_from_execution_result(self, execution_result: dict[str, any]):
"""根据执行结果更新规划器统计"""
if not execution_result:
return
@@ -244,7 +242,7 @@ class ChatterActionPlanner:
self.planner_stats["replies_generated"] += reply_count
self.planner_stats["other_actions_executed"] += other_count
def _build_return_result(self, plan: "Plan") -> Tuple[List[Dict], Optional[Dict]]:
def _build_return_result(self, plan: "Plan") -> tuple[list[dict], dict | None]:
"""构建返回结果"""
final_actions = plan.decided_actions or []
final_target_message = next((act.action_message for act in final_actions if act.action_message), None)
@@ -261,7 +259,7 @@ class ChatterActionPlanner:
return final_actions_dict, final_target_message_dict
def get_planner_stats(self) -> Dict[str, any]:
def get_planner_stats(self) -> dict[str, any]:
"""获取规划器统计"""
return self.planner_stats.copy()
@@ -270,7 +268,7 @@ class ChatterActionPlanner:
chat_mood = mood_manager.get_mood_by_chat_id(self.chat_id)
return chat_mood.mood_state
def get_mood_stats(self) -> Dict[str, any]:
def get_mood_stats(self) -> dict[str, any]:
"""获取情绪状态统计"""
chat_mood = mood_manager.get_mood_by_chat_id(self.chat_id)
return {

View File

@@ -2,12 +2,10 @@
亲和力聊天处理器插件
"""
from typing import List, Tuple, Type
from src.common.logger import get_logger
from src.plugin_system.apis.plugin_register_api import register_plugin
from src.plugin_system.base.base_plugin import BasePlugin
from src.plugin_system.base.component_types import ComponentInfo
from src.common.logger import get_logger
logger = get_logger("affinity_chatter_plugin")
@@ -29,7 +27,7 @@ class AffinityChatterPlugin(BasePlugin):
# 简单的 config_schema 占位(如果将来需要配置可扩展)
config_schema = {}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
"""返回插件包含的组件列表ChatterInfo, AffinityChatter
这里采用延迟导入 AffinityChatter 来避免循环依赖和启动顺序问题。

View File

@@ -5,15 +5,15 @@
"""
import time
from typing import Dict, List, Optional
from src.common.logger import get_logger
from src.config.config import model_config, global_config
from src.llm_models.utils_model import LLMRequest
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import UserRelationships, Messages
from sqlalchemy import select, desc
from sqlalchemy import desc, select
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Messages, UserRelationships
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger("chatter_relationship_tracker")
@@ -22,15 +22,15 @@ class ChatterRelationshipTracker:
"""用户关系追踪器"""
def __init__(self, interest_scoring_system=None):
self.tracking_users: Dict[str, Dict] = {} # user_id -> interaction_data
self.tracking_users: dict[str, dict] = {} # user_id -> interaction_data
self.max_tracking_users = 3
self.update_interval_minutes = 30
self.last_update_time = time.time()
self.relationship_history: List[Dict] = []
self.relationship_history: list[dict] = []
self.interest_scoring_system = interest_scoring_system
# 用户关系缓存 (user_id -> {"relationship_text": str, "relationship_score": float, "last_tracked": float})
self.user_relationship_cache: Dict[str, Dict] = {}
self.user_relationship_cache: dict[str, dict] = {}
self.cache_expiry_hours = 1 # 缓存过期时间(小时)
# 关系更新LLM
@@ -91,7 +91,7 @@ class ChatterRelationshipTracker:
logger.debug(f"添加用户交互追踪: {user_id}")
async def check_and_update_relationships(self) -> List[Dict]:
async def check_and_update_relationships(self) -> list[dict]:
"""检查并更新用户关系"""
current_time = time.time()
if current_time - self.last_update_time < self.update_interval_minutes * 60:
@@ -108,7 +108,7 @@ class ChatterRelationshipTracker:
self.last_update_time = current_time
return updates
async def _update_user_relationship(self, interaction: Dict) -> Optional[Dict]:
async def _update_user_relationship(self, interaction: dict) -> dict | None:
"""更新单个用户的关系"""
try:
# 获取bot人设信息
@@ -201,11 +201,11 @@ class ChatterRelationshipTracker:
return None
def get_tracking_users(self) -> Dict[str, Dict]:
def get_tracking_users(self) -> dict[str, dict]:
"""获取正在追踪的用户"""
return self.tracking_users.copy()
def get_user_interaction(self, user_id: str) -> Optional[Dict]:
def get_user_interaction(self, user_id: str) -> dict | None:
"""获取特定用户的交互记录"""
return self.tracking_users.get(user_id)
@@ -220,11 +220,11 @@ class ChatterRelationshipTracker:
self.tracking_users.clear()
logger.info("清空所有用户追踪")
def get_relationship_history(self) -> List[Dict]:
def get_relationship_history(self) -> list[dict]:
"""获取关系历史记录"""
return self.relationship_history.copy()
def add_to_history(self, relationship_update: Dict):
def add_to_history(self, relationship_update: dict):
"""添加到关系历史"""
self.relationship_history.append({**relationship_update, "update_time": time.time()})
@@ -232,7 +232,7 @@ class ChatterRelationshipTracker:
if len(self.relationship_history) > 100:
self.relationship_history = self.relationship_history[-100:]
def get_tracker_stats(self) -> Dict:
def get_tracker_stats(self) -> dict:
"""获取追踪器统计"""
return {
"tracking_users": len(self.tracking_users),
@@ -268,7 +268,7 @@ class ChatterRelationshipTracker:
self.add_to_history(update_info)
logger.info(f"强制更新用户关系: {user_id} -> {new_score:.2f}")
def get_user_summary(self, user_id: str) -> Dict:
def get_user_summary(self, user_id: str) -> dict:
"""获取用户交互总结"""
if user_id not in self.tracking_users:
return {}
@@ -313,7 +313,7 @@ class ChatterRelationshipTracker:
# 数据库中也没有,返回默认值
return global_config.affinity_flow.base_relationship_score
async def _get_user_relationship_from_db(self, user_id: str) -> Optional[Dict]:
async def _get_user_relationship_from_db(self, user_id: str) -> dict | None:
"""从数据库获取用户关系数据"""
try:
async with get_db_session() as session:
@@ -431,7 +431,7 @@ class ChatterRelationshipTracker:
return 0
async def _get_last_bot_reply_to_user(self, user_id: str) -> Optional[DatabaseMessages]:
async def _get_last_bot_reply_to_user(self, user_id: str) -> DatabaseMessages | None:
"""获取上次bot回复该用户的消息"""
try:
async with get_db_session() as session:
@@ -455,7 +455,7 @@ class ChatterRelationshipTracker:
return None
async def _get_user_reactions_after_reply(self, user_id: str, reply_time: float) -> List[DatabaseMessages]:
async def _get_user_reactions_after_reply(self, user_id: str, reply_time: float) -> list[DatabaseMessages]:
"""获取用户在bot回复后的反应消息"""
try:
async with get_db_session() as session:
@@ -511,7 +511,7 @@ class ChatterRelationshipTracker:
user_id: str,
user_name: str,
last_bot_reply: DatabaseMessages,
user_reactions: List[DatabaseMessages],
user_reactions: list[DatabaseMessages],
current_text: str,
current_score: float,
current_reply: str,

View File

@@ -8,9 +8,9 @@
- 测试功能
"""
from src.plugin_system.base import BaseCommand
from src.chat.antipromptinjector import get_anti_injector
from src.common.logger import get_logger
from src.plugin_system.base import BaseCommand
logger = get_logger("anti_injector.commands")
@@ -56,5 +56,5 @@ class AntiInjectorStatusCommand(BaseCommand):
except Exception as e:
logger.error(f"获取反注入系统状态失败: {e}")
await self.send_text(f"获取状态失败: {str(e)}")
return False, f"获取状态失败: {str(e)}", True
await self.send_text(f"获取状态失败: {e!s}")
return False, f"获取状态失败: {e!s}", True

View File

@@ -1,19 +1,18 @@
import random
from typing import Tuple
# 导入新插件系统
from src.plugin_system import BaseAction, ActionActivationType, ChatMode
from src.chat.emoji_system.emoji_history import add_emoji_to_history, get_recent_emojis
from src.chat.emoji_system.emoji_manager import MaiEmoji, get_emoji_manager
from src.chat.utils.utils_image import image_path_to_base64
# 导入依赖的系统组件
from src.common.logger import get_logger
from src.config.config import global_config
# 导入新插件系统
from src.plugin_system import ActionActivationType, BaseAction, ChatMode
# 导入API模块 - 标准Python包方式
from src.plugin_system.apis import llm_api, message_api
from src.chat.emoji_system.emoji_manager import get_emoji_manager, MaiEmoji
from src.chat.utils.utils_image import image_path_to_base64
from src.config.config import global_config
from src.chat.emoji_system.emoji_history import get_recent_emojis, add_emoji_to_history
logger = get_logger("emoji")
@@ -59,7 +58,7 @@ class EmojiAction(BaseAction):
# 关联类型
associated_types = ["emoji"]
async def execute(self) -> Tuple[bool, str]:
async def execute(self) -> tuple[bool, str]:
"""执行表情动作"""
logger.info(f"{self.log_prefix} 决定发送表情")
@@ -286,4 +285,4 @@ class EmojiAction(BaseAction):
except Exception as e:
logger.error(f"{self.log_prefix} 表情动作执行失败: {e}", exc_info=True)
return False, f"表情发送失败: {str(e)}"
return False, f"表情发送失败: {e!s}"

View File

@@ -5,19 +5,16 @@
这是系统的内置插件,提供基础的聊天交互功能
"""
from typing import List, Tuple, Type
# 导入新插件系统
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo
from src.plugin_system.base.config_types import ConfigField
# 导入依赖的系统组件
from src.common.logger import get_logger
# 导入新插件系统
from src.plugin_system import BasePlugin, ComponentInfo, register_plugin
from src.plugin_system.base.config_types import ConfigField
from src.plugins.built_in.core_actions.anti_injector_manager import AntiInjectorStatusCommand
# 导入API模块 - 标准Python包方式
from src.plugins.built_in.core_actions.emoji import EmojiAction
from src.plugins.built_in.core_actions.anti_injector_manager import AntiInjectorStatusCommand
logger = get_logger("core_actions")
@@ -62,7 +59,7 @@ class CoreActionsPlugin(BasePlugin):
},
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
"""返回插件包含的组件列表"""
# --- 根据配置注册组件 ---

View File

@@ -1,8 +1,8 @@
from typing import Dict, Any
from typing import Any
from src.chat.knowledge.knowledge_lib import qa_manager
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.knowledge.knowledge_lib import qa_manager
from src.plugin_system import BaseTool, ToolParamType
logger = get_logger("lpmm_get_knowledge_tool")
@@ -19,7 +19,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
]
available_for_llm = global_config.lpmm_knowledge.enable
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行知识库搜索
Args:
@@ -56,7 +56,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
return {"type": "lpmm_knowledge", "id": query, "content": content}
except Exception as e:
# 捕获异常并记录错误
logger.error(f"知识库搜索工具执行失败: {str(e)}")
logger.error(f"知识库搜索工具执行失败: {e!s}")
# 在其他异常情况下,确保 id 仍然是 query (如果它被定义了)
query_id = query if "query" in locals() else "unknown_query"
return {"type": "info", "id": query_id, "content": f"lpmm知识库搜索失败炸了: {str(e)}"}
return {"type": "info", "id": query_id, "content": f"lpmm知识库搜索失败炸了: {e!s}"}

View File

@@ -1,9 +1,8 @@
# -*- coding: utf-8 -*-
"""
让框架能够发现并加载子目录中的组件。
"""
from .plugin import MaiZoneRefactoredPlugin as MaiZoneRefactoredPlugin
from .actions.send_feed_action import SendFeedAction as SendFeedAction
from .actions.read_feed_action import ReadFeedAction as ReadFeedAction
from .actions.send_feed_action import SendFeedAction as SendFeedAction
from .commands.send_feed_command import SendFeedCommand as SendFeedCommand
from .plugin import MaiZoneRefactoredPlugin as MaiZoneRefactoredPlugin

View File

@@ -1,14 +1,12 @@
# -*- coding: utf-8 -*-
"""
阅读说说动作组件
"""
from typing import Tuple
from src.common.logger import get_logger
from src.plugin_system import BaseAction, ActionActivationType, ChatMode
from src.plugin_system import ActionActivationType, BaseAction, ChatMode
from src.plugin_system.apis import generator_api
from src.plugin_system.apis.permission_api import permission_api
from ..services.manager import get_qzone_service
logger = get_logger("MaiZone.ReadFeedAction")
@@ -41,7 +39,7 @@ class ReadFeedAction(BaseAction):
# 使用权限API检查用户是否有阅读说说的权限
return await permission_api.check_permission(platform, user_id, "plugin.maizone.read_feed")
async def execute(self) -> Tuple[bool, str]:
async def execute(self) -> tuple[bool, str]:
"""
执行动作的核心逻辑。
"""

View File

@@ -1,14 +1,12 @@
# -*- coding: utf-8 -*-
"""
发送说说动作组件
"""
from typing import Tuple
from src.common.logger import get_logger
from src.plugin_system import BaseAction, ActionActivationType, ChatMode
from src.plugin_system import ActionActivationType, BaseAction, ChatMode
from src.plugin_system.apis import generator_api
from src.plugin_system.apis.permission_api import permission_api
from ..services.manager import get_qzone_service
logger = get_logger("MaiZone.SendFeedAction")
@@ -41,7 +39,7 @@ class SendFeedAction(BaseAction):
# 使用权限API检查用户是否有发送说说的权限
return await permission_api.check_permission(platform, user_id, "plugin.maizone.send_feed")
async def execute(self) -> Tuple[bool, str]:
async def execute(self) -> tuple[bool, str]:
"""
执行动作的核心逻辑。
"""

View File

@@ -1,15 +1,13 @@
# -*- coding: utf-8 -*-
"""
发送说说命令 await self.send_text(f"收到!正在为你生成关于"{topic or '随机'}"的说说,请稍候...【热重载测试成功】")件
"""
from typing import Tuple
from src.common.logger import get_logger
from src.plugin_system.base.plus_command import PlusCommand
from src.plugin_system.base.command_args import CommandArgs
from src.plugin_system.base.plus_command import PlusCommand
from src.plugin_system.utils.permission_decorators import require_permission
from ..services.manager import get_qzone_service, get_config_getter
from ..services.manager import get_config_getter, get_qzone_service
logger = get_logger("MaiZone.SendFeedCommand")
@@ -28,7 +26,7 @@ class SendFeedCommand(PlusCommand):
super().__init__(*args, **kwargs)
@require_permission("plugin.maizone.send_feed")
async def execute(self, args: CommandArgs) -> Tuple[bool, str, bool]:
async def execute(self, args: CommandArgs) -> tuple[bool, str, bool]:
"""
执行命令的核心逻辑。
"""

View File

@@ -1,28 +1,26 @@
# -*- coding: utf-8 -*-
"""
MaiZone麦麦空间- 重构版
"""
import asyncio
from pathlib import Path
from typing import List, Tuple, Type
from src.common.logger import get_logger
from src.plugin_system import BasePlugin, ComponentInfo, register_plugin
from src.plugin_system.base.config_types import ConfigField
from src.plugin_system.apis.permission_api import permission_api
from src.plugin_system.base.config_types import ConfigField
from .actions.read_feed_action import ReadFeedAction
from .actions.send_feed_action import SendFeedAction
from .commands.send_feed_command import SendFeedCommand
from .services.content_service import ContentService
from .services.image_service import ImageService
from .services.qzone_service import QZoneService
from .services.scheduler_service import SchedulerService
from .services.monitor_service import MonitorService
from .services.cookie_service import CookieService
from .services.reply_tracker_service import ReplyTrackerService
from .services.image_service import ImageService
from .services.manager import register_service
from .services.monitor_service import MonitorService
from .services.qzone_service import QZoneService
from .services.reply_tracker_service import ReplyTrackerService
from .services.scheduler_service import SchedulerService
logger = get_logger("MaiZone.Plugin")
@@ -35,8 +33,8 @@ class MaiZoneRefactoredPlugin(BasePlugin):
plugin_description: str = "重构版的MaiZone插件"
config_file_name: str = "config.toml"
enable_plugin: bool = True
dependencies: List[str] = []
python_dependencies: List[str] = []
dependencies: list[str] = []
python_dependencies: list[str] = []
config_schema: dict = {
"plugin": {"enable": ConfigField(type=bool, default=True, description="是否启用插件")},
@@ -125,7 +123,7 @@ class MaiZoneRefactoredPlugin(BasePlugin):
asyncio.create_task(monitor_service.start())
logger.info("MaiZone后台监控和定时任务已启动。")
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
return [
(SendFeedAction.get_action_info(), SendFeedAction),
(ReadFeedAction.get_action_info(), ReadFeedAction),

View File

@@ -1,23 +1,23 @@
# -*- coding: utf-8 -*-
"""
内容服务模块
负责生成所有与QQ空间相关的文本内容例如说说、评论等。
"""
from typing import Callable, Optional
import datetime
import base64
import aiohttp
from src.common.logger import get_logger
import imghdr
import asyncio
from src.plugin_system.apis import llm_api, config_api, generator_api
from src.plugin_system.apis.cross_context_api import get_chat_history_by_group_name
from src.chat.message_receive.chat_stream import get_chat_manager
import base64
import datetime
import imghdr
from collections.abc import Callable
import aiohttp
from maim_message import UserInfo
from src.llm_models.utils_model import LLMRequest
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger
from src.config.api_ada_configs import TaskConfig
from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis import config_api, generator_api, llm_api
from src.plugin_system.apis.cross_context_api import get_chat_history_by_group_name
# 导入旧的工具函数,我们稍后会考虑是否也需要重构它
from ..utils.history_utils import get_send_history
@@ -38,7 +38,7 @@ class ContentService:
"""
self.get_config = get_config
async def generate_story(self, topic: str, context: Optional[str] = None) -> str:
async def generate_story(self, topic: str, context: str | None = None) -> str:
"""
根据指定主题和可选的上下文生成一条QQ空间说说。
@@ -231,7 +231,7 @@ class ContentService:
return ""
return ""
async def _describe_image(self, image_url: str) -> Optional[str]:
async def _describe_image(self, image_url: str) -> str | None:
"""
使用LLM识别图片内容。
"""

View File

@@ -1,14 +1,14 @@
# -*- coding: utf-8 -*-
"""
Cookie服务模块
负责从多种来源获取、缓存和管理QZone的Cookie。
"""
import orjson
from collections.abc import Callable
from pathlib import Path
from typing import Callable, Optional, Dict
import aiohttp
import orjson
from src.common.logger import get_logger
from src.plugin_system.apis import send_api
@@ -29,28 +29,28 @@ class CookieService:
"""获取指定QQ账号的cookie文件路径"""
return self.cookie_dir / f"cookies-{qq_account}.json"
def _save_cookies_to_file(self, qq_account: str, cookies: Dict[str, str]):
def _save_cookies_to_file(self, qq_account: str, cookies: dict[str, str]):
"""将Cookie保存到本地文件"""
cookie_file_path = self._get_cookie_file_path(qq_account)
try:
with open(cookie_file_path, "w", encoding="utf-8") as f:
f.write(orjson.dumps(cookies, option=orjson.OPT_INDENT_2).decode("utf-8"))
logger.info(f"Cookie已成功缓存至: {cookie_file_path}")
except IOError as e:
except OSError as e:
logger.error(f"无法写入Cookie文件 {cookie_file_path}: {e}")
def _load_cookies_from_file(self, qq_account: str) -> Optional[Dict[str, str]]:
def _load_cookies_from_file(self, qq_account: str) -> dict[str, str] | None:
"""从本地文件加载Cookie"""
cookie_file_path = self._get_cookie_file_path(qq_account)
if cookie_file_path.exists():
try:
with open(cookie_file_path, "r", encoding="utf-8") as f:
with open(cookie_file_path, encoding="utf-8") as f:
return orjson.loads(f.read())
except (IOError, orjson.JSONDecodeError) as e:
except (OSError, orjson.JSONDecodeError) as e:
logger.error(f"无法读取或解析Cookie文件 {cookie_file_path}: {e}")
return None
async def _get_cookies_from_adapter(self, stream_id: Optional[str]) -> Optional[Dict[str, str]]:
async def _get_cookies_from_adapter(self, stream_id: str | None) -> dict[str, str] | None:
"""通过Adapter API获取Cookie"""
try:
params = {"domain": "user.qzone.qq.com"}
@@ -73,7 +73,7 @@ class CookieService:
logger.error(f"通过Adapter获取Cookie时发生异常: {e}")
return None
async def _get_cookies_from_http(self) -> Optional[Dict[str, str]]:
async def _get_cookies_from_http(self) -> dict[str, str] | None:
"""通过备用HTTP端点获取Cookie"""
host = self.get_config("cookie.http_fallback_host", "172.20.130.55")
port = self.get_config("cookie.http_fallback_port", "9999")
@@ -110,7 +110,7 @@ class CookieService:
logger.error(f"通过HTTP备用地址 {http_url} 获取Cookie失败: {e}")
return None
async def get_cookies(self, qq_account: str, stream_id: Optional[str]) -> Optional[Dict[str, str]]:
async def get_cookies(self, qq_account: str, stream_id: str | None) -> dict[str, str] | None:
"""
获取Cookie按以下顺序尝试
1. HTTP备用端点 (更稳定)

View File

@@ -1,12 +1,11 @@
# -*- coding: utf-8 -*-
"""
图片服务模块
负责处理所有与图片相关的任务特别是AI生成图片。
"""
import base64
from collections.abc import Callable
from pathlib import Path
from typing import Callable
import aiohttp

View File

@@ -1,14 +1,15 @@
# -*- coding: utf-8 -*-
"""
服务管理器/定位器
这是一个独立的模块,用于注册和获取插件内的全局服务实例,以避免循环导入。
"""
from typing import Dict, Any, Callable
from collections.abc import Callable
from typing import Any
from .qzone_service import QZoneService
# --- 全局服务注册表 ---
_services: Dict[str, Any] = {}
_services: dict[str, Any] = {}
def register_service(name: str, instance: Any):

View File

@@ -1,13 +1,13 @@
# -*- coding: utf-8 -*-
"""
好友动态监控服务
"""
import asyncio
import traceback
from typing import Callable
from collections.abc import Callable
from src.common.logger import get_logger
from .qzone_service import QZoneService
logger = get_logger("MaiZone.MonitorService")

View File

@@ -1,32 +1,33 @@
# -*- coding: utf-8 -*-
"""
QQ空间服务模块
封装了所有与QQ空间API的直接交互是插件的核心业务逻辑层。
"""
import asyncio
import orjson
import base64
import os
import random
import time
import base64
from collections.abc import Callable
from pathlib import Path
from typing import Callable, Optional, Dict, Any, List, Tuple
from typing import Any
import aiohttp
import bs4
import json5
from src.common.logger import get_logger
from src.plugin_system.apis import config_api, person_api
import orjson
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id,
get_raw_msg_by_timestamp_with_chat,
)
from src.common.logger import get_logger
from src.plugin_system.apis import config_api, person_api
from .content_service import ContentService
from .image_service import ImageService
from .cookie_service import CookieService
from .image_service import ImageService
from .reply_tracker_service import ReplyTrackerService
logger = get_logger("MaiZone.QZoneService")
@@ -64,7 +65,7 @@ class QZoneService:
# --- Public Methods (High-Level Business Logic) ---
async def send_feed(self, topic: str, stream_id: Optional[str]) -> Dict[str, Any]:
async def send_feed(self, topic: str, stream_id: str | None) -> dict[str, Any]:
"""发送一条说说"""
# --- 获取互通组上下文 ---
context = await self._get_intercom_context(stream_id) if stream_id else None
@@ -92,7 +93,7 @@ class QZoneService:
logger.error(f"发布说说时发生异常: {e}", exc_info=True)
return {"success": False, "message": f"发布说说异常: {e}"}
async def send_feed_from_activity(self, activity: str) -> Dict[str, Any]:
async def send_feed_from_activity(self, activity: str) -> dict[str, Any]:
"""根据日程活动发送一条说说"""
story = await self.content_service.generate_story_from_activity(activity)
if not story:
@@ -118,7 +119,7 @@ class QZoneService:
logger.error(f"根据活动发布说说时发生异常: {e}", exc_info=True)
return {"success": False, "message": f"发布说说异常: {e}"}
async def read_and_process_feeds(self, target_name: str, stream_id: Optional[str]) -> Dict[str, Any]:
async def read_and_process_feeds(self, target_name: str, stream_id: str | None) -> dict[str, Any]:
"""读取并处理指定好友的说说"""
target_person_id = await person_api.get_person_id_by_name(target_name)
if not target_person_id:
@@ -147,7 +148,7 @@ class QZoneService:
logger.error(f"读取和处理说说时发生异常: {e}", exc_info=True)
return {"success": False, "message": f"处理说说异常: {e}"}
async def monitor_feeds(self, stream_id: Optional[str] = None):
async def monitor_feeds(self, stream_id: str | None = None):
"""监控并处理所有好友的动态,包括回复自己说说的评论"""
logger.info("开始执行好友动态监控...")
qq_account = config_api.get_global_config("bot.qq_account", "")
@@ -189,7 +190,7 @@ class QZoneService:
# --- Internal Helper Methods ---
async def _get_intercom_context(self, stream_id: str) -> Optional[str]:
async def _get_intercom_context(self, stream_id: str) -> str | None:
"""
根据 stream_id 查找其所属的互通组,并构建该组的聊天上下文。
@@ -247,7 +248,7 @@ class QZoneService:
logger.debug(f"Stream ID '{stream_id}' 未在任何互通组中找到。")
return None
async def _reply_to_own_feed_comments(self, feed: Dict, api_client: Dict):
async def _reply_to_own_feed_comments(self, feed: dict, api_client: dict):
"""处理对自己说说的评论并进行回复"""
qq_account = config_api.get_global_config("bot.qq_account", "")
comments = feed.get("comments", [])
@@ -309,7 +310,7 @@ class QZoneService:
if comment_key in self.processing_comments:
self.processing_comments.remove(comment_key)
async def _validate_and_cleanup_reply_records(self, fid: str, my_replies: List[Dict]):
async def _validate_and_cleanup_reply_records(self, fid: str, my_replies: list[dict]):
"""验证并清理已删除的回复记录"""
# 获取当前记录中该说说的所有已回复评论ID
recorded_replied_comments = self.reply_tracker.get_replied_comments(fid)
@@ -333,7 +334,7 @@ class QZoneService:
self.reply_tracker.remove_reply_record(fid, comment_tid)
logger.debug(f"已清理删除的回复记录: feed_id={fid}, comment_id={comment_tid}")
async def _process_single_feed(self, feed: Dict, api_client: Dict, target_qq: str, target_name: str):
async def _process_single_feed(self, feed: dict, api_client: dict, target_qq: str, target_name: str):
"""处理单条说说,决定是否评论和点赞"""
content = feed.get("content", "")
fid = feed.get("tid", "")
@@ -371,7 +372,7 @@ class QZoneService:
if random.random() <= self.get_config("read.like_possibility", 1.0):
await api_client["like"](target_qq, fid)
def _load_local_images(self, image_dir: str) -> List[bytes]:
def _load_local_images(self, image_dir: str) -> list[bytes]:
"""随机加载本地图片(不删除文件)"""
images = []
if not image_dir or not os.path.exists(image_dir):
@@ -432,7 +433,7 @@ class QZoneService:
hash_val += (hash_val << 5) + ord(char)
return str(hash_val & 2147483647)
async def _renew_and_load_cookies(self, qq_account: str, stream_id: Optional[str]) -> Optional[Dict[str, str]]:
async def _renew_and_load_cookies(self, qq_account: str, stream_id: str | None) -> dict[str, str] | None:
cookie_dir = Path(__file__).resolve().parent.parent / "cookies"
cookie_dir.mkdir(exist_ok=True)
cookie_file_path = cookie_dir / f"cookies-{qq_account}.json"
@@ -480,7 +481,7 @@ class QZoneService:
logger.error("所有获取Cookie的方式均失败。")
return None
async def _fetch_cookies_http(self, host: str, port: int, napcat_token: str) -> Optional[Dict]:
async def _fetch_cookies_http(self, host: str, port: int, napcat_token: str) -> dict | None:
"""通过HTTP服务器获取Cookie"""
# 从配置中读取主机和端口,如果未提供则使用传入的参数
final_host = self.get_config("cookie.http_fallback_host", host)
@@ -515,19 +516,19 @@ class QZoneService:
except aiohttp.ClientError as e:
if attempt < max_retries - 1:
logger.warning(f"无法连接到Napcat服务(尝试 {attempt + 1}/{max_retries}): {url},错误: {str(e)}")
logger.warning(f"无法连接到Napcat服务(尝试 {attempt + 1}/{max_retries}): {url},错误: {e!s}")
await asyncio.sleep(retry_delay)
retry_delay *= 2
continue
logger.error(f"无法连接到Napcat服务(最终尝试): {url},错误: {str(e)}")
logger.error(f"无法连接到Napcat服务(最终尝试): {url},错误: {e!s}")
raise RuntimeError(f"无法连接到Napcat服务: {url}") from e
except Exception as e:
logger.error(f"获取cookie异常: {str(e)}")
logger.error(f"获取cookie异常: {e!s}")
raise
raise RuntimeError(f"无法连接到Napcat服务: 超过最大重试次数({max_retries})")
async def _get_api_client(self, qq_account: str, stream_id: Optional[str]) -> Optional[Dict]:
async def _get_api_client(self, qq_account: str, stream_id: str | None) -> dict | None:
cookies = await self.cookie_service.get_cookies(qq_account, stream_id)
if not cookies:
logger.error(
@@ -559,7 +560,7 @@ class QZoneService:
response.raise_for_status()
return await response.text()
async def _publish(content: str, images: List[bytes]) -> Tuple[bool, str]:
async def _publish(content: str, images: list[bytes]) -> tuple[bool, str]:
"""发布说说"""
try:
post_data = {
@@ -660,7 +661,7 @@ class QZoneService:
return picbo, richval
async def _upload_image(image_bytes: bytes, index: int) -> Optional[Dict[str, str]]:
async def _upload_image(image_bytes: bytes, index: int) -> dict[str, str] | None:
"""上传图片到QQ空间完全按照原版实现"""
try:
upload_url = "https://up.qzone.qq.com/cgi-bin/upload/cgi_upload_image"
@@ -745,7 +746,7 @@ class QZoneService:
logger.error(f"上传图片 {index + 1} 异常: {e}", exc_info=True)
return None
async def _list_feeds(t_qq: str, num: int) -> List[Dict]:
async def _list_feeds(t_qq: str, num: int) -> list[dict]:
"""获取指定用户说说列表 (统一接口)"""
try:
# 统一使用 format=json 获取完整评论
@@ -920,7 +921,7 @@ class QZoneService:
logger.error(f"回复评论异常: {e}", exc_info=True)
return False
async def _monitor_list_feeds(num: int) -> List[Dict]:
async def _monitor_list_feeds(num: int) -> list[dict]:
"""监控好友动态"""
try:
params = {

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
评论回复跟踪服务
负责记录和管理已回复过的评论ID避免重复回复
@@ -7,7 +6,8 @@
import json
import time
from pathlib import Path
from typing import Set, Dict, Any, Union
from typing import Any
from src.common.logger import get_logger
logger = get_logger("MaiZone.ReplyTrackerService")
@@ -27,7 +27,7 @@ class ReplyTrackerService:
# 内存中的已回复评论记录
# 格式: {feed_id: {comment_id: timestamp, ...}, ...}
self.replied_comments: Dict[str, Dict[str, float]] = {}
self.replied_comments: dict[str, dict[str, float]] = {}
# 数据清理配置
self.max_record_days = 30 # 保留30天的记录
@@ -64,7 +64,7 @@ class ReplyTrackerService:
try:
if self.reply_record_file.exists():
try:
with open(self.reply_record_file, "r", encoding="utf-8") as f:
with open(self.reply_record_file, encoding="utf-8") as f:
file_content = f.read().strip()
if not file_content: # 文件为空
logger.warning("回复记录文件为空,将创建新的记录")
@@ -173,7 +173,7 @@ class ReplyTrackerService:
if total_removed > 0:
logger.info(f"清理了 {total_removed} 条超过{self.max_record_days}天的过期回复记录")
def has_replied(self, feed_id: str, comment_id: Union[str, int]) -> bool:
def has_replied(self, feed_id: str, comment_id: str | int) -> bool:
"""
检查是否已经回复过指定的评论
@@ -190,7 +190,7 @@ class ReplyTrackerService:
comment_id_str = str(comment_id)
return feed_id in self.replied_comments and comment_id_str in self.replied_comments[feed_id]
def mark_as_replied(self, feed_id: str, comment_id: Union[str, int]):
def mark_as_replied(self, feed_id: str, comment_id: str | int):
"""
标记指定评论为已回复
@@ -219,7 +219,7 @@ class ReplyTrackerService:
else:
logger.error(f"标记评论时数据验证失败: feed_id={feed_id}, comment_id={comment_id}")
def get_replied_comments(self, feed_id: str) -> Set[str]:
def get_replied_comments(self, feed_id: str) -> set[str]:
"""
获取指定说说下所有已回复的评论ID
@@ -234,7 +234,7 @@ class ReplyTrackerService:
return {str(comment_id) for comment_id in self.replied_comments[feed_id].keys()}
return set()
def get_stats(self) -> Dict[str, Any]:
def get_stats(self) -> dict[str, Any]:
"""
获取回复记录统计信息

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
定时任务服务
根据日程表定时发送说说。
@@ -8,13 +7,14 @@ import asyncio
import datetime
import random
import traceback
from typing import Callable
from collections.abc import Callable
from sqlalchemy import select
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import MaiZoneScheduleStatus
from src.common.logger import get_logger
from src.schedule.schedule_manager import schedule_manager
from src.common.database.sqlalchemy_database_api import get_db_session
from sqlalchemy import select
from src.common.database.sqlalchemy_models import MaiZoneScheduleStatus
from .qzone_service import QZoneService

View File

@@ -1,15 +1,15 @@
# -*- coding: utf-8 -*-
"""
历史记录工具模块
提供用于获取QQ空间发送历史的功能。
"""
import orjson
import os
from pathlib import Path
from typing import Dict, Any, Optional, List
from typing import Any
import orjson
import requests
from src.common.logger import get_logger
logger = get_logger("MaiZone.HistoryUtils")
@@ -26,11 +26,11 @@ class _CookieManager:
return str(cookie_dir / f"cookies-{uin}.json")
@staticmethod
def load_cookies(qq_account: str) -> Optional[Dict[str, str]]:
def load_cookies(qq_account: str) -> dict[str, str] | None:
cookie_file = _CookieManager.get_cookie_file_path(qq_account)
if os.path.exists(cookie_file):
try:
with open(cookie_file, "r", encoding="utf-8") as f:
with open(cookie_file, encoding="utf-8") as f:
return orjson.loads(f.read())
except Exception as e:
logger.error(f"加载Cookie文件失败: {e}")
@@ -42,7 +42,7 @@ class _SimpleQZoneAPI:
LIST_URL = "https://user.qzone.qq.com/proxy/domain/taotao.qq.com/cgi-bin/emotion_cgi_msglist_v6"
def __init__(self, cookies_dict: Optional[Dict[str, str]] = None):
def __init__(self, cookies_dict: dict[str, str] | None = None):
self.cookies = cookies_dict or {}
self.gtk2 = ""
p_skey = self.cookies.get("p_skey") or self.cookies.get("p_skey".upper())
@@ -55,7 +55,7 @@ class _SimpleQZoneAPI:
hash_val += (hash_val << 5) + ord(char)
return str(hash_val & 2147483647)
def get_feed_list(self, target_qq: str, num: int) -> List[Dict[str, Any]]:
def get_feed_list(self, target_qq: str, num: int) -> list[dict[str, Any]]:
try:
params = {
"g_tk": self.gtk2,

View File

@@ -6,19 +6,17 @@
"""
import re
from typing import List, Optional, Tuple, Type
from src.plugin_system.apis.logging_api import get_logger
from src.plugin_system.apis.permission_api import permission_api
from src.plugin_system.apis.plugin_register_api import register_plugin
from src.plugin_system.base.base_plugin import BasePlugin
from src.plugin_system.base.plus_command import PlusCommand
from src.plugin_system.base.command_args import CommandArgs
from src.plugin_system.apis.permission_api import permission_api
from src.plugin_system.apis.logging_api import get_logger
from src.plugin_system.base.component_types import PlusCommandInfo, ChatType
from src.plugin_system.base.component_types import ChatType, PlusCommandInfo
from src.plugin_system.base.config_types import ConfigField
from src.plugin_system.base.plus_command import PlusCommand
from src.plugin_system.utils.permission_decorators import require_permission
logger = get_logger("Permission")
@@ -44,7 +42,7 @@ class PermissionCommand(PlusCommand):
"plugin.permission.view", "权限查看:可以查看权限节点和用户权限信息", "permission_manager", True
)
async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]:
async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]:
"""执行权限管理命令"""
if args.is_empty:
await self._show_help()
@@ -114,7 +112,7 @@ class PermissionCommand(PlusCommand):
await self.send_text(help_text)
@staticmethod
def _parse_user_mention(mention: str) -> Optional[str]:
def _parse_user_mention(mention: str) -> str | None:
"""解析用户提及提取QQ号
支持的格式:
@@ -134,7 +132,7 @@ class PermissionCommand(PlusCommand):
return None
@staticmethod
def parse_user_from_args(args: CommandArgs, index: int = 0) -> Optional[str]:
def parse_user_from_args(args: CommandArgs, index: int = 0) -> str | None:
"""从CommandArgs中解析用户ID
Args:
@@ -166,7 +164,7 @@ class PermissionCommand(PlusCommand):
return None
@require_permission("plugin.permission.manage", "❌ 你没有权限管理的权限")
async def _grant_permission(self, chat_stream, args: List[str]):
async def _grant_permission(self, chat_stream, args: list[str]):
"""授权用户权限"""
if len(args) < 2:
await self.send_text("❌ 用法: /permission grant <@用户|QQ号> <权限节点>")
@@ -189,7 +187,7 @@ class PermissionCommand(PlusCommand):
await self.send_text("❌ 授权失败,请检查权限节点是否存在")
@require_permission("plugin.permission.manage", "❌ 你没有权限管理的权限")
async def _revoke_permission(self, chat_stream, args: List[str]):
async def _revoke_permission(self, chat_stream, args: list[str]):
"""撤销用户权限"""
if len(args) < 2:
await self.send_text("❌ 用法: /permission revoke <@用户|QQ号> <权限节点>")
@@ -212,7 +210,7 @@ class PermissionCommand(PlusCommand):
await self.send_text("❌ 撤销失败,请检查权限节点是否存在")
@require_permission("plugin.permission.view", "❌ 你没有查看权限的权限")
async def _list_permissions(self, chat_stream, args: List[str]):
async def _list_permissions(self, chat_stream, args: list[str]):
"""列出用户权限"""
target_user_id = None
@@ -244,7 +242,7 @@ class PermissionCommand(PlusCommand):
await self.send_text(response)
@require_permission("plugin.permission.view", "❌ 你没有查看权限的权限")
async def _check_permission(self, chat_stream, args: List[str]):
async def _check_permission(self, chat_stream, args: list[str]):
"""检查用户权限"""
if len(args) < 2:
await self.send_text("❌ 用法: /permission check <@用户|QQ号> <权限节点>")
@@ -273,7 +271,7 @@ class PermissionCommand(PlusCommand):
await self.send_text(response)
@require_permission("plugin.permission.view", "❌ 你没有查看权限的权限")
async def _list_nodes(self, chat_stream, args: List[str]):
async def _list_nodes(self, chat_stream, args: list[str]):
"""列出权限节点"""
plugin_name = args[0] if args else None
@@ -388,6 +386,6 @@ class PermissionManagerPlugin(BasePlugin):
}
}
def get_plugin_components(self) -> List[Tuple[PlusCommandInfo, Type[PlusCommand]]]:
def get_plugin_components(self) -> list[tuple[PlusCommandInfo, type[PlusCommand]]]:
"""返回插件的PlusCommand组件"""
return [(PermissionCommand.get_plus_command_info(), PermissionCommand)]

View File

@@ -1,19 +1,18 @@
import asyncio
from typing import List, Tuple, Type
from src.plugin_system import (
BasePlugin,
ConfigField,
register_plugin,
plugin_manage_api,
component_manage_api,
ComponentInfo,
ComponentType,
ConfigField,
component_manage_api,
plugin_manage_api,
register_plugin,
)
from src.plugin_system.base.plus_command import PlusCommand
from src.plugin_system.base.command_args import CommandArgs
from src.plugin_system.base.component_types import PlusCommandInfo, ChatType
from src.plugin_system.apis.permission_api import permission_api
from src.plugin_system.base.command_args import CommandArgs
from src.plugin_system.base.component_types import ChatType, PlusCommandInfo
from src.plugin_system.base.plus_command import PlusCommand
from src.plugin_system.utils.permission_decorators import require_permission
@@ -31,7 +30,7 @@ class ManagementCommand(PlusCommand):
super().__init__(*args, **kwargs)
@require_permission("plugin.management.admin", "❌ 你没有插件管理的权限")
async def execute(self, args: CommandArgs) -> Tuple[bool, str, bool]:
async def execute(self, args: CommandArgs) -> tuple[bool, str, bool]:
"""执行插件管理命令"""
if args.is_empty:
await self._show_help("all")
@@ -51,7 +50,7 @@ class ManagementCommand(PlusCommand):
await self.send_text(f"❌ 未知的子命令: {subcommand}\n使用 /pm help 查看帮助")
return True, "未知子命令", True
async def _handle_plugin_commands(self, args: List[str]) -> Tuple[bool, str, bool]:
async def _handle_plugin_commands(self, args: list[str]) -> tuple[bool, str, bool]:
"""处理插件相关命令"""
if not args:
await self._show_help("plugin")
@@ -83,7 +82,7 @@ class ManagementCommand(PlusCommand):
return True, "插件命令执行完成", True
async def _handle_component_commands(self, args: List[str]) -> Tuple[bool, str, bool]:
async def _handle_component_commands(self, args: list[str]) -> tuple[bool, str, bool]:
"""处理组件相关命令"""
if not args:
await self._show_help("component")
@@ -258,7 +257,7 @@ class ManagementCommand(PlusCommand):
else:
await self.send_text(f"❌ 插件强制重载失败: `{plugin_name}`")
except Exception as e:
await self.send_text(f"❌ 强制重载过程中发生错误: {str(e)}")
await self.send_text(f"❌ 强制重载过程中发生错误: {e!s}")
async def _add_dir(self, dir_path: str):
"""添加插件目录"""
@@ -271,17 +270,17 @@ class ManagementCommand(PlusCommand):
await self.send_text(f"❌ 插件目录添加失败: `{dir_path}`")
@staticmethod
def _fetch_all_registered_components() -> List[ComponentInfo]:
def _fetch_all_registered_components() -> list[ComponentInfo]:
all_plugin_info = component_manage_api.get_all_plugin_info()
if not all_plugin_info:
return []
components_info: List[ComponentInfo] = []
components_info: list[ComponentInfo] = []
for plugin_info in all_plugin_info.values():
components_info.extend(plugin_info.components)
return components_info
def _fetch_locally_disabled_components(self) -> List[str]:
def _fetch_locally_disabled_components(self) -> list[str]:
"""获取本地禁用的组件列表"""
stream_id = self.message.chat_stream.stream_id
locally_disabled_components_actions = component_manage_api.get_locally_disabled_components(
@@ -509,7 +508,7 @@ class PluginManagementPlugin(BasePlugin):
False,
)
def get_plugin_components(self) -> List[Tuple[PlusCommandInfo, Type[PlusCommand]]]:
def get_plugin_components(self) -> list[tuple[PlusCommandInfo, type[PlusCommand]]]:
"""返回插件的PlusCommand组件"""
components = []
if self.get_config("plugin.enabled", True):

View File

@@ -1,14 +1,13 @@
from typing import List, Tuple, Type
from src.common.logger import get_logger
from src.plugin_system.base.base_plugin import BasePlugin
from src.plugin_system import (
BaseEventHandler,
BasePlugin,
ConfigField,
register_plugin,
EventHandlerInfo,
BaseEventHandler,
register_plugin,
)
from src.plugin_system.base.base_plugin import BasePlugin
from .proacive_thinker_event import ProactiveThinkerEventHandler
logger = get_logger(__name__)
@@ -33,9 +32,9 @@ class ProactiveThinkerPlugin(BasePlugin):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def get_plugin_components(self) -> List[Tuple[EventHandlerInfo, Type[BaseEventHandler]]]:
def get_plugin_components(self) -> list[tuple[EventHandlerInfo, type[BaseEventHandler]]]:
"""返回插件的EventHandler组件"""
components: List[Tuple[EventHandlerInfo, Type[BaseEventHandler]]] = [
components: list[tuple[EventHandlerInfo, type[BaseEventHandler]]] = [
(ProactiveThinkerEventHandler.get_handler_info(), ProactiveThinkerEventHandler)
]
return components

View File

@@ -2,17 +2,17 @@ import asyncio
import random
import time
from datetime import datetime
from typing import List, Union
from maim_message import UserInfo
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger
from src.config.config import global_config
from src.manager.async_task_manager import async_task_manager, AsyncTask
from src.plugin_system import EventType, BaseEventHandler
from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.plugin_system import BaseEventHandler, EventType
from src.plugin_system.apis import chat_api, person_api
from src.plugin_system.base.base_event import HandlerResult
from .proactive_thinker_executor import ProactiveThinkerExecutor
logger = get_logger(__name__)
@@ -199,7 +199,7 @@ class ProactiveThinkerEventHandler(BaseEventHandler):
handler_name: str = "proactive_thinker_on_start"
handler_description: str = "主动思考插件的启动事件处理器"
init_subscribe: List[Union[EventType, str]] = [EventType.ON_START]
init_subscribe: list[EventType | str] = [EventType.ON_START]
async def execute(self, kwargs: dict | None) -> "HandlerResult":
"""在机器人启动时执行,根据配置决定是否启动后台任务。"""

View File

@@ -1,20 +1,21 @@
import orjson
from typing import Optional, Dict, Any
from datetime import datetime
from typing import Any
import orjson
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.person_info.person_info import get_person_info_manager
from src.plugin_system.apis import (
chat_api,
database_api,
generator_api,
llm_api,
message_api,
person_api,
schedule_api,
send_api,
llm_api,
message_api,
generator_api,
database_api,
)
from src.config.config import global_config, model_config
from src.person_info.person_info import get_person_info_manager
logger = get_logger(__name__)
@@ -101,7 +102,7 @@ class ProactiveThinkerExecutor:
logger.error(f"解析 stream_id ({stream_id}) 或获取 stream 失败: {e}")
return None
async def _gather_context(self, stream_id: str) -> Optional[Dict[str, Any]]:
async def _gather_context(self, stream_id: str) -> dict[str, Any] | None:
"""
收集构建提示词所需的所有上下文信息
"""
@@ -165,7 +166,7 @@ class ProactiveThinkerExecutor:
"current_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}
async def _make_decision(self, context: Dict[str, Any], start_mode: str) -> Optional[Dict[str, Any]]:
async def _make_decision(self, context: dict[str, Any], start_mode: str) -> dict[str, Any] | None:
"""
决策模块:判断是否应该主动发起对话,以及聊什么话题
"""
@@ -234,7 +235,7 @@ class ProactiveThinkerExecutor:
logger.error(f"决策LLM返回的JSON格式无效: {response}")
return {"should_reply": False, "reason": "决策模型返回格式错误"}
def _build_plan_prompt(self, context: Dict[str, Any], start_mode: str, topic: str, reason: str) -> str:
def _build_plan_prompt(self, context: dict[str, Any], start_mode: str, topic: str, reason: str) -> str:
"""
根据启动模式和决策话题,构建最终的规划提示词
"""

View File

@@ -1,24 +1,25 @@
import re
from typing import List, Tuple, Type, Optional
from src.plugin_system import (
BasePlugin,
register_plugin,
BaseAction,
ComponentInfo,
ActionActivationType,
ConfigField,
)
from src.common.logger import get_logger
from .qq_emoji_list import qq_face
from src.plugin_system.base.component_types import ChatType
from src.person_info.person_info import get_person_info_manager
from dateutil.parser import parse as parse_datetime
from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.plugin_system.apis import send_api, llm_api, generator_api
from src.chat.message_receive.chat_stream import ChatStream
import asyncio
import datetime
import re
from dateutil.parser import parse as parse_datetime
from src.chat.message_receive.chat_stream import ChatStream
from src.common.logger import get_logger
from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.person_info.person_info import get_person_info_manager
from src.plugin_system import (
ActionActivationType,
BaseAction,
BasePlugin,
ComponentInfo,
ConfigField,
register_plugin,
)
from src.plugin_system.apis import generator_api, llm_api, send_api
from src.plugin_system.base.component_types import ChatType
from .qq_emoji_list import qq_face
logger = get_logger("set_emoji_like_plugin")
@@ -30,7 +31,7 @@ class ReminderTask(AsyncTask):
self,
delay: float,
stream_id: str,
group_id: Optional[str],
group_id: str | None,
is_group: bool,
target_user_id: str,
target_user_name: str,
@@ -162,7 +163,7 @@ class PokeAction(BaseAction):
"""
associated_types = ["text"]
async def execute(self) -> Tuple[bool, str]:
async def execute(self) -> tuple[bool, str]:
"""执行戳一戳的动作"""
user_id = self.action_data.get("user_id")
user_name = self.action_data.get("user_name")
@@ -242,7 +243,7 @@ class SetEmojiLikeAction(BaseAction):
if match:
emoji_options.append(match.group(1))
async def execute(self) -> Tuple[bool, str]:
async def execute(self) -> tuple[bool, str]:
"""执行设置表情回应的动作"""
message_id = None
set_like = self.action_data.get("set", True)
@@ -360,7 +361,7 @@ class RemindAction(BaseAction):
"例如:'10分钟后提醒我收快递''明天早上九点喊一下李四参加晨会'",
]
async def execute(self) -> Tuple[bool, str]:
async def execute(self) -> tuple[bool, str]:
"""执行设置提醒的动作"""
user_name = self.action_data.get("user_name")
remind_time_str = self.action_data.get("remind_time")
@@ -386,14 +387,14 @@ class RemindAction(BaseAction):
# 优先尝试直接解析
try:
target_time = parse_datetime(remind_time_str, fuzzy=True)
except Exception:
except Exception as e:
# 如果直接解析失败,调用 LLM 进行转换
logger.info(f"[ReminderPlugin] 直接解析时间 '{remind_time_str}' 失败,尝试使用 LLM 进行转换...")
# 获取所有可用的模型配置
available_models = llm_api.get_available_models()
if "utils_small" not in available_models:
raise ValueError("未找到 'utils_small' 模型配置,无法解析时间")
raise ValueError("未找到 'utils_small' 模型配置,无法解析时间") from e
# 明确使用 'planner' 模型
model_to_use = available_models["utils_small"]
@@ -419,7 +420,7 @@ class RemindAction(BaseAction):
)
if not success or not response:
raise ValueError(f"LLM未能返回有效的时间字符串: {response}")
raise ValueError(f"LLM未能返回有效的时间字符串: {response}") from e
converted_time_str = response.strip()
logger.info(f"[ReminderPlugin] LLM 转换结果: '{converted_time_str}'")
@@ -533,8 +534,8 @@ class SetEmojiLikePlugin(BasePlugin):
# 插件基本信息
plugin_name: str = "social_toolkit_plugin" # 内部标识符
enable_plugin: bool = True
dependencies: List[str] = [] # 插件依赖列表
python_dependencies: List[str] = [] # Python包依赖列表现在使用内置API
dependencies: list[str] = [] # 插件依赖列表
python_dependencies: list[str] = [] # Python包依赖列表现在使用内置API
config_file_name: str = "config.toml" # 配置文件名
# 配置节描述
@@ -555,7 +556,7 @@ class SetEmojiLikePlugin(BasePlugin):
},
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
enable_components = []
if self.get_config("components.action_set_emoji_like"):
enable_components.append((SetEmojiLikeAction.get_action_info(), SetEmojiLikeAction))

View File

@@ -1,10 +1,9 @@
from src.common.logger import get_logger
from src.plugin_system.apis.plugin_register_api import register_plugin
from src.plugin_system.base.base_action import ActionActivationType, BaseAction, ChatMode
from src.plugin_system.base.base_plugin import BasePlugin
from src.plugin_system.base.component_types import ComponentInfo
from src.common.logger import get_logger
from src.plugin_system.base.base_action import BaseAction, ActionActivationType, ChatMode
from src.plugin_system.base.config_types import ConfigField
from typing import Tuple, List, Type
logger = get_logger("tts")
@@ -44,7 +43,7 @@ class TTSAction(BaseAction):
# 关联类型
associated_types = ["tts_text"]
async def execute(self) -> Tuple[bool, str]:
async def execute(self) -> tuple[bool, str]:
"""处理TTS文本转语音动作"""
logger.info(f"{self.log_prefix} 执行TTS动作: {self.reasoning}")
@@ -140,7 +139,7 @@ class TTSPlugin(BasePlugin):
},
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
"""返回插件包含的组件列表"""
# 从配置获取组件启用状态

View File

@@ -3,7 +3,7 @@ Base search engine interface
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Any
from typing import Any
class BaseSearchEngine(ABC):
@@ -12,7 +12,7 @@ class BaseSearchEngine(ABC):
"""
@abstractmethod
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
"""
执行搜索

View File

@@ -6,11 +6,13 @@ import asyncio
import functools
import random
import traceback
from typing import Dict, List, Any
from typing import Any
import requests
from bs4 import BeautifulSoup
from src.common.logger import get_logger
from .base import BaseSearchEngine
logger = get_logger("bing_engine")
@@ -68,7 +70,7 @@ class BingSearchEngine(BaseSearchEngine):
"""检查Bing搜索引擎是否可用"""
return True # Bing是免费搜索引擎总是可用
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
"""执行Bing搜索"""
query = args["query"]
num_results = args.get("num_results", 3)
@@ -83,7 +85,7 @@ class BingSearchEngine(BaseSearchEngine):
logger.error(f"Bing 搜索失败: {e}")
return []
def _search_sync(self, keyword: str, num_results: int, time_range: str) -> List[Dict[str, Any]]:
def _search_sync(self, keyword: str, num_results: int, time_range: str) -> list[dict[str, Any]]:
"""同步执行Bing搜索"""
if not keyword:
return []
@@ -113,7 +115,7 @@ class BingSearchEngine(BaseSearchEngine):
return list_result[:num_results] if len(list_result) > num_results else list_result
@staticmethod
def _parse_html(url: str) -> List[Dict[str, Any]]:
def _parse_html(url: str) -> list[dict[str, Any]]:
"""解析处理结果"""
try:
logger.debug(f"访问Bing搜索URL: {url}")
@@ -141,11 +143,11 @@ class BingSearchEngine(BaseSearchEngine):
try:
res = session.get(url=url, timeout=(3.05, 6), verify=True, allow_redirects=True)
except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e:
logger.warning(f"第一次请求超时,正在重试: {str(e)}")
logger.warning(f"第一次请求超时,正在重试: {e!s}")
try:
res = session.get(url=url, timeout=(5, 10), verify=False)
except Exception as e2:
logger.error(f"第二次请求也失败: {str(e2)}")
logger.error(f"第二次请求也失败: {e2!s}")
return []
res.encoding = "utf-8"
@@ -175,7 +177,7 @@ class BingSearchEngine(BaseSearchEngine):
try:
root = BeautifulSoup(res.text, "html.parser")
except Exception as e:
logger.error(f"HTML解析失败: {str(e)}")
logger.error(f"HTML解析失败: {e!s}")
return []
list_data = []
@@ -262,6 +264,6 @@ class BingSearchEngine(BaseSearchEngine):
return list_data
except Exception as e:
logger.error(f"解析Bing页面时出错: {str(e)}")
logger.error(f"解析Bing页面时出错: {e!s}")
logger.debug(traceback.format_exc())
return []

View File

@@ -2,10 +2,12 @@
DuckDuckGo search engine implementation
"""
from typing import Dict, List, Any
from typing import Any
from asyncddgs import aDDGS
from src.common.logger import get_logger
from .base import BaseSearchEngine
logger = get_logger("ddg_engine")
@@ -20,7 +22,7 @@ class DDGSearchEngine(BaseSearchEngine):
"""检查DuckDuckGo搜索引擎是否可用"""
return True # DuckDuckGo不需要API密钥总是可用
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
"""执行DuckDuckGo搜索"""
query = args["query"]
num_results = args.get("num_results", 3)

View File

@@ -5,13 +5,15 @@ Exa search engine implementation
import asyncio
import functools
from datetime import datetime, timedelta
from typing import Dict, List, Any
from typing import Any
from exa_py import Exa
from src.common.logger import get_logger
from src.plugin_system.apis import config_api
from .base import BaseSearchEngine
from ..utils.api_key_manager import create_api_key_manager_from_config
from .base import BaseSearchEngine
logger = get_logger("exa_engine")
@@ -36,7 +38,7 @@ class ExaSearchEngine(BaseSearchEngine):
"""检查Exa搜索引擎是否可用"""
return self.api_manager.is_available()
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
"""执行Exa搜索"""
if not self.is_available():
return []

View File

@@ -4,13 +4,15 @@ Tavily search engine implementation
import asyncio
import functools
from typing import Dict, List, Any
from typing import Any
from tavily import TavilyClient
from src.common.logger import get_logger
from src.plugin_system.apis import config_api
from .base import BaseSearchEngine
from ..utils.api_key_manager import create_api_key_manager_from_config
from .base import BaseSearchEngine
logger = get_logger("tavily_engine")
@@ -37,7 +39,7 @@ class TavilySearchEngine(BaseSearchEngine):
"""检查Tavily搜索引擎是否可用"""
return self.api_manager.is_available()
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
"""执行Tavily搜索"""
if not self.is_available():
return []

View File

@@ -4,14 +4,12 @@ Web Search Tool Plugin
一个功能强大的网络搜索和URL解析插件支持多种搜索引擎和解析策略。
"""
from typing import List, Tuple, Type
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo, ConfigField, PythonDependency
from src.plugin_system.apis import config_api
from src.common.logger import get_logger
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, PythonDependency, register_plugin
from src.plugin_system.apis import config_api
from .tools.web_search import WebSurfingTool
from .tools.url_parser import URLParserTool
from .tools.web_search import WebSurfingTool
logger = get_logger("web_search_plugin")
@@ -31,7 +29,7 @@ class WEBSEARCHPLUGIN(BasePlugin):
# 插件基本信息
plugin_name: str = "web_search_tool" # 内部标识符
enable_plugin: bool = True
dependencies: List[str] = [] # 插件依赖列表
dependencies: list[str] = [] # 插件依赖列表
def __init__(self, *args, **kwargs):
"""初始化插件,立即加载所有搜索引擎"""
@@ -40,10 +38,10 @@ class WEBSEARCHPLUGIN(BasePlugin):
# 立即初始化所有搜索引擎触发API密钥管理器的日志输出
logger.info("🚀 正在初始化所有搜索引擎...")
try:
from .engines.bing_engine import BingSearchEngine
from .engines.ddg_engine import DDGSearchEngine
from .engines.exa_engine import ExaSearchEngine
from .engines.tavily_engine import TavilySearchEngine
from .engines.ddg_engine import DDGSearchEngine
from .engines.bing_engine import BingSearchEngine
# 实例化所有搜索引擎这会触发API密钥管理器的初始化
exa_engine = ExaSearchEngine()
@@ -71,7 +69,7 @@ class WEBSEARCHPLUGIN(BasePlugin):
logger.error(f"❌ 搜索引擎初始化失败: {e}", exc_info=True)
# Python包依赖列表
python_dependencies: List[PythonDependency] = [
python_dependencies: list[PythonDependency] = [
PythonDependency(package_name="asyncddgs", description="异步DuckDuckGo搜索库", optional=False),
PythonDependency(
package_name="exa_py",
@@ -119,7 +117,7 @@ class WEBSEARCHPLUGIN(BasePlugin):
},
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
"""
获取插件组件列表

View File

@@ -4,19 +4,20 @@ URL parser tool implementation
import asyncio
import functools
from typing import Any, Dict
from exa_py import Exa
from typing import Any
import httpx
from bs4 import BeautifulSoup
from exa_py import Exa
from src.common.cache_manager import tool_cache
from src.common.logger import get_logger
from src.plugin_system import BaseTool, ToolParamType, llm_api
from src.plugin_system.apis import config_api
from src.common.cache_manager import tool_cache
from ..utils.api_key_manager import create_api_key_manager_from_config
from ..utils.formatters import format_url_parse_results
from ..utils.url_utils import parse_urls_from_input, validate_urls
from ..utils.api_key_manager import create_api_key_manager_from_config
logger = get_logger("url_parser_tool")
@@ -50,7 +51,7 @@ class URLParserTool(BaseTool):
exa_api_keys, lambda key: Exa(api_key=key), "Exa URL Parser"
)
async def _local_parse_and_summarize(self, url: str) -> Dict[str, Any]:
async def _local_parse_and_summarize(self, url: str) -> dict[str, Any]:
"""
使用本地库(httpx, BeautifulSoup)解析URL并调用LLM进行总结。
"""
@@ -124,9 +125,9 @@ class URLParserTool(BaseTool):
return {"error": f"请求失败,状态码: {e.response.status_code}"}
except Exception as e:
logger.error(f"本地解析或总结URL '{url}' 时发生未知异常: {e}", exc_info=True)
return {"error": f"发生未知错误: {str(e)}"}
return {"error": f"发生未知错误: {e!s}"}
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""
执行URL内容提取和总结。优先使用Exa失败后尝试本地解析。
"""

View File

@@ -3,18 +3,18 @@ Web search tool implementation
"""
import asyncio
from typing import Any, Dict, List
from typing import Any
from src.common.cache_manager import tool_cache
from src.common.logger import get_logger
from src.plugin_system import BaseTool, ToolParamType
from src.plugin_system.apis import config_api
from src.common.cache_manager import tool_cache
from ..engines.bing_engine import BingSearchEngine
from ..engines.ddg_engine import DDGSearchEngine
from ..engines.exa_engine import ExaSearchEngine
from ..engines.tavily_engine import TavilySearchEngine
from ..engines.ddg_engine import DDGSearchEngine
from ..engines.bing_engine import BingSearchEngine
from ..utils.formatters import format_search_results, deduplicate_results
from ..utils.formatters import deduplicate_results, format_search_results
logger = get_logger("web_search_tool")
@@ -51,7 +51,7 @@ class WebSurfingTool(BaseTool):
"bing": BingSearchEngine(),
}
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
query = function_args.get("query")
if not query:
return {"error": "搜索查询不能为空。"}
@@ -88,8 +88,8 @@ class WebSurfingTool(BaseTool):
return result
async def _execute_parallel_search(
self, function_args: Dict[str, Any], enabled_engines: List[str]
) -> Dict[str, Any]:
self, function_args: dict[str, Any], enabled_engines: list[str]
) -> dict[str, Any]:
"""并行搜索策略:同时使用所有启用的搜索引擎"""
search_tasks = []
@@ -124,11 +124,11 @@ class WebSurfingTool(BaseTool):
except Exception as e:
logger.error(f"执行并行网络搜索时发生异常: {e}", exc_info=True)
return {"error": f"执行网络搜索时发生严重错误: {str(e)}"}
return {"error": f"执行网络搜索时发生严重错误: {e!s}"}
async def _execute_fallback_search(
self, function_args: Dict[str, Any], enabled_engines: List[str]
) -> Dict[str, Any]:
self, function_args: dict[str, Any], enabled_engines: list[str]
) -> dict[str, Any]:
"""回退搜索策略:按顺序尝试搜索引擎,失败则尝试下一个"""
for engine_name in enabled_engines:
engine = self.engines.get(engine_name)
@@ -154,7 +154,7 @@ class WebSurfingTool(BaseTool):
return {"error": "所有搜索引擎都失败了。"}
async def _execute_single_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]:
async def _execute_single_search(self, function_args: dict[str, Any], enabled_engines: list[str]) -> dict[str, Any]:
"""单一搜索策略:只使用第一个可用的搜索引擎"""
for engine_name in enabled_engines:
engine = self.engines.get(engine_name)
@@ -174,6 +174,6 @@ class WebSurfingTool(BaseTool):
except Exception as e:
logger.error(f"{engine_name} 搜索失败: {e}")
return {"error": f"{engine_name} 搜索失败: {str(e)}"}
return {"error": f"{engine_name} 搜索失败: {e!s}"}
return {"error": "没有可用的搜索引擎。"}

View File

@@ -3,7 +3,9 @@ API密钥管理器提供轮询机制
"""
import itertools
from typing import List, Optional, TypeVar, Generic, Callable
from collections.abc import Callable
from typing import Generic, TypeVar
from src.common.logger import get_logger
logger = get_logger("api_key_manager")
@@ -16,7 +18,7 @@ class APIKeyManager(Generic[T]):
API密钥管理器支持轮询机制
"""
def __init__(self, api_keys: List[str], client_factory: Callable[[str], T], service_name: str = "Unknown"):
def __init__(self, api_keys: list[str], client_factory: Callable[[str], T], service_name: str = "Unknown"):
"""
初始化API密钥管理器
@@ -26,8 +28,8 @@ class APIKeyManager(Generic[T]):
service_name: 服务名称,用于日志记录
"""
self.service_name = service_name
self.clients: List[T] = []
self.client_cycle: Optional[itertools.cycle] = None
self.clients: list[T] = []
self.client_cycle: itertools.cycle | None = None
if api_keys:
# 过滤有效的API密钥排除None、空字符串、"None"字符串等
@@ -54,7 +56,7 @@ class APIKeyManager(Generic[T]):
"""检查是否有可用的客户端"""
return bool(self.clients and self.client_cycle)
def get_next_client(self) -> Optional[T]:
def get_next_client(self) -> T | None:
"""获取下一个客户端(轮询)"""
if not self.is_available():
return None
@@ -66,7 +68,7 @@ class APIKeyManager(Generic[T]):
def create_api_key_manager_from_config(
config_keys: Optional[List[str]], client_factory: Callable[[str], T], service_name: str
config_keys: list[str] | None, client_factory: Callable[[str], T], service_name: str
) -> APIKeyManager[T]:
"""
从配置创建API密钥管理器的便捷函数

View File

@@ -2,10 +2,10 @@
Formatters for web search results
"""
from typing import List, Dict, Any
from typing import Any
def format_search_results(results: List[Dict[str, Any]]) -> str:
def format_search_results(results: list[dict[str, Any]]) -> str:
"""
格式化搜索结果为字符串
"""
@@ -26,7 +26,7 @@ def format_search_results(results: List[Dict[str, Any]]) -> str:
return formatted_string
def format_url_parse_results(results: List[Dict[str, Any]]) -> str:
def format_url_parse_results(results: list[dict[str, Any]]) -> str:
"""
将成功解析的URL结果列表格式化为一段简洁的文本。
"""
@@ -45,7 +45,7 @@ def format_url_parse_results(results: List[Dict[str, Any]]) -> str:
return "\n---\n".join(formatted_parts)
def deduplicate_results(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
def deduplicate_results(results: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
根据URL去重搜索结果
"""

View File

@@ -3,10 +3,9 @@ URL processing utilities
"""
import re
from typing import List
def parse_urls_from_input(urls_input) -> List[str]:
def parse_urls_from_input(urls_input) -> list[str]:
"""
从输入中解析URL列表
"""
@@ -29,7 +28,7 @@ def parse_urls_from_input(urls_input) -> List[str]:
return urls
def validate_urls(urls: List[str]) -> List[str]:
def validate_urls(urls: list[str]) -> list[str]:
"""
验证URL格式返回有效的URL列表
"""