Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox_Bot into dev
This commit is contained in:
@@ -85,6 +85,7 @@
|
||||
- 🧠 **拓展记忆系统** - 支持瞬时记忆等多种记忆
|
||||
- 🎪 **完善的 Event** - 支持动态事件注册和处理器订阅,并实现了聚合结果管理
|
||||
- 🔍 **内嵌魔改插件** - 内置联网搜索等诸多功能,等你来探索
|
||||
- 🔌 **MCP 协议支持** - 集成 Model Context Protocol,支持外部工具服务器连接(仅 Streamable HTTP)
|
||||
- 🌟 **还有更多** - 请参阅详细修改 [commits](https://github.com/MoFox-Studio/MoFox_Bot/commits)
|
||||
|
||||
</td>
|
||||
|
||||
@@ -19,8 +19,8 @@ from src.plugin_system import (
|
||||
ToolParamType,
|
||||
register_plugin,
|
||||
)
|
||||
from src.plugin_system.base.component_types import InjectionRule,InjectionType
|
||||
from src.plugin_system.base.base_event import HandlerResult
|
||||
from src.plugin_system.base.component_types import InjectionRule, InjectionType
|
||||
|
||||
logger = get_logger("hello_world_plugin")
|
||||
|
||||
|
||||
@@ -74,6 +74,7 @@ dependencies = [
|
||||
"aiosqlite>=0.21.0",
|
||||
"inkfox>=0.1.1",
|
||||
"rjieba>=0.1.13",
|
||||
"fastmcp>=2.13.0",
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
@@ -110,6 +111,7 @@ ignore = [
|
||||
"RUF001", # ambiguous-unicode-character-string
|
||||
"RUF002", # ambiguous-unicode-character-docstring
|
||||
"RUF003", # ambiguous-unicode-character-comment
|
||||
"PERF203", # try-except-in-loop (我们需要单独处理每个项的错误)
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ customtkinter
|
||||
dotenv
|
||||
faiss-cpu
|
||||
fastapi
|
||||
fastmcp
|
||||
rjieba
|
||||
jsonlines
|
||||
maim_message
|
||||
|
||||
59
scripts/debug_mcp_tools.py
Normal file
59
scripts/debug_mcp_tools.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
调试 MCP 工具列表获取
|
||||
|
||||
直接测试 MCP 客户端是否能获取工具
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from fastmcp.client import Client, StreamableHttpTransport
|
||||
|
||||
|
||||
async def test_direct_connection():
|
||||
"""直接连接 MCP 服务器并获取工具列表"""
|
||||
print("=" * 60)
|
||||
print("直接测试 MCP 服务器连接")
|
||||
print("=" * 60)
|
||||
|
||||
url = "http://localhost:8000/mcp"
|
||||
print(f"\n连接到: {url}")
|
||||
|
||||
try:
|
||||
# 创建传输层
|
||||
transport = StreamableHttpTransport(url)
|
||||
print("✓ 传输层创建成功")
|
||||
|
||||
# 创建客户端
|
||||
async with Client(transport) as client:
|
||||
print("✓ 客户端连接成功")
|
||||
|
||||
# 获取工具列表
|
||||
print("\n正在获取工具列表...")
|
||||
tools_result = await client.list_tools()
|
||||
|
||||
print(f"\n获取结果类型: {type(tools_result)}")
|
||||
print(f"结果内容: {tools_result}")
|
||||
|
||||
# 检查是否有 tools 属性
|
||||
if hasattr(tools_result, "tools"):
|
||||
tools = tools_result.tools
|
||||
print(f"\n✓ 找到 tools 属性,包含 {len(tools)} 个工具")
|
||||
|
||||
for i, tool in enumerate(tools, 1):
|
||||
print(f"\n工具 {i}:")
|
||||
print(f" 名称: {tool.name}")
|
||||
print(f" 描述: {tool.description}")
|
||||
if hasattr(tool, "inputSchema"):
|
||||
print(f" 参数 Schema: {tool.inputSchema}")
|
||||
else:
|
||||
print("\n✗ 结果中没有 tools 属性")
|
||||
print(f"可用属性: {dir(tools_result)}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ 连接失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_direct_connection())
|
||||
142
scripts/simple_mcp_server.py
Normal file
142
scripts/simple_mcp_server.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
简单的 MCP 测试服务器
|
||||
|
||||
使用 fastmcp 创建一个简单的 MCP 服务器供测试使用。
|
||||
|
||||
安装依赖:
|
||||
pip install fastmcp uvicorn
|
||||
|
||||
运行服务器:
|
||||
python scripts/simple_mcp_server.py
|
||||
|
||||
服务器将在 http://localhost:8000/mcp 提供 MCP 服务
|
||||
"""
|
||||
|
||||
from fastmcp import FastMCP
|
||||
|
||||
# 创建 MCP 服务器实例
|
||||
mcp = FastMCP("Demo Server")
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def add(a: int, b: int) -> int:
|
||||
"""将两个数字相加
|
||||
|
||||
Args:
|
||||
a: 第一个数字
|
||||
b: 第二个数字
|
||||
|
||||
Returns:
|
||||
两个数字的和
|
||||
"""
|
||||
return a + b
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def multiply(a: float, b: float) -> float:
|
||||
"""将两个数字相乘
|
||||
|
||||
Args:
|
||||
a: 第一个数字
|
||||
b: 第二个数字
|
||||
|
||||
Returns:
|
||||
两个数字的乘积
|
||||
"""
|
||||
return a * b
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def get_weather(city: str) -> str:
|
||||
"""获取指定城市的天气信息(模拟)
|
||||
|
||||
Args:
|
||||
city: 城市名称
|
||||
|
||||
Returns:
|
||||
天气信息字符串
|
||||
"""
|
||||
# 这是一个模拟实现
|
||||
weather_data = {
|
||||
"beijing": "北京今天晴朗,温度 20°C",
|
||||
"shanghai": "上海今天多云,温度 18°C",
|
||||
"guangzhou": "广州今天有雨,温度 25°C",
|
||||
}
|
||||
|
||||
city_lower = city.lower()
|
||||
return weather_data.get(
|
||||
city_lower,
|
||||
f"{city}的天气信息暂不可用"
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def echo(message: str, repeat: int = 1) -> str:
|
||||
"""重复输出一条消息
|
||||
|
||||
Args:
|
||||
message: 要重复的消息
|
||||
repeat: 重复次数,默认为 1
|
||||
|
||||
Returns:
|
||||
重复后的消息
|
||||
"""
|
||||
return (message + "\n") * repeat
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def check_prime(number: int) -> bool:
|
||||
"""检查一个数字是否为质数
|
||||
|
||||
Args:
|
||||
number: 要检查的数字
|
||||
|
||||
Returns:
|
||||
如果是质数返回 True,否则返回 False
|
||||
"""
|
||||
if number < 2:
|
||||
return False
|
||||
|
||||
for i in range(2, int(number ** 0.5) + 1):
|
||||
if number % i == 0:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("简单 MCP 测试服务器")
|
||||
print("=" * 60)
|
||||
print("\n服务器配置:")
|
||||
print(" - 传输协议: Streamable HTTP")
|
||||
print(" - 地址: http://localhost:8000/mcp")
|
||||
print("\n可用工具:")
|
||||
print(" 1. add(a: int, b: int) -> int")
|
||||
print(" 2. multiply(a: float, b: float) -> float")
|
||||
print(" 3. get_weather(city: str) -> str")
|
||||
print(" 4. echo(message: str, repeat: int = 1) -> str")
|
||||
print(" 5. check_prime(number: int) -> bool")
|
||||
print("\n配置示例 (config/mcp.json):")
|
||||
print("""
|
||||
{
|
||||
"$schema": "./mcp.schema.json",
|
||||
"mcpServers": {
|
||||
"demo_server": {
|
||||
"enabled": true,
|
||||
"transport": {
|
||||
"type": "streamable-http",
|
||||
"url": "http://localhost:8000/mcp"
|
||||
},
|
||||
"timeout": 30
|
||||
}
|
||||
}
|
||||
}
|
||||
""")
|
||||
print("=" * 60)
|
||||
print("\n正在启动服务器...")
|
||||
print("请参考 fastmcp 官方文档了解如何运行此服务器。")
|
||||
print("文档: https://github.com/jlowin/fastmcp")
|
||||
print("\n基本命令:")
|
||||
print(" fastmcp run simple_mcp_server:mcp")
|
||||
print("=" * 60)
|
||||
15
scripts/test/demo_mcp_server.py
Normal file
15
scripts/test/demo_mcp_server.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from fastmcp import FastMCP
|
||||
|
||||
app = FastMCP(
|
||||
name="Demo MCP Server",
|
||||
streamable_http_path="/mcp"
|
||||
)
|
||||
|
||||
@app.tool()
|
||||
async def echo_tool(input: str) -> str:
|
||||
"""一个简单的回声工具"""
|
||||
return f"Echo: {input}"
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=8000, transport="streamable-http"
|
||||
)
|
||||
190
scripts/test_mcp_integration.py
Normal file
190
scripts/test_mcp_integration.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
MCP 集成测试脚本
|
||||
|
||||
测试 MCP 客户端连接、工具列表获取和工具调用功能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.core.component_registry import ComponentRegistry
|
||||
from src.plugin_system.core.mcp_client_manager import MCPClientManager
|
||||
|
||||
logger = get_logger("test_mcp_integration")
|
||||
|
||||
|
||||
async def test_mcp_client_manager():
|
||||
"""测试 MCPClientManager 基本功能"""
|
||||
print("\n" + "="*60)
|
||||
print("测试 1: MCPClientManager 连接和工具列表")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
# 初始化 MCP 客户端管理器
|
||||
manager = MCPClientManager()
|
||||
await manager.initialize()
|
||||
|
||||
print("\n✓ MCP 客户端管理器初始化成功")
|
||||
print(f"已连接服务器数量: {len(manager.clients)}")
|
||||
|
||||
# 获取所有工具
|
||||
tools = await manager.get_all_tools()
|
||||
print(f"\n获取到 {len(tools)} 个 MCP 工具:")
|
||||
|
||||
for tool in tools:
|
||||
print(f"\n 工具: {tool}")
|
||||
# 注意: 这里 tool 是字符串形式的工具名称
|
||||
# 如果需要工具详情,需要从其他地方获取
|
||||
|
||||
return manager, tools
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ 测试失败: {e}")
|
||||
logger.exception("MCPClientManager 测试失败")
|
||||
return None, []
|
||||
|
||||
|
||||
async def test_tool_call(manager: MCPClientManager, tools):
|
||||
"""测试工具调用"""
|
||||
print("\n" + "="*60)
|
||||
print("测试 2: MCP 工具调用")
|
||||
print("="*60)
|
||||
|
||||
if not tools:
|
||||
print("\n⚠ 没有可用的工具进行测试")
|
||||
return
|
||||
|
||||
try:
|
||||
# 工具列表测试已在第一个测试中完成
|
||||
print("\n✓ 工具列表获取成功")
|
||||
print(f"可用工具数量: {len(tools)}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ 工具调用测试失败: {e}")
|
||||
logger.exception("工具调用测试失败")
|
||||
|
||||
|
||||
async def test_component_registry_integration():
|
||||
"""测试 ComponentRegistry 集成"""
|
||||
print("\n" + "="*60)
|
||||
print("测试 3: ComponentRegistry MCP 工具集成")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
registry = ComponentRegistry()
|
||||
|
||||
# 加载 MCP 工具
|
||||
await registry.load_mcp_tools()
|
||||
|
||||
# 获取 MCP 工具
|
||||
mcp_tools = registry.get_mcp_tools()
|
||||
print(f"\n✓ ComponentRegistry 加载了 {len(mcp_tools)} 个 MCP 工具")
|
||||
|
||||
for tool in mcp_tools:
|
||||
print(f"\n 工具: {tool.name}")
|
||||
print(f" 描述: {tool.description}")
|
||||
print(f" 参数数量: {len(tool.parameters)}")
|
||||
|
||||
# 测试 is_mcp_tool 方法
|
||||
is_mcp = registry.is_mcp_tool(tool.name)
|
||||
print(f" is_mcp_tool 检测: {'✓' if is_mcp else '✗'}")
|
||||
|
||||
return mcp_tools
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ ComponentRegistry 集成测试失败: {e}")
|
||||
logger.exception("ComponentRegistry 集成测试失败")
|
||||
return []
|
||||
|
||||
|
||||
async def test_tool_execution(mcp_tools):
|
||||
"""测试通过适配器执行工具"""
|
||||
print("\n" + "="*60)
|
||||
print("测试 4: MCPToolAdapter 工具执行")
|
||||
print("="*60)
|
||||
|
||||
if not mcp_tools:
|
||||
print("\n⚠ 没有可用的 MCP 工具进行测试")
|
||||
return
|
||||
|
||||
try:
|
||||
# 选择第一个工具测试
|
||||
test_tool = mcp_tools[0]
|
||||
print(f"\n测试工具: {test_tool.name}")
|
||||
|
||||
# 构建测试参数
|
||||
test_args = {}
|
||||
for param_name, param_type, param_desc, is_required, enum_values in test_tool.parameters:
|
||||
if is_required:
|
||||
# 根据类型提供默认值
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType
|
||||
|
||||
if param_type == ToolParamType.STRING:
|
||||
test_args[param_name] = "test_value"
|
||||
elif param_type == ToolParamType.INTEGER:
|
||||
test_args[param_name] = 1
|
||||
elif param_type == ToolParamType.FLOAT:
|
||||
test_args[param_name] = 1.0
|
||||
elif param_type == ToolParamType.BOOLEAN:
|
||||
test_args[param_name] = True
|
||||
|
||||
print(f"测试参数: {test_args}")
|
||||
|
||||
# 执行工具
|
||||
result = await test_tool.execute(test_args)
|
||||
|
||||
if result:
|
||||
print("\n✓ 工具执行成功")
|
||||
print(f"结果类型: {result.get('type')}")
|
||||
print(f"结果内容: {result.get('content', '')[:200]}...") # 只显示前200字符
|
||||
else:
|
||||
print("\n✗ 工具执行失败,返回 None")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ 工具执行测试失败: {e}")
|
||||
logger.exception("工具执行测试失败")
|
||||
|
||||
|
||||
async def main():
|
||||
"""主测试流程"""
|
||||
print("\n" + "="*60)
|
||||
print("MCP 集成测试")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
# 测试 1: MCPClientManager 基本功能
|
||||
manager, tools = await test_mcp_client_manager()
|
||||
|
||||
if manager:
|
||||
# 测试 2: 工具调用
|
||||
await test_tool_call(manager, tools)
|
||||
|
||||
# 测试 3: ComponentRegistry 集成
|
||||
mcp_tools = await test_component_registry_integration()
|
||||
|
||||
# 测试 4: 工具执行
|
||||
await test_tool_execution(mcp_tools)
|
||||
|
||||
# 关闭连接
|
||||
await manager.close()
|
||||
print("\n✓ MCP 客户端连接已关闭")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("测试完成")
|
||||
print("="*60 + "\n")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n测试被用户中断")
|
||||
except Exception as e:
|
||||
print(f"\n测试过程中发生错误: {e}")
|
||||
logger.exception("测试失败")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -219,4 +219,4 @@ async def get_llm_stats(
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"获取LLM统计信息失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -878,7 +878,7 @@ class MemorySystem:
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"检索瞬时记忆失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
# 最终截断
|
||||
final_memories = final_memories[:effective_limit]
|
||||
|
||||
|
||||
@@ -72,4 +72,4 @@ class MessageCollectionProcessor:
|
||||
"active_buffers": len(self.message_buffers),
|
||||
"total_buffered_messages": total_buffered_messages,
|
||||
"buffer_capacity_per_chat": self.buffer_size,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
专用于存储和检索消息集合,以提供即时上下文。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
@@ -125,7 +124,7 @@ class MessageCollectionStorage:
|
||||
if results and results.get("ids") and results["ids"][0]:
|
||||
for metadata in results["metadatas"][0]:
|
||||
collections.append(MessageCollection.from_dict(metadata))
|
||||
|
||||
|
||||
return collections
|
||||
except Exception as e:
|
||||
logger.error(f"检索相关消息集合失败: {e}", exc_info=True)
|
||||
@@ -163,7 +162,7 @@ class MessageCollectionStorage:
|
||||
|
||||
# 格式化消息集合为 prompt 上下文
|
||||
final_context = "\n\n---\n\n".join(context_parts) + "\n\n---"
|
||||
|
||||
|
||||
logger.info(f"🔗 为查询 '{query_text[:50]}...' 在聊天 '{chat_id}' 中找到 {len(collections)} 个相关消息集合上下文")
|
||||
return f"\n{final_context}\n"
|
||||
|
||||
@@ -192,4 +191,4 @@ class MessageCollectionStorage:
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取消息集合存储统计失败: {e}")
|
||||
return {}
|
||||
return {}
|
||||
|
||||
@@ -659,6 +659,41 @@ class ChatBot:
|
||||
group_name = getattr(group_info, "group_name", None)
|
||||
group_platform = getattr(group_info, "platform", None)
|
||||
|
||||
# 准备 additional_config,将 format_info 嵌入其中
|
||||
additional_config_str = None
|
||||
try:
|
||||
import orjson
|
||||
|
||||
additional_config_data = {}
|
||||
|
||||
# 首先获取adapter传递的additional_config
|
||||
if hasattr(message_info, 'additional_config') and message_info.additional_config:
|
||||
if isinstance(message_info.additional_config, dict):
|
||||
additional_config_data = message_info.additional_config.copy()
|
||||
elif isinstance(message_info.additional_config, str):
|
||||
try:
|
||||
additional_config_data = orjson.loads(message_info.additional_config)
|
||||
except Exception as e:
|
||||
logger.warning(f"无法解析 additional_config JSON: {e}")
|
||||
additional_config_data = {}
|
||||
|
||||
# 然后添加format_info到additional_config中
|
||||
if hasattr(message_info, 'format_info') and message_info.format_info:
|
||||
try:
|
||||
format_info_dict = message_info.format_info.to_dict()
|
||||
additional_config_data["format_info"] = format_info_dict
|
||||
logger.debug(f"[bot.py] 嵌入 format_info 到 additional_config: {format_info_dict}")
|
||||
except Exception as e:
|
||||
logger.warning(f"将 format_info 转换为字典失败: {e}")
|
||||
else:
|
||||
logger.warning(f"[bot.py] [问题] 消息缺少 format_info: message_id={message_id}")
|
||||
|
||||
# 序列化为JSON字符串
|
||||
if additional_config_data:
|
||||
additional_config_str = orjson.dumps(additional_config_data).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"准备 additional_config 失败: {e}")
|
||||
|
||||
# 创建数据库消息对象
|
||||
db_message = DatabaseMessages(
|
||||
message_id=message_id,
|
||||
@@ -674,6 +709,7 @@ class ChatBot:
|
||||
is_notify=bool(message.is_notify),
|
||||
is_public_notice=bool(message.is_public_notice),
|
||||
notice_type=message.notice_type,
|
||||
additional_config=additional_config_str,
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
|
||||
@@ -156,6 +156,13 @@ class ChatStream:
|
||||
|
||||
return instance
|
||||
|
||||
def get_raw_id(self) -> str:
|
||||
"""获取原始的、未哈希的聊天流ID字符串"""
|
||||
if self.group_info:
|
||||
return f"{self.platform}:{self.group_info.group_id}:group"
|
||||
else:
|
||||
return f"{self.platform}:{self.user_info.user_id}:private"
|
||||
|
||||
def update_active_time(self):
|
||||
"""更新最后活跃时间"""
|
||||
self.last_active_time = time.time()
|
||||
@@ -213,8 +220,8 @@ class ChatStream:
|
||||
priority_info=json.dumps(getattr(message, "priority_info", None))
|
||||
if getattr(message, "priority_info", None)
|
||||
else None,
|
||||
# 额外配置
|
||||
additional_config=getattr(message_info, "additional_config", None),
|
||||
# 额外配置 - 需要将 format_info 嵌入到 additional_config 中
|
||||
additional_config=self._prepare_additional_config(message_info),
|
||||
# 用户信息
|
||||
user_id=str(getattr(user_info, "user_id", "")),
|
||||
user_nickname=getattr(user_info, "user_nickname", ""),
|
||||
@@ -253,8 +260,59 @@ class ChatStream:
|
||||
f"interest_value: {db_message.interest_value}"
|
||||
)
|
||||
|
||||
def _prepare_additional_config(self, message_info) -> str | None:
|
||||
"""
|
||||
准备 additional_config,将 format_info 嵌入其中
|
||||
|
||||
这个方法模仿 storage.py 中的逻辑,确保 DatabaseMessages 中的 additional_config
|
||||
包含 format_info,使得 action_modifier 能够正确获取适配器支持的消息类型
|
||||
|
||||
Args:
|
||||
message_info: BaseMessageInfo 对象
|
||||
|
||||
Returns:
|
||||
str | None: JSON 字符串格式的 additional_config,如果为空则返回 None
|
||||
"""
|
||||
import orjson
|
||||
|
||||
# 首先获取adapter传递的additional_config
|
||||
additional_config_data = {}
|
||||
if hasattr(message_info, 'additional_config') and message_info.additional_config:
|
||||
if isinstance(message_info.additional_config, dict):
|
||||
additional_config_data = message_info.additional_config.copy()
|
||||
elif isinstance(message_info.additional_config, str):
|
||||
# 如果是字符串,尝试解析
|
||||
try:
|
||||
additional_config_data = orjson.loads(message_info.additional_config)
|
||||
except Exception as e:
|
||||
logger.warning(f"无法解析 additional_config JSON: {e}")
|
||||
additional_config_data = {}
|
||||
|
||||
# 然后添加format_info到additional_config中
|
||||
if hasattr(message_info, 'format_info') and message_info.format_info:
|
||||
try:
|
||||
format_info_dict = message_info.format_info.to_dict()
|
||||
additional_config_data["format_info"] = format_info_dict
|
||||
logger.debug(f"嵌入 format_info 到 additional_config: {format_info_dict}")
|
||||
except Exception as e:
|
||||
logger.warning(f"将 format_info 转换为字典失败: {e}")
|
||||
else:
|
||||
logger.warning(f"[问题] 消息缺少 format_info: message_id={getattr(message_info, 'message_id', 'unknown')}")
|
||||
logger.warning("[问题] 这可能导致 Action 无法正确检查适配器支持的类型")
|
||||
|
||||
# 序列化为JSON字符串
|
||||
if additional_config_data:
|
||||
try:
|
||||
return orjson.dumps(additional_config_data).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"序列化 additional_config 失败: {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
def _safe_get_actions(self, message: "MessageRecv") -> list | None:
|
||||
"""安全获取消息的actions字段"""
|
||||
import json
|
||||
|
||||
try:
|
||||
actions = getattr(message, "actions", None)
|
||||
if actions is None:
|
||||
@@ -263,8 +321,6 @@ class ChatStream:
|
||||
# 如果是字符串,尝试解析为JSON
|
||||
if isinstance(actions, str):
|
||||
try:
|
||||
import json
|
||||
|
||||
actions = json.loads(actions)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"无法解析actions JSON字符串: {actions}")
|
||||
|
||||
@@ -9,10 +9,10 @@ from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.utils.self_voice_cache import consume_self_voice_text
|
||||
from src.chat.utils.utils_image import get_image_manager
|
||||
from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available
|
||||
from src.chat.utils.utils_voice import get_voice_text
|
||||
from src.chat.utils.self_voice_cache import consume_self_voice_text
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
@@ -212,7 +212,7 @@ class MessageRecv(Message):
|
||||
return f"[语音:{cached_text}]"
|
||||
else:
|
||||
logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
|
||||
|
||||
|
||||
# 标准语音识别流程 (也作为缓存未命中的后备方案)
|
||||
if isinstance(segment.data, str):
|
||||
return await get_voice_text(segment.data)
|
||||
@@ -370,7 +370,7 @@ class MessageRecvS4U(MessageRecv):
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = True
|
||||
|
||||
|
||||
# 检查消息是否由机器人自己发送
|
||||
# 检查消息是否由机器人自己发送
|
||||
if self.message_info and self.message_info.user_info and str(self.message_info.user_info.user_id) == str(global_config.bot.qq_account):
|
||||
|
||||
@@ -1,487 +0,0 @@
|
||||
"""
|
||||
优化版聊天流 - 实现写时复制机制
|
||||
避免不必要的深拷贝开销,提升多流并发性能
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .message import MessageRecv
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("optimized_chat_stream")
|
||||
|
||||
|
||||
class SharedContext:
|
||||
"""共享上下文数据 - 只读数据结构"""
|
||||
|
||||
def __init__(self, stream_id: str, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None):
|
||||
self.stream_id = stream_id
|
||||
self.platform = platform
|
||||
self.user_info = user_info
|
||||
self.group_info = group_info
|
||||
self.create_time = time.time()
|
||||
self._frozen = True
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if hasattr(self, "_frozen") and self._frozen and name not in ["_frozen"]:
|
||||
raise AttributeError(f"SharedContext is frozen, cannot modify {name}")
|
||||
super().__setattr__(name, value)
|
||||
|
||||
|
||||
class LocalChanges:
|
||||
"""本地修改跟踪器"""
|
||||
|
||||
def __init__(self):
|
||||
self._changes: dict[str, Any] = {}
|
||||
self._dirty = False
|
||||
|
||||
def set_change(self, key: str, value: Any):
|
||||
"""设置修改项"""
|
||||
self._changes[key] = value
|
||||
self._dirty = True
|
||||
|
||||
def get_change(self, key: str, default: Any = None) -> Any:
|
||||
"""获取修改项"""
|
||||
return self._changes.get(key, default)
|
||||
|
||||
def has_changes(self) -> bool:
|
||||
"""是否有修改"""
|
||||
return self._dirty
|
||||
|
||||
def get_changes(self) -> dict[str, Any]:
|
||||
"""获取所有修改"""
|
||||
return self._changes.copy()
|
||||
|
||||
def clear_changes(self):
|
||||
"""清除修改记录"""
|
||||
self._changes.clear()
|
||||
self._dirty = False
|
||||
|
||||
|
||||
class OptimizedChatStream:
|
||||
"""优化版聊天流 - 使用写时复制机制"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream_id: str,
|
||||
platform: str,
|
||||
user_info: UserInfo,
|
||||
group_info: GroupInfo | None = None,
|
||||
data: dict | None = None,
|
||||
):
|
||||
# 共享的只读数据
|
||||
self._shared_context = SharedContext(
|
||||
stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info
|
||||
)
|
||||
|
||||
# 本地修改数据
|
||||
self._local_changes = LocalChanges()
|
||||
|
||||
# 写时复制标志
|
||||
self._copy_on_write = False
|
||||
|
||||
# 基础参数
|
||||
self.base_interest_energy = data.get("base_interest_energy", 0.5) if data else 0.5
|
||||
self._focus_energy = data.get("focus_energy", 0.5) if data else 0.5
|
||||
self.no_reply_consecutive = 0
|
||||
|
||||
# 创建StreamContext(延迟创建)
|
||||
self._stream_context = None
|
||||
self._context_manager = None
|
||||
|
||||
# 更新活跃时间
|
||||
self.update_active_time()
|
||||
|
||||
# 保存标志
|
||||
self.saved = False
|
||||
|
||||
@property
|
||||
def stream_id(self) -> str:
|
||||
return self._shared_context.stream_id
|
||||
|
||||
@property
|
||||
def platform(self) -> str:
|
||||
return self._shared_context.platform
|
||||
|
||||
@property
|
||||
def user_info(self) -> UserInfo:
|
||||
return self._shared_context.user_info
|
||||
|
||||
@user_info.setter
|
||||
def user_info(self, value: UserInfo):
|
||||
"""修改用户信息时触发写时复制"""
|
||||
self._ensure_copy_on_write()
|
||||
# 由于SharedContext是frozen的,我们需要在本地修改中记录
|
||||
self._local_changes.set_change("user_info", value)
|
||||
|
||||
@property
|
||||
def group_info(self) -> GroupInfo | None:
|
||||
if self._local_changes.has_changes() and "group_info" in self._local_changes._changes:
|
||||
return self._local_changes.get_change("group_info")
|
||||
return self._shared_context.group_info
|
||||
|
||||
@group_info.setter
|
||||
def group_info(self, value: GroupInfo | None):
|
||||
"""修改群组信息时触发写时复制"""
|
||||
self._ensure_copy_on_write()
|
||||
self._local_changes.set_change("group_info", value)
|
||||
|
||||
@property
|
||||
def create_time(self) -> float:
|
||||
if self._local_changes.has_changes() and "create_time" in self._local_changes._changes:
|
||||
return self._local_changes.get_change("create_time")
|
||||
return self._shared_context.create_time
|
||||
|
||||
@property
|
||||
def last_active_time(self) -> float:
|
||||
return self._local_changes.get_change("last_active_time", self.create_time)
|
||||
|
||||
@last_active_time.setter
|
||||
def last_active_time(self, value: float):
|
||||
self._local_changes.set_change("last_active_time", value)
|
||||
self.saved = False
|
||||
|
||||
@property
|
||||
def sleep_pressure(self) -> float:
|
||||
return self._local_changes.get_change("sleep_pressure", 0.0)
|
||||
|
||||
@sleep_pressure.setter
|
||||
def sleep_pressure(self, value: float):
|
||||
self._local_changes.set_change("sleep_pressure", value)
|
||||
self.saved = False
|
||||
|
||||
def _ensure_copy_on_write(self):
|
||||
"""确保写时复制机制生效"""
|
||||
if not self._copy_on_write:
|
||||
self._copy_on_write = True
|
||||
# 深拷贝共享上下文到本地
|
||||
logger.debug(f"触发写时复制: {self.stream_id}")
|
||||
|
||||
def _get_effective_user_info(self) -> UserInfo:
|
||||
"""获取有效的用户信息"""
|
||||
if self._local_changes.has_changes() and "user_info" in self._local_changes._changes:
|
||||
return self._local_changes.get_change("user_info")
|
||||
return self._shared_context.user_info
|
||||
|
||||
def _get_effective_group_info(self) -> GroupInfo | None:
|
||||
"""获取有效的群组信息"""
|
||||
if self._local_changes.has_changes() and "group_info" in self._local_changes._changes:
|
||||
return self._local_changes.get_change("group_info")
|
||||
return self._shared_context.group_info
|
||||
|
||||
def update_active_time(self):
|
||||
"""更新最后活跃时间"""
|
||||
self.last_active_time = time.time()
|
||||
|
||||
def set_context(self, message: "MessageRecv"):
|
||||
"""设置聊天消息上下文"""
|
||||
# 确保stream_context存在
|
||||
if self._stream_context is None:
|
||||
self._ensure_copy_on_write()
|
||||
self._create_stream_context()
|
||||
|
||||
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
|
||||
import json
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
message_info = getattr(message, "message_info", {})
|
||||
user_info = getattr(message_info, "user_info", {})
|
||||
group_info = getattr(message_info, "group_info", {})
|
||||
|
||||
reply_to = None
|
||||
if hasattr(message, "message_segment") and message.message_segment:
|
||||
reply_to = self._extract_reply_from_segment(message.message_segment)
|
||||
|
||||
db_message = DatabaseMessages(
|
||||
message_id=getattr(message, "message_id", ""),
|
||||
time=getattr(message, "time", time.time()),
|
||||
chat_id=self._generate_chat_id(message_info),
|
||||
reply_to=reply_to,
|
||||
interest_value=getattr(message, "interest_value", 0.0),
|
||||
key_words=json.dumps(getattr(message, "key_words", []), ensure_ascii=False)
|
||||
if getattr(message, "key_words", None)
|
||||
else None,
|
||||
key_words_lite=json.dumps(getattr(message, "key_words_lite", []), ensure_ascii=False)
|
||||
if getattr(message, "key_words_lite", None)
|
||||
else None,
|
||||
is_mentioned=getattr(message, "is_mentioned", None),
|
||||
is_at=getattr(message, "is_at", False),
|
||||
is_emoji=getattr(message, "is_emoji", False),
|
||||
is_picid=getattr(message, "is_picid", False),
|
||||
is_voice=getattr(message, "is_voice", False),
|
||||
is_video=getattr(message, "is_video", False),
|
||||
is_command=getattr(message, "is_command", False),
|
||||
is_notify=getattr(message, "is_notify", False),
|
||||
is_public_notice=getattr(message, "is_public_notice", False),
|
||||
notice_type=getattr(message, "notice_type", None),
|
||||
processed_plain_text=getattr(message, "processed_plain_text", ""),
|
||||
display_message=getattr(message, "processed_plain_text", ""),
|
||||
priority_mode=getattr(message, "priority_mode", None),
|
||||
priority_info=json.dumps(getattr(message, "priority_info", None))
|
||||
if getattr(message, "priority_info", None)
|
||||
else None,
|
||||
additional_config=getattr(message_info, "additional_config", None),
|
||||
user_id=str(getattr(user_info, "user_id", "")),
|
||||
user_nickname=getattr(user_info, "user_nickname", ""),
|
||||
user_cardname=getattr(user_info, "user_cardname", None),
|
||||
user_platform=getattr(user_info, "platform", ""),
|
||||
chat_info_group_id=getattr(group_info, "group_id", None),
|
||||
chat_info_group_name=getattr(group_info, "group_name", None),
|
||||
chat_info_group_platform=getattr(group_info, "platform", None),
|
||||
chat_info_user_id=str(getattr(user_info, "user_id", "")),
|
||||
chat_info_user_nickname=getattr(user_info, "user_nickname", ""),
|
||||
chat_info_user_cardname=getattr(user_info, "user_cardname", None),
|
||||
chat_info_user_platform=getattr(user_info, "platform", ""),
|
||||
chat_info_stream_id=self.stream_id,
|
||||
chat_info_platform=self.platform,
|
||||
chat_info_create_time=self.create_time,
|
||||
chat_info_last_active_time=self.last_active_time,
|
||||
actions=self._safe_get_actions(message),
|
||||
should_reply=getattr(message, "should_reply", False),
|
||||
)
|
||||
|
||||
self._stream_context.set_current_message(db_message)
|
||||
self._stream_context.priority_mode = getattr(message, "priority_mode", None)
|
||||
self._stream_context.priority_info = getattr(message, "priority_info", None)
|
||||
|
||||
logger.debug(
|
||||
f"消息数据转移完成 - message_id: {db_message.message_id}, "
|
||||
f"chat_id: {db_message.chat_id}, "
|
||||
f"interest_value: {db_message.interest_value}"
|
||||
)
|
||||
|
||||
def _create_stream_context(self):
|
||||
"""创建StreamContext"""
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
self._stream_context = StreamContext(
|
||||
stream_id=self.stream_id,
|
||||
chat_type=ChatType.GROUP if self.group_info else ChatType.PRIVATE,
|
||||
chat_mode=ChatMode.NORMAL,
|
||||
)
|
||||
|
||||
# 创建单流上下文管理器
|
||||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||||
|
||||
self._context_manager = SingleStreamContextManager(stream_id=self.stream_id, context=self._stream_context)
|
||||
|
||||
@property
|
||||
def stream_context(self):
|
||||
"""获取StreamContext"""
|
||||
if self._stream_context is None:
|
||||
self._ensure_copy_on_write()
|
||||
self._create_stream_context()
|
||||
return self._stream_context
|
||||
|
||||
@property
|
||||
def context_manager(self):
|
||||
"""获取ContextManager"""
|
||||
if self._context_manager is None:
|
||||
self._ensure_copy_on_write()
|
||||
self._create_stream_context()
|
||||
return self._context_manager
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典格式 - 考虑本地修改"""
|
||||
user_info = self._get_effective_user_info()
|
||||
group_info = self._get_effective_group_info()
|
||||
|
||||
return {
|
||||
"stream_id": self.stream_id,
|
||||
"platform": self.platform,
|
||||
"user_info": user_info.to_dict() if user_info else None,
|
||||
"group_info": group_info.to_dict() if group_info else None,
|
||||
"create_time": self.create_time,
|
||||
"last_active_time": self.last_active_time,
|
||||
"sleep_pressure": self.sleep_pressure,
|
||||
"focus_energy": self.focus_energy,
|
||||
"base_interest_energy": self.base_interest_energy,
|
||||
"stream_context_chat_type": self.stream_context.chat_type.value,
|
||||
"stream_context_chat_mode": self.stream_context.chat_mode.value,
|
||||
"interruption_count": self.stream_context.interruption_count,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "OptimizedChatStream":
|
||||
"""从字典创建实例"""
|
||||
user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
|
||||
group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
|
||||
|
||||
instance = cls(
|
||||
stream_id=data["stream_id"],
|
||||
platform=data["platform"],
|
||||
user_info=user_info, # type: ignore
|
||||
group_info=group_info,
|
||||
data=data,
|
||||
)
|
||||
|
||||
# 恢复stream_context信息
|
||||
if "stream_context_chat_type" in data:
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"])
|
||||
if "stream_context_chat_mode" in data:
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"])
|
||||
|
||||
# 恢复interruption_count信息
|
||||
if "interruption_count" in data:
|
||||
instance.stream_context.interruption_count = data["interruption_count"]
|
||||
|
||||
return instance
|
||||
|
||||
def _safe_get_actions(self, message: "MessageRecv") -> list | None:
|
||||
"""安全获取消息的actions字段"""
|
||||
try:
|
||||
actions = getattr(message, "actions", None)
|
||||
if actions is None:
|
||||
return None
|
||||
|
||||
if isinstance(actions, str):
|
||||
try:
|
||||
import json
|
||||
|
||||
actions = json.loads(actions)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"无法解析actions JSON字符串: {actions}")
|
||||
return None
|
||||
|
||||
if isinstance(actions, list):
|
||||
filtered_actions = [action for action in actions if action is not None and isinstance(action, str)]
|
||||
return filtered_actions if filtered_actions else None
|
||||
else:
|
||||
logger.warning(f"actions字段类型不支持: {type(actions)}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"获取actions字段失败: {e}")
|
||||
return None
|
||||
|
||||
def _extract_reply_from_segment(self, segment) -> str | None:
|
||||
"""从消息段中提取reply_to信息"""
|
||||
try:
|
||||
if hasattr(segment, "type") and segment.type == "seglist":
|
||||
if hasattr(segment, "data") and segment.data:
|
||||
for seg in segment.data:
|
||||
reply_id = self._extract_reply_from_segment(seg)
|
||||
if reply_id:
|
||||
return reply_id
|
||||
elif hasattr(segment, "type") and segment.type == "reply":
|
||||
return str(segment.data) if segment.data else None
|
||||
except Exception as e:
|
||||
logger.warning(f"提取reply_to信息失败: {e}")
|
||||
return None
|
||||
|
||||
def _generate_chat_id(self, message_info) -> str:
|
||||
"""生成chat_id,基于群组或用户信息"""
|
||||
try:
|
||||
group_info = getattr(message_info, "group_info", None)
|
||||
user_info = getattr(message_info, "user_info", None)
|
||||
|
||||
if group_info and hasattr(group_info, "group_id") and group_info.group_id:
|
||||
return f"{self.platform}_{group_info.group_id}"
|
||||
elif user_info and hasattr(user_info, "user_id") and user_info.user_id:
|
||||
return f"{self.platform}_{user_info.user_id}_private"
|
||||
else:
|
||||
return self.stream_id
|
||||
except Exception as e:
|
||||
logger.warning(f"生成chat_id失败: {e}")
|
||||
return self.stream_id
|
||||
|
||||
@property
|
||||
def focus_energy(self) -> float:
|
||||
"""获取缓存的focus_energy值"""
|
||||
return self._focus_energy
|
||||
|
||||
async def calculate_focus_energy(self) -> float:
|
||||
"""异步计算focus_energy"""
|
||||
try:
|
||||
all_messages = self.context_manager.get_messages(limit=global_config.chat.max_context_size)
|
||||
|
||||
user_id = None
|
||||
effective_user_info = self._get_effective_user_info()
|
||||
if effective_user_info and hasattr(effective_user_info, "user_id"):
|
||||
user_id = str(effective_user_info.user_id)
|
||||
|
||||
from src.chat.energy_system import energy_manager
|
||||
|
||||
energy = await energy_manager.calculate_focus_energy(
|
||||
stream_id=self.stream_id, messages=all_messages, user_id=user_id
|
||||
)
|
||||
|
||||
self._focus_energy = energy
|
||||
|
||||
logger.debug(f"聊天流 {self.stream_id} 能量: {energy:.3f}")
|
||||
return energy
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取focus_energy失败: {e}", exc_info=True)
|
||||
return self._focus_energy
|
||||
|
||||
@focus_energy.setter
|
||||
def focus_energy(self, value: float):
|
||||
"""设置focus_energy值"""
|
||||
self._focus_energy = max(0.0, min(1.0, value))
|
||||
|
||||
async def _get_user_relationship_score(self) -> float:
|
||||
"""获取用户关系分"""
|
||||
try:
|
||||
from src.plugin_system.apis.scoring_api import scoring_api
|
||||
|
||||
effective_user_info = self._get_effective_user_info()
|
||||
if effective_user_info and hasattr(effective_user_info, "user_id"):
|
||||
user_id = str(effective_user_info.user_id)
|
||||
relationship_score = await scoring_api.get_user_relationship_score(user_id)
|
||||
logger.debug(f"OptimizedChatStream {self.stream_id}: 用户关系分 = {relationship_score:.3f}")
|
||||
return relationship_score
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"OptimizedChatStream {self.stream_id}: 关系分计算失败: {e}")
|
||||
|
||||
return 0.3
|
||||
|
||||
def create_snapshot(self) -> "OptimizedChatStream":
|
||||
"""创建当前状态的快照(用于缓存)"""
|
||||
# 创建一个新的实例,共享相同的上下文
|
||||
snapshot = OptimizedChatStream(
|
||||
stream_id=self.stream_id,
|
||||
platform=self.platform,
|
||||
user_info=self._get_effective_user_info(),
|
||||
group_info=self._get_effective_group_info(),
|
||||
)
|
||||
|
||||
# 复制本地修改(但不触发写时复制)
|
||||
snapshot._local_changes._changes = self._local_changes.get_changes()
|
||||
snapshot._local_changes._dirty = self._local_changes._dirty
|
||||
snapshot._focus_energy = self._focus_energy
|
||||
snapshot.base_interest_energy = self.base_interest_energy
|
||||
snapshot.no_reply_consecutive = self.no_reply_consecutive
|
||||
snapshot.saved = self.saved
|
||||
|
||||
return snapshot
|
||||
|
||||
|
||||
# 为了向后兼容,创建一个工厂函数
|
||||
def create_optimized_chat_stream(
|
||||
stream_id: str,
|
||||
platform: str,
|
||||
user_info: UserInfo,
|
||||
group_info: GroupInfo | None = None,
|
||||
data: dict | None = None,
|
||||
) -> OptimizedChatStream:
|
||||
"""创建优化版聊天流实例"""
|
||||
return OptimizedChatStream(
|
||||
stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info, data=data
|
||||
)
|
||||
@@ -796,44 +796,63 @@ class DefaultReplyer:
|
||||
async def build_keywords_reaction_prompt(self, target: str | None) -> str:
|
||||
"""构建关键词反应提示
|
||||
|
||||
该方法根据配置的关键词和正则表达式规则,
|
||||
检查目标消息内容是否触发了任何反应。
|
||||
如果匹配成功,它会生成一个包含所有触发反应的提示字符串,
|
||||
用于指导LLM的回复。
|
||||
|
||||
Args:
|
||||
target: 目标消息内容
|
||||
|
||||
Returns:
|
||||
str: 关键词反应提示字符串
|
||||
str: 关键词反应提示字符串,如果没有触发任何反应则为空字符串
|
||||
"""
|
||||
# 关键词检测与反应
|
||||
keywords_reaction_prompt = ""
|
||||
if target is None:
|
||||
return ""
|
||||
|
||||
reaction_prompt = ""
|
||||
try:
|
||||
# 添加None检查,防止NoneType错误
|
||||
if target is None:
|
||||
return keywords_reaction_prompt
|
||||
current_chat_stream_id_str = self.chat_stream.get_raw_id()
|
||||
# 2. 筛选适用的规则(全局规则 + 特定于当前聊天的规则)
|
||||
applicable_rules = []
|
||||
for rule in global_config.reaction.rules:
|
||||
if rule.chat_stream_id == "" or rule.chat_stream_id == current_chat_stream_id_str:
|
||||
applicable_rules.append(rule) # noqa: PERF401
|
||||
|
||||
# 处理关键词规则
|
||||
for rule in global_config.keyword_reaction.keyword_rules:
|
||||
if any(keyword in target for keyword in rule.keywords):
|
||||
logger.info(f"检测到关键词规则:{rule.keywords},触发反应:{rule.reaction}")
|
||||
keywords_reaction_prompt += f"{rule.reaction},"
|
||||
# 3. 遍历适用规则并执行匹配
|
||||
for rule in applicable_rules:
|
||||
matched = False
|
||||
if rule.rule_type == "keyword":
|
||||
if any(keyword in target for keyword in rule.patterns):
|
||||
logger.info(f"检测到关键词规则:{rule.patterns},触发反应:{rule.reaction}")
|
||||
reaction_prompt += f"{rule.reaction},"
|
||||
matched = True
|
||||
|
||||
elif rule.rule_type == "regex":
|
||||
for pattern_str in rule.patterns:
|
||||
try:
|
||||
pattern = re.compile(pattern_str)
|
||||
if result := pattern.search(target):
|
||||
reaction = rule.reaction
|
||||
# 替换命名捕获组
|
||||
for name, content in result.groupdict().items():
|
||||
reaction = reaction.replace(f"[{name}]", content)
|
||||
logger.info(f"匹配到正则表达式:{pattern_str},触发反应:{reaction}")
|
||||
reaction_prompt += f"{reaction},"
|
||||
matched = True
|
||||
break # 一个正则规则里只要有一个 pattern 匹配成功即可
|
||||
except re.error as e:
|
||||
logger.error(f"正则表达式编译错误: {pattern_str}, 错误信息: {e!s}")
|
||||
continue
|
||||
|
||||
if matched:
|
||||
# 如果需要每条消息只触发一个反应规则,可以在这里 break
|
||||
pass
|
||||
|
||||
# 处理正则表达式规则
|
||||
for rule in global_config.keyword_reaction.regex_rules:
|
||||
for pattern_str in rule.regex:
|
||||
try:
|
||||
pattern = re.compile(pattern_str)
|
||||
if result := pattern.search(target):
|
||||
reaction = rule.reaction
|
||||
for name, content in result.groupdict().items():
|
||||
reaction = reaction.replace(f"[{name}]", content)
|
||||
logger.info(f"匹配到正则表达式:{pattern_str},触发反应:{reaction}")
|
||||
keywords_reaction_prompt += f"{reaction},"
|
||||
break
|
||||
except re.error as e:
|
||||
logger.error(f"正则表达式编译错误: {pattern_str}, 错误信息: {e!s}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"关键词检测与反应时发生异常: {e!s}", exc_info=True)
|
||||
|
||||
return keywords_reaction_prompt
|
||||
return reaction_prompt
|
||||
|
||||
async def build_notice_block(self, chat_id: str) -> str:
|
||||
"""构建notice信息块
|
||||
|
||||
@@ -303,7 +303,7 @@ class Prompt:
|
||||
|
||||
@staticmethod
|
||||
def _process_escaped_braces(template) -> str:
|
||||
"""预处理模板,将 `\{` 和 `\}` 替换为临时标记."""
|
||||
r"""预处理模板,将 `\{` 和 `\}` 替换为临时标记."""
|
||||
if isinstance(template, list):
|
||||
template = "\n".join(str(item) for item in template)
|
||||
elif not isinstance(template, str):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Type
|
||||
|
||||
from src.chat.utils.prompt_params import PromptParameters
|
||||
from src.common.logger import get_logger
|
||||
@@ -21,7 +20,7 @@ class PromptComponentManager:
|
||||
3. 提供一个接口,以便在构建核心Prompt时,能够获取并执行所有相关的组件。
|
||||
"""
|
||||
|
||||
def _get_rules_for(self, target_prompt_name: str) -> list[tuple[InjectionRule, Type[BasePrompt]]]:
|
||||
def _get_rules_for(self, target_prompt_name: str) -> list[tuple[InjectionRule, type[BasePrompt]]]:
|
||||
"""
|
||||
获取指定目标Prompt的所有注入规则及其关联的组件类。
|
||||
|
||||
|
||||
@@ -6,15 +6,14 @@
|
||||
避免不必要的自我语音识别。
|
||||
"""
|
||||
import hashlib
|
||||
from typing import Dict
|
||||
|
||||
# 一个简单的内存缓存,用于将机器人自己发送的语音消息映射到其原始文本。
|
||||
# 键是语音base64内容的SHA256哈希值。
|
||||
_self_voice_cache: Dict[str, str] = {}
|
||||
_self_voice_cache: dict[str, str] = {}
|
||||
|
||||
def get_voice_key(base64_content: str) -> str:
|
||||
"""为语音内容生成一个一致的键。"""
|
||||
return hashlib.sha256(base64_content.encode('utf-8')).hexdigest()
|
||||
return hashlib.sha256(base64_content.encode("utf-8")).hexdigest()
|
||||
|
||||
def register_self_voice(base64_content: str, text: str):
|
||||
"""
|
||||
@@ -39,4 +38,4 @@ def consume_self_voice_text(base64_content: str) -> str | None:
|
||||
str | None: 如果找到,则返回原始文本,否则返回None。
|
||||
"""
|
||||
key = get_voice_key(base64_content)
|
||||
return _self_voice_cache.pop(key, None)
|
||||
return _self_voice_cache.pop(key, None)
|
||||
|
||||
@@ -430,19 +430,13 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
|
||||
if global_config.response_splitter.enable and enable_splitter:
|
||||
logger.info(f"回复分割器已启用,模式: {global_config.response_splitter.split_mode}。")
|
||||
|
||||
split_mode = global_config.response_splitter.split_mode
|
||||
|
||||
if split_mode == "llm" and "[SPLIT]" in cleaned_text:
|
||||
if "[SPLIT]" in cleaned_text:
|
||||
logger.debug("检测到 [SPLIT] 标记,使用 LLM 自定义分割。")
|
||||
split_sentences_raw = cleaned_text.split("[SPLIT]")
|
||||
split_sentences = [s.strip() for s in split_sentences_raw if s.strip()]
|
||||
else:
|
||||
if split_mode == "llm":
|
||||
logger.debug("未检测到 [SPLIT] 标记,本次不进行分割。")
|
||||
split_sentences = [cleaned_text]
|
||||
else: # mode == "punctuation"
|
||||
logger.debug("使用基于标点的传统模式进行分割。")
|
||||
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
|
||||
logger.debug("使用基于标点的传统模式进行分割。")
|
||||
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
|
||||
else:
|
||||
logger.debug("回复分割器已禁用。")
|
||||
split_sentences = [cleaned_text]
|
||||
|
||||
@@ -19,10 +19,11 @@ async def get_voice_text(voice_base64: str) -> str:
|
||||
|
||||
# 如果选择本地识别
|
||||
if asr_provider == "local":
|
||||
from src.plugin_system.apis import tool_api
|
||||
import tempfile
|
||||
import base64
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from src.plugin_system.apis import tool_api
|
||||
|
||||
local_asr_tool = tool_api.get_tool_instance("local_asr")
|
||||
if not local_asr_tool:
|
||||
@@ -39,8 +40,8 @@ async def get_voice_text(voice_base64: str) -> str:
|
||||
text = await local_asr_tool.execute(function_args={"audio_path": audio_path})
|
||||
if "失败" in text or "出错" in text or "错误" in text:
|
||||
logger.warning(f"本地语音识别失败: {text}")
|
||||
return f"[语音(本地识别失败)]"
|
||||
|
||||
return "[语音(本地识别失败)]"
|
||||
|
||||
logger.info(f"本地语音识别成功: {text}")
|
||||
return f"[语音] {text}"
|
||||
|
||||
|
||||
@@ -208,22 +208,28 @@ class StreamContext(BaseDataModel):
|
||||
bool: 如果消息支持所有指定的类型则返回True,否则返回False
|
||||
"""
|
||||
if not self.current_message:
|
||||
logger.warning("[问题] StreamContext.check_types: current_message 为 None")
|
||||
return False
|
||||
|
||||
if not types:
|
||||
# 如果没有指定类型要求,默认为支持
|
||||
return True
|
||||
|
||||
logger.debug(f"[check_types] 检查消息是否支持类型: {types}")
|
||||
|
||||
# 优先从additional_config中获取format_info
|
||||
if hasattr(self.current_message, "additional_config") and self.current_message.additional_config:
|
||||
try:
|
||||
import orjson
|
||||
|
||||
logger.debug(f"[check_types] additional_config 类型: {type(self.current_message.additional_config)}")
|
||||
config = orjson.loads(self.current_message.additional_config)
|
||||
logger.debug(f"[check_types] 解析后的 config 键: {config.keys() if isinstance(config, dict) else 'N/A'}")
|
||||
|
||||
# 检查format_info结构
|
||||
if "format_info" in config:
|
||||
format_info = config["format_info"]
|
||||
logger.debug(f"[check_types] 找到 format_info: {format_info}")
|
||||
|
||||
# 方法1: 直接检查accept_format字段
|
||||
if "accept_format" in format_info:
|
||||
@@ -240,8 +246,9 @@ class StreamContext(BaseDataModel):
|
||||
# 检查所有请求的类型是否都被支持
|
||||
for requested_type in types:
|
||||
if requested_type not in accept_format:
|
||||
logger.debug(f"消息不支持类型 '{requested_type}',支持的类型: {accept_format}")
|
||||
logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的类型: {accept_format}")
|
||||
return False
|
||||
logger.debug(f"[check_types] ✅ 消息支持所有请求的类型 (来自 accept_format)")
|
||||
return True
|
||||
|
||||
# 方法2: 检查content_format字段(向后兼容)
|
||||
@@ -258,22 +265,30 @@ class StreamContext(BaseDataModel):
|
||||
# 检查所有请求的类型是否都被支持
|
||||
for requested_type in types:
|
||||
if requested_type not in content_format:
|
||||
logger.debug(f"消息不支持类型 '{requested_type}',支持的内容格式: {content_format}")
|
||||
logger.debug(f"[check_types] 消息不支持类型 '{requested_type}',支持的内容格式: {content_format}")
|
||||
return False
|
||||
logger.debug(f"[check_types] ✅ 消息支持所有请求的类型 (来自 content_format)")
|
||||
return True
|
||||
else:
|
||||
logger.warning("[check_types] [问题] additional_config 中没有 format_info 字段")
|
||||
|
||||
except (orjson.JSONDecodeError, AttributeError, TypeError) as e:
|
||||
logger.debug(f"解析消息格式信息失败: {e}")
|
||||
logger.warning(f"[check_types] [问题] 解析消息格式信息失败: {e}")
|
||||
else:
|
||||
logger.warning("[check_types] [问题] current_message 没有 additional_config 或为空")
|
||||
|
||||
# 备用方案:如果无法从additional_config获取格式信息,使用默认支持的类型
|
||||
# 大多数消息至少支持text类型
|
||||
logger.debug("[check_types] 使用备用方案:默认支持类型检查")
|
||||
default_supported_types = ["text", "emoji"]
|
||||
for requested_type in types:
|
||||
if requested_type not in default_supported_types:
|
||||
logger.debug(f"使用默认类型检查,消息可能不支持类型 '{requested_type}'")
|
||||
logger.debug(f"[check_types] 使用默认类型检查,消息可能不支持类型 '{requested_type}'")
|
||||
# 对于非基础类型,返回False以避免错误
|
||||
if requested_type not in ["text", "emoji", "reply"]:
|
||||
logger.warning(f"[check_types] ❌ 备用方案拒绝类型 '{requested_type}'")
|
||||
return False
|
||||
logger.debug("[check_types] ✅ 备用方案通过所有类型检查")
|
||||
return True
|
||||
|
||||
def get_priority_mode(self) -> str | None:
|
||||
|
||||
@@ -26,7 +26,7 @@ from src.config.official_configs import (
|
||||
EmojiConfig,
|
||||
ExperimentalConfig,
|
||||
ExpressionConfig,
|
||||
KeywordReactionConfig,
|
||||
ReactionConfig,
|
||||
LPMMKnowledgeConfig,
|
||||
MaimMessageConfig,
|
||||
MemoryConfig,
|
||||
@@ -384,7 +384,7 @@ class Config(ValidatedConfigBase):
|
||||
expression: ExpressionConfig = Field(..., description="表达配置")
|
||||
memory: MemoryConfig = Field(..., description="记忆配置")
|
||||
mood: MoodConfig = Field(..., description="情绪配置")
|
||||
keyword_reaction: KeywordReactionConfig = Field(..., description="关键词反应配置")
|
||||
reaction: ReactionConfig = Field(default_factory=ReactionConfig, description="反应规则配置")
|
||||
chinese_typo: ChineseTypoConfig = Field(..., description="中文错别字配置")
|
||||
response_post_process: ResponsePostProcessConfig = Field(..., description="响应后处理配置")
|
||||
response_splitter: ResponseSplitterConfig = Field(..., description="响应分割配置")
|
||||
|
||||
@@ -401,32 +401,31 @@ class MoodConfig(ValidatedConfigBase):
|
||||
mood_update_threshold: float = Field(default=1.0, description="情绪更新阈值")
|
||||
|
||||
|
||||
class KeywordRuleConfig(ValidatedConfigBase):
|
||||
"""关键词规则配置类"""
|
||||
class ReactionRuleConfig(ValidatedConfigBase):
|
||||
"""反应规则配置类"""
|
||||
|
||||
keywords: list[str] = Field(default_factory=lambda: [], description="关键词列表")
|
||||
regex: list[str] = Field(default_factory=lambda: [], description="正则表达式列表")
|
||||
reaction: str = Field(default="", description="反应内容")
|
||||
chat_stream_id: str = Field(default="", description='聊天流ID,格式为 "platform:id:type",空字符串表示全局')
|
||||
rule_type: Literal["keyword", "regex"] = Field(..., description='规则类型,必须是 "keyword" 或 "regex"')
|
||||
patterns: list[str] = Field(..., description="关键词或正则表达式列表")
|
||||
reaction: str = Field(..., description="触发后的回复内容")
|
||||
|
||||
def __post_init__(self):
|
||||
import re
|
||||
|
||||
if not self.keywords and not self.regex:
|
||||
raise ValueError("关键词规则必须至少包含keywords或regex中的一个")
|
||||
if not self.reaction:
|
||||
raise ValueError("关键词规则必须包含reaction")
|
||||
for pattern in self.regex:
|
||||
try:
|
||||
re.compile(pattern)
|
||||
except re.error as e:
|
||||
raise ValueError(f"无效的正则表达式 '{pattern}': {e!s}") from e
|
||||
if not self.patterns:
|
||||
raise ValueError("patterns 列表不能为空")
|
||||
if self.rule_type == "regex":
|
||||
for pattern in self.patterns:
|
||||
try:
|
||||
re.compile(pattern)
|
||||
except re.error as e:
|
||||
raise ValueError(f"无效的正则表达式 '{pattern}': {e!s}") from e
|
||||
|
||||
|
||||
class KeywordReactionConfig(ValidatedConfigBase):
|
||||
"""关键词配置类"""
|
||||
class ReactionConfig(ValidatedConfigBase):
|
||||
"""反应规则系统配置"""
|
||||
|
||||
keyword_rules: list[KeywordRuleConfig] = Field(default_factory=lambda: [], description="关键词规则列表")
|
||||
regex_rules: list[KeywordRuleConfig] = Field(default_factory=lambda: [], description="正则表达式规则列表")
|
||||
rules: list[ReactionRuleConfig] = Field(default_factory=list, description="反应规则列表")
|
||||
|
||||
|
||||
class CustomPromptConfig(ValidatedConfigBase):
|
||||
|
||||
@@ -4,9 +4,10 @@ import signal
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Callable, Coroutine
|
||||
from functools import partial
|
||||
from random import choices
|
||||
from typing import Any, Callable, Coroutine
|
||||
from typing import Any
|
||||
|
||||
from maim_message import MessageServer
|
||||
from rich.traceback import install
|
||||
@@ -24,8 +25,8 @@ from src.config.config import global_config
|
||||
from src.individuality.individuality import Individuality, get_individuality
|
||||
from src.manager.async_task_manager import async_task_manager
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.schedule.monthly_plan_manager import monthly_plan_manager
|
||||
@@ -418,7 +419,7 @@ MoFox_Bot(第三方修改版)
|
||||
|
||||
# 处理所有缓存的事件订阅(插件加载完成后)
|
||||
event_manager.process_all_pending_subscriptions()
|
||||
|
||||
|
||||
# 初始化表情管理器
|
||||
get_emoji_manager().initialize()
|
||||
logger.info("表情包管理器初始化成功")
|
||||
|
||||
0
src/mcp_integration/__init__.py
Normal file
0
src/mcp_integration/__init__.py
Normal file
1
src/mcp_integration/client_manager.py
Normal file
1
src/mcp_integration/client_manager.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
0
src/mcp_integration/config_loader.py
Normal file
0
src/mcp_integration/config_loader.py
Normal file
0
src/mcp_integration/tool_wrapper.py
Normal file
0
src/mcp_integration/tool_wrapper.py
Normal file
@@ -1,4 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
@@ -22,7 +23,7 @@ def get_tool_instance(tool_name: str) -> BaseTool | None:
|
||||
|
||||
|
||||
def get_llm_available_tool_definitions() -> list[dict[str, Any]]:
|
||||
"""获取LLM可用的工具定义列表
|
||||
"""获取LLM可用的工具定义列表(包括 MCP 工具)
|
||||
|
||||
Returns:
|
||||
list[dict[str, Any]]: 工具定义列表
|
||||
@@ -31,6 +32,8 @@ def get_llm_available_tool_definitions() -> list[dict[str, Any]]:
|
||||
|
||||
llm_available_tools = component_registry.get_llm_available_tools()
|
||||
tool_definitions = []
|
||||
|
||||
# 获取常规工具定义
|
||||
for tool_name, tool_class in llm_available_tools.items():
|
||||
try:
|
||||
# 调用类方法 get_tool_definition 获取定义
|
||||
@@ -38,5 +41,18 @@ def get_llm_available_tool_definitions() -> list[dict[str, Any]]:
|
||||
tool_definitions.append(definition)
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具 {tool_name} 的定义失败: {e}")
|
||||
|
||||
# 获取 MCP 工具定义
|
||||
try:
|
||||
mcp_tools = component_registry.get_mcp_tools()
|
||||
for mcp_tool in mcp_tools:
|
||||
try:
|
||||
definition = mcp_tool.get_tool_definition()
|
||||
tool_definitions.append(definition)
|
||||
except Exception as e:
|
||||
logger.error(f"获取 MCP 工具 {mcp_tool.name} 的定义失败: {e}")
|
||||
except Exception as e:
|
||||
logger.debug(f"获取 MCP 工具列表失败(可能未启用): {e}")
|
||||
|
||||
return tool_definitions
|
||||
|
||||
|
||||
@@ -87,6 +87,10 @@ class ComponentRegistry:
|
||||
self._tool_registry: dict[str, type["BaseTool"]] = {} # 工具名 -> 工具类
|
||||
self._llm_available_tools: dict[str, type["BaseTool"]] = {} # llm可用的工具名 -> 工具类
|
||||
|
||||
# MCP 工具注册表(运行时动态加载)
|
||||
self._mcp_tools: list["BaseTool"] = [] # MCP 工具适配器实例列表
|
||||
self._mcp_tools_loaded = False # MCP 工具是否已加载
|
||||
|
||||
# EventHandler特定注册表
|
||||
self._event_handler_registry: dict[str, type["BaseEventHandler"]] = {}
|
||||
"""event_handler名 -> event_handler类"""
|
||||
@@ -891,6 +895,7 @@ class ComponentRegistry:
|
||||
"action_components": action_components,
|
||||
"command_components": command_components,
|
||||
"tool_components": tool_components,
|
||||
"mcp_tools": len(self._mcp_tools),
|
||||
"event_handlers": events_handlers,
|
||||
"plus_command_components": plus_command_components,
|
||||
"chatter_components": chatter_components,
|
||||
@@ -904,6 +909,34 @@ class ComponentRegistry:
|
||||
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
|
||||
}
|
||||
|
||||
# === MCP 工具相关方法 ===
|
||||
|
||||
async def load_mcp_tools(self) -> None:
|
||||
"""加载 MCP 工具(异步方法)"""
|
||||
if self._mcp_tools_loaded:
|
||||
logger.debug("MCP 工具已加载,跳过")
|
||||
return
|
||||
|
||||
try:
|
||||
from .mcp_tool_adapter import load_mcp_tools_as_adapters
|
||||
|
||||
logger.info("开始加载 MCP 工具...")
|
||||
self._mcp_tools = await load_mcp_tools_as_adapters()
|
||||
self._mcp_tools_loaded = True
|
||||
logger.info(f"MCP 工具加载完成,共 {len(self._mcp_tools)} 个工具")
|
||||
except Exception as e:
|
||||
logger.error(f"加载 MCP 工具失败: {e}")
|
||||
self._mcp_tools = []
|
||||
self._mcp_tools_loaded = True # 标记为已尝试加载,避免重复尝试
|
||||
|
||||
def get_mcp_tools(self) -> list["BaseTool"]:
|
||||
"""获取所有 MCP 工具适配器实例"""
|
||||
return self._mcp_tools.copy()
|
||||
|
||||
def is_mcp_tool(self, tool_name: str) -> bool:
|
||||
"""检查工具名是否为 MCP 工具"""
|
||||
return tool_name.startswith("mcp_")
|
||||
|
||||
# === 组件移除相关 ===
|
||||
|
||||
async def unregister_plugin(self, plugin_name: str) -> bool:
|
||||
|
||||
266
src/plugin_system/core/mcp_client_manager.py
Normal file
266
src/plugin_system/core/mcp_client_manager.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""
|
||||
MCP Client Manager
|
||||
|
||||
管理多个 MCP (Model Context Protocol) 客户端连接,支持动态加载和工具注册
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import mcp.types
|
||||
from fastmcp.client import Client, StreamableHttpTransport
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("mcp_client_manager")
|
||||
|
||||
|
||||
class MCPServerConfig:
|
||||
"""单个 MCP 服务器的配置"""
|
||||
|
||||
def __init__(self, name: str, config: dict[str, Any]):
|
||||
self.name = name
|
||||
self.description = config.get("description", "")
|
||||
self.enabled = config.get("enabled", True)
|
||||
self.transport_config = config["transport"]
|
||||
self.auth_config = config.get("auth")
|
||||
self.timeout = config.get("timeout", 30)
|
||||
self.retry_config = config.get("retry", {"max_retries": 3, "retry_delay": 1})
|
||||
|
||||
def __repr__(self):
|
||||
return f"<MCPServerConfig {self.name} (enabled={self.enabled})>"
|
||||
|
||||
|
||||
class MCPClientManager:
|
||||
"""
|
||||
MCP 客户端管理器
|
||||
|
||||
负责:
|
||||
1. 从配置文件加载 MCP 服务器配置
|
||||
2. 建立和维护与 MCP 服务器的连接
|
||||
3. 获取可用的工具列表
|
||||
4. 执行工具调用
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: str | Path | None = None):
|
||||
"""
|
||||
初始化 MCP 客户端管理器
|
||||
|
||||
Args:
|
||||
config_path: mcp.json 配置文件路径,默认为 config/mcp.json
|
||||
"""
|
||||
if config_path is None:
|
||||
# 默认配置路径
|
||||
|
||||
config_path = Path(__file__).parent.parent.parent.parent / "config" / "mcp.json"
|
||||
|
||||
self.config_path = Path(config_path)
|
||||
self.servers: dict[str, MCPServerConfig] = {}
|
||||
self.clients: dict[str, Client] = {}
|
||||
self._initialized = False
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
logger.info(f"MCP 客户端管理器初始化,配置文件: {self.config_path}")
|
||||
|
||||
def load_config(self) -> dict[str, MCPServerConfig]:
|
||||
"""
|
||||
从配置文件加载 MCP 服务器配置
|
||||
|
||||
Returns:
|
||||
Dict[str, MCPServerConfig]: 服务器名称 -> 配置对象
|
||||
"""
|
||||
if not self.config_path.exists():
|
||||
logger.warning(f"MCP 配置文件不存在: {self.config_path}")
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(self.config_path, encoding="utf-8") as f:
|
||||
config_data = json.load(f)
|
||||
|
||||
servers = {}
|
||||
mcp_servers = config_data.get("mcpServers", {})
|
||||
|
||||
for server_name, server_config in mcp_servers.items():
|
||||
try:
|
||||
server = MCPServerConfig(server_name, server_config)
|
||||
servers[server_name] = server
|
||||
logger.debug(f"加载 MCP 服务器配置: {server}")
|
||||
except Exception as e:
|
||||
logger.error(f"加载服务器配置 '{server_name}' 失败: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"成功加载 {len(servers)} 个 MCP 服务器配置")
|
||||
return servers
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"解析 MCP 配置文件失败: {e}")
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.error(f"读取 MCP 配置文件失败: {e}")
|
||||
return {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""
|
||||
初始化所有启用的 MCP 客户端连接
|
||||
|
||||
这个方法会:
|
||||
1. 加载配置文件
|
||||
2. 为每个启用的服务器创建客户端
|
||||
3. 建立连接并验证
|
||||
"""
|
||||
async with self._lock:
|
||||
if self._initialized:
|
||||
logger.debug("MCP 客户端管理器已初始化,跳过")
|
||||
return
|
||||
|
||||
logger.info("开始初始化 MCP 客户端连接...")
|
||||
|
||||
# 加载配置
|
||||
self.servers = self.load_config()
|
||||
|
||||
if not self.servers:
|
||||
logger.warning("没有找到任何 MCP 服务器配置")
|
||||
self._initialized = True
|
||||
return
|
||||
|
||||
# 为每个启用的服务器创建客户端
|
||||
for server_name, server_config in self.servers.items():
|
||||
if not server_config.enabled:
|
||||
logger.debug(f"服务器 '{server_name}' 未启用,跳过")
|
||||
continue
|
||||
|
||||
try:
|
||||
client = await self._create_client(server_config)
|
||||
self.clients[server_name] = client
|
||||
logger.info(f"✅ MCP 服务器 '{server_name}' 连接成功")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 连接 MCP 服务器 '{server_name}' 失败: {e}")
|
||||
continue
|
||||
|
||||
self._initialized = True
|
||||
logger.info(f"MCP 客户端管理器初始化完成,成功连接 {len(self.clients)}/{len(self.servers)} 个服务器")
|
||||
|
||||
async def _create_client(self, server_config: MCPServerConfig) -> Client:
|
||||
"""
|
||||
根据配置创建 MCP 客户端
|
||||
|
||||
Args:
|
||||
server_config: 服务器配置
|
||||
|
||||
Returns:
|
||||
Client: 已连接的 MCP 客户端
|
||||
"""
|
||||
transport_type = server_config.transport_config.get("type", "streamable-http")
|
||||
|
||||
if transport_type == "streamable-http":
|
||||
url = server_config.transport_config["url"]
|
||||
transport = StreamableHttpTransport(url)
|
||||
|
||||
# 设置认证(如果有)
|
||||
if server_config.auth_config:
|
||||
auth_type = server_config.auth_config.get("type")
|
||||
if auth_type == "bearer":
|
||||
from fastmcp.client.auth import BearerAuth
|
||||
|
||||
token = server_config.auth_config.get("token", "")
|
||||
transport._set_auth(BearerAuth(token))
|
||||
|
||||
client = Client(transport, timeout=server_config.timeout)
|
||||
|
||||
elif transport_type == "sse":
|
||||
from fastmcp.client import SSETransport
|
||||
|
||||
url = server_config.transport_config["url"]
|
||||
client = Client(SSETransport(url), timeout=server_config.timeout)
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的传输类型: {transport_type}")
|
||||
|
||||
# 进入客户端上下文(建立连接)
|
||||
await client.__aenter__()
|
||||
|
||||
return client
|
||||
|
||||
async def get_all_tools(self) -> dict[str, list[mcp.types.Tool]]:
|
||||
"""
|
||||
获取所有 MCP 服务器提供的工具列表
|
||||
|
||||
Returns:
|
||||
Dict[str, List[mcp.types.Tool]]: 服务器名称 -> 工具列表
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
all_tools = {}
|
||||
|
||||
for server_name, client in self.clients.items():
|
||||
try:
|
||||
# fastmcp 的 list_tools() 直接返回 List[Tool],不是包含 tools 属性的对象
|
||||
tools = await client.list_tools()
|
||||
all_tools[server_name] = tools
|
||||
logger.debug(f"从服务器 '{server_name}' 获取到 {len(tools)} 个工具")
|
||||
except Exception as e:
|
||||
logger.error(f"从服务器 '{server_name}' 获取工具列表失败: {e}")
|
||||
all_tools[server_name] = []
|
||||
|
||||
return all_tools
|
||||
|
||||
async def call_tool(
|
||||
self, server_name: str, tool_name: str, arguments: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""
|
||||
调用指定 MCP 服务器的工具
|
||||
|
||||
Args:
|
||||
server_name: 服务器名称
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
|
||||
Returns:
|
||||
Any: 工具执行结果(CallToolResult 的兼容类型)
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
if server_name not in self.clients:
|
||||
raise ValueError(f"MCP 服务器 '{server_name}' 未连接")
|
||||
|
||||
client = self.clients[server_name]
|
||||
|
||||
try:
|
||||
logger.debug(f"调用 MCP 工具: {server_name}.{tool_name} | 参数: {arguments}")
|
||||
result = await client.call_tool(tool_name, arguments or {})
|
||||
logger.debug(f"MCP 工具调用成功: {server_name}.{tool_name}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MCP 工具调用失败: {server_name}.{tool_name} | 错误: {e}")
|
||||
raise
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭所有 MCP 客户端连接"""
|
||||
async with self._lock:
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
logger.info("关闭所有 MCP 客户端连接...")
|
||||
|
||||
for server_name, client in self.clients.items():
|
||||
try:
|
||||
await client.__aexit__(None, None, None)
|
||||
logger.debug(f"已关闭 MCP 服务器 '{server_name}' 的连接")
|
||||
except Exception as e:
|
||||
logger.error(f"关闭服务器 '{server_name}' 连接失败: {e}")
|
||||
|
||||
self.clients.clear()
|
||||
self._initialized = False
|
||||
logger.info("所有 MCP 客户端连接已关闭")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<MCPClientManager servers={len(self.servers)} clients={len(self.clients)}>"
|
||||
|
||||
|
||||
# 全局单例
|
||||
mcp_client_manager = MCPClientManager()
|
||||
249
src/plugin_system/core/mcp_tool_adapter.py
Normal file
249
src/plugin_system/core/mcp_tool_adapter.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""
|
||||
MCP Tool Adapter
|
||||
|
||||
将 MCP 工具适配为 BaseTool,使其能够被插件系统识别和调用
|
||||
"""
|
||||
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import mcp.types
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.component_types import ToolParamType
|
||||
|
||||
from .mcp_client_manager import mcp_client_manager
|
||||
|
||||
logger = get_logger("mcp_tool_adapter")
|
||||
|
||||
|
||||
class MCPToolAdapter(BaseTool):
|
||||
"""
|
||||
MCP 工具适配器
|
||||
|
||||
将 MCP 协议的工具适配为 BaseTool,使其能够:
|
||||
1. 被插件系统识别和注册
|
||||
2. 被 LLM 调用
|
||||
3. 参与工具缓存机制
|
||||
"""
|
||||
|
||||
# 类级别默认值,使用 ClassVar 标注
|
||||
available_for_llm: ClassVar[bool] = True
|
||||
|
||||
def __init__(self, server_name: str, mcp_tool: mcp.types.Tool, plugin_config: dict | None = None):
|
||||
"""
|
||||
初始化 MCP 工具适配器
|
||||
|
||||
Args:
|
||||
server_name: MCP 服务器名称
|
||||
mcp_tool: MCP 工具对象
|
||||
plugin_config: 插件配置(可选)
|
||||
"""
|
||||
super().__init__(plugin_config)
|
||||
|
||||
self.server_name = server_name
|
||||
self.mcp_tool = mcp_tool
|
||||
|
||||
# 设置实例属性
|
||||
self.name = f"mcp_{server_name}_{mcp_tool.name}"
|
||||
self.description = mcp_tool.description or f"MCP tool from {server_name}"
|
||||
|
||||
# 转换参数定义
|
||||
self.parameters = self._convert_parameters(mcp_tool.inputSchema)
|
||||
|
||||
logger.debug(f"创建 MCP 工具适配器: {self.name}")
|
||||
|
||||
def _convert_parameters(
|
||||
self, input_schema: dict[str, Any] | None
|
||||
) -> list[tuple[str, ToolParamType, str, bool, list[str] | None]]:
|
||||
"""
|
||||
将 MCP 工具的 JSON Schema 参数转换为 BaseTool 参数格式
|
||||
|
||||
Args:
|
||||
input_schema: MCP 工具的 inputSchema (JSON Schema)
|
||||
|
||||
Returns:
|
||||
List[Tuple]: BaseTool 参数格式列表
|
||||
"""
|
||||
if not input_schema:
|
||||
return []
|
||||
|
||||
parameters = []
|
||||
|
||||
# JSON Schema 通常有 properties 和 required 字段
|
||||
properties = input_schema.get("properties", {})
|
||||
required_fields = input_schema.get("required", [])
|
||||
|
||||
for param_name, param_def in properties.items():
|
||||
# 获取参数类型
|
||||
param_type_str = param_def.get("type", "string")
|
||||
param_type = self._map_json_type_to_tool_param_type(param_type_str)
|
||||
|
||||
# 获取参数描述
|
||||
param_desc = param_def.get("description", f"Parameter {param_name}")
|
||||
|
||||
# 判断是否必填
|
||||
is_required = param_name in required_fields
|
||||
|
||||
# 获取枚举值(如果有)
|
||||
enum_values = param_def.get("enum")
|
||||
|
||||
parameters.append((param_name, param_type, param_desc, is_required, enum_values))
|
||||
|
||||
return parameters
|
||||
|
||||
@staticmethod
|
||||
def _map_json_type_to_tool_param_type(json_type: str) -> ToolParamType:
|
||||
"""
|
||||
将 JSON Schema 类型映射到 ToolParamType
|
||||
|
||||
Args:
|
||||
json_type: JSON Schema 类型字符串
|
||||
|
||||
Returns:
|
||||
ToolParamType: 对应的工具参数类型
|
||||
"""
|
||||
type_mapping = {
|
||||
"string": ToolParamType.STRING,
|
||||
"integer": ToolParamType.INTEGER,
|
||||
"number": ToolParamType.FLOAT,
|
||||
"boolean": ToolParamType.BOOLEAN,
|
||||
}
|
||||
return type_mapping.get(json_type, ToolParamType.STRING)
|
||||
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
执行 MCP 工具调用
|
||||
|
||||
Args:
|
||||
function_args: 工具调用参数
|
||||
|
||||
Returns:
|
||||
Dict: 工具执行结果
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"执行 MCP 工具: {self.name} | 服务器: {self.server_name} | 参数: {function_args}")
|
||||
|
||||
# 移除 llm_called 标记(这是内部使用的)
|
||||
clean_args = {k: v for k, v in function_args.items() if k != "llm_called"}
|
||||
|
||||
# 调用 MCP 客户端管理器执行工具
|
||||
result = await mcp_client_manager.call_tool(
|
||||
server_name=self.server_name, tool_name=self.mcp_tool.name, arguments=clean_args
|
||||
)
|
||||
|
||||
# 解析结果
|
||||
return self._format_result(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MCP 工具执行失败: {self.name} | 错误: {e}")
|
||||
return {
|
||||
"type": "error",
|
||||
"content": f"MCP 工具调用失败: {e!s}",
|
||||
"id": self.name,
|
||||
}
|
||||
|
||||
def _format_result(self, result: mcp.types.CallToolResult) -> dict[str, Any]:
|
||||
"""
|
||||
格式化 MCP 工具执行结果为标准格式
|
||||
|
||||
Args:
|
||||
result: MCP CallToolResult 对象
|
||||
|
||||
Returns:
|
||||
Dict: 标准化的工具执行结果
|
||||
"""
|
||||
# MCP 结果包含 content 列表
|
||||
if not result.content:
|
||||
return {
|
||||
"type": "mcp_result",
|
||||
"content": "",
|
||||
"id": self.name,
|
||||
}
|
||||
|
||||
# 提取所有内容
|
||||
content_parts = []
|
||||
for content_item in result.content:
|
||||
# 根据内容类型提取文本
|
||||
content_type = getattr(content_item, "type", None)
|
||||
|
||||
if content_type == "text":
|
||||
# TextContent 类型
|
||||
text = getattr(content_item, "text", "")
|
||||
content_parts.append(text)
|
||||
elif content_type == "image":
|
||||
# ImageContent 类型
|
||||
data = getattr(content_item, "data", b"")
|
||||
content_parts.append(f"[Image data: {len(data)} bytes]")
|
||||
elif content_type == "audio":
|
||||
# AudioContent 类型
|
||||
data = getattr(content_item, "data", b"")
|
||||
content_parts.append(f"[Audio data: {len(data)} bytes]")
|
||||
else:
|
||||
# 尝试提取 text 或 data 属性
|
||||
text = getattr(content_item, "text", None)
|
||||
if text is not None:
|
||||
content_parts.append(text)
|
||||
else:
|
||||
data = getattr(content_item, "data", None)
|
||||
if data is not None:
|
||||
data_len = len(data) if hasattr(data, "__len__") else "unknown"
|
||||
content_parts.append(f"[Binary data: {data_len} bytes]")
|
||||
else:
|
||||
content_parts.append(str(content_item))
|
||||
|
||||
return {
|
||||
"type": "mcp_result",
|
||||
"content": "\n".join(content_parts),
|
||||
"id": self.name,
|
||||
"is_error": getattr(result, "isError", False),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_mcp_tool(cls, server_name: str, mcp_tool: mcp.types.Tool) -> "MCPToolAdapter":
|
||||
"""
|
||||
从 MCP 工具对象创建适配器实例
|
||||
|
||||
Args:
|
||||
server_name: MCP 服务器名称
|
||||
mcp_tool: MCP 工具对象
|
||||
|
||||
Returns:
|
||||
MCPToolAdapter: 工具适配器实例
|
||||
"""
|
||||
return cls(server_name, mcp_tool)
|
||||
|
||||
|
||||
async def load_mcp_tools_as_adapters() -> list[MCPToolAdapter]:
|
||||
"""
|
||||
加载所有 MCP 工具并转换为适配器
|
||||
|
||||
Returns:
|
||||
List[MCPToolAdapter]: 工具适配器列表
|
||||
"""
|
||||
logger.info("开始加载 MCP 工具...")
|
||||
|
||||
# 初始化 MCP 客户端管理器
|
||||
await mcp_client_manager.initialize()
|
||||
|
||||
# 获取所有工具
|
||||
all_tools_dict = await mcp_client_manager.get_all_tools()
|
||||
|
||||
adapters = []
|
||||
total_tools = 0
|
||||
|
||||
for server_name, tools in all_tools_dict.items():
|
||||
logger.debug(f"处理服务器 '{server_name}' 的 {len(tools)} 个工具")
|
||||
total_tools += len(tools)
|
||||
|
||||
for mcp_tool in tools:
|
||||
try:
|
||||
adapter = MCPToolAdapter.from_mcp_tool(server_name, mcp_tool)
|
||||
adapters.append(adapter)
|
||||
logger.debug(f" ✓ 加载工具: {adapter.name}")
|
||||
except Exception as e:
|
||||
logger.error(f" ✗ 创建工具适配器失败: {mcp_tool.name} | 错误: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"MCP 工具加载完成: 成功 {len(adapters)}/{total_tools} 个")
|
||||
return adapters
|
||||
@@ -145,8 +145,7 @@ class ToolExecutor:
|
||||
pending_step_two = getattr(self, "_pending_step_two_tools", {})
|
||||
if pending_step_two:
|
||||
# 添加第二步工具定义
|
||||
for step_two_def in pending_step_two.values():
|
||||
tool_definitions.append(step_two_def)
|
||||
tool_definitions.extend(list(pending_step_two.values()))
|
||||
|
||||
return tool_definitions
|
||||
|
||||
@@ -286,10 +285,33 @@ class ToolExecutor:
|
||||
logger.info(
|
||||
f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}"
|
||||
)
|
||||
|
||||
|
||||
# 检查是否是MCP工具
|
||||
pass
|
||||
|
||||
from src.plugin_system.core import component_registry
|
||||
|
||||
if component_registry.is_mcp_tool(function_name):
|
||||
logger.debug(f"{self.log_prefix}识别到 MCP 工具: {function_name}")
|
||||
# 找到对应的 MCP 工具实例
|
||||
mcp_tools = component_registry.get_mcp_tools()
|
||||
mcp_tool = next((t for t in mcp_tools if t.name == function_name), None)
|
||||
|
||||
if mcp_tool:
|
||||
logger.debug(f"{self.log_prefix}执行 MCP 工具 {function_name}")
|
||||
result = await mcp_tool.execute(function_args)
|
||||
|
||||
if result:
|
||||
logger.debug(f"{self.log_prefix}MCP 工具 {function_name} 执行成功")
|
||||
return {
|
||||
"tool_call_id": tool_call.call_id,
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"type": "function",
|
||||
"content": result.get("content", ""),
|
||||
}
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix}未找到 MCP 工具: {function_name}")
|
||||
return None
|
||||
|
||||
function_args["llm_called"] = True # 标记为LLM调用
|
||||
|
||||
# 检查是否是二步工具的第二步调用
|
||||
|
||||
@@ -175,10 +175,10 @@ class NoticeHandler:
|
||||
logger.warning("notice处理失败或不支持")
|
||||
return None
|
||||
|
||||
group_info: GroupInfo = None
|
||||
group_info: GroupInfo | None = None
|
||||
if group_id:
|
||||
fetched_group_info = await get_group_info(self.get_server_connection(), group_id)
|
||||
group_name: str = None
|
||||
group_name: str | None = None
|
||||
if fetched_group_info:
|
||||
group_name = fetched_group_info.get("group_name")
|
||||
else:
|
||||
|
||||
@@ -6,4 +6,4 @@ __plugin_meta__ = PluginMetadata(
|
||||
usage="在 bot_config.toml 中将 asr_provider 设置为 'local' 即可启用",
|
||||
version="0.1.0",
|
||||
author="Elysia",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,9 +1,4 @@
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any
|
||||
from pathlib import Path
|
||||
import toml
|
||||
|
||||
import whisper
|
||||
|
||||
@@ -40,7 +35,7 @@ class LocalASRTool(BaseTool):
|
||||
model_size = plugin_config.get("whisper", {}).get("model_size", "tiny")
|
||||
device = plugin_config.get("whisper", {}).get("device", "cpu")
|
||||
logger.info(f"正在预加载 Whisper ASR 模型: {model_size} ({device})")
|
||||
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
_whisper_model = await loop.run_in_executor(
|
||||
None, whisper.load_model, model_size, device
|
||||
@@ -61,10 +56,10 @@ class LocalASRTool(BaseTool):
|
||||
# 增强的等待逻辑:只要模型还没准备好,就一直等待后台加载任务完成
|
||||
while _is_loading:
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
|
||||
if _whisper_model is None:
|
||||
return "Whisper 模型加载失败,无法识别语音。"
|
||||
|
||||
|
||||
try:
|
||||
logger.info(f"开始使用 Whisper 识别音频: {audio_path}")
|
||||
loop = asyncio.get_running_loop()
|
||||
@@ -110,6 +105,6 @@ class STTWhisperPlugin(BasePlugin):
|
||||
), LocalASRTool)]
|
||||
except Exception as e:
|
||||
logger.error(f"检查 ASR provider 配置时出错: {e}")
|
||||
|
||||
|
||||
logger.debug("ASR provider is not 'local', whisper plugin's tool is disabled.")
|
||||
return []
|
||||
|
||||
@@ -6,9 +6,9 @@ from pathlib import Path
|
||||
|
||||
import toml
|
||||
|
||||
from src.chat.utils.self_voice_cache import register_self_voice
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_action import BaseAction, ChatMode
|
||||
from src.chat.utils.self_voice_cache import register_self_voice
|
||||
|
||||
from ..services.manager import get_service
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ __plugin_meta__ = PluginMetadata(
|
||||
"is_built_in": True,
|
||||
},
|
||||
# Python包依赖列表
|
||||
python_dependencies = [ # noqa: RUF012
|
||||
python_dependencies = [
|
||||
PythonDependency(package_name="asyncddgs", description="异步DuckDuckGo搜索库", optional=False),
|
||||
PythonDependency(
|
||||
package_name="exa_py",
|
||||
|
||||
@@ -279,12 +279,6 @@ max_frequency_bonus = 10.0 # 最大激活频率奖励天数
|
||||
# 休眠机制
|
||||
dormant_threshold_days = 90 # 休眠状态判定天数(超过此天数未访问的记忆进入休眠状态)
|
||||
|
||||
# 统一存储配置 (已弃用 - 请使用Vector DB配置)
|
||||
# DEPRECATED: unified_storage_path = "data/unified_memory"
|
||||
# DEPRECATED: unified_storage_cache_limit = 10000
|
||||
# DEPRECATED: unified_storage_auto_save_interval = 50
|
||||
# DEPRECATED: unified_storage_enable_compression = true
|
||||
|
||||
# Vector DB存储配置 (新增 - 替代JSON存储)
|
||||
enable_vector_memory_storage = true # 启用Vector DB存储
|
||||
enable_llm_instant_memory = true # 启用基于LLM的瞬时记忆
|
||||
@@ -336,22 +330,36 @@ qa_ppr_damping = 0.8 # PPR阻尼系数
|
||||
qa_res_top_k = 3 # 最终提供的文段TopK
|
||||
embedding_dimension = 1024 # 嵌入向量维度,应该与模型的输出维度一致
|
||||
|
||||
# keyword_rules 用于设置关键词触发的额外回复知识
|
||||
# 添加新规则方法:在 keyword_rules 数组中增加一项,格式如下:
|
||||
# { keywords = ["关键词1", "关键词2"], reaction = "触发这些关键词时的回复内容" }
|
||||
# 例如,添加一个新规则:当检测到“你好”或“hello”时回复“你好,有什么可以帮你?”
|
||||
# { keywords = ["你好", "hello"], reaction = "你好,有什么可以帮你?" }
|
||||
[keyword_reaction]
|
||||
keyword_rules = [
|
||||
{ keywords = ["人机", "bot", "机器", "入机", "robot", "机器人", "ai", "AI"], reaction = "有人说你是人机或者机器人,否定这一事实,攻击或反问他。注意,只有认为你是机器人才回复,否则不要否认" },
|
||||
{ keywords = ["测试关键词回复", "test"], reaction = "回答测试成功" },
|
||||
#{ keywords = ["你好", "hello"], reaction = "你好,有什么可以帮你?" }
|
||||
# 在此处添加更多规则,格式同上
|
||||
]
|
||||
# --- 反应规则系统 ---
|
||||
# 在这里,您可以定义一系列基于关键词或正则表达式的自动回复规则。
|
||||
# 每条规则都是一个独立的 [[reaction.rules]] 块。
|
||||
|
||||
regex_rules = [
|
||||
{ regex = ["^(?P<n>\\S{1,20})是这样的$"], reaction = "请按照以下模板造句:[n]是这样的,xx只要xx就可以,可是[n]要考虑的事情就很多了,比如什么时候xx,什么时候xx,什么时候xx。(请自由发挥替换xx部分,只需保持句式结构,同时表达一种将[n]过度重视的反讽意味)" }
|
||||
]
|
||||
# chat_stream_id: 聊天流ID (格式 "platform:id:type")。
|
||||
# 用于指定此规则仅在哪个聊天中生效。
|
||||
# 如果留空 (""),则为全局规则,对所有聊天生效。
|
||||
# rule_type: 规则类型,必须是 "keyword" 或 "regex"。
|
||||
# "keyword": 表示本条规则使用关键词匹配。
|
||||
# "regex": 表示本条规则使用正则表达式匹配。
|
||||
# patterns: 一个字符串列表,根据 rule_type 的不同,这里填写关键词或正则表达式。
|
||||
# reaction: 触发规则后,机器人发送的回复内容。
|
||||
|
||||
[[reaction.rules]]
|
||||
chat_stream_id = ""
|
||||
rule_type = "keyword"
|
||||
patterns = ["人机", "bot", "机器", "入机", "robot", "机器人", "ai", "AI"]
|
||||
reaction = "有人说你是人机或者机器人,否定这一事实,攻击或反问他。注意,只有认为你是机器人才回复,否则不要否认"
|
||||
|
||||
[[reaction.rules]]
|
||||
chat_stream_id = ""
|
||||
rule_type = "keyword"
|
||||
patterns = ["测试关键词回复", "test"]
|
||||
reaction = "回答测试成功"
|
||||
|
||||
[[reaction.rules]]
|
||||
chat_stream_id = ""
|
||||
rule_type = "regex"
|
||||
patterns = ["^(?P<n>\\S{1,20})是这样的$"]
|
||||
reaction = "请按照以下模板造句:[n]是这样的,xx只要xx就可以,可是[n]要考虑的事情就很多了,比如什么时候xx,什么时候xx,什么时候xx。(请自由发挥替换xx部分,只需保持句式结构,同时表达一种将[n]过度重视的反讽意味)"
|
||||
|
||||
# 可以自定义部分提示词
|
||||
[custom_prompt]
|
||||
|
||||
Reference in New Issue
Block a user