重构消息处理并用DatabaseMessages替换MessageRecv
-更新PlusCommand以使用DatabaseMessages而不是MessageRecv。 -将消息处理逻辑重构到一个新模块message_processor.py中,以处理消息段并从消息字典中创建DatabaseMessages。 -删除了已弃用的MessageRecv类及其相关逻辑。 -调整了各种插件以适应新的DatabaseMessages结构。 -增强了消息处理功能中的错误处理和日志记录。
This commit is contained in:
@@ -1,303 +0,0 @@
|
|||||||
"""
|
|
||||||
关系追踪工具集成测试脚本
|
|
||||||
|
|
||||||
注意:此脚本需要在完整的应用环境中运行
|
|
||||||
建议通过 bot.py 启动后在交互式环境中测试
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
|
|
||||||
async def test_user_profile_tool():
|
|
||||||
"""测试用户画像工具"""
|
|
||||||
print("\n" + "=" * 80)
|
|
||||||
print("测试 UserProfileTool")
|
|
||||||
print("=" * 80)
|
|
||||||
|
|
||||||
from src.plugins.built_in.affinity_flow_chatter.user_profile_tool import UserProfileTool
|
|
||||||
from src.common.database.sqlalchemy_database_api import db_query
|
|
||||||
from src.common.database.sqlalchemy_models import UserRelationships
|
|
||||||
|
|
||||||
tool = UserProfileTool()
|
|
||||||
print(f"✅ 工具名称: {tool.name}")
|
|
||||||
print(f" 工具描述: {tool.description}")
|
|
||||||
|
|
||||||
# 执行工具
|
|
||||||
test_user_id = "integration_test_user_001"
|
|
||||||
result = await tool.execute({
|
|
||||||
"target_user_id": test_user_id,
|
|
||||||
"user_aliases": "测试小明,TestMing,小明君",
|
|
||||||
"impression_description": "这是一个集成测试用户,性格开朗活泼,喜欢技术讨论,对AI和编程特别感兴趣。经常提出有深度的问题。",
|
|
||||||
"preference_keywords": "AI,Python,深度学习,游戏开发,科幻小说",
|
|
||||||
"affection_score": 0.85
|
|
||||||
})
|
|
||||||
|
|
||||||
print(f"\n✅ 工具执行结果:")
|
|
||||||
print(f" 类型: {result.get('type')}")
|
|
||||||
print(f" 内容: {result.get('content')}")
|
|
||||||
|
|
||||||
# 验证数据库
|
|
||||||
db_data = await db_query(
|
|
||||||
UserRelationships,
|
|
||||||
filters={"user_id": test_user_id},
|
|
||||||
limit=1
|
|
||||||
)
|
|
||||||
|
|
||||||
if db_data:
|
|
||||||
data = db_data[0]
|
|
||||||
print(f"\n✅ 数据库验证:")
|
|
||||||
print(f" user_id: {data.get('user_id')}")
|
|
||||||
print(f" user_aliases: {data.get('user_aliases')}")
|
|
||||||
print(f" relationship_text: {data.get('relationship_text', '')[:80]}...")
|
|
||||||
print(f" preference_keywords: {data.get('preference_keywords')}")
|
|
||||||
print(f" relationship_score: {data.get('relationship_score')}")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
print(f"\n❌ 数据库中未找到数据")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def test_chat_stream_impression_tool():
|
|
||||||
"""测试聊天流印象工具"""
|
|
||||||
print("\n" + "=" * 80)
|
|
||||||
print("测试 ChatStreamImpressionTool")
|
|
||||||
print("=" * 80)
|
|
||||||
|
|
||||||
from src.plugins.built_in.affinity_flow_chatter.chat_stream_impression_tool import ChatStreamImpressionTool
|
|
||||||
from src.common.database.sqlalchemy_database_api import db_query
|
|
||||||
from src.common.database.sqlalchemy_models import ChatStreams, get_db_session
|
|
||||||
|
|
||||||
# 准备测试数据:先创建一条 ChatStreams 记录
|
|
||||||
test_stream_id = "integration_test_stream_001"
|
|
||||||
print(f"🔧 准备测试数据:创建聊天流记录 {test_stream_id}")
|
|
||||||
|
|
||||||
import time
|
|
||||||
current_time = time.time()
|
|
||||||
|
|
||||||
async with get_db_session() as session:
|
|
||||||
new_stream = ChatStreams(
|
|
||||||
stream_id=test_stream_id,
|
|
||||||
create_time=current_time,
|
|
||||||
last_active_time=current_time,
|
|
||||||
platform="QQ",
|
|
||||||
user_platform="QQ",
|
|
||||||
user_id="test_user_123",
|
|
||||||
user_nickname="测试用户",
|
|
||||||
group_name="测试技术交流群",
|
|
||||||
group_platform="QQ",
|
|
||||||
group_id="test_group_456",
|
|
||||||
stream_impression_text="", # 初始为空
|
|
||||||
stream_chat_style="",
|
|
||||||
stream_topic_keywords="",
|
|
||||||
stream_interest_score=0.5
|
|
||||||
)
|
|
||||||
session.add(new_stream)
|
|
||||||
await session.commit()
|
|
||||||
print(f"✅ 测试聊天流记录已创建")
|
|
||||||
|
|
||||||
tool = ChatStreamImpressionTool()
|
|
||||||
print(f"✅ 工具名称: {tool.name}")
|
|
||||||
print(f" 工具描述: {tool.description}")
|
|
||||||
|
|
||||||
# 执行工具
|
|
||||||
result = await tool.execute({
|
|
||||||
"stream_id": test_stream_id,
|
|
||||||
"impression_description": "这是一个技术交流群,成员主要是程序员和AI爱好者。大家经常分享最新的技术文章,讨论编程问题,氛围友好且专业。",
|
|
||||||
"chat_style": "专业技术交流,活跃讨论,互帮互助,知识分享",
|
|
||||||
"topic_keywords": "Python开发,机器学习,AI应用,Web后端,数据分析,开源项目",
|
|
||||||
"interest_score": 0.90
|
|
||||||
})
|
|
||||||
|
|
||||||
print(f"\n✅ 工具执行结果:")
|
|
||||||
print(f" 类型: {result.get('type')}")
|
|
||||||
print(f" 内容: {result.get('content')}")
|
|
||||||
|
|
||||||
# 验证数据库
|
|
||||||
db_data = await db_query(
|
|
||||||
ChatStreams,
|
|
||||||
filters={"stream_id": test_stream_id},
|
|
||||||
limit=1
|
|
||||||
)
|
|
||||||
|
|
||||||
if db_data:
|
|
||||||
data = db_data[0]
|
|
||||||
print(f"\n✅ 数据库验证:")
|
|
||||||
print(f" stream_id: {data.get('stream_id')}")
|
|
||||||
print(f" stream_impression_text: {data.get('stream_impression_text', '')[:80]}...")
|
|
||||||
print(f" stream_chat_style: {data.get('stream_chat_style')}")
|
|
||||||
print(f" stream_topic_keywords: {data.get('stream_topic_keywords')}")
|
|
||||||
print(f" stream_interest_score: {data.get('stream_interest_score')}")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
print(f"\n❌ 数据库中未找到数据")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def test_relationship_info_build():
|
|
||||||
"""测试关系信息构建"""
|
|
||||||
print("\n" + "=" * 80)
|
|
||||||
print("测试关系信息构建(提示词集成)")
|
|
||||||
print("=" * 80)
|
|
||||||
|
|
||||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
|
||||||
|
|
||||||
test_stream_id = "integration_test_stream_001"
|
|
||||||
test_person_id = "test_person_999" # 使用一个可能不存在的ID来测试
|
|
||||||
|
|
||||||
fetcher = relationship_fetcher_manager.get_fetcher(test_stream_id)
|
|
||||||
print(f"✅ RelationshipFetcher 已创建")
|
|
||||||
|
|
||||||
# 测试聊天流印象构建
|
|
||||||
print(f"\n🔍 构建聊天流印象...")
|
|
||||||
stream_info = await fetcher.build_chat_stream_impression(test_stream_id)
|
|
||||||
|
|
||||||
if stream_info:
|
|
||||||
print(f"✅ 聊天流印象构建成功")
|
|
||||||
print(f"\n{'=' * 80}")
|
|
||||||
print(stream_info)
|
|
||||||
print(f"{'=' * 80}")
|
|
||||||
else:
|
|
||||||
print(f"⚠️ 聊天流印象为空(可能测试数据不存在)")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
async def cleanup_test_data():
|
|
||||||
"""清理测试数据"""
|
|
||||||
print("\n" + "=" * 80)
|
|
||||||
print("清理测试数据")
|
|
||||||
print("=" * 80)
|
|
||||||
|
|
||||||
from src.common.database.sqlalchemy_database_api import db_query
|
|
||||||
from src.common.database.sqlalchemy_models import UserRelationships, ChatStreams
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 清理用户数据
|
|
||||||
await db_query(
|
|
||||||
UserRelationships,
|
|
||||||
query_type="delete",
|
|
||||||
filters={"user_id": "integration_test_user_001"}
|
|
||||||
)
|
|
||||||
print("✅ 用户测试数据已清理")
|
|
||||||
|
|
||||||
# 清理聊天流数据
|
|
||||||
await db_query(
|
|
||||||
ChatStreams,
|
|
||||||
query_type="delete",
|
|
||||||
filters={"stream_id": "integration_test_stream_001"}
|
|
||||||
)
|
|
||||||
print("✅ 聊天流测试数据已清理")
|
|
||||||
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
print(f"⚠️ 清理失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def run_all_tests():
|
|
||||||
"""运行所有测试"""
|
|
||||||
print("\n" + "=" * 80)
|
|
||||||
print("关系追踪工具集成测试")
|
|
||||||
print("=" * 80)
|
|
||||||
|
|
||||||
results = {}
|
|
||||||
|
|
||||||
# 测试1
|
|
||||||
try:
|
|
||||||
results["UserProfileTool"] = await test_user_profile_tool()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n❌ UserProfileTool 测试失败: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
results["UserProfileTool"] = False
|
|
||||||
|
|
||||||
# 测试2
|
|
||||||
try:
|
|
||||||
results["ChatStreamImpressionTool"] = await test_chat_stream_impression_tool()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n❌ ChatStreamImpressionTool 测试失败: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
results["ChatStreamImpressionTool"] = False
|
|
||||||
|
|
||||||
# 测试3
|
|
||||||
try:
|
|
||||||
results["RelationshipFetcher"] = await test_relationship_info_build()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n❌ RelationshipFetcher 测试失败: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
results["RelationshipFetcher"] = False
|
|
||||||
|
|
||||||
# 清理
|
|
||||||
try:
|
|
||||||
await cleanup_test_data()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n⚠️ 清理测试数据失败: {e}")
|
|
||||||
|
|
||||||
# 总结
|
|
||||||
print("\n" + "=" * 80)
|
|
||||||
print("测试总结")
|
|
||||||
print("=" * 80)
|
|
||||||
|
|
||||||
passed = sum(1 for r in results.values() if r)
|
|
||||||
total = len(results)
|
|
||||||
|
|
||||||
for test_name, result in results.items():
|
|
||||||
status = "✅ 通过" if result else "❌ 失败"
|
|
||||||
print(f"{status} - {test_name}")
|
|
||||||
|
|
||||||
print(f"\n总计: {passed}/{total} 测试通过")
|
|
||||||
|
|
||||||
if passed == total:
|
|
||||||
print("\n🎉 所有测试通过!")
|
|
||||||
else:
|
|
||||||
print(f"\n⚠️ {total - passed} 个测试失败")
|
|
||||||
|
|
||||||
return passed == total
|
|
||||||
|
|
||||||
|
|
||||||
# 使用说明
|
|
||||||
print("""
|
|
||||||
============================================================================
|
|
||||||
关系追踪工具集成测试脚本
|
|
||||||
============================================================================
|
|
||||||
|
|
||||||
此脚本需要在完整的应用环境中运行。
|
|
||||||
|
|
||||||
使用方法1: 在 bot.py 中添加测试调用
|
|
||||||
-----------------------------------
|
|
||||||
在 bot.py 的 main() 函数中添加:
|
|
||||||
|
|
||||||
# 测试关系追踪工具
|
|
||||||
from tests.integration_test_relationship_tools import run_all_tests
|
|
||||||
await run_all_tests()
|
|
||||||
|
|
||||||
使用方法2: 在 Python REPL 中运行
|
|
||||||
-----------------------------------
|
|
||||||
启动 bot.py 后,在 Python 调试控制台中执行:
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from tests.integration_test_relationship_tools import run_all_tests
|
|
||||||
asyncio.create_task(run_all_tests())
|
|
||||||
|
|
||||||
使用方法3: 直接在此文件底部运行
|
|
||||||
-----------------------------------
|
|
||||||
取消注释下面的代码,然后确保已启动应用环境
|
|
||||||
============================================================================
|
|
||||||
""")
|
|
||||||
|
|
||||||
|
|
||||||
# 如果需要直接运行(需要应用环境已启动)
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("\n⚠️ 警告: 直接运行此脚本可能会失败,因为缺少应用环境")
|
|
||||||
print("建议在 bot.py 启动后的环境中运行\n")
|
|
||||||
|
|
||||||
try:
|
|
||||||
asyncio.run(run_all_tests())
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n❌ 测试失败: {e}")
|
|
||||||
print("\n建议:")
|
|
||||||
print("1. 确保已启动 bot.py")
|
|
||||||
print("2. 在 Python 调试控制台中运行测试")
|
|
||||||
print("3. 或在 bot.py 中添加测试调用")
|
|
||||||
@@ -27,6 +27,6 @@
|
|||||||
"venvPath": ".",
|
"venvPath": ".",
|
||||||
"venv": ".venv",
|
"venv": ".venv",
|
||||||
"executionEnvironments": [
|
"executionEnvironments": [
|
||||||
{"root": "src"}
|
{"root": "."}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("anti_injector.message_processor")
|
logger = get_logger("anti_injector.message_processor")
|
||||||
@@ -15,7 +15,7 @@ logger = get_logger("anti_injector.message_processor")
|
|||||||
class MessageProcessor:
|
class MessageProcessor:
|
||||||
"""消息内容处理器"""
|
"""消息内容处理器"""
|
||||||
|
|
||||||
def extract_text_content(self, message: MessageRecv) -> str:
|
def extract_text_content(self, message: DatabaseMessages) -> str:
|
||||||
"""提取消息中的文本内容,过滤掉引用的历史内容
|
"""提取消息中的文本内容,过滤掉引用的历史内容
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -64,7 +64,7 @@ class MessageProcessor:
|
|||||||
return new_content
|
return new_content
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_whitelist(message: MessageRecv, whitelist: list) -> tuple | None:
|
def check_whitelist(message: DatabaseMessages, whitelist: list) -> tuple | None:
|
||||||
"""检查用户白名单
|
"""检查用户白名单
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -74,8 +74,8 @@ class MessageProcessor:
|
|||||||
Returns:
|
Returns:
|
||||||
如果在白名单中返回结果元组,否则返回None
|
如果在白名单中返回结果元组,否则返回None
|
||||||
"""
|
"""
|
||||||
user_id = message.message_info.user_info.user_id
|
user_id = message.user_info.user_id
|
||||||
platform = message.message_info.platform
|
platform = message.chat_info.platform
|
||||||
|
|
||||||
# 检查用户白名单:格式为 [[platform, user_id], ...]
|
# 检查用户白名单:格式为 [[platform, user_id], ...]
|
||||||
for whitelist_entry in whitelist:
|
for whitelist_entry in whitelist:
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ class SingleStreamContextManager:
|
|||||||
|
|
||||||
# 配置参数
|
# 配置参数
|
||||||
self.max_context_size = max_context_size or getattr(global_config.chat, "max_context_size", 100)
|
self.max_context_size = max_context_size or getattr(global_config.chat, "max_context_size", 100)
|
||||||
self.context_ttl = getattr(global_config.chat, "context_ttl", 24 * 3600) # 24小时
|
|
||||||
|
|
||||||
# 元数据
|
# 元数据
|
||||||
self.created_time = time.time()
|
self.created_time = time.time()
|
||||||
@@ -93,27 +92,24 @@ class SingleStreamContextManager:
|
|||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
logger.warning(f"消息缓存系统添加失败,回退到直接添加: {self.stream_id}")
|
logger.warning(f"消息缓存系统添加失败,回退到直接添加: {self.stream_id}")
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
logger.debug("MessageManager不可用,使用直接添加模式")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"消息缓存系统异常,回退到直接添加: {self.stream_id}, error={e}")
|
logger.warning(f"消息缓存系统异常,回退到直接添加: {self.stream_id}, error={e}")
|
||||||
|
|
||||||
# 回退方案:直接添加到未读消息
|
# 回退方案:直接添加到未读消息
|
||||||
message.is_read = False
|
message.is_read = False
|
||||||
self.context.unread_messages.append(message)
|
self.context.unread_messages.append(message)
|
||||||
|
|
||||||
# 自动检测和更新chat type
|
# 自动检测和更新chat type
|
||||||
self._detect_chat_type(message)
|
self._detect_chat_type(message)
|
||||||
|
|
||||||
# 在上下文管理器中计算兴趣值
|
# 在上下文管理器中计算兴趣值
|
||||||
await self._calculate_message_interest(message)
|
await self._calculate_message_interest(message)
|
||||||
self.total_messages += 1
|
self.total_messages += 1
|
||||||
self.last_access_time = time.time()
|
self.last_access_time = time.time()
|
||||||
# 启动流的循环任务(如果还未启动)
|
# 启动流的循环任务(如果还未启动)
|
||||||
asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
|
asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
|
||||||
logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}")
|
logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
|
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -71,14 +71,6 @@ class MessageManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"启动批量数据库写入器失败: {e}")
|
logger.error(f"启动批量数据库写入器失败: {e}")
|
||||||
|
|
||||||
# 启动流缓存管理器
|
|
||||||
try:
|
|
||||||
from src.chat.message_manager.stream_cache_manager import init_stream_cache_manager
|
|
||||||
|
|
||||||
await init_stream_cache_manager()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"启动流缓存管理器失败: {e}")
|
|
||||||
|
|
||||||
# 启动消息缓存系统(内置)
|
# 启动消息缓存系统(内置)
|
||||||
logger.info("📦 消息缓存系统已启动")
|
logger.info("📦 消息缓存系统已启动")
|
||||||
|
|
||||||
@@ -116,15 +108,6 @@ class MessageManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"停止批量数据库写入器失败: {e}")
|
logger.error(f"停止批量数据库写入器失败: {e}")
|
||||||
|
|
||||||
# 停止流缓存管理器
|
|
||||||
try:
|
|
||||||
from src.chat.message_manager.stream_cache_manager import shutdown_stream_cache_manager
|
|
||||||
|
|
||||||
await shutdown_stream_cache_manager()
|
|
||||||
logger.info("🗄️ 流缓存管理器已停止")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"停止流缓存管理器失败: {e}")
|
|
||||||
|
|
||||||
# 停止消息缓存系统(内置)
|
# 停止消息缓存系统(内置)
|
||||||
self.message_caches.clear()
|
self.message_caches.clear()
|
||||||
self.stream_processing_status.clear()
|
self.stream_processing_status.clear()
|
||||||
@@ -152,7 +135,7 @@ class MessageManager:
|
|||||||
# 检查是否为notice消息
|
# 检查是否为notice消息
|
||||||
if self._is_notice_message(message):
|
if self._is_notice_message(message):
|
||||||
# Notice消息处理 - 添加到全局管理器
|
# Notice消息处理 - 添加到全局管理器
|
||||||
logger.info(f"📢 检测到notice消息: message_id={message.message_id}, is_notify={message.is_notify}, notice_type={getattr(message, 'notice_type', None)}")
|
logger.info(f"📢 检测到notice消息: notice_type={getattr(message, 'notice_type', None)}")
|
||||||
await self._handle_notice_message(stream_id, message)
|
await self._handle_notice_message(stream_id, message)
|
||||||
|
|
||||||
# 根据配置决定是否继续处理(触发聊天流程)
|
# 根据配置决定是否继续处理(触发聊天流程)
|
||||||
@@ -206,39 +189,6 @@ class MessageManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"更新消息 {message_id} 时发生错误: {e}")
|
logger.error(f"更新消息 {message_id} 时发生错误: {e}")
|
||||||
|
|
||||||
async def bulk_update_messages(self, stream_id: str, updates: list[dict[str, Any]]) -> int:
|
|
||||||
"""批量更新消息信息,降低更新频率"""
|
|
||||||
if not updates:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
chat_manager = get_chat_manager()
|
|
||||||
chat_stream = await chat_manager.get_stream(stream_id)
|
|
||||||
if not chat_stream:
|
|
||||||
logger.warning(f"MessageManager.bulk_update_messages: 聊天流 {stream_id} 不存在")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
updated_count = 0
|
|
||||||
for item in updates:
|
|
||||||
message_id = item.get("message_id")
|
|
||||||
if not message_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
payload = {key: value for key, value in item.items() if key != "message_id" and value is not None}
|
|
||||||
|
|
||||||
if not payload:
|
|
||||||
continue
|
|
||||||
|
|
||||||
success = await chat_stream.context_manager.update_message(message_id, payload)
|
|
||||||
if success:
|
|
||||||
updated_count += 1
|
|
||||||
|
|
||||||
if updated_count:
|
|
||||||
logger.debug(f"批量更新消息 {updated_count} 条 (stream={stream_id})")
|
|
||||||
return updated_count
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"批量更新聊天流 {stream_id} 消息失败: {e}")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
async def add_action(self, stream_id: str, message_id: str, action: str):
|
async def add_action(self, stream_id: str, message_id: str, action: str):
|
||||||
"""添加动作到消息"""
|
"""添加动作到消息"""
|
||||||
@@ -266,7 +216,7 @@ class MessageManager:
|
|||||||
logger.warning(f"停用流失败: 聊天流 {stream_id} 不存在")
|
logger.warning(f"停用流失败: 聊天流 {stream_id} 不存在")
|
||||||
return
|
return
|
||||||
|
|
||||||
context = chat_stream.stream_context
|
context = chat_stream.context_manager.context
|
||||||
context.is_active = False
|
context.is_active = False
|
||||||
|
|
||||||
# 取消处理任务
|
# 取消处理任务
|
||||||
@@ -288,7 +238,7 @@ class MessageManager:
|
|||||||
logger.warning(f"激活流失败: 聊天流 {stream_id} 不存在")
|
logger.warning(f"激活流失败: 聊天流 {stream_id} 不存在")
|
||||||
return
|
return
|
||||||
|
|
||||||
context = chat_stream.stream_context
|
context = chat_stream.context_manager.context
|
||||||
context.is_active = True
|
context.is_active = True
|
||||||
logger.info(f"激活聊天流: {stream_id}")
|
logger.info(f"激活聊天流: {stream_id}")
|
||||||
|
|
||||||
@@ -304,7 +254,7 @@ class MessageManager:
|
|||||||
if not chat_stream:
|
if not chat_stream:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
context = chat_stream.stream_context
|
context = chat_stream.context_manager.context
|
||||||
unread_count = len(chat_stream.context_manager.get_unread_messages())
|
unread_count = len(chat_stream.context_manager.get_unread_messages())
|
||||||
|
|
||||||
return StreamStats(
|
return StreamStats(
|
||||||
@@ -447,7 +397,7 @@ class MessageManager:
|
|||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
# 获取当前的stream context
|
# 获取当前的stream context
|
||||||
context = chat_stream.stream_context
|
context = chat_stream.context_manager.context
|
||||||
|
|
||||||
# 确保有未读消息需要处理
|
# 确保有未读消息需要处理
|
||||||
unread_messages = context.get_unread_messages()
|
unread_messages = context.get_unread_messages()
|
||||||
|
|||||||
@@ -1,377 +0,0 @@
|
|||||||
"""
|
|
||||||
流缓存管理器 - 使用优化版聊天流和智能缓存策略
|
|
||||||
提供分层缓存和自动清理功能
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
from collections import OrderedDict
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
from maim_message import GroupInfo, UserInfo
|
|
||||||
|
|
||||||
from src.chat.message_receive.optimized_chat_stream import OptimizedChatStream, create_optimized_chat_stream
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger("stream_cache_manager")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class StreamCacheStats:
|
|
||||||
"""缓存统计信息"""
|
|
||||||
|
|
||||||
hot_cache_size: int = 0
|
|
||||||
warm_storage_size: int = 0
|
|
||||||
cold_storage_size: int = 0
|
|
||||||
total_memory_usage: int = 0 # 估算的内存使用(字节)
|
|
||||||
cache_hits: int = 0
|
|
||||||
cache_misses: int = 0
|
|
||||||
evictions: int = 0
|
|
||||||
last_cleanup_time: float = 0
|
|
||||||
|
|
||||||
|
|
||||||
class TieredStreamCache:
|
|
||||||
"""分层流缓存管理器"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_hot_size: int = 100,
|
|
||||||
max_warm_size: int = 500,
|
|
||||||
max_cold_size: int = 2000,
|
|
||||||
cleanup_interval: float = 300.0, # 5分钟清理一次
|
|
||||||
hot_timeout: float = 1800.0, # 30分钟未访问降级到warm
|
|
||||||
warm_timeout: float = 7200.0, # 2小时未访问降级到cold
|
|
||||||
cold_timeout: float = 86400.0, # 24小时未访问删除
|
|
||||||
):
|
|
||||||
self.max_hot_size = max_hot_size
|
|
||||||
self.max_warm_size = max_warm_size
|
|
||||||
self.max_cold_size = max_cold_size
|
|
||||||
self.cleanup_interval = cleanup_interval
|
|
||||||
self.hot_timeout = hot_timeout
|
|
||||||
self.warm_timeout = warm_timeout
|
|
||||||
self.cold_timeout = cold_timeout
|
|
||||||
|
|
||||||
# 三层缓存存储
|
|
||||||
self.hot_cache: OrderedDict[str, OptimizedChatStream] = OrderedDict() # 热数据(LRU)
|
|
||||||
self.warm_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间)
|
|
||||||
self.cold_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间)
|
|
||||||
|
|
||||||
# 统计信息
|
|
||||||
self.stats = StreamCacheStats()
|
|
||||||
|
|
||||||
# 清理任务
|
|
||||||
self.cleanup_task: asyncio.Task | None = None
|
|
||||||
self.is_running = False
|
|
||||||
|
|
||||||
logger.info(f"分层流缓存管理器初始化完成 (hot:{max_hot_size}, warm:{max_warm_size}, cold:{max_cold_size})")
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
"""启动缓存管理器"""
|
|
||||||
if self.is_running:
|
|
||||||
logger.warning("缓存管理器已经在运行")
|
|
||||||
return
|
|
||||||
|
|
||||||
self.is_running = True
|
|
||||||
self.cleanup_task = asyncio.create_task(self._cleanup_loop(), name="stream_cache_cleanup")
|
|
||||||
|
|
||||||
async def stop(self):
|
|
||||||
"""停止缓存管理器"""
|
|
||||||
if not self.is_running:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.is_running = False
|
|
||||||
|
|
||||||
if self.cleanup_task and not self.cleanup_task.done():
|
|
||||||
self.cleanup_task.cancel()
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(self.cleanup_task, timeout=10.0)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning("缓存清理任务停止超时")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"停止缓存清理任务时出错: {e}")
|
|
||||||
|
|
||||||
logger.info("分层流缓存管理器已停止")
|
|
||||||
|
|
||||||
async def get_or_create_stream(
|
|
||||||
self,
|
|
||||||
stream_id: str,
|
|
||||||
platform: str,
|
|
||||||
user_info: UserInfo,
|
|
||||||
group_info: GroupInfo | None = None,
|
|
||||||
data: dict | None = None,
|
|
||||||
) -> OptimizedChatStream:
|
|
||||||
"""获取或创建流 - 优化版本"""
|
|
||||||
current_time = time.time()
|
|
||||||
|
|
||||||
# 1. 检查热缓存
|
|
||||||
if stream_id in self.hot_cache:
|
|
||||||
stream = self.hot_cache[stream_id]
|
|
||||||
# 移动到末尾(LRU更新)
|
|
||||||
self.hot_cache.move_to_end(stream_id)
|
|
||||||
self.stats.cache_hits += 1
|
|
||||||
logger.debug(f"热缓存命中: {stream_id}")
|
|
||||||
return stream.create_snapshot()
|
|
||||||
|
|
||||||
# 2. 检查温存储
|
|
||||||
if stream_id in self.warm_storage:
|
|
||||||
stream, last_access = self.warm_storage[stream_id]
|
|
||||||
self.warm_storage[stream_id] = (stream, current_time)
|
|
||||||
self.stats.cache_hits += 1
|
|
||||||
logger.debug(f"温缓存命中: {stream_id}")
|
|
||||||
# 提升到热缓存
|
|
||||||
await self._promote_to_hot(stream_id, stream)
|
|
||||||
return stream.create_snapshot()
|
|
||||||
|
|
||||||
# 3. 检查冷存储
|
|
||||||
if stream_id in self.cold_storage:
|
|
||||||
stream, last_access = self.cold_storage[stream_id]
|
|
||||||
self.cold_storage[stream_id] = (stream, current_time)
|
|
||||||
self.stats.cache_hits += 1
|
|
||||||
logger.debug(f"冷缓存命中: {stream_id}")
|
|
||||||
# 提升到温缓存
|
|
||||||
await self._promote_to_warm(stream_id, stream)
|
|
||||||
return stream.create_snapshot()
|
|
||||||
|
|
||||||
# 4. 缓存未命中,创建新流
|
|
||||||
self.stats.cache_misses += 1
|
|
||||||
stream = create_optimized_chat_stream(
|
|
||||||
stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info, data=data
|
|
||||||
)
|
|
||||||
logger.debug(f"缓存未命中,创建新流: {stream_id}")
|
|
||||||
|
|
||||||
# 添加到热缓存
|
|
||||||
await self._add_to_hot(stream_id, stream)
|
|
||||||
|
|
||||||
return stream
|
|
||||||
|
|
||||||
async def _add_to_hot(self, stream_id: str, stream: OptimizedChatStream):
|
|
||||||
"""添加到热缓存"""
|
|
||||||
# 检查是否需要驱逐
|
|
||||||
if len(self.hot_cache) >= self.max_hot_size:
|
|
||||||
await self._evict_from_hot()
|
|
||||||
|
|
||||||
self.hot_cache[stream_id] = stream
|
|
||||||
self.stats.hot_cache_size = len(self.hot_cache)
|
|
||||||
|
|
||||||
async def _promote_to_hot(self, stream_id: str, stream: OptimizedChatStream):
|
|
||||||
"""提升到热缓存"""
|
|
||||||
# 从温存储中移除
|
|
||||||
if stream_id in self.warm_storage:
|
|
||||||
del self.warm_storage[stream_id]
|
|
||||||
self.stats.warm_storage_size = len(self.warm_storage)
|
|
||||||
|
|
||||||
# 添加到热缓存
|
|
||||||
await self._add_to_hot(stream_id, stream)
|
|
||||||
logger.debug(f"流 {stream_id} 提升到热缓存")
|
|
||||||
|
|
||||||
async def _promote_to_warm(self, stream_id: str, stream: OptimizedChatStream):
|
|
||||||
"""提升到温缓存"""
|
|
||||||
# 从冷存储中移除
|
|
||||||
if stream_id in self.cold_storage:
|
|
||||||
del self.cold_storage[stream_id]
|
|
||||||
self.stats.cold_storage_size = len(self.cold_storage)
|
|
||||||
|
|
||||||
# 添加到温存储
|
|
||||||
if len(self.warm_storage) >= self.max_warm_size:
|
|
||||||
await self._evict_from_warm()
|
|
||||||
|
|
||||||
current_time = time.time()
|
|
||||||
self.warm_storage[stream_id] = (stream, current_time)
|
|
||||||
self.stats.warm_storage_size = len(self.warm_storage)
|
|
||||||
logger.debug(f"流 {stream_id} 提升到温缓存")
|
|
||||||
|
|
||||||
async def _evict_from_hot(self):
|
|
||||||
"""从热缓存驱逐最久未使用的流"""
|
|
||||||
if not self.hot_cache:
|
|
||||||
return
|
|
||||||
|
|
||||||
# LRU驱逐
|
|
||||||
stream_id, stream = self.hot_cache.popitem(last=False)
|
|
||||||
self.stats.evictions += 1
|
|
||||||
logger.debug(f"从热缓存驱逐: {stream_id}")
|
|
||||||
|
|
||||||
# 移动到温存储
|
|
||||||
if len(self.warm_storage) < self.max_warm_size:
|
|
||||||
current_time = time.time()
|
|
||||||
self.warm_storage[stream_id] = (stream, current_time)
|
|
||||||
self.stats.warm_storage_size = len(self.warm_storage)
|
|
||||||
else:
|
|
||||||
# 温存储也满了,直接删除
|
|
||||||
logger.debug(f"温存储已满,删除流: {stream_id}")
|
|
||||||
|
|
||||||
self.stats.hot_cache_size = len(self.hot_cache)
|
|
||||||
|
|
||||||
async def _evict_from_warm(self):
|
|
||||||
"""从温存储驱逐最久未使用的流"""
|
|
||||||
if not self.warm_storage:
|
|
||||||
return
|
|
||||||
|
|
||||||
# 找到最久未访问的流
|
|
||||||
oldest_stream_id = min(self.warm_storage.keys(), key=lambda k: self.warm_storage[k][1])
|
|
||||||
stream, last_access = self.warm_storage.pop(oldest_stream_id)
|
|
||||||
self.stats.evictions += 1
|
|
||||||
logger.debug(f"从温存储驱逐: {oldest_stream_id}")
|
|
||||||
|
|
||||||
# 移动到冷存储
|
|
||||||
if len(self.cold_storage) < self.max_cold_size:
|
|
||||||
current_time = time.time()
|
|
||||||
self.cold_storage[oldest_stream_id] = (stream, current_time)
|
|
||||||
self.stats.cold_storage_size = len(self.cold_storage)
|
|
||||||
else:
|
|
||||||
# 冷存储也满了,直接删除
|
|
||||||
logger.debug(f"冷存储已满,删除流: {oldest_stream_id}")
|
|
||||||
|
|
||||||
self.stats.warm_storage_size = len(self.warm_storage)
|
|
||||||
|
|
||||||
async def _cleanup_loop(self):
|
|
||||||
"""清理循环"""
|
|
||||||
logger.info("流缓存清理循环启动")
|
|
||||||
|
|
||||||
while self.is_running:
|
|
||||||
try:
|
|
||||||
await asyncio.sleep(self.cleanup_interval)
|
|
||||||
await self._perform_cleanup()
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.info("流缓存清理循环被取消")
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"流缓存清理出错: {e}")
|
|
||||||
|
|
||||||
logger.info("流缓存清理循环结束")
|
|
||||||
|
|
||||||
async def _perform_cleanup(self):
|
|
||||||
"""执行清理操作"""
|
|
||||||
current_time = time.time()
|
|
||||||
cleanup_stats = {
|
|
||||||
"hot_to_warm": 0,
|
|
||||||
"warm_to_cold": 0,
|
|
||||||
"cold_removed": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 1. 检查热缓存超时
|
|
||||||
hot_to_demote = []
|
|
||||||
for stream_id, stream in self.hot_cache.items():
|
|
||||||
# 获取最后访问时间(简化:使用创建时间作为近似)
|
|
||||||
last_access = getattr(stream, "last_active_time", stream.create_time)
|
|
||||||
if current_time - last_access > self.hot_timeout:
|
|
||||||
hot_to_demote.append(stream_id)
|
|
||||||
|
|
||||||
for stream_id in hot_to_demote:
|
|
||||||
stream = self.hot_cache.pop(stream_id)
|
|
||||||
current_time_local = time.time()
|
|
||||||
self.warm_storage[stream_id] = (stream, current_time_local)
|
|
||||||
cleanup_stats["hot_to_warm"] += 1
|
|
||||||
|
|
||||||
# 2. 检查温存储超时
|
|
||||||
warm_to_demote = []
|
|
||||||
for stream_id, (stream, last_access) in self.warm_storage.items():
|
|
||||||
if current_time - last_access > self.warm_timeout:
|
|
||||||
warm_to_demote.append(stream_id)
|
|
||||||
|
|
||||||
for stream_id in warm_to_demote:
|
|
||||||
stream, last_access = self.warm_storage.pop(stream_id)
|
|
||||||
self.cold_storage[stream_id] = (stream, last_access)
|
|
||||||
cleanup_stats["warm_to_cold"] += 1
|
|
||||||
|
|
||||||
# 3. 检查冷存储超时
|
|
||||||
cold_to_remove = []
|
|
||||||
for stream_id, (stream, last_access) in self.cold_storage.items():
|
|
||||||
if current_time - last_access > self.cold_timeout:
|
|
||||||
cold_to_remove.append(stream_id)
|
|
||||||
|
|
||||||
for stream_id in cold_to_remove:
|
|
||||||
self.cold_storage.pop(stream_id)
|
|
||||||
cleanup_stats["cold_removed"] += 1
|
|
||||||
|
|
||||||
# 更新统计信息
|
|
||||||
self.stats.hot_cache_size = len(self.hot_cache)
|
|
||||||
self.stats.warm_storage_size = len(self.warm_storage)
|
|
||||||
self.stats.cold_storage_size = len(self.cold_storage)
|
|
||||||
self.stats.last_cleanup_time = current_time
|
|
||||||
|
|
||||||
# 估算内存使用(粗略估计)
|
|
||||||
self.stats.total_memory_usage = (
|
|
||||||
len(self.hot_cache) * 1024 # 每个热流约1KB
|
|
||||||
+ len(self.warm_storage) * 512 # 每个温流约512B
|
|
||||||
+ len(self.cold_storage) * 256 # 每个冷流约256B
|
|
||||||
)
|
|
||||||
|
|
||||||
if sum(cleanup_stats.values()) > 0:
|
|
||||||
logger.info(
|
|
||||||
f"缓存清理完成: {cleanup_stats['hot_to_warm']}热→温, "
|
|
||||||
f"{cleanup_stats['warm_to_cold']}温→冷, "
|
|
||||||
f"{cleanup_stats['cold_removed']}冷删除"
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_stats(self) -> StreamCacheStats:
|
|
||||||
"""获取缓存统计信息"""
|
|
||||||
# 计算命中率
|
|
||||||
total_requests = self.stats.cache_hits + self.stats.cache_misses
|
|
||||||
hit_rate = self.stats.cache_hits / total_requests if total_requests > 0 else 0
|
|
||||||
|
|
||||||
stats_copy = StreamCacheStats(
|
|
||||||
hot_cache_size=self.stats.hot_cache_size,
|
|
||||||
warm_storage_size=self.stats.warm_storage_size,
|
|
||||||
cold_storage_size=self.stats.cold_storage_size,
|
|
||||||
total_memory_usage=self.stats.total_memory_usage,
|
|
||||||
cache_hits=self.stats.cache_hits,
|
|
||||||
cache_misses=self.stats.cache_misses,
|
|
||||||
evictions=self.stats.evictions,
|
|
||||||
last_cleanup_time=self.stats.last_cleanup_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 添加命中率信息
|
|
||||||
stats_copy.hit_rate = hit_rate
|
|
||||||
|
|
||||||
return stats_copy
|
|
||||||
|
|
||||||
def clear_cache(self):
|
|
||||||
"""清空所有缓存"""
|
|
||||||
self.hot_cache.clear()
|
|
||||||
self.warm_storage.clear()
|
|
||||||
self.cold_storage.clear()
|
|
||||||
|
|
||||||
self.stats.hot_cache_size = 0
|
|
||||||
self.stats.warm_storage_size = 0
|
|
||||||
self.stats.cold_storage_size = 0
|
|
||||||
self.stats.total_memory_usage = 0
|
|
||||||
|
|
||||||
logger.info("所有缓存已清空")
|
|
||||||
|
|
||||||
async def get_stream_snapshot(self, stream_id: str) -> OptimizedChatStream | None:
|
|
||||||
"""获取流的快照(不修改缓存状态)"""
|
|
||||||
if stream_id in self.hot_cache:
|
|
||||||
return self.hot_cache[stream_id].create_snapshot()
|
|
||||||
elif stream_id in self.warm_storage:
|
|
||||||
return self.warm_storage[stream_id][0].create_snapshot()
|
|
||||||
elif stream_id in self.cold_storage:
|
|
||||||
return self.cold_storage[stream_id][0].create_snapshot()
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_cached_stream_ids(self) -> set[str]:
|
|
||||||
"""获取所有缓存的流ID"""
|
|
||||||
return set(self.hot_cache.keys()) | set(self.warm_storage.keys()) | set(self.cold_storage.keys())
|
|
||||||
|
|
||||||
|
|
||||||
# 全局缓存管理器实例
|
|
||||||
_cache_manager: TieredStreamCache | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_stream_cache_manager() -> TieredStreamCache:
|
|
||||||
"""获取流缓存管理器实例"""
|
|
||||||
global _cache_manager
|
|
||||||
if _cache_manager is None:
|
|
||||||
_cache_manager = TieredStreamCache()
|
|
||||||
return _cache_manager
|
|
||||||
|
|
||||||
|
|
||||||
async def init_stream_cache_manager():
|
|
||||||
"""初始化流缓存管理器"""
|
|
||||||
manager = get_stream_cache_manager()
|
|
||||||
await manager.start()
|
|
||||||
|
|
||||||
|
|
||||||
async def shutdown_stream_cache_manager():
|
|
||||||
"""关闭流缓存管理器"""
|
|
||||||
manager = get_stream_cache_manager()
|
|
||||||
await manager.stop()
|
|
||||||
@@ -9,10 +9,10 @@ from maim_message import UserInfo
|
|||||||
from src.chat.antipromptinjector import initialize_anti_injector
|
from src.chat.antipromptinjector import initialize_anti_injector
|
||||||
from src.chat.message_manager import message_manager
|
from src.chat.message_manager import message_manager
|
||||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||||
from src.chat.message_receive.message import MessageRecv
|
|
||||||
from src.chat.message_receive.storage import MessageStorage
|
from src.chat.message_receive.storage import MessageStorage
|
||||||
from src.chat.utils.prompt import create_prompt_async, global_prompt_manager
|
from src.chat.utils.prompt import create_prompt_async, global_prompt_manager
|
||||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.mood.mood_manager import mood_manager # 导入情绪管理器
|
from src.mood.mood_manager import mood_manager # 导入情绪管理器
|
||||||
@@ -105,10 +105,10 @@ class ChatBot:
|
|||||||
|
|
||||||
self._started = True
|
self._started = True
|
||||||
|
|
||||||
async def _process_plus_commands(self, message: MessageRecv):
|
async def _process_plus_commands(self, message: DatabaseMessages, chat: ChatStream):
|
||||||
"""独立处理PlusCommand系统"""
|
"""独立处理PlusCommand系统"""
|
||||||
try:
|
try:
|
||||||
text = message.processed_plain_text
|
text = message.processed_plain_text or ""
|
||||||
|
|
||||||
# 获取配置的命令前缀
|
# 获取配置的命令前缀
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
@@ -166,10 +166,10 @@ class ChatBot:
|
|||||||
|
|
||||||
# 检查命令是否被禁用
|
# 检查命令是否被禁用
|
||||||
if (
|
if (
|
||||||
message.chat_stream
|
chat
|
||||||
and message.chat_stream.stream_id
|
and chat.stream_id
|
||||||
and plus_command_name
|
and plus_command_name
|
||||||
in global_announcement_manager.get_disabled_chat_commands(message.chat_stream.stream_id)
|
in global_announcement_manager.get_disabled_chat_commands(chat.stream_id)
|
||||||
):
|
):
|
||||||
logger.info("用户禁用的PlusCommand,跳过处理")
|
logger.info("用户禁用的PlusCommand,跳过处理")
|
||||||
return False, None, True
|
return False, None, True
|
||||||
@@ -181,11 +181,14 @@ class ChatBot:
|
|||||||
|
|
||||||
# 创建PlusCommand实例
|
# 创建PlusCommand实例
|
||||||
plus_command_instance = plus_command_class(message, plugin_config)
|
plus_command_instance = plus_command_class(message, plugin_config)
|
||||||
|
|
||||||
|
# 为插件实例设置 chat_stream 运行时属性
|
||||||
|
setattr(plus_command_instance, "chat_stream", chat)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 检查聊天类型限制
|
# 检查聊天类型限制
|
||||||
if not plus_command_instance.is_chat_type_allowed():
|
if not plus_command_instance.is_chat_type_allowed():
|
||||||
is_group = message.message_info.group_info
|
is_group = chat.group_info is not None
|
||||||
logger.info(
|
logger.info(
|
||||||
f"PlusCommand {plus_command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
|
f"PlusCommand {plus_command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
|
||||||
)
|
)
|
||||||
@@ -225,11 +228,11 @@ class ChatBot:
|
|||||||
logger.error(f"处理PlusCommand时出错: {e}")
|
logger.error(f"处理PlusCommand时出错: {e}")
|
||||||
return False, None, True # 出错时继续处理消息
|
return False, None, True # 出错时继续处理消息
|
||||||
|
|
||||||
async def _process_commands_with_new_system(self, message: MessageRecv):
|
async def _process_commands_with_new_system(self, message: DatabaseMessages, chat: ChatStream):
|
||||||
# sourcery skip: use-named-expression
|
# sourcery skip: use-named-expression
|
||||||
"""使用新插件系统处理命令"""
|
"""使用新插件系统处理命令"""
|
||||||
try:
|
try:
|
||||||
text = message.processed_plain_text
|
text = message.processed_plain_text or ""
|
||||||
|
|
||||||
# 使用新的组件注册中心查找命令
|
# 使用新的组件注册中心查找命令
|
||||||
command_result = component_registry.find_command_by_text(text)
|
command_result = component_registry.find_command_by_text(text)
|
||||||
@@ -238,10 +241,10 @@ class ChatBot:
|
|||||||
plugin_name = command_info.plugin_name
|
plugin_name = command_info.plugin_name
|
||||||
command_name = command_info.name
|
command_name = command_info.name
|
||||||
if (
|
if (
|
||||||
message.chat_stream
|
chat
|
||||||
and message.chat_stream.stream_id
|
and chat.stream_id
|
||||||
and command_name
|
and command_name
|
||||||
in global_announcement_manager.get_disabled_chat_commands(message.chat_stream.stream_id)
|
in global_announcement_manager.get_disabled_chat_commands(chat.stream_id)
|
||||||
):
|
):
|
||||||
logger.info("用户禁用的命令,跳过处理")
|
logger.info("用户禁用的命令,跳过处理")
|
||||||
return False, None, True
|
return False, None, True
|
||||||
@@ -254,11 +257,14 @@ class ChatBot:
|
|||||||
# 创建命令实例
|
# 创建命令实例
|
||||||
command_instance: BaseCommand = command_class(message, plugin_config)
|
command_instance: BaseCommand = command_class(message, plugin_config)
|
||||||
command_instance.set_matched_groups(matched_groups)
|
command_instance.set_matched_groups(matched_groups)
|
||||||
|
|
||||||
|
# 为插件实例设置 chat_stream 运行时属性
|
||||||
|
setattr(command_instance, "chat_stream", chat)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 检查聊天类型限制
|
# 检查聊天类型限制
|
||||||
if not command_instance.is_chat_type_allowed():
|
if not command_instance.is_chat_type_allowed():
|
||||||
is_group = message.message_info.group_info
|
is_group = chat.group_info is not None
|
||||||
logger.info(
|
logger.info(
|
||||||
f"命令 {command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
|
f"命令 {command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
|
||||||
)
|
)
|
||||||
@@ -295,13 +301,20 @@ class ChatBot:
|
|||||||
logger.error(f"处理命令时出错: {e}")
|
logger.error(f"处理命令时出错: {e}")
|
||||||
return False, None, True # 出错时继续处理消息
|
return False, None, True # 出错时继续处理消息
|
||||||
|
|
||||||
async def handle_notice_message(self, message: MessageRecv):
|
async def handle_notice_message(self, message: DatabaseMessages):
|
||||||
"""处理notice消息
|
"""处理notice消息
|
||||||
|
|
||||||
notice消息是系统事件通知(如禁言、戳一戳等),具有以下特点:
|
notice消息是系统事件通知(如禁言、戳一戳等),具有以下特点:
|
||||||
1. 默认不触发聊天流程,只记录
|
1. 默认不触发聊天流程,只记录
|
||||||
2. 可通过配置开启触发聊天流程
|
2. 可通过配置开启触发聊天流程
|
||||||
3. 会在提示词中展示
|
3. 会在提示词中展示
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: DatabaseMessages 对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True表示notice已完整处理(需要存储并终止后续流程)
|
||||||
|
False表示不是notice或notice需要继续处理(触发聊天流程)
|
||||||
"""
|
"""
|
||||||
# 检查是否是notice消息
|
# 检查是否是notice消息
|
||||||
if message.is_notify:
|
if message.is_notify:
|
||||||
@@ -309,53 +322,42 @@ class ChatBot:
|
|||||||
|
|
||||||
# 根据配置决定是否触发聊天流程
|
# 根据配置决定是否触发聊天流程
|
||||||
if not global_config.notice.enable_notice_trigger_chat:
|
if not global_config.notice.enable_notice_trigger_chat:
|
||||||
logger.debug("notice消息不触发聊天流程(配置已关闭)")
|
logger.debug("notice消息不触发聊天流程(配置已关闭),将存储后终止")
|
||||||
return True # 返回True表示已处理,不继续后续流程
|
return True # 返回True:需要在调用处存储并终止
|
||||||
else:
|
else:
|
||||||
logger.debug("notice消息触发聊天流程(配置已开启)")
|
logger.debug("notice消息触发聊天流程(配置已开启),继续处理")
|
||||||
return False # 返回False表示继续处理,触发聊天流程
|
return False # 返回False:继续正常流程,作为普通消息处理
|
||||||
|
|
||||||
# 兼容旧的notice判断方式
|
# 兼容旧的notice判断方式
|
||||||
if message.message_info.message_id == "notice":
|
if message.message_id == "notice":
|
||||||
message.is_notify = True
|
# 为 DatabaseMessages 设置 is_notify 运行时属性
|
||||||
|
from src.chat.message_receive.message_processor import set_db_message_runtime_attr
|
||||||
|
set_db_message_runtime_attr(message, "is_notify", True)
|
||||||
logger.info("旧格式notice消息")
|
logger.info("旧格式notice消息")
|
||||||
|
|
||||||
# 同样根据配置决定
|
# 同样根据配置决定
|
||||||
if not global_config.notice.enable_notice_trigger_chat:
|
if not global_config.notice.enable_notice_trigger_chat:
|
||||||
return True
|
logger.debug("旧格式notice消息不触发聊天流程,将存储后终止")
|
||||||
|
return True # 需要存储并终止
|
||||||
else:
|
else:
|
||||||
return False
|
logger.debug("旧格式notice消息触发聊天流程,继续处理")
|
||||||
|
return False # 继续正常流程
|
||||||
|
|
||||||
# 处理适配器响应消息
|
# DatabaseMessages 不再有 message_segment,适配器响应处理已在消息处理阶段完成
|
||||||
if hasattr(message, "message_segment") and message.message_segment:
|
# 这里保留逻辑以防万一,但实际上不会再执行到
|
||||||
if message.message_segment.type == "adapter_response":
|
return False # 不是notice消息,继续正常流程
|
||||||
await self.handle_adapter_response(message)
|
|
||||||
return True
|
|
||||||
elif message.message_segment.type == "adapter_command":
|
|
||||||
# 适配器命令消息不需要进一步处理
|
|
||||||
logger.debug("收到适配器命令消息,跳过后续处理")
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
async def handle_adapter_response(self, message: DatabaseMessages):
|
||||||
|
"""处理适配器命令响应
|
||||||
async def handle_adapter_response(self, message: MessageRecv):
|
|
||||||
"""处理适配器命令响应"""
|
注意: 此方法目前未被调用,但保留以备将来使用
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
from src.plugin_system.apis.send_api import put_adapter_response
|
from src.plugin_system.apis.send_api import put_adapter_response
|
||||||
|
|
||||||
seg_data = message.message_segment.data
|
# DatabaseMessages 使用 message_segments 字段存储消息段
|
||||||
if isinstance(seg_data, dict):
|
# 注意: 这可能需要根据实际使用情况进行调整
|
||||||
request_id = seg_data.get("request_id")
|
logger.warning("handle_adapter_response 方法被调用,但目前未实现对 DatabaseMessages 的支持")
|
||||||
response_data = seg_data.get("response")
|
|
||||||
else:
|
|
||||||
request_id = None
|
|
||||||
response_data = None
|
|
||||||
|
|
||||||
if request_id and response_data:
|
|
||||||
logger.debug(f"收到适配器响应: request_id={request_id}")
|
|
||||||
put_adapter_response(request_id, response_data)
|
|
||||||
else:
|
|
||||||
logger.warning("适配器响应消息格式不正确")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理适配器响应时出错: {e}")
|
logger.error(f"处理适配器响应时出错: {e}")
|
||||||
@@ -381,9 +383,6 @@ class ChatBot:
|
|||||||
await self._ensure_started()
|
await self._ensure_started()
|
||||||
|
|
||||||
# 控制握手等消息可能缺少 message_info,这里直接跳过避免 KeyError
|
# 控制握手等消息可能缺少 message_info,这里直接跳过避免 KeyError
|
||||||
if not isinstance(message_data, dict):
|
|
||||||
logger.warning(f"收到无法解析的消息类型: {type(message_data)},已跳过")
|
|
||||||
return
|
|
||||||
message_info = message_data.get("message_info")
|
message_info = message_data.get("message_info")
|
||||||
if not isinstance(message_info, dict):
|
if not isinstance(message_info, dict):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -392,8 +391,6 @@ class ChatBot:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
platform = message_info.get("platform")
|
|
||||||
|
|
||||||
if message_info.get("group_info") is not None:
|
if message_info.get("group_info") is not None:
|
||||||
message_info["group_info"]["group_id"] = str(
|
message_info["group_info"]["group_id"] = str(
|
||||||
message_info["group_info"]["group_id"]
|
message_info["group_info"]["group_id"]
|
||||||
@@ -404,74 +401,94 @@ class ChatBot:
|
|||||||
)
|
)
|
||||||
# print(message_data)
|
# print(message_data)
|
||||||
# logger.debug(str(message_data))
|
# logger.debug(str(message_data))
|
||||||
message = MessageRecv(message_data)
|
|
||||||
|
# 先提取基础信息检查是否是自身消息上报
|
||||||
group_info = message.message_info.group_info
|
from maim_message import BaseMessageInfo
|
||||||
user_info = message.message_info.user_info
|
temp_message_info = BaseMessageInfo.from_dict(message_data.get("message_info", {}))
|
||||||
if message.message_info.additional_config:
|
if temp_message_info.additional_config:
|
||||||
sent_message = message.message_info.additional_config.get("echo", False)
|
sent_message = temp_message_info.additional_config.get("echo", False)
|
||||||
if sent_message: # 这一段只是为了在一切处理前劫持上报的自身消息,用于更新message_id,需要ada支持上报事件,实际测试中不会对正常使用造成任何问题
|
if sent_message: # 这一段只是为了在一切处理前劫持上报的自身消息,用于更新message_id,需要ada支持上报事件,实际测试中不会对正常使用造成任何问题
|
||||||
await MessageStorage.update_message(message)
|
# 直接使用消息字典更新,不再需要创建 MessageRecv
|
||||||
|
await MessageStorage.update_message(message_data)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
group_info = temp_message_info.group_info
|
||||||
|
user_info = temp_message_info.user_info
|
||||||
|
|
||||||
get_chat_manager().register_message(message)
|
# 获取或创建聊天流
|
||||||
|
|
||||||
chat = await get_chat_manager().get_or_create_stream(
|
chat = await get_chat_manager().get_or_create_stream(
|
||||||
platform=message.message_info.platform, # type: ignore
|
platform=temp_message_info.platform, # type: ignore
|
||||||
user_info=user_info, # type: ignore
|
user_info=user_info, # type: ignore
|
||||||
group_info=group_info,
|
group_info=group_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
message.update_chat_stream(chat)
|
# 使用新的消息处理器直接生成 DatabaseMessages
|
||||||
|
from src.chat.message_receive.message_processor import process_message_from_dict
|
||||||
# 处理消息内容,生成纯文本
|
message = await process_message_from_dict(
|
||||||
await message.process()
|
message_dict=message_data,
|
||||||
|
stream_id=chat.stream_id,
|
||||||
|
platform=chat.platform
|
||||||
|
)
|
||||||
|
|
||||||
|
# 填充聊天流时间信息
|
||||||
|
message.chat_info.create_time = chat.create_time
|
||||||
|
message.chat_info.last_active_time = chat.last_active_time
|
||||||
|
|
||||||
|
# 注册消息到聊天管理器
|
||||||
|
get_chat_manager().register_message(message)
|
||||||
|
|
||||||
|
# 检测是否提及机器人
|
||||||
message.is_mentioned, _ = is_mentioned_bot_in_message(message)
|
message.is_mentioned, _ = is_mentioned_bot_in_message(message)
|
||||||
|
|
||||||
# 在这里打印[所见]日志,确保在所有处理和过滤之前记录
|
# 在这里打印[所见]日志,确保在所有处理和过滤之前记录
|
||||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||||
if message.message_info.user_info:
|
user_nickname = message.user_info.user_nickname if message.user_info else "未知用户"
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{chat_name}]{message.message_info.user_info.user_nickname}:{message.processed_plain_text}\u001b[0m"
|
f"[{chat_name}]{user_nickname}:{message.processed_plain_text}\u001b[0m"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 在此添加硬编码过滤,防止回复图片处理失败的消息
|
# 在此添加硬编码过滤,防止回复图片处理失败的消息
|
||||||
failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"]
|
failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"]
|
||||||
if any(keyword in message.processed_plain_text for keyword in failure_keywords):
|
processed_text = message.processed_plain_text or ""
|
||||||
logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({message.processed_plain_text}),消息被静默处理。")
|
if any(keyword in processed_text for keyword in failure_keywords):
|
||||||
|
logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({processed_text}),消息被静默处理。")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 处理notice消息
|
# 处理notice消息
|
||||||
|
# notice_handled=True: 表示notice不触发聊天,需要在此存储并终止
|
||||||
|
# notice_handled=False: 表示notice触发聊天或不是notice,继续正常流程
|
||||||
notice_handled = await self.handle_notice_message(message)
|
notice_handled = await self.handle_notice_message(message)
|
||||||
if notice_handled:
|
if notice_handled:
|
||||||
# notice消息已处理,使用统一的转换方法
|
# notice消息不触发聊天流程,在此进行存储和记录后终止
|
||||||
try:
|
try:
|
||||||
# 直接转换为 DatabaseMessages
|
# message 已经是 DatabaseMessages,直接使用
|
||||||
db_message = message.to_database_message()
|
|
||||||
|
|
||||||
# 添加到message_manager(这会将notice添加到全局notice管理器)
|
# 添加到message_manager(这会将notice添加到全局notice管理器)
|
||||||
await message_manager.add_message(message.chat_stream.stream_id, db_message)
|
await message_manager.add_message(chat.stream_id, message)
|
||||||
logger.info(f"✅ Notice消息已添加到message_manager: type={message.notice_type}, stream={message.chat_stream.stream_id}")
|
logger.info(f"✅ Notice消息已添加到message_manager: type={message.notice_type}, stream={chat.stream_id}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Notice消息添加到message_manager失败: {e}", exc_info=True)
|
logger.error(f"Notice消息添加到message_manager失败: {e}", exc_info=True)
|
||||||
|
|
||||||
# 存储后直接返回
|
# 存储notice消息到数据库(需要更新 storage.py 支持 DatabaseMessages)
|
||||||
await MessageStorage.store_message(message, chat)
|
# 暂时跳过存储,等待更新 storage.py
|
||||||
logger.debug("notice消息已存储,跳过后续处理")
|
logger.debug("notice消息已添加到message_manager(存储功能待更新)")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# 如果notice_handled=False,则继续执行后续流程
|
||||||
|
# 对于启用触发聊天的notice,会在后续的正常流程中被存储和处理
|
||||||
|
|
||||||
# 过滤检查
|
# 过滤检查
|
||||||
|
# DatabaseMessages 使用 display_message 作为原始消息表示
|
||||||
|
raw_text = message.display_message or message.processed_plain_text or ""
|
||||||
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
|
||||||
message.raw_message, # type: ignore
|
raw_text,
|
||||||
chat,
|
chat,
|
||||||
user_info, # type: ignore
|
user_info, # type: ignore
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
# 命令处理 - 首先尝试PlusCommand独立处理
|
# 命令处理 - 首先尝试PlusCommand独立处理
|
||||||
is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message)
|
is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message, chat)
|
||||||
|
|
||||||
# 如果是PlusCommand且不需要继续处理,则直接返回
|
# 如果是PlusCommand且不需要继续处理,则直接返回
|
||||||
if is_plus_command and not plus_continue_process:
|
if is_plus_command and not plus_continue_process:
|
||||||
@@ -481,7 +498,7 @@ class ChatBot:
|
|||||||
|
|
||||||
# 如果不是PlusCommand,尝试传统的BaseCommand处理
|
# 如果不是PlusCommand,尝试传统的BaseCommand处理
|
||||||
if not is_plus_command:
|
if not is_plus_command:
|
||||||
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message)
|
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message, chat)
|
||||||
|
|
||||||
# 如果是命令且不需要继续处理,则直接返回
|
# 如果是命令且不需要继续处理,则直接返回
|
||||||
if is_command and not continue_process:
|
if is_command and not continue_process:
|
||||||
@@ -493,24 +510,14 @@ class ChatBot:
|
|||||||
if result and not result.all_continue_process():
|
if result and not result.all_continue_process():
|
||||||
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理")
|
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理")
|
||||||
|
|
||||||
# TODO:暂不可用
|
# TODO:暂不可用 - DatabaseMessages 不再有 message_info.template_info
|
||||||
# 确认从接口发来的message是否有自定义的prompt模板信息
|
# 确认从接口发来的message是否有自定义的prompt模板信息
|
||||||
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
# 这个功能需要在 adapter 层通过 additional_config 传递
|
||||||
template_group_name: str | None = message.message_info.template_info.template_name # type: ignore
|
template_group_name = None
|
||||||
template_items = message.message_info.template_info.template_items
|
|
||||||
async with global_prompt_manager.async_message_scope(template_group_name):
|
|
||||||
if isinstance(template_items, dict):
|
|
||||||
for k in template_items.keys():
|
|
||||||
await create_prompt_async(template_items[k], k)
|
|
||||||
logger.debug(f"注册{template_items[k]},{k}")
|
|
||||||
else:
|
|
||||||
template_group_name = None
|
|
||||||
|
|
||||||
async def preprocess():
|
async def preprocess():
|
||||||
# 使用统一的转换方法创建数据库消息对象
|
# message 已经是 DatabaseMessages,直接使用
|
||||||
db_message = message.to_database_message()
|
group_info = chat.group_info
|
||||||
|
|
||||||
group_info = getattr(message.chat_stream, "group_info", None)
|
|
||||||
|
|
||||||
# 先交给消息管理器处理,计算兴趣度等衍生数据
|
# 先交给消息管理器处理,计算兴趣度等衍生数据
|
||||||
try:
|
try:
|
||||||
@@ -527,31 +534,15 @@ class ChatBot:
|
|||||||
should_process_in_manager = False
|
should_process_in_manager = False
|
||||||
|
|
||||||
if should_process_in_manager:
|
if should_process_in_manager:
|
||||||
await message_manager.add_message(message.chat_stream.stream_id, db_message)
|
await message_manager.add_message(chat.stream_id, message)
|
||||||
logger.debug(f"消息已添加到消息管理器: {message.chat_stream.stream_id}")
|
logger.debug(f"消息已添加到消息管理器: {chat.stream_id}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"消息添加到消息管理器失败: {e}")
|
logger.error(f"消息添加到消息管理器失败: {e}")
|
||||||
|
|
||||||
# 将兴趣度结果同步回原始消息,便于后续流程使用
|
|
||||||
message.interest_value = getattr(db_message, "interest_value", getattr(message, "interest_value", 0.0))
|
|
||||||
setattr(
|
|
||||||
message,
|
|
||||||
"should_reply",
|
|
||||||
getattr(db_message, "should_reply", getattr(message, "should_reply", False)),
|
|
||||||
)
|
|
||||||
setattr(message, "should_act", getattr(db_message, "should_act", getattr(message, "should_act", False)))
|
|
||||||
|
|
||||||
# 存储消息到数据库,只进行一次写入
|
# 存储消息到数据库,只进行一次写入
|
||||||
try:
|
try:
|
||||||
await MessageStorage.store_message(message, message.chat_stream)
|
await MessageStorage.store_message(message, chat)
|
||||||
logger.debug(
|
|
||||||
"消息已存储到数据库: %s (interest=%.3f, should_reply=%s, should_act=%s)",
|
|
||||||
message.message_info.message_id,
|
|
||||||
getattr(message, "interest_value", -1.0),
|
|
||||||
getattr(message, "should_reply", None),
|
|
||||||
getattr(message, "should_act", None),
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"存储消息到数据库失败: {e}")
|
logger.error(f"存储消息到数据库失败: {e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
@@ -560,13 +551,13 @@ class ChatBot:
|
|||||||
try:
|
try:
|
||||||
if global_config.mood.enable_mood:
|
if global_config.mood.enable_mood:
|
||||||
# 获取兴趣度用于情绪更新
|
# 获取兴趣度用于情绪更新
|
||||||
interest_rate = getattr(message, "interest_value", 0.0)
|
interest_rate = message.interest_value
|
||||||
if interest_rate is None:
|
if interest_rate is None:
|
||||||
interest_rate = 0.0
|
interest_rate = 0.0
|
||||||
logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}")
|
logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}")
|
||||||
|
|
||||||
# 获取当前聊天的情绪对象并更新情绪状态
|
# 获取当前聊天的情绪对象并更新情绪状态
|
||||||
chat_mood = mood_manager.get_mood_by_chat_id(message.chat_stream.stream_id)
|
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
|
||||||
await chat_mood.update_mood_by_message(message, interest_rate)
|
await chat_mood.update_mood_by_message(message, interest_rate)
|
||||||
logger.debug("情绪状态更新完成")
|
logger.debug("情绪状态更新完成")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -12,13 +12,10 @@ from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
|||||||
|
|
||||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||||
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
|
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config # 新增导入
|
from src.config.config import global_config # 新增导入
|
||||||
|
|
||||||
# 避免循环导入,使用TYPE_CHECKING进行类型提示
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .message import MessageRecv
|
|
||||||
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
@@ -33,7 +30,7 @@ class ChatStream:
|
|||||||
self,
|
self,
|
||||||
stream_id: str,
|
stream_id: str,
|
||||||
platform: str,
|
platform: str,
|
||||||
user_info: UserInfo,
|
user_info: UserInfo | None = None,
|
||||||
group_info: GroupInfo | None = None,
|
group_info: GroupInfo | None = None,
|
||||||
data: dict | None = None,
|
data: dict | None = None,
|
||||||
):
|
):
|
||||||
@@ -46,20 +43,18 @@ class ChatStream:
|
|||||||
self.sleep_pressure = data.get("sleep_pressure", 0.0) if data else 0.0
|
self.sleep_pressure = data.get("sleep_pressure", 0.0) if data else 0.0
|
||||||
self.saved = False
|
self.saved = False
|
||||||
|
|
||||||
# 使用StreamContext替代ChatMessageContext
|
# 创建单流上下文管理器(包含StreamContext)
|
||||||
|
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||||||
from src.common.data_models.message_manager_data_model import StreamContext
|
from src.common.data_models.message_manager_data_model import StreamContext
|
||||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||||
|
|
||||||
# 创建StreamContext
|
|
||||||
self.stream_context: StreamContext = StreamContext(
|
|
||||||
stream_id=stream_id, chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, chat_mode=ChatMode.NORMAL
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建单流上下文管理器
|
|
||||||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
|
||||||
|
|
||||||
self.context_manager: SingleStreamContextManager = SingleStreamContextManager(
|
self.context_manager: SingleStreamContextManager = SingleStreamContextManager(
|
||||||
stream_id=stream_id, context=self.stream_context
|
stream_id=stream_id,
|
||||||
|
context=StreamContext(
|
||||||
|
stream_id=stream_id,
|
||||||
|
chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE,
|
||||||
|
chat_mode=ChatMode.NORMAL,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 基础参数
|
# 基础参数
|
||||||
@@ -88,13 +83,12 @@ class ChatStream:
|
|||||||
new_stream._focus_energy = self._focus_energy
|
new_stream._focus_energy = self._focus_energy
|
||||||
new_stream.no_reply_consecutive = self.no_reply_consecutive
|
new_stream.no_reply_consecutive = self.no_reply_consecutive
|
||||||
|
|
||||||
# 复制 stream_context,但跳过 processing_task
|
# 复制 context_manager(包含 stream_context)
|
||||||
new_stream.stream_context = copy.deepcopy(self.stream_context, memo)
|
|
||||||
if hasattr(new_stream.stream_context, "processing_task"):
|
|
||||||
new_stream.stream_context.processing_task = None
|
|
||||||
|
|
||||||
# 复制 context_manager
|
|
||||||
new_stream.context_manager = copy.deepcopy(self.context_manager, memo)
|
new_stream.context_manager = copy.deepcopy(self.context_manager, memo)
|
||||||
|
|
||||||
|
# 清理 processing_task(如果存在)
|
||||||
|
if hasattr(new_stream.context_manager.context, "processing_task"):
|
||||||
|
new_stream.context_manager.context.processing_task = None
|
||||||
|
|
||||||
return new_stream
|
return new_stream
|
||||||
|
|
||||||
@@ -111,11 +105,11 @@ class ChatStream:
|
|||||||
"focus_energy": self.focus_energy,
|
"focus_energy": self.focus_energy,
|
||||||
# 基础兴趣度
|
# 基础兴趣度
|
||||||
"base_interest_energy": self.base_interest_energy,
|
"base_interest_energy": self.base_interest_energy,
|
||||||
# stream_context基本信息
|
# stream_context基本信息(通过context_manager访问)
|
||||||
"stream_context_chat_type": self.stream_context.chat_type.value,
|
"stream_context_chat_type": self.context_manager.context.chat_type.value,
|
||||||
"stream_context_chat_mode": self.stream_context.chat_mode.value,
|
"stream_context_chat_mode": self.context_manager.context.chat_mode.value,
|
||||||
# 统计信息
|
# 统计信息
|
||||||
"interruption_count": self.stream_context.interruption_count,
|
"interruption_count": self.context_manager.context.interruption_count,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -132,27 +126,19 @@ class ChatStream:
|
|||||||
data=data,
|
data=data,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 恢复stream_context信息
|
# 恢复stream_context信息(通过context_manager访问)
|
||||||
if "stream_context_chat_type" in data:
|
if "stream_context_chat_type" in data:
|
||||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||||
|
|
||||||
instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"])
|
instance.context_manager.context.chat_type = ChatType(data["stream_context_chat_type"])
|
||||||
if "stream_context_chat_mode" in data:
|
if "stream_context_chat_mode" in data:
|
||||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||||
|
|
||||||
instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"])
|
instance.context_manager.context.chat_mode = ChatMode(data["stream_context_chat_mode"])
|
||||||
|
|
||||||
# 恢复interruption_count信息
|
# 恢复interruption_count信息
|
||||||
if "interruption_count" in data:
|
if "interruption_count" in data:
|
||||||
instance.stream_context.interruption_count = data["interruption_count"]
|
instance.context_manager.context.interruption_count = data["interruption_count"]
|
||||||
|
|
||||||
# 确保 context_manager 已初始化
|
|
||||||
if not hasattr(instance, "context_manager"):
|
|
||||||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
|
||||||
|
|
||||||
instance.context_manager = SingleStreamContextManager(
|
|
||||||
stream_id=instance.stream_id, context=instance.stream_context
|
|
||||||
)
|
|
||||||
|
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
@@ -160,156 +146,44 @@ class ChatStream:
|
|||||||
"""获取原始的、未哈希的聊天流ID字符串"""
|
"""获取原始的、未哈希的聊天流ID字符串"""
|
||||||
if self.group_info:
|
if self.group_info:
|
||||||
return f"{self.platform}:{self.group_info.group_id}:group"
|
return f"{self.platform}:{self.group_info.group_id}:group"
|
||||||
else:
|
elif self.user_info:
|
||||||
return f"{self.platform}:{self.user_info.user_id}:private"
|
return f"{self.platform}:{self.user_info.user_id}:private"
|
||||||
|
else:
|
||||||
|
return f"{self.platform}:unknown:private"
|
||||||
|
|
||||||
def update_active_time(self):
|
def update_active_time(self):
|
||||||
"""更新最后活跃时间"""
|
"""更新最后活跃时间"""
|
||||||
self.last_active_time = time.time()
|
self.last_active_time = time.time()
|
||||||
self.saved = False
|
self.saved = False
|
||||||
|
|
||||||
async def set_context(self, message: "MessageRecv"):
|
async def set_context(self, message: DatabaseMessages):
|
||||||
"""设置聊天消息上下文"""
|
"""设置聊天消息上下文
|
||||||
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
|
|
||||||
import json
|
|
||||||
|
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
|
||||||
|
|
||||||
# 安全获取message_info中的数据
|
|
||||||
message_info = getattr(message, "message_info", {})
|
|
||||||
user_info = getattr(message_info, "user_info", {})
|
|
||||||
group_info = getattr(message_info, "group_info", {})
|
|
||||||
|
|
||||||
# 提取reply_to信息(从message_segment中查找reply类型的段)
|
|
||||||
reply_to = None
|
|
||||||
if hasattr(message, "message_segment") and message.message_segment:
|
|
||||||
reply_to = self._extract_reply_from_segment(message.message_segment)
|
|
||||||
|
|
||||||
# 完整的数据转移逻辑
|
|
||||||
db_message = DatabaseMessages(
|
|
||||||
# 基础消息信息
|
|
||||||
message_id=getattr(message, "message_id", ""),
|
|
||||||
time=getattr(message, "time", time.time()),
|
|
||||||
chat_id=self._generate_chat_id(message_info),
|
|
||||||
reply_to=reply_to,
|
|
||||||
# 兴趣度相关
|
|
||||||
interest_value=getattr(message, "interest_value", 0.0),
|
|
||||||
# 关键词
|
|
||||||
key_words=json.dumps(getattr(message, "key_words", []), ensure_ascii=False)
|
|
||||||
if getattr(message, "key_words", None)
|
|
||||||
else None,
|
|
||||||
key_words_lite=json.dumps(getattr(message, "key_words_lite", []), ensure_ascii=False)
|
|
||||||
if getattr(message, "key_words_lite", None)
|
|
||||||
else None,
|
|
||||||
# 消息状态标记
|
|
||||||
is_mentioned=getattr(message, "is_mentioned", None),
|
|
||||||
is_at=getattr(message, "is_at", False),
|
|
||||||
is_emoji=getattr(message, "is_emoji", False),
|
|
||||||
is_picid=getattr(message, "is_picid", False),
|
|
||||||
is_voice=getattr(message, "is_voice", False),
|
|
||||||
is_video=getattr(message, "is_video", False),
|
|
||||||
is_command=getattr(message, "is_command", False),
|
|
||||||
is_notify=getattr(message, "is_notify", False),
|
|
||||||
is_public_notice=getattr(message, "is_public_notice", False),
|
|
||||||
notice_type=getattr(message, "notice_type", None),
|
|
||||||
# 消息内容
|
|
||||||
processed_plain_text=getattr(message, "processed_plain_text", ""),
|
|
||||||
display_message=getattr(message, "processed_plain_text", ""), # 默认使用processed_plain_text
|
|
||||||
# 优先级信息
|
|
||||||
priority_mode=getattr(message, "priority_mode", None),
|
|
||||||
priority_info=json.dumps(getattr(message, "priority_info", None))
|
|
||||||
if getattr(message, "priority_info", None)
|
|
||||||
else None,
|
|
||||||
# 额外配置 - 需要将 format_info 嵌入到 additional_config 中
|
|
||||||
additional_config=self._prepare_additional_config(message_info),
|
|
||||||
# 用户信息
|
|
||||||
user_id=str(getattr(user_info, "user_id", "")),
|
|
||||||
user_nickname=getattr(user_info, "user_nickname", ""),
|
|
||||||
user_cardname=getattr(user_info, "user_cardname", None),
|
|
||||||
user_platform=getattr(user_info, "platform", ""),
|
|
||||||
# 群组信息
|
|
||||||
chat_info_group_id=getattr(group_info, "group_id", None),
|
|
||||||
chat_info_group_name=getattr(group_info, "group_name", None),
|
|
||||||
chat_info_group_platform=getattr(group_info, "platform", None),
|
|
||||||
# 聊天流信息
|
|
||||||
chat_info_user_id=str(getattr(user_info, "user_id", "")),
|
|
||||||
chat_info_user_nickname=getattr(user_info, "user_nickname", ""),
|
|
||||||
chat_info_user_cardname=getattr(user_info, "user_cardname", None),
|
|
||||||
chat_info_user_platform=getattr(user_info, "platform", ""),
|
|
||||||
chat_info_stream_id=self.stream_id,
|
|
||||||
chat_info_platform=self.platform,
|
|
||||||
chat_info_create_time=self.create_time,
|
|
||||||
chat_info_last_active_time=self.last_active_time,
|
|
||||||
# 新增兴趣度系统字段 - 添加安全处理
|
|
||||||
actions=self._safe_get_actions(message),
|
|
||||||
should_reply=getattr(message, "should_reply", False),
|
|
||||||
should_act=getattr(message, "should_act", False),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.stream_context.set_current_message(db_message)
|
|
||||||
self.stream_context.priority_mode = getattr(message, "priority_mode", None)
|
|
||||||
self.stream_context.priority_info = getattr(message, "priority_info", None)
|
|
||||||
|
|
||||||
# 调试日志:记录数据转移情况
|
|
||||||
logger.debug(
|
|
||||||
f"消息数据转移完成 - message_id: {db_message.message_id}, "
|
|
||||||
f"chat_id: {db_message.chat_id}, "
|
|
||||||
f"is_mentioned: {db_message.is_mentioned}, "
|
|
||||||
f"is_emoji: {db_message.is_emoji}, "
|
|
||||||
f"is_picid: {db_message.is_picid}, "
|
|
||||||
f"interest_value: {db_message.interest_value}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _prepare_additional_config(self, message_info) -> str | None:
|
|
||||||
"""
|
|
||||||
准备 additional_config,将 format_info 嵌入其中
|
|
||||||
|
|
||||||
这个方法模仿 storage.py 中的逻辑,确保 DatabaseMessages 中的 additional_config
|
|
||||||
包含 format_info,使得 action_modifier 能够正确获取适配器支持的消息类型
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message_info: BaseMessageInfo 对象
|
message: DatabaseMessages 对象,直接使用不需要转换
|
||||||
|
|
||||||
Returns:
|
|
||||||
str | None: JSON 字符串格式的 additional_config,如果为空则返回 None
|
|
||||||
"""
|
"""
|
||||||
import orjson
|
# 直接使用传入的 DatabaseMessages,设置到上下文中
|
||||||
|
self.context_manager.context.set_current_message(message)
|
||||||
# 首先获取adapter传递的additional_config
|
|
||||||
additional_config_data = {}
|
|
||||||
if hasattr(message_info, 'additional_config') and message_info.additional_config:
|
|
||||||
if isinstance(message_info.additional_config, dict):
|
|
||||||
additional_config_data = message_info.additional_config.copy()
|
|
||||||
elif isinstance(message_info.additional_config, str):
|
|
||||||
# 如果是字符串,尝试解析
|
|
||||||
try:
|
|
||||||
additional_config_data = orjson.loads(message_info.additional_config)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"无法解析 additional_config JSON: {e}")
|
|
||||||
additional_config_data = {}
|
|
||||||
|
|
||||||
# 然后添加format_info到additional_config中
|
# 设置优先级信息(如果存在)
|
||||||
if hasattr(message_info, 'format_info') and message_info.format_info:
|
priority_mode = getattr(message, "priority_mode", None)
|
||||||
try:
|
priority_info = getattr(message, "priority_info", None)
|
||||||
format_info_dict = message_info.format_info.to_dict()
|
if priority_mode:
|
||||||
additional_config_data["format_info"] = format_info_dict
|
self.context_manager.context.priority_mode = priority_mode
|
||||||
logger.debug(f"嵌入 format_info 到 additional_config: {format_info_dict}")
|
if priority_info:
|
||||||
except Exception as e:
|
self.context_manager.context.priority_info = priority_info
|
||||||
logger.warning(f"将 format_info 转换为字典失败: {e}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"[问题] 消息缺少 format_info: message_id={getattr(message_info, 'message_id', 'unknown')}")
|
|
||||||
logger.warning("[问题] 这可能导致 Action 无法正确检查适配器支持的类型")
|
|
||||||
|
|
||||||
# 序列化为JSON字符串
|
|
||||||
if additional_config_data:
|
|
||||||
try:
|
|
||||||
return orjson.dumps(additional_config_data).decode("utf-8")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"序列化 additional_config 失败: {e}")
|
|
||||||
return None
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _safe_get_actions(self, message: "MessageRecv") -> list | None:
|
# 调试日志
|
||||||
|
logger.debug(
|
||||||
|
f"消息上下文已设置 - message_id: {message.message_id}, "
|
||||||
|
f"chat_id: {message.chat_id}, "
|
||||||
|
f"is_mentioned: {message.is_mentioned}, "
|
||||||
|
f"is_emoji: {message.is_emoji}, "
|
||||||
|
f"is_picid: {message.is_picid}, "
|
||||||
|
f"interest_value: {message.interest_value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _safe_get_actions(self, message: DatabaseMessages) -> list | None:
|
||||||
"""安全获取消息的actions字段"""
|
"""安全获取消息的actions字段"""
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@@ -380,23 +254,6 @@ class ChatStream:
|
|||||||
if hasattr(db_message, "should_act"):
|
if hasattr(db_message, "should_act"):
|
||||||
db_message.should_act = False
|
db_message.should_act = False
|
||||||
|
|
||||||
def _extract_reply_from_segment(self, segment) -> str | None:
|
|
||||||
"""从消息段中提取reply_to信息"""
|
|
||||||
try:
|
|
||||||
if hasattr(segment, "type") and segment.type == "seglist":
|
|
||||||
# 递归搜索seglist中的reply段
|
|
||||||
if hasattr(segment, "data") and segment.data:
|
|
||||||
for seg in segment.data:
|
|
||||||
reply_id = self._extract_reply_from_segment(seg)
|
|
||||||
if reply_id:
|
|
||||||
return reply_id
|
|
||||||
elif hasattr(segment, "type") and segment.type == "reply":
|
|
||||||
# 找到reply段,返回message_id
|
|
||||||
return str(segment.data) if segment.data else None
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"提取reply_to信息失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _generate_chat_id(self, message_info) -> str:
|
def _generate_chat_id(self, message_info) -> str:
|
||||||
"""生成chat_id,基于群组或用户信息"""
|
"""生成chat_id,基于群组或用户信息"""
|
||||||
try:
|
try:
|
||||||
@@ -493,8 +350,10 @@ class ChatManager:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|
||||||
self.streams: dict[str, ChatStream] = {} # stream_id -> ChatStream
|
self.streams: dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||||||
self.last_messages: dict[str, "MessageRecv"] = {} # stream_id -> last_message
|
self.last_messages: dict[str, DatabaseMessages] = {} # stream_id -> last_message
|
||||||
# try:
|
# try:
|
||||||
# async with get_db_session() as session:
|
# async with get_db_session() as session:
|
||||||
# db.connect(reuse_if_open=True)
|
# db.connect(reuse_if_open=True)
|
||||||
@@ -528,12 +387,30 @@ class ChatManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"聊天流自动保存失败: {e!s}")
|
logger.error(f"聊天流自动保存失败: {e!s}")
|
||||||
|
|
||||||
def register_message(self, message: "MessageRecv"):
|
def register_message(self, message: DatabaseMessages):
|
||||||
"""注册消息到聊天流"""
|
"""注册消息到聊天流"""
|
||||||
|
# 从 DatabaseMessages 提取平台和用户/群组信息
|
||||||
|
from maim_message import UserInfo, GroupInfo
|
||||||
|
|
||||||
|
user_info = UserInfo(
|
||||||
|
platform=message.user_info.platform,
|
||||||
|
user_id=message.user_info.user_id,
|
||||||
|
user_nickname=message.user_info.user_nickname,
|
||||||
|
user_cardname=message.user_info.user_cardname or ""
|
||||||
|
)
|
||||||
|
|
||||||
|
group_info = None
|
||||||
|
if message.group_info:
|
||||||
|
group_info = GroupInfo(
|
||||||
|
platform=message.group_info.group_platform or "",
|
||||||
|
group_id=message.group_info.group_id,
|
||||||
|
group_name=message.group_info.group_name
|
||||||
|
)
|
||||||
|
|
||||||
stream_id = self._generate_stream_id(
|
stream_id = self._generate_stream_id(
|
||||||
message.message_info.platform, # type: ignore
|
message.chat_info.platform,
|
||||||
message.message_info.user_info,
|
user_info,
|
||||||
message.message_info.group_info,
|
group_info,
|
||||||
)
|
)
|
||||||
self.last_messages[stream_id] = message
|
self.last_messages[stream_id] = message
|
||||||
# logger.debug(f"注册消息到聊天流: {stream_id}")
|
# logger.debug(f"注册消息到聊天流: {stream_id}")
|
||||||
@@ -578,32 +455,6 @@ class ChatManager:
|
|||||||
try:
|
try:
|
||||||
stream_id = self._generate_stream_id(platform, user_info, group_info)
|
stream_id = self._generate_stream_id(platform, user_info, group_info)
|
||||||
|
|
||||||
# 优先使用缓存管理器(优化版本)
|
|
||||||
try:
|
|
||||||
from src.chat.message_manager.stream_cache_manager import get_stream_cache_manager
|
|
||||||
|
|
||||||
cache_manager = get_stream_cache_manager()
|
|
||||||
|
|
||||||
if cache_manager.is_running:
|
|
||||||
optimized_stream = await cache_manager.get_or_create_stream(
|
|
||||||
stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info
|
|
||||||
)
|
|
||||||
|
|
||||||
# 设置消息上下文
|
|
||||||
from .message import MessageRecv
|
|
||||||
|
|
||||||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
|
|
||||||
optimized_stream.set_context(self.last_messages[stream_id])
|
|
||||||
|
|
||||||
# 转换为原始ChatStream以保持兼容性
|
|
||||||
original_stream = self._convert_to_original_stream(optimized_stream)
|
|
||||||
|
|
||||||
return original_stream
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"缓存管理器获取流失败,使用原始方法: {e}")
|
|
||||||
|
|
||||||
# 回退到原始方法
|
|
||||||
# 检查内存中是否存在
|
# 检查内存中是否存在
|
||||||
if stream_id in self.streams:
|
if stream_id in self.streams:
|
||||||
stream = self.streams[stream_id]
|
stream = self.streams[stream_id]
|
||||||
@@ -615,12 +466,13 @@ class ChatManager:
|
|||||||
stream.user_info = user_info
|
stream.user_info = user_info
|
||||||
if group_info:
|
if group_info:
|
||||||
stream.group_info = group_info
|
stream.group_info = group_info
|
||||||
from .message import MessageRecv # 延迟导入,避免循环引用
|
|
||||||
|
# 检查是否有最后一条消息(现在使用 DatabaseMessages)
|
||||||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
|
||||||
await stream.set_context(self.last_messages[stream_id])
|
await stream.set_context(self.last_messages[stream_id])
|
||||||
else:
|
else:
|
||||||
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
|
logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的或还没有消息")
|
||||||
return stream
|
return stream
|
||||||
|
|
||||||
# 检查数据库中是否存在
|
# 检查数据库中是否存在
|
||||||
@@ -679,19 +531,27 @@ class ChatManager:
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
stream = copy.deepcopy(stream)
|
stream = copy.deepcopy(stream)
|
||||||
from .message import MessageRecv # 延迟导入,避免循环引用
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|
||||||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
|
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
|
||||||
await stream.set_context(self.last_messages[stream_id])
|
await stream.set_context(self.last_messages[stream_id])
|
||||||
else:
|
else:
|
||||||
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
|
logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
|
||||||
|
|
||||||
# 确保 ChatStream 有自己的 context_manager
|
# 确保 ChatStream 有自己的 context_manager
|
||||||
if not hasattr(stream, "context_manager"):
|
if not hasattr(stream, "context_manager"):
|
||||||
# 创建新的单流上下文管理器
|
|
||||||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||||||
|
from src.common.data_models.message_manager_data_model import StreamContext
|
||||||
|
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||||
|
|
||||||
stream.context_manager = SingleStreamContextManager(stream_id=stream_id, context=stream.stream_context)
|
stream.context_manager = SingleStreamContextManager(
|
||||||
|
stream_id=stream_id,
|
||||||
|
context=StreamContext(
|
||||||
|
stream_id=stream_id,
|
||||||
|
chat_type=ChatType.GROUP if stream.group_info else ChatType.PRIVATE,
|
||||||
|
chat_mode=ChatMode.NORMAL,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# 保存到内存和数据库
|
# 保存到内存和数据库
|
||||||
self.streams[stream_id] = stream
|
self.streams[stream_id] = stream
|
||||||
@@ -700,10 +560,12 @@ class ChatManager:
|
|||||||
|
|
||||||
async def get_stream(self, stream_id: str) -> ChatStream | None:
|
async def get_stream(self, stream_id: str) -> ChatStream | None:
|
||||||
"""通过stream_id获取聊天流"""
|
"""通过stream_id获取聊天流"""
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|
||||||
stream = self.streams.get(stream_id)
|
stream = self.streams.get(stream_id)
|
||||||
if not stream:
|
if not stream:
|
||||||
return None
|
return None
|
||||||
if stream_id in self.last_messages:
|
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
|
||||||
await stream.set_context(self.last_messages[stream_id])
|
await stream.set_context(self.last_messages[stream_id])
|
||||||
return stream
|
return stream
|
||||||
|
|
||||||
@@ -921,9 +783,16 @@ class ChatManager:
|
|||||||
# 确保 ChatStream 有自己的 context_manager
|
# 确保 ChatStream 有自己的 context_manager
|
||||||
if not hasattr(stream, "context_manager"):
|
if not hasattr(stream, "context_manager"):
|
||||||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||||||
|
from src.common.data_models.message_manager_data_model import StreamContext
|
||||||
|
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||||
|
|
||||||
stream.context_manager = SingleStreamContextManager(
|
stream.context_manager = SingleStreamContextManager(
|
||||||
stream_id=stream.stream_id, context=stream.stream_context
|
stream_id=stream.stream_id,
|
||||||
|
context=StreamContext(
|
||||||
|
stream_id=stream.stream_id,
|
||||||
|
chat_type=ChatType.GROUP if stream.group_info else ChatType.PRIVATE,
|
||||||
|
chat_mode=ChatMode.NORMAL,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}", exc_info=True)
|
logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}", exc_info=True)
|
||||||
@@ -932,46 +801,6 @@ class ChatManager:
|
|||||||
chat_manager = None
|
chat_manager = None
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_original_stream(self, optimized_stream) -> "ChatStream":
|
|
||||||
"""将OptimizedChatStream转换为原始ChatStream以保持兼容性"""
|
|
||||||
try:
|
|
||||||
# 创建原始ChatStream实例
|
|
||||||
original_stream = ChatStream(
|
|
||||||
stream_id=optimized_stream.stream_id,
|
|
||||||
platform=optimized_stream.platform,
|
|
||||||
user_info=optimized_stream._get_effective_user_info(),
|
|
||||||
group_info=optimized_stream._get_effective_group_info(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 复制状态
|
|
||||||
original_stream.create_time = optimized_stream.create_time
|
|
||||||
original_stream.last_active_time = optimized_stream.last_active_time
|
|
||||||
original_stream.sleep_pressure = optimized_stream.sleep_pressure
|
|
||||||
original_stream.base_interest_energy = optimized_stream.base_interest_energy
|
|
||||||
original_stream._focus_energy = optimized_stream._focus_energy
|
|
||||||
original_stream.no_reply_consecutive = optimized_stream.no_reply_consecutive
|
|
||||||
original_stream.saved = optimized_stream.saved
|
|
||||||
|
|
||||||
# 复制上下文信息(如果存在)
|
|
||||||
if hasattr(optimized_stream, "_stream_context") and optimized_stream._stream_context:
|
|
||||||
original_stream.stream_context = optimized_stream._stream_context
|
|
||||||
|
|
||||||
if hasattr(optimized_stream, "_context_manager") and optimized_stream._context_manager:
|
|
||||||
original_stream.context_manager = optimized_stream._context_manager
|
|
||||||
|
|
||||||
return original_stream
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"转换OptimizedChatStream失败: {e}")
|
|
||||||
# 如果转换失败,创建一个新的原始流
|
|
||||||
return ChatStream(
|
|
||||||
stream_id=optimized_stream.stream_id,
|
|
||||||
platform=optimized_stream.platform,
|
|
||||||
user_info=optimized_stream._get_effective_user_info(),
|
|
||||||
group_info=optimized_stream._get_effective_group_info(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_chat_manager():
|
def get_chat_manager():
|
||||||
global chat_manager
|
global chat_manager
|
||||||
if chat_manager is None:
|
if chat_manager is None:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import base64
|
|||||||
import time
|
import time
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import urllib3
|
import urllib3
|
||||||
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo
|
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo
|
||||||
@@ -13,6 +13,7 @@ from src.chat.utils.self_voice_cache import consume_self_voice_text
|
|||||||
from src.chat.utils.utils_image import get_image_manager
|
from src.chat.utils.utils_image import get_image_manager
|
||||||
from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available
|
from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available
|
||||||
from src.chat.utils.utils_voice import get_voice_text
|
from src.chat.utils.utils_voice import get_voice_text
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
@@ -43,7 +44,7 @@ class Message(MessageBase, metaclass=ABCMeta):
|
|||||||
user_info: UserInfo,
|
user_info: UserInfo,
|
||||||
message_segment: Seg | None = None,
|
message_segment: Seg | None = None,
|
||||||
timestamp: float | None = None,
|
timestamp: float | None = None,
|
||||||
reply: Optional["MessageRecv"] = None,
|
reply: Optional["DatabaseMessages"] = None,
|
||||||
processed_plain_text: str = "",
|
processed_plain_text: str = "",
|
||||||
):
|
):
|
||||||
# 使用传入的时间戳或当前时间
|
# 使用传入的时间戳或当前时间
|
||||||
@@ -95,346 +96,12 @@ class Message(MessageBase, metaclass=ABCMeta):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MessageRecv(Message):
|
|
||||||
"""接收消息类,用于处理从MessageCQ序列化的消息"""
|
|
||||||
|
|
||||||
def __init__(self, message_dict: dict[str, Any]):
|
# MessageRecv 类已被完全移除,现在统一使用 DatabaseMessages
|
||||||
"""从MessageCQ的字典初始化
|
# 如需从消息字典创建 DatabaseMessages,请使用:
|
||||||
|
# from src.chat.message_receive.message_processor import process_message_from_dict
|
||||||
Args:
|
#
|
||||||
message_dict: MessageCQ序列化后的字典
|
# 迁移完成日期: 2025-10-31
|
||||||
"""
|
|
||||||
# Manually initialize attributes from MessageBase and Message
|
|
||||||
self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
|
|
||||||
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
|
|
||||||
self.raw_message = message_dict.get("raw_message")
|
|
||||||
|
|
||||||
self.chat_stream = None
|
|
||||||
self.reply = None
|
|
||||||
self.processed_plain_text = message_dict.get("processed_plain_text", "")
|
|
||||||
self.memorized_times = 0
|
|
||||||
|
|
||||||
# MessageRecv specific attributes
|
|
||||||
self.is_emoji = False
|
|
||||||
self.has_emoji = False
|
|
||||||
self.is_picid = False
|
|
||||||
self.has_picid = False
|
|
||||||
self.is_voice = False
|
|
||||||
self.is_video = False
|
|
||||||
self.is_mentioned = None
|
|
||||||
self.is_notify = False # 是否为notice消息
|
|
||||||
self.is_public_notice = False # 是否为公共notice
|
|
||||||
self.notice_type = None # notice类型
|
|
||||||
self.is_at = False
|
|
||||||
self.is_command = False
|
|
||||||
|
|
||||||
self.priority_mode = "interest"
|
|
||||||
self.priority_info = None
|
|
||||||
self.interest_value: float = 0.0
|
|
||||||
|
|
||||||
self.key_words = []
|
|
||||||
self.key_words_lite = []
|
|
||||||
|
|
||||||
# 解析additional_config中的notice信息
|
|
||||||
if self.message_info.additional_config and isinstance(self.message_info.additional_config, dict):
|
|
||||||
self.is_notify = self.message_info.additional_config.get("is_notice", False)
|
|
||||||
self.is_public_notice = self.message_info.additional_config.get("is_public_notice", False)
|
|
||||||
self.notice_type = self.message_info.additional_config.get("notice_type")
|
|
||||||
|
|
||||||
def update_chat_stream(self, chat_stream: "ChatStream"):
|
|
||||||
self.chat_stream = chat_stream
|
|
||||||
|
|
||||||
def to_database_message(self) -> "DatabaseMessages":
|
|
||||||
"""将 MessageRecv 转换为 DatabaseMessages 对象
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
DatabaseMessages: 数据库消息对象
|
|
||||||
"""
|
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
|
|
||||||
message_info = self.message_info
|
|
||||||
msg_user_info = getattr(message_info, "user_info", None)
|
|
||||||
stream_user_info = getattr(self.chat_stream, "user_info", None) if self.chat_stream else None
|
|
||||||
group_info = getattr(self.chat_stream, "group_info", None) if self.chat_stream else None
|
|
||||||
|
|
||||||
message_id = message_info.message_id or ""
|
|
||||||
message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time()
|
|
||||||
is_mentioned = None
|
|
||||||
if isinstance(self.is_mentioned, bool):
|
|
||||||
is_mentioned = self.is_mentioned
|
|
||||||
elif isinstance(self.is_mentioned, int | float):
|
|
||||||
is_mentioned = self.is_mentioned != 0
|
|
||||||
|
|
||||||
# 提取用户信息
|
|
||||||
user_id = ""
|
|
||||||
user_nickname = ""
|
|
||||||
user_cardname = None
|
|
||||||
user_platform = ""
|
|
||||||
if msg_user_info:
|
|
||||||
user_id = str(getattr(msg_user_info, "user_id", "") or "")
|
|
||||||
user_nickname = getattr(msg_user_info, "user_nickname", "") or ""
|
|
||||||
user_cardname = getattr(msg_user_info, "user_cardname", None)
|
|
||||||
user_platform = getattr(msg_user_info, "platform", "") or ""
|
|
||||||
elif stream_user_info:
|
|
||||||
user_id = str(getattr(stream_user_info, "user_id", "") or "")
|
|
||||||
user_nickname = getattr(stream_user_info, "user_nickname", "") or ""
|
|
||||||
user_cardname = getattr(stream_user_info, "user_cardname", None)
|
|
||||||
user_platform = getattr(stream_user_info, "platform", "") or ""
|
|
||||||
|
|
||||||
# 提取聊天流信息
|
|
||||||
chat_user_id = str(getattr(stream_user_info, "user_id", "") or "") if stream_user_info else ""
|
|
||||||
chat_user_nickname = getattr(stream_user_info, "user_nickname", "") or "" if stream_user_info else ""
|
|
||||||
chat_user_cardname = getattr(stream_user_info, "user_cardname", None) if stream_user_info else None
|
|
||||||
chat_user_platform = getattr(stream_user_info, "platform", "") or "" if stream_user_info else ""
|
|
||||||
|
|
||||||
group_id = getattr(group_info, "group_id", None) if group_info else None
|
|
||||||
group_name = getattr(group_info, "group_name", None) if group_info else None
|
|
||||||
group_platform = getattr(group_info, "platform", None) if group_info else None
|
|
||||||
|
|
||||||
# 准备 additional_config
|
|
||||||
additional_config_str = None
|
|
||||||
try:
|
|
||||||
import orjson
|
|
||||||
|
|
||||||
additional_config_data = {}
|
|
||||||
|
|
||||||
# 首先获取adapter传递的additional_config
|
|
||||||
if hasattr(message_info, 'additional_config') and message_info.additional_config:
|
|
||||||
if isinstance(message_info.additional_config, dict):
|
|
||||||
additional_config_data = message_info.additional_config.copy()
|
|
||||||
elif isinstance(message_info.additional_config, str):
|
|
||||||
try:
|
|
||||||
additional_config_data = orjson.loads(message_info.additional_config)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"无法解析 additional_config JSON: {e}")
|
|
||||||
additional_config_data = {}
|
|
||||||
|
|
||||||
# 添加notice相关标志
|
|
||||||
if self.is_notify:
|
|
||||||
additional_config_data["is_notice"] = True
|
|
||||||
additional_config_data["notice_type"] = self.notice_type or "unknown"
|
|
||||||
additional_config_data["is_public_notice"] = bool(self.is_public_notice)
|
|
||||||
|
|
||||||
# 添加format_info到additional_config中
|
|
||||||
if hasattr(message_info, 'format_info') and message_info.format_info:
|
|
||||||
try:
|
|
||||||
format_info_dict = message_info.format_info.to_dict()
|
|
||||||
additional_config_data["format_info"] = format_info_dict
|
|
||||||
logger.debug(f"[message.py] 嵌入 format_info 到 additional_config: {format_info_dict}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"将 format_info 转换为字典失败: {e}")
|
|
||||||
|
|
||||||
# 序列化为JSON字符串
|
|
||||||
if additional_config_data:
|
|
||||||
additional_config_str = orjson.dumps(additional_config_data).decode("utf-8")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"准备 additional_config 失败: {e}")
|
|
||||||
|
|
||||||
# 创建数据库消息对象
|
|
||||||
db_message = DatabaseMessages(
|
|
||||||
message_id=message_id,
|
|
||||||
time=float(message_time),
|
|
||||||
chat_id=self.chat_stream.stream_id if self.chat_stream else "",
|
|
||||||
processed_plain_text=self.processed_plain_text,
|
|
||||||
display_message=self.processed_plain_text,
|
|
||||||
is_mentioned=is_mentioned,
|
|
||||||
is_at=bool(self.is_at) if self.is_at is not None else None,
|
|
||||||
is_emoji=bool(self.is_emoji),
|
|
||||||
is_picid=bool(self.is_picid),
|
|
||||||
is_command=bool(self.is_command),
|
|
||||||
is_notify=bool(self.is_notify),
|
|
||||||
is_public_notice=bool(self.is_public_notice),
|
|
||||||
notice_type=self.notice_type,
|
|
||||||
additional_config=additional_config_str,
|
|
||||||
user_id=user_id,
|
|
||||||
user_nickname=user_nickname,
|
|
||||||
user_cardname=user_cardname,
|
|
||||||
user_platform=user_platform,
|
|
||||||
chat_info_stream_id=self.chat_stream.stream_id if self.chat_stream else "",
|
|
||||||
chat_info_platform=self.chat_stream.platform if self.chat_stream else "",
|
|
||||||
chat_info_create_time=float(self.chat_stream.create_time) if self.chat_stream else 0.0,
|
|
||||||
chat_info_last_active_time=float(self.chat_stream.last_active_time) if self.chat_stream else 0.0,
|
|
||||||
chat_info_user_id=chat_user_id,
|
|
||||||
chat_info_user_nickname=chat_user_nickname,
|
|
||||||
chat_info_user_cardname=chat_user_cardname,
|
|
||||||
chat_info_user_platform=chat_user_platform,
|
|
||||||
chat_info_group_id=group_id,
|
|
||||||
chat_info_group_name=group_name,
|
|
||||||
chat_info_group_platform=group_platform,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 同步兴趣度等衍生属性
|
|
||||||
db_message.interest_value = getattr(self, "interest_value", 0.0)
|
|
||||||
setattr(db_message, "should_reply", getattr(self, "should_reply", False))
|
|
||||||
setattr(db_message, "should_act", getattr(self, "should_act", False))
|
|
||||||
|
|
||||||
return db_message
|
|
||||||
|
|
||||||
async def process(self) -> None:
|
|
||||||
"""处理消息内容,生成纯文本和详细文本
|
|
||||||
|
|
||||||
这个方法必须在创建实例后显式调用,因为它包含异步操作。
|
|
||||||
"""
|
|
||||||
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
|
||||||
|
|
||||||
async def _process_single_segment(self, segment: Seg) -> str:
|
|
||||||
"""处理单个消息段
|
|
||||||
|
|
||||||
Args:
|
|
||||||
segment: 消息段
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 处理后的文本
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if segment.type == "text":
|
|
||||||
self.is_picid = False
|
|
||||||
self.is_emoji = False
|
|
||||||
self.is_video = False
|
|
||||||
return segment.data # type: ignore
|
|
||||||
elif segment.type == "at":
|
|
||||||
self.is_picid = False
|
|
||||||
self.is_emoji = False
|
|
||||||
self.is_video = False
|
|
||||||
# 处理at消息,格式为"昵称:QQ号"
|
|
||||||
if isinstance(segment.data, str) and ":" in segment.data:
|
|
||||||
nickname, qq_id = segment.data.split(":", 1)
|
|
||||||
return f"@{nickname}"
|
|
||||||
return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户"
|
|
||||||
elif segment.type == "image":
|
|
||||||
# 如果是base64图片数据
|
|
||||||
if isinstance(segment.data, str):
|
|
||||||
self.has_picid = True
|
|
||||||
self.is_picid = True
|
|
||||||
self.is_emoji = False
|
|
||||||
self.is_video = False
|
|
||||||
image_manager = get_image_manager()
|
|
||||||
# print(f"segment.data: {segment.data}")
|
|
||||||
_, processed_text = await image_manager.process_image(segment.data)
|
|
||||||
return processed_text
|
|
||||||
return "[发了一张图片,网卡了加载不出来]"
|
|
||||||
elif segment.type == "emoji":
|
|
||||||
self.has_emoji = True
|
|
||||||
self.is_emoji = True
|
|
||||||
self.is_picid = False
|
|
||||||
self.is_voice = False
|
|
||||||
self.is_video = False
|
|
||||||
if isinstance(segment.data, str):
|
|
||||||
return await get_image_manager().get_emoji_description(segment.data)
|
|
||||||
return "[发了一个表情包,网卡了加载不出来]"
|
|
||||||
elif segment.type == "voice":
|
|
||||||
self.is_picid = False
|
|
||||||
self.is_emoji = False
|
|
||||||
self.is_voice = True
|
|
||||||
self.is_video = False
|
|
||||||
|
|
||||||
# 检查消息是否由机器人自己发送
|
|
||||||
if self.message_info and self.message_info.user_info and str(self.message_info.user_info.user_id) == str(global_config.bot.qq_account):
|
|
||||||
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {self.message_info.user_info.user_id}),尝试从缓存获取文本。")
|
|
||||||
if isinstance(segment.data, str):
|
|
||||||
cached_text = consume_self_voice_text(segment.data)
|
|
||||||
if cached_text:
|
|
||||||
logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'")
|
|
||||||
return f"[语音:{cached_text}]"
|
|
||||||
else:
|
|
||||||
logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
|
|
||||||
|
|
||||||
# 标准语音识别流程 (也作为缓存未命中的后备方案)
|
|
||||||
if isinstance(segment.data, str):
|
|
||||||
return await get_voice_text(segment.data)
|
|
||||||
return "[发了一段语音,网卡了加载不出来]"
|
|
||||||
elif segment.type == "mention_bot":
|
|
||||||
self.is_picid = False
|
|
||||||
self.is_emoji = False
|
|
||||||
self.is_voice = False
|
|
||||||
self.is_video = False
|
|
||||||
self.is_mentioned = float(segment.data) # type: ignore
|
|
||||||
return ""
|
|
||||||
elif segment.type == "priority_info":
|
|
||||||
self.is_picid = False
|
|
||||||
self.is_emoji = False
|
|
||||||
self.is_voice = False
|
|
||||||
if isinstance(segment.data, dict):
|
|
||||||
# 处理优先级信息
|
|
||||||
self.priority_mode = "priority"
|
|
||||||
self.priority_info = segment.data
|
|
||||||
"""
|
|
||||||
{
|
|
||||||
'message_type': 'vip', # vip or normal
|
|
||||||
'message_priority': 1.0, # 优先级,大为优先,float
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
return ""
|
|
||||||
elif segment.type == "file":
|
|
||||||
if isinstance(segment.data, dict):
|
|
||||||
file_name = segment.data.get('name', '未知文件')
|
|
||||||
file_size = segment.data.get('size', '未知大小')
|
|
||||||
return f"[文件:{file_name} ({file_size}字节)]"
|
|
||||||
return "[收到一个文件]"
|
|
||||||
elif segment.type == "video":
|
|
||||||
self.is_picid = False
|
|
||||||
self.is_emoji = False
|
|
||||||
self.is_voice = False
|
|
||||||
self.is_video = True
|
|
||||||
logger.info(f"接收到视频消息,数据类型: {type(segment.data)}")
|
|
||||||
|
|
||||||
# 检查视频分析功能是否可用
|
|
||||||
if not is_video_analysis_available():
|
|
||||||
logger.warning("⚠️ Rust视频处理模块不可用,跳过视频分析")
|
|
||||||
return "[视频]"
|
|
||||||
|
|
||||||
if global_config.video_analysis.enable:
|
|
||||||
logger.info("已启用视频识别,开始识别")
|
|
||||||
if isinstance(segment.data, dict):
|
|
||||||
try:
|
|
||||||
# 从Adapter接收的视频数据
|
|
||||||
video_base64 = segment.data.get("base64")
|
|
||||||
filename = segment.data.get("filename", "video.mp4")
|
|
||||||
|
|
||||||
logger.info(f"视频文件名: {filename}")
|
|
||||||
logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
|
|
||||||
|
|
||||||
if video_base64:
|
|
||||||
# 解码base64视频数据
|
|
||||||
video_bytes = base64.b64decode(video_base64)
|
|
||||||
logger.info(f"解码后视频大小: {len(video_bytes)} 字节")
|
|
||||||
|
|
||||||
# 使用video analyzer分析视频
|
|
||||||
video_analyzer = get_video_analyzer()
|
|
||||||
result = await video_analyzer.analyze_video_from_bytes(
|
|
||||||
video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"视频分析结果: {result}")
|
|
||||||
|
|
||||||
# 返回视频分析结果
|
|
||||||
summary = result.get("summary", "")
|
|
||||||
if summary:
|
|
||||||
return f"[视频内容] {summary}"
|
|
||||||
else:
|
|
||||||
return "[已收到视频,但分析失败]"
|
|
||||||
else:
|
|
||||||
logger.warning("视频消息中没有base64数据")
|
|
||||||
return "[收到视频消息,但数据异常]"
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"视频处理失败: {e!s}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
logger.error(f"错误详情: {traceback.format_exc()}")
|
|
||||||
return "[收到视频,但处理时出现错误]"
|
|
||||||
else:
|
|
||||||
logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}")
|
|
||||||
return "[发了一个视频,但格式不支持]"
|
|
||||||
else:
|
|
||||||
return ""
|
|
||||||
else:
|
|
||||||
logger.warning(f"未知的消息段类型: {segment.type}")
|
|
||||||
return f"[{segment.type} 消息]"
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
|
|
||||||
return f"[处理失败的{segment.type}消息]"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -447,7 +114,7 @@ class MessageProcessBase(Message):
|
|||||||
chat_stream: "ChatStream",
|
chat_stream: "ChatStream",
|
||||||
bot_user_info: UserInfo,
|
bot_user_info: UserInfo,
|
||||||
message_segment: Seg | None = None,
|
message_segment: Seg | None = None,
|
||||||
reply: Optional["MessageRecv"] = None,
|
reply: Optional["DatabaseMessages"] = None,
|
||||||
thinking_start_time: float = 0,
|
thinking_start_time: float = 0,
|
||||||
timestamp: float | None = None,
|
timestamp: float | None = None,
|
||||||
):
|
):
|
||||||
@@ -548,7 +215,7 @@ class MessageSending(MessageProcessBase):
|
|||||||
sender_info: UserInfo | None, # 用来记录发送者信息
|
sender_info: UserInfo | None, # 用来记录发送者信息
|
||||||
message_segment: Seg,
|
message_segment: Seg,
|
||||||
display_message: str = "",
|
display_message: str = "",
|
||||||
reply: Optional["MessageRecv"] = None,
|
reply: Optional["DatabaseMessages"] = None,
|
||||||
is_head: bool = False,
|
is_head: bool = False,
|
||||||
is_emoji: bool = False,
|
is_emoji: bool = False,
|
||||||
thinking_start_time: float = 0,
|
thinking_start_time: float = 0,
|
||||||
@@ -567,7 +234,11 @@ class MessageSending(MessageProcessBase):
|
|||||||
|
|
||||||
# 发送状态特有属性
|
# 发送状态特有属性
|
||||||
self.sender_info = sender_info
|
self.sender_info = sender_info
|
||||||
self.reply_to_message_id = reply.message_info.message_id if reply else None
|
# 从 DatabaseMessages 获取 message_id
|
||||||
|
if reply:
|
||||||
|
self.reply_to_message_id = reply.message_id
|
||||||
|
else:
|
||||||
|
self.reply_to_message_id = None
|
||||||
self.is_head = is_head
|
self.is_head = is_head
|
||||||
self.is_emoji = is_emoji
|
self.is_emoji = is_emoji
|
||||||
self.apply_set_reply_logic = apply_set_reply_logic
|
self.apply_set_reply_logic = apply_set_reply_logic
|
||||||
@@ -582,14 +253,18 @@ class MessageSending(MessageProcessBase):
|
|||||||
def build_reply(self):
|
def build_reply(self):
|
||||||
"""设置回复消息"""
|
"""设置回复消息"""
|
||||||
if self.reply:
|
if self.reply:
|
||||||
self.reply_to_message_id = self.reply.message_info.message_id
|
# 从 DatabaseMessages 获取 message_id
|
||||||
self.message_segment = Seg(
|
message_id = self.reply.message_id
|
||||||
type="seglist",
|
|
||||||
data=[
|
if message_id:
|
||||||
Seg(type="reply", data=self.reply.message_info.message_id), # type: ignore
|
self.reply_to_message_id = message_id
|
||||||
self.message_segment,
|
self.message_segment = Seg(
|
||||||
],
|
type="seglist",
|
||||||
)
|
data=[
|
||||||
|
Seg(type="reply", data=message_id), # type: ignore
|
||||||
|
self.message_segment,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
async def process(self) -> None:
|
async def process(self) -> None:
|
||||||
"""处理消息内容,生成纯文本和详细文本"""
|
"""处理消息内容,生成纯文本和详细文本"""
|
||||||
@@ -607,48 +282,5 @@ class MessageSending(MessageProcessBase):
|
|||||||
return self.message_info.group_info is None or self.message_info.group_info.group_id is None
|
return self.message_info.group_info is None or self.message_info.group_info.group_id is None
|
||||||
|
|
||||||
|
|
||||||
def message_recv_from_dict(message_dict: dict) -> MessageRecv:
|
# message_recv_from_dict 和 message_from_db_dict 函数已被移除
|
||||||
return MessageRecv(message_dict)
|
# 请使用: from src.chat.message_receive.message_processor import process_message_from_dict
|
||||||
|
|
||||||
def message_from_db_dict(db_dict: dict) -> MessageRecv:
|
|
||||||
"""从数据库字典创建MessageRecv实例"""
|
|
||||||
# 转换扁平的数据库字典为嵌套结构
|
|
||||||
message_info_dict = {
|
|
||||||
"platform": db_dict.get("chat_info_platform"),
|
|
||||||
"message_id": db_dict.get("message_id"),
|
|
||||||
"time": db_dict.get("time"),
|
|
||||||
"group_info": {
|
|
||||||
"platform": db_dict.get("chat_info_group_platform"),
|
|
||||||
"group_id": db_dict.get("chat_info_group_id"),
|
|
||||||
"group_name": db_dict.get("chat_info_group_name"),
|
|
||||||
},
|
|
||||||
"user_info": {
|
|
||||||
"platform": db_dict.get("user_platform"),
|
|
||||||
"user_id": db_dict.get("user_id"),
|
|
||||||
"user_nickname": db_dict.get("user_nickname"),
|
|
||||||
"user_cardname": db_dict.get("user_cardname"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
processed_text = db_dict.get("processed_plain_text", "")
|
|
||||||
|
|
||||||
# 构建 MessageRecv 需要的字典
|
|
||||||
recv_dict = {
|
|
||||||
"message_info": message_info_dict,
|
|
||||||
"message_segment": {"type": "text", "data": processed_text}, # 从纯文本重建消息段
|
|
||||||
"raw_message": None, # 数据库中未存储原始消息
|
|
||||||
"processed_plain_text": processed_text,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 创建 MessageRecv 实例
|
|
||||||
msg = MessageRecv(recv_dict)
|
|
||||||
|
|
||||||
# 从数据库字典中填充其他可选字段
|
|
||||||
msg.interest_value = db_dict.get("interest_value", 0.0)
|
|
||||||
msg.is_mentioned = db_dict.get("is_mentioned")
|
|
||||||
msg.priority_mode = db_dict.get("priority_mode", "interest")
|
|
||||||
msg.priority_info = db_dict.get("priority_info")
|
|
||||||
msg.is_emoji = db_dict.get("is_emoji", False)
|
|
||||||
msg.is_picid = db_dict.get("is_picid", False)
|
|
||||||
|
|
||||||
return msg
|
|
||||||
|
|||||||
493
src/chat/message_receive/message_processor.py
Normal file
493
src/chat/message_receive/message_processor.py
Normal file
@@ -0,0 +1,493 @@
|
|||||||
|
"""消息处理工具模块
|
||||||
|
将原 MessageRecv 的消息处理逻辑提取为独立函数,
|
||||||
|
直接从适配器消息字典生成 DatabaseMessages
|
||||||
|
"""
|
||||||
|
import base64
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
from maim_message import BaseMessageInfo, Seg
|
||||||
|
|
||||||
|
from src.chat.utils.self_voice_cache import consume_self_voice_text
|
||||||
|
from src.chat.utils.utils_image import get_image_manager
|
||||||
|
from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available
|
||||||
|
from src.chat.utils.utils_voice import get_voice_text
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
logger = get_logger("message_processor")
|
||||||
|
|
||||||
|
|
||||||
|
async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str, platform: str) -> DatabaseMessages:
|
||||||
|
"""从适配器消息字典处理并生成 DatabaseMessages
|
||||||
|
|
||||||
|
这个函数整合了原 MessageRecv 的所有处理逻辑:
|
||||||
|
1. 解析 message_segment 并异步处理内容(图片、语音、视频等)
|
||||||
|
2. 提取所有消息元数据
|
||||||
|
3. 直接构造 DatabaseMessages 对象
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_dict: MessageCQ序列化后的字典
|
||||||
|
stream_id: 聊天流ID
|
||||||
|
platform: 平台标识
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DatabaseMessages: 处理完成的数据库消息对象
|
||||||
|
"""
|
||||||
|
# 解析基础信息
|
||||||
|
message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
|
||||||
|
message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
|
||||||
|
|
||||||
|
# 初始化处理状态
|
||||||
|
processing_state = {
|
||||||
|
"is_emoji": False,
|
||||||
|
"has_emoji": False,
|
||||||
|
"is_picid": False,
|
||||||
|
"has_picid": False,
|
||||||
|
"is_voice": False,
|
||||||
|
"is_video": False,
|
||||||
|
"is_mentioned": None,
|
||||||
|
"is_at": False,
|
||||||
|
"priority_mode": "interest",
|
||||||
|
"priority_info": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 异步处理消息段,生成纯文本
|
||||||
|
processed_plain_text = await _process_message_segments(message_segment, processing_state, message_info)
|
||||||
|
|
||||||
|
# 解析 notice 信息
|
||||||
|
is_notify = False
|
||||||
|
is_public_notice = False
|
||||||
|
notice_type = None
|
||||||
|
if message_info.additional_config and isinstance(message_info.additional_config, dict):
|
||||||
|
is_notify = message_info.additional_config.get("is_notice", False)
|
||||||
|
is_public_notice = message_info.additional_config.get("is_public_notice", False)
|
||||||
|
notice_type = message_info.additional_config.get("notice_type")
|
||||||
|
|
||||||
|
# 提取用户信息
|
||||||
|
user_info = message_info.user_info
|
||||||
|
user_id = str(user_info.user_id) if user_info and user_info.user_id else ""
|
||||||
|
user_nickname = (user_info.user_nickname or "") if user_info else ""
|
||||||
|
user_cardname = user_info.user_cardname if user_info else None
|
||||||
|
user_platform = (user_info.platform or "") if user_info else ""
|
||||||
|
|
||||||
|
# 提取群组信息
|
||||||
|
group_info = message_info.group_info
|
||||||
|
group_id = group_info.group_id if group_info else None
|
||||||
|
group_name = group_info.group_name if group_info else None
|
||||||
|
group_platform = group_info.platform if group_info else None
|
||||||
|
|
||||||
|
# 生成 chat_id
|
||||||
|
if group_info and group_id:
|
||||||
|
chat_id = f"{platform}_{group_id}"
|
||||||
|
elif user_info and user_info.user_id:
|
||||||
|
chat_id = f"{platform}_{user_info.user_id}_private"
|
||||||
|
else:
|
||||||
|
chat_id = stream_id
|
||||||
|
|
||||||
|
# 准备 additional_config
|
||||||
|
additional_config_str = _prepare_additional_config(message_info, is_notify, is_public_notice, notice_type)
|
||||||
|
|
||||||
|
# 提取 reply_to
|
||||||
|
reply_to = _extract_reply_from_segment(message_segment)
|
||||||
|
|
||||||
|
# 构造 DatabaseMessages
|
||||||
|
message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time()
|
||||||
|
message_id = message_info.message_id or ""
|
||||||
|
|
||||||
|
# 处理 is_mentioned
|
||||||
|
is_mentioned = None
|
||||||
|
mentioned_value = processing_state.get("is_mentioned")
|
||||||
|
if isinstance(mentioned_value, bool):
|
||||||
|
is_mentioned = mentioned_value
|
||||||
|
elif isinstance(mentioned_value, (int, float)):
|
||||||
|
is_mentioned = mentioned_value != 0
|
||||||
|
|
||||||
|
db_message = DatabaseMessages(
|
||||||
|
message_id=message_id,
|
||||||
|
time=float(message_time),
|
||||||
|
chat_id=chat_id,
|
||||||
|
reply_to=reply_to,
|
||||||
|
processed_plain_text=processed_plain_text,
|
||||||
|
display_message=processed_plain_text,
|
||||||
|
is_mentioned=is_mentioned,
|
||||||
|
is_at=bool(processing_state.get("is_at", False)),
|
||||||
|
is_emoji=bool(processing_state.get("is_emoji", False)),
|
||||||
|
is_picid=bool(processing_state.get("is_picid", False)),
|
||||||
|
is_command=False, # 将在后续处理中设置
|
||||||
|
is_notify=bool(is_notify),
|
||||||
|
is_public_notice=bool(is_public_notice),
|
||||||
|
notice_type=notice_type,
|
||||||
|
additional_config=additional_config_str,
|
||||||
|
user_id=user_id,
|
||||||
|
user_nickname=user_nickname,
|
||||||
|
user_cardname=user_cardname,
|
||||||
|
user_platform=user_platform,
|
||||||
|
chat_info_stream_id=stream_id,
|
||||||
|
chat_info_platform=platform,
|
||||||
|
chat_info_create_time=0.0, # 将由 ChatStream 填充
|
||||||
|
chat_info_last_active_time=0.0, # 将由 ChatStream 填充
|
||||||
|
chat_info_user_id=user_id,
|
||||||
|
chat_info_user_nickname=user_nickname,
|
||||||
|
chat_info_user_cardname=user_cardname,
|
||||||
|
chat_info_user_platform=user_platform,
|
||||||
|
chat_info_group_id=group_id,
|
||||||
|
chat_info_group_name=group_name,
|
||||||
|
chat_info_group_platform=group_platform,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 设置优先级信息
|
||||||
|
if processing_state.get("priority_mode"):
|
||||||
|
setattr(db_message, "priority_mode", processing_state["priority_mode"])
|
||||||
|
if processing_state.get("priority_info"):
|
||||||
|
setattr(db_message, "priority_info", processing_state["priority_info"])
|
||||||
|
|
||||||
|
# 设置其他运行时属性
|
||||||
|
setattr(db_message, "is_voice", bool(processing_state.get("is_voice", False)))
|
||||||
|
setattr(db_message, "is_video", bool(processing_state.get("is_video", False)))
|
||||||
|
setattr(db_message, "has_emoji", bool(processing_state.get("has_emoji", False)))
|
||||||
|
setattr(db_message, "has_picid", bool(processing_state.get("has_picid", False)))
|
||||||
|
|
||||||
|
return db_message
|
||||||
|
|
||||||
|
|
||||||
|
async def _process_message_segments(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str:
|
||||||
|
"""递归处理消息段,转换为文字描述
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segment: 要处理的消息段
|
||||||
|
state: 处理状态字典(用于记录消息类型标记)
|
||||||
|
message_info: 消息基础信息(用于某些处理逻辑)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 处理后的文本
|
||||||
|
"""
|
||||||
|
if segment.type == "seglist":
|
||||||
|
# 处理消息段列表
|
||||||
|
segments_text = []
|
||||||
|
for seg in segment.data:
|
||||||
|
processed = await _process_message_segments(seg, state, message_info)
|
||||||
|
if processed:
|
||||||
|
segments_text.append(processed)
|
||||||
|
return " ".join(segments_text)
|
||||||
|
else:
|
||||||
|
# 处理单个消息段
|
||||||
|
return await _process_single_segment(segment, state, message_info)
|
||||||
|
|
||||||
|
|
||||||
|
async def _process_single_segment(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str:
|
||||||
|
"""处理单个消息段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segment: 消息段
|
||||||
|
state: 处理状态字典
|
||||||
|
message_info: 消息基础信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 处理后的文本
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if segment.type == "text":
|
||||||
|
state["is_picid"] = False
|
||||||
|
state["is_emoji"] = False
|
||||||
|
state["is_video"] = False
|
||||||
|
return segment.data
|
||||||
|
|
||||||
|
elif segment.type == "at":
|
||||||
|
state["is_picid"] = False
|
||||||
|
state["is_emoji"] = False
|
||||||
|
state["is_video"] = False
|
||||||
|
state["is_at"] = True
|
||||||
|
# 处理at消息,格式为"昵称:QQ号"
|
||||||
|
if isinstance(segment.data, str) and ":" in segment.data:
|
||||||
|
nickname, qq_id = segment.data.split(":", 1)
|
||||||
|
return f"@{nickname}"
|
||||||
|
return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户"
|
||||||
|
|
||||||
|
elif segment.type == "image":
|
||||||
|
# 如果是base64图片数据
|
||||||
|
if isinstance(segment.data, str):
|
||||||
|
state["has_picid"] = True
|
||||||
|
state["is_picid"] = True
|
||||||
|
state["is_emoji"] = False
|
||||||
|
state["is_video"] = False
|
||||||
|
image_manager = get_image_manager()
|
||||||
|
_, processed_text = await image_manager.process_image(segment.data)
|
||||||
|
return processed_text
|
||||||
|
return "[发了一张图片,网卡了加载不出来]"
|
||||||
|
|
||||||
|
elif segment.type == "emoji":
|
||||||
|
state["has_emoji"] = True
|
||||||
|
state["is_emoji"] = True
|
||||||
|
state["is_picid"] = False
|
||||||
|
state["is_voice"] = False
|
||||||
|
state["is_video"] = False
|
||||||
|
if isinstance(segment.data, str):
|
||||||
|
return await get_image_manager().get_emoji_description(segment.data)
|
||||||
|
return "[发了一个表情包,网卡了加载不出来]"
|
||||||
|
|
||||||
|
elif segment.type == "voice":
|
||||||
|
state["is_picid"] = False
|
||||||
|
state["is_emoji"] = False
|
||||||
|
state["is_voice"] = True
|
||||||
|
state["is_video"] = False
|
||||||
|
|
||||||
|
# 检查消息是否由机器人自己发送
|
||||||
|
if message_info and message_info.user_info and str(message_info.user_info.user_id) == str(global_config.bot.qq_account):
|
||||||
|
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {message_info.user_info.user_id}),尝试从缓存获取文本。")
|
||||||
|
if isinstance(segment.data, str):
|
||||||
|
cached_text = consume_self_voice_text(segment.data)
|
||||||
|
if cached_text:
|
||||||
|
logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'")
|
||||||
|
return f"[语音:{cached_text}]"
|
||||||
|
else:
|
||||||
|
logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
|
||||||
|
|
||||||
|
# 标准语音识别流程
|
||||||
|
if isinstance(segment.data, str):
|
||||||
|
return await get_voice_text(segment.data)
|
||||||
|
return "[发了一段语音,网卡了加载不出来]"
|
||||||
|
|
||||||
|
elif segment.type == "mention_bot":
|
||||||
|
state["is_picid"] = False
|
||||||
|
state["is_emoji"] = False
|
||||||
|
state["is_voice"] = False
|
||||||
|
state["is_video"] = False
|
||||||
|
state["is_mentioned"] = float(segment.data)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
elif segment.type == "priority_info":
|
||||||
|
state["is_picid"] = False
|
||||||
|
state["is_emoji"] = False
|
||||||
|
state["is_voice"] = False
|
||||||
|
if isinstance(segment.data, dict):
|
||||||
|
# 处理优先级信息
|
||||||
|
state["priority_mode"] = "priority"
|
||||||
|
state["priority_info"] = segment.data
|
||||||
|
return ""
|
||||||
|
|
||||||
|
elif segment.type == "file":
|
||||||
|
if isinstance(segment.data, dict):
|
||||||
|
file_name = segment.data.get('name', '未知文件')
|
||||||
|
file_size = segment.data.get('size', '未知大小')
|
||||||
|
return f"[文件:{file_name} ({file_size}字节)]"
|
||||||
|
return "[收到一个文件]"
|
||||||
|
|
||||||
|
elif segment.type == "video":
|
||||||
|
state["is_picid"] = False
|
||||||
|
state["is_emoji"] = False
|
||||||
|
state["is_voice"] = False
|
||||||
|
state["is_video"] = True
|
||||||
|
logger.info(f"接收到视频消息,数据类型: {type(segment.data)}")
|
||||||
|
|
||||||
|
# 检查视频分析功能是否可用
|
||||||
|
if not is_video_analysis_available():
|
||||||
|
logger.warning("⚠️ Rust视频处理模块不可用,跳过视频分析")
|
||||||
|
return "[视频]"
|
||||||
|
|
||||||
|
if global_config.video_analysis.enable:
|
||||||
|
logger.info("已启用视频识别,开始识别")
|
||||||
|
if isinstance(segment.data, dict):
|
||||||
|
try:
|
||||||
|
# 从Adapter接收的视频数据
|
||||||
|
video_base64 = segment.data.get("base64")
|
||||||
|
filename = segment.data.get("filename", "video.mp4")
|
||||||
|
|
||||||
|
logger.info(f"视频文件名: {filename}")
|
||||||
|
logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
|
||||||
|
|
||||||
|
if video_base64:
|
||||||
|
# 解码base64视频数据
|
||||||
|
video_bytes = base64.b64decode(video_base64)
|
||||||
|
logger.info(f"解码后视频大小: {len(video_bytes)} 字节")
|
||||||
|
|
||||||
|
# 使用video analyzer分析视频
|
||||||
|
video_analyzer = get_video_analyzer()
|
||||||
|
result = await video_analyzer.analyze_video_from_bytes(
|
||||||
|
video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"视频分析结果: {result}")
|
||||||
|
|
||||||
|
# 返回视频分析结果
|
||||||
|
summary = result.get("summary", "")
|
||||||
|
if summary:
|
||||||
|
return f"[视频内容] {summary}"
|
||||||
|
else:
|
||||||
|
return "[已收到视频,但分析失败]"
|
||||||
|
else:
|
||||||
|
logger.warning("视频消息中没有base64数据")
|
||||||
|
return "[收到视频消息,但数据异常]"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"视频处理失败: {e!s}")
|
||||||
|
import traceback
|
||||||
|
logger.error(f"错误详情: {traceback.format_exc()}")
|
||||||
|
return "[收到视频,但处理时出现错误]"
|
||||||
|
else:
|
||||||
|
logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}")
|
||||||
|
return "[发了一个视频,但格式不支持]"
|
||||||
|
else:
|
||||||
|
return ""
|
||||||
|
else:
|
||||||
|
logger.warning(f"未知的消息段类型: {segment.type}")
|
||||||
|
return f"[{segment.type} 消息]"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
|
||||||
|
return f"[处理失败的{segment.type}消息]"
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, is_public_notice: bool, notice_type: str | None) -> str | None:
|
||||||
|
"""准备 additional_config,包含 format_info 和 notice 信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_info: 消息基础信息
|
||||||
|
is_notify: 是否为notice消息
|
||||||
|
is_public_notice: 是否为公共notice
|
||||||
|
notice_type: notice类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str | None: JSON 字符串格式的 additional_config,如果为空则返回 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
additional_config_data = {}
|
||||||
|
|
||||||
|
# 首先获取adapter传递的additional_config
|
||||||
|
if hasattr(message_info, 'additional_config') and message_info.additional_config:
|
||||||
|
if isinstance(message_info.additional_config, dict):
|
||||||
|
additional_config_data = message_info.additional_config.copy()
|
||||||
|
elif isinstance(message_info.additional_config, str):
|
||||||
|
try:
|
||||||
|
additional_config_data = orjson.loads(message_info.additional_config)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"无法解析 additional_config JSON: {e}")
|
||||||
|
additional_config_data = {}
|
||||||
|
|
||||||
|
# 添加notice相关标志
|
||||||
|
if is_notify:
|
||||||
|
additional_config_data["is_notice"] = True
|
||||||
|
additional_config_data["notice_type"] = notice_type or "unknown"
|
||||||
|
additional_config_data["is_public_notice"] = bool(is_public_notice)
|
||||||
|
|
||||||
|
# 添加format_info到additional_config中
|
||||||
|
if hasattr(message_info, 'format_info') and message_info.format_info:
|
||||||
|
try:
|
||||||
|
format_info_dict = message_info.format_info.to_dict()
|
||||||
|
additional_config_data["format_info"] = format_info_dict
|
||||||
|
logger.debug(f"[message_processor] 嵌入 format_info 到 additional_config: {format_info_dict}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"将 format_info 转换为字典失败: {e}")
|
||||||
|
|
||||||
|
# 序列化为JSON字符串
|
||||||
|
if additional_config_data:
|
||||||
|
return orjson.dumps(additional_config_data).decode("utf-8")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"准备 additional_config 失败: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_reply_from_segment(segment: Seg) -> str | None:
|
||||||
|
"""从消息段中提取reply_to信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segment: 消息段
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str | None: 回复的消息ID,如果没有则返回None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if hasattr(segment, "type") and segment.type == "seglist":
|
||||||
|
# 递归搜索seglist中的reply段
|
||||||
|
if hasattr(segment, "data") and segment.data:
|
||||||
|
for seg in segment.data:
|
||||||
|
reply_id = _extract_reply_from_segment(seg)
|
||||||
|
if reply_id:
|
||||||
|
return reply_id
|
||||||
|
elif hasattr(segment, "type") and segment.type == "reply":
|
||||||
|
# 找到reply段,返回message_id
|
||||||
|
return str(segment.data) if segment.data else None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"提取reply_to信息失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# DatabaseMessages 扩展工具函数
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessageInfo:
|
||||||
|
"""从 DatabaseMessages 重建 BaseMessageInfo(用于需要 message_info 的遗留代码)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_message: DatabaseMessages 对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseMessageInfo: 重建的消息信息对象
|
||||||
|
"""
|
||||||
|
from maim_message import UserInfo, GroupInfo
|
||||||
|
|
||||||
|
# 从 DatabaseMessages 的 user_info 转换为 maim_message.UserInfo
|
||||||
|
user_info = UserInfo(
|
||||||
|
platform=db_message.user_info.platform,
|
||||||
|
user_id=db_message.user_info.user_id,
|
||||||
|
user_nickname=db_message.user_info.user_nickname,
|
||||||
|
user_cardname=db_message.user_info.user_cardname or ""
|
||||||
|
)
|
||||||
|
|
||||||
|
# 从 DatabaseMessages 的 group_info 转换为 maim_message.GroupInfo(如果存在)
|
||||||
|
group_info = None
|
||||||
|
if db_message.group_info:
|
||||||
|
group_info = GroupInfo(
|
||||||
|
platform=db_message.group_info.group_platform or "",
|
||||||
|
group_id=db_message.group_info.group_id,
|
||||||
|
group_name=db_message.group_info.group_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解析 additional_config(从 JSON 字符串到字典)
|
||||||
|
additional_config = None
|
||||||
|
if db_message.additional_config:
|
||||||
|
try:
|
||||||
|
additional_config = orjson.loads(db_message.additional_config)
|
||||||
|
except Exception:
|
||||||
|
# 如果解析失败,保持为字符串
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 创建 BaseMessageInfo
|
||||||
|
message_info = BaseMessageInfo(
|
||||||
|
platform=db_message.chat_info.platform,
|
||||||
|
message_id=db_message.message_id,
|
||||||
|
time=db_message.time,
|
||||||
|
user_info=user_info,
|
||||||
|
group_info=group_info,
|
||||||
|
additional_config=additional_config # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
return message_info
|
||||||
|
|
||||||
|
|
||||||
|
def set_db_message_runtime_attr(db_message: DatabaseMessages, attr_name: str, value: Any) -> None:
|
||||||
|
"""安全地为 DatabaseMessages 设置运行时属性
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_message: DatabaseMessages 对象
|
||||||
|
attr_name: 属性名
|
||||||
|
value: 属性值
|
||||||
|
"""
|
||||||
|
setattr(db_message, attr_name, value)
|
||||||
|
|
||||||
|
|
||||||
|
def get_db_message_runtime_attr(db_message: DatabaseMessages, attr_name: str, default: Any = None) -> Any:
|
||||||
|
"""安全地获取 DatabaseMessages 的运行时属性
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_message: DatabaseMessages 对象
|
||||||
|
attr_name: 属性名
|
||||||
|
default: 默认值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
属性值或默认值
|
||||||
|
"""
|
||||||
|
return getattr(db_message, attr_name, default)
|
||||||
434
src/chat/message_receive/message_recv_backup.py
Normal file
434
src/chat/message_receive/message_recv_backup.py
Normal file
@@ -0,0 +1,434 @@
|
|||||||
|
# MessageRecv 类备份 - 已从 message.py 中移除
|
||||||
|
# 备份日期: 2025-10-31
|
||||||
|
# 此类已被 DatabaseMessages 完全取代
|
||||||
|
|
||||||
|
# MessageRecv 类已被移除
|
||||||
|
# 现在所有消息处理都使用 DatabaseMessages
|
||||||
|
# 如果需要从消息字典创建 DatabaseMessages,请使用 message_processor.process_message_from_dict()
|
||||||
|
#
|
||||||
|
# 历史参考: MessageRecv 曾经是接收消息的包装类,现已被 DatabaseMessages 完全取代
|
||||||
|
# 迁移完成日期: 2025-10-31
|
||||||
|
|
||||||
|
"""
|
||||||
|
# 以下是已删除的 MessageRecv 类(保留作为参考)
|
||||||
|
class MessageRecv:
|
||||||
|
接收消息类 - DatabaseMessages 的轻量级包装器
|
||||||
|
|
||||||
|
这个类现在主要作为适配器层,处理外部消息格式并内部使用 DatabaseMessages。
|
||||||
|
保留此类是为了向后兼容性和处理 message_segment 的异步逻辑。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message_dict: dict[str, Any]):
|
||||||
|
"""从MessageCQ的字典初始化
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_dict: MessageCQ序列化后的字典
|
||||||
|
"""
|
||||||
|
# 保留原始消息信息用于某些场景
|
||||||
|
self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
|
||||||
|
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
|
||||||
|
self.raw_message = message_dict.get("raw_message")
|
||||||
|
|
||||||
|
# 处理状态(在process()之前临时使用)
|
||||||
|
self._processing_state = {
|
||||||
|
"is_emoji": False,
|
||||||
|
"has_emoji": False,
|
||||||
|
"is_picid": False,
|
||||||
|
"has_picid": False,
|
||||||
|
"is_voice": False,
|
||||||
|
"is_video": False,
|
||||||
|
"is_mentioned": None,
|
||||||
|
"is_at": False,
|
||||||
|
"priority_mode": "interest",
|
||||||
|
"priority_info": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.chat_stream = None
|
||||||
|
self.reply = None
|
||||||
|
self.processed_plain_text = message_dict.get("processed_plain_text", "")
|
||||||
|
|
||||||
|
# 解析additional_config中的notice信息
|
||||||
|
self.is_notify = False
|
||||||
|
self.is_public_notice = False
|
||||||
|
self.notice_type = None
|
||||||
|
if self.message_info.additional_config and isinstance(self.message_info.additional_config, dict):
|
||||||
|
self.is_notify = self.message_info.additional_config.get("is_notice", False)
|
||||||
|
self.is_public_notice = self.message_info.additional_config.get("is_public_notice", False)
|
||||||
|
self.notice_type = self.message_info.additional_config.get("notice_type")
|
||||||
|
|
||||||
|
# 兼容性属性 - 代理到 _processing_state
|
||||||
|
@property
|
||||||
|
def is_emoji(self) -> bool:
|
||||||
|
return self._processing_state["is_emoji"]
|
||||||
|
|
||||||
|
@is_emoji.setter
|
||||||
|
def is_emoji(self, value: bool):
|
||||||
|
self._processing_state["is_emoji"] = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_emoji(self) -> bool:
|
||||||
|
return self._processing_state["has_emoji"]
|
||||||
|
|
||||||
|
@has_emoji.setter
|
||||||
|
def has_emoji(self, value: bool):
|
||||||
|
self._processing_state["has_emoji"] = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_picid(self) -> bool:
|
||||||
|
return self._processing_state["is_picid"]
|
||||||
|
|
||||||
|
@is_picid.setter
|
||||||
|
def is_picid(self, value: bool):
|
||||||
|
self._processing_state["is_picid"] = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_picid(self) -> bool:
|
||||||
|
return self._processing_state["has_picid"]
|
||||||
|
|
||||||
|
@has_picid.setter
|
||||||
|
def has_picid(self, value: bool):
|
||||||
|
self._processing_state["has_picid"] = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_voice(self) -> bool:
|
||||||
|
return self._processing_state["is_voice"]
|
||||||
|
|
||||||
|
@is_voice.setter
|
||||||
|
def is_voice(self, value: bool):
|
||||||
|
self._processing_state["is_voice"] = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_video(self) -> bool:
|
||||||
|
return self._processing_state["is_video"]
|
||||||
|
|
||||||
|
@is_video.setter
|
||||||
|
def is_video(self, value: bool):
|
||||||
|
self._processing_state["is_video"] = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_mentioned(self):
|
||||||
|
return self._processing_state["is_mentioned"]
|
||||||
|
|
||||||
|
@is_mentioned.setter
|
||||||
|
def is_mentioned(self, value):
|
||||||
|
self._processing_state["is_mentioned"] = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_at(self) -> bool:
|
||||||
|
return self._processing_state["is_at"]
|
||||||
|
|
||||||
|
@is_at.setter
|
||||||
|
def is_at(self, value: bool):
|
||||||
|
self._processing_state["is_at"] = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def priority_mode(self) -> str:
|
||||||
|
return self._processing_state["priority_mode"]
|
||||||
|
|
||||||
|
@priority_mode.setter
|
||||||
|
def priority_mode(self, value: str):
|
||||||
|
self._processing_state["priority_mode"] = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def priority_info(self):
|
||||||
|
return self._processing_state["priority_info"]
|
||||||
|
|
||||||
|
@priority_info.setter
|
||||||
|
def priority_info(self, value):
|
||||||
|
self._processing_state["priority_info"] = value
|
||||||
|
|
||||||
|
# 其他常用属性
|
||||||
|
interest_value: float = 0.0
|
||||||
|
is_command: bool = False
|
||||||
|
memorized_times: int = 0
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""dataclass 初始化后处理"""
|
||||||
|
self.key_words = []
|
||||||
|
self.key_words_lite = []
|
||||||
|
|
||||||
|
def update_chat_stream(self, chat_stream: "ChatStream"):
|
||||||
|
self.chat_stream = chat_stream
|
||||||
|
|
||||||
|
def to_database_message(self) -> "DatabaseMessages":
|
||||||
|
"""将 MessageRecv 转换为 DatabaseMessages 对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DatabaseMessages: 数据库消息对象
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
message_info = self.message_info
|
||||||
|
msg_user_info = getattr(message_info, "user_info", None)
|
||||||
|
stream_user_info = getattr(self.chat_stream, "user_info", None) if self.chat_stream else None
|
||||||
|
group_info = getattr(self.chat_stream, "group_info", None) if self.chat_stream else None
|
||||||
|
|
||||||
|
message_id = message_info.message_id or ""
|
||||||
|
message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time()
|
||||||
|
is_mentioned = None
|
||||||
|
if isinstance(self.is_mentioned, bool):
|
||||||
|
is_mentioned = self.is_mentioned
|
||||||
|
elif isinstance(self.is_mentioned, int | float):
|
||||||
|
is_mentioned = self.is_mentioned != 0
|
||||||
|
|
||||||
|
# 提取用户信息
|
||||||
|
user_id = ""
|
||||||
|
user_nickname = ""
|
||||||
|
user_cardname = None
|
||||||
|
user_platform = ""
|
||||||
|
if msg_user_info:
|
||||||
|
user_id = str(getattr(msg_user_info, "user_id", "") or "")
|
||||||
|
user_nickname = getattr(msg_user_info, "user_nickname", "") or ""
|
||||||
|
user_cardname = getattr(msg_user_info, "user_cardname", None)
|
||||||
|
user_platform = getattr(msg_user_info, "platform", "") or ""
|
||||||
|
elif stream_user_info:
|
||||||
|
user_id = str(getattr(stream_user_info, "user_id", "") or "")
|
||||||
|
user_nickname = getattr(stream_user_info, "user_nickname", "") or ""
|
||||||
|
user_cardname = getattr(stream_user_info, "user_cardname", None)
|
||||||
|
user_platform = getattr(stream_user_info, "platform", "") or ""
|
||||||
|
|
||||||
|
# 提取聊天流信息
|
||||||
|
chat_user_id = str(getattr(stream_user_info, "user_id", "") or "") if stream_user_info else ""
|
||||||
|
chat_user_nickname = getattr(stream_user_info, "user_nickname", "") or "" if stream_user_info else ""
|
||||||
|
chat_user_cardname = getattr(stream_user_info, "user_cardname", None) if stream_user_info else None
|
||||||
|
chat_user_platform = getattr(stream_user_info, "platform", "") or "" if stream_user_info else ""
|
||||||
|
|
||||||
|
group_id = getattr(group_info, "group_id", None) if group_info else None
|
||||||
|
group_name = getattr(group_info, "group_name", None) if group_info else None
|
||||||
|
group_platform = getattr(group_info, "platform", None) if group_info else None
|
||||||
|
|
||||||
|
# 准备 additional_config
|
||||||
|
additional_config_str = None
|
||||||
|
try:
|
||||||
|
import orjson
|
||||||
|
|
||||||
|
additional_config_data = {}
|
||||||
|
|
||||||
|
# 首先获取adapter传递的additional_config
|
||||||
|
if hasattr(message_info, 'additional_config') and message_info.additional_config:
|
||||||
|
if isinstance(message_info.additional_config, dict):
|
||||||
|
additional_config_data = message_info.additional_config.copy()
|
||||||
|
elif isinstance(message_info.additional_config, str):
|
||||||
|
try:
|
||||||
|
additional_config_data = orjson.loads(message_info.additional_config)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"无法解析 additional_config JSON: {e}")
|
||||||
|
additional_config_data = {}
|
||||||
|
|
||||||
|
# 添加notice相关标志
|
||||||
|
if self.is_notify:
|
||||||
|
additional_config_data["is_notice"] = True
|
||||||
|
additional_config_data["notice_type"] = self.notice_type or "unknown"
|
||||||
|
additional_config_data["is_public_notice"] = bool(self.is_public_notice)
|
||||||
|
|
||||||
|
# 添加format_info到additional_config中
|
||||||
|
if hasattr(message_info, 'format_info') and message_info.format_info:
|
||||||
|
try:
|
||||||
|
format_info_dict = message_info.format_info.to_dict()
|
||||||
|
additional_config_data["format_info"] = format_info_dict
|
||||||
|
logger.debug(f"[message.py] 嵌入 format_info 到 additional_config: {format_info_dict}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"将 format_info 转换为字典失败: {e}")
|
||||||
|
|
||||||
|
# 序列化为JSON字符串
|
||||||
|
if additional_config_data:
|
||||||
|
additional_config_str = orjson.dumps(additional_config_data).decode("utf-8")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"准备 additional_config 失败: {e}")
|
||||||
|
|
||||||
|
# 创建数据库消息对象
|
||||||
|
db_message = DatabaseMessages(
|
||||||
|
message_id=message_id,
|
||||||
|
time=float(message_time),
|
||||||
|
chat_id=self.chat_stream.stream_id if self.chat_stream else "",
|
||||||
|
processed_plain_text=self.processed_plain_text,
|
||||||
|
display_message=self.processed_plain_text,
|
||||||
|
is_mentioned=is_mentioned,
|
||||||
|
is_at=bool(self.is_at) if self.is_at is not None else None,
|
||||||
|
is_emoji=bool(self.is_emoji),
|
||||||
|
is_picid=bool(self.is_picid),
|
||||||
|
is_command=bool(self.is_command),
|
||||||
|
is_notify=bool(self.is_notify),
|
||||||
|
is_public_notice=bool(self.is_public_notice),
|
||||||
|
notice_type=self.notice_type,
|
||||||
|
additional_config=additional_config_str,
|
||||||
|
user_id=user_id,
|
||||||
|
user_nickname=user_nickname,
|
||||||
|
user_cardname=user_cardname,
|
||||||
|
user_platform=user_platform,
|
||||||
|
chat_info_stream_id=self.chat_stream.stream_id if self.chat_stream else "",
|
||||||
|
chat_info_platform=self.chat_stream.platform if self.chat_stream else "",
|
||||||
|
chat_info_create_time=float(self.chat_stream.create_time) if self.chat_stream else 0.0,
|
||||||
|
chat_info_last_active_time=float(self.chat_stream.last_active_time) if self.chat_stream else 0.0,
|
||||||
|
chat_info_user_id=chat_user_id,
|
||||||
|
chat_info_user_nickname=chat_user_nickname,
|
||||||
|
chat_info_user_cardname=chat_user_cardname,
|
||||||
|
chat_info_user_platform=chat_user_platform,
|
||||||
|
chat_info_group_id=group_id,
|
||||||
|
chat_info_group_name=group_name,
|
||||||
|
chat_info_group_platform=group_platform,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 同步兴趣度等衍生属性
|
||||||
|
db_message.interest_value = getattr(self, "interest_value", 0.0)
|
||||||
|
setattr(db_message, "should_reply", getattr(self, "should_reply", False))
|
||||||
|
setattr(db_message, "should_act", getattr(self, "should_act", False))
|
||||||
|
|
||||||
|
return db_message
|
||||||
|
|
||||||
|
async def process(self) -> None:
|
||||||
|
"""处理消息内容,生成纯文本和详细文本
|
||||||
|
|
||||||
|
这个方法必须在创建实例后显式调用,因为它包含异步操作。
|
||||||
|
"""
|
||||||
|
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
||||||
|
|
||||||
|
async def _process_single_segment(self, segment: Seg) -> str:
|
||||||
|
"""处理单个消息段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segment: 消息段
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 处理后的文本
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if segment.type == "text":
|
||||||
|
self.is_picid = False
|
||||||
|
self.is_emoji = False
|
||||||
|
self.is_video = False
|
||||||
|
return segment.data # type: ignore
|
||||||
|
elif segment.type == "at":
|
||||||
|
self.is_picid = False
|
||||||
|
self.is_emoji = False
|
||||||
|
self.is_video = False
|
||||||
|
# 处理at消息,格式为"昵称:QQ号"
|
||||||
|
if isinstance(segment.data, str) and ":" in segment.data:
|
||||||
|
nickname, qq_id = segment.data.split(":", 1)
|
||||||
|
return f"@{nickname}"
|
||||||
|
return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户"
|
||||||
|
elif segment.type == "image":
|
||||||
|
# 如果是base64图片数据
|
||||||
|
if isinstance(segment.data, str):
|
||||||
|
self.has_picid = True
|
||||||
|
self.is_picid = True
|
||||||
|
self.is_emoji = False
|
||||||
|
self.is_video = False
|
||||||
|
image_manager = get_image_manager()
|
||||||
|
# print(f"segment.data: {segment.data}")
|
||||||
|
_, processed_text = await image_manager.process_image(segment.data)
|
||||||
|
return processed_text
|
||||||
|
return "[发了一张图片,网卡了加载不出来]"
|
||||||
|
elif segment.type == "emoji":
|
||||||
|
self.has_emoji = True
|
||||||
|
self.is_emoji = True
|
||||||
|
self.is_picid = False
|
||||||
|
self.is_voice = False
|
||||||
|
self.is_video = False
|
||||||
|
if isinstance(segment.data, str):
|
||||||
|
return await get_image_manager().get_emoji_description(segment.data)
|
||||||
|
return "[发了一个表情包,网卡了加载不出来]"
|
||||||
|
elif segment.type == "voice":
|
||||||
|
self.is_picid = False
|
||||||
|
self.is_emoji = False
|
||||||
|
self.is_voice = True
|
||||||
|
self.is_video = False
|
||||||
|
|
||||||
|
# 检查消息是否由机器人自己发送
|
||||||
|
if self.message_info and self.message_info.user_info and str(self.message_info.user_info.user_id) == str(global_config.bot.qq_account):
|
||||||
|
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {self.message_info.user_info.user_id}),尝试从缓存获取文本。")
|
||||||
|
if isinstance(segment.data, str):
|
||||||
|
cached_text = consume_self_voice_text(segment.data)
|
||||||
|
if cached_text:
|
||||||
|
logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'")
|
||||||
|
return f"[语音:{cached_text}]"
|
||||||
|
else:
|
||||||
|
logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
|
||||||
|
|
||||||
|
# 标准语音识别流程 (也作为缓存未命中的后备方案)
|
||||||
|
if isinstance(segment.data, str):
|
||||||
|
return await get_voice_text(segment.data)
|
||||||
|
return "[发了一段语音,网卡了加载不出来]"
|
||||||
|
elif segment.type == "mention_bot":
|
||||||
|
self.is_picid = False
|
||||||
|
self.is_emoji = False
|
||||||
|
self.is_voice = False
|
||||||
|
self.is_video = False
|
||||||
|
self.is_mentioned = float(segment.data) # type: ignore
|
||||||
|
return ""
|
||||||
|
elif segment.type == "priority_info":
|
||||||
|
self.is_picid = False
|
||||||
|
self.is_emoji = False
|
||||||
|
self.is_voice = False
|
||||||
|
if isinstance(segment.data, dict):
|
||||||
|
# 处理优先级信息
|
||||||
|
self.priority_mode = "priority"
|
||||||
|
self.priority_info = segment.data
|
||||||
|
"""
|
||||||
|
{
|
||||||
|
'message_type': 'vip', # vip or normal
|
||||||
|
'message_priority': 1.0, # 优先级,大为优先,float
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
return ""
|
||||||
|
elif segment.type == "file":
|
||||||
|
if isinstance(segment.data, dict):
|
||||||
|
file_name = segment.data.get('name', '未知文件')
|
||||||
|
file_size = segment.data.get('size', '未知大小')
|
||||||
|
return f"[文件:{file_name} ({file_size}字节)]"
|
||||||
|
return "[收到一个文件]"
|
||||||
|
elif segment.type == "video":
|
||||||
|
self.is_picid = False
|
||||||
|
self.is_emoji = False
|
||||||
|
self.is_voice = False
|
||||||
|
self.is_video = True
|
||||||
|
logger.info(f"接收到视频消息,数据类型: {type(segment.data)}")
|
||||||
|
|
||||||
|
# 检查视频分析功能是否可用
|
||||||
|
if not is_video_analysis_available():
|
||||||
|
logger.warning("⚠️ Rust视频处理模块不可用,跳过视频分析")
|
||||||
|
return "[视频]"
|
||||||
|
|
||||||
|
if global_config.video_analysis.enable:
|
||||||
|
logger.info("已启用视频识别,开始识别")
|
||||||
|
if isinstance(segment.data, dict):
|
||||||
|
try:
|
||||||
|
# 从Adapter接收的视频数据
|
||||||
|
video_base64 = segment.data.get("base64")
|
||||||
|
filename = segment.data.get("filename", "video.mp4")
|
||||||
|
|
||||||
|
logger.info(f"视频文件名: {filename}")
|
||||||
|
logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
|
||||||
|
|
||||||
|
if video_base64:
|
||||||
|
# 解码base64视频数据
|
||||||
|
video_bytes = base64.b64decode(video_base64)
|
||||||
|
logger.info(f"解码后视频大小: {len(video_bytes)} 字节")
|
||||||
|
|
||||||
|
# 使用video analyzer分析视频
|
||||||
|
video_analyzer = get_video_analyzer()
|
||||||
|
result = await video_analyzer.analyze_video_from_bytes(
|
||||||
|
video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"视频分析结果: {result}")
|
||||||
|
|
||||||
|
# 返回视频分析结果
|
||||||
|
summary = result.get("summary", "")
|
||||||
|
if summary:
|
||||||
|
return f"[视频内容] {summary}"
|
||||||
|
else:
|
||||||
|
return "[已收到视频,但分析失败]"
|
||||||
|
else:
|
||||||
|
logger.warning("视频消息中没有base64数据")
|
||||||
|
return "[收到视频消息,但数据异常]"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"视频处理失败: {e!s}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logger.error(f"错误详情: {traceback.format_exc()}")
|
||||||
|
return "[收到视频,但处理时出现错误]"
|
||||||
|
else:
|
||||||
|
logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}")
|
||||||
|
return "[发了一个视频,但格式不支持]"
|
||||||
|
else:
|
||||||
@@ -9,8 +9,10 @@ from src.common.database.sqlalchemy_database_api import get_db_session
|
|||||||
from src.common.database.sqlalchemy_models import Images, Messages
|
from src.common.database.sqlalchemy_models import Images, Messages
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|
||||||
from .chat_stream import ChatStream
|
from .chat_stream import ChatStream
|
||||||
from .message import MessageRecv, MessageSending
|
from .message import MessageSending
|
||||||
|
|
||||||
logger = get_logger("message_storage")
|
logger = get_logger("message_storage")
|
||||||
|
|
||||||
@@ -34,97 +36,166 @@ class MessageStorage:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def store_message(message: MessageSending | MessageRecv, chat_stream: ChatStream) -> None:
|
async def store_message(message: DatabaseMessages | MessageSending, chat_stream: ChatStream) -> None:
|
||||||
"""存储消息到数据库"""
|
"""存储消息到数据库"""
|
||||||
try:
|
try:
|
||||||
# 过滤敏感信息的正则模式
|
# 过滤敏感信息的正则模式
|
||||||
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
|
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
|
||||||
|
|
||||||
processed_plain_text = message.processed_plain_text
|
# 如果是 DatabaseMessages,直接使用它的字段
|
||||||
|
if isinstance(message, DatabaseMessages):
|
||||||
if processed_plain_text:
|
processed_plain_text = message.processed_plain_text
|
||||||
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
|
if processed_plain_text:
|
||||||
# 增加对None的防御性处理
|
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
|
||||||
safe_processed_plain_text = processed_plain_text or ""
|
safe_processed_plain_text = processed_plain_text or ""
|
||||||
filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
|
filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
|
||||||
else:
|
|
||||||
filtered_processed_plain_text = ""
|
|
||||||
|
|
||||||
if isinstance(message, MessageSending):
|
|
||||||
display_message = message.display_message
|
|
||||||
if display_message:
|
|
||||||
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
|
|
||||||
else:
|
else:
|
||||||
# 如果没有设置display_message,使用processed_plain_text作为显示消息
|
filtered_processed_plain_text = ""
|
||||||
filtered_display_message = (
|
|
||||||
re.sub(pattern, "", (message.processed_plain_text or ""), flags=re.DOTALL)
|
display_message = message.display_message or message.processed_plain_text or ""
|
||||||
)
|
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
|
||||||
interest_value = 0
|
|
||||||
is_mentioned = False
|
# 直接从 DatabaseMessages 获取所有字段
|
||||||
reply_to = message.reply_to
|
msg_id = message.message_id
|
||||||
priority_mode = ""
|
msg_time = message.time
|
||||||
priority_info = {}
|
chat_id = message.chat_id
|
||||||
is_emoji = False
|
reply_to = "" # DatabaseMessages 没有 reply_to 字段
|
||||||
is_picid = False
|
|
||||||
is_notify = False
|
|
||||||
is_command = False
|
|
||||||
key_words = ""
|
|
||||||
key_words_lite = ""
|
|
||||||
else:
|
|
||||||
filtered_display_message = ""
|
|
||||||
interest_value = message.interest_value
|
|
||||||
is_mentioned = message.is_mentioned
|
is_mentioned = message.is_mentioned
|
||||||
reply_to = ""
|
interest_value = message.interest_value or 0.0
|
||||||
priority_mode = message.priority_mode
|
priority_mode = "" # DatabaseMessages 没有 priority_mode
|
||||||
priority_info = message.priority_info
|
priority_info_json = None # DatabaseMessages 没有 priority_info
|
||||||
is_emoji = message.is_emoji
|
is_emoji = message.is_emoji or False
|
||||||
is_picid = message.is_picid
|
is_picid = message.is_picid or False
|
||||||
is_notify = message.is_notify
|
is_notify = message.is_notify or False
|
||||||
is_command = message.is_command
|
is_command = message.is_command or False
|
||||||
# 序列化关键词列表为JSON字符串
|
key_words = "" # DatabaseMessages 没有 key_words
|
||||||
key_words = MessageStorage._serialize_keywords(message.key_words)
|
key_words_lite = ""
|
||||||
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
|
memorized_times = 0 # DatabaseMessages 没有 memorized_times
|
||||||
|
|
||||||
|
# 使用 DatabaseMessages 中的嵌套对象信息
|
||||||
|
user_platform = message.user_info.platform if message.user_info else ""
|
||||||
|
user_id = message.user_info.user_id if message.user_info else ""
|
||||||
|
user_nickname = message.user_info.user_nickname if message.user_info else ""
|
||||||
|
user_cardname = message.user_info.user_cardname if message.user_info else None
|
||||||
|
|
||||||
|
chat_info_stream_id = message.chat_info.stream_id if message.chat_info else ""
|
||||||
|
chat_info_platform = message.chat_info.platform if message.chat_info else ""
|
||||||
|
chat_info_create_time = message.chat_info.create_time if message.chat_info else 0.0
|
||||||
|
chat_info_last_active_time = message.chat_info.last_active_time if message.chat_info else 0.0
|
||||||
|
chat_info_user_platform = message.chat_info.user_info.platform if message.chat_info and message.chat_info.user_info else ""
|
||||||
|
chat_info_user_id = message.chat_info.user_info.user_id if message.chat_info and message.chat_info.user_info else ""
|
||||||
|
chat_info_user_nickname = message.chat_info.user_info.user_nickname if message.chat_info and message.chat_info.user_info else ""
|
||||||
|
chat_info_user_cardname = message.chat_info.user_info.user_cardname if message.chat_info and message.chat_info.user_info else None
|
||||||
|
chat_info_group_platform = message.group_info.group_platform if message.group_info else None
|
||||||
|
chat_info_group_id = message.group_info.group_id if message.group_info else None
|
||||||
|
chat_info_group_name = message.group_info.group_name if message.group_info else None
|
||||||
|
|
||||||
|
else:
|
||||||
|
# MessageSending 处理逻辑
|
||||||
|
processed_plain_text = message.processed_plain_text
|
||||||
|
|
||||||
chat_info_dict = chat_stream.to_dict()
|
if processed_plain_text:
|
||||||
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
|
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
|
||||||
|
# 增加对None的防御性处理
|
||||||
|
safe_processed_plain_text = processed_plain_text or ""
|
||||||
|
filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
|
||||||
|
else:
|
||||||
|
filtered_processed_plain_text = ""
|
||||||
|
|
||||||
# message_id 现在是 TextField,直接使用字符串值
|
if isinstance(message, MessageSending):
|
||||||
msg_id = message.message_info.message_id
|
display_message = message.display_message
|
||||||
|
if display_message:
|
||||||
|
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
|
||||||
|
else:
|
||||||
|
# 如果没有设置display_message,使用processed_plain_text作为显示消息
|
||||||
|
filtered_display_message = (
|
||||||
|
re.sub(pattern, "", (message.processed_plain_text or ""), flags=re.DOTALL)
|
||||||
|
)
|
||||||
|
interest_value = 0
|
||||||
|
is_mentioned = False
|
||||||
|
reply_to = message.reply_to
|
||||||
|
priority_mode = ""
|
||||||
|
priority_info = {}
|
||||||
|
is_emoji = False
|
||||||
|
is_picid = False
|
||||||
|
is_notify = False
|
||||||
|
is_command = False
|
||||||
|
key_words = ""
|
||||||
|
key_words_lite = ""
|
||||||
|
else:
|
||||||
|
filtered_display_message = ""
|
||||||
|
interest_value = message.interest_value
|
||||||
|
is_mentioned = message.is_mentioned
|
||||||
|
reply_to = ""
|
||||||
|
priority_mode = message.priority_mode
|
||||||
|
priority_info = message.priority_info
|
||||||
|
is_emoji = message.is_emoji
|
||||||
|
is_picid = message.is_picid
|
||||||
|
is_notify = message.is_notify
|
||||||
|
is_command = message.is_command
|
||||||
|
# 序列化关键词列表为JSON字符串
|
||||||
|
key_words = MessageStorage._serialize_keywords(message.key_words)
|
||||||
|
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
|
||||||
|
|
||||||
# 安全地获取 group_info, 如果为 None 则视为空字典
|
chat_info_dict = chat_stream.to_dict()
|
||||||
group_info_from_chat = chat_info_dict.get("group_info") or {}
|
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
|
||||||
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
|
|
||||||
user_info_from_chat = chat_info_dict.get("user_info") or {}
|
|
||||||
|
|
||||||
# 将priority_info字典序列化为JSON字符串,以便存储到数据库的Text字段
|
# message_id 现在是 TextField,直接使用字符串值
|
||||||
priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None
|
msg_id = message.message_info.message_id
|
||||||
|
msg_time = float(message.message_info.time or time.time())
|
||||||
|
chat_id = chat_stream.stream_id
|
||||||
|
memorized_times = message.memorized_times
|
||||||
|
|
||||||
|
# 安全地获取 group_info, 如果为 None 则视为空字典
|
||||||
|
group_info_from_chat = chat_info_dict.get("group_info") or {}
|
||||||
|
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
|
||||||
|
user_info_from_chat = chat_info_dict.get("user_info") or {}
|
||||||
|
|
||||||
|
# 将priority_info字典序列化为JSON字符串,以便存储到数据库的Text字段
|
||||||
|
priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None
|
||||||
|
|
||||||
|
user_platform = user_info_dict.get("platform")
|
||||||
|
user_id = user_info_dict.get("user_id")
|
||||||
|
user_nickname = user_info_dict.get("user_nickname")
|
||||||
|
user_cardname = user_info_dict.get("user_cardname")
|
||||||
|
|
||||||
|
chat_info_stream_id = chat_info_dict.get("stream_id")
|
||||||
|
chat_info_platform = chat_info_dict.get("platform")
|
||||||
|
chat_info_create_time = float(chat_info_dict.get("create_time", 0.0))
|
||||||
|
chat_info_last_active_time = float(chat_info_dict.get("last_active_time", 0.0))
|
||||||
|
chat_info_user_platform = user_info_from_chat.get("platform")
|
||||||
|
chat_info_user_id = user_info_from_chat.get("user_id")
|
||||||
|
chat_info_user_nickname = user_info_from_chat.get("user_nickname")
|
||||||
|
chat_info_user_cardname = user_info_from_chat.get("user_cardname")
|
||||||
|
chat_info_group_platform = group_info_from_chat.get("platform")
|
||||||
|
chat_info_group_id = group_info_from_chat.get("group_id")
|
||||||
|
chat_info_group_name = group_info_from_chat.get("group_name")
|
||||||
|
|
||||||
# 获取数据库会话
|
# 获取数据库会话
|
||||||
|
|
||||||
new_message = Messages(
|
new_message = Messages(
|
||||||
message_id=msg_id,
|
message_id=msg_id,
|
||||||
time=float(message.message_info.time or time.time()),
|
time=msg_time,
|
||||||
chat_id=chat_stream.stream_id,
|
chat_id=chat_id,
|
||||||
reply_to=reply_to,
|
reply_to=reply_to,
|
||||||
is_mentioned=is_mentioned,
|
is_mentioned=is_mentioned,
|
||||||
chat_info_stream_id=chat_info_dict.get("stream_id"),
|
chat_info_stream_id=chat_info_stream_id,
|
||||||
chat_info_platform=chat_info_dict.get("platform"),
|
chat_info_platform=chat_info_platform,
|
||||||
chat_info_user_platform=user_info_from_chat.get("platform"),
|
chat_info_user_platform=chat_info_user_platform,
|
||||||
chat_info_user_id=user_info_from_chat.get("user_id"),
|
chat_info_user_id=chat_info_user_id,
|
||||||
chat_info_user_nickname=user_info_from_chat.get("user_nickname"),
|
chat_info_user_nickname=chat_info_user_nickname,
|
||||||
chat_info_user_cardname=user_info_from_chat.get("user_cardname"),
|
chat_info_user_cardname=chat_info_user_cardname,
|
||||||
chat_info_group_platform=group_info_from_chat.get("platform"),
|
chat_info_group_platform=chat_info_group_platform,
|
||||||
chat_info_group_id=group_info_from_chat.get("group_id"),
|
chat_info_group_id=chat_info_group_id,
|
||||||
chat_info_group_name=group_info_from_chat.get("group_name"),
|
chat_info_group_name=chat_info_group_name,
|
||||||
chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)),
|
chat_info_create_time=chat_info_create_time,
|
||||||
chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)),
|
chat_info_last_active_time=chat_info_last_active_time,
|
||||||
user_platform=user_info_dict.get("platform"),
|
user_platform=user_platform,
|
||||||
user_id=user_info_dict.get("user_id"),
|
user_id=user_id,
|
||||||
user_nickname=user_info_dict.get("user_nickname"),
|
user_nickname=user_nickname,
|
||||||
user_cardname=user_info_dict.get("user_cardname"),
|
user_cardname=user_cardname,
|
||||||
processed_plain_text=filtered_processed_plain_text,
|
processed_plain_text=filtered_processed_plain_text,
|
||||||
display_message=filtered_display_message,
|
display_message=filtered_display_message,
|
||||||
memorized_times=message.memorized_times,
|
memorized_times=memorized_times,
|
||||||
interest_value=interest_value,
|
interest_value=interest_value,
|
||||||
priority_mode=priority_mode,
|
priority_mode=priority_mode,
|
||||||
priority_info=priority_info_json,
|
priority_info=priority_info_json,
|
||||||
@@ -145,36 +216,43 @@ class MessageStorage:
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def update_message(message):
|
async def update_message(message_data: dict):
|
||||||
"""更新消息ID"""
|
"""更新消息ID(从消息字典)"""
|
||||||
try:
|
try:
|
||||||
mmc_message_id = message.message_info.message_id
|
# 从字典中提取信息
|
||||||
|
message_info = message_data.get("message_info", {})
|
||||||
|
mmc_message_id = message_info.get("message_id")
|
||||||
|
|
||||||
|
message_segment = message_data.get("message_segment", {})
|
||||||
|
segment_type = message_segment.get("type") if isinstance(message_segment, dict) else None
|
||||||
|
segment_data = message_segment.get("data", {}) if isinstance(message_segment, dict) else {}
|
||||||
|
|
||||||
qq_message_id = None
|
qq_message_id = None
|
||||||
|
|
||||||
logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {message.message_segment.type}")
|
logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {segment_type}")
|
||||||
|
|
||||||
# 根据消息段类型提取message_id
|
# 根据消息段类型提取message_id
|
||||||
if message.message_segment.type == "notify":
|
if segment_type == "notify":
|
||||||
qq_message_id = message.message_segment.data.get("id")
|
qq_message_id = segment_data.get("id")
|
||||||
elif message.message_segment.type == "text":
|
elif segment_type == "text":
|
||||||
qq_message_id = message.message_segment.data.get("id")
|
qq_message_id = segment_data.get("id")
|
||||||
elif message.message_segment.type == "reply":
|
elif segment_type == "reply":
|
||||||
qq_message_id = message.message_segment.data.get("id")
|
qq_message_id = segment_data.get("id")
|
||||||
if qq_message_id:
|
if qq_message_id:
|
||||||
logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}")
|
logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}")
|
||||||
elif message.message_segment.type == "adapter_response":
|
elif segment_type == "adapter_response":
|
||||||
logger.debug("适配器响应消息,不需要更新ID")
|
logger.debug("适配器响应消息,不需要更新ID")
|
||||||
return
|
return
|
||||||
elif message.message_segment.type == "adapter_command":
|
elif segment_type == "adapter_command":
|
||||||
logger.debug("适配器命令消息,不需要更新ID")
|
logger.debug("适配器命令消息,不需要更新ID")
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
logger.debug(f"未知的消息段类型: {message.message_segment.type},跳过ID更新")
|
logger.debug(f"未知的消息段类型: {segment_type},跳过ID更新")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not qq_message_id:
|
if not qq_message_id:
|
||||||
logger.debug(f"消息段类型 {message.message_segment.type} 中未找到有效的message_id,跳过更新")
|
logger.debug(f"消息段类型 {segment_type} 中未找到有效的message_id,跳过更新")
|
||||||
logger.debug(f"消息段数据: {message.message_segment.data}")
|
logger.debug(f"消息段数据: {segment_data}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 使用上下文管理器确保session正确管理
|
# 使用上下文管理器确保session正确管理
|
||||||
|
|||||||
@@ -137,7 +137,7 @@ class ActionModifier:
|
|||||||
logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用")
|
logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用")
|
||||||
|
|
||||||
# === 第二阶段:检查动作的关联类型 ===
|
# === 第二阶段:检查动作的关联类型 ===
|
||||||
chat_context = self.chat_stream.stream_context
|
chat_context = self.chat_stream.context_manager.context
|
||||||
current_actions_s2 = self.action_manager.get_using_actions()
|
current_actions_s2 = self.action_manager.get_using_actions()
|
||||||
type_mismatched_actions = self._check_action_associated_types(current_actions_s2, chat_context)
|
type_mismatched_actions = self._check_action_associated_types(current_actions_s2, chat_context)
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from typing import Any
|
|||||||
|
|
||||||
from src.chat.express.expression_selector import expression_selector
|
from src.chat.express.expression_selector import expression_selector
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
from src.chat.message_receive.message import MessageRecv, MessageSending, Seg, UserInfo
|
from src.chat.message_receive.message import MessageSending, Seg, UserInfo
|
||||||
from src.chat.message_receive.uni_message_sender import HeartFCSender
|
from src.chat.message_receive.uni_message_sender import HeartFCSender
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
build_readable_messages,
|
build_readable_messages,
|
||||||
@@ -1733,7 +1733,7 @@ class DefaultReplyer:
|
|||||||
is_emoji: bool,
|
is_emoji: bool,
|
||||||
thinking_start_time: float,
|
thinking_start_time: float,
|
||||||
display_message: str,
|
display_message: str,
|
||||||
anchor_message: MessageRecv | None = None,
|
anchor_message: DatabaseMessages | None = None,
|
||||||
) -> MessageSending:
|
) -> MessageSending:
|
||||||
"""构建单个发送消息"""
|
"""构建单个发送消息"""
|
||||||
|
|
||||||
@@ -1743,8 +1743,11 @@ class DefaultReplyer:
|
|||||||
platform=self.chat_stream.platform,
|
platform=self.chat_stream.platform,
|
||||||
)
|
)
|
||||||
|
|
||||||
# await anchor_message.process()
|
# 从 DatabaseMessages 获取 sender_info
|
||||||
sender_info = anchor_message.message_info.user_info if anchor_message else None
|
if anchor_message:
|
||||||
|
sender_info = anchor_message.user_info
|
||||||
|
else:
|
||||||
|
sender_info = None
|
||||||
|
|
||||||
return MessageSending(
|
return MessageSending(
|
||||||
message_id=message_id, # 使用片段的唯一ID
|
message_id=message_id, # 使用片段的唯一ID
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import rjieba
|
|||||||
from maim_message import UserInfo
|
from maim_message import UserInfo
|
||||||
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.chat.message_receive.message import MessageRecv
|
# MessageRecv 已被移除,现在使用 DatabaseMessages
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.message_repository import count_messages, find_messages
|
from src.common.message_repository import count_messages, find_messages
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
@@ -41,34 +41,58 @@ def db_message_to_str(message_dict: dict) -> str:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
|
||||||
"""检查消息是否提到了机器人"""
|
"""检查消息是否提到了机器人
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: DatabaseMessages 消息对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[bool, float]: (是否提及, 提及概率)
|
||||||
|
"""
|
||||||
keywords = [global_config.bot.nickname]
|
keywords = [global_config.bot.nickname]
|
||||||
nicknames = global_config.bot.alias_names
|
nicknames = global_config.bot.alias_names
|
||||||
reply_probability = 0.0
|
reply_probability = 0.0
|
||||||
is_at = False
|
is_at = False
|
||||||
is_mentioned = False
|
is_mentioned = False
|
||||||
if message.is_mentioned is not None:
|
|
||||||
return bool(message.is_mentioned), message.is_mentioned
|
# 检查 is_mentioned 属性
|
||||||
if (
|
mentioned_attr = getattr(message, "is_mentioned", None)
|
||||||
message.message_info.additional_config is not None
|
if mentioned_attr is not None:
|
||||||
and message.message_info.additional_config.get("is_mentioned") is not None
|
|
||||||
):
|
|
||||||
try:
|
try:
|
||||||
reply_probability = float(message.message_info.additional_config.get("is_mentioned")) # type: ignore
|
return bool(mentioned_attr), float(mentioned_attr)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 检查 additional_config
|
||||||
|
additional_config = None
|
||||||
|
|
||||||
|
# DatabaseMessages: additional_config 是 JSON 字符串
|
||||||
|
if message.additional_config:
|
||||||
|
try:
|
||||||
|
import orjson
|
||||||
|
additional_config = orjson.loads(message.additional_config)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if additional_config and additional_config.get("is_mentioned") is not None:
|
||||||
|
try:
|
||||||
|
reply_probability = float(additional_config.get("is_mentioned")) # type: ignore
|
||||||
is_mentioned = True
|
is_mentioned = True
|
||||||
return is_mentioned, reply_probability
|
return is_mentioned, reply_probability
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(str(e))
|
logger.warning(str(e))
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}"
|
f"消息中包含不合理的设置 is_mentioned: {additional_config.get('is_mentioned')}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if global_config.bot.nickname in message.processed_plain_text:
|
# 检查消息文本内容
|
||||||
|
processed_text = message.processed_plain_text or ""
|
||||||
|
if global_config.bot.nickname in processed_text:
|
||||||
is_mentioned = True
|
is_mentioned = True
|
||||||
|
|
||||||
for alias_name in global_config.bot.alias_names:
|
for alias_name in global_config.bot.alias_names:
|
||||||
if alias_name in message.processed_plain_text:
|
if alias_name in processed_text:
|
||||||
is_mentioned = True
|
is_mentioned = True
|
||||||
|
|
||||||
# 判断是否被@
|
# 判断是否被@
|
||||||
@@ -110,7 +134,6 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
|||||||
logger.debug("被提及,回复概率设置为100%")
|
logger.debug("被提及,回复概率设置为100%")
|
||||||
return is_mentioned, reply_probability
|
return is_mentioned, reply_probability
|
||||||
|
|
||||||
|
|
||||||
async def get_embedding(text, request_type="embedding") -> list[float] | None:
|
async def get_embedding(text, request_type="embedding") -> list[float] | None:
|
||||||
"""获取文本的embedding向量"""
|
"""获取文本的embedding向量"""
|
||||||
# 每次都创建新的LLMRequest实例以避免事件循环冲突
|
# 每次都创建新的LLMRequest实例以避免事件循环冲突
|
||||||
|
|||||||
@@ -9,15 +9,18 @@ from src.common.logger import get_logger
|
|||||||
logger = get_logger("db_migration")
|
logger = get_logger("db_migration")
|
||||||
|
|
||||||
|
|
||||||
async def check_and_migrate_database():
|
async def check_and_migrate_database(existing_engine=None):
|
||||||
"""
|
"""
|
||||||
异步检查数据库结构并自动迁移。
|
异步检查数据库结构并自动迁移。
|
||||||
- 自动创建不存在的表。
|
- 自动创建不存在的表。
|
||||||
- 自动为现有表添加缺失的列。
|
- 自动为现有表添加缺失的列。
|
||||||
- 自动为现有表创建缺失的索引。
|
- 自动为现有表创建缺失的索引。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
existing_engine: 可选的已存在的数据库引擎。如果提供,将使用该引擎;否则获取全局引擎。
|
||||||
"""
|
"""
|
||||||
logger.info("正在检查数据库结构并执行自动迁移...")
|
logger.info("正在检查数据库结构并执行自动迁移...")
|
||||||
engine = await get_engine()
|
engine = existing_engine if existing_engine is not None else await get_engine()
|
||||||
|
|
||||||
async with engine.connect() as connection:
|
async with engine.connect() as connection:
|
||||||
# 在同步上下文中运行inspector操作
|
# 在同步上下文中运行inspector操作
|
||||||
|
|||||||
@@ -780,12 +780,8 @@ async def initialize_database() -> tuple["AsyncEngine", async_sessionmaker[Async
|
|||||||
_SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False)
|
_SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False)
|
||||||
|
|
||||||
# 迁移
|
# 迁移
|
||||||
try:
|
from src.common.database.db_migration import check_and_migrate_database
|
||||||
from src.common.database.db_migration import check_and_migrate_database
|
await check_and_migrate_database(existing_engine=_engine)
|
||||||
await check_and_migrate_database(existing_engine=_engine)
|
|
||||||
except TypeError:
|
|
||||||
from src.common.database.db_migration import check_and_migrate_database as _legacy_migrate
|
|
||||||
await _legacy_migrate()
|
|
||||||
|
|
||||||
if config.database_type == "sqlite":
|
if config.database_type == "sqlite":
|
||||||
await enable_sqlite_wal_mode(_engine)
|
await enable_sqlite_wal_mode(_engine)
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import math
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from src.chat.message_receive.message import MessageRecv
|
|
||||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
@@ -98,7 +97,7 @@ class ChatMood:
|
|||||||
if not hasattr(self, "last_change_time"):
|
if not hasattr(self, "last_change_time"):
|
||||||
self.last_change_time = 0
|
self.last_change_time = 0
|
||||||
|
|
||||||
async def update_mood_by_message(self, message: MessageRecv | DatabaseMessages, interested_rate: float):
|
async def update_mood_by_message(self, message: DatabaseMessages, interested_rate: float):
|
||||||
# 确保异步初始化已完成
|
# 确保异步初始化已完成
|
||||||
await self._initialize()
|
await self._initialize()
|
||||||
|
|
||||||
@@ -109,11 +108,8 @@ class ChatMood:
|
|||||||
|
|
||||||
self.regression_count = 0
|
self.regression_count = 0
|
||||||
|
|
||||||
# 处理不同类型的消息对象
|
# 使用 DatabaseMessages 的时间字段
|
||||||
if isinstance(message, MessageRecv):
|
message_time = message.time
|
||||||
message_time = message.message_info.time
|
|
||||||
else: # DatabaseMessages
|
|
||||||
message_time = message.time
|
|
||||||
|
|
||||||
# 防止负时间差
|
# 防止负时间差
|
||||||
during_last_time = max(0, message_time - self.last_change_time)
|
during_last_time = max(0, message_time - self.last_change_time)
|
||||||
|
|||||||
@@ -86,13 +86,16 @@ async def file_to_stream(
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Any
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
from maim_message import Seg, UserInfo
|
from maim_message import Seg, UserInfo
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|
||||||
# 导入依赖
|
# 导入依赖
|
||||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||||
from src.chat.message_receive.message import MessageRecv, MessageSending
|
from src.chat.message_receive.message import MessageSending
|
||||||
from src.chat.message_receive.uni_message_sender import HeartFCSender
|
from src.chat.message_receive.uni_message_sender import HeartFCSender
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
@@ -104,84 +107,53 @@ logger = get_logger("send_api")
|
|||||||
_adapter_response_pool: dict[str, asyncio.Future] = {}
|
_adapter_response_pool: dict[str, asyncio.Future] = {}
|
||||||
|
|
||||||
|
|
||||||
def message_dict_to_message_recv(message_dict: dict[str, Any]) -> MessageRecv | None:
|
def message_dict_to_db_message(message_dict: dict[str, Any]) -> "DatabaseMessages | None":
|
||||||
"""查找要回复的消息
|
"""从消息字典构建 DatabaseMessages 对象
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message_dict: 消息字典或 DatabaseMessages 对象
|
message_dict: 消息字典或 DatabaseMessages 对象
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[MessageRecv]: 找到的消息,如果没找到则返回None
|
Optional[DatabaseMessages]: 构建的消息对象,如果构建失败则返回None
|
||||||
"""
|
"""
|
||||||
# 兼容 DatabaseMessages 对象和字典
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
if isinstance(message_dict, dict):
|
|
||||||
user_platform = message_dict.get("user_platform", "")
|
|
||||||
user_id = message_dict.get("user_id", "")
|
|
||||||
user_nickname = message_dict.get("user_nickname", "")
|
|
||||||
user_cardname = message_dict.get("user_cardname", "")
|
|
||||||
chat_info_group_id = message_dict.get("chat_info_group_id")
|
|
||||||
chat_info_group_platform = message_dict.get("chat_info_group_platform", "")
|
|
||||||
chat_info_group_name = message_dict.get("chat_info_group_name", "")
|
|
||||||
chat_info_platform = message_dict.get("chat_info_platform", "")
|
|
||||||
message_id = message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id")
|
|
||||||
time_val = message_dict.get("time")
|
|
||||||
additional_config = message_dict.get("additional_config")
|
|
||||||
processed_plain_text = message_dict.get("processed_plain_text")
|
|
||||||
else:
|
|
||||||
# DatabaseMessages 对象
|
|
||||||
user_platform = getattr(message_dict, "user_platform", "")
|
|
||||||
user_id = getattr(message_dict, "user_id", "")
|
|
||||||
user_nickname = getattr(message_dict, "user_nickname", "")
|
|
||||||
user_cardname = getattr(message_dict, "user_cardname", "")
|
|
||||||
chat_info_group_id = getattr(message_dict, "chat_info_group_id", None)
|
|
||||||
chat_info_group_platform = getattr(message_dict, "chat_info_group_platform", "")
|
|
||||||
chat_info_group_name = getattr(message_dict, "chat_info_group_name", "")
|
|
||||||
chat_info_platform = getattr(message_dict, "chat_info_platform", "")
|
|
||||||
message_id = getattr(message_dict, "message_id", None)
|
|
||||||
time_val = getattr(message_dict, "time", None)
|
|
||||||
additional_config = getattr(message_dict, "additional_config", None)
|
|
||||||
processed_plain_text = getattr(message_dict, "processed_plain_text", "")
|
|
||||||
|
|
||||||
# 构建MessageRecv对象
|
# 如果已经是 DatabaseMessages,直接返回
|
||||||
user_info = {
|
if isinstance(message_dict, DatabaseMessages):
|
||||||
"platform": user_platform,
|
return message_dict
|
||||||
"user_id": user_id,
|
|
||||||
"user_nickname": user_nickname,
|
# 从字典提取信息
|
||||||
"user_cardname": user_cardname,
|
user_platform = message_dict.get("user_platform", "")
|
||||||
}
|
user_id = message_dict.get("user_id", "")
|
||||||
|
user_nickname = message_dict.get("user_nickname", "")
|
||||||
group_info = {}
|
user_cardname = message_dict.get("user_cardname", "")
|
||||||
if chat_info_group_id:
|
chat_info_group_id = message_dict.get("chat_info_group_id")
|
||||||
group_info = {
|
chat_info_group_platform = message_dict.get("chat_info_group_platform", "")
|
||||||
"platform": chat_info_group_platform,
|
chat_info_group_name = message_dict.get("chat_info_group_name", "")
|
||||||
"group_id": chat_info_group_id,
|
chat_info_platform = message_dict.get("chat_info_platform", "")
|
||||||
"group_name": chat_info_group_name,
|
message_id = message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id")
|
||||||
}
|
time_val = message_dict.get("time", time.time())
|
||||||
|
additional_config = message_dict.get("additional_config")
|
||||||
format_info = {"content_format": "", "accept_format": ""}
|
processed_plain_text = message_dict.get("processed_plain_text", "")
|
||||||
template_info = {"template_items": {}}
|
|
||||||
|
# DatabaseMessages 使用扁平参数构造
|
||||||
message_info = {
|
db_message = DatabaseMessages(
|
||||||
"platform": chat_info_platform,
|
message_id=message_id or "temp_reply_id",
|
||||||
"message_id": message_id,
|
time=time_val,
|
||||||
"time": time_val,
|
user_id=user_id,
|
||||||
"group_info": group_info,
|
user_nickname=user_nickname,
|
||||||
"user_info": user_info,
|
user_cardname=user_cardname,
|
||||||
"additional_config": additional_config,
|
user_platform=user_platform,
|
||||||
"format_info": format_info,
|
chat_info_group_id=chat_info_group_id,
|
||||||
"template_info": template_info,
|
chat_info_group_name=chat_info_group_name,
|
||||||
}
|
chat_info_group_platform=chat_info_group_platform,
|
||||||
|
chat_info_platform=chat_info_platform,
|
||||||
new_message_dict = {
|
processed_plain_text=processed_plain_text,
|
||||||
"message_info": message_info,
|
additional_config=additional_config
|
||||||
"raw_message": processed_plain_text,
|
)
|
||||||
"processed_plain_text": processed_plain_text,
|
|
||||||
}
|
logger.info(f"[SendAPI] 构建回复消息对象,发送者: {user_nickname}")
|
||||||
|
return db_message
|
||||||
message_recv = MessageRecv(new_message_dict)
|
|
||||||
|
|
||||||
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {user_nickname}")
|
|
||||||
return message_recv
|
|
||||||
|
|
||||||
|
|
||||||
def put_adapter_response(request_id: str, response_data: dict) -> None:
|
def put_adapter_response(request_id: str, response_data: dict) -> None:
|
||||||
@@ -285,17 +257,17 @@ async def _send_to_target(
|
|||||||
"message_id": "temp_reply_id", # 临时ID
|
"message_id": "temp_reply_id", # 临时ID
|
||||||
"time": time.time()
|
"time": time.time()
|
||||||
}
|
}
|
||||||
anchor_message = message_dict_to_message_recv(message_dict=temp_message_dict)
|
anchor_message = message_dict_to_db_message(message_dict=temp_message_dict)
|
||||||
else:
|
else:
|
||||||
anchor_message = None
|
anchor_message = None
|
||||||
reply_to_platform_id = f"{target_stream.platform}:{sender_id}" if anchor_message else None
|
reply_to_platform_id = f"{target_stream.platform}:{sender_id}" if anchor_message else None
|
||||||
|
|
||||||
elif reply_to_message:
|
elif reply_to_message:
|
||||||
anchor_message = message_dict_to_message_recv(message_dict=reply_to_message)
|
anchor_message = message_dict_to_db_message(message_dict=reply_to_message)
|
||||||
if anchor_message:
|
if anchor_message:
|
||||||
anchor_message.update_chat_stream(target_stream)
|
# DatabaseMessages 不需要 update_chat_stream,它是纯数据对象
|
||||||
reply_to_platform_id = (
|
reply_to_platform_id = (
|
||||||
f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}"
|
f"{anchor_message.chat_info.platform}:{anchor_message.user_info.user_id}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
reply_to_platform_id = None
|
reply_to_platform_id = None
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.apis import send_api
|
from src.plugin_system.apis import send_api
|
||||||
from src.plugin_system.base.component_types import ChatType, CommandInfo, ComponentType
|
from src.plugin_system.base.component_types import ChatType, CommandInfo, ComponentType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
|
|
||||||
logger = get_logger("base_command")
|
logger = get_logger("base_command")
|
||||||
|
|
||||||
|
|
||||||
@@ -29,11 +33,11 @@ class BaseCommand(ABC):
|
|||||||
chat_type_allow: ChatType = ChatType.ALL
|
chat_type_allow: ChatType = ChatType.ALL
|
||||||
"""允许的聊天类型,默认为所有类型"""
|
"""允许的聊天类型,默认为所有类型"""
|
||||||
|
|
||||||
def __init__(self, message: MessageRecv, plugin_config: dict | None = None):
|
def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None):
|
||||||
"""初始化Command组件
|
"""初始化Command组件
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: 接收到的消息对象
|
message: 接收到的消息对象(DatabaseMessages)
|
||||||
plugin_config: 插件配置字典
|
plugin_config: 插件配置字典
|
||||||
"""
|
"""
|
||||||
self.message = message
|
self.message = message
|
||||||
@@ -41,6 +45,9 @@ class BaseCommand(ABC):
|
|||||||
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
|
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
|
||||||
|
|
||||||
self.log_prefix = "[Command]"
|
self.log_prefix = "[Command]"
|
||||||
|
|
||||||
|
# chat_stream 会在运行时被 bot.py 设置
|
||||||
|
self.chat_stream: "ChatStream | None" = None
|
||||||
|
|
||||||
# 从类属性获取chat_type_allow设置
|
# 从类属性获取chat_type_allow设置
|
||||||
self.chat_type_allow = getattr(self.__class__, "chat_type_allow", ChatType.ALL)
|
self.chat_type_allow = getattr(self.__class__, "chat_type_allow", ChatType.ALL)
|
||||||
@@ -49,7 +56,7 @@ class BaseCommand(ABC):
|
|||||||
|
|
||||||
# 验证聊天类型限制
|
# 验证聊天类型限制
|
||||||
if not self._validate_chat_type():
|
if not self._validate_chat_type():
|
||||||
is_group = hasattr(self.message, "is_group_message") and self.message.is_group_message
|
is_group = message.group_info is not None
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"{self.log_prefix} Command '{self.command_name}' 不支持当前聊天类型: "
|
f"{self.log_prefix} Command '{self.command_name}' 不支持当前聊天类型: "
|
||||||
f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}"
|
f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}"
|
||||||
@@ -72,8 +79,8 @@ class BaseCommand(ABC):
|
|||||||
if self.chat_type_allow == ChatType.ALL:
|
if self.chat_type_allow == ChatType.ALL:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# 检查是否为群聊消息
|
# 检查是否为群聊消息(DatabaseMessages使用group_info来判断)
|
||||||
is_group = self.message.message_info.group_info
|
is_group = self.message.group_info is not None
|
||||||
|
|
||||||
if self.chat_type_allow == ChatType.GROUP and is_group:
|
if self.chat_type_allow == ChatType.GROUP and is_group:
|
||||||
return True
|
return True
|
||||||
@@ -137,12 +144,11 @@ class BaseCommand(ABC):
|
|||||||
bool: 是否发送成功
|
bool: 是否发送成功
|
||||||
"""
|
"""
|
||||||
# 获取聊天流信息
|
# 获取聊天流信息
|
||||||
chat_stream = self.message.chat_stream
|
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
|
||||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
|
||||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return await send_api.text_to_stream(text=content, stream_id=chat_stream.stream_id, reply_to=reply_to)
|
return await send_api.text_to_stream(text=content, stream_id=self.chat_stream.stream_id, reply_to=reply_to)
|
||||||
|
|
||||||
async def send_type(
|
async def send_type(
|
||||||
self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = ""
|
self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = ""
|
||||||
@@ -160,15 +166,14 @@ class BaseCommand(ABC):
|
|||||||
bool: 是否发送成功
|
bool: 是否发送成功
|
||||||
"""
|
"""
|
||||||
# 获取聊天流信息
|
# 获取聊天流信息
|
||||||
chat_stream = self.message.chat_stream
|
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
|
||||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
|
||||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return await send_api.custom_to_stream(
|
return await send_api.custom_to_stream(
|
||||||
message_type=message_type,
|
message_type=message_type,
|
||||||
content=content,
|
content=content,
|
||||||
stream_id=chat_stream.stream_id,
|
stream_id=self.chat_stream.stream_id,
|
||||||
display_message=display_message,
|
display_message=display_message,
|
||||||
typing=typing,
|
typing=typing,
|
||||||
reply_to=reply_to,
|
reply_to=reply_to,
|
||||||
@@ -190,8 +195,7 @@ class BaseCommand(ABC):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 获取聊天流信息
|
# 获取聊天流信息
|
||||||
chat_stream = self.message.chat_stream
|
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
|
||||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
|
||||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -200,7 +204,7 @@ class BaseCommand(ABC):
|
|||||||
|
|
||||||
success = await send_api.command_to_stream(
|
success = await send_api.command_to_stream(
|
||||||
command=command_data,
|
command=command_data,
|
||||||
stream_id=chat_stream.stream_id,
|
stream_id=self.chat_stream.stream_id,
|
||||||
storage_message=storage_message,
|
storage_message=storage_message,
|
||||||
display_message=display_message,
|
display_message=display_message,
|
||||||
)
|
)
|
||||||
@@ -225,12 +229,11 @@ class BaseCommand(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
bool: 是否发送成功
|
bool: 是否发送成功
|
||||||
"""
|
"""
|
||||||
chat_stream = self.message.chat_stream
|
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
|
||||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
|
||||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return await send_api.emoji_to_stream(emoji_base64, chat_stream.stream_id)
|
return await send_api.emoji_to_stream(emoji_base64, self.chat_stream.stream_id)
|
||||||
|
|
||||||
async def send_image(self, image_base64: str) -> bool:
|
async def send_image(self, image_base64: str) -> bool:
|
||||||
"""发送图片
|
"""发送图片
|
||||||
@@ -241,12 +244,11 @@ class BaseCommand(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
bool: 是否发送成功
|
bool: 是否发送成功
|
||||||
"""
|
"""
|
||||||
chat_stream = self.message.chat_stream
|
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
|
||||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
|
||||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return await send_api.image_to_stream(image_base64, chat_stream.stream_id)
|
return await send_api.image_to_stream(image_base64, self.chat_stream.stream_id)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_command_info(cls) -> "CommandInfo":
|
def get_command_info(cls) -> "CommandInfo":
|
||||||
|
|||||||
@@ -5,8 +5,9 @@
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.plugin_system.apis import send_api
|
from src.plugin_system.apis import send_api
|
||||||
@@ -14,6 +15,9 @@ from src.plugin_system.base.base_command import BaseCommand
|
|||||||
from src.plugin_system.base.command_args import CommandArgs
|
from src.plugin_system.base.command_args import CommandArgs
|
||||||
from src.plugin_system.base.component_types import ChatType, ComponentType, PlusCommandInfo
|
from src.plugin_system.base.component_types import ChatType, ComponentType, PlusCommandInfo
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
|
|
||||||
logger = get_logger("plus_command")
|
logger = get_logger("plus_command")
|
||||||
|
|
||||||
|
|
||||||
@@ -50,23 +54,26 @@ class PlusCommand(ABC):
|
|||||||
intercept_message: bool = False
|
intercept_message: bool = False
|
||||||
"""是否拦截消息,不进行后续处理"""
|
"""是否拦截消息,不进行后续处理"""
|
||||||
|
|
||||||
def __init__(self, message: MessageRecv, plugin_config: dict | None = None):
|
def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None):
|
||||||
"""初始化命令组件
|
"""初始化命令组件
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: 接收到的消息对象
|
message: 接收到的消息对象(DatabaseMessages)
|
||||||
plugin_config: 插件配置字典
|
plugin_config: 插件配置字典
|
||||||
"""
|
"""
|
||||||
self.message = message
|
self.message = message
|
||||||
self.plugin_config = plugin_config or {}
|
self.plugin_config = plugin_config or {}
|
||||||
self.log_prefix = "[PlusCommand]"
|
self.log_prefix = "[PlusCommand]"
|
||||||
|
|
||||||
|
# chat_stream 会在运行时被 bot.py 设置
|
||||||
|
self.chat_stream: "ChatStream | None" = None
|
||||||
|
|
||||||
# 解析命令参数
|
# 解析命令参数
|
||||||
self._parse_command()
|
self._parse_command()
|
||||||
|
|
||||||
# 验证聊天类型限制
|
# 验证聊天类型限制
|
||||||
if not self._validate_chat_type():
|
if not self._validate_chat_type():
|
||||||
is_group = self.message.message_info.group_info.group_id
|
is_group = message.group_info is not None
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"{self.log_prefix} 命令 '{self.command_name}' 不支持当前聊天类型: "
|
f"{self.log_prefix} 命令 '{self.command_name}' 不支持当前聊天类型: "
|
||||||
f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}"
|
f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}"
|
||||||
@@ -124,8 +131,8 @@ class PlusCommand(ABC):
|
|||||||
if self.chat_type_allow == ChatType.ALL:
|
if self.chat_type_allow == ChatType.ALL:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# 检查是否为群聊消息
|
# 检查是否为群聊消息(DatabaseMessages使用group_info判断)
|
||||||
is_group = hasattr(self.message.message_info, "group_info") and self.message.message_info.group_info
|
is_group = self.message.group_info is not None
|
||||||
|
|
||||||
if self.chat_type_allow == ChatType.GROUP and is_group:
|
if self.chat_type_allow == ChatType.GROUP and is_group:
|
||||||
return True
|
return True
|
||||||
@@ -152,7 +159,7 @@ class PlusCommand(ABC):
|
|||||||
|
|
||||||
def _is_exact_command_call(self) -> bool:
|
def _is_exact_command_call(self) -> bool:
|
||||||
"""检查是否是精确的命令调用(无参数)"""
|
"""检查是否是精确的命令调用(无参数)"""
|
||||||
if not hasattr(self.message, "plain_text") or not self.message.processed_plain_text:
|
if not self.message.processed_plain_text:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
plain_text = self.message.processed_plain_text.strip()
|
plain_text = self.message.processed_plain_text.strip()
|
||||||
@@ -218,12 +225,11 @@ class PlusCommand(ABC):
|
|||||||
bool: 是否发送成功
|
bool: 是否发送成功
|
||||||
"""
|
"""
|
||||||
# 获取聊天流信息
|
# 获取聊天流信息
|
||||||
chat_stream = self.message.chat_stream
|
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
|
||||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
|
||||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return await send_api.text_to_stream(text=content, stream_id=chat_stream.stream_id, reply_to=reply_to)
|
return await send_api.text_to_stream(text=content, stream_id=self.chat_stream.stream_id, reply_to=reply_to)
|
||||||
|
|
||||||
async def send_type(
|
async def send_type(
|
||||||
self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = ""
|
self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = ""
|
||||||
@@ -241,15 +247,14 @@ class PlusCommand(ABC):
|
|||||||
bool: 是否发送成功
|
bool: 是否发送成功
|
||||||
"""
|
"""
|
||||||
# 获取聊天流信息
|
# 获取聊天流信息
|
||||||
chat_stream = self.message.chat_stream
|
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
|
||||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
|
||||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return await send_api.custom_to_stream(
|
return await send_api.custom_to_stream(
|
||||||
message_type=message_type,
|
message_type=message_type,
|
||||||
content=content,
|
content=content,
|
||||||
stream_id=chat_stream.stream_id,
|
stream_id=self.chat_stream.stream_id,
|
||||||
display_message=display_message,
|
display_message=display_message,
|
||||||
typing=typing,
|
typing=typing,
|
||||||
reply_to=reply_to,
|
reply_to=reply_to,
|
||||||
@@ -264,12 +269,11 @@ class PlusCommand(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
bool: 是否发送成功
|
bool: 是否发送成功
|
||||||
"""
|
"""
|
||||||
chat_stream = self.message.chat_stream
|
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
|
||||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
|
||||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return await send_api.emoji_to_stream(emoji_base64, chat_stream.stream_id)
|
return await send_api.emoji_to_stream(emoji_base64, self.chat_stream.stream_id)
|
||||||
|
|
||||||
async def send_image(self, image_base64: str) -> bool:
|
async def send_image(self, image_base64: str) -> bool:
|
||||||
"""发送图片
|
"""发送图片
|
||||||
@@ -280,12 +284,11 @@ class PlusCommand(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
bool: 是否发送成功
|
bool: 是否发送成功
|
||||||
"""
|
"""
|
||||||
chat_stream = self.message.chat_stream
|
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
|
||||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
|
||||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return await send_api.image_to_stream(image_base64, chat_stream.stream_id)
|
return await send_api.image_to_stream(image_base64, self.chat_stream.stream_id)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_plus_command_info(cls) -> "PlusCommandInfo":
|
def get_plus_command_info(cls) -> "PlusCommandInfo":
|
||||||
@@ -340,12 +343,12 @@ class PlusCommandAdapter(BaseCommand):
|
|||||||
将PlusCommand适配到现有的插件系统,继承BaseCommand
|
将PlusCommand适配到现有的插件系统,继承BaseCommand
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, plus_command_class, message: MessageRecv, plugin_config: dict | None = None):
|
def __init__(self, plus_command_class, message: DatabaseMessages, plugin_config: dict | None = None):
|
||||||
"""初始化适配器
|
"""初始化适配器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
plus_command_class: PlusCommand子类
|
plus_command_class: PlusCommand子类
|
||||||
message: 消息对象
|
message: 消息对象(DatabaseMessages)
|
||||||
plugin_config: 插件配置
|
plugin_config: 插件配置
|
||||||
"""
|
"""
|
||||||
# 先设置必要的类属性
|
# 先设置必要的类属性
|
||||||
@@ -400,7 +403,7 @@ def create_plus_command_adapter(plus_command_class):
|
|||||||
command_pattern = plus_command_class._generate_command_pattern()
|
command_pattern = plus_command_class._generate_command_pattern()
|
||||||
chat_type_allow = getattr(plus_command_class, "chat_type_allow", ChatType.ALL)
|
chat_type_allow = getattr(plus_command_class, "chat_type_allow", ChatType.ALL)
|
||||||
|
|
||||||
def __init__(self, message: MessageRecv, plugin_config: dict | None = None):
|
def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None):
|
||||||
super().__init__(message, plugin_config)
|
super().__init__(message, plugin_config)
|
||||||
self.plus_command = plus_command_class(message, plugin_config)
|
self.plus_command = plus_command_class(message, plugin_config)
|
||||||
self.priority = getattr(plus_command_class, "priority", 0)
|
self.priority = getattr(plus_command_class, "priority", 0)
|
||||||
|
|||||||
@@ -410,11 +410,9 @@ class ChatterPlanExecutor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 添加到chat_stream的已读消息中
|
# 添加到chat_stream的已读消息中
|
||||||
if hasattr(chat_stream, "stream_context") and chat_stream.stream_context:
|
chat_stream.context_manager.context.history_messages.append(bot_message)
|
||||||
chat_stream.stream_context.history_messages.append(bot_message)
|
logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...")
|
||||||
logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...")
|
|
||||||
else:
|
|
||||||
logger.warning("chat_stream没有stream_context,无法添加已读消息")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"添加机器人回复到已读消息时出错: {e}")
|
logger.error(f"添加机器人回复到已读消息时出错: {e}")
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ class ProactiveThinkingMessageHandler(BaseEventHandler):
|
|||||||
"""处理消息事件
|
"""处理消息事件
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
kwargs: 事件参数,格式为 {"message": MessageRecv}
|
kwargs: 事件参数,格式为 {"message": DatabaseMessages}
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
HandlerResult: 处理结果
|
HandlerResult: 处理结果
|
||||||
@@ -104,7 +104,7 @@ class ProactiveThinkingMessageHandler(BaseEventHandler):
|
|||||||
if not kwargs:
|
if not kwargs:
|
||||||
return HandlerResult(success=True, continue_process=True, message=None)
|
return HandlerResult(success=True, continue_process=True, message=None)
|
||||||
|
|
||||||
# 从 kwargs 中获取 MessageRecv 对象
|
# 从 kwargs 中获取 DatabaseMessages 对象
|
||||||
message = kwargs.get("message")
|
message = kwargs.get("message")
|
||||||
if not message or not hasattr(message, "chat_stream"):
|
if not message or not hasattr(message, "chat_stream"):
|
||||||
return HandlerResult(success=True, continue_process=True, message=None)
|
return HandlerResult(success=True, continue_process=True, message=None)
|
||||||
|
|||||||
Reference in New Issue
Block a user