refactor: 统一类型注解风格并优化代码结构
- 将裸 except 改为显式 Exception 捕获 - 用列表推导式替换冗余 for 循环 - 为类属性添加 ClassVar 注解 - 统一 Union/Optional 写法为 | - 移除未使用的导入 - 修复 SQLAlchemy 空值比较语法 - 优化字符串拼接与字典更新逻辑 - 补充缺失的 noqa 注释与异常链 BREAKING CHANGE: 所有插件基类的类级字段现要求显式 ClassVar 注解,自定义插件需同步更新
This commit is contained in:
@@ -4,7 +4,7 @@ Bilibili 视频观看体验工具
|
||||
支持哔哩哔哩视频链接解析和AI视频内容分析
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BasePlugin, BaseTool, ComponentInfo, ConfigField, ToolParamType, register_plugin
|
||||
@@ -21,7 +21,7 @@ class BilibiliTool(BaseTool):
|
||||
description = "观看用户分享的哔哩哔哩视频,以真实用户视角给出观看感受和评价"
|
||||
available_for_llm = True
|
||||
|
||||
parameters = [
|
||||
parameters: ClassVar = [
|
||||
(
|
||||
"url",
|
||||
ToolParamType.STRING,
|
||||
@@ -166,7 +166,7 @@ class BilibiliTool(BaseTool):
|
||||
return "(有点长,适合闲时观看)"
|
||||
else:
|
||||
return "(超长视频,需要耐心)"
|
||||
except:
|
||||
except Exception:
|
||||
return ""
|
||||
return ""
|
||||
|
||||
@@ -191,16 +191,16 @@ class BilibiliPlugin(BasePlugin):
|
||||
|
||||
# 插件基本信息
|
||||
plugin_name: str = "bilibili_video_watcher"
|
||||
enable_plugin: bool = False
|
||||
dependencies: list[str] = []
|
||||
python_dependencies: list[str] = []
|
||||
enable_plugin: bool = True
|
||||
dependencies: ClassVar[list[str] ] = []
|
||||
python_dependencies: ClassVar[list[str] ] = []
|
||||
config_file_name: str = "config.toml"
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {"plugin": "插件基本信息", "bilibili": "哔哩哔哩视频观看配置", "tool": "工具配置"}
|
||||
config_section_descriptions: ClassVar[dict] = {"plugin": "插件基本信息", "bilibili": "哔哩哔哩视频观看配置", "tool": "工具配置"}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema: dict = {
|
||||
config_schema: ClassVar[dict] = {
|
||||
"plugin": {
|
||||
"name": ConfigField(type=str, default="bilibili_video_watcher", description="插件名称"),
|
||||
"version": ConfigField(type=str, default="2.0.0", description="插件版本"),
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import random
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -29,7 +29,7 @@ class StartupMessageHandler(BaseEventHandler):
|
||||
|
||||
handler_name = "hello_world_startup_handler"
|
||||
handler_description = "在机器人启动时打印一条日志。"
|
||||
init_subscribe = [EventType.ON_START]
|
||||
init_subscribe: ClassVar[list[EventType]] = [EventType.ON_START]
|
||||
|
||||
async def execute(self, params: dict) -> HandlerResult:
|
||||
logger.info("🎉 Hello World 插件已启动,准备就绪!")
|
||||
@@ -42,7 +42,7 @@ class GetSystemInfoTool(BaseTool):
|
||||
name = "get_system_info"
|
||||
description = "获取当前系统的模拟版本和状态信息。"
|
||||
available_for_llm = True
|
||||
parameters = [
|
||||
parameters: ClassVar = [
|
||||
("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None),
|
||||
("num_results", ToolParamType.INTEGER, "期望每个搜索引擎返回的搜索结果数量,默认为5。", False, None),
|
||||
(
|
||||
@@ -63,7 +63,7 @@ class HelloCommand(PlusCommand):
|
||||
|
||||
command_name = "hello"
|
||||
command_description = "向机器人发送一个简单的问候。"
|
||||
command_aliases = ["hi", "你好"]
|
||||
command_aliases: ClassVar[list[str]] = ["hi", "你好"]
|
||||
chat_type_allow = ChatType.ALL
|
||||
|
||||
async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]:
|
||||
@@ -85,8 +85,8 @@ class KeywordActivationExampleAction(BaseAction):
|
||||
|
||||
action_name = "keyword_example"
|
||||
action_description = "当检测到特定关键词时发送回应"
|
||||
action_require = ["用户提到了问候语"]
|
||||
associated_types = ["text"]
|
||||
action_require: ClassVar[list[str]] = ["用户提到了问候语"]
|
||||
associated_types: ClassVar[list[str]] = ["text"]
|
||||
|
||||
async def go_activate(self, chat_content: str = "", llm_judge_model=None) -> bool:
|
||||
"""关键词激活:检测到"你好"、"hello"或"hi"时激活"""
|
||||
@@ -109,8 +109,8 @@ class LLMJudgeExampleAction(BaseAction):
|
||||
|
||||
action_name = "llm_judge_example"
|
||||
action_description = "当用户表达情绪低落时提供安慰"
|
||||
action_require = ["用户情绪低落", "需要情感支持"]
|
||||
associated_types = ["text"]
|
||||
action_require: ClassVar[list[str]] = ["用户情绪低落", "需要情感支持"]
|
||||
associated_types: ClassVar[list[str]] = ["text"]
|
||||
|
||||
async def go_activate(self, chat_content: str = "", llm_judge_model=None) -> bool:
|
||||
"""LLM 判断激活:判断用户是否情绪低落"""
|
||||
@@ -139,8 +139,8 @@ class CombinedActivationExampleAction(BaseAction):
|
||||
|
||||
action_name = "combined_example"
|
||||
action_description = "展示如何组合多种激活条件"
|
||||
action_require = ["展示灵活的激活逻辑"]
|
||||
associated_types = ["text"]
|
||||
action_require: ClassVar[list[str]] = ["展示灵活的激活逻辑"]
|
||||
associated_types: ClassVar[list[str]] = ["text"]
|
||||
|
||||
async def go_activate(self, chat_content: str = "", llm_judge_model=None) -> bool:
|
||||
"""组合激活:随机 20% 概率,或者匹配特定关键词"""
|
||||
@@ -168,8 +168,8 @@ class RandomEmojiAction(BaseAction):
|
||||
|
||||
action_name = "random_emoji"
|
||||
action_description = "随机发送一个表情符号,增加聊天的趣味性。"
|
||||
action_require = ["当对话气氛轻松时", "可以用来回应简单的情感表达"]
|
||||
associated_types = ["text"]
|
||||
action_require: ClassVar[list[str]] = ["当对话气氛轻松时", "可以用来回应简单的情感表达"]
|
||||
associated_types: ClassVar[list[str]] = ["text"]
|
||||
|
||||
async def go_activate(self, llm_judge_model=None) -> bool:
|
||||
"""使用新的激活方式:10% 的概率激活
|
||||
@@ -189,7 +189,7 @@ class WeatherPrompt(BasePrompt):
|
||||
|
||||
prompt_name = "weather_info_prompt"
|
||||
prompt_description = "向Planner注入当前天气信息,以丰富对话上下文。"
|
||||
injection_rules = [InjectionRule(target_prompt="planner_prompt", injection_type=InjectionType.REPLACE, target_content="## 可用动作列表")]
|
||||
injection_rules: ClassVar[list[InjectionRule]] = [InjectionRule(target_prompt="planner_prompt", injection_type=InjectionType.REPLACE, target_content="## 可用动作列表")]
|
||||
|
||||
async def execute(self) -> str:
|
||||
# 在实际应用中,这里可以调用天气API
|
||||
@@ -203,11 +203,11 @@ class HelloWorldPlugin(BasePlugin):
|
||||
|
||||
plugin_name = "hello_world_plugin"
|
||||
enable_plugin = True
|
||||
dependencies = []
|
||||
python_dependencies = []
|
||||
dependencies: ClassVar = []
|
||||
python_dependencies: ClassVar = []
|
||||
config_file_name = "config.toml"
|
||||
|
||||
config_schema = {
|
||||
config_schema: ClassVar = {
|
||||
"meta": {
|
||||
"config_version": ConfigField(type=int, default=1, description="配置文件版本,请勿手动修改。"),
|
||||
},
|
||||
@@ -224,7 +224,7 @@ class HelloWorldPlugin(BasePlugin):
|
||||
|
||||
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
|
||||
"""根据配置文件动态注册插件的功能组件。"""
|
||||
components: list[tuple[ComponentInfo, type]] = []
|
||||
components: ClassVar[list[tuple[ComponentInfo, type]] ] = []
|
||||
|
||||
components.append((StartupMessageHandler.get_handler_info(), StartupMessageHandler))
|
||||
components.append((GetSystemInfoTool.get_tool_info(), GetSystemInfoTool))
|
||||
|
||||
@@ -63,12 +63,12 @@ async def check_database():
|
||||
null_situation = await session.execute(
|
||||
select(func.count())
|
||||
.select_from(Expression)
|
||||
.where(Expression.situation == None)
|
||||
.where(Expression.situation is None)
|
||||
)
|
||||
null_style = await session.execute(
|
||||
select(func.count())
|
||||
.select_from(Expression)
|
||||
.where(Expression.style == None)
|
||||
.where(Expression.style is None)
|
||||
)
|
||||
|
||||
null_sit_count = null_situation.scalar()
|
||||
@@ -102,7 +102,7 @@ async def check_database():
|
||||
.limit(20)
|
||||
)
|
||||
|
||||
styles = [s for s in unique_styles.scalars()]
|
||||
styles = list(unique_styles.scalars())
|
||||
for style in styles:
|
||||
print(f" - {style}")
|
||||
|
||||
|
||||
@@ -29,15 +29,14 @@ async def analyze_style_fields():
|
||||
print(f"\n总共检查 {len(expressions)} 条记录\n")
|
||||
|
||||
# 按类型分类
|
||||
style_examples = []
|
||||
|
||||
for expr in expressions:
|
||||
if expr.type == "style":
|
||||
style_examples.append({
|
||||
style_examples = [
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"length": len(expr.style) if expr.style else 0
|
||||
})
|
||||
}
|
||||
for expr in expressions if expr.type == "style"
|
||||
]
|
||||
|
||||
print("📋 Style 类型样例 (前15条):")
|
||||
print("="*60)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
from typing import List, Optional
|
||||
|
||||
from colorama import Fore, init
|
||||
|
||||
|
||||
@@ -430,7 +430,7 @@ class ExpressionSelector:
|
||||
.where(Expression.type == "style")
|
||||
.distinct()
|
||||
)
|
||||
db_chat_ids = [cid for cid in db_chat_ids_result.scalars()]
|
||||
db_chat_ids = list(db_chat_ids_result.scalars())
|
||||
logger.info(f"数据库中有表达方式的chat_ids ({len(db_chat_ids)}个): {db_chat_ids}")
|
||||
|
||||
# 获取所有相关 chat_id 的表达方式(用于模糊匹配)
|
||||
@@ -509,15 +509,16 @@ class ExpressionSelector:
|
||||
)
|
||||
|
||||
# 转换为字典格式
|
||||
expressions = []
|
||||
for expr in expressions_objs:
|
||||
expressions.append({
|
||||
expressions = [
|
||||
{
|
||||
"situation": expr.situation or "",
|
||||
"style": expr.style or "",
|
||||
"type": expr.type or "style",
|
||||
"count": float(expr.count) if expr.count else 0.0,
|
||||
"last_active_time": expr.last_active_time or 0.0
|
||||
})
|
||||
}
|
||||
for expr in expressions_objs
|
||||
]
|
||||
|
||||
logger.debug(f"从数据库获取了 {len(expressions)} 个表达方式")
|
||||
return expressions
|
||||
@@ -606,7 +607,7 @@ class ExpressionSelector:
|
||||
|
||||
# 对选中的所有表达方式,一次性更新count数
|
||||
if valid_expressions:
|
||||
asyncio.create_task(self.update_expressions_count_batch(valid_expressions, 0.006))
|
||||
asyncio.create_task(self.update_expressions_count_batch(valid_expressions, 0.006)) # noqa: RUF006
|
||||
|
||||
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
|
||||
return valid_expressions, selected_ids
|
||||
|
||||
@@ -61,7 +61,7 @@ class ExpressorModel:
|
||||
if cid not in self.nb.token_counts:
|
||||
self.nb.token_counts[cid] = defaultdict(float)
|
||||
|
||||
def predict(self, text: str, k: int = None) -> tuple[str | None, dict[str, float]]:
|
||||
def predict(self, text: str, k: int | None = None) -> tuple[str | None, dict[str, float]]:
|
||||
"""
|
||||
直接对所有候选进行朴素贝叶斯评分
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ logger = get_logger("expressor.tokenizer")
|
||||
class Tokenizer:
|
||||
"""文本分词器,支持中文Jieba分词"""
|
||||
|
||||
def __init__(self, stopwords: set = None, use_jieba: bool = True):
|
||||
def __init__(self, stopwords: set | None = None, use_jieba: bool = True):
|
||||
"""
|
||||
Args:
|
||||
stopwords: 停用词集合
|
||||
@@ -21,7 +21,7 @@ class Tokenizer:
|
||||
|
||||
if use_jieba:
|
||||
try:
|
||||
import rjieba
|
||||
import rjieba # noqa: F401
|
||||
|
||||
# rjieba 会自动初始化,无需手动调用
|
||||
logger.info("RJieba分词器初始化成功")
|
||||
|
||||
@@ -391,7 +391,7 @@ class StyleLearnerManager:
|
||||
是否全部保存成功
|
||||
"""
|
||||
success = True
|
||||
for chat_id, learner in self.learners.items():
|
||||
for learner in self.learners.values():
|
||||
if not learner.save(self.model_save_path):
|
||||
success = False
|
||||
|
||||
|
||||
@@ -306,10 +306,8 @@ class EmbeddingStore:
|
||||
|
||||
def save_to_file(self) -> None:
|
||||
"""保存到文件"""
|
||||
data = []
|
||||
logger.info(f"正在保存{self.namespace}嵌入库到文件{self.embedding_file_path}")
|
||||
for item in self.store.values():
|
||||
data.append(item.to_dict())
|
||||
data = [item.to_dict() for item in self.store.values()]
|
||||
data_frame = pd.DataFrame(data)
|
||||
|
||||
if not os.path.exists(self.dir):
|
||||
|
||||
@@ -15,15 +15,14 @@ def dyn_select_top_k(
|
||||
# 归一化
|
||||
max_score = sorted_score[0][1]
|
||||
min_score = sorted_score[-1][1]
|
||||
normalized_score = []
|
||||
for score_item in sorted_score:
|
||||
normalized_score.append(
|
||||
normalized_score = [
|
||||
(
|
||||
score_item[0],
|
||||
score_item[1],
|
||||
(score_item[1] - min_score) / (max_score - min_score),
|
||||
)
|
||||
)
|
||||
for score_item in sorted_score
|
||||
]
|
||||
|
||||
# 寻找跳变点:score变化最大的位置
|
||||
jump_idx = 0
|
||||
|
||||
@@ -468,10 +468,10 @@ class HippocampusSampler:
|
||||
merged_groups.append(current_group)
|
||||
|
||||
# 过滤掉只有一条消息的组(除非内容较长)
|
||||
result_groups = []
|
||||
for group in merged_groups:
|
||||
if len(group) > 1 or any(len(msg.get("processed_plain_text", "")) > 100 for msg in group):
|
||||
result_groups.append(group)
|
||||
result_groups = [
|
||||
group for group in merged_groups
|
||||
if len(group) > 1 or any(len(msg.get("processed_plain_text", "")) > 100 for msg in group)
|
||||
]
|
||||
|
||||
return result_groups
|
||||
|
||||
|
||||
@@ -634,9 +634,7 @@ class MemoryBuilder:
|
||||
if cleaned:
|
||||
participants.append(cleaned)
|
||||
elif isinstance(value, str):
|
||||
for part in self._split_subject_string(value):
|
||||
if part:
|
||||
participants.append(part)
|
||||
participants.extend(part for part in self._split_subject_string(value) if part)
|
||||
|
||||
fallback = self._resolve_user_display(context, user_id)
|
||||
if fallback:
|
||||
|
||||
@@ -1267,9 +1267,7 @@ class MemorySystem:
|
||||
)
|
||||
|
||||
if relevant_memories:
|
||||
memory_contexts = []
|
||||
for memory in relevant_memories:
|
||||
memory_contexts.append(f"[历史记忆] {memory.text_content}")
|
||||
memory_contexts = [f"[历史记忆] {memory.text_content}" for memory in relevant_memories]
|
||||
|
||||
memory_transcript = "\n".join(memory_contexts)
|
||||
cleaned_fallback = (fallback_text or "").strip()
|
||||
|
||||
@@ -122,8 +122,7 @@ class MessageCollectionStorage:
|
||||
|
||||
collections = []
|
||||
if results and results.get("ids") and results["ids"][0]:
|
||||
for metadata in results["metadatas"][0]:
|
||||
collections.append(MessageCollection.from_dict(metadata))
|
||||
collections.extend(MessageCollection.from_dict(metadata) for metadata in results["metadatas"][0])
|
||||
|
||||
return collections
|
||||
except Exception as e:
|
||||
|
||||
@@ -161,7 +161,7 @@ class GlobalNoticeManager:
|
||||
self._cleanup_expired_notices()
|
||||
|
||||
# 收集可访问的notice
|
||||
for storage_key, notices in self._notices.items():
|
||||
for notices in self._notices.values():
|
||||
for notice in notices:
|
||||
if notice.is_expired():
|
||||
continue
|
||||
|
||||
@@ -98,7 +98,7 @@ async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str
|
||||
mentioned_value = processing_state.get("is_mentioned")
|
||||
if isinstance(mentioned_value, bool):
|
||||
is_mentioned = mentioned_value
|
||||
elif isinstance(mentioned_value, (int, float)):
|
||||
elif isinstance(mentioned_value, int | float):
|
||||
is_mentioned = mentioned_value != 0
|
||||
|
||||
db_message = DatabaseMessages(
|
||||
|
||||
@@ -276,8 +276,8 @@ class MessageStorage:
|
||||
except Exception as e:
|
||||
logger.error(f"更新消息ID失败: {e}")
|
||||
logger.error(
|
||||
f"消息信息: message_id={getattr(message.message_info, 'message_id', 'N/A')}, "
|
||||
f"segment_type={getattr(message.message_segment, 'type', 'N/A')}"
|
||||
f"消息信息: message_id={message_data.get('message_info', {}).get('message_id', 'N/A')}, "
|
||||
f"segment_type={message_data.get('message_segment', {}).get('type', 'N/A')}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -47,7 +47,7 @@ async def send_message(message: MessageSending, show_log=True) -> bool:
|
||||
logger.error(f"[事件触发] 异步触发事件失败: {e}", exc_info=True)
|
||||
|
||||
# 创建异步任务,不等待完成
|
||||
asyncio.create_task(trigger_event_async())
|
||||
asyncio.create_task(trigger_event_async()) # noqa: RUF006
|
||||
logger.info("[发送完成] AFTER_SEND 事件已提交到异步任务")
|
||||
except Exception as event_error:
|
||||
logger.error(f"触发 AFTER_SEND 事件时出错: {event_error}", exc_info=True)
|
||||
|
||||
@@ -204,7 +204,7 @@ class ChatterActionManager:
|
||||
action_prompt_display=reason,
|
||||
)
|
||||
else:
|
||||
asyncio.create_task(
|
||||
asyncio.create_task( # noqa: RUF006
|
||||
database_api.store_action_info(
|
||||
chat_stream=chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
@@ -217,7 +217,7 @@ class ChatterActionManager:
|
||||
)
|
||||
|
||||
# 自动清空所有未读消息
|
||||
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, "no_reply"))
|
||||
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, "no_reply")) # noqa: RUF006
|
||||
|
||||
return {"action_type": "no_reply", "success": True, "reply_text": "", "command": ""}
|
||||
|
||||
@@ -235,14 +235,14 @@ class ChatterActionManager:
|
||||
|
||||
# 记录执行的动作到目标消息
|
||||
if success:
|
||||
asyncio.create_task(
|
||||
asyncio.create_task( # noqa: RUF006
|
||||
self._record_action_to_message(chat_stream, action_name, target_message, action_data)
|
||||
)
|
||||
# 自动清空所有未读消息
|
||||
if clear_unread_messages:
|
||||
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, action_name))
|
||||
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, action_name)) # noqa: RUF006
|
||||
# 重置打断计数
|
||||
asyncio.create_task(self._reset_interruption_count_after_action(chat_stream.stream_id))
|
||||
asyncio.create_task(self._reset_interruption_count_after_action(chat_stream.stream_id)) # noqa: RUF006
|
||||
|
||||
return {
|
||||
"action_type": action_name,
|
||||
@@ -295,13 +295,13 @@ class ChatterActionManager:
|
||||
)
|
||||
|
||||
# 记录回复动作到目标消息
|
||||
asyncio.create_task(self._record_action_to_message(chat_stream, "reply", target_message, action_data))
|
||||
asyncio.create_task(self._record_action_to_message(chat_stream, "reply", target_message, action_data)) # noqa: RUF006
|
||||
|
||||
if clear_unread_messages:
|
||||
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, "reply"))
|
||||
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, "reply")) # noqa: RUF006
|
||||
|
||||
# 回复成功,重置打断计数
|
||||
asyncio.create_task(self._reset_interruption_count_after_action(chat_stream.stream_id))
|
||||
asyncio.create_task(self._reset_interruption_count_after_action(chat_stream.stream_id)) # noqa: RUF006
|
||||
|
||||
return {"action_type": "reply", "success": True, "reply_text": reply_text, "loop_info": loop_info}
|
||||
|
||||
|
||||
@@ -393,8 +393,7 @@ class ActionModifier:
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}并行激活判断失败: {e}")
|
||||
# 如果并行执行失败,为所有任务默认不激活
|
||||
for action_name in task_action_names:
|
||||
deactivated_actions.append((action_name, f"并行判断失败: {e}"))
|
||||
deactivated_actions.extend((action_name, f"并行判断失败: {e}") for action_name in task_action_names)
|
||||
|
||||
return deactivated_actions
|
||||
|
||||
|
||||
@@ -505,9 +505,7 @@ class Prompt:
|
||||
context_data.update(result)
|
||||
|
||||
# 合并预构建的参数,这会覆盖任何同名的实时构建结果
|
||||
for key, value in pre_built_params.items():
|
||||
if value:
|
||||
context_data[key] = value
|
||||
context_data.update({key: value for key, value in pre_built_params.items() if value})
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 这是一个不太可能发生的、总体的构建超时,作为最后的保障
|
||||
|
||||
@@ -246,7 +246,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
logger.exception(f"后台统计数据输出过程中发生异常:{e}")
|
||||
|
||||
# 创建后台任务,立即返回
|
||||
asyncio.create_task(_async_collect_and_output())
|
||||
asyncio.create_task(_async_collect_and_output()) # noqa: RUF006
|
||||
|
||||
# -- 以下为统计数据收集方法 --
|
||||
|
||||
|
||||
@@ -190,11 +190,10 @@ class ConnectionPoolManager:
|
||||
async def _cleanup_expired_connections_locked(self):
|
||||
"""清理过期连接(需要在锁内调用)"""
|
||||
time.time()
|
||||
expired_connections = []
|
||||
|
||||
for connection_info in list(self._connections):
|
||||
if connection_info.is_expired(self.max_lifetime, self.max_idle) and not connection_info.in_use:
|
||||
expired_connections.append(connection_info)
|
||||
expired_connections = [
|
||||
connection_info for connection_info in list(self._connections)
|
||||
if connection_info.is_expired(self.max_lifetime, self.max_idle) and not connection_info.in_use
|
||||
]
|
||||
|
||||
for connection_info in expired_connections:
|
||||
await connection_info.close()
|
||||
|
||||
@@ -258,10 +258,7 @@ def remove_duplicate_handlers(): # sourcery skip: for-append-to-extend, list-co
|
||||
root_logger = logging.getLogger()
|
||||
|
||||
# 收集所有时间戳文件handler
|
||||
file_handlers = []
|
||||
for handler in root_logger.handlers[:]:
|
||||
if isinstance(handler, TimestampedFileHandler):
|
||||
file_handlers.append(handler)
|
||||
file_handlers = [handler for handler in root_logger.handlers[:] if isinstance(handler, TimestampedFileHandler)]
|
||||
|
||||
# 如果有多个文件handler,保留第一个,关闭其他的
|
||||
if len(file_handlers) > 1:
|
||||
|
||||
@@ -117,10 +117,7 @@ class PersonalityEvaluatorDirect:
|
||||
使用 DeepSeek AI 评估用户对特定场景的反应
|
||||
"""
|
||||
# 构建维度描述
|
||||
dimension_descriptions = []
|
||||
for dim in dimensions:
|
||||
if desc := FACTOR_DESCRIPTIONS.get(dim, ""):
|
||||
dimension_descriptions.append(f"- {dim}:{desc}")
|
||||
dimension_descriptions = [f"- {dim}:{desc}" for dim in dimensions if (desc := FACTOR_DESCRIPTIONS.get(dim, ""))]
|
||||
|
||||
dimensions_text = "\n".join(dimension_descriptions)
|
||||
|
||||
|
||||
@@ -372,10 +372,7 @@ def _default_normal_response_parser(
|
||||
|
||||
# 解析文本内容
|
||||
if "content" in candidate and "parts" in candidate["content"]:
|
||||
content_parts = []
|
||||
for part in candidate["content"]["parts"]:
|
||||
if "text" in part:
|
||||
content_parts.append(part["text"])
|
||||
content_parts = [part["text"] for part in candidate["content"]["parts"] if "text" in part]
|
||||
|
||||
if content_parts:
|
||||
api_response.content = "".join(content_parts)
|
||||
|
||||
@@ -3,7 +3,7 @@ import base64
|
||||
import io
|
||||
import re
|
||||
from collections.abc import Callable, Coroutine, Iterable
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
@@ -384,7 +384,7 @@ def _default_normal_response_parser(
|
||||
@client_registry.register_client_class("openai")
|
||||
class OpenaiClient(BaseClient):
|
||||
# 类级别的全局缓存:所有 OpenaiClient 实例共享
|
||||
_global_client_cache: dict[int, AsyncOpenAI] = {}
|
||||
_global_client_cache: ClassVar[dict[int, AsyncOpenAI] ] = {}
|
||||
"""全局 AsyncOpenAI 客户端缓存:config_hash -> AsyncOpenAI 实例"""
|
||||
|
||||
def __init__(self, api_provider: APIProvider):
|
||||
|
||||
@@ -535,7 +535,7 @@ class _RequestExecutor:
|
||||
retry_interval = api_provider.retry_interval
|
||||
|
||||
if isinstance(e, NetworkConnectionError | ReqAbortException):
|
||||
return self._check_retry(remain_try, retry_interval, "连接异常", model_name)
|
||||
return await self._check_retry(remain_try, retry_interval, "连接异常", model_name)
|
||||
elif isinstance(e, RespNotOkException):
|
||||
return await self._handle_resp_not_ok(e, model_info, api_provider, remain_try, messages_info)
|
||||
elif isinstance(e, RespParseException):
|
||||
@@ -1064,7 +1064,8 @@ class LLMRequest:
|
||||
# 遍历工具的参数
|
||||
for param in tool.get("parameters", []):
|
||||
# 严格验证参数格式是否为包含5个元素的元组
|
||||
assert isinstance(param, tuple) and len(param) == 5, "参数必须是包含5个元素的元组"
|
||||
assert isinstance(param, tuple), "参数必须是元组"
|
||||
assert len(param) == 5, "参数必须包含5个元素"
|
||||
builder.add_param(
|
||||
name=param[0],
|
||||
param_type=param[1],
|
||||
|
||||
@@ -511,15 +511,11 @@ class PersonInfoManager:
|
||||
final_data = {"person_id": person_id}
|
||||
|
||||
# Start with defaults for all model fields
|
||||
for key, default_value in _person_info_default.items():
|
||||
if key in model_fields:
|
||||
final_data[key] = default_value
|
||||
final_data.update({key: default_value for key, default_value in _person_info_default.items() if key in model_fields})
|
||||
|
||||
# Override with provided data
|
||||
if data:
|
||||
for key, value in data.items():
|
||||
if key in model_fields:
|
||||
final_data[key] = value
|
||||
final_data.update({key: value for key, value in data.items() if key in model_fields})
|
||||
|
||||
# Ensure person_id is correctly set from the argument
|
||||
final_data["person_id"] = person_id
|
||||
@@ -572,15 +568,11 @@ class PersonInfoManager:
|
||||
final_data = {"person_id": person_id}
|
||||
|
||||
# Start with defaults for all model fields
|
||||
for key, default_value in _person_info_default.items():
|
||||
if key in model_fields:
|
||||
final_data[key] = default_value
|
||||
final_data.update({key: default_value for key, default_value in _person_info_default.items() if key in model_fields})
|
||||
|
||||
# Override with provided data
|
||||
if data:
|
||||
for key, value in data.items():
|
||||
if key in model_fields:
|
||||
final_data[key] = value
|
||||
final_data.update({key: value for key, value in data.items() if key in model_fields})
|
||||
|
||||
# Ensure person_id is correctly set from the argument
|
||||
final_data["person_id"] = person_id
|
||||
|
||||
@@ -401,9 +401,9 @@ class RelationshipBuilder:
|
||||
for person_id in users_to_build_relationship:
|
||||
segments = self.person_engaged_cache[person_id]
|
||||
# 异步执行关系构建
|
||||
person = Person(person_id=person_id)
|
||||
if person.is_known:
|
||||
asyncio.create_task(self.update_impression_on_segments(person_id, self.chat_id, segments))
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(self.update_impression_on_segments(person_id, self.chat_id, segments)) # noqa: RUF006
|
||||
# 移除已处理的用户缓存
|
||||
del self.person_engaged_cache[person_id]
|
||||
self._save_cache()
|
||||
|
||||
@@ -44,7 +44,6 @@ from .base import (
|
||||
# 新增的增强命令系统
|
||||
PlusCommand,
|
||||
PlusCommandAdapter,
|
||||
PlusCommandInfo,
|
||||
PythonDependency,
|
||||
ToolInfo,
|
||||
ToolParamType,
|
||||
|
||||
@@ -48,9 +48,10 @@ class ChatManager:
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
streams = []
|
||||
try:
|
||||
for stream in get_chat_manager().streams.values():
|
||||
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform:
|
||||
streams.append(stream)
|
||||
streams.extend(
|
||||
stream for stream in get_chat_manager().streams.values()
|
||||
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform
|
||||
)
|
||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的聊天流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 获取聊天流失败: {e}")
|
||||
@@ -71,9 +72,10 @@ class ChatManager:
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
streams = []
|
||||
try:
|
||||
for stream in get_chat_manager().streams.values():
|
||||
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.group_info:
|
||||
streams.append(stream)
|
||||
streams.extend(
|
||||
stream for stream in get_chat_manager().streams.values()
|
||||
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.group_info
|
||||
)
|
||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的群聊流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 获取群聊流失败: {e}")
|
||||
@@ -97,9 +99,10 @@ class ChatManager:
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
streams = []
|
||||
try:
|
||||
for stream in get_chat_manager().streams.values():
|
||||
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and not stream.group_info:
|
||||
streams.append(stream)
|
||||
streams.extend(
|
||||
stream for stream in get_chat_manager().streams.values()
|
||||
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and not stream.group_info
|
||||
)
|
||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的私聊流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 获取私聊流失败: {e}")
|
||||
|
||||
@@ -183,9 +183,10 @@ async def build_cross_context_s4u(
|
||||
blacklisted_streams.add(stream_id)
|
||||
except ValueError:
|
||||
logger.warning(f"无效的S4U黑名单格式: {chat_str}")
|
||||
for stream_id in chat_manager.streams:
|
||||
if stream_id != chat_stream.stream_id and stream_id not in blacklisted_streams:
|
||||
streams_to_scan.append(stream_id)
|
||||
streams_to_scan.extend(
|
||||
stream_id for stream_id in chat_manager.streams
|
||||
if stream_id != chat_stream.stream_id and stream_id not in blacklisted_streams
|
||||
)
|
||||
|
||||
logger.debug(f"[S4U] Found {len(streams_to_scan)} group streams to scan.")
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ class ScoringAPI:
|
||||
return await relationship_service.get_user_relationship_data(user_id)
|
||||
|
||||
@staticmethod
|
||||
async def update_user_relationship(user_id: str, relationship_score: float, relationship_text: str = None, user_name: str = None):
|
||||
async def update_user_relationship(user_id: str, relationship_score: float, relationship_text: str | None = None, user_name: str | None = None):
|
||||
"""
|
||||
更新用户关系数据
|
||||
|
||||
@@ -71,7 +71,7 @@ class ScoringAPI:
|
||||
await interest_service.initialize_smart_interests(personality_description, personality_id)
|
||||
|
||||
@staticmethod
|
||||
async def calculate_interest_match(content: str, keywords: list[str] = None):
|
||||
async def calculate_interest_match(content: str, keywords: list[str] | None = None):
|
||||
"""
|
||||
计算内容与兴趣的匹配度
|
||||
|
||||
@@ -98,7 +98,7 @@ class ScoringAPI:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def clear_caches(user_id: str = None):
|
||||
def clear_caches(user_id: str | None = None):
|
||||
"""
|
||||
清理缓存
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -26,7 +26,7 @@ class PluginStorageManager:
|
||||
哼,现在它和API住在一起了,希望它们能和睦相处。
|
||||
"""
|
||||
|
||||
_instances: dict[str, "PluginStorage"] = {}
|
||||
_instances: ClassVar[dict[str, "PluginStorage"] ] = {}
|
||||
_lock = threading.Lock()
|
||||
_base_path = os.path.join("data", "plugin_data")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import asyncio
|
||||
import random
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
@@ -80,7 +80,7 @@ class BaseAction(ABC):
|
||||
"""是否为二步Action。如果为True,Action将分两步执行:第一步选择操作,第二步执行具体操作"""
|
||||
step_one_description: str = ""
|
||||
"""第一步的描述,用于向LLM展示Action的基本功能"""
|
||||
sub_actions: list[tuple[str, str, dict[str, str]]] = []
|
||||
sub_actions: ClassVar[list[tuple[str, str, dict[str, str]]] ] = []
|
||||
"""子Action列表,格式为[(子Action名, 子Action描述, 子Action参数)]。仅在二步Action中使用"""
|
||||
|
||||
def __init__(
|
||||
@@ -110,7 +110,7 @@ class BaseAction(ABC):
|
||||
**kwargs: 其他参数
|
||||
"""
|
||||
if plugin_config is None:
|
||||
plugin_config = {}
|
||||
plugin_config: ClassVar = {}
|
||||
self.action_data = action_data
|
||||
self.reasoning = reasoning
|
||||
self.cycle_timers = cycle_timers
|
||||
@@ -492,7 +492,7 @@ class BaseAction(ABC):
|
||||
|
||||
plugin_config = component_registry.get_plugin_config(component_info.plugin_name)
|
||||
# 3. 实例化被调用的Action
|
||||
action_params = {
|
||||
action_params: ClassVar = {
|
||||
"action_data": called_action_data,
|
||||
"reasoning": f"Called by {self.action_name}",
|
||||
"cycle_timers": self.cycle_timers,
|
||||
@@ -745,7 +745,7 @@ class BaseAction(ABC):
|
||||
if not case_sensitive:
|
||||
search_text = search_text.lower()
|
||||
|
||||
matched_keywords = []
|
||||
matched_keywords: ClassVar = []
|
||||
for keyword in keywords:
|
||||
check_keyword = keyword if case_sensitive else keyword.lower()
|
||||
if check_keyword in search_text:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.plugin_system.base.component_types import ChatterInfo, ComponentType
|
||||
@@ -15,7 +15,7 @@ class BaseChatter(ABC):
|
||||
"""Chatter组件的名称"""
|
||||
chatter_description: str = ""
|
||||
"""Chatter组件的描述"""
|
||||
chat_types: list[ChatType] = [ChatType.PRIVATE, ChatType.GROUP]
|
||||
chat_types: ClassVar[list[ChatType]] = [ChatType.PRIVATE, ChatType.GROUP]
|
||||
|
||||
def __init__(self, stream_id: str, action_manager: "ChatterActionManager"):
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import ClassVar
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -21,7 +22,7 @@ class BaseEventHandler(ABC):
|
||||
"""处理器权重,越大权重越高"""
|
||||
intercept_message: bool = False
|
||||
"""是否拦截消息,默认为否"""
|
||||
init_subscribe: list[EventType | str] = [EventType.UNKNOWN]
|
||||
init_subscribe: ClassVar[list[EventType | str]] = [EventType.UNKNOWN]
|
||||
"""初始化时订阅的事件名称"""
|
||||
plugin_name = None
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from src.chat.utils.prompt_params import PromptParameters
|
||||
from src.common.logger import get_logger
|
||||
@@ -27,7 +27,7 @@ class BasePrompt(ABC):
|
||||
# 定义此组件希望如何注入到核心Prompt中
|
||||
# 这是一个 InjectionRule 对象的列表,可以实现复杂的注入逻辑
|
||||
# 例如: [InjectionRule(target_prompt="planner_prompt", injection_type=InjectionType.APPEND, priority=50)]
|
||||
injection_rules: list[InjectionRule] = []
|
||||
injection_rules: ClassVar[list[InjectionRule] ] = []
|
||||
"""定义注入规则的列表"""
|
||||
|
||||
# 旧的注入点定义,用于向后兼容。如果定义了这个,它将被自动转换为 injection_rules。
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
@@ -18,7 +18,7 @@ class BaseTool(ABC):
|
||||
"""工具的名称"""
|
||||
description: str = ""
|
||||
"""工具的描述"""
|
||||
parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = []
|
||||
parameters: ClassVar[list[tuple[str, ToolParamType, str, bool, list[str] | None]] ] = []
|
||||
"""工具的参数定义,为[("param_name", param_type, "description", required, enum_values)]格式
|
||||
param_name: 参数名称
|
||||
param_type: 参数类型
|
||||
@@ -44,7 +44,7 @@ class BaseTool(ABC):
|
||||
"""是否为二步工具。如果为True,工具将分两步调用:第一步展示工具信息,第二步执行具体操作"""
|
||||
step_one_description: str = ""
|
||||
"""第一步的描述,用于向LLM展示工具的基本功能"""
|
||||
sub_tools: list[tuple[str, str, list[tuple[str, ToolParamType, str, bool, list[str] | None]]]] = []
|
||||
sub_tools: ClassVar[list[tuple[str, str, list[tuple[str, ToolParamType, str, bool, list[str] | None]]]] ] = []
|
||||
"""子工具列表,格式为[(子工具名, 子工具描述, 子工具参数)]。仅在二步工具中使用"""
|
||||
|
||||
def __init__(self, plugin_config: dict | None = None, chat_stream: Any = None):
|
||||
@@ -112,7 +112,7 @@ class BaseTool(ABC):
|
||||
if not cls.is_two_step_tool:
|
||||
return []
|
||||
|
||||
definitions = []
|
||||
definitions: ClassVar = []
|
||||
for sub_name, sub_desc, sub_params in cls.sub_tools:
|
||||
definitions.append({"name": f"{cls.name}_{sub_name}", "description": sub_desc, "parameters": sub_params})
|
||||
return definitions
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import shutil
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import toml
|
||||
|
||||
@@ -30,11 +30,11 @@ class PluginBase(ABC):
|
||||
config_file_name: str
|
||||
enable_plugin: bool = True
|
||||
|
||||
config_schema: dict[str, dict[str, ConfigField] | str] = {}
|
||||
config_schema: ClassVar[dict[str, dict[str, ConfigField] | str] ] = {}
|
||||
|
||||
permission_nodes: list["PermissionNodeField"] = []
|
||||
permission_nodes: ClassVar[list["PermissionNodeField"] ] = []
|
||||
|
||||
config_section_descriptions: dict[str, str] = {}
|
||||
config_section_descriptions: ClassVar[dict[str, str] ] = {}
|
||||
|
||||
def __init__(self, plugin_dir: str, metadata: PluginMetadata):
|
||||
"""初始化插件
|
||||
@@ -206,12 +206,12 @@ class PluginBase(ABC):
|
||||
if not self.config_schema:
|
||||
return {}
|
||||
|
||||
config_data = {}
|
||||
config_data: ClassVar = {}
|
||||
|
||||
# 遍历每个配置节
|
||||
for section, fields in self.config_schema.items():
|
||||
if isinstance(fields, dict):
|
||||
section_data = {}
|
||||
section_data: ClassVar = {}
|
||||
|
||||
# 遍历节内的字段
|
||||
for field_name, field in fields.items():
|
||||
@@ -331,7 +331,7 @@ class PluginBase(ABC):
|
||||
|
||||
try:
|
||||
with open(user_config_path, encoding="utf-8") as f:
|
||||
user_config = toml.load(f) or {}
|
||||
user_config: ClassVar = toml.load(f) or {}
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 加载用户配置文件 {user_config_path} 失败: {e}", exc_info=True)
|
||||
self.config = self._generate_config_from_schema() # 加载失败时使用默认 schema
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
@@ -42,7 +42,7 @@ class PlusCommand(ABC):
|
||||
command_description: str = ""
|
||||
"""命令描述"""
|
||||
|
||||
command_aliases: list[str] = []
|
||||
command_aliases: ClassVar[list[str] ] = []
|
||||
"""命令别名列表,如 ['say', 'repeat']"""
|
||||
|
||||
priority: int = 0
|
||||
@@ -435,7 +435,3 @@ def create_plus_command_adapter(plus_command_class):
|
||||
|
||||
return AdapterClass
|
||||
|
||||
|
||||
# 兼容旧的命名
|
||||
PlusCommandAdapter = create_plus_command_adapter
|
||||
|
||||
|
||||
@@ -87,8 +87,8 @@ class ComponentRegistry:
|
||||
self._tool_registry: dict[str, type["BaseTool"]] = {} # 工具名 -> 工具类
|
||||
self._llm_available_tools: dict[str, type["BaseTool"]] = {} # llm可用的工具名 -> 工具类
|
||||
|
||||
# MCP 工具注册表(运行时动态加载)
|
||||
self._mcp_tools: list["BaseTool"] = [] # MCP 工具适配器实例列表
|
||||
# MCP 工具注册表(运行时动态加载)
|
||||
self._mcp_tools: list[Any] = [] # MCP 工具适配器实例列表
|
||||
self._mcp_tools_loaded = False # MCP 工具是否已加载
|
||||
|
||||
# EventHandler特定注册表
|
||||
|
||||
@@ -7,7 +7,6 @@ from threading import Lock
|
||||
from typing import Any, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BaseEventHandler
|
||||
from src.plugin_system.base.base_event import BaseEvent, HandlerResultsCollection
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
@@ -176,10 +175,10 @@ class EventManager:
|
||||
|
||||
# 处理init_subscribe,缓存失败的订阅
|
||||
if self._event_handlers[handler_name].init_subscribe:
|
||||
failed_subscriptions = []
|
||||
for event_name in self._event_handlers[handler_name].init_subscribe:
|
||||
if not self.subscribe_handler_to_event(handler_name, event_name):
|
||||
failed_subscriptions.append(event_name)
|
||||
failed_subscriptions = [
|
||||
event_name for event_name in self._event_handlers[handler_name].init_subscribe
|
||||
if not self.subscribe_handler_to_event(handler_name, event_name)
|
||||
]
|
||||
|
||||
# 缓存失败的订阅
|
||||
if failed_subscriptions:
|
||||
|
||||
@@ -4,7 +4,7 @@ MCP Tool Adapter
|
||||
将 MCP 工具适配为 BaseTool,使其能够被插件系统识别和调用
|
||||
"""
|
||||
|
||||
from typing import Any, ClassVar
|
||||
from typing import Any
|
||||
|
||||
import mcp.types
|
||||
|
||||
@@ -27,9 +27,6 @@ class MCPToolAdapter(BaseTool):
|
||||
3. 参与工具缓存机制
|
||||
"""
|
||||
|
||||
# 类级别默认值,使用 ClassVar 标注
|
||||
available_for_llm: ClassVar[bool] = True
|
||||
|
||||
def __init__(self, server_name: str, mcp_tool: mcp.types.Tool, plugin_config: dict | None = None):
|
||||
"""
|
||||
初始化 MCP 工具适配器
|
||||
@@ -47,6 +44,7 @@ class MCPToolAdapter(BaseTool):
|
||||
# 设置实例属性
|
||||
self.name = f"mcp_{server_name}_{mcp_tool.name}"
|
||||
self.description = mcp_tool.description or f"MCP tool from {server_name}"
|
||||
self.available_for_llm = True # MCP 工具默认可供 LLM 使用
|
||||
|
||||
# 转换参数定义
|
||||
self.parameters = self._convert_parameters(mcp_tool.inputSchema)
|
||||
|
||||
@@ -456,8 +456,7 @@ class PermissionManager(IPermissionManager):
|
||||
)
|
||||
granted_users = result.scalars().all()
|
||||
|
||||
for user_perm in granted_users:
|
||||
users.append((user_perm.platform, user_perm.user_id))
|
||||
users.extend((user_perm.platform, user_perm.user_id) for user_perm in granted_users)
|
||||
|
||||
# 如果是默认授权的权限节点,还需要考虑没有明确设置的用户
|
||||
# 但这里我们只返回明确授权的用户,避免返回所有用户
|
||||
|
||||
@@ -94,7 +94,6 @@ class PluginManager:
|
||||
if not plugin_class:
|
||||
logger.error(f"插件 {plugin_name} 的插件类未注册或不存在")
|
||||
return False, 1
|
||||
init_module = None # 预先定义,避免后续条件加载导致未绑定
|
||||
try:
|
||||
# 使用记录的插件目录路径
|
||||
plugin_dir = self.plugin_paths.get(plugin_name)
|
||||
|
||||
@@ -7,7 +7,7 @@ import asyncio
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from src.chat.express.expression_learner import expression_learner_manager
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
@@ -29,7 +29,7 @@ class AffinityChatter(BaseChatter):
|
||||
|
||||
chatter_name: str = "AffinityChatter"
|
||||
chatter_description: str = "基于亲和力模型的智能聊天处理器,支持多种聊天类型"
|
||||
chat_types: list[ChatType] = [ChatType.ALL] # 支持所有聊天类型
|
||||
chat_types: ClassVar[list[ChatType]] = [ChatType.ALL] # 支持所有聊天类型
|
||||
|
||||
def __init__(self, stream_id: str, action_manager: ChatterActionManager):
|
||||
"""
|
||||
@@ -68,7 +68,7 @@ class AffinityChatter(BaseChatter):
|
||||
try:
|
||||
# 触发表达学习
|
||||
learner = await expression_learner_manager.get_expression_learner(self.stream_id)
|
||||
asyncio.create_task(learner.trigger_learning_for_chat())
|
||||
asyncio.create_task(learner.trigger_learning_for_chat()) # noqa: RUF006
|
||||
|
||||
unread_messages = context.get_unread_messages()
|
||||
|
||||
@@ -87,7 +87,7 @@ class AffinityChatter(BaseChatter):
|
||||
self.stats["successful_executions"] += 1
|
||||
self.last_activity_time = time.time()
|
||||
|
||||
result = {
|
||||
result: ClassVar = {
|
||||
"success": True,
|
||||
"stream_id": self.stream_id,
|
||||
"plan_created": True,
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -29,7 +29,7 @@ class ChatStreamImpressionTool(BaseTool):
|
||||
|
||||
name = "update_chat_stream_impression"
|
||||
description = "当你通过观察聊天记录对当前聊天环境(群聊或私聊)产生了整体印象或认识时使用此工具,更新对这个聊天流的看法。包括:环境氛围、聊天风格、常见话题、你的兴趣程度。调用时机:当你发现这个聊天环境有明显的氛围特点(如很活跃、很专业、很闲聊)、群成员经常讨论某类话题、或者你对这个环境的感受发生变化时。注意:这是对整个聊天环境的印象,而非对单个用户。"
|
||||
parameters = [
|
||||
parameters: ClassVar = [
|
||||
(
|
||||
"impression_description",
|
||||
ToolParamType.STRING,
|
||||
@@ -73,7 +73,7 @@ class ChatStreamImpressionTool(BaseTool):
|
||||
)
|
||||
except AttributeError:
|
||||
# 降级处理
|
||||
available_models = [
|
||||
available_models: ClassVar = [
|
||||
attr
|
||||
for attr in dir(model_config.model_task_config)
|
||||
if not attr.startswith("_") and attr != "model_dump"
|
||||
@@ -153,7 +153,7 @@ class ChatStreamImpressionTool(BaseTool):
|
||||
await self._update_stream_impression_in_db(stream_id, final_impression)
|
||||
|
||||
# 构建返回信息
|
||||
updates = []
|
||||
updates: ClassVar = []
|
||||
if final_impression.get("stream_impression_text"):
|
||||
updates.append(f"印象: {final_impression['stream_impression_text'][:50]}...")
|
||||
if final_impression.get("stream_chat_style"):
|
||||
|
||||
@@ -117,9 +117,11 @@ class ChatterPlanFilter:
|
||||
elif isinstance(actions_obj, list):
|
||||
actions_to_process_for_log.extend(actions_obj)
|
||||
|
||||
for single_action in actions_to_process_for_log:
|
||||
if isinstance(single_action, dict):
|
||||
action_types_to_log.append(single_action.get("action_type", "no_action"))
|
||||
action_types_to_log = [
|
||||
single_action.get("action_type", "no_action")
|
||||
for single_action in actions_to_process_for_log
|
||||
if isinstance(single_action, dict)
|
||||
]
|
||||
|
||||
if thinking != "未提供思考过程" and action_types_to_log:
|
||||
await self._add_decision_to_history(plan, thinking, ", ".join(action_types_to_log))
|
||||
|
||||
@@ -118,7 +118,6 @@ class ChatterActionPlanner:
|
||||
# 2. 使用新的兴趣度管理系统进行评分
|
||||
max_message_interest = 0.0
|
||||
reply_not_available = True
|
||||
interest_updates: list[dict[str, Any]] = []
|
||||
aggregate_should_act = False
|
||||
|
||||
if unread_messages:
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
亲和力聊天处理器插件(包含兴趣计算器功能)
|
||||
"""
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis.plugin_register_api import register_plugin
|
||||
from src.plugin_system.base.base_plugin import BasePlugin
|
||||
@@ -21,12 +23,12 @@ class AffinityChatterPlugin(BasePlugin):
|
||||
|
||||
plugin_name: str = "affinity_chatter"
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = []
|
||||
python_dependencies: list[str] = []
|
||||
dependencies: ClassVar[list[str] ] = []
|
||||
python_dependencies: ClassVar[list[str] ] = []
|
||||
config_file_name: str = ""
|
||||
|
||||
# 简单的 config_schema 占位(如果将来需要配置可扩展)
|
||||
config_schema = {}
|
||||
config_schema: ClassVar = {}
|
||||
|
||||
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
|
||||
"""返回插件包含的组件列表
|
||||
@@ -34,7 +36,7 @@ class AffinityChatterPlugin(BasePlugin):
|
||||
这里采用延迟导入以避免循环依赖和启动顺序问题。
|
||||
如果导入失败则返回空列表以让注册过程继续而不崩溃。
|
||||
"""
|
||||
components = []
|
||||
components: ClassVar = []
|
||||
|
||||
try:
|
||||
# 延迟导入 AffinityChatter
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
监听bot的reply事件,在reply后重置对应聊天流的主动思考定时任务
|
||||
"""
|
||||
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BaseEventHandler, EventType
|
||||
from src.plugin_system.base.base_event import HandlerResult
|
||||
@@ -23,7 +26,7 @@ class ProactiveThinkingReplyHandler(BaseEventHandler):
|
||||
|
||||
handler_name: str = "proactive_thinking_reply_handler"
|
||||
handler_description: str = "监听reply事件,重置主动思考定时任务"
|
||||
init_subscribe: list[EventType | str] = [EventType.AFTER_SEND]
|
||||
init_subscribe: ClassVar[list[EventType | str]] = [EventType.AFTER_SEND]
|
||||
|
||||
async def execute(self, kwargs: dict | None) -> HandlerResult:
|
||||
"""处理reply事件
|
||||
@@ -90,7 +93,7 @@ class ProactiveThinkingMessageHandler(BaseEventHandler):
|
||||
|
||||
handler_name: str = "proactive_thinking_message_handler"
|
||||
handler_description: str = "监听消息事件,为新聊天流创建主动思考任务"
|
||||
init_subscribe: list[EventType | str] = [EventType.ON_MESSAGE]
|
||||
init_subscribe: ClassVar[list[EventType | str]] = [EventType.ON_MESSAGE]
|
||||
|
||||
async def execute(self, kwargs: dict | None) -> HandlerResult:
|
||||
"""处理消息事件
|
||||
|
||||
@@ -351,6 +351,76 @@ class ProactiveThinkingPlanner:
|
||||
logger.error(f"决策过程失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _build_decision_prompt(self, context: dict[str, Any]) -> str:
|
||||
"""构建决策提示词"""
|
||||
# 构建上次决策信息
|
||||
last_decision_text = ""
|
||||
if context.get("last_decision"):
|
||||
last_dec = context["last_decision"]
|
||||
last_action = last_dec.get("action", "未知")
|
||||
last_reasoning = last_dec.get("reasoning", "无")
|
||||
last_topic = last_dec.get("topic")
|
||||
last_time = last_dec.get("timestamp", "未知")
|
||||
|
||||
last_decision_text = f"""
|
||||
【上次主动思考的决策】
|
||||
- 时间: {last_time}
|
||||
- 决策: {last_action}
|
||||
- 理由: {last_reasoning}"""
|
||||
if last_topic:
|
||||
last_decision_text += f"\n- 话题: {last_topic}"
|
||||
|
||||
return f"""你的人设是:
|
||||
{context['bot_personality']}
|
||||
|
||||
现在是 {context['current_time']},你正在考虑是否要在与 "{context['stream_name']}" 的对话中主动说些什么。
|
||||
|
||||
【你当前的心情】
|
||||
{context.get("current_mood", "感觉很平静")}
|
||||
|
||||
【聊天环境信息】
|
||||
- 整体印象: {context["stream_impression"]}
|
||||
- 聊天风格: {context["chat_style"]}
|
||||
- 常见话题: {context["topic_keywords"] or "暂无"}
|
||||
- 你的兴趣程度: {context["interest_score"]:.2f}/1.0
|
||||
{last_decision_text}
|
||||
|
||||
【最近的聊天记录】
|
||||
{context["recent_chat_history"]}
|
||||
|
||||
请根据以上信息,决定你现在应该做什么:
|
||||
|
||||
**选项1:什么都不做 (do_nothing)**
|
||||
- 适用场景:气氛不适合说话、最近对话很活跃、没什么特别想说的、或者此时说话会显得突兀。
|
||||
- 心情影响:如果心情不好(如生气、难过),可能更倾向于保持沉默。
|
||||
|
||||
**选项2:简单冒个泡 (simple_bubble)**
|
||||
- 适用场景:对话有些冷清,你想缓和气氛或开启新的互动。
|
||||
- 方式:说一句轻松随意的话,旨在建立或维持连接。
|
||||
- 心情影响:心情会影响你冒泡的方式和内容。
|
||||
|
||||
**选项3:发起一次有目的的互动 (throw_topic)**
|
||||
- 适用场景:你想延续对话、表达关心、或深入讨论某个具体话题。
|
||||
- **【互动类型1:延续约定或提醒】(最高优先级)**:检查最近的聊天记录,是否存在可以延续的互动。例如,如果昨晚的最后一条消息是“晚安”,现在是早上,一个“早安”的回应是绝佳的选择。如果之前提到过某个约定(如“待会聊”),现在可以主动跟进。
|
||||
- **【互动类型2:展现真诚的关心】(次高优先级)**:如果不存在可延续的约定,请仔细阅读聊天记录,寻找对方提及的个人状况(如天气、出行、身体、情绪、工作学习等),并主动表达关心。
|
||||
- **【互动类型3:开启新话题】**:当以上两点都不适用时,可以考虑开启一个你感兴趣的新话题。
|
||||
- 心情影响:心情会影响你想发起互动的方式和内容。
|
||||
|
||||
请以JSON格式回复你的决策:
|
||||
{{
|
||||
"action": "do_nothing" | "simple_bubble" | "throw_topic",
|
||||
"reasoning": "你的决策理由(请结合你的心情、聊天环境和对话历史进行分析)",
|
||||
"topic": "(仅当action=throw_topic时填写)你的互动意图(如:回应晚安并说早安、关心对方的考试情况、讨论新游戏)"
|
||||
}}
|
||||
|
||||
注意:
|
||||
1. 兴趣度较低(<0.4)时或者最近聊天很活跃(不到1小时),倾向于 `do_nothing` 或 `simple_bubble`。
|
||||
2. 你的心情会影响你的行动倾向和表达方式。
|
||||
3. 参考上次决策,避免重复,并可根据上次的互动效果调整策略。
|
||||
4. 只有在真的有感而发时才选择 `throw_topic`。
|
||||
5. 保持你的人设,确保行为一致性。
|
||||
"""
|
||||
|
||||
async def generate_reply(
|
||||
self, context: dict[str, Any], action: Literal["simple_bubble", "throw_topic"], topic: str | None = None
|
||||
) -> str | None:
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import orjson
|
||||
from sqlalchemy import select
|
||||
@@ -30,7 +30,7 @@ class UserProfileTool(BaseTool):
|
||||
|
||||
name = "update_user_profile"
|
||||
description = "当你通过聊天记录对某个用户产生了新的认识或印象时使用此工具,更新该用户的画像信息。包括:用户别名、你对TA的主观印象、TA的偏好兴趣、你对TA的好感程度。调用时机:当你发现用户透露了新的个人信息、展现了性格特点、表达了兴趣偏好,或者你们的互动让你对TA的看法发生变化时。"
|
||||
parameters = [
|
||||
parameters: ClassVar = [
|
||||
("target_user_id", ToolParamType.STRING, "目标用户的ID(必须)", True, None),
|
||||
("user_aliases", ToolParamType.STRING, "该用户的昵称或别名,如果发现用户自称或被他人称呼的其他名字时填写,多个别名用逗号分隔(可选)", False, None),
|
||||
("impression_description", ToolParamType.STRING, "你对该用户的整体印象和性格感受,例如'这个用户很幽默开朗'、'TA对技术很有热情'等。当你通过对话了解到用户的性格、态度、行为特点时填写(可选)", False, None),
|
||||
@@ -51,7 +51,7 @@ class UserProfileTool(BaseTool):
|
||||
)
|
||||
except AttributeError:
|
||||
# 降级处理
|
||||
available_models = [
|
||||
available_models: ClassVar = [
|
||||
attr for attr in dir(model_config.model_task_config)
|
||||
if not attr.startswith("_") and attr != "model_dump"
|
||||
]
|
||||
@@ -131,7 +131,7 @@ class UserProfileTool(BaseTool):
|
||||
await self._update_user_profile_in_db(target_user_id, final_profile)
|
||||
|
||||
# 构建返回信息
|
||||
updates = []
|
||||
updates: ClassVar = []
|
||||
if final_profile.get("user_aliases"):
|
||||
updates.append(f"别名: {final_profile['user_aliases']}")
|
||||
if final_profile.get("relationship_text"):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import random
|
||||
import re
|
||||
from typing import ClassVar
|
||||
|
||||
from src.chat.emoji_system.emoji_history import add_emoji_to_history, get_recent_emojis
|
||||
from src.chat.emoji_system.emoji_manager import MaiEmoji, get_emoji_manager
|
||||
@@ -75,17 +76,17 @@ class EmojiAction(BaseAction):
|
||||
"""
|
||||
|
||||
# 动作参数定义
|
||||
action_parameters = {}
|
||||
action_parameters: ClassVar = {}
|
||||
|
||||
# 动作使用场景
|
||||
action_require = [
|
||||
action_require: ClassVar = [
|
||||
"发送表情包辅助表达情绪",
|
||||
"表达情绪时可以选择使用",
|
||||
"不要连续发送,如果你已经发过[表情包],就不要选择此动作",
|
||||
]
|
||||
|
||||
# 关联类型
|
||||
associated_types = ["emoji"]
|
||||
associated_types: ClassVar[list[str]] = ["emoji"]
|
||||
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""执行表情动作"""
|
||||
@@ -119,8 +120,8 @@ class EmojiAction(BaseAction):
|
||||
logger.error(f"{self.log_prefix} 获取或处理表情发送历史时出错: {e}")
|
||||
|
||||
# 4. 准备情感数据和后备列表
|
||||
emotion_map = {}
|
||||
all_emojis_data = []
|
||||
emotion_map: ClassVar = {}
|
||||
all_emojis_data: ClassVar = []
|
||||
|
||||
for emoji in all_emojis_obj:
|
||||
b64 = image_path_to_base64(emoji.full_path)
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
"""
|
||||
|
||||
# 导入依赖的系统组件
|
||||
from typing import ClassVar
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# 导入新插件系统
|
||||
@@ -34,18 +36,18 @@ class CoreActionsPlugin(BasePlugin):
|
||||
# 插件基本信息
|
||||
plugin_name: str = "core_actions" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = [] # 插件依赖列表
|
||||
python_dependencies: list[str] = [] # Python包依赖列表
|
||||
dependencies: ClassVar[list[str]] = [] # 插件依赖列表
|
||||
python_dependencies: ClassVar[list[str]] = [] # Python包依赖列表
|
||||
config_file_name: str = "config.toml"
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {
|
||||
config_section_descriptions: ClassVar = {
|
||||
"plugin": "插件启用配置",
|
||||
"components": "核心组件启用配置",
|
||||
}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema: dict = {
|
||||
config_schema: ClassVar[dict] = {
|
||||
"plugin": {
|
||||
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
||||
"config_version": ConfigField(type=str, default="0.6.0", description="配置文件版本"),
|
||||
@@ -63,7 +65,7 @@ class CoreActionsPlugin(BasePlugin):
|
||||
"""返回插件包含的组件列表"""
|
||||
|
||||
# --- 根据配置注册组件 ---
|
||||
components = []
|
||||
components: ClassVar = []
|
||||
if self.get_config("components.enable_emoji", True):
|
||||
components.append((EmojiAction.get_action_info(), EmojiAction))
|
||||
if self.get_config("components.enable_anti_injector_manager", True):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
||||
from src.common.logger import get_logger
|
||||
@@ -13,7 +13,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
|
||||
|
||||
name = "lpmm_search_knowledge"
|
||||
description = "从知识库中搜索相关信息,如果你需要知识,就使用这个工具"
|
||||
parameters = [
|
||||
parameters: ClassVar = [
|
||||
("query", ToolParamType.STRING, "搜索查询关键词", True, None),
|
||||
("threshold", ToolParamType.FLOAT, "相似度阈值,0.0到1.0之间", False, None),
|
||||
]
|
||||
@@ -44,7 +44,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
|
||||
logger.debug(f"知识库查询结果: {knowledge_info}")
|
||||
|
||||
if knowledge_info and knowledge_info.get("knowledge_items"):
|
||||
knowledge_parts = []
|
||||
knowledge_parts: ClassVar = []
|
||||
for i, item in enumerate(knowledge_info["knowledge_items"]):
|
||||
knowledge_parts.append(f"- {item.get('content', 'N/A')}")
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
阅读说说动作组件
|
||||
"""
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import ActionActivationType, BaseAction, ChatMode
|
||||
from src.plugin_system.apis import generator_api
|
||||
@@ -21,9 +23,9 @@ class ReadFeedAction(BaseAction):
|
||||
action_description: str = "读取好友的最新动态并进行评论点赞"
|
||||
activation_type: ActionActivationType = ActionActivationType.KEYWORD
|
||||
mode_enable: ChatMode = ChatMode.ALL
|
||||
activation_keywords: list = ["看说说", "看空间", "看动态", "刷空间"]
|
||||
activation_keywords: ClassVar[list] = ["看说说", "看空间", "看动态", "刷空间"]
|
||||
|
||||
action_parameters = {
|
||||
action_parameters: ClassVar[dict] = {
|
||||
"target_name": "需要阅读动态的好友的昵称",
|
||||
"user_name": "请求你阅读动态的好友的昵称",
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
发送说说动作组件
|
||||
"""
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import ActionActivationType, BaseAction, ChatMode
|
||||
from src.plugin_system.apis import generator_api
|
||||
@@ -21,9 +23,9 @@ class SendFeedAction(BaseAction):
|
||||
action_description: str = "发送一条关于特定主题的说说"
|
||||
activation_type: ActionActivationType = ActionActivationType.KEYWORD
|
||||
mode_enable: ChatMode = ChatMode.ALL
|
||||
activation_keywords: list = ["发说说", "发空间", "发动态"]
|
||||
activation_keywords: ClassVar[list] = ["发说说", "发空间", "发动态"]
|
||||
|
||||
action_parameters = {
|
||||
action_parameters: ClassVar[dict] = {
|
||||
"topic": "用户想要发送的说说主题",
|
||||
"user_name": "请求你发说说的好友的昵称",
|
||||
}
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
发送说说命令 await self.send_text(f"收到!正在为你生成关于"{topic or '随机'}"的说说,请稍候...【热重载测试成功】")件
|
||||
"""
|
||||
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.command_args import CommandArgs
|
||||
from src.plugin_system.base.plus_command import PlusCommand
|
||||
@@ -20,7 +23,7 @@ class SendFeedCommand(PlusCommand):
|
||||
|
||||
command_name: str = "send_feed"
|
||||
command_description: str = "发一条QQ空间说说"
|
||||
command_aliases = ["发空间"]
|
||||
command_aliases: ClassVar[list[str]] = ["发空间"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@@ -4,6 +4,7 @@ MaiZone(麦麦空间)- 重构版
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BasePlugin, ComponentInfo, register_plugin
|
||||
@@ -33,10 +34,10 @@ class MaiZoneRefactoredPlugin(BasePlugin):
|
||||
plugin_description: str = "重构版的MaiZone插件"
|
||||
config_file_name: str = "config.toml"
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = []
|
||||
python_dependencies: list[str] = []
|
||||
dependencies: ClassVar[list[str] ] = []
|
||||
python_dependencies: ClassVar[list[str] ] = []
|
||||
|
||||
config_schema: dict = {
|
||||
config_schema: ClassVar[dict] = {
|
||||
"plugin": {"enable": ConfigField(type=bool, default=True, description="是否启用插件")},
|
||||
"models": {
|
||||
"text_model": ConfigField(type=str, default="maizone", description="生成文本的模型名称"),
|
||||
@@ -83,7 +84,7 @@ class MaiZoneRefactoredPlugin(BasePlugin):
|
||||
},
|
||||
}
|
||||
|
||||
permission_nodes: list[PermissionNodeField] = [
|
||||
permission_nodes: ClassVar[list[PermissionNodeField]] = [
|
||||
PermissionNodeField(node_name="send_feed", description="是否可以使用机器人发送QQ空间说说"),
|
||||
PermissionNodeField(node_name="read_feed", description="是否可以使用机器人读取QQ空间说说"),
|
||||
]
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from typing import ClassVar
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BaseEventHandler
|
||||
from src.plugin_system.base.base_event import HandlerResult
|
||||
|
||||
from .src.send_handler import send_handler
|
||||
from .event_types import NapcatEvent
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .src.send_handler import send_handler
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
@@ -14,7 +15,7 @@ class SetProfileHandler(BaseEventHandler):
|
||||
handler_description: str = "设置账号信息"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.ACCOUNT.SET_PROFILE]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.ACCOUNT.SET_PROFILE]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -49,7 +50,7 @@ class GetOnlineClientsHandler(BaseEventHandler):
|
||||
handler_description: str = "获取当前账号在线客户端列表"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.ACCOUNT.GET_ONLINE_CLIENTS]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.ACCOUNT.GET_ONLINE_CLIENTS]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -72,7 +73,7 @@ class SetOnlineStatusHandler(BaseEventHandler):
|
||||
handler_description: str = "设置在线状态"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.ACCOUNT.SET_ONLINE_STATUS]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.ACCOUNT.SET_ONLINE_STATUS]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -103,7 +104,7 @@ class GetFriendsWithCategoryHandler(BaseEventHandler):
|
||||
handler_description: str = "获取好友分组列表"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.ACCOUNT.GET_FRIENDS_WITH_CATEGORY]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.ACCOUNT.GET_FRIENDS_WITH_CATEGORY]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
payload = {}
|
||||
@@ -120,7 +121,7 @@ class SetAvatarHandler(BaseEventHandler):
|
||||
handler_description: str = "设置头像"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.ACCOUNT.SET_AVATAR]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.ACCOUNT.SET_AVATAR]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -147,7 +148,7 @@ class SendLikeHandler(BaseEventHandler):
|
||||
handler_description: str = "点赞"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.ACCOUNT.SEND_LIKE]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.ACCOUNT.SEND_LIKE]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -176,7 +177,7 @@ class SetFriendAddRequestHandler(BaseEventHandler):
|
||||
handler_description: str = "处理好友请求"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.ACCOUNT.SET_FRIEND_ADD_REQUEST]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.ACCOUNT.SET_FRIEND_ADD_REQUEST]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -207,7 +208,7 @@ class SetSelfLongnickHandler(BaseEventHandler):
|
||||
handler_description: str = "设置个性签名"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.ACCOUNT.SET_SELF_LONGNICK]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.ACCOUNT.SET_SELF_LONGNICK]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -240,7 +241,7 @@ class GetLoginInfoHandler(BaseEventHandler):
|
||||
handler_description: str = "获取登录号信息"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.ACCOUNT.GET_LOGIN_INFO]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.ACCOUNT.GET_LOGIN_INFO]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
payload = {}
|
||||
@@ -257,7 +258,7 @@ class GetRecentContactHandler(BaseEventHandler):
|
||||
handler_description: str = "最近消息列表"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.ACCOUNT.GET_RECENT_CONTACT]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.ACCOUNT.GET_RECENT_CONTACT]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -280,7 +281,7 @@ class GetStrangerInfoHandler(BaseEventHandler):
|
||||
handler_description: str = "获取(指定)账号信息"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.ACCOUNT.GET_STRANGER_INFO]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.ACCOUNT.GET_STRANGER_INFO]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -307,7 +308,7 @@ class GetFriendListHandler(BaseEventHandler):
|
||||
handler_description: str = "获取好友列表"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.ACCOUNT.GET_FRIEND_LIST]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.ACCOUNT.GET_FRIEND_LIST]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -330,7 +331,7 @@ class GetProfileLikeHandler(BaseEventHandler):
|
||||
handler_description: str = "获取点赞列表"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.ACCOUNT.GET_PROFILE_LIKE]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.ACCOUNT.GET_PROFILE_LIKE]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -360,7 +361,7 @@ class DeleteFriendHandler(BaseEventHandler):
|
||||
handler_description: str = "删除好友"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.ACCOUNT.DELETE_FRIEND]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.ACCOUNT.DELETE_FRIEND]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -395,7 +396,7 @@ class GetUserStatusHandler(BaseEventHandler):
|
||||
handler_description: str = "获取(指定)用户状态"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.ACCOUNT.GET_USER_STATUS]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.ACCOUNT.GET_USER_STATUS]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -422,7 +423,7 @@ class GetStatusHandler(BaseEventHandler):
|
||||
handler_description: str = "获取状态"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.ACCOUNT.GET_STATUS]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.ACCOUNT.GET_STATUS]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
payload = {}
|
||||
@@ -439,7 +440,7 @@ class GetMiniAppArkHandler(BaseEventHandler):
|
||||
handler_description: str = "获取小程序卡片"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.ACCOUNT.GET_MINI_APP_ARK]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.ACCOUNT.GET_MINI_APP_ARK]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -486,7 +487,7 @@ class SetDiyOnlineStatusHandler(BaseEventHandler):
|
||||
handler_description: str = "设置自定义在线状态"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.ACCOUNT.SET_DIY_ONLINE_STATUS]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.ACCOUNT.SET_DIY_ONLINE_STATUS]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -518,7 +519,7 @@ class SendPrivateMsgHandler(BaseEventHandler):
|
||||
handler_description: str = "发送私聊消息"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.MESSAGE.SEND_PRIVATE_MSG]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.MESSAGE.SEND_PRIVATE_MSG]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -547,7 +548,7 @@ class SendPokeHandler(BaseEventHandler):
|
||||
handler_description: str = "发送戳一戳"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.MESSAGE.SEND_POKE]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.MESSAGE.SEND_POKE]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -579,7 +580,7 @@ class DeleteMsgHandler(BaseEventHandler):
|
||||
handler_description: str = "撤回消息"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.MESSAGE.DELETE_MSG]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.MESSAGE.DELETE_MSG]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -606,7 +607,7 @@ class GetGroupMsgHistoryHandler(BaseEventHandler):
|
||||
handler_description: str = "获取群历史消息"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.MESSAGE.GET_GROUP_MSG_HISTORY]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.MESSAGE.GET_GROUP_MSG_HISTORY]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -644,7 +645,7 @@ class GetMsgHandler(BaseEventHandler):
|
||||
handler_description: str = "获取消息详情"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.MESSAGE.GET_MSG]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.MESSAGE.GET_MSG]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -671,7 +672,7 @@ class GetForwardMsgHandler(BaseEventHandler):
|
||||
handler_description: str = "获取合并转发消息"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.MESSAGE.GET_FORWARD_MSG]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.MESSAGE.GET_FORWARD_MSG]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -698,7 +699,7 @@ class SetMsgEmojiLikeHandler(BaseEventHandler):
|
||||
handler_description: str = "贴表情"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.MESSAGE.SET_MSG_EMOJI_LIKE]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.MESSAGE.SET_MSG_EMOJI_LIKE]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -729,7 +730,7 @@ class GetFriendMsgHistoryHandler(BaseEventHandler):
|
||||
handler_description: str = "获取好友历史消息"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.MESSAGE.GET_FRIEND_MSG_HISTORY]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.MESSAGE.GET_FRIEND_MSG_HISTORY]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -767,7 +768,7 @@ class FetchEmojiLikeHandler(BaseEventHandler):
|
||||
handler_description: str = "获取贴表情详情"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.MESSAGE.FETCH_EMOJI_LIKE]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.MESSAGE.FETCH_EMOJI_LIKE]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -805,7 +806,7 @@ class SendForwardMsgHandler(BaseEventHandler):
|
||||
handler_description: str = "发送合并转发消息"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.MESSAGE.SEND_FORWARD_MSG]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.MESSAGE.SEND_FORWARD_MSG]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -849,7 +850,7 @@ class SendGroupAiRecordHandler(BaseEventHandler):
|
||||
handler_description: str = "发送群AI语音"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.MESSAGE.SEND_GROUP_AI_RECORD]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.MESSAGE.SEND_GROUP_AI_RECORD]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -881,7 +882,7 @@ class GetGroupInfoHandler(BaseEventHandler):
|
||||
handler_description: str = "获取群信息"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.GET_GROUP_INFO]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.GET_GROUP_INFO]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -908,7 +909,7 @@ class SetGroupAddOptionHandler(BaseEventHandler):
|
||||
handler_description: str = "设置群添加选项"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.SET_GROUP_ADD_OPTION]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.SET_GROUP_ADD_OPTION]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -946,7 +947,7 @@ class SetGroupKickMembersHandler(BaseEventHandler):
|
||||
handler_description: str = "批量踢出群成员"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.SET_GROUP_KICK_MEMBERS]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.SET_GROUP_KICK_MEMBERS]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -977,7 +978,7 @@ class SetGroupRemarkHandler(BaseEventHandler):
|
||||
handler_description: str = "设置群备注"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.SET_GROUP_REMARK]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.SET_GROUP_REMARK]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1006,7 +1007,7 @@ class SetGroupKickHandler(BaseEventHandler):
|
||||
handler_description: str = "群踢人"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.SET_GROUP_KICK]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.SET_GROUP_KICK]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1037,7 +1038,7 @@ class GetGroupSystemMsgHandler(BaseEventHandler):
|
||||
handler_description: str = "获取群系统消息"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.GET_GROUP_SYSTEM_MSG]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.GET_GROUP_SYSTEM_MSG]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1064,7 +1065,7 @@ class SetGroupBanHandler(BaseEventHandler):
|
||||
handler_description: str = "群禁言"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.SET_GROUP_BAN]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.SET_GROUP_BAN]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1095,7 +1096,7 @@ class GetEssenceMsgListHandler(BaseEventHandler):
|
||||
handler_description: str = "获取群精华消息"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.GET_ESSENCE_MSG_LIST]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.GET_ESSENCE_MSG_LIST]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1122,7 +1123,7 @@ class SetGroupWholeBanHandler(BaseEventHandler):
|
||||
handler_description: str = "全体禁言"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.SET_GROUP_WHOLE_BAN]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.SET_GROUP_WHOLE_BAN]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1151,7 +1152,7 @@ class SetGroupPortraitHandler(BaseEventHandler):
|
||||
handler_description: str = "设置群头像"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.SET_GROUP_PORTRAINT]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.SET_GROUP_PORTRAINT]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1180,7 +1181,7 @@ class SetGroupAdminHandler(BaseEventHandler):
|
||||
handler_description: str = "设置群管理"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.SET_GROUP_ADMIN]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.SET_GROUP_ADMIN]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1211,7 +1212,7 @@ class SetGroupCardHandler(BaseEventHandler):
|
||||
handler_description: str = "设置群成员名片"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.SET_GROUP_CARD]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.SET_GROUP_CARD]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1245,7 +1246,7 @@ class SetEssenceMsgHandler(BaseEventHandler):
|
||||
handler_description: str = "设置群精华消息"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.SET_ESSENCE_MSG]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.SET_ESSENCE_MSG]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1272,7 +1273,7 @@ class SetGroupNameHandler(BaseEventHandler):
|
||||
handler_description: str = "设置群名"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.SET_GROUP_NAME]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.SET_GROUP_NAME]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1301,7 +1302,7 @@ class DeleteEssenceMsgHandler(BaseEventHandler):
|
||||
handler_description: str = "删除群精华消息"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.DELETE_ESSENCE_MSG]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.DELETE_ESSENCE_MSG]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1328,7 +1329,7 @@ class SetGroupLeaveHandler(BaseEventHandler):
|
||||
handler_description: str = "退群"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.SET_GROUP_LEAVE]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.SET_GROUP_LEAVE]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1355,7 +1356,7 @@ class SendGroupNoticeHandler(BaseEventHandler):
|
||||
handler_description: str = "发送群公告"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.SEND_GROUP_NOTICE]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.SEND_GROUP_NOTICE]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1389,7 +1390,7 @@ class SetGroupSpecialTitleHandler(BaseEventHandler):
|
||||
handler_description: str = "设置群头衔"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.SET_GROUP_SPECIAL_TITLE]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.SET_GROUP_SPECIAL_TITLE]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1423,7 +1424,7 @@ class GetGroupNoticeHandler(BaseEventHandler):
|
||||
handler_description: str = "获取群公告"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.GET_GROUP_NOTICE]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.GET_GROUP_NOTICE]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1450,7 +1451,7 @@ class SetGroupAddRequestHandler(BaseEventHandler):
|
||||
handler_description: str = "处理加群请求"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.SET_GROUP_ADD_REQUEST]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.SET_GROUP_ADD_REQUEST]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1484,7 +1485,7 @@ class GetGroupListHandler(BaseEventHandler):
|
||||
handler_description: str = "获取群列表"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.GET_GROUP_LIST]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.GET_GROUP_LIST]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1507,7 +1508,7 @@ class DeleteGroupNoticeHandler(BaseEventHandler):
|
||||
handler_description: str = "删除群公告"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.DELETE_GROUP_NOTICE]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.DELETE_GROUP_NOTICE]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1536,7 +1537,7 @@ class GetGroupMemberInfoHandler(BaseEventHandler):
|
||||
handler_description: str = "获取群成员信息"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.GET_GROUP_MEMBER_INFO]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.GET_GROUP_MEMBER_INFO]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1567,7 +1568,7 @@ class GetGroupMemberListHandler(BaseEventHandler):
|
||||
handler_description: str = "获取群成员列表"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.GET_GROUP_MEMBER_LIST]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.GET_GROUP_MEMBER_LIST]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1596,7 +1597,7 @@ class GetGroupHonorInfoHandler(BaseEventHandler):
|
||||
handler_description: str = "获取群荣誉"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.GET_GROUP_HONOR_INFO]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.GET_GROUP_HONOR_INFO]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1628,7 +1629,7 @@ class GetGroupInfoExHandler(BaseEventHandler):
|
||||
handler_description: str = "获取群信息ex"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.GET_GROUP_INFO_EX]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.GET_GROUP_INFO_EX]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1655,7 +1656,7 @@ class GetGroupAtAllRemainHandler(BaseEventHandler):
|
||||
handler_description: str = "获取群 @全体成员 剩余次数"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.GET_GROUP_AT_ALL_REMAIN]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.GET_GROUP_AT_ALL_REMAIN]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1682,7 +1683,7 @@ class GetGroupShutListHandler(BaseEventHandler):
|
||||
handler_description: str = "获取群禁言列表"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.GET_GROUP_SHUT_LIST]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.GET_GROUP_SHUT_LIST]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1709,7 +1710,7 @@ class GetGroupIgnoredNotifiesHandler(BaseEventHandler):
|
||||
handler_description: str = "获取群过滤系统消息"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.GET_GROUP_IGNORED_NOTIFIES]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.GET_GROUP_IGNORED_NOTIFIES]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
payload = {}
|
||||
@@ -1726,7 +1727,7 @@ class SetGroupSignHandler(BaseEventHandler):
|
||||
handler_description: str = "群打卡"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.GROUP.SET_GROUP_SIGN]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.GROUP.SET_GROUP_SIGN]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
@@ -1754,7 +1755,7 @@ class SetInputStatusHandler(BaseEventHandler):
|
||||
handler_description: str = "设置输入状态"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [NapcatEvent.PERSONAL.SET_INPUT_STATUS]
|
||||
init_subscribe: ClassVar[list] = [NapcatEvent.PERSONAL.SET_INPUT_STATUS]
|
||||
|
||||
async def execute(self, params: dict):
|
||||
raw = params.get("raw", {})
|
||||
|
||||
@@ -1,25 +1,24 @@
|
||||
import asyncio
|
||||
import json
|
||||
import inspect
|
||||
import json
|
||||
from typing import ClassVar, List
|
||||
|
||||
import websockets as Server
|
||||
from . import event_types, CONSTS, event_handlers
|
||||
|
||||
from typing import List
|
||||
|
||||
from src.plugin_system import BasePlugin, BaseEventHandler, register_plugin, EventType, ConfigField
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system.apis import config_api
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BaseEventHandler, BasePlugin, ConfigField, EventType, register_plugin
|
||||
from src.plugin_system.apis import config_api
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
from . import CONSTS, event_handlers, event_types
|
||||
from .src.message_chunker import chunker, reassembler
|
||||
from .src.mmc_com_layer import mmc_start_com, mmc_stop_com, router
|
||||
from .src.recv_handler.message_handler import message_handler
|
||||
from .src.recv_handler.message_sending import message_send_instance
|
||||
from .src.recv_handler.meta_event_handler import meta_event_handler
|
||||
from .src.recv_handler.notice_handler import notice_handler
|
||||
from .src.recv_handler.message_sending import message_send_instance
|
||||
from .src.response_pool import check_timeout_response, put_response
|
||||
from .src.send_handler import send_handler
|
||||
from .src.mmc_com_layer import mmc_start_com, router, mmc_stop_com
|
||||
from .src.response_pool import put_response, check_timeout_response
|
||||
from .src.websocket_manager import websocket_manager
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
@@ -219,7 +218,7 @@ class LauchNapcatAdapterHandler(BaseEventHandler):
|
||||
handler_description: str = "自动启动napcat adapter"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [EventType.ON_START]
|
||||
init_subscribe: ClassVar[list] = [EventType.ON_START]
|
||||
|
||||
async def execute(self, kwargs):
|
||||
# 启动消息重组器的清理任务
|
||||
@@ -267,7 +266,7 @@ class StopNapcatAdapterHandler(BaseEventHandler):
|
||||
handler_description: str = "关闭napcat adapter"
|
||||
weight: int = 100
|
||||
intercept_message: bool = False
|
||||
init_subscribe = [EventType.ON_STOP]
|
||||
init_subscribe: ClassVar[list] = [EventType.ON_STOP]
|
||||
|
||||
async def execute(self, kwargs):
|
||||
await graceful_shutdown()
|
||||
@@ -277,8 +276,8 @@ class StopNapcatAdapterHandler(BaseEventHandler):
|
||||
@register_plugin
|
||||
class NapcatAdapterPlugin(BasePlugin):
|
||||
plugin_name = CONSTS.PLUGIN_NAME
|
||||
dependencies: List[str] = [] # 插件依赖列表
|
||||
python_dependencies: List[str] = [] # Python包依赖列表
|
||||
dependencies: ClassVar[List[str]] = [] # 插件依赖列表
|
||||
python_dependencies: ClassVar[List[str]] = [] # Python包依赖列表
|
||||
config_file_name: str = "config.toml" # 配置文件名
|
||||
|
||||
@property
|
||||
@@ -291,10 +290,10 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
return False
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {"plugin": "插件基本信息"}
|
||||
config_section_descriptions: ClassVar[dict] = {"plugin": "插件基本信息"}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema: dict = {
|
||||
config_schema: ClassVar[dict] = {
|
||||
"plugin": {
|
||||
"name": ConfigField(type=str, default="napcat_adapter_plugin", description="插件名称"),
|
||||
"version": ConfigField(type=str, default="1.1.0", description="插件版本"),
|
||||
@@ -389,7 +388,7 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
}
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {
|
||||
config_section_descriptions: ClassVar[dict] = {
|
||||
"plugin": "插件基本信息",
|
||||
"inner": "内部配置信息(请勿修改)",
|
||||
"nickname": "昵称配置(目前未使用)",
|
||||
@@ -421,9 +420,11 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
components = []
|
||||
components.append((LauchNapcatAdapterHandler.get_handler_info(), LauchNapcatAdapterHandler))
|
||||
components.append((StopNapcatAdapterHandler.get_handler_info(), StopNapcatAdapterHandler))
|
||||
for handler in get_classes_in_module(event_handlers):
|
||||
if issubclass(handler, BaseEventHandler):
|
||||
components.append((handler.get_handler_info(), handler))
|
||||
components.extend(
|
||||
(handler.get_handler_info(), handler)
|
||||
for handler in get_classes_in_module(event_handlers)
|
||||
if issubclass(handler, BaseEventHandler)
|
||||
)
|
||||
return components
|
||||
|
||||
async def on_plugin_loaded(self):
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from enum import Enum
|
||||
import tomlkit
|
||||
import os
|
||||
from enum import Enum
|
||||
|
||||
import tomlkit
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
@@ -13,9 +13,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List, Sequence
|
||||
from typing import List, Optional, Sequence
|
||||
|
||||
from sqlalchemy import Column, Integer, BigInteger, UniqueConstraint, select, Index
|
||||
from sqlalchemy import BigInteger, Column, Index, Integer, UniqueConstraint, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.common.database.sqlalchemy_models import Base, get_db_session
|
||||
|
||||
@@ -4,14 +4,14 @@
|
||||
仅在 Ada -> MMC 方向进行切片,其他方向(MMC -> Ada,Ada <-> Napcat)不切片
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from src.plugin_system.apis import config_api
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from maim_message import Router, RouteConfig, TargetConfig
|
||||
from maim_message import RouteConfig, Router, TargetConfig
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.server import get_global_server
|
||||
|
||||
@@ -1,45 +1,43 @@
|
||||
from ...event_types import NapcatEvent
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import websockets as Server
|
||||
from maim_message import (
|
||||
BaseMessageInfo,
|
||||
FormatInfo,
|
||||
GroupInfo,
|
||||
MessageBase,
|
||||
Seg,
|
||||
TemplateInfo,
|
||||
UserInfo,
|
||||
)
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from ...CONSTS import PLUGIN_NAME
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
from src.plugin_system.apis import config_api
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
from ...CONSTS import PLUGIN_NAME
|
||||
from ...event_types import NapcatEvent
|
||||
from ..response_pool import get_response
|
||||
from ..utils import (
|
||||
get_group_info,
|
||||
get_member_info,
|
||||
get_image_base64,
|
||||
get_member_info,
|
||||
get_message_detail,
|
||||
get_record_detail,
|
||||
get_self_info,
|
||||
get_message_detail,
|
||||
)
|
||||
from .qq_emoji_list import qq_face
|
||||
from .message_sending import message_send_instance
|
||||
from . import RealMessageType, MessageType, ACCEPT_FORMAT
|
||||
from ..video_handler import get_video_downloader
|
||||
from ..websocket_manager import websocket_manager
|
||||
from . import ACCEPT_FORMAT, MessageType, RealMessageType
|
||||
from .message_sending import message_send_instance
|
||||
from .qq_emoji_list import qq_face
|
||||
|
||||
import time
|
||||
import json
|
||||
import websockets as Server
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Optional, Dict, Any
|
||||
import uuid
|
||||
|
||||
from maim_message import (
|
||||
UserInfo,
|
||||
GroupInfo,
|
||||
Seg,
|
||||
BaseMessageInfo,
|
||||
MessageBase,
|
||||
TemplateInfo,
|
||||
FormatInfo,
|
||||
)
|
||||
|
||||
|
||||
from ..response_pool import get_response
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
|
||||
class MessageHandler:
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import asyncio
|
||||
|
||||
from maim_message import MessageBase, Router
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from ..message_chunker import chunker
|
||||
from src.plugin_system.apis import config_api
|
||||
|
||||
from ..message_chunker import chunker
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
from maim_message import MessageBase, Router
|
||||
|
||||
|
||||
class MessageSending:
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
from src.plugin_system.apis import config_api
|
||||
import time
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api
|
||||
|
||||
from . import MetaEventType
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
|
||||
class MetaEventHandler:
|
||||
"""
|
||||
|
||||
@@ -1,21 +1,16 @@
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import ClassVar, Optional, Tuple
|
||||
|
||||
import websockets as Server
|
||||
from typing import Tuple, Optional
|
||||
from maim_message import BaseMessageInfo, FormatInfo, GroupInfo, MessageBase, Seg, UserInfo
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
from src.plugin_system.apis import config_api
|
||||
from ..database import BanUser, napcat_db, is_identical
|
||||
from . import NoticeType, ACCEPT_FORMAT
|
||||
from .message_sending import message_send_instance
|
||||
from .message_handler import message_handler
|
||||
from maim_message import FormatInfo, UserInfo, GroupInfo, Seg, BaseMessageInfo, MessageBase
|
||||
from ..websocket_manager import websocket_manager
|
||||
|
||||
from ...CONSTS import PLUGIN_NAME, QQ_FACE
|
||||
from ..database import BanUser, is_identical, napcat_db
|
||||
from ..utils import (
|
||||
get_group_info,
|
||||
get_member_info,
|
||||
@@ -23,16 +18,20 @@ from ..utils import (
|
||||
get_stranger_info,
|
||||
read_ban_list,
|
||||
)
|
||||
from ..websocket_manager import websocket_manager
|
||||
from . import ACCEPT_FORMAT, NoticeType
|
||||
from .message_handler import message_handler
|
||||
from .message_sending import message_send_instance
|
||||
|
||||
from ...CONSTS import PLUGIN_NAME, QQ_FACE
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
notice_queue: asyncio.Queue[MessageBase] = asyncio.Queue(maxsize=100)
|
||||
unsuccessful_notice_queue: asyncio.Queue[MessageBase] = asyncio.Queue(maxsize=3)
|
||||
|
||||
|
||||
class NoticeHandler:
|
||||
banned_list: list[BanUser] = [] # 当前仍在禁言中的用户列表
|
||||
lifted_list: list[BanUser] = [] # 已经自然解除禁言
|
||||
banned_list: ClassVar[list[BanUser]] = [] # 当前仍在禁言中的用户列表
|
||||
lifted_list: ClassVar[list[BanUser]] = [] # 已经自然解除禁言
|
||||
|
||||
def __init__(self):
|
||||
self.server_connection: Server.ServerConnection | None = None
|
||||
@@ -131,6 +130,7 @@ class NoticeHandler:
|
||||
logger.warning("戳一戳消息被禁用,取消戳一戳处理")
|
||||
case NoticeType.Notify.input_status:
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
from ...event_types import NapcatEvent
|
||||
|
||||
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME)
|
||||
@@ -357,6 +357,7 @@ class NoticeHandler:
|
||||
logger.debug("无法获取表情回复对方的用户昵称")
|
||||
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
from ...event_types import NapcatEvent
|
||||
|
||||
target_message = await event_manager.trigger_event(NapcatEvent.MESSAGE.GET_MSG,message_id=raw_message.get("message_id",""))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api
|
||||
|
||||
|
||||
@@ -1,26 +1,28 @@
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
import websockets as Server
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import websockets as Server
|
||||
from maim_message import (
|
||||
UserInfo,
|
||||
GroupInfo,
|
||||
Seg,
|
||||
BaseMessageInfo,
|
||||
GroupInfo,
|
||||
MessageBase,
|
||||
Seg,
|
||||
UserInfo,
|
||||
)
|
||||
from typing import Dict, Any, Tuple, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api
|
||||
|
||||
from . import CommandType
|
||||
from .recv_handler.message_sending import message_send_instance
|
||||
from .response_pool import get_response
|
||||
from src.common.logger import get_logger
|
||||
from .utils import convert_image_to_gif, get_image_format
|
||||
from .websocket_manager import websocket_manager
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
from .utils import get_image_format, convert_image_to_gif
|
||||
from .recv_handler.message_sending import message_send_instance
|
||||
from .websocket_manager import websocket_manager
|
||||
|
||||
|
||||
class SendHandler:
|
||||
@@ -546,7 +548,7 @@ class SendHandler:
|
||||
set_like = bool(args["set"])
|
||||
except (KeyError, ValueError) as e:
|
||||
logger.error(f"处理表情回应命令时发生错误: {e}, 原始参数: {args}")
|
||||
raise ValueError(f"缺少必需参数或参数类型错误: {e}")
|
||||
raise ValueError(f"缺少必需参数或参数类型错误: {e}") from e
|
||||
|
||||
return (
|
||||
CommandType.SET_EMOJI_LIKE.value,
|
||||
@@ -566,8 +568,8 @@ class SendHandler:
|
||||
try:
|
||||
user_id: int = int(args["qq_id"])
|
||||
times: int = int(args["times"])
|
||||
except (KeyError, ValueError):
|
||||
raise ValueError("缺少必需参数: qq_id 或 times")
|
||||
except (KeyError, ValueError) as e:
|
||||
raise ValueError("缺少必需参数: qq_id 或 times") from e
|
||||
|
||||
return (
|
||||
CommandType.SEND_LIKE.value,
|
||||
|
||||
@@ -1,19 +1,20 @@
|
||||
import websockets as Server
|
||||
import json
|
||||
import base64
|
||||
import uuid
|
||||
import urllib3
|
||||
import ssl
|
||||
import io
|
||||
import json
|
||||
import ssl
|
||||
import uuid
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import urllib3
|
||||
import websockets as Server
|
||||
from PIL import Image
|
||||
|
||||
from .database import BanUser, napcat_db
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("napcat_adapter")
|
||||
from .database import BanUser, napcat_db
|
||||
from .response_pool import get_response
|
||||
|
||||
from PIL import Image
|
||||
from typing import Union, List, Tuple, Optional
|
||||
logger = get_logger("napcat_adapter")
|
||||
|
||||
|
||||
class SSLAdapter(urllib3.PoolManager):
|
||||
|
||||
@@ -5,10 +5,12 @@
|
||||
用于从QQ消息中下载视频并转发给Bot进行分析
|
||||
"""
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("video_handler")
|
||||
@@ -34,20 +36,20 @@ class VideoDownloader:
|
||||
if any(keyword in url_lower for keyword in video_keywords):
|
||||
return True
|
||||
|
||||
# 检查文件扩展名(传统方法)
|
||||
# 检查文件扩展名(传统方法)
|
||||
path = Path(url.split("?")[0]) # 移除查询参数
|
||||
if path.suffix.lower() in self.supported_formats:
|
||||
return True
|
||||
|
||||
# 对于QQ等特殊平台,URL可能没有扩展名
|
||||
# 我们允许这些URL通过,稍后通过HTTP头Content-Type验证
|
||||
# 对于QQ等特殊平台,URL可能没有扩展名
|
||||
# 我们允许这些URL通过,稍后通过HTTP头Content-Type验证
|
||||
qq_domains = ["qpic.cn", "gtimg.cn", "qq.com", "tencent.com"]
|
||||
if any(domain in url_lower for domain in qq_domains):
|
||||
return True
|
||||
|
||||
return False
|
||||
except:
|
||||
# 如果解析失败,默认允许尝试下载(稍后验证)
|
||||
except Exception:
|
||||
# 如果解析失败,默认允许尝试下载(稍后验证)
|
||||
return True
|
||||
|
||||
def check_file_size(self, content_length: Optional[str]) -> bool:
|
||||
@@ -59,7 +61,7 @@ class VideoDownloader:
|
||||
size_bytes = int(content_length)
|
||||
size_mb = size_bytes / (1024 * 1024)
|
||||
return size_mb <= self.max_size_mb
|
||||
except:
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
async def download_video(self, url: str, filename: Optional[str] = None) -> Dict[str, Any]:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import websockets as Server
|
||||
from typing import Optional, Callable, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import ClassVar
|
||||
|
||||
from src.plugin_system.apis.logging_api import get_logger
|
||||
from src.plugin_system.apis.permission_api import permission_api
|
||||
@@ -29,7 +30,7 @@ class PermissionCommand(PlusCommand):
|
||||
|
||||
command_name = "permission"
|
||||
command_description = "权限管理命令,支持授权、撤销、查询等功能"
|
||||
command_aliases = ["perm", "权限"]
|
||||
command_aliases: ClassVar[list[str]] = ["perm", "权限"]
|
||||
priority = 10
|
||||
chat_type_allow = ChatType.ALL
|
||||
intercept_message = True
|
||||
@@ -37,7 +38,7 @@ class PermissionCommand(PlusCommand):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
permission_nodes: list[PermissionNodeField] = [
|
||||
permission_nodes: ClassVar[list[PermissionNodeField]] = [
|
||||
PermissionNodeField(
|
||||
node_name="manage",
|
||||
description="权限管理:可以授权和撤销其他用户的权限",
|
||||
@@ -382,10 +383,10 @@ class PermissionCommand(PlusCommand):
|
||||
class PermissionManagerPlugin(BasePlugin):
|
||||
plugin_name: str = "permission_manager_plugin"
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = []
|
||||
python_dependencies: list[str] = []
|
||||
dependencies: ClassVar[list[str]] = []
|
||||
python_dependencies: ClassVar[list[str]] = []
|
||||
config_file_name: str = "config.toml"
|
||||
config_schema: dict = {
|
||||
config_schema: ClassVar[dict] = {
|
||||
"plugin": {
|
||||
"enabled": ConfigField(bool, default=True, description="是否启用插件"),
|
||||
"config_version": ConfigField(type=str, default="1.1.0", description="配置文件版本"),
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from typing import ClassVar
|
||||
|
||||
from src.plugin_system import (
|
||||
BasePlugin,
|
||||
@@ -21,7 +22,7 @@ class ManagementCommand(PlusCommand):
|
||||
|
||||
command_name = "pm"
|
||||
command_description = "插件管理命令,支持插件和组件的管理操作"
|
||||
command_aliases = ["pluginmanage", "插件管理"]
|
||||
command_aliases: ClassVar[list[str]] = ["pluginmanage", "插件管理"]
|
||||
priority = 10
|
||||
chat_type_allow = ChatType.ALL
|
||||
intercept_message = True
|
||||
@@ -273,6 +274,7 @@ class ManagementCommand(PlusCommand):
|
||||
def _fetch_all_registered_components() -> list[ComponentInfo]:
|
||||
all_plugin_info = component_manage_api.get_all_plugin_info()
|
||||
if not all_plugin_info:
|
||||
|
||||
return []
|
||||
|
||||
components_info: list[ComponentInfo] = []
|
||||
@@ -486,10 +488,10 @@ class ManagementCommand(PlusCommand):
|
||||
class PluginManagementPlugin(BasePlugin):
|
||||
plugin_name: str = "plugin_management_plugin"
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = []
|
||||
python_dependencies: list[str] = []
|
||||
dependencies: ClassVar[list[str]] = []
|
||||
python_dependencies: ClassVar[list[str]] = []
|
||||
config_file_name: str = "config.toml"
|
||||
config_schema: dict = {
|
||||
config_schema: ClassVar[dict] = {
|
||||
"plugin": {
|
||||
"enabled": ConfigField(bool, default=False, description="是否启用插件"),
|
||||
"config_version": ConfigField(type=str, default="1.1.0", description="配置文件版本"),
|
||||
|
||||
@@ -152,12 +152,12 @@ class PokeAction(BaseAction):
|
||||
parallel_action = True
|
||||
|
||||
# === 功能描述(必须填写)===
|
||||
action_parameters = {
|
||||
action_parameters: ClassVar[dict] = {
|
||||
"user_name": "需要戳一戳的用户的名字 (可选)",
|
||||
"user_id": "需要戳一戳的用户的ID (可选,优先级更高)",
|
||||
"times": "需要戳一戳的次数 (默认为 1)",
|
||||
}
|
||||
action_require = ["当需要戳某个用户时使用", "当你想提醒特定用户时使用"]
|
||||
action_require: ClassVar[list] = ["当需要戳某个用户时使用", "当你想提醒特定用户时使用"]
|
||||
llm_judge_prompt = """
|
||||
判定是否需要使用戳一戳动作的条件:
|
||||
1. **互动时机**: 这是一个有趣的互动方式,可以在想提醒某人,或者单纯想开个玩笑时使用。
|
||||
@@ -167,7 +167,7 @@ class PokeAction(BaseAction):
|
||||
|
||||
请根据上述规则,回答“是”或“否”。
|
||||
"""
|
||||
associated_types = ["text"]
|
||||
associated_types: ClassVar[list[str]] = ["text"]
|
||||
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""执行戳一戳的动作"""
|
||||
@@ -225,10 +225,10 @@ class SetEmojiLikeAction(BaseAction):
|
||||
parallel_action = True
|
||||
|
||||
# === 功能描述(必须填写)===
|
||||
action_parameters = {
|
||||
action_parameters: ClassVar[dict] = {
|
||||
"set": "是否设置回应 (True/False)",
|
||||
}
|
||||
action_require = [
|
||||
action_require: ClassVar[list] = [
|
||||
"当需要对一个已存在消息进行‘贴表情’回应时使用",
|
||||
"这是一个对旧消息的操作,而不是发送新消息",
|
||||
]
|
||||
@@ -240,10 +240,10 @@ class SetEmojiLikeAction(BaseAction):
|
||||
|
||||
请回答"是"或"否"。
|
||||
"""
|
||||
associated_types = ["text"]
|
||||
associated_types: ClassVar[list[str]] = ["text"]
|
||||
|
||||
# 重新启用完整的表情库
|
||||
emoji_options = []
|
||||
emoji_options: ClassVar[list] = []
|
||||
for name in qq_face.values():
|
||||
match = re.search(r"\[表情:(.+?)\]", name)
|
||||
if match:
|
||||
@@ -359,14 +359,14 @@ class RemindAction(BaseAction):
|
||||
action_name = "set_reminder"
|
||||
action_description = "根据用户的对话内容,智能地设置一个未来的提醒事项。"
|
||||
activation_type = ActionActivationType.KEYWORD
|
||||
activation_keywords = ["提醒", "叫我", "记得", "别忘了"]
|
||||
activation_keywords: ClassVar[list[str]] = ["提醒", "叫我", "记得", "别忘了"]
|
||||
chat_type_allow = ChatType.ALL
|
||||
parallel_action = True
|
||||
|
||||
# === LLM 判断与参数提取 ===
|
||||
llm_judge_prompt = ""
|
||||
action_parameters = {}
|
||||
action_require = [
|
||||
action_parameters: ClassVar[dict] = {}
|
||||
action_require: ClassVar[list] = [
|
||||
"当用户请求在未来的某个时间点提醒他/她或别人某件事时使用",
|
||||
"适用于包含明确时间信息和事件描述的对话",
|
||||
"例如:'10分钟后提醒我收快递'、'明天早上九点喊一下李四参加晨会'",
|
||||
@@ -545,12 +545,12 @@ class SetEmojiLikePlugin(BasePlugin):
|
||||
# 插件基本信息
|
||||
plugin_name: str = "social_toolkit_plugin" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = [] # 插件依赖列表
|
||||
python_dependencies: list[str] = [] # Python包依赖列表,现在使用内置API
|
||||
dependencies: ClassVar[list[str]] = [] # 插件依赖列表
|
||||
python_dependencies: ClassVar[list[str]] = [] # Python包依赖列表,现在使用内置API
|
||||
config_file_name: str = "config.toml" # 配置文件名
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {"plugin": "插件基本信息", "components": "插件组件"}
|
||||
config_section_descriptions: ClassVar[dict] = {"plugin": "插件基本信息", "components": "插件组件"}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema: ClassVar[dict] = {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from typing import ClassVar
|
||||
|
||||
import whisper
|
||||
|
||||
@@ -19,7 +20,7 @@ class LocalASRTool(BaseTool):
|
||||
"""
|
||||
tool_name = "local_asr"
|
||||
tool_description = "将本地音频文件路径转换为文字。"
|
||||
tool_parameters = [
|
||||
tool_parameters: ClassVar[list] = [
|
||||
{"name": "audio_path", "type": "string", "description": "需要识别的音频文件路径", "required": True}
|
||||
]
|
||||
|
||||
@@ -50,6 +51,7 @@ class LocalASRTool(BaseTool):
|
||||
async def execute(self, function_args: dict) -> str:
|
||||
audio_path = function_args.get("audio_path")
|
||||
if not audio_path:
|
||||
|
||||
return "错误:缺少 audio_path 参数。"
|
||||
|
||||
global _whisper_model
|
||||
@@ -78,7 +80,7 @@ class LocalASRTool(BaseTool):
|
||||
class STTWhisperPlugin(BasePlugin):
|
||||
plugin_name = "stt_whisper_plugin"
|
||||
config_file_name = "config.toml"
|
||||
python_dependencies = ["openai-whisper"]
|
||||
python_dependencies: ClassVar[list[str]] = ["openai-whisper"]
|
||||
|
||||
async def on_plugin_loaded(self):
|
||||
"""
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import ClassVar
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis.plugin_register_api import register_plugin
|
||||
from src.plugin_system.base.base_action import ActionActivationType, BaseAction, ChatMode
|
||||
@@ -22,16 +24,16 @@ class TTSAction(BaseAction):
|
||||
action_description = "将文本转换为语音进行播放,适用于需要语音输出的场景"
|
||||
|
||||
# 关键词配置 - Normal模式下使用关键词触发
|
||||
activation_keywords = ["语音", "tts", "播报", "读出来", "语音播放", "听", "朗读"]
|
||||
activation_keywords: ClassVar[list[str]] = ["语音", "tts", "播报", "读出来", "语音播放", "听", "朗读"]
|
||||
keyword_case_sensitive = False
|
||||
|
||||
# 动作参数定义
|
||||
action_parameters = {
|
||||
action_parameters: ClassVar[dict] = {
|
||||
"text": "需要转换为语音的文本内容,必填,内容应当适合语音播报,语句流畅、清晰",
|
||||
}
|
||||
|
||||
# 动作使用场景
|
||||
action_require = [
|
||||
action_require: ClassVar[list] = [
|
||||
"当需要发送语音信息时使用",
|
||||
"当用户要求你说话时使用",
|
||||
"当用户要求听你声音时使用",
|
||||
@@ -41,7 +43,7 @@ class TTSAction(BaseAction):
|
||||
]
|
||||
|
||||
# 关联类型
|
||||
associated_types = ["tts_text"]
|
||||
associated_types: ClassVar[list[str]] = ["tts_text"]
|
||||
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
"""处理TTS文本转语音动作"""
|
||||
@@ -111,19 +113,19 @@ class TTSPlugin(BasePlugin):
|
||||
# 插件基本信息
|
||||
plugin_name: str = "tts_plugin" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = [] # 插件依赖列表
|
||||
python_dependencies: list[str] = [] # Python包依赖列表
|
||||
dependencies: ClassVar[list[str]] = [] # 插件依赖列表
|
||||
python_dependencies: ClassVar[list[str]] = [] # Python包依赖列表
|
||||
config_file_name: str = "config.toml"
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {
|
||||
config_section_descriptions: ClassVar[dict] = {
|
||||
"plugin": "插件基本信息配置",
|
||||
"components": "组件启用控制",
|
||||
"logging": "日志记录相关配置",
|
||||
}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema: dict = {
|
||||
config_schema: ClassVar[dict] = {
|
||||
"plugin": {
|
||||
"name": ConfigField(type=str, default="tts_plugin", description="插件名称", required=True),
|
||||
"version": ConfigField(type=str, default="0.1.0", description="插件版本号"),
|
||||
|
||||
@@ -3,6 +3,7 @@ TTS 语音合成 Action
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
import toml
|
||||
|
||||
@@ -32,6 +33,7 @@ def _get_available_styles() -> list[str]:
|
||||
|
||||
styles_config = config.get("tts_styles", [])
|
||||
if not isinstance(styles_config, list):
|
||||
|
||||
return ["default"]
|
||||
|
||||
# 使用显式循环和类型检查来提取 style_name,以确保 Pylance 类型检查通过
|
||||
@@ -65,7 +67,7 @@ class TTSVoiceAction(BaseAction):
|
||||
mode_enable = ChatMode.ALL
|
||||
parallel_action = False
|
||||
|
||||
action_parameters = {
|
||||
action_parameters: ClassVar[dict] = {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "需要转换为语音并发送的完整、自然、适合口语的文本内容。",
|
||||
@@ -97,7 +99,7 @@ class TTSVoiceAction(BaseAction):
|
||||
}
|
||||
}
|
||||
|
||||
action_require = [
|
||||
action_require: ClassVar[list] = [
|
||||
"在调用此动作时,你必须在 'text' 参数中提供要合成语音的完整回复内容。这是强制性的。",
|
||||
"当用户明确请求使用语音进行回复时,例如‘发个语音听听’、‘用语音说’等。",
|
||||
"当对话内容适合用语音表达,例如讲故事、念诗、撒嬌或进行角色扮演时。",
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""
|
||||
TTS 语音合成命令
|
||||
"""
|
||||
from typing import ClassVar
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.command_args import CommandArgs
|
||||
from src.plugin_system.base.plus_command import PlusCommand
|
||||
@@ -18,7 +20,7 @@ class TTSVoiceCommand(PlusCommand):
|
||||
|
||||
command_name: str = "tts"
|
||||
command_description: str = "使用GPT-SoVITS将文本转换为语音并发送"
|
||||
command_aliases = ["语音合成", "说"]
|
||||
command_aliases: ClassVar[list[str]] = ["语音合成", "说"]
|
||||
command_usage = "/tts <要说的文本> [风格]"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
TTS Voice 插件 - 重构版
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import toml
|
||||
|
||||
@@ -29,15 +29,15 @@ class TTSVoicePlugin(BasePlugin):
|
||||
plugin_version = "3.1.2"
|
||||
plugin_author = "Kilo Code & 靚仔"
|
||||
config_file_name = "config.toml"
|
||||
dependencies = []
|
||||
dependencies: ClassVar[list[str]] = []
|
||||
|
||||
permission_nodes: list[PermissionNodeField] = [
|
||||
permission_nodes: ClassVar[list[PermissionNodeField]] = [
|
||||
PermissionNodeField(node_name="command.use", description="是否可以使用 /tts 命令"),
|
||||
]
|
||||
|
||||
config_schema = {}
|
||||
config_schema: ClassVar[dict] = {}
|
||||
|
||||
config_section_descriptions = {
|
||||
config_section_descriptions: ClassVar[dict] = {
|
||||
"plugin": "插件基本配置",
|
||||
"components": "组件启用控制",
|
||||
"tts": "TTS语音合成基础配置",
|
||||
|
||||
@@ -67,10 +67,14 @@ class TTSService:
|
||||
logger.warning("TTS 'default' style is missing 'refer_wav_path'.")
|
||||
|
||||
for style_cfg in tts_styles_config:
|
||||
if not isinstance(style_cfg, dict): continue
|
||||
if not isinstance(style_cfg, dict):
|
||||
|
||||
continue
|
||||
|
||||
style_name = style_cfg.get("style_name")
|
||||
if not style_name: continue
|
||||
if not style_name:
|
||||
|
||||
continue
|
||||
|
||||
styles[style_name] = {
|
||||
"url": global_server,
|
||||
@@ -158,7 +162,9 @@ class TTSService:
|
||||
|
||||
# --- 步骤一:像稳定版一样,先切换模型 ---
|
||||
async def switch_model_weights(weights_path: str | None, weight_type: str):
|
||||
if not weights_path: return
|
||||
if not weights_path:
|
||||
|
||||
return
|
||||
api_endpoint = f"/set_{weight_type}_weights"
|
||||
switch_url = f"{base_url}{api_endpoint}"
|
||||
try:
|
||||
@@ -220,6 +226,7 @@ class TTSService:
|
||||
try:
|
||||
effects_config = self.get_config("spatial_effects", {})
|
||||
if not effects_config.get("enabled", False):
|
||||
|
||||
return audio_data
|
||||
|
||||
# 获取插件目录和IR文件路径
|
||||
@@ -251,6 +258,8 @@ class TTSService:
|
||||
logger.warning(f"卷积混响已启用,但IR文件不存在 ({ir_path}),跳过该效果。")
|
||||
|
||||
if not effects:
|
||||
|
||||
|
||||
return audio_data
|
||||
|
||||
# 将原始音频数据加载到内存中的 AudioFile 对象
|
||||
@@ -293,7 +302,9 @@ class TTSService:
|
||||
|
||||
server_config = self.tts_styles[style]
|
||||
clean_text = self._clean_text_for_tts(text)
|
||||
if not clean_text: return None
|
||||
if not clean_text:
|
||||
|
||||
return None
|
||||
|
||||
# 语言决策流程:
|
||||
# 1. 优先使用决策模型直接指定的 language_hint (最高优先级)
|
||||
|
||||
@@ -42,6 +42,7 @@ class TavilySearchEngine(BaseSearchEngine):
|
||||
async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""执行Tavily搜索"""
|
||||
if not self.is_available():
|
||||
|
||||
return []
|
||||
|
||||
query = args["query"]
|
||||
@@ -76,14 +77,14 @@ class TavilySearchEngine(BaseSearchEngine):
|
||||
|
||||
results = []
|
||||
if search_response and "results" in search_response:
|
||||
for res in search_response["results"]:
|
||||
results.append(
|
||||
results.extend(
|
||||
{
|
||||
"title": res.get("title", "无标题"),
|
||||
"url": res.get("url", ""),
|
||||
"snippet": res.get("content", "")[:300] + "..." if res.get("content") else "无摘要",
|
||||
"provider": "Tavily",
|
||||
}
|
||||
for res in search_response["results"]
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@@ -4,6 +4,8 @@ Web Search Tool Plugin
|
||||
一个功能强大的网络搜索和URL解析插件,支持多种搜索引擎和解析策略。
|
||||
"""
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, register_plugin
|
||||
from src.plugin_system.apis import config_api
|
||||
@@ -30,7 +32,7 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
# 插件基本信息
|
||||
plugin_name: str = "web_search_tool" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = [] # 插件依赖列表
|
||||
dependencies: ClassVar[list[str]] = [] # 插件依赖列表
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""初始化插件,立即加载所有搜索引擎"""
|
||||
@@ -77,11 +79,11 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
config_file_name: str = "config.toml" # 配置文件名
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {"plugin": "插件基本信息", "proxy": "链接本地解析代理配置"}
|
||||
config_section_descriptions: ClassVar[dict] = {"plugin": "插件基本信息", "proxy": "链接本地解析代理配置"}
|
||||
|
||||
# 配置Schema定义
|
||||
# 注意:EXA配置和组件设置已迁移到主配置文件(bot_config.toml)的[exa]和[web_search]部分
|
||||
config_schema: dict = {
|
||||
config_schema: ClassVar[dict] = {
|
||||
"plugin": {
|
||||
"name": ConfigField(type=str, default="WEB_SEARCH_PLUGIN", description="插件名称"),
|
||||
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),
|
||||
|
||||
@@ -4,7 +4,7 @@ URL parser tool implementation
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import httpx
|
||||
from bs4 import BeautifulSoup
|
||||
@@ -30,7 +30,7 @@ class URLParserTool(BaseTool):
|
||||
name: str = "parse_url"
|
||||
description: str = "当需要理解一个或多个特定网页链接的内容时,使用此工具。例如:'这些网页讲了什么?[https://example.com, https://example2.com]' 或 '帮我总结一下这些文章'"
|
||||
available_for_llm: bool = True
|
||||
parameters = [
|
||||
parameters: ClassVar[list] = [
|
||||
("urls", ToolParamType.STRING, "要理解的网站", True, None),
|
||||
]
|
||||
|
||||
@@ -93,6 +93,8 @@ class URLParserTool(BaseTool):
|
||||
text = soup.get_text(strip=True)
|
||||
|
||||
if not text:
|
||||
|
||||
|
||||
return {"error": "无法从页面提取有效文本内容。"}
|
||||
|
||||
summary_prompt = f"请根据以下网页内容,生成一段不超过300字的中文摘要,保留核心信息和关键点:\n\n---\n\n标题: {title}\n\n内容:\n{text[:4000]}\n\n---\n\n摘要:"
|
||||
@@ -144,16 +146,19 @@ class URLParserTool(BaseTool):
|
||||
|
||||
urls_input = function_args.get("urls")
|
||||
if not urls_input:
|
||||
|
||||
return {"error": "URL列表不能为空。"}
|
||||
|
||||
# 处理URL输入,确保是列表格式
|
||||
urls = parse_urls_from_input(urls_input)
|
||||
if not urls:
|
||||
|
||||
return {"error": "提供的字符串中未找到有效的URL。"}
|
||||
|
||||
# 验证URL格式
|
||||
valid_urls = validate_urls(urls)
|
||||
if not valid_urls:
|
||||
|
||||
return {"error": "未找到有效的URL。"}
|
||||
|
||||
urls = valid_urls
|
||||
@@ -226,6 +231,8 @@ class URLParserTool(BaseTool):
|
||||
successful_results.append(res)
|
||||
|
||||
if not successful_results:
|
||||
|
||||
|
||||
return {"error": "无法从所有给定的URL获取内容。", "details": error_messages}
|
||||
|
||||
formatted_content = format_url_parse_results(successful_results)
|
||||
|
||||
@@ -3,7 +3,7 @@ Web search tool implementation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from src.common.cache_manager import tool_cache
|
||||
from src.common.logger import get_logger
|
||||
@@ -31,7 +31,7 @@ class WebSurfingTool(BaseTool):
|
||||
"用于执行网络搜索。当用户明确要求搜索,或者需要获取关于公司、产品、事件的最新信息、新闻或动态时,必须使用此工具"
|
||||
)
|
||||
available_for_llm: bool = True
|
||||
parameters = [
|
||||
parameters: ClassVar[list] = [
|
||||
("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None),
|
||||
("num_results", ToolParamType.INTEGER, "期望每个搜索引擎返回的搜索结果数量,默认为5。", False, None),
|
||||
(
|
||||
@@ -58,6 +58,7 @@ class WebSurfingTool(BaseTool):
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
query = function_args.get("query")
|
||||
if not query:
|
||||
|
||||
return {"error": "搜索查询不能为空。"}
|
||||
|
||||
# 获取当前文件路径用于缓存键
|
||||
@@ -105,6 +106,8 @@ class WebSurfingTool(BaseTool):
|
||||
search_tasks.append(engine.search(custom_args))
|
||||
|
||||
if not search_tasks:
|
||||
|
||||
|
||||
return {"error": "没有可用的搜索引擎。"}
|
||||
|
||||
try:
|
||||
@@ -137,6 +140,7 @@ class WebSurfingTool(BaseTool):
|
||||
for engine_name in enabled_engines:
|
||||
engine = self.engines.get(engine_name)
|
||||
if not engine or not engine.is_available():
|
||||
|
||||
continue
|
||||
|
||||
try:
|
||||
@@ -163,6 +167,7 @@ class WebSurfingTool(BaseTool):
|
||||
for engine_name in enabled_engines:
|
||||
engine = self.engines.get(engine_name)
|
||||
if not engine or not engine.is_available():
|
||||
|
||||
continue
|
||||
|
||||
try:
|
||||
|
||||
@@ -33,10 +33,10 @@ class APIKeyManager(Generic[T]):
|
||||
|
||||
if api_keys:
|
||||
# 过滤有效的API密钥,排除None、空字符串、"None"字符串等
|
||||
valid_keys = []
|
||||
for key in api_keys:
|
||||
if isinstance(key, str) and key.strip() and key.strip().lower() not in ("none", "null", ""):
|
||||
valid_keys.append(key.strip())
|
||||
valid_keys = [
|
||||
key.strip() for key in api_keys
|
||||
if isinstance(key, str) and key.strip() and key.strip().lower() not in ("none", "null", "")
|
||||
]
|
||||
|
||||
if valid_keys:
|
||||
try:
|
||||
@@ -59,6 +59,7 @@ class APIKeyManager(Generic[T]):
|
||||
def get_next_client(self) -> T | None:
|
||||
"""获取下一个客户端(轮询)"""
|
||||
if not self.is_available():
|
||||
|
||||
return None
|
||||
return next(self.client_cycle)
|
||||
|
||||
|
||||
@@ -32,8 +32,4 @@ def validate_urls(urls: list[str]) -> list[str]:
|
||||
"""
|
||||
验证URL格式,返回有效的URL列表
|
||||
"""
|
||||
valid_urls = []
|
||||
for url in urls:
|
||||
if url.startswith(("http://", "https://")):
|
||||
valid_urls.append(url)
|
||||
return valid_urls
|
||||
return [url for url in urls if url.startswith(("http://", "https://"))]
|
||||
|
||||
@@ -102,7 +102,7 @@ class UILogHandler(logging.Handler):
|
||||
emoji_map = {"info": "📝", "warning": "⚠️", "error": "❌", "debug": "🔍"}
|
||||
formatted_msg = f"{emoji_map.get(ui_level, '📝')} {msg}"
|
||||
|
||||
success = self._send_log_with_retry(formatted_msg, ui_level)
|
||||
self._send_log_with_retry(formatted_msg, ui_level)
|
||||
# 可选:记录发送状态
|
||||
# if not success:
|
||||
# print(f"[UI日志适配器] 日志发送失败: {ui_level} - {formatted_msg[:50]}...")
|
||||
|
||||
Reference in New Issue
Block a user