diff --git a/integration_test_relationship_tools.py b/integration_test_relationship_tools.py
deleted file mode 100644
index a2ac3a7fa..000000000
--- a/integration_test_relationship_tools.py
+++ /dev/null
@@ -1,303 +0,0 @@
-"""
-关系追踪工具集成测试脚本
-
-注意:此脚本需要在完整的应用环境中运行
-建议通过 bot.py 启动后在交互式环境中测试
-"""
-
-import asyncio
-
-
-async def test_user_profile_tool():
- """测试用户画像工具"""
- print("\n" + "=" * 80)
- print("测试 UserProfileTool")
- print("=" * 80)
-
- from src.plugins.built_in.affinity_flow_chatter.user_profile_tool import UserProfileTool
- from src.common.database.sqlalchemy_database_api import db_query
- from src.common.database.sqlalchemy_models import UserRelationships
-
- tool = UserProfileTool()
- print(f"✅ 工具名称: {tool.name}")
- print(f" 工具描述: {tool.description}")
-
- # 执行工具
- test_user_id = "integration_test_user_001"
- result = await tool.execute({
- "target_user_id": test_user_id,
- "user_aliases": "测试小明,TestMing,小明君",
- "impression_description": "这是一个集成测试用户,性格开朗活泼,喜欢技术讨论,对AI和编程特别感兴趣。经常提出有深度的问题。",
- "preference_keywords": "AI,Python,深度学习,游戏开发,科幻小说",
- "affection_score": 0.85
- })
-
- print(f"\n✅ 工具执行结果:")
- print(f" 类型: {result.get('type')}")
- print(f" 内容: {result.get('content')}")
-
- # 验证数据库
- db_data = await db_query(
- UserRelationships,
- filters={"user_id": test_user_id},
- limit=1
- )
-
- if db_data:
- data = db_data[0]
- print(f"\n✅ 数据库验证:")
- print(f" user_id: {data.get('user_id')}")
- print(f" user_aliases: {data.get('user_aliases')}")
- print(f" relationship_text: {data.get('relationship_text', '')[:80]}...")
- print(f" preference_keywords: {data.get('preference_keywords')}")
- print(f" relationship_score: {data.get('relationship_score')}")
- return True
- else:
- print(f"\n❌ 数据库中未找到数据")
- return False
-
-
-async def test_chat_stream_impression_tool():
- """测试聊天流印象工具"""
- print("\n" + "=" * 80)
- print("测试 ChatStreamImpressionTool")
- print("=" * 80)
-
- from src.plugins.built_in.affinity_flow_chatter.chat_stream_impression_tool import ChatStreamImpressionTool
- from src.common.database.sqlalchemy_database_api import db_query
- from src.common.database.sqlalchemy_models import ChatStreams, get_db_session
-
- # 准备测试数据:先创建一条 ChatStreams 记录
- test_stream_id = "integration_test_stream_001"
- print(f"🔧 准备测试数据:创建聊天流记录 {test_stream_id}")
-
- import time
- current_time = time.time()
-
- async with get_db_session() as session:
- new_stream = ChatStreams(
- stream_id=test_stream_id,
- create_time=current_time,
- last_active_time=current_time,
- platform="QQ",
- user_platform="QQ",
- user_id="test_user_123",
- user_nickname="测试用户",
- group_name="测试技术交流群",
- group_platform="QQ",
- group_id="test_group_456",
- stream_impression_text="", # 初始为空
- stream_chat_style="",
- stream_topic_keywords="",
- stream_interest_score=0.5
- )
- session.add(new_stream)
- await session.commit()
- print(f"✅ 测试聊天流记录已创建")
-
- tool = ChatStreamImpressionTool()
- print(f"✅ 工具名称: {tool.name}")
- print(f" 工具描述: {tool.description}")
-
- # 执行工具
- result = await tool.execute({
- "stream_id": test_stream_id,
- "impression_description": "这是一个技术交流群,成员主要是程序员和AI爱好者。大家经常分享最新的技术文章,讨论编程问题,氛围友好且专业。",
- "chat_style": "专业技术交流,活跃讨论,互帮互助,知识分享",
- "topic_keywords": "Python开发,机器学习,AI应用,Web后端,数据分析,开源项目",
- "interest_score": 0.90
- })
-
- print(f"\n✅ 工具执行结果:")
- print(f" 类型: {result.get('type')}")
- print(f" 内容: {result.get('content')}")
-
- # 验证数据库
- db_data = await db_query(
- ChatStreams,
- filters={"stream_id": test_stream_id},
- limit=1
- )
-
- if db_data:
- data = db_data[0]
- print(f"\n✅ 数据库验证:")
- print(f" stream_id: {data.get('stream_id')}")
- print(f" stream_impression_text: {data.get('stream_impression_text', '')[:80]}...")
- print(f" stream_chat_style: {data.get('stream_chat_style')}")
- print(f" stream_topic_keywords: {data.get('stream_topic_keywords')}")
- print(f" stream_interest_score: {data.get('stream_interest_score')}")
- return True
- else:
- print(f"\n❌ 数据库中未找到数据")
- return False
-
-
-async def test_relationship_info_build():
- """测试关系信息构建"""
- print("\n" + "=" * 80)
- print("测试关系信息构建(提示词集成)")
- print("=" * 80)
-
- from src.person_info.relationship_fetcher import relationship_fetcher_manager
-
- test_stream_id = "integration_test_stream_001"
- test_person_id = "test_person_999" # 使用一个可能不存在的ID来测试
-
- fetcher = relationship_fetcher_manager.get_fetcher(test_stream_id)
- print(f"✅ RelationshipFetcher 已创建")
-
- # 测试聊天流印象构建
- print(f"\n🔍 构建聊天流印象...")
- stream_info = await fetcher.build_chat_stream_impression(test_stream_id)
-
- if stream_info:
- print(f"✅ 聊天流印象构建成功")
- print(f"\n{'=' * 80}")
- print(stream_info)
- print(f"{'=' * 80}")
- else:
- print(f"⚠️ 聊天流印象为空(可能测试数据不存在)")
-
- return True
-
-
-async def cleanup_test_data():
- """清理测试数据"""
- print("\n" + "=" * 80)
- print("清理测试数据")
- print("=" * 80)
-
- from src.common.database.sqlalchemy_database_api import db_query
- from src.common.database.sqlalchemy_models import UserRelationships, ChatStreams
-
- try:
- # 清理用户数据
- await db_query(
- UserRelationships,
- query_type="delete",
- filters={"user_id": "integration_test_user_001"}
- )
- print("✅ 用户测试数据已清理")
-
- # 清理聊天流数据
- await db_query(
- ChatStreams,
- query_type="delete",
- filters={"stream_id": "integration_test_stream_001"}
- )
- print("✅ 聊天流测试数据已清理")
-
- return True
- except Exception as e:
- print(f"⚠️ 清理失败: {e}")
- return False
-
-
-async def run_all_tests():
- """运行所有测试"""
- print("\n" + "=" * 80)
- print("关系追踪工具集成测试")
- print("=" * 80)
-
- results = {}
-
- # 测试1
- try:
- results["UserProfileTool"] = await test_user_profile_tool()
- except Exception as e:
- print(f"\n❌ UserProfileTool 测试失败: {e}")
- import traceback
- traceback.print_exc()
- results["UserProfileTool"] = False
-
- # 测试2
- try:
- results["ChatStreamImpressionTool"] = await test_chat_stream_impression_tool()
- except Exception as e:
- print(f"\n❌ ChatStreamImpressionTool 测试失败: {e}")
- import traceback
- traceback.print_exc()
- results["ChatStreamImpressionTool"] = False
-
- # 测试3
- try:
- results["RelationshipFetcher"] = await test_relationship_info_build()
- except Exception as e:
- print(f"\n❌ RelationshipFetcher 测试失败: {e}")
- import traceback
- traceback.print_exc()
- results["RelationshipFetcher"] = False
-
- # 清理
- try:
- await cleanup_test_data()
- except Exception as e:
- print(f"\n⚠️ 清理测试数据失败: {e}")
-
- # 总结
- print("\n" + "=" * 80)
- print("测试总结")
- print("=" * 80)
-
- passed = sum(1 for r in results.values() if r)
- total = len(results)
-
- for test_name, result in results.items():
- status = "✅ 通过" if result else "❌ 失败"
- print(f"{status} - {test_name}")
-
- print(f"\n总计: {passed}/{total} 测试通过")
-
- if passed == total:
- print("\n🎉 所有测试通过!")
- else:
- print(f"\n⚠️ {total - passed} 个测试失败")
-
- return passed == total
-
-
-# 使用说明
-print("""
-============================================================================
-关系追踪工具集成测试脚本
-============================================================================
-
-此脚本需要在完整的应用环境中运行。
-
-使用方法1: 在 bot.py 中添加测试调用
------------------------------------
-在 bot.py 的 main() 函数中添加:
-
- # 测试关系追踪工具
- from tests.integration_test_relationship_tools import run_all_tests
- await run_all_tests()
-
-使用方法2: 在 Python REPL 中运行
------------------------------------
-启动 bot.py 后,在 Python 调试控制台中执行:
-
- import asyncio
- from tests.integration_test_relationship_tools import run_all_tests
- asyncio.create_task(run_all_tests())
-
-使用方法3: 直接在此文件底部运行
------------------------------------
-取消注释下面的代码,然后确保已启动应用环境
-============================================================================
-""")
-
-
-# 如果需要直接运行(需要应用环境已启动)
-if __name__ == "__main__":
- print("\n⚠️ 警告: 直接运行此脚本可能会失败,因为缺少应用环境")
- print("建议在 bot.py 启动后的环境中运行\n")
-
- try:
- asyncio.run(run_all_tests())
- except Exception as e:
- print(f"\n❌ 测试失败: {e}")
- print("\n建议:")
- print("1. 确保已启动 bot.py")
- print("2. 在 Python 调试控制台中运行测试")
- print("3. 或在 bot.py 中添加测试调用")
diff --git a/pyrightconfig.json b/pyrightconfig.json
index 3cffac58c..adf9c8dcf 100644
--- a/pyrightconfig.json
+++ b/pyrightconfig.json
@@ -27,6 +27,6 @@
"venvPath": ".",
"venv": ".venv",
"executionEnvironments": [
- {"root": "src"}
+ {"root": "."}
]
}
diff --git a/src/chat/antipromptinjector/processors/message_processor.py b/src/chat/antipromptinjector/processors/message_processor.py
index 0e37efc0d..b13baff13 100644
--- a/src/chat/antipromptinjector/processors/message_processor.py
+++ b/src/chat/antipromptinjector/processors/message_processor.py
@@ -6,7 +6,7 @@
import re
-from src.chat.message_receive.message import MessageRecv
+from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
logger = get_logger("anti_injector.message_processor")
@@ -15,7 +15,7 @@ logger = get_logger("anti_injector.message_processor")
class MessageProcessor:
"""消息内容处理器"""
- def extract_text_content(self, message: MessageRecv) -> str:
+ def extract_text_content(self, message: DatabaseMessages) -> str:
"""提取消息中的文本内容,过滤掉引用的历史内容
Args:
@@ -64,7 +64,7 @@ class MessageProcessor:
return new_content
@staticmethod
- def check_whitelist(message: MessageRecv, whitelist: list) -> tuple | None:
+ def check_whitelist(message: DatabaseMessages, whitelist: list) -> tuple | None:
"""检查用户白名单
Args:
@@ -74,8 +74,8 @@ class MessageProcessor:
Returns:
如果在白名单中返回结果元组,否则返回None
"""
- user_id = message.message_info.user_info.user_id
- platform = message.message_info.platform
+ user_id = message.user_info.user_id
+ platform = message.chat_info.platform
# 检查用户白名单:格式为 [[platform, user_id], ...]
for whitelist_entry in whitelist:
diff --git a/src/chat/message_manager/context_manager.py b/src/chat/message_manager/context_manager.py
index 41bf47781..c569fde6b 100644
--- a/src/chat/message_manager/context_manager.py
+++ b/src/chat/message_manager/context_manager.py
@@ -29,7 +29,6 @@ class SingleStreamContextManager:
# 配置参数
self.max_context_size = max_context_size or getattr(global_config.chat, "max_context_size", 100)
- self.context_ttl = getattr(global_config.chat, "context_ttl", 24 * 3600) # 24小时
# 元数据
self.created_time = time.time()
@@ -93,27 +92,24 @@ class SingleStreamContextManager:
return True
else:
logger.warning(f"消息缓存系统添加失败,回退到直接添加: {self.stream_id}")
-
- except ImportError:
- logger.debug("MessageManager不可用,使用直接添加模式")
except Exception as e:
logger.warning(f"消息缓存系统异常,回退到直接添加: {self.stream_id}, error={e}")
- # 回退方案:直接添加到未读消息
- message.is_read = False
- self.context.unread_messages.append(message)
+ # 回退方案:直接添加到未读消息
+ message.is_read = False
+ self.context.unread_messages.append(message)
- # 自动检测和更新chat type
- self._detect_chat_type(message)
+ # 自动检测和更新chat type
+ self._detect_chat_type(message)
- # 在上下文管理器中计算兴趣值
- await self._calculate_message_interest(message)
- self.total_messages += 1
- self.last_access_time = time.time()
- # 启动流的循环任务(如果还未启动)
- asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
- logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}")
- return True
+ # 在上下文管理器中计算兴趣值
+ await self._calculate_message_interest(message)
+ self.total_messages += 1
+ self.last_access_time = time.time()
+ # 启动流的循环任务(如果还未启动)
+ asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
+ logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}")
+ return True
except Exception as e:
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
return False
diff --git a/src/chat/message_manager/message_manager.py b/src/chat/message_manager/message_manager.py
index 49c169640..b54fc8bdc 100644
--- a/src/chat/message_manager/message_manager.py
+++ b/src/chat/message_manager/message_manager.py
@@ -71,14 +71,6 @@ class MessageManager:
except Exception as e:
logger.error(f"启动批量数据库写入器失败: {e}")
- # 启动流缓存管理器
- try:
- from src.chat.message_manager.stream_cache_manager import init_stream_cache_manager
-
- await init_stream_cache_manager()
- except Exception as e:
- logger.error(f"启动流缓存管理器失败: {e}")
-
# 启动消息缓存系统(内置)
logger.info("📦 消息缓存系统已启动")
@@ -116,15 +108,6 @@ class MessageManager:
except Exception as e:
logger.error(f"停止批量数据库写入器失败: {e}")
- # 停止流缓存管理器
- try:
- from src.chat.message_manager.stream_cache_manager import shutdown_stream_cache_manager
-
- await shutdown_stream_cache_manager()
- logger.info("🗄️ 流缓存管理器已停止")
- except Exception as e:
- logger.error(f"停止流缓存管理器失败: {e}")
-
# 停止消息缓存系统(内置)
self.message_caches.clear()
self.stream_processing_status.clear()
@@ -152,7 +135,7 @@ class MessageManager:
# 检查是否为notice消息
if self._is_notice_message(message):
# Notice消息处理 - 添加到全局管理器
- logger.info(f"📢 检测到notice消息: message_id={message.message_id}, is_notify={message.is_notify}, notice_type={getattr(message, 'notice_type', None)}")
+ logger.info(f"📢 检测到notice消息: notice_type={getattr(message, 'notice_type', None)}")
await self._handle_notice_message(stream_id, message)
# 根据配置决定是否继续处理(触发聊天流程)
@@ -206,39 +189,6 @@ class MessageManager:
except Exception as e:
logger.error(f"更新消息 {message_id} 时发生错误: {e}")
- async def bulk_update_messages(self, stream_id: str, updates: list[dict[str, Any]]) -> int:
- """批量更新消息信息,降低更新频率"""
- if not updates:
- return 0
-
- try:
- chat_manager = get_chat_manager()
- chat_stream = await chat_manager.get_stream(stream_id)
- if not chat_stream:
- logger.warning(f"MessageManager.bulk_update_messages: 聊天流 {stream_id} 不存在")
- return 0
-
- updated_count = 0
- for item in updates:
- message_id = item.get("message_id")
- if not message_id:
- continue
-
- payload = {key: value for key, value in item.items() if key != "message_id" and value is not None}
-
- if not payload:
- continue
-
- success = await chat_stream.context_manager.update_message(message_id, payload)
- if success:
- updated_count += 1
-
- if updated_count:
- logger.debug(f"批量更新消息 {updated_count} 条 (stream={stream_id})")
- return updated_count
- except Exception as e:
- logger.error(f"批量更新聊天流 {stream_id} 消息失败: {e}")
- return 0
async def add_action(self, stream_id: str, message_id: str, action: str):
"""添加动作到消息"""
@@ -266,7 +216,7 @@ class MessageManager:
logger.warning(f"停用流失败: 聊天流 {stream_id} 不存在")
return
- context = chat_stream.stream_context
+ context = chat_stream.context_manager.context
context.is_active = False
# 取消处理任务
@@ -288,7 +238,7 @@ class MessageManager:
logger.warning(f"激活流失败: 聊天流 {stream_id} 不存在")
return
- context = chat_stream.stream_context
+ context = chat_stream.context_manager.context
context.is_active = True
logger.info(f"激活聊天流: {stream_id}")
@@ -304,7 +254,7 @@ class MessageManager:
if not chat_stream:
return None
- context = chat_stream.stream_context
+ context = chat_stream.context_manager.context
unread_count = len(chat_stream.context_manager.get_unread_messages())
return StreamStats(
@@ -447,7 +397,7 @@ class MessageManager:
await asyncio.sleep(0.1)
# 获取当前的stream context
- context = chat_stream.stream_context
+ context = chat_stream.context_manager.context
# 确保有未读消息需要处理
unread_messages = context.get_unread_messages()
diff --git a/src/chat/message_manager/stream_cache_manager.py b/src/chat/message_manager/stream_cache_manager.py
deleted file mode 100644
index ea85c3855..000000000
--- a/src/chat/message_manager/stream_cache_manager.py
+++ /dev/null
@@ -1,377 +0,0 @@
-"""
-流缓存管理器 - 使用优化版聊天流和智能缓存策略
-提供分层缓存和自动清理功能
-"""
-
-import asyncio
-import time
-from collections import OrderedDict
-from dataclasses import dataclass
-
-from maim_message import GroupInfo, UserInfo
-
-from src.chat.message_receive.optimized_chat_stream import OptimizedChatStream, create_optimized_chat_stream
-from src.common.logger import get_logger
-
-logger = get_logger("stream_cache_manager")
-
-
-@dataclass
-class StreamCacheStats:
- """缓存统计信息"""
-
- hot_cache_size: int = 0
- warm_storage_size: int = 0
- cold_storage_size: int = 0
- total_memory_usage: int = 0 # 估算的内存使用(字节)
- cache_hits: int = 0
- cache_misses: int = 0
- evictions: int = 0
- last_cleanup_time: float = 0
-
-
-class TieredStreamCache:
- """分层流缓存管理器"""
-
- def __init__(
- self,
- max_hot_size: int = 100,
- max_warm_size: int = 500,
- max_cold_size: int = 2000,
- cleanup_interval: float = 300.0, # 5分钟清理一次
- hot_timeout: float = 1800.0, # 30分钟未访问降级到warm
- warm_timeout: float = 7200.0, # 2小时未访问降级到cold
- cold_timeout: float = 86400.0, # 24小时未访问删除
- ):
- self.max_hot_size = max_hot_size
- self.max_warm_size = max_warm_size
- self.max_cold_size = max_cold_size
- self.cleanup_interval = cleanup_interval
- self.hot_timeout = hot_timeout
- self.warm_timeout = warm_timeout
- self.cold_timeout = cold_timeout
-
- # 三层缓存存储
- self.hot_cache: OrderedDict[str, OptimizedChatStream] = OrderedDict() # 热数据(LRU)
- self.warm_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间)
- self.cold_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间)
-
- # 统计信息
- self.stats = StreamCacheStats()
-
- # 清理任务
- self.cleanup_task: asyncio.Task | None = None
- self.is_running = False
-
- logger.info(f"分层流缓存管理器初始化完成 (hot:{max_hot_size}, warm:{max_warm_size}, cold:{max_cold_size})")
-
- async def start(self):
- """启动缓存管理器"""
- if self.is_running:
- logger.warning("缓存管理器已经在运行")
- return
-
- self.is_running = True
- self.cleanup_task = asyncio.create_task(self._cleanup_loop(), name="stream_cache_cleanup")
-
- async def stop(self):
- """停止缓存管理器"""
- if not self.is_running:
- return
-
- self.is_running = False
-
- if self.cleanup_task and not self.cleanup_task.done():
- self.cleanup_task.cancel()
- try:
- await asyncio.wait_for(self.cleanup_task, timeout=10.0)
- except asyncio.TimeoutError:
- logger.warning("缓存清理任务停止超时")
- except Exception as e:
- logger.error(f"停止缓存清理任务时出错: {e}")
-
- logger.info("分层流缓存管理器已停止")
-
- async def get_or_create_stream(
- self,
- stream_id: str,
- platform: str,
- user_info: UserInfo,
- group_info: GroupInfo | None = None,
- data: dict | None = None,
- ) -> OptimizedChatStream:
- """获取或创建流 - 优化版本"""
- current_time = time.time()
-
- # 1. 检查热缓存
- if stream_id in self.hot_cache:
- stream = self.hot_cache[stream_id]
- # 移动到末尾(LRU更新)
- self.hot_cache.move_to_end(stream_id)
- self.stats.cache_hits += 1
- logger.debug(f"热缓存命中: {stream_id}")
- return stream.create_snapshot()
-
- # 2. 检查温存储
- if stream_id in self.warm_storage:
- stream, last_access = self.warm_storage[stream_id]
- self.warm_storage[stream_id] = (stream, current_time)
- self.stats.cache_hits += 1
- logger.debug(f"温缓存命中: {stream_id}")
- # 提升到热缓存
- await self._promote_to_hot(stream_id, stream)
- return stream.create_snapshot()
-
- # 3. 检查冷存储
- if stream_id in self.cold_storage:
- stream, last_access = self.cold_storage[stream_id]
- self.cold_storage[stream_id] = (stream, current_time)
- self.stats.cache_hits += 1
- logger.debug(f"冷缓存命中: {stream_id}")
- # 提升到温缓存
- await self._promote_to_warm(stream_id, stream)
- return stream.create_snapshot()
-
- # 4. 缓存未命中,创建新流
- self.stats.cache_misses += 1
- stream = create_optimized_chat_stream(
- stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info, data=data
- )
- logger.debug(f"缓存未命中,创建新流: {stream_id}")
-
- # 添加到热缓存
- await self._add_to_hot(stream_id, stream)
-
- return stream
-
- async def _add_to_hot(self, stream_id: str, stream: OptimizedChatStream):
- """添加到热缓存"""
- # 检查是否需要驱逐
- if len(self.hot_cache) >= self.max_hot_size:
- await self._evict_from_hot()
-
- self.hot_cache[stream_id] = stream
- self.stats.hot_cache_size = len(self.hot_cache)
-
- async def _promote_to_hot(self, stream_id: str, stream: OptimizedChatStream):
- """提升到热缓存"""
- # 从温存储中移除
- if stream_id in self.warm_storage:
- del self.warm_storage[stream_id]
- self.stats.warm_storage_size = len(self.warm_storage)
-
- # 添加到热缓存
- await self._add_to_hot(stream_id, stream)
- logger.debug(f"流 {stream_id} 提升到热缓存")
-
- async def _promote_to_warm(self, stream_id: str, stream: OptimizedChatStream):
- """提升到温缓存"""
- # 从冷存储中移除
- if stream_id in self.cold_storage:
- del self.cold_storage[stream_id]
- self.stats.cold_storage_size = len(self.cold_storage)
-
- # 添加到温存储
- if len(self.warm_storage) >= self.max_warm_size:
- await self._evict_from_warm()
-
- current_time = time.time()
- self.warm_storage[stream_id] = (stream, current_time)
- self.stats.warm_storage_size = len(self.warm_storage)
- logger.debug(f"流 {stream_id} 提升到温缓存")
-
- async def _evict_from_hot(self):
- """从热缓存驱逐最久未使用的流"""
- if not self.hot_cache:
- return
-
- # LRU驱逐
- stream_id, stream = self.hot_cache.popitem(last=False)
- self.stats.evictions += 1
- logger.debug(f"从热缓存驱逐: {stream_id}")
-
- # 移动到温存储
- if len(self.warm_storage) < self.max_warm_size:
- current_time = time.time()
- self.warm_storage[stream_id] = (stream, current_time)
- self.stats.warm_storage_size = len(self.warm_storage)
- else:
- # 温存储也满了,直接删除
- logger.debug(f"温存储已满,删除流: {stream_id}")
-
- self.stats.hot_cache_size = len(self.hot_cache)
-
- async def _evict_from_warm(self):
- """从温存储驱逐最久未使用的流"""
- if not self.warm_storage:
- return
-
- # 找到最久未访问的流
- oldest_stream_id = min(self.warm_storage.keys(), key=lambda k: self.warm_storage[k][1])
- stream, last_access = self.warm_storage.pop(oldest_stream_id)
- self.stats.evictions += 1
- logger.debug(f"从温存储驱逐: {oldest_stream_id}")
-
- # 移动到冷存储
- if len(self.cold_storage) < self.max_cold_size:
- current_time = time.time()
- self.cold_storage[oldest_stream_id] = (stream, current_time)
- self.stats.cold_storage_size = len(self.cold_storage)
- else:
- # 冷存储也满了,直接删除
- logger.debug(f"冷存储已满,删除流: {oldest_stream_id}")
-
- self.stats.warm_storage_size = len(self.warm_storage)
-
- async def _cleanup_loop(self):
- """清理循环"""
- logger.info("流缓存清理循环启动")
-
- while self.is_running:
- try:
- await asyncio.sleep(self.cleanup_interval)
- await self._perform_cleanup()
- except asyncio.CancelledError:
- logger.info("流缓存清理循环被取消")
- break
- except Exception as e:
- logger.error(f"流缓存清理出错: {e}")
-
- logger.info("流缓存清理循环结束")
-
- async def _perform_cleanup(self):
- """执行清理操作"""
- current_time = time.time()
- cleanup_stats = {
- "hot_to_warm": 0,
- "warm_to_cold": 0,
- "cold_removed": 0,
- }
-
- # 1. 检查热缓存超时
- hot_to_demote = []
- for stream_id, stream in self.hot_cache.items():
- # 获取最后访问时间(简化:使用创建时间作为近似)
- last_access = getattr(stream, "last_active_time", stream.create_time)
- if current_time - last_access > self.hot_timeout:
- hot_to_demote.append(stream_id)
-
- for stream_id in hot_to_demote:
- stream = self.hot_cache.pop(stream_id)
- current_time_local = time.time()
- self.warm_storage[stream_id] = (stream, current_time_local)
- cleanup_stats["hot_to_warm"] += 1
-
- # 2. 检查温存储超时
- warm_to_demote = []
- for stream_id, (stream, last_access) in self.warm_storage.items():
- if current_time - last_access > self.warm_timeout:
- warm_to_demote.append(stream_id)
-
- for stream_id in warm_to_demote:
- stream, last_access = self.warm_storage.pop(stream_id)
- self.cold_storage[stream_id] = (stream, last_access)
- cleanup_stats["warm_to_cold"] += 1
-
- # 3. 检查冷存储超时
- cold_to_remove = []
- for stream_id, (stream, last_access) in self.cold_storage.items():
- if current_time - last_access > self.cold_timeout:
- cold_to_remove.append(stream_id)
-
- for stream_id in cold_to_remove:
- self.cold_storage.pop(stream_id)
- cleanup_stats["cold_removed"] += 1
-
- # 更新统计信息
- self.stats.hot_cache_size = len(self.hot_cache)
- self.stats.warm_storage_size = len(self.warm_storage)
- self.stats.cold_storage_size = len(self.cold_storage)
- self.stats.last_cleanup_time = current_time
-
- # 估算内存使用(粗略估计)
- self.stats.total_memory_usage = (
- len(self.hot_cache) * 1024 # 每个热流约1KB
- + len(self.warm_storage) * 512 # 每个温流约512B
- + len(self.cold_storage) * 256 # 每个冷流约256B
- )
-
- if sum(cleanup_stats.values()) > 0:
- logger.info(
- f"缓存清理完成: {cleanup_stats['hot_to_warm']}热→温, "
- f"{cleanup_stats['warm_to_cold']}温→冷, "
- f"{cleanup_stats['cold_removed']}冷删除"
- )
-
- def get_stats(self) -> StreamCacheStats:
- """获取缓存统计信息"""
- # 计算命中率
- total_requests = self.stats.cache_hits + self.stats.cache_misses
- hit_rate = self.stats.cache_hits / total_requests if total_requests > 0 else 0
-
- stats_copy = StreamCacheStats(
- hot_cache_size=self.stats.hot_cache_size,
- warm_storage_size=self.stats.warm_storage_size,
- cold_storage_size=self.stats.cold_storage_size,
- total_memory_usage=self.stats.total_memory_usage,
- cache_hits=self.stats.cache_hits,
- cache_misses=self.stats.cache_misses,
- evictions=self.stats.evictions,
- last_cleanup_time=self.stats.last_cleanup_time,
- )
-
- # 添加命中率信息
- stats_copy.hit_rate = hit_rate
-
- return stats_copy
-
- def clear_cache(self):
- """清空所有缓存"""
- self.hot_cache.clear()
- self.warm_storage.clear()
- self.cold_storage.clear()
-
- self.stats.hot_cache_size = 0
- self.stats.warm_storage_size = 0
- self.stats.cold_storage_size = 0
- self.stats.total_memory_usage = 0
-
- logger.info("所有缓存已清空")
-
- async def get_stream_snapshot(self, stream_id: str) -> OptimizedChatStream | None:
- """获取流的快照(不修改缓存状态)"""
- if stream_id in self.hot_cache:
- return self.hot_cache[stream_id].create_snapshot()
- elif stream_id in self.warm_storage:
- return self.warm_storage[stream_id][0].create_snapshot()
- elif stream_id in self.cold_storage:
- return self.cold_storage[stream_id][0].create_snapshot()
- return None
-
- def get_cached_stream_ids(self) -> set[str]:
- """获取所有缓存的流ID"""
- return set(self.hot_cache.keys()) | set(self.warm_storage.keys()) | set(self.cold_storage.keys())
-
-
-# 全局缓存管理器实例
-_cache_manager: TieredStreamCache | None = None
-
-
-def get_stream_cache_manager() -> TieredStreamCache:
- """获取流缓存管理器实例"""
- global _cache_manager
- if _cache_manager is None:
- _cache_manager = TieredStreamCache()
- return _cache_manager
-
-
-async def init_stream_cache_manager():
- """初始化流缓存管理器"""
- manager = get_stream_cache_manager()
- await manager.start()
-
-
-async def shutdown_stream_cache_manager():
- """关闭流缓存管理器"""
- manager = get_stream_cache_manager()
- await manager.stop()
diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py
index 1096852cf..32e8c90e5 100644
--- a/src/chat/message_receive/bot.py
+++ b/src/chat/message_receive/bot.py
@@ -9,10 +9,10 @@ from maim_message import UserInfo
from src.chat.antipromptinjector import initialize_anti_injector
from src.chat.message_manager import message_manager
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
-from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.prompt import create_prompt_async, global_prompt_manager
from src.chat.utils.utils import is_mentioned_bot_in_message
+from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.config.config import global_config
from src.mood.mood_manager import mood_manager # 导入情绪管理器
@@ -105,10 +105,10 @@ class ChatBot:
self._started = True
- async def _process_plus_commands(self, message: MessageRecv):
+ async def _process_plus_commands(self, message: DatabaseMessages, chat: ChatStream):
"""独立处理PlusCommand系统"""
try:
- text = message.processed_plain_text
+ text = message.processed_plain_text or ""
# 获取配置的命令前缀
from src.config.config import global_config
@@ -166,10 +166,10 @@ class ChatBot:
# 检查命令是否被禁用
if (
- message.chat_stream
- and message.chat_stream.stream_id
+ chat
+ and chat.stream_id
and plus_command_name
- in global_announcement_manager.get_disabled_chat_commands(message.chat_stream.stream_id)
+ in global_announcement_manager.get_disabled_chat_commands(chat.stream_id)
):
logger.info("用户禁用的PlusCommand,跳过处理")
return False, None, True
@@ -181,11 +181,14 @@ class ChatBot:
# 创建PlusCommand实例
plus_command_instance = plus_command_class(message, plugin_config)
+
+ # 为插件实例设置 chat_stream 运行时属性
+ setattr(plus_command_instance, "chat_stream", chat)
try:
# 检查聊天类型限制
if not plus_command_instance.is_chat_type_allowed():
- is_group = message.message_info.group_info
+ is_group = chat.group_info is not None
logger.info(
f"PlusCommand {plus_command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
)
@@ -225,11 +228,11 @@ class ChatBot:
logger.error(f"处理PlusCommand时出错: {e}")
return False, None, True # 出错时继续处理消息
- async def _process_commands_with_new_system(self, message: MessageRecv):
+ async def _process_commands_with_new_system(self, message: DatabaseMessages, chat: ChatStream):
# sourcery skip: use-named-expression
"""使用新插件系统处理命令"""
try:
- text = message.processed_plain_text
+ text = message.processed_plain_text or ""
# 使用新的组件注册中心查找命令
command_result = component_registry.find_command_by_text(text)
@@ -238,10 +241,10 @@ class ChatBot:
plugin_name = command_info.plugin_name
command_name = command_info.name
if (
- message.chat_stream
- and message.chat_stream.stream_id
+ chat
+ and chat.stream_id
and command_name
- in global_announcement_manager.get_disabled_chat_commands(message.chat_stream.stream_id)
+ in global_announcement_manager.get_disabled_chat_commands(chat.stream_id)
):
logger.info("用户禁用的命令,跳过处理")
return False, None, True
@@ -254,11 +257,14 @@ class ChatBot:
# 创建命令实例
command_instance: BaseCommand = command_class(message, plugin_config)
command_instance.set_matched_groups(matched_groups)
+
+ # 为插件实例设置 chat_stream 运行时属性
+ setattr(command_instance, "chat_stream", chat)
try:
# 检查聊天类型限制
if not command_instance.is_chat_type_allowed():
- is_group = message.message_info.group_info
+ is_group = chat.group_info is not None
logger.info(
f"命令 {command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
)
@@ -295,13 +301,20 @@ class ChatBot:
logger.error(f"处理命令时出错: {e}")
return False, None, True # 出错时继续处理消息
- async def handle_notice_message(self, message: MessageRecv):
+ async def handle_notice_message(self, message: DatabaseMessages):
"""处理notice消息
notice消息是系统事件通知(如禁言、戳一戳等),具有以下特点:
1. 默认不触发聊天流程,只记录
2. 可通过配置开启触发聊天流程
3. 会在提示词中展示
+
+ Args:
+ message: DatabaseMessages 对象
+
+ Returns:
+ bool: True表示notice已完整处理(需要存储并终止后续流程)
+ False表示不是notice或notice需要继续处理(触发聊天流程)
"""
# 检查是否是notice消息
if message.is_notify:
@@ -309,53 +322,42 @@ class ChatBot:
# 根据配置决定是否触发聊天流程
if not global_config.notice.enable_notice_trigger_chat:
- logger.debug("notice消息不触发聊天流程(配置已关闭)")
- return True # 返回True表示已处理,不继续后续流程
+ logger.debug("notice消息不触发聊天流程(配置已关闭),将存储后终止")
+ return True # 返回True:需要在调用处存储并终止
else:
- logger.debug("notice消息触发聊天流程(配置已开启)")
- return False # 返回False表示继续处理,触发聊天流程
+ logger.debug("notice消息触发聊天流程(配置已开启),继续处理")
+ return False # 返回False:继续正常流程,作为普通消息处理
# 兼容旧的notice判断方式
- if message.message_info.message_id == "notice":
- message.is_notify = True
+ if message.message_id == "notice":
+ # 为 DatabaseMessages 设置 is_notify 运行时属性
+ from src.chat.message_receive.message_processor import set_db_message_runtime_attr
+ set_db_message_runtime_attr(message, "is_notify", True)
logger.info("旧格式notice消息")
# 同样根据配置决定
if not global_config.notice.enable_notice_trigger_chat:
- return True
+ logger.debug("旧格式notice消息不触发聊天流程,将存储后终止")
+ return True # 需要存储并终止
else:
- return False
+ logger.debug("旧格式notice消息触发聊天流程,继续处理")
+ return False # 继续正常流程
- # 处理适配器响应消息
- if hasattr(message, "message_segment") and message.message_segment:
- if message.message_segment.type == "adapter_response":
- await self.handle_adapter_response(message)
- return True
- elif message.message_segment.type == "adapter_command":
- # 适配器命令消息不需要进一步处理
- logger.debug("收到适配器命令消息,跳过后续处理")
- return True
+ # DatabaseMessages 不再有 message_segment,适配器响应处理已在消息处理阶段完成
+ # 这里保留逻辑以防万一,但实际上不会再执行到
+ return False # 不是notice消息,继续正常流程
- return False
-
- async def handle_adapter_response(self, message: MessageRecv):
- """处理适配器命令响应"""
+ async def handle_adapter_response(self, message: DatabaseMessages):
+ """处理适配器命令响应
+
+ 注意: 此方法目前未被调用,但保留以备将来使用
+ """
try:
from src.plugin_system.apis.send_api import put_adapter_response
- seg_data = message.message_segment.data
- if isinstance(seg_data, dict):
- request_id = seg_data.get("request_id")
- response_data = seg_data.get("response")
- else:
- request_id = None
- response_data = None
-
- if request_id and response_data:
- logger.debug(f"收到适配器响应: request_id={request_id}")
- put_adapter_response(request_id, response_data)
- else:
- logger.warning("适配器响应消息格式不正确")
+ # DatabaseMessages 使用 message_segments 字段存储消息段
+ # 注意: 这可能需要根据实际使用情况进行调整
+ logger.warning("handle_adapter_response 方法被调用,但目前未实现对 DatabaseMessages 的支持")
except Exception as e:
logger.error(f"处理适配器响应时出错: {e}")
@@ -381,9 +383,6 @@ class ChatBot:
await self._ensure_started()
# 控制握手等消息可能缺少 message_info,这里直接跳过避免 KeyError
- if not isinstance(message_data, dict):
- logger.warning(f"收到无法解析的消息类型: {type(message_data)},已跳过")
- return
message_info = message_data.get("message_info")
if not isinstance(message_info, dict):
logger.debug(
@@ -392,8 +391,6 @@ class ChatBot:
)
return
- platform = message_info.get("platform")
-
if message_info.get("group_info") is not None:
message_info["group_info"]["group_id"] = str(
message_info["group_info"]["group_id"]
@@ -404,74 +401,94 @@ class ChatBot:
)
# print(message_data)
# logger.debug(str(message_data))
- message = MessageRecv(message_data)
-
- group_info = message.message_info.group_info
- user_info = message.message_info.user_info
- if message.message_info.additional_config:
- sent_message = message.message_info.additional_config.get("echo", False)
+
+ # 先提取基础信息检查是否是自身消息上报
+ from maim_message import BaseMessageInfo
+ temp_message_info = BaseMessageInfo.from_dict(message_data.get("message_info", {}))
+ if temp_message_info.additional_config:
+ sent_message = temp_message_info.additional_config.get("echo", False)
if sent_message: # 这一段只是为了在一切处理前劫持上报的自身消息,用于更新message_id,需要ada支持上报事件,实际测试中不会对正常使用造成任何问题
- await MessageStorage.update_message(message)
+ # 直接使用消息字典更新,不再需要创建 MessageRecv
+ await MessageStorage.update_message(message_data)
return
+
+ group_info = temp_message_info.group_info
+ user_info = temp_message_info.user_info
- get_chat_manager().register_message(message)
-
+ # 获取或创建聊天流
chat = await get_chat_manager().get_or_create_stream(
- platform=message.message_info.platform, # type: ignore
+ platform=temp_message_info.platform, # type: ignore
user_info=user_info, # type: ignore
group_info=group_info,
)
- message.update_chat_stream(chat)
-
- # 处理消息内容,生成纯文本
- await message.process()
-
+ # 使用新的消息处理器直接生成 DatabaseMessages
+ from src.chat.message_receive.message_processor import process_message_from_dict
+ message = await process_message_from_dict(
+ message_dict=message_data,
+ stream_id=chat.stream_id,
+ platform=chat.platform
+ )
+
+ # 填充聊天流时间信息
+ message.chat_info.create_time = chat.create_time
+ message.chat_info.last_active_time = chat.last_active_time
+
+ # 注册消息到聊天管理器
+ get_chat_manager().register_message(message)
+
+ # 检测是否提及机器人
message.is_mentioned, _ = is_mentioned_bot_in_message(message)
# 在这里打印[所见]日志,确保在所有处理和过滤之前记录
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
- if message.message_info.user_info:
- logger.info(
- f"[{chat_name}]{message.message_info.user_info.user_nickname}:{message.processed_plain_text}\u001b[0m"
- )
+ user_nickname = message.user_info.user_nickname if message.user_info else "未知用户"
+ logger.info(
+ f"[{chat_name}]{user_nickname}:{message.processed_plain_text}\u001b[0m"
+ )
# 在此添加硬编码过滤,防止回复图片处理失败的消息
failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"]
- if any(keyword in message.processed_plain_text for keyword in failure_keywords):
- logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({message.processed_plain_text}),消息被静默处理。")
+ processed_text = message.processed_plain_text or ""
+ if any(keyword in processed_text for keyword in failure_keywords):
+ logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({processed_text}),消息被静默处理。")
return
# 处理notice消息
+ # notice_handled=True: 表示notice不触发聊天,需要在此存储并终止
+ # notice_handled=False: 表示notice触发聊天或不是notice,继续正常流程
notice_handled = await self.handle_notice_message(message)
if notice_handled:
- # notice消息已处理,使用统一的转换方法
+ # notice消息不触发聊天流程,在此进行存储和记录后终止
try:
- # 直接转换为 DatabaseMessages
- db_message = message.to_database_message()
-
+ # message 已经是 DatabaseMessages,直接使用
# 添加到message_manager(这会将notice添加到全局notice管理器)
- await message_manager.add_message(message.chat_stream.stream_id, db_message)
- logger.info(f"✅ Notice消息已添加到message_manager: type={message.notice_type}, stream={message.chat_stream.stream_id}")
+ await message_manager.add_message(chat.stream_id, message)
+ logger.info(f"✅ Notice消息已添加到message_manager: type={message.notice_type}, stream={chat.stream_id}")
except Exception as e:
logger.error(f"Notice消息添加到message_manager失败: {e}", exc_info=True)
- # 存储后直接返回
- await MessageStorage.store_message(message, chat)
- logger.debug("notice消息已存储,跳过后续处理")
+ # 存储notice消息到数据库(需要更新 storage.py 支持 DatabaseMessages)
+ # 暂时跳过存储,等待更新 storage.py
+ logger.debug("notice消息已添加到message_manager(存储功能待更新)")
return
+
+ # 如果notice_handled=False,则继续执行后续流程
+ # 对于启用触发聊天的notice,会在后续的正常流程中被存储和处理
# 过滤检查
+ # DatabaseMessages 使用 display_message 作为原始消息表示
+ raw_text = message.display_message or message.processed_plain_text or ""
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
- message.raw_message, # type: ignore
+ raw_text,
chat,
user_info, # type: ignore
):
return
# 命令处理 - 首先尝试PlusCommand独立处理
- is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message)
+ is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message, chat)
# 如果是PlusCommand且不需要继续处理,则直接返回
if is_plus_command and not plus_continue_process:
@@ -481,7 +498,7 @@ class ChatBot:
# 如果不是PlusCommand,尝试传统的BaseCommand处理
if not is_plus_command:
- is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message)
+ is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message, chat)
# 如果是命令且不需要继续处理,则直接返回
if is_command and not continue_process:
@@ -493,24 +510,14 @@ class ChatBot:
if result and not result.all_continue_process():
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理")
- # TODO:暂不可用
+ # TODO:暂不可用 - DatabaseMessages 不再有 message_info.template_info
# 确认从接口发来的message是否有自定义的prompt模板信息
- if message.message_info.template_info and not message.message_info.template_info.template_default:
- template_group_name: str | None = message.message_info.template_info.template_name # type: ignore
- template_items = message.message_info.template_info.template_items
- async with global_prompt_manager.async_message_scope(template_group_name):
- if isinstance(template_items, dict):
- for k in template_items.keys():
- await create_prompt_async(template_items[k], k)
- logger.debug(f"注册{template_items[k]},{k}")
- else:
- template_group_name = None
+ # 这个功能需要在 adapter 层通过 additional_config 传递
+ template_group_name = None
async def preprocess():
- # 使用统一的转换方法创建数据库消息对象
- db_message = message.to_database_message()
-
- group_info = getattr(message.chat_stream, "group_info", None)
+ # message 已经是 DatabaseMessages,直接使用
+ group_info = chat.group_info
# 先交给消息管理器处理,计算兴趣度等衍生数据
try:
@@ -527,31 +534,15 @@ class ChatBot:
should_process_in_manager = False
if should_process_in_manager:
- await message_manager.add_message(message.chat_stream.stream_id, db_message)
- logger.debug(f"消息已添加到消息管理器: {message.chat_stream.stream_id}")
+ await message_manager.add_message(chat.stream_id, message)
+ logger.debug(f"消息已添加到消息管理器: {chat.stream_id}")
except Exception as e:
logger.error(f"消息添加到消息管理器失败: {e}")
- # 将兴趣度结果同步回原始消息,便于后续流程使用
- message.interest_value = getattr(db_message, "interest_value", getattr(message, "interest_value", 0.0))
- setattr(
- message,
- "should_reply",
- getattr(db_message, "should_reply", getattr(message, "should_reply", False)),
- )
- setattr(message, "should_act", getattr(db_message, "should_act", getattr(message, "should_act", False)))
-
# 存储消息到数据库,只进行一次写入
try:
- await MessageStorage.store_message(message, message.chat_stream)
- logger.debug(
- "消息已存储到数据库: %s (interest=%.3f, should_reply=%s, should_act=%s)",
- message.message_info.message_id,
- getattr(message, "interest_value", -1.0),
- getattr(message, "should_reply", None),
- getattr(message, "should_act", None),
- )
+ await MessageStorage.store_message(message, chat)
except Exception as e:
logger.error(f"存储消息到数据库失败: {e}")
traceback.print_exc()
@@ -560,13 +551,13 @@ class ChatBot:
try:
if global_config.mood.enable_mood:
# 获取兴趣度用于情绪更新
- interest_rate = getattr(message, "interest_value", 0.0)
+ interest_rate = message.interest_value
if interest_rate is None:
interest_rate = 0.0
logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}")
# 获取当前聊天的情绪对象并更新情绪状态
- chat_mood = mood_manager.get_mood_by_chat_id(message.chat_stream.stream_id)
+ chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
await chat_mood.update_mood_by_message(message, interest_rate)
logger.debug("情绪状态更新完成")
except Exception as e:
diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py
index c22d755fb..a79b91400 100644
--- a/src/chat/message_receive/chat_stream.py
+++ b/src/chat/message_receive/chat_stream.py
@@ -12,13 +12,10 @@ from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
+from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.config.config import global_config # 新增导入
-# 避免循环导入,使用TYPE_CHECKING进行类型提示
-if TYPE_CHECKING:
- from .message import MessageRecv
-
install(extra_lines=3)
@@ -33,7 +30,7 @@ class ChatStream:
self,
stream_id: str,
platform: str,
- user_info: UserInfo,
+ user_info: UserInfo | None = None,
group_info: GroupInfo | None = None,
data: dict | None = None,
):
@@ -46,20 +43,18 @@ class ChatStream:
self.sleep_pressure = data.get("sleep_pressure", 0.0) if data else 0.0
self.saved = False
- # 使用StreamContext替代ChatMessageContext
+ # 创建单流上下文管理器(包含StreamContext)
+ from src.chat.message_manager.context_manager import SingleStreamContextManager
from src.common.data_models.message_manager_data_model import StreamContext
from src.plugin_system.base.component_types import ChatMode, ChatType
- # 创建StreamContext
- self.stream_context: StreamContext = StreamContext(
- stream_id=stream_id, chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, chat_mode=ChatMode.NORMAL
- )
-
- # 创建单流上下文管理器
- from src.chat.message_manager.context_manager import SingleStreamContextManager
-
self.context_manager: SingleStreamContextManager = SingleStreamContextManager(
- stream_id=stream_id, context=self.stream_context
+ stream_id=stream_id,
+ context=StreamContext(
+ stream_id=stream_id,
+ chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE,
+ chat_mode=ChatMode.NORMAL,
+ ),
)
# 基础参数
@@ -88,13 +83,12 @@ class ChatStream:
new_stream._focus_energy = self._focus_energy
new_stream.no_reply_consecutive = self.no_reply_consecutive
- # 复制 stream_context,但跳过 processing_task
- new_stream.stream_context = copy.deepcopy(self.stream_context, memo)
- if hasattr(new_stream.stream_context, "processing_task"):
- new_stream.stream_context.processing_task = None
-
- # 复制 context_manager
+ # 复制 context_manager(包含 stream_context)
new_stream.context_manager = copy.deepcopy(self.context_manager, memo)
+
+ # 清理 processing_task(如果存在)
+ if hasattr(new_stream.context_manager.context, "processing_task"):
+ new_stream.context_manager.context.processing_task = None
return new_stream
@@ -111,11 +105,11 @@ class ChatStream:
"focus_energy": self.focus_energy,
# 基础兴趣度
"base_interest_energy": self.base_interest_energy,
- # stream_context基本信息
- "stream_context_chat_type": self.stream_context.chat_type.value,
- "stream_context_chat_mode": self.stream_context.chat_mode.value,
+ # stream_context基本信息(通过context_manager访问)
+ "stream_context_chat_type": self.context_manager.context.chat_type.value,
+ "stream_context_chat_mode": self.context_manager.context.chat_mode.value,
# 统计信息
- "interruption_count": self.stream_context.interruption_count,
+ "interruption_count": self.context_manager.context.interruption_count,
}
@classmethod
@@ -132,27 +126,19 @@ class ChatStream:
data=data,
)
- # 恢复stream_context信息
+ # 恢复stream_context信息(通过context_manager访问)
if "stream_context_chat_type" in data:
from src.plugin_system.base.component_types import ChatMode, ChatType
- instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"])
+ instance.context_manager.context.chat_type = ChatType(data["stream_context_chat_type"])
if "stream_context_chat_mode" in data:
from src.plugin_system.base.component_types import ChatMode, ChatType
- instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"])
+ instance.context_manager.context.chat_mode = ChatMode(data["stream_context_chat_mode"])
# 恢复interruption_count信息
if "interruption_count" in data:
- instance.stream_context.interruption_count = data["interruption_count"]
-
- # 确保 context_manager 已初始化
- if not hasattr(instance, "context_manager"):
- from src.chat.message_manager.context_manager import SingleStreamContextManager
-
- instance.context_manager = SingleStreamContextManager(
- stream_id=instance.stream_id, context=instance.stream_context
- )
+ instance.context_manager.context.interruption_count = data["interruption_count"]
return instance
@@ -160,156 +146,44 @@ class ChatStream:
"""获取原始的、未哈希的聊天流ID字符串"""
if self.group_info:
return f"{self.platform}:{self.group_info.group_id}:group"
- else:
+ elif self.user_info:
return f"{self.platform}:{self.user_info.user_id}:private"
+ else:
+ return f"{self.platform}:unknown:private"
def update_active_time(self):
"""更新最后活跃时间"""
self.last_active_time = time.time()
self.saved = False
- async def set_context(self, message: "MessageRecv"):
- """设置聊天消息上下文"""
- # 将MessageRecv转换为DatabaseMessages并设置到stream_context
- import json
-
- from src.common.data_models.database_data_model import DatabaseMessages
-
- # 安全获取message_info中的数据
- message_info = getattr(message, "message_info", {})
- user_info = getattr(message_info, "user_info", {})
- group_info = getattr(message_info, "group_info", {})
-
- # 提取reply_to信息(从message_segment中查找reply类型的段)
- reply_to = None
- if hasattr(message, "message_segment") and message.message_segment:
- reply_to = self._extract_reply_from_segment(message.message_segment)
-
- # 完整的数据转移逻辑
- db_message = DatabaseMessages(
- # 基础消息信息
- message_id=getattr(message, "message_id", ""),
- time=getattr(message, "time", time.time()),
- chat_id=self._generate_chat_id(message_info),
- reply_to=reply_to,
- # 兴趣度相关
- interest_value=getattr(message, "interest_value", 0.0),
- # 关键词
- key_words=json.dumps(getattr(message, "key_words", []), ensure_ascii=False)
- if getattr(message, "key_words", None)
- else None,
- key_words_lite=json.dumps(getattr(message, "key_words_lite", []), ensure_ascii=False)
- if getattr(message, "key_words_lite", None)
- else None,
- # 消息状态标记
- is_mentioned=getattr(message, "is_mentioned", None),
- is_at=getattr(message, "is_at", False),
- is_emoji=getattr(message, "is_emoji", False),
- is_picid=getattr(message, "is_picid", False),
- is_voice=getattr(message, "is_voice", False),
- is_video=getattr(message, "is_video", False),
- is_command=getattr(message, "is_command", False),
- is_notify=getattr(message, "is_notify", False),
- is_public_notice=getattr(message, "is_public_notice", False),
- notice_type=getattr(message, "notice_type", None),
- # 消息内容
- processed_plain_text=getattr(message, "processed_plain_text", ""),
- display_message=getattr(message, "processed_plain_text", ""), # 默认使用processed_plain_text
- # 优先级信息
- priority_mode=getattr(message, "priority_mode", None),
- priority_info=json.dumps(getattr(message, "priority_info", None))
- if getattr(message, "priority_info", None)
- else None,
- # 额外配置 - 需要将 format_info 嵌入到 additional_config 中
- additional_config=self._prepare_additional_config(message_info),
- # 用户信息
- user_id=str(getattr(user_info, "user_id", "")),
- user_nickname=getattr(user_info, "user_nickname", ""),
- user_cardname=getattr(user_info, "user_cardname", None),
- user_platform=getattr(user_info, "platform", ""),
- # 群组信息
- chat_info_group_id=getattr(group_info, "group_id", None),
- chat_info_group_name=getattr(group_info, "group_name", None),
- chat_info_group_platform=getattr(group_info, "platform", None),
- # 聊天流信息
- chat_info_user_id=str(getattr(user_info, "user_id", "")),
- chat_info_user_nickname=getattr(user_info, "user_nickname", ""),
- chat_info_user_cardname=getattr(user_info, "user_cardname", None),
- chat_info_user_platform=getattr(user_info, "platform", ""),
- chat_info_stream_id=self.stream_id,
- chat_info_platform=self.platform,
- chat_info_create_time=self.create_time,
- chat_info_last_active_time=self.last_active_time,
- # 新增兴趣度系统字段 - 添加安全处理
- actions=self._safe_get_actions(message),
- should_reply=getattr(message, "should_reply", False),
- should_act=getattr(message, "should_act", False),
- )
-
- self.stream_context.set_current_message(db_message)
- self.stream_context.priority_mode = getattr(message, "priority_mode", None)
- self.stream_context.priority_info = getattr(message, "priority_info", None)
-
- # 调试日志:记录数据转移情况
- logger.debug(
- f"消息数据转移完成 - message_id: {db_message.message_id}, "
- f"chat_id: {db_message.chat_id}, "
- f"is_mentioned: {db_message.is_mentioned}, "
- f"is_emoji: {db_message.is_emoji}, "
- f"is_picid: {db_message.is_picid}, "
- f"interest_value: {db_message.interest_value}"
- )
-
- def _prepare_additional_config(self, message_info) -> str | None:
- """
- 准备 additional_config,将 format_info 嵌入其中
-
- 这个方法模仿 storage.py 中的逻辑,确保 DatabaseMessages 中的 additional_config
- 包含 format_info,使得 action_modifier 能够正确获取适配器支持的消息类型
-
+ async def set_context(self, message: DatabaseMessages):
+ """设置聊天消息上下文
+
Args:
- message_info: BaseMessageInfo 对象
-
- Returns:
- str | None: JSON 字符串格式的 additional_config,如果为空则返回 None
+ message: DatabaseMessages 对象,直接使用不需要转换
"""
- import orjson
-
- # 首先获取adapter传递的additional_config
- additional_config_data = {}
- if hasattr(message_info, 'additional_config') and message_info.additional_config:
- if isinstance(message_info.additional_config, dict):
- additional_config_data = message_info.additional_config.copy()
- elif isinstance(message_info.additional_config, str):
- # 如果是字符串,尝试解析
- try:
- additional_config_data = orjson.loads(message_info.additional_config)
- except Exception as e:
- logger.warning(f"无法解析 additional_config JSON: {e}")
- additional_config_data = {}
+ # 直接使用传入的 DatabaseMessages,设置到上下文中
+ self.context_manager.context.set_current_message(message)
- # 然后添加format_info到additional_config中
- if hasattr(message_info, 'format_info') and message_info.format_info:
- try:
- format_info_dict = message_info.format_info.to_dict()
- additional_config_data["format_info"] = format_info_dict
- logger.debug(f"嵌入 format_info 到 additional_config: {format_info_dict}")
- except Exception as e:
- logger.warning(f"将 format_info 转换为字典失败: {e}")
- else:
- logger.warning(f"[问题] 消息缺少 format_info: message_id={getattr(message_info, 'message_id', 'unknown')}")
- logger.warning("[问题] 这可能导致 Action 无法正确检查适配器支持的类型")
-
- # 序列化为JSON字符串
- if additional_config_data:
- try:
- return orjson.dumps(additional_config_data).decode("utf-8")
- except Exception as e:
- logger.error(f"序列化 additional_config 失败: {e}")
- return None
- return None
+ # 设置优先级信息(如果存在)
+ priority_mode = getattr(message, "priority_mode", None)
+ priority_info = getattr(message, "priority_info", None)
+ if priority_mode:
+ self.context_manager.context.priority_mode = priority_mode
+ if priority_info:
+ self.context_manager.context.priority_info = priority_info
- def _safe_get_actions(self, message: "MessageRecv") -> list | None:
+ # 调试日志
+ logger.debug(
+ f"消息上下文已设置 - message_id: {message.message_id}, "
+ f"chat_id: {message.chat_id}, "
+ f"is_mentioned: {message.is_mentioned}, "
+ f"is_emoji: {message.is_emoji}, "
+ f"is_picid: {message.is_picid}, "
+ f"interest_value: {message.interest_value}"
+ )
+
+ def _safe_get_actions(self, message: DatabaseMessages) -> list | None:
"""安全获取消息的actions字段"""
import json
@@ -380,23 +254,6 @@ class ChatStream:
if hasattr(db_message, "should_act"):
db_message.should_act = False
- def _extract_reply_from_segment(self, segment) -> str | None:
- """从消息段中提取reply_to信息"""
- try:
- if hasattr(segment, "type") and segment.type == "seglist":
- # 递归搜索seglist中的reply段
- if hasattr(segment, "data") and segment.data:
- for seg in segment.data:
- reply_id = self._extract_reply_from_segment(seg)
- if reply_id:
- return reply_id
- elif hasattr(segment, "type") and segment.type == "reply":
- # 找到reply段,返回message_id
- return str(segment.data) if segment.data else None
- except Exception as e:
- logger.warning(f"提取reply_to信息失败: {e}")
- return None
-
def _generate_chat_id(self, message_info) -> str:
"""生成chat_id,基于群组或用户信息"""
try:
@@ -493,8 +350,10 @@ class ChatManager:
def __init__(self):
if not self._initialized:
+ from src.common.data_models.database_data_model import DatabaseMessages
+
self.streams: dict[str, ChatStream] = {} # stream_id -> ChatStream
- self.last_messages: dict[str, "MessageRecv"] = {} # stream_id -> last_message
+ self.last_messages: dict[str, DatabaseMessages] = {} # stream_id -> last_message
# try:
# async with get_db_session() as session:
# db.connect(reuse_if_open=True)
@@ -528,12 +387,30 @@ class ChatManager:
except Exception as e:
logger.error(f"聊天流自动保存失败: {e!s}")
- def register_message(self, message: "MessageRecv"):
+ def register_message(self, message: DatabaseMessages):
"""注册消息到聊天流"""
+ # 从 DatabaseMessages 提取平台和用户/群组信息
+ from maim_message import UserInfo, GroupInfo
+
+ user_info = UserInfo(
+ platform=message.user_info.platform,
+ user_id=message.user_info.user_id,
+ user_nickname=message.user_info.user_nickname,
+ user_cardname=message.user_info.user_cardname or ""
+ )
+
+ group_info = None
+ if message.group_info:
+ group_info = GroupInfo(
+ platform=message.group_info.group_platform or "",
+ group_id=message.group_info.group_id,
+ group_name=message.group_info.group_name
+ )
+
stream_id = self._generate_stream_id(
- message.message_info.platform, # type: ignore
- message.message_info.user_info,
- message.message_info.group_info,
+ message.chat_info.platform,
+ user_info,
+ group_info,
)
self.last_messages[stream_id] = message
# logger.debug(f"注册消息到聊天流: {stream_id}")
@@ -578,32 +455,6 @@ class ChatManager:
try:
stream_id = self._generate_stream_id(platform, user_info, group_info)
- # 优先使用缓存管理器(优化版本)
- try:
- from src.chat.message_manager.stream_cache_manager import get_stream_cache_manager
-
- cache_manager = get_stream_cache_manager()
-
- if cache_manager.is_running:
- optimized_stream = await cache_manager.get_or_create_stream(
- stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info
- )
-
- # 设置消息上下文
- from .message import MessageRecv
-
- if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
- optimized_stream.set_context(self.last_messages[stream_id])
-
- # 转换为原始ChatStream以保持兼容性
- original_stream = self._convert_to_original_stream(optimized_stream)
-
- return original_stream
-
- except Exception as e:
- logger.debug(f"缓存管理器获取流失败,使用原始方法: {e}")
-
- # 回退到原始方法
# 检查内存中是否存在
if stream_id in self.streams:
stream = self.streams[stream_id]
@@ -615,12 +466,13 @@ class ChatManager:
stream.user_info = user_info
if group_info:
stream.group_info = group_info
- from .message import MessageRecv # 延迟导入,避免循环引用
-
- if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
+
+ # 检查是否有最后一条消息(现在使用 DatabaseMessages)
+ from src.common.data_models.database_data_model import DatabaseMessages
+ if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
await stream.set_context(self.last_messages[stream_id])
else:
- logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
+ logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的或还没有消息")
return stream
# 检查数据库中是否存在
@@ -679,19 +531,27 @@ class ChatManager:
raise e
stream = copy.deepcopy(stream)
- from .message import MessageRecv # 延迟导入,避免循环引用
+ from src.common.data_models.database_data_model import DatabaseMessages
- if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
+ if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
await stream.set_context(self.last_messages[stream_id])
else:
- logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
+ logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
# 确保 ChatStream 有自己的 context_manager
if not hasattr(stream, "context_manager"):
- # 创建新的单流上下文管理器
from src.chat.message_manager.context_manager import SingleStreamContextManager
+ from src.common.data_models.message_manager_data_model import StreamContext
+ from src.plugin_system.base.component_types import ChatMode, ChatType
- stream.context_manager = SingleStreamContextManager(stream_id=stream_id, context=stream.stream_context)
+ stream.context_manager = SingleStreamContextManager(
+ stream_id=stream_id,
+ context=StreamContext(
+ stream_id=stream_id,
+ chat_type=ChatType.GROUP if stream.group_info else ChatType.PRIVATE,
+ chat_mode=ChatMode.NORMAL,
+ ),
+ )
# 保存到内存和数据库
self.streams[stream_id] = stream
@@ -700,10 +560,12 @@ class ChatManager:
async def get_stream(self, stream_id: str) -> ChatStream | None:
"""通过stream_id获取聊天流"""
+ from src.common.data_models.database_data_model import DatabaseMessages
+
stream = self.streams.get(stream_id)
if not stream:
return None
- if stream_id in self.last_messages:
+ if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
await stream.set_context(self.last_messages[stream_id])
return stream
@@ -921,9 +783,16 @@ class ChatManager:
# 确保 ChatStream 有自己的 context_manager
if not hasattr(stream, "context_manager"):
from src.chat.message_manager.context_manager import SingleStreamContextManager
+ from src.common.data_models.message_manager_data_model import StreamContext
+ from src.plugin_system.base.component_types import ChatMode, ChatType
stream.context_manager = SingleStreamContextManager(
- stream_id=stream.stream_id, context=stream.stream_context
+ stream_id=stream.stream_id,
+ context=StreamContext(
+ stream_id=stream.stream_id,
+ chat_type=ChatType.GROUP if stream.group_info else ChatType.PRIVATE,
+ chat_mode=ChatMode.NORMAL,
+ ),
)
except Exception as e:
logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}", exc_info=True)
@@ -932,46 +801,6 @@ class ChatManager:
chat_manager = None
-def _convert_to_original_stream(self, optimized_stream) -> "ChatStream":
- """将OptimizedChatStream转换为原始ChatStream以保持兼容性"""
- try:
- # 创建原始ChatStream实例
- original_stream = ChatStream(
- stream_id=optimized_stream.stream_id,
- platform=optimized_stream.platform,
- user_info=optimized_stream._get_effective_user_info(),
- group_info=optimized_stream._get_effective_group_info(),
- )
-
- # 复制状态
- original_stream.create_time = optimized_stream.create_time
- original_stream.last_active_time = optimized_stream.last_active_time
- original_stream.sleep_pressure = optimized_stream.sleep_pressure
- original_stream.base_interest_energy = optimized_stream.base_interest_energy
- original_stream._focus_energy = optimized_stream._focus_energy
- original_stream.no_reply_consecutive = optimized_stream.no_reply_consecutive
- original_stream.saved = optimized_stream.saved
-
- # 复制上下文信息(如果存在)
- if hasattr(optimized_stream, "_stream_context") and optimized_stream._stream_context:
- original_stream.stream_context = optimized_stream._stream_context
-
- if hasattr(optimized_stream, "_context_manager") and optimized_stream._context_manager:
- original_stream.context_manager = optimized_stream._context_manager
-
- return original_stream
-
- except Exception as e:
- logger.error(f"转换OptimizedChatStream失败: {e}")
- # 如果转换失败,创建一个新的原始流
- return ChatStream(
- stream_id=optimized_stream.stream_id,
- platform=optimized_stream.platform,
- user_info=optimized_stream._get_effective_user_info(),
- group_info=optimized_stream._get_effective_group_info(),
- )
-
-
def get_chat_manager():
global chat_manager
if chat_manager is None:
diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py
index b7603b2f3..9cdaa337f 100644
--- a/src/chat/message_receive/message.py
+++ b/src/chat/message_receive/message.py
@@ -2,7 +2,7 @@ import base64
import time
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
-from typing import Any, Optional
+from typing import Any, Optional, Union
import urllib3
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo
@@ -13,6 +13,7 @@ from src.chat.utils.self_voice_cache import consume_self_voice_text
from src.chat.utils.utils_image import get_image_manager
from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available
from src.chat.utils.utils_voice import get_voice_text
+from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.config.config import global_config
@@ -43,7 +44,7 @@ class Message(MessageBase, metaclass=ABCMeta):
user_info: UserInfo,
message_segment: Seg | None = None,
timestamp: float | None = None,
- reply: Optional["MessageRecv"] = None,
+ reply: Optional["DatabaseMessages"] = None,
processed_plain_text: str = "",
):
# 使用传入的时间戳或当前时间
@@ -95,346 +96,12 @@ class Message(MessageBase, metaclass=ABCMeta):
@dataclass
-class MessageRecv(Message):
- """接收消息类,用于处理从MessageCQ序列化的消息"""
- def __init__(self, message_dict: dict[str, Any]):
- """从MessageCQ的字典初始化
-
- Args:
- message_dict: MessageCQ序列化后的字典
- """
- # Manually initialize attributes from MessageBase and Message
- self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
- self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
- self.raw_message = message_dict.get("raw_message")
-
- self.chat_stream = None
- self.reply = None
- self.processed_plain_text = message_dict.get("processed_plain_text", "")
- self.memorized_times = 0
-
- # MessageRecv specific attributes
- self.is_emoji = False
- self.has_emoji = False
- self.is_picid = False
- self.has_picid = False
- self.is_voice = False
- self.is_video = False
- self.is_mentioned = None
- self.is_notify = False # 是否为notice消息
- self.is_public_notice = False # 是否为公共notice
- self.notice_type = None # notice类型
- self.is_at = False
- self.is_command = False
-
- self.priority_mode = "interest"
- self.priority_info = None
- self.interest_value: float = 0.0
-
- self.key_words = []
- self.key_words_lite = []
-
- # 解析additional_config中的notice信息
- if self.message_info.additional_config and isinstance(self.message_info.additional_config, dict):
- self.is_notify = self.message_info.additional_config.get("is_notice", False)
- self.is_public_notice = self.message_info.additional_config.get("is_public_notice", False)
- self.notice_type = self.message_info.additional_config.get("notice_type")
-
- def update_chat_stream(self, chat_stream: "ChatStream"):
- self.chat_stream = chat_stream
-
- def to_database_message(self) -> "DatabaseMessages":
- """将 MessageRecv 转换为 DatabaseMessages 对象
-
- Returns:
- DatabaseMessages: 数据库消息对象
- """
- from src.common.data_models.database_data_model import DatabaseMessages
- import json
- import time
-
- message_info = self.message_info
- msg_user_info = getattr(message_info, "user_info", None)
- stream_user_info = getattr(self.chat_stream, "user_info", None) if self.chat_stream else None
- group_info = getattr(self.chat_stream, "group_info", None) if self.chat_stream else None
-
- message_id = message_info.message_id or ""
- message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time()
- is_mentioned = None
- if isinstance(self.is_mentioned, bool):
- is_mentioned = self.is_mentioned
- elif isinstance(self.is_mentioned, int | float):
- is_mentioned = self.is_mentioned != 0
-
- # 提取用户信息
- user_id = ""
- user_nickname = ""
- user_cardname = None
- user_platform = ""
- if msg_user_info:
- user_id = str(getattr(msg_user_info, "user_id", "") or "")
- user_nickname = getattr(msg_user_info, "user_nickname", "") or ""
- user_cardname = getattr(msg_user_info, "user_cardname", None)
- user_platform = getattr(msg_user_info, "platform", "") or ""
- elif stream_user_info:
- user_id = str(getattr(stream_user_info, "user_id", "") or "")
- user_nickname = getattr(stream_user_info, "user_nickname", "") or ""
- user_cardname = getattr(stream_user_info, "user_cardname", None)
- user_platform = getattr(stream_user_info, "platform", "") or ""
-
- # 提取聊天流信息
- chat_user_id = str(getattr(stream_user_info, "user_id", "") or "") if stream_user_info else ""
- chat_user_nickname = getattr(stream_user_info, "user_nickname", "") or "" if stream_user_info else ""
- chat_user_cardname = getattr(stream_user_info, "user_cardname", None) if stream_user_info else None
- chat_user_platform = getattr(stream_user_info, "platform", "") or "" if stream_user_info else ""
-
- group_id = getattr(group_info, "group_id", None) if group_info else None
- group_name = getattr(group_info, "group_name", None) if group_info else None
- group_platform = getattr(group_info, "platform", None) if group_info else None
-
- # 准备 additional_config
- additional_config_str = None
- try:
- import orjson
-
- additional_config_data = {}
-
- # 首先获取adapter传递的additional_config
- if hasattr(message_info, 'additional_config') and message_info.additional_config:
- if isinstance(message_info.additional_config, dict):
- additional_config_data = message_info.additional_config.copy()
- elif isinstance(message_info.additional_config, str):
- try:
- additional_config_data = orjson.loads(message_info.additional_config)
- except Exception as e:
- logger.warning(f"无法解析 additional_config JSON: {e}")
- additional_config_data = {}
-
- # 添加notice相关标志
- if self.is_notify:
- additional_config_data["is_notice"] = True
- additional_config_data["notice_type"] = self.notice_type or "unknown"
- additional_config_data["is_public_notice"] = bool(self.is_public_notice)
-
- # 添加format_info到additional_config中
- if hasattr(message_info, 'format_info') and message_info.format_info:
- try:
- format_info_dict = message_info.format_info.to_dict()
- additional_config_data["format_info"] = format_info_dict
- logger.debug(f"[message.py] 嵌入 format_info 到 additional_config: {format_info_dict}")
- except Exception as e:
- logger.warning(f"将 format_info 转换为字典失败: {e}")
-
- # 序列化为JSON字符串
- if additional_config_data:
- additional_config_str = orjson.dumps(additional_config_data).decode("utf-8")
- except Exception as e:
- logger.error(f"准备 additional_config 失败: {e}")
-
- # 创建数据库消息对象
- db_message = DatabaseMessages(
- message_id=message_id,
- time=float(message_time),
- chat_id=self.chat_stream.stream_id if self.chat_stream else "",
- processed_plain_text=self.processed_plain_text,
- display_message=self.processed_plain_text,
- is_mentioned=is_mentioned,
- is_at=bool(self.is_at) if self.is_at is not None else None,
- is_emoji=bool(self.is_emoji),
- is_picid=bool(self.is_picid),
- is_command=bool(self.is_command),
- is_notify=bool(self.is_notify),
- is_public_notice=bool(self.is_public_notice),
- notice_type=self.notice_type,
- additional_config=additional_config_str,
- user_id=user_id,
- user_nickname=user_nickname,
- user_cardname=user_cardname,
- user_platform=user_platform,
- chat_info_stream_id=self.chat_stream.stream_id if self.chat_stream else "",
- chat_info_platform=self.chat_stream.platform if self.chat_stream else "",
- chat_info_create_time=float(self.chat_stream.create_time) if self.chat_stream else 0.0,
- chat_info_last_active_time=float(self.chat_stream.last_active_time) if self.chat_stream else 0.0,
- chat_info_user_id=chat_user_id,
- chat_info_user_nickname=chat_user_nickname,
- chat_info_user_cardname=chat_user_cardname,
- chat_info_user_platform=chat_user_platform,
- chat_info_group_id=group_id,
- chat_info_group_name=group_name,
- chat_info_group_platform=group_platform,
- )
-
- # 同步兴趣度等衍生属性
- db_message.interest_value = getattr(self, "interest_value", 0.0)
- setattr(db_message, "should_reply", getattr(self, "should_reply", False))
- setattr(db_message, "should_act", getattr(self, "should_act", False))
-
- return db_message
-
- async def process(self) -> None:
- """处理消息内容,生成纯文本和详细文本
-
- 这个方法必须在创建实例后显式调用,因为它包含异步操作。
- """
- self.processed_plain_text = await self._process_message_segments(self.message_segment)
-
- async def _process_single_segment(self, segment: Seg) -> str:
- """处理单个消息段
-
- Args:
- segment: 消息段
-
- Returns:
- str: 处理后的文本
- """
- try:
- if segment.type == "text":
- self.is_picid = False
- self.is_emoji = False
- self.is_video = False
- return segment.data # type: ignore
- elif segment.type == "at":
- self.is_picid = False
- self.is_emoji = False
- self.is_video = False
- # 处理at消息,格式为"昵称:QQ号"
- if isinstance(segment.data, str) and ":" in segment.data:
- nickname, qq_id = segment.data.split(":", 1)
- return f"@{nickname}"
- return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户"
- elif segment.type == "image":
- # 如果是base64图片数据
- if isinstance(segment.data, str):
- self.has_picid = True
- self.is_picid = True
- self.is_emoji = False
- self.is_video = False
- image_manager = get_image_manager()
- # print(f"segment.data: {segment.data}")
- _, processed_text = await image_manager.process_image(segment.data)
- return processed_text
- return "[发了一张图片,网卡了加载不出来]"
- elif segment.type == "emoji":
- self.has_emoji = True
- self.is_emoji = True
- self.is_picid = False
- self.is_voice = False
- self.is_video = False
- if isinstance(segment.data, str):
- return await get_image_manager().get_emoji_description(segment.data)
- return "[发了一个表情包,网卡了加载不出来]"
- elif segment.type == "voice":
- self.is_picid = False
- self.is_emoji = False
- self.is_voice = True
- self.is_video = False
-
- # 检查消息是否由机器人自己发送
- if self.message_info and self.message_info.user_info and str(self.message_info.user_info.user_id) == str(global_config.bot.qq_account):
- logger.info(f"检测到机器人自身发送的语音消息 (User ID: {self.message_info.user_info.user_id}),尝试从缓存获取文本。")
- if isinstance(segment.data, str):
- cached_text = consume_self_voice_text(segment.data)
- if cached_text:
- logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'")
- return f"[语音:{cached_text}]"
- else:
- logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
-
- # 标准语音识别流程 (也作为缓存未命中的后备方案)
- if isinstance(segment.data, str):
- return await get_voice_text(segment.data)
- return "[发了一段语音,网卡了加载不出来]"
- elif segment.type == "mention_bot":
- self.is_picid = False
- self.is_emoji = False
- self.is_voice = False
- self.is_video = False
- self.is_mentioned = float(segment.data) # type: ignore
- return ""
- elif segment.type == "priority_info":
- self.is_picid = False
- self.is_emoji = False
- self.is_voice = False
- if isinstance(segment.data, dict):
- # 处理优先级信息
- self.priority_mode = "priority"
- self.priority_info = segment.data
- """
- {
- 'message_type': 'vip', # vip or normal
- 'message_priority': 1.0, # 优先级,大为优先,float
- }
- """
- return ""
- elif segment.type == "file":
- if isinstance(segment.data, dict):
- file_name = segment.data.get('name', '未知文件')
- file_size = segment.data.get('size', '未知大小')
- return f"[文件:{file_name} ({file_size}字节)]"
- return "[收到一个文件]"
- elif segment.type == "video":
- self.is_picid = False
- self.is_emoji = False
- self.is_voice = False
- self.is_video = True
- logger.info(f"接收到视频消息,数据类型: {type(segment.data)}")
-
- # 检查视频分析功能是否可用
- if not is_video_analysis_available():
- logger.warning("⚠️ Rust视频处理模块不可用,跳过视频分析")
- return "[视频]"
-
- if global_config.video_analysis.enable:
- logger.info("已启用视频识别,开始识别")
- if isinstance(segment.data, dict):
- try:
- # 从Adapter接收的视频数据
- video_base64 = segment.data.get("base64")
- filename = segment.data.get("filename", "video.mp4")
-
- logger.info(f"视频文件名: {filename}")
- logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
-
- if video_base64:
- # 解码base64视频数据
- video_bytes = base64.b64decode(video_base64)
- logger.info(f"解码后视频大小: {len(video_bytes)} 字节")
-
- # 使用video analyzer分析视频
- video_analyzer = get_video_analyzer()
- result = await video_analyzer.analyze_video_from_bytes(
- video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt
- )
-
- logger.info(f"视频分析结果: {result}")
-
- # 返回视频分析结果
- summary = result.get("summary", "")
- if summary:
- return f"[视频内容] {summary}"
- else:
- return "[已收到视频,但分析失败]"
- else:
- logger.warning("视频消息中没有base64数据")
- return "[收到视频消息,但数据异常]"
- except Exception as e:
- logger.error(f"视频处理失败: {e!s}")
- import traceback
-
- logger.error(f"错误详情: {traceback.format_exc()}")
- return "[收到视频,但处理时出现错误]"
- else:
- logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}")
- return "[发了一个视频,但格式不支持]"
- else:
- return ""
- else:
- logger.warning(f"未知的消息段类型: {segment.type}")
- return f"[{segment.type} 消息]"
- except Exception as e:
- logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
- return f"[处理失败的{segment.type}消息]"
+# MessageRecv 类已被完全移除,现在统一使用 DatabaseMessages
+# 如需从消息字典创建 DatabaseMessages,请使用:
+# from src.chat.message_receive.message_processor import process_message_from_dict
+#
+# 迁移完成日期: 2025-10-31
@dataclass
@@ -447,7 +114,7 @@ class MessageProcessBase(Message):
chat_stream: "ChatStream",
bot_user_info: UserInfo,
message_segment: Seg | None = None,
- reply: Optional["MessageRecv"] = None,
+ reply: Optional["DatabaseMessages"] = None,
thinking_start_time: float = 0,
timestamp: float | None = None,
):
@@ -548,7 +215,7 @@ class MessageSending(MessageProcessBase):
sender_info: UserInfo | None, # 用来记录发送者信息
message_segment: Seg,
display_message: str = "",
- reply: Optional["MessageRecv"] = None,
+ reply: Optional["DatabaseMessages"] = None,
is_head: bool = False,
is_emoji: bool = False,
thinking_start_time: float = 0,
@@ -567,7 +234,11 @@ class MessageSending(MessageProcessBase):
# 发送状态特有属性
self.sender_info = sender_info
- self.reply_to_message_id = reply.message_info.message_id if reply else None
+ # 从 DatabaseMessages 获取 message_id
+ if reply:
+ self.reply_to_message_id = reply.message_id
+ else:
+ self.reply_to_message_id = None
self.is_head = is_head
self.is_emoji = is_emoji
self.apply_set_reply_logic = apply_set_reply_logic
@@ -582,14 +253,18 @@ class MessageSending(MessageProcessBase):
def build_reply(self):
"""设置回复消息"""
if self.reply:
- self.reply_to_message_id = self.reply.message_info.message_id
- self.message_segment = Seg(
- type="seglist",
- data=[
- Seg(type="reply", data=self.reply.message_info.message_id), # type: ignore
- self.message_segment,
- ],
- )
+ # 从 DatabaseMessages 获取 message_id
+ message_id = self.reply.message_id
+
+ if message_id:
+ self.reply_to_message_id = message_id
+ self.message_segment = Seg(
+ type="seglist",
+ data=[
+ Seg(type="reply", data=message_id), # type: ignore
+ self.message_segment,
+ ],
+ )
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本"""
@@ -607,48 +282,5 @@ class MessageSending(MessageProcessBase):
return self.message_info.group_info is None or self.message_info.group_info.group_id is None
-def message_recv_from_dict(message_dict: dict) -> MessageRecv:
- return MessageRecv(message_dict)
-
-def message_from_db_dict(db_dict: dict) -> MessageRecv:
- """从数据库字典创建MessageRecv实例"""
- # 转换扁平的数据库字典为嵌套结构
- message_info_dict = {
- "platform": db_dict.get("chat_info_platform"),
- "message_id": db_dict.get("message_id"),
- "time": db_dict.get("time"),
- "group_info": {
- "platform": db_dict.get("chat_info_group_platform"),
- "group_id": db_dict.get("chat_info_group_id"),
- "group_name": db_dict.get("chat_info_group_name"),
- },
- "user_info": {
- "platform": db_dict.get("user_platform"),
- "user_id": db_dict.get("user_id"),
- "user_nickname": db_dict.get("user_nickname"),
- "user_cardname": db_dict.get("user_cardname"),
- },
- }
-
- processed_text = db_dict.get("processed_plain_text", "")
-
- # 构建 MessageRecv 需要的字典
- recv_dict = {
- "message_info": message_info_dict,
- "message_segment": {"type": "text", "data": processed_text}, # 从纯文本重建消息段
- "raw_message": None, # 数据库中未存储原始消息
- "processed_plain_text": processed_text,
- }
-
- # 创建 MessageRecv 实例
- msg = MessageRecv(recv_dict)
-
- # 从数据库字典中填充其他可选字段
- msg.interest_value = db_dict.get("interest_value", 0.0)
- msg.is_mentioned = db_dict.get("is_mentioned")
- msg.priority_mode = db_dict.get("priority_mode", "interest")
- msg.priority_info = db_dict.get("priority_info")
- msg.is_emoji = db_dict.get("is_emoji", False)
- msg.is_picid = db_dict.get("is_picid", False)
-
- return msg
+# message_recv_from_dict 和 message_from_db_dict 函数已被移除
+# 请使用: from src.chat.message_receive.message_processor import process_message_from_dict
diff --git a/src/chat/message_receive/message_processor.py b/src/chat/message_receive/message_processor.py
new file mode 100644
index 000000000..5da582710
--- /dev/null
+++ b/src/chat/message_receive/message_processor.py
@@ -0,0 +1,493 @@
+"""消息处理工具模块
+将原 MessageRecv 的消息处理逻辑提取为独立函数,
+直接从适配器消息字典生成 DatabaseMessages
+"""
+import base64
+import time
+from typing import Any
+
+import orjson
+from maim_message import BaseMessageInfo, Seg
+
+from src.chat.utils.self_voice_cache import consume_self_voice_text
+from src.chat.utils.utils_image import get_image_manager
+from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available
+from src.chat.utils.utils_voice import get_voice_text
+from src.common.data_models.database_data_model import DatabaseMessages
+from src.common.logger import get_logger
+from src.config.config import global_config
+
+logger = get_logger("message_processor")
+
+
+async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str, platform: str) -> DatabaseMessages:
+ """从适配器消息字典处理并生成 DatabaseMessages
+
+ 这个函数整合了原 MessageRecv 的所有处理逻辑:
+ 1. 解析 message_segment 并异步处理内容(图片、语音、视频等)
+ 2. 提取所有消息元数据
+ 3. 直接构造 DatabaseMessages 对象
+
+ Args:
+ message_dict: MessageCQ序列化后的字典
+ stream_id: 聊天流ID
+ platform: 平台标识
+
+ Returns:
+ DatabaseMessages: 处理完成的数据库消息对象
+ """
+ # 解析基础信息
+ message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
+ message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
+
+ # 初始化处理状态
+ processing_state = {
+ "is_emoji": False,
+ "has_emoji": False,
+ "is_picid": False,
+ "has_picid": False,
+ "is_voice": False,
+ "is_video": False,
+ "is_mentioned": None,
+ "is_at": False,
+ "priority_mode": "interest",
+ "priority_info": None,
+ }
+
+ # 异步处理消息段,生成纯文本
+ processed_plain_text = await _process_message_segments(message_segment, processing_state, message_info)
+
+ # 解析 notice 信息
+ is_notify = False
+ is_public_notice = False
+ notice_type = None
+ if message_info.additional_config and isinstance(message_info.additional_config, dict):
+ is_notify = message_info.additional_config.get("is_notice", False)
+ is_public_notice = message_info.additional_config.get("is_public_notice", False)
+ notice_type = message_info.additional_config.get("notice_type")
+
+ # 提取用户信息
+ user_info = message_info.user_info
+ user_id = str(user_info.user_id) if user_info and user_info.user_id else ""
+ user_nickname = (user_info.user_nickname or "") if user_info else ""
+ user_cardname = user_info.user_cardname if user_info else None
+ user_platform = (user_info.platform or "") if user_info else ""
+
+ # 提取群组信息
+ group_info = message_info.group_info
+ group_id = group_info.group_id if group_info else None
+ group_name = group_info.group_name if group_info else None
+ group_platform = group_info.platform if group_info else None
+
+ # 生成 chat_id
+ if group_info and group_id:
+ chat_id = f"{platform}_{group_id}"
+ elif user_info and user_info.user_id:
+ chat_id = f"{platform}_{user_info.user_id}_private"
+ else:
+ chat_id = stream_id
+
+ # 准备 additional_config
+ additional_config_str = _prepare_additional_config(message_info, is_notify, is_public_notice, notice_type)
+
+ # 提取 reply_to
+ reply_to = _extract_reply_from_segment(message_segment)
+
+ # 构造 DatabaseMessages
+ message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time()
+ message_id = message_info.message_id or ""
+
+ # 处理 is_mentioned
+ is_mentioned = None
+ mentioned_value = processing_state.get("is_mentioned")
+ if isinstance(mentioned_value, bool):
+ is_mentioned = mentioned_value
+ elif isinstance(mentioned_value, (int, float)):
+ is_mentioned = mentioned_value != 0
+
+ db_message = DatabaseMessages(
+ message_id=message_id,
+ time=float(message_time),
+ chat_id=chat_id,
+ reply_to=reply_to,
+ processed_plain_text=processed_plain_text,
+ display_message=processed_plain_text,
+ is_mentioned=is_mentioned,
+ is_at=bool(processing_state.get("is_at", False)),
+ is_emoji=bool(processing_state.get("is_emoji", False)),
+ is_picid=bool(processing_state.get("is_picid", False)),
+ is_command=False, # 将在后续处理中设置
+ is_notify=bool(is_notify),
+ is_public_notice=bool(is_public_notice),
+ notice_type=notice_type,
+ additional_config=additional_config_str,
+ user_id=user_id,
+ user_nickname=user_nickname,
+ user_cardname=user_cardname,
+ user_platform=user_platform,
+ chat_info_stream_id=stream_id,
+ chat_info_platform=platform,
+ chat_info_create_time=0.0, # 将由 ChatStream 填充
+ chat_info_last_active_time=0.0, # 将由 ChatStream 填充
+ chat_info_user_id=user_id,
+ chat_info_user_nickname=user_nickname,
+ chat_info_user_cardname=user_cardname,
+ chat_info_user_platform=user_platform,
+ chat_info_group_id=group_id,
+ chat_info_group_name=group_name,
+ chat_info_group_platform=group_platform,
+ )
+
+ # 设置优先级信息
+ if processing_state.get("priority_mode"):
+ setattr(db_message, "priority_mode", processing_state["priority_mode"])
+ if processing_state.get("priority_info"):
+ setattr(db_message, "priority_info", processing_state["priority_info"])
+
+ # 设置其他运行时属性
+ setattr(db_message, "is_voice", bool(processing_state.get("is_voice", False)))
+ setattr(db_message, "is_video", bool(processing_state.get("is_video", False)))
+ setattr(db_message, "has_emoji", bool(processing_state.get("has_emoji", False)))
+ setattr(db_message, "has_picid", bool(processing_state.get("has_picid", False)))
+
+ return db_message
+
+
+async def _process_message_segments(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str:
+ """递归处理消息段,转换为文字描述
+
+ Args:
+ segment: 要处理的消息段
+ state: 处理状态字典(用于记录消息类型标记)
+ message_info: 消息基础信息(用于某些处理逻辑)
+
+ Returns:
+ str: 处理后的文本
+ """
+ if segment.type == "seglist":
+ # 处理消息段列表
+ segments_text = []
+ for seg in segment.data:
+ processed = await _process_message_segments(seg, state, message_info)
+ if processed:
+ segments_text.append(processed)
+ return " ".join(segments_text)
+ else:
+ # 处理单个消息段
+ return await _process_single_segment(segment, state, message_info)
+
+
+async def _process_single_segment(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str:
+ """处理单个消息段
+
+ Args:
+ segment: 消息段
+ state: 处理状态字典
+ message_info: 消息基础信息
+
+ Returns:
+ str: 处理后的文本
+ """
+ try:
+ if segment.type == "text":
+ state["is_picid"] = False
+ state["is_emoji"] = False
+ state["is_video"] = False
+ return segment.data
+
+ elif segment.type == "at":
+ state["is_picid"] = False
+ state["is_emoji"] = False
+ state["is_video"] = False
+ state["is_at"] = True
+ # 处理at消息,格式为"昵称:QQ号"
+ if isinstance(segment.data, str) and ":" in segment.data:
+ nickname, qq_id = segment.data.split(":", 1)
+ return f"@{nickname}"
+ return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户"
+
+ elif segment.type == "image":
+ # 如果是base64图片数据
+ if isinstance(segment.data, str):
+ state["has_picid"] = True
+ state["is_picid"] = True
+ state["is_emoji"] = False
+ state["is_video"] = False
+ image_manager = get_image_manager()
+ _, processed_text = await image_manager.process_image(segment.data)
+ return processed_text
+ return "[发了一张图片,网卡了加载不出来]"
+
+ elif segment.type == "emoji":
+ state["has_emoji"] = True
+ state["is_emoji"] = True
+ state["is_picid"] = False
+ state["is_voice"] = False
+ state["is_video"] = False
+ if isinstance(segment.data, str):
+ return await get_image_manager().get_emoji_description(segment.data)
+ return "[发了一个表情包,网卡了加载不出来]"
+
+ elif segment.type == "voice":
+ state["is_picid"] = False
+ state["is_emoji"] = False
+ state["is_voice"] = True
+ state["is_video"] = False
+
+ # 检查消息是否由机器人自己发送
+ if message_info and message_info.user_info and str(message_info.user_info.user_id) == str(global_config.bot.qq_account):
+ logger.info(f"检测到机器人自身发送的语音消息 (User ID: {message_info.user_info.user_id}),尝试从缓存获取文本。")
+ if isinstance(segment.data, str):
+ cached_text = consume_self_voice_text(segment.data)
+ if cached_text:
+ logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'")
+ return f"[语音:{cached_text}]"
+ else:
+ logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
+
+ # 标准语音识别流程
+ if isinstance(segment.data, str):
+ return await get_voice_text(segment.data)
+ return "[发了一段语音,网卡了加载不出来]"
+
+ elif segment.type == "mention_bot":
+ state["is_picid"] = False
+ state["is_emoji"] = False
+ state["is_voice"] = False
+ state["is_video"] = False
+ state["is_mentioned"] = float(segment.data)
+ return ""
+
+ elif segment.type == "priority_info":
+ state["is_picid"] = False
+ state["is_emoji"] = False
+ state["is_voice"] = False
+ if isinstance(segment.data, dict):
+ # 处理优先级信息
+ state["priority_mode"] = "priority"
+ state["priority_info"] = segment.data
+ return ""
+
+ elif segment.type == "file":
+ if isinstance(segment.data, dict):
+ file_name = segment.data.get('name', '未知文件')
+ file_size = segment.data.get('size', '未知大小')
+ return f"[文件:{file_name} ({file_size}字节)]"
+ return "[收到一个文件]"
+
+ elif segment.type == "video":
+ state["is_picid"] = False
+ state["is_emoji"] = False
+ state["is_voice"] = False
+ state["is_video"] = True
+ logger.info(f"接收到视频消息,数据类型: {type(segment.data)}")
+
+ # 检查视频分析功能是否可用
+ if not is_video_analysis_available():
+ logger.warning("⚠️ Rust视频处理模块不可用,跳过视频分析")
+ return "[视频]"
+
+ if global_config.video_analysis.enable:
+ logger.info("已启用视频识别,开始识别")
+ if isinstance(segment.data, dict):
+ try:
+ # 从Adapter接收的视频数据
+ video_base64 = segment.data.get("base64")
+ filename = segment.data.get("filename", "video.mp4")
+
+ logger.info(f"视频文件名: {filename}")
+ logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
+
+ if video_base64:
+ # 解码base64视频数据
+ video_bytes = base64.b64decode(video_base64)
+ logger.info(f"解码后视频大小: {len(video_bytes)} 字节")
+
+ # 使用video analyzer分析视频
+ video_analyzer = get_video_analyzer()
+ result = await video_analyzer.analyze_video_from_bytes(
+ video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt
+ )
+
+ logger.info(f"视频分析结果: {result}")
+
+ # 返回视频分析结果
+ summary = result.get("summary", "")
+ if summary:
+ return f"[视频内容] {summary}"
+ else:
+ return "[已收到视频,但分析失败]"
+ else:
+ logger.warning("视频消息中没有base64数据")
+ return "[收到视频消息,但数据异常]"
+ except Exception as e:
+ logger.error(f"视频处理失败: {e!s}")
+ import traceback
+ logger.error(f"错误详情: {traceback.format_exc()}")
+ return "[收到视频,但处理时出现错误]"
+ else:
+ logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}")
+ return "[发了一个视频,但格式不支持]"
+ else:
+ return ""
+ else:
+ logger.warning(f"未知的消息段类型: {segment.type}")
+ return f"[{segment.type} 消息]"
+
+ except Exception as e:
+ logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
+ return f"[处理失败的{segment.type}消息]"
+
+
+def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, is_public_notice: bool, notice_type: str | None) -> str | None:
+ """准备 additional_config,包含 format_info 和 notice 信息
+
+ Args:
+ message_info: 消息基础信息
+ is_notify: 是否为notice消息
+ is_public_notice: 是否为公共notice
+ notice_type: notice类型
+
+ Returns:
+ str | None: JSON 字符串格式的 additional_config,如果为空则返回 None
+ """
+ try:
+ additional_config_data = {}
+
+ # 首先获取adapter传递的additional_config
+ if hasattr(message_info, 'additional_config') and message_info.additional_config:
+ if isinstance(message_info.additional_config, dict):
+ additional_config_data = message_info.additional_config.copy()
+ elif isinstance(message_info.additional_config, str):
+ try:
+ additional_config_data = orjson.loads(message_info.additional_config)
+ except Exception as e:
+ logger.warning(f"无法解析 additional_config JSON: {e}")
+ additional_config_data = {}
+
+ # 添加notice相关标志
+ if is_notify:
+ additional_config_data["is_notice"] = True
+ additional_config_data["notice_type"] = notice_type or "unknown"
+ additional_config_data["is_public_notice"] = bool(is_public_notice)
+
+ # 添加format_info到additional_config中
+ if hasattr(message_info, 'format_info') and message_info.format_info:
+ try:
+ format_info_dict = message_info.format_info.to_dict()
+ additional_config_data["format_info"] = format_info_dict
+ logger.debug(f"[message_processor] 嵌入 format_info 到 additional_config: {format_info_dict}")
+ except Exception as e:
+ logger.warning(f"将 format_info 转换为字典失败: {e}")
+
+ # 序列化为JSON字符串
+ if additional_config_data:
+ return orjson.dumps(additional_config_data).decode("utf-8")
+ except Exception as e:
+ logger.error(f"准备 additional_config 失败: {e}")
+
+ return None
+
+
+def _extract_reply_from_segment(segment: Seg) -> str | None:
+ """从消息段中提取reply_to信息
+
+ Args:
+ segment: 消息段
+
+ Returns:
+ str | None: 回复的消息ID,如果没有则返回None
+ """
+ try:
+ if hasattr(segment, "type") and segment.type == "seglist":
+ # 递归搜索seglist中的reply段
+ if hasattr(segment, "data") and segment.data:
+ for seg in segment.data:
+ reply_id = _extract_reply_from_segment(seg)
+ if reply_id:
+ return reply_id
+ elif hasattr(segment, "type") and segment.type == "reply":
+ # 找到reply段,返回message_id
+ return str(segment.data) if segment.data else None
+ except Exception as e:
+ logger.warning(f"提取reply_to信息失败: {e}")
+ return None
+
+
+# =============================================================================
+# DatabaseMessages 扩展工具函数
+# =============================================================================
+
+def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessageInfo:
+ """从 DatabaseMessages 重建 BaseMessageInfo(用于需要 message_info 的遗留代码)
+
+ Args:
+ db_message: DatabaseMessages 对象
+
+ Returns:
+ BaseMessageInfo: 重建的消息信息对象
+ """
+ from maim_message import UserInfo, GroupInfo
+
+ # 从 DatabaseMessages 的 user_info 转换为 maim_message.UserInfo
+ user_info = UserInfo(
+ platform=db_message.user_info.platform,
+ user_id=db_message.user_info.user_id,
+ user_nickname=db_message.user_info.user_nickname,
+ user_cardname=db_message.user_info.user_cardname or ""
+ )
+
+ # 从 DatabaseMessages 的 group_info 转换为 maim_message.GroupInfo(如果存在)
+ group_info = None
+ if db_message.group_info:
+ group_info = GroupInfo(
+ platform=db_message.group_info.group_platform or "",
+ group_id=db_message.group_info.group_id,
+ group_name=db_message.group_info.group_name
+ )
+
+ # 解析 additional_config(从 JSON 字符串到字典)
+ additional_config = None
+ if db_message.additional_config:
+ try:
+ additional_config = orjson.loads(db_message.additional_config)
+ except Exception:
+ # 如果解析失败,保持为字符串
+ pass
+
+ # 创建 BaseMessageInfo
+ message_info = BaseMessageInfo(
+ platform=db_message.chat_info.platform,
+ message_id=db_message.message_id,
+ time=db_message.time,
+ user_info=user_info,
+ group_info=group_info,
+ additional_config=additional_config # type: ignore
+ )
+
+ return message_info
+
+
+def set_db_message_runtime_attr(db_message: DatabaseMessages, attr_name: str, value: Any) -> None:
+ """安全地为 DatabaseMessages 设置运行时属性
+
+ Args:
+ db_message: DatabaseMessages 对象
+ attr_name: 属性名
+ value: 属性值
+ """
+ setattr(db_message, attr_name, value)
+
+
+def get_db_message_runtime_attr(db_message: DatabaseMessages, attr_name: str, default: Any = None) -> Any:
+ """安全地获取 DatabaseMessages 的运行时属性
+
+ Args:
+ db_message: DatabaseMessages 对象
+ attr_name: 属性名
+ default: 默认值
+
+ Returns:
+ 属性值或默认值
+ """
+ return getattr(db_message, attr_name, default)
diff --git a/src/chat/message_receive/message_recv_backup.py b/src/chat/message_receive/message_recv_backup.py
new file mode 100644
index 000000000..3f1943ebf
--- /dev/null
+++ b/src/chat/message_receive/message_recv_backup.py
@@ -0,0 +1,434 @@
+# MessageRecv 类备份 - 已从 message.py 中移除
+# 备份日期: 2025-10-31
+# 此类已被 DatabaseMessages 完全取代
+
+# MessageRecv 类已被移除
+# 现在所有消息处理都使用 DatabaseMessages
+# 如果需要从消息字典创建 DatabaseMessages,请使用 message_processor.process_message_from_dict()
+#
+# 历史参考: MessageRecv 曾经是接收消息的包装类,现已被 DatabaseMessages 完全取代
+# 迁移完成日期: 2025-10-31
+
+"""
+# 以下是已删除的 MessageRecv 类(保留作为参考)
+class MessageRecv:
+ 接收消息类 - DatabaseMessages 的轻量级包装器
+
+ 这个类现在主要作为适配器层,处理外部消息格式并内部使用 DatabaseMessages。
+ 保留此类是为了向后兼容性和处理 message_segment 的异步逻辑。
+"""
+
+ def __init__(self, message_dict: dict[str, Any]):
+ """从MessageCQ的字典初始化
+
+ Args:
+ message_dict: MessageCQ序列化后的字典
+ """
+ # 保留原始消息信息用于某些场景
+ self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
+ self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
+ self.raw_message = message_dict.get("raw_message")
+
+ # 处理状态(在process()之前临时使用)
+ self._processing_state = {
+ "is_emoji": False,
+ "has_emoji": False,
+ "is_picid": False,
+ "has_picid": False,
+ "is_voice": False,
+ "is_video": False,
+ "is_mentioned": None,
+ "is_at": False,
+ "priority_mode": "interest",
+ "priority_info": None,
+ }
+
+ self.chat_stream = None
+ self.reply = None
+ self.processed_plain_text = message_dict.get("processed_plain_text", "")
+
+ # 解析additional_config中的notice信息
+ self.is_notify = False
+ self.is_public_notice = False
+ self.notice_type = None
+ if self.message_info.additional_config and isinstance(self.message_info.additional_config, dict):
+ self.is_notify = self.message_info.additional_config.get("is_notice", False)
+ self.is_public_notice = self.message_info.additional_config.get("is_public_notice", False)
+ self.notice_type = self.message_info.additional_config.get("notice_type")
+
+ # 兼容性属性 - 代理到 _processing_state
+ @property
+ def is_emoji(self) -> bool:
+ return self._processing_state["is_emoji"]
+
+ @is_emoji.setter
+ def is_emoji(self, value: bool):
+ self._processing_state["is_emoji"] = value
+
+ @property
+ def has_emoji(self) -> bool:
+ return self._processing_state["has_emoji"]
+
+ @has_emoji.setter
+ def has_emoji(self, value: bool):
+ self._processing_state["has_emoji"] = value
+
+ @property
+ def is_picid(self) -> bool:
+ return self._processing_state["is_picid"]
+
+ @is_picid.setter
+ def is_picid(self, value: bool):
+ self._processing_state["is_picid"] = value
+
+ @property
+ def has_picid(self) -> bool:
+ return self._processing_state["has_picid"]
+
+ @has_picid.setter
+ def has_picid(self, value: bool):
+ self._processing_state["has_picid"] = value
+
+ @property
+ def is_voice(self) -> bool:
+ return self._processing_state["is_voice"]
+
+ @is_voice.setter
+ def is_voice(self, value: bool):
+ self._processing_state["is_voice"] = value
+
+ @property
+ def is_video(self) -> bool:
+ return self._processing_state["is_video"]
+
+ @is_video.setter
+ def is_video(self, value: bool):
+ self._processing_state["is_video"] = value
+
+ @property
+ def is_mentioned(self):
+ return self._processing_state["is_mentioned"]
+
+ @is_mentioned.setter
+ def is_mentioned(self, value):
+ self._processing_state["is_mentioned"] = value
+
+ @property
+ def is_at(self) -> bool:
+ return self._processing_state["is_at"]
+
+ @is_at.setter
+ def is_at(self, value: bool):
+ self._processing_state["is_at"] = value
+
+ @property
+ def priority_mode(self) -> str:
+ return self._processing_state["priority_mode"]
+
+ @priority_mode.setter
+ def priority_mode(self, value: str):
+ self._processing_state["priority_mode"] = value
+
+ @property
+ def priority_info(self):
+ return self._processing_state["priority_info"]
+
+ @priority_info.setter
+ def priority_info(self, value):
+ self._processing_state["priority_info"] = value
+
+ # 其他常用属性
+ interest_value: float = 0.0
+ is_command: bool = False
+ memorized_times: int = 0
+
+ def __post_init__(self):
+ """dataclass 初始化后处理"""
+ self.key_words = []
+ self.key_words_lite = []
+
+ def update_chat_stream(self, chat_stream: "ChatStream"):
+ self.chat_stream = chat_stream
+
+ def to_database_message(self) -> "DatabaseMessages":
+ """将 MessageRecv 转换为 DatabaseMessages 对象
+
+ Returns:
+ DatabaseMessages: 数据库消息对象
+ """
+ import time
+
+ message_info = self.message_info
+ msg_user_info = getattr(message_info, "user_info", None)
+ stream_user_info = getattr(self.chat_stream, "user_info", None) if self.chat_stream else None
+ group_info = getattr(self.chat_stream, "group_info", None) if self.chat_stream else None
+
+ message_id = message_info.message_id or ""
+ message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time()
+ is_mentioned = None
+ if isinstance(self.is_mentioned, bool):
+ is_mentioned = self.is_mentioned
+ elif isinstance(self.is_mentioned, int | float):
+ is_mentioned = self.is_mentioned != 0
+
+ # 提取用户信息
+ user_id = ""
+ user_nickname = ""
+ user_cardname = None
+ user_platform = ""
+ if msg_user_info:
+ user_id = str(getattr(msg_user_info, "user_id", "") or "")
+ user_nickname = getattr(msg_user_info, "user_nickname", "") or ""
+ user_cardname = getattr(msg_user_info, "user_cardname", None)
+ user_platform = getattr(msg_user_info, "platform", "") or ""
+ elif stream_user_info:
+ user_id = str(getattr(stream_user_info, "user_id", "") or "")
+ user_nickname = getattr(stream_user_info, "user_nickname", "") or ""
+ user_cardname = getattr(stream_user_info, "user_cardname", None)
+ user_platform = getattr(stream_user_info, "platform", "") or ""
+
+ # 提取聊天流信息
+ chat_user_id = str(getattr(stream_user_info, "user_id", "") or "") if stream_user_info else ""
+ chat_user_nickname = getattr(stream_user_info, "user_nickname", "") or "" if stream_user_info else ""
+ chat_user_cardname = getattr(stream_user_info, "user_cardname", None) if stream_user_info else None
+ chat_user_platform = getattr(stream_user_info, "platform", "") or "" if stream_user_info else ""
+
+ group_id = getattr(group_info, "group_id", None) if group_info else None
+ group_name = getattr(group_info, "group_name", None) if group_info else None
+ group_platform = getattr(group_info, "platform", None) if group_info else None
+
+ # 准备 additional_config
+ additional_config_str = None
+ try:
+ import orjson
+
+ additional_config_data = {}
+
+ # 首先获取adapter传递的additional_config
+ if hasattr(message_info, 'additional_config') and message_info.additional_config:
+ if isinstance(message_info.additional_config, dict):
+ additional_config_data = message_info.additional_config.copy()
+ elif isinstance(message_info.additional_config, str):
+ try:
+ additional_config_data = orjson.loads(message_info.additional_config)
+ except Exception as e:
+ logger.warning(f"无法解析 additional_config JSON: {e}")
+ additional_config_data = {}
+
+ # 添加notice相关标志
+ if self.is_notify:
+ additional_config_data["is_notice"] = True
+ additional_config_data["notice_type"] = self.notice_type or "unknown"
+ additional_config_data["is_public_notice"] = bool(self.is_public_notice)
+
+ # 添加format_info到additional_config中
+ if hasattr(message_info, 'format_info') and message_info.format_info:
+ try:
+ format_info_dict = message_info.format_info.to_dict()
+ additional_config_data["format_info"] = format_info_dict
+ logger.debug(f"[message.py] 嵌入 format_info 到 additional_config: {format_info_dict}")
+ except Exception as e:
+ logger.warning(f"将 format_info 转换为字典失败: {e}")
+
+ # 序列化为JSON字符串
+ if additional_config_data:
+ additional_config_str = orjson.dumps(additional_config_data).decode("utf-8")
+ except Exception as e:
+ logger.error(f"准备 additional_config 失败: {e}")
+
+ # 创建数据库消息对象
+ db_message = DatabaseMessages(
+ message_id=message_id,
+ time=float(message_time),
+ chat_id=self.chat_stream.stream_id if self.chat_stream else "",
+ processed_plain_text=self.processed_plain_text,
+ display_message=self.processed_plain_text,
+ is_mentioned=is_mentioned,
+ is_at=bool(self.is_at) if self.is_at is not None else None,
+ is_emoji=bool(self.is_emoji),
+ is_picid=bool(self.is_picid),
+ is_command=bool(self.is_command),
+ is_notify=bool(self.is_notify),
+ is_public_notice=bool(self.is_public_notice),
+ notice_type=self.notice_type,
+ additional_config=additional_config_str,
+ user_id=user_id,
+ user_nickname=user_nickname,
+ user_cardname=user_cardname,
+ user_platform=user_platform,
+ chat_info_stream_id=self.chat_stream.stream_id if self.chat_stream else "",
+ chat_info_platform=self.chat_stream.platform if self.chat_stream else "",
+ chat_info_create_time=float(self.chat_stream.create_time) if self.chat_stream else 0.0,
+ chat_info_last_active_time=float(self.chat_stream.last_active_time) if self.chat_stream else 0.0,
+ chat_info_user_id=chat_user_id,
+ chat_info_user_nickname=chat_user_nickname,
+ chat_info_user_cardname=chat_user_cardname,
+ chat_info_user_platform=chat_user_platform,
+ chat_info_group_id=group_id,
+ chat_info_group_name=group_name,
+ chat_info_group_platform=group_platform,
+ )
+
+ # 同步兴趣度等衍生属性
+ db_message.interest_value = getattr(self, "interest_value", 0.0)
+ setattr(db_message, "should_reply", getattr(self, "should_reply", False))
+ setattr(db_message, "should_act", getattr(self, "should_act", False))
+
+ return db_message
+
+ async def process(self) -> None:
+ """处理消息内容,生成纯文本和详细文本
+
+ 这个方法必须在创建实例后显式调用,因为它包含异步操作。
+ """
+ self.processed_plain_text = await self._process_message_segments(self.message_segment)
+
+ async def _process_single_segment(self, segment: Seg) -> str:
+ """处理单个消息段
+
+ Args:
+ segment: 消息段
+
+ Returns:
+ str: 处理后的文本
+ """
+ try:
+ if segment.type == "text":
+ self.is_picid = False
+ self.is_emoji = False
+ self.is_video = False
+ return segment.data # type: ignore
+ elif segment.type == "at":
+ self.is_picid = False
+ self.is_emoji = False
+ self.is_video = False
+ # 处理at消息,格式为"昵称:QQ号"
+ if isinstance(segment.data, str) and ":" in segment.data:
+ nickname, qq_id = segment.data.split(":", 1)
+ return f"@{nickname}"
+ return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户"
+ elif segment.type == "image":
+ # 如果是base64图片数据
+ if isinstance(segment.data, str):
+ self.has_picid = True
+ self.is_picid = True
+ self.is_emoji = False
+ self.is_video = False
+ image_manager = get_image_manager()
+ # print(f"segment.data: {segment.data}")
+ _, processed_text = await image_manager.process_image(segment.data)
+ return processed_text
+ return "[发了一张图片,网卡了加载不出来]"
+ elif segment.type == "emoji":
+ self.has_emoji = True
+ self.is_emoji = True
+ self.is_picid = False
+ self.is_voice = False
+ self.is_video = False
+ if isinstance(segment.data, str):
+ return await get_image_manager().get_emoji_description(segment.data)
+ return "[发了一个表情包,网卡了加载不出来]"
+ elif segment.type == "voice":
+ self.is_picid = False
+ self.is_emoji = False
+ self.is_voice = True
+ self.is_video = False
+
+ # 检查消息是否由机器人自己发送
+ if self.message_info and self.message_info.user_info and str(self.message_info.user_info.user_id) == str(global_config.bot.qq_account):
+ logger.info(f"检测到机器人自身发送的语音消息 (User ID: {self.message_info.user_info.user_id}),尝试从缓存获取文本。")
+ if isinstance(segment.data, str):
+ cached_text = consume_self_voice_text(segment.data)
+ if cached_text:
+ logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'")
+ return f"[语音:{cached_text}]"
+ else:
+ logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
+
+ # 标准语音识别流程 (也作为缓存未命中的后备方案)
+ if isinstance(segment.data, str):
+ return await get_voice_text(segment.data)
+ return "[发了一段语音,网卡了加载不出来]"
+ elif segment.type == "mention_bot":
+ self.is_picid = False
+ self.is_emoji = False
+ self.is_voice = False
+ self.is_video = False
+ self.is_mentioned = float(segment.data) # type: ignore
+ return ""
+ elif segment.type == "priority_info":
+ self.is_picid = False
+ self.is_emoji = False
+ self.is_voice = False
+ if isinstance(segment.data, dict):
+ # 处理优先级信息
+ self.priority_mode = "priority"
+ self.priority_info = segment.data
+ """
+ {
+ 'message_type': 'vip', # vip or normal
+ 'message_priority': 1.0, # 优先级,大为优先,float
+ }
+ """
+ return ""
+ elif segment.type == "file":
+ if isinstance(segment.data, dict):
+ file_name = segment.data.get('name', '未知文件')
+ file_size = segment.data.get('size', '未知大小')
+ return f"[文件:{file_name} ({file_size}字节)]"
+ return "[收到一个文件]"
+ elif segment.type == "video":
+ self.is_picid = False
+ self.is_emoji = False
+ self.is_voice = False
+ self.is_video = True
+ logger.info(f"接收到视频消息,数据类型: {type(segment.data)}")
+
+ # 检查视频分析功能是否可用
+ if not is_video_analysis_available():
+ logger.warning("⚠️ Rust视频处理模块不可用,跳过视频分析")
+ return "[视频]"
+
+ if global_config.video_analysis.enable:
+ logger.info("已启用视频识别,开始识别")
+ if isinstance(segment.data, dict):
+ try:
+ # 从Adapter接收的视频数据
+ video_base64 = segment.data.get("base64")
+ filename = segment.data.get("filename", "video.mp4")
+
+ logger.info(f"视频文件名: {filename}")
+ logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
+
+ if video_base64:
+ # 解码base64视频数据
+ video_bytes = base64.b64decode(video_base64)
+ logger.info(f"解码后视频大小: {len(video_bytes)} 字节")
+
+ # 使用video analyzer分析视频
+ video_analyzer = get_video_analyzer()
+ result = await video_analyzer.analyze_video_from_bytes(
+ video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt
+ )
+
+ logger.info(f"视频分析结果: {result}")
+
+ # 返回视频分析结果
+ summary = result.get("summary", "")
+ if summary:
+ return f"[视频内容] {summary}"
+ else:
+ return "[已收到视频,但分析失败]"
+ else:
+ logger.warning("视频消息中没有base64数据")
+ return "[收到视频消息,但数据异常]"
+ except Exception as e:
+ logger.error(f"视频处理失败: {e!s}")
+ import traceback
+
+ logger.error(f"错误详情: {traceback.format_exc()}")
+ return "[收到视频,但处理时出现错误]"
+ else:
+ logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}")
+ return "[发了一个视频,但格式不支持]"
+ else:
diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py
index edf9bb9c8..9b2b54991 100644
--- a/src/chat/message_receive/storage.py
+++ b/src/chat/message_receive/storage.py
@@ -9,8 +9,10 @@ from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Images, Messages
from src.common.logger import get_logger
+from src.common.data_models.database_data_model import DatabaseMessages
+
from .chat_stream import ChatStream
-from .message import MessageRecv, MessageSending
+from .message import MessageSending
logger = get_logger("message_storage")
@@ -34,97 +36,166 @@ class MessageStorage:
return []
@staticmethod
- async def store_message(message: MessageSending | MessageRecv, chat_stream: ChatStream) -> None:
+ async def store_message(message: DatabaseMessages | MessageSending, chat_stream: ChatStream) -> None:
"""存储消息到数据库"""
try:
# 过滤敏感信息的正则模式
pattern = r".*?|.*?|.*?"
- processed_plain_text = message.processed_plain_text
-
- if processed_plain_text:
- processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
- # 增加对None的防御性处理
- safe_processed_plain_text = processed_plain_text or ""
- filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
- else:
- filtered_processed_plain_text = ""
-
- if isinstance(message, MessageSending):
- display_message = message.display_message
- if display_message:
- filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
+ # 如果是 DatabaseMessages,直接使用它的字段
+ if isinstance(message, DatabaseMessages):
+ processed_plain_text = message.processed_plain_text
+ if processed_plain_text:
+ processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
+ safe_processed_plain_text = processed_plain_text or ""
+ filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
else:
- # 如果没有设置display_message,使用processed_plain_text作为显示消息
- filtered_display_message = (
- re.sub(pattern, "", (message.processed_plain_text or ""), flags=re.DOTALL)
- )
- interest_value = 0
- is_mentioned = False
- reply_to = message.reply_to
- priority_mode = ""
- priority_info = {}
- is_emoji = False
- is_picid = False
- is_notify = False
- is_command = False
- key_words = ""
- key_words_lite = ""
- else:
- filtered_display_message = ""
- interest_value = message.interest_value
+ filtered_processed_plain_text = ""
+
+ display_message = message.display_message or message.processed_plain_text or ""
+ filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
+
+ # 直接从 DatabaseMessages 获取所有字段
+ msg_id = message.message_id
+ msg_time = message.time
+ chat_id = message.chat_id
+ reply_to = "" # DatabaseMessages 没有 reply_to 字段
is_mentioned = message.is_mentioned
- reply_to = ""
- priority_mode = message.priority_mode
- priority_info = message.priority_info
- is_emoji = message.is_emoji
- is_picid = message.is_picid
- is_notify = message.is_notify
- is_command = message.is_command
- # 序列化关键词列表为JSON字符串
- key_words = MessageStorage._serialize_keywords(message.key_words)
- key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
+ interest_value = message.interest_value or 0.0
+ priority_mode = "" # DatabaseMessages 没有 priority_mode
+ priority_info_json = None # DatabaseMessages 没有 priority_info
+ is_emoji = message.is_emoji or False
+ is_picid = message.is_picid or False
+ is_notify = message.is_notify or False
+ is_command = message.is_command or False
+ key_words = "" # DatabaseMessages 没有 key_words
+ key_words_lite = ""
+ memorized_times = 0 # DatabaseMessages 没有 memorized_times
+
+ # 使用 DatabaseMessages 中的嵌套对象信息
+ user_platform = message.user_info.platform if message.user_info else ""
+ user_id = message.user_info.user_id if message.user_info else ""
+ user_nickname = message.user_info.user_nickname if message.user_info else ""
+ user_cardname = message.user_info.user_cardname if message.user_info else None
+
+ chat_info_stream_id = message.chat_info.stream_id if message.chat_info else ""
+ chat_info_platform = message.chat_info.platform if message.chat_info else ""
+ chat_info_create_time = message.chat_info.create_time if message.chat_info else 0.0
+ chat_info_last_active_time = message.chat_info.last_active_time if message.chat_info else 0.0
+ chat_info_user_platform = message.chat_info.user_info.platform if message.chat_info and message.chat_info.user_info else ""
+ chat_info_user_id = message.chat_info.user_info.user_id if message.chat_info and message.chat_info.user_info else ""
+ chat_info_user_nickname = message.chat_info.user_info.user_nickname if message.chat_info and message.chat_info.user_info else ""
+ chat_info_user_cardname = message.chat_info.user_info.user_cardname if message.chat_info and message.chat_info.user_info else None
+ chat_info_group_platform = message.group_info.group_platform if message.group_info else None
+ chat_info_group_id = message.group_info.group_id if message.group_info else None
+ chat_info_group_name = message.group_info.group_name if message.group_info else None
+
+ else:
+ # MessageSending 处理逻辑
+ processed_plain_text = message.processed_plain_text
- chat_info_dict = chat_stream.to_dict()
- user_info_dict = message.message_info.user_info.to_dict() # type: ignore
+ if processed_plain_text:
+ processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
+ # 增加对None的防御性处理
+ safe_processed_plain_text = processed_plain_text or ""
+ filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
+ else:
+ filtered_processed_plain_text = ""
- # message_id 现在是 TextField,直接使用字符串值
- msg_id = message.message_info.message_id
+ if isinstance(message, MessageSending):
+ display_message = message.display_message
+ if display_message:
+ filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
+ else:
+ # 如果没有设置display_message,使用processed_plain_text作为显示消息
+ filtered_display_message = (
+ re.sub(pattern, "", (message.processed_plain_text or ""), flags=re.DOTALL)
+ )
+ interest_value = 0
+ is_mentioned = False
+ reply_to = message.reply_to
+ priority_mode = ""
+ priority_info = {}
+ is_emoji = False
+ is_picid = False
+ is_notify = False
+ is_command = False
+ key_words = ""
+ key_words_lite = ""
+ else:
+ filtered_display_message = ""
+ interest_value = message.interest_value
+ is_mentioned = message.is_mentioned
+ reply_to = ""
+ priority_mode = message.priority_mode
+ priority_info = message.priority_info
+ is_emoji = message.is_emoji
+ is_picid = message.is_picid
+ is_notify = message.is_notify
+ is_command = message.is_command
+ # 序列化关键词列表为JSON字符串
+ key_words = MessageStorage._serialize_keywords(message.key_words)
+ key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
- # 安全地获取 group_info, 如果为 None 则视为空字典
- group_info_from_chat = chat_info_dict.get("group_info") or {}
- # 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
- user_info_from_chat = chat_info_dict.get("user_info") or {}
+ chat_info_dict = chat_stream.to_dict()
+ user_info_dict = message.message_info.user_info.to_dict() # type: ignore
- # 将priority_info字典序列化为JSON字符串,以便存储到数据库的Text字段
- priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None
+ # message_id 现在是 TextField,直接使用字符串值
+ msg_id = message.message_info.message_id
+ msg_time = float(message.message_info.time or time.time())
+ chat_id = chat_stream.stream_id
+ memorized_times = message.memorized_times
+
+ # 安全地获取 group_info, 如果为 None 则视为空字典
+ group_info_from_chat = chat_info_dict.get("group_info") or {}
+ # 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
+ user_info_from_chat = chat_info_dict.get("user_info") or {}
+
+ # 将priority_info字典序列化为JSON字符串,以便存储到数据库的Text字段
+ priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None
+
+ user_platform = user_info_dict.get("platform")
+ user_id = user_info_dict.get("user_id")
+ user_nickname = user_info_dict.get("user_nickname")
+ user_cardname = user_info_dict.get("user_cardname")
+
+ chat_info_stream_id = chat_info_dict.get("stream_id")
+ chat_info_platform = chat_info_dict.get("platform")
+ chat_info_create_time = float(chat_info_dict.get("create_time", 0.0))
+ chat_info_last_active_time = float(chat_info_dict.get("last_active_time", 0.0))
+ chat_info_user_platform = user_info_from_chat.get("platform")
+ chat_info_user_id = user_info_from_chat.get("user_id")
+ chat_info_user_nickname = user_info_from_chat.get("user_nickname")
+ chat_info_user_cardname = user_info_from_chat.get("user_cardname")
+ chat_info_group_platform = group_info_from_chat.get("platform")
+ chat_info_group_id = group_info_from_chat.get("group_id")
+ chat_info_group_name = group_info_from_chat.get("group_name")
# 获取数据库会话
-
new_message = Messages(
message_id=msg_id,
- time=float(message.message_info.time or time.time()),
- chat_id=chat_stream.stream_id,
+ time=msg_time,
+ chat_id=chat_id,
reply_to=reply_to,
is_mentioned=is_mentioned,
- chat_info_stream_id=chat_info_dict.get("stream_id"),
- chat_info_platform=chat_info_dict.get("platform"),
- chat_info_user_platform=user_info_from_chat.get("platform"),
- chat_info_user_id=user_info_from_chat.get("user_id"),
- chat_info_user_nickname=user_info_from_chat.get("user_nickname"),
- chat_info_user_cardname=user_info_from_chat.get("user_cardname"),
- chat_info_group_platform=group_info_from_chat.get("platform"),
- chat_info_group_id=group_info_from_chat.get("group_id"),
- chat_info_group_name=group_info_from_chat.get("group_name"),
- chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)),
- chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)),
- user_platform=user_info_dict.get("platform"),
- user_id=user_info_dict.get("user_id"),
- user_nickname=user_info_dict.get("user_nickname"),
- user_cardname=user_info_dict.get("user_cardname"),
+ chat_info_stream_id=chat_info_stream_id,
+ chat_info_platform=chat_info_platform,
+ chat_info_user_platform=chat_info_user_platform,
+ chat_info_user_id=chat_info_user_id,
+ chat_info_user_nickname=chat_info_user_nickname,
+ chat_info_user_cardname=chat_info_user_cardname,
+ chat_info_group_platform=chat_info_group_platform,
+ chat_info_group_id=chat_info_group_id,
+ chat_info_group_name=chat_info_group_name,
+ chat_info_create_time=chat_info_create_time,
+ chat_info_last_active_time=chat_info_last_active_time,
+ user_platform=user_platform,
+ user_id=user_id,
+ user_nickname=user_nickname,
+ user_cardname=user_cardname,
processed_plain_text=filtered_processed_plain_text,
display_message=filtered_display_message,
- memorized_times=message.memorized_times,
+ memorized_times=memorized_times,
interest_value=interest_value,
priority_mode=priority_mode,
priority_info=priority_info_json,
@@ -145,36 +216,43 @@ class MessageStorage:
traceback.print_exc()
@staticmethod
- async def update_message(message):
- """更新消息ID"""
+ async def update_message(message_data: dict):
+ """更新消息ID(从消息字典)"""
try:
- mmc_message_id = message.message_info.message_id
+ # 从字典中提取信息
+ message_info = message_data.get("message_info", {})
+ mmc_message_id = message_info.get("message_id")
+
+ message_segment = message_data.get("message_segment", {})
+ segment_type = message_segment.get("type") if isinstance(message_segment, dict) else None
+ segment_data = message_segment.get("data", {}) if isinstance(message_segment, dict) else {}
+
qq_message_id = None
- logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {message.message_segment.type}")
+ logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {segment_type}")
# 根据消息段类型提取message_id
- if message.message_segment.type == "notify":
- qq_message_id = message.message_segment.data.get("id")
- elif message.message_segment.type == "text":
- qq_message_id = message.message_segment.data.get("id")
- elif message.message_segment.type == "reply":
- qq_message_id = message.message_segment.data.get("id")
+ if segment_type == "notify":
+ qq_message_id = segment_data.get("id")
+ elif segment_type == "text":
+ qq_message_id = segment_data.get("id")
+ elif segment_type == "reply":
+ qq_message_id = segment_data.get("id")
if qq_message_id:
logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}")
- elif message.message_segment.type == "adapter_response":
+ elif segment_type == "adapter_response":
logger.debug("适配器响应消息,不需要更新ID")
return
- elif message.message_segment.type == "adapter_command":
+ elif segment_type == "adapter_command":
logger.debug("适配器命令消息,不需要更新ID")
return
else:
- logger.debug(f"未知的消息段类型: {message.message_segment.type},跳过ID更新")
+ logger.debug(f"未知的消息段类型: {segment_type},跳过ID更新")
return
if not qq_message_id:
- logger.debug(f"消息段类型 {message.message_segment.type} 中未找到有效的message_id,跳过更新")
- logger.debug(f"消息段数据: {message.message_segment.data}")
+ logger.debug(f"消息段类型 {segment_type} 中未找到有效的message_id,跳过更新")
+ logger.debug(f"消息段数据: {segment_data}")
return
# 使用上下文管理器确保session正确管理
diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py
index 35a17d675..7ea2b4785 100644
--- a/src/chat/planner_actions/action_modifier.py
+++ b/src/chat/planner_actions/action_modifier.py
@@ -137,7 +137,7 @@ class ActionModifier:
logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用")
# === 第二阶段:检查动作的关联类型 ===
- chat_context = self.chat_stream.stream_context
+ chat_context = self.chat_stream.context_manager.context
current_actions_s2 = self.action_manager.get_using_actions()
type_mismatched_actions = self._check_action_associated_types(current_actions_s2, chat_context)
diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py
index 8a07948fb..66f33ce09 100644
--- a/src/chat/replyer/default_generator.py
+++ b/src/chat/replyer/default_generator.py
@@ -13,7 +13,7 @@ from typing import Any
from src.chat.express.expression_selector import expression_selector
from src.chat.message_receive.chat_stream import ChatStream
-from src.chat.message_receive.message import MessageRecv, MessageSending, Seg, UserInfo
+from src.chat.message_receive.message import MessageSending, Seg, UserInfo
from src.chat.message_receive.uni_message_sender import HeartFCSender
from src.chat.utils.chat_message_builder import (
build_readable_messages,
@@ -1733,7 +1733,7 @@ class DefaultReplyer:
is_emoji: bool,
thinking_start_time: float,
display_message: str,
- anchor_message: MessageRecv | None = None,
+ anchor_message: DatabaseMessages | None = None,
) -> MessageSending:
"""构建单个发送消息"""
@@ -1743,8 +1743,11 @@ class DefaultReplyer:
platform=self.chat_stream.platform,
)
- # await anchor_message.process()
- sender_info = anchor_message.message_info.user_info if anchor_message else None
+ # 从 DatabaseMessages 获取 sender_info
+ if anchor_message:
+ sender_info = anchor_message.user_info
+ else:
+ sender_info = None
return MessageSending(
message_id=message_id, # 使用片段的唯一ID
diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py
index 496e50673..bc0641c9d 100644
--- a/src/chat/utils/utils.py
+++ b/src/chat/utils/utils.py
@@ -11,7 +11,7 @@ import rjieba
from maim_message import UserInfo
from src.chat.message_receive.chat_stream import get_chat_manager
-from src.chat.message_receive.message import MessageRecv
+# MessageRecv 已被移除,现在使用 DatabaseMessages
from src.common.logger import get_logger
from src.common.message_repository import count_messages, find_messages
from src.config.config import global_config, model_config
@@ -41,34 +41,58 @@ def db_message_to_str(message_dict: dict) -> str:
return result
-def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
- """检查消息是否提到了机器人"""
+def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
+ """检查消息是否提到了机器人
+
+ Args:
+ message: DatabaseMessages 消息对象
+
+ Returns:
+ tuple[bool, float]: (是否提及, 提及概率)
+ """
keywords = [global_config.bot.nickname]
nicknames = global_config.bot.alias_names
reply_probability = 0.0
is_at = False
is_mentioned = False
- if message.is_mentioned is not None:
- return bool(message.is_mentioned), message.is_mentioned
- if (
- message.message_info.additional_config is not None
- and message.message_info.additional_config.get("is_mentioned") is not None
- ):
+
+ # 检查 is_mentioned 属性
+ mentioned_attr = getattr(message, "is_mentioned", None)
+ if mentioned_attr is not None:
try:
- reply_probability = float(message.message_info.additional_config.get("is_mentioned")) # type: ignore
+ return bool(mentioned_attr), float(mentioned_attr)
+ except (ValueError, TypeError):
+ pass
+
+ # 检查 additional_config
+ additional_config = None
+
+ # DatabaseMessages: additional_config 是 JSON 字符串
+ if message.additional_config:
+ try:
+ import orjson
+ additional_config = orjson.loads(message.additional_config)
+ except Exception:
+ pass
+
+ if additional_config and additional_config.get("is_mentioned") is not None:
+ try:
+ reply_probability = float(additional_config.get("is_mentioned")) # type: ignore
is_mentioned = True
return is_mentioned, reply_probability
except Exception as e:
logger.warning(str(e))
logger.warning(
- f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}"
+ f"消息中包含不合理的设置 is_mentioned: {additional_config.get('is_mentioned')}"
)
- if global_config.bot.nickname in message.processed_plain_text:
+ # 检查消息文本内容
+ processed_text = message.processed_plain_text or ""
+ if global_config.bot.nickname in processed_text:
is_mentioned = True
for alias_name in global_config.bot.alias_names:
- if alias_name in message.processed_plain_text:
+ if alias_name in processed_text:
is_mentioned = True
# 判断是否被@
@@ -110,7 +134,6 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
logger.debug("被提及,回复概率设置为100%")
return is_mentioned, reply_probability
-
async def get_embedding(text, request_type="embedding") -> list[float] | None:
"""获取文本的embedding向量"""
# 每次都创建新的LLMRequest实例以避免事件循环冲突
diff --git a/src/common/database/db_migration.py b/src/common/database/db_migration.py
index 2ab7ba13e..fad348bf9 100644
--- a/src/common/database/db_migration.py
+++ b/src/common/database/db_migration.py
@@ -9,15 +9,18 @@ from src.common.logger import get_logger
logger = get_logger("db_migration")
-async def check_and_migrate_database():
+async def check_and_migrate_database(existing_engine=None):
"""
异步检查数据库结构并自动迁移。
- 自动创建不存在的表。
- 自动为现有表添加缺失的列。
- 自动为现有表创建缺失的索引。
+
+ Args:
+ existing_engine: 可选的已存在的数据库引擎。如果提供,将使用该引擎;否则获取全局引擎。
"""
logger.info("正在检查数据库结构并执行自动迁移...")
- engine = await get_engine()
+ engine = existing_engine if existing_engine is not None else await get_engine()
async with engine.connect() as connection:
# 在同步上下文中运行inspector操作
diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py
index 9f03aa43c..287f0fc29 100644
--- a/src/common/database/sqlalchemy_models.py
+++ b/src/common/database/sqlalchemy_models.py
@@ -780,12 +780,8 @@ async def initialize_database() -> tuple["AsyncEngine", async_sessionmaker[Async
_SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False)
# 迁移
- try:
- from src.common.database.db_migration import check_and_migrate_database
- await check_and_migrate_database(existing_engine=_engine)
- except TypeError:
- from src.common.database.db_migration import check_and_migrate_database as _legacy_migrate
- await _legacy_migrate()
+ from src.common.database.db_migration import check_and_migrate_database
+ await check_and_migrate_database(existing_engine=_engine)
if config.database_type == "sqlite":
await enable_sqlite_wal_mode(_engine)
diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py
index 14f1dfef5..ef52b93a1 100644
--- a/src/mood/mood_manager.py
+++ b/src/mood/mood_manager.py
@@ -2,7 +2,6 @@ import math
import random
import time
-from src.chat.message_receive.message import MessageRecv
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.data_models.database_data_model import DatabaseMessages
@@ -98,7 +97,7 @@ class ChatMood:
if not hasattr(self, "last_change_time"):
self.last_change_time = 0
- async def update_mood_by_message(self, message: MessageRecv | DatabaseMessages, interested_rate: float):
+ async def update_mood_by_message(self, message: DatabaseMessages, interested_rate: float):
# 确保异步初始化已完成
await self._initialize()
@@ -109,11 +108,8 @@ class ChatMood:
self.regression_count = 0
- # 处理不同类型的消息对象
- if isinstance(message, MessageRecv):
- message_time = message.message_info.time
- else: # DatabaseMessages
- message_time = message.time
+ # 使用 DatabaseMessages 的时间字段
+ message_time = message.time
# 防止负时间差
during_last_time = max(0, message_time - self.last_change_time)
diff --git a/src/plugin_system/apis/send_api.py b/src/plugin_system/apis/send_api.py
index 96f0e4b09..71786562d 100644
--- a/src/plugin_system/apis/send_api.py
+++ b/src/plugin_system/apis/send_api.py
@@ -86,13 +86,16 @@ async def file_to_stream(
import asyncio
import time
import traceback
-from typing import Any
+from typing import Any, TYPE_CHECKING
from maim_message import Seg, UserInfo
+if TYPE_CHECKING:
+ from src.common.data_models.database_data_model import DatabaseMessages
+
# 导入依赖
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
-from src.chat.message_receive.message import MessageRecv, MessageSending
+from src.chat.message_receive.message import MessageSending
from src.chat.message_receive.uni_message_sender import HeartFCSender
from src.common.logger import get_logger
from src.config.config import global_config
@@ -104,84 +107,53 @@ logger = get_logger("send_api")
_adapter_response_pool: dict[str, asyncio.Future] = {}
-def message_dict_to_message_recv(message_dict: dict[str, Any]) -> MessageRecv | None:
- """查找要回复的消息
+def message_dict_to_db_message(message_dict: dict[str, Any]) -> "DatabaseMessages | None":
+ """从消息字典构建 DatabaseMessages 对象
Args:
message_dict: 消息字典或 DatabaseMessages 对象
Returns:
- Optional[MessageRecv]: 找到的消息,如果没找到则返回None
+ Optional[DatabaseMessages]: 构建的消息对象,如果构建失败则返回None
"""
- # 兼容 DatabaseMessages 对象和字典
- if isinstance(message_dict, dict):
- user_platform = message_dict.get("user_platform", "")
- user_id = message_dict.get("user_id", "")
- user_nickname = message_dict.get("user_nickname", "")
- user_cardname = message_dict.get("user_cardname", "")
- chat_info_group_id = message_dict.get("chat_info_group_id")
- chat_info_group_platform = message_dict.get("chat_info_group_platform", "")
- chat_info_group_name = message_dict.get("chat_info_group_name", "")
- chat_info_platform = message_dict.get("chat_info_platform", "")
- message_id = message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id")
- time_val = message_dict.get("time")
- additional_config = message_dict.get("additional_config")
- processed_plain_text = message_dict.get("processed_plain_text")
- else:
- # DatabaseMessages 对象
- user_platform = getattr(message_dict, "user_platform", "")
- user_id = getattr(message_dict, "user_id", "")
- user_nickname = getattr(message_dict, "user_nickname", "")
- user_cardname = getattr(message_dict, "user_cardname", "")
- chat_info_group_id = getattr(message_dict, "chat_info_group_id", None)
- chat_info_group_platform = getattr(message_dict, "chat_info_group_platform", "")
- chat_info_group_name = getattr(message_dict, "chat_info_group_name", "")
- chat_info_platform = getattr(message_dict, "chat_info_platform", "")
- message_id = getattr(message_dict, "message_id", None)
- time_val = getattr(message_dict, "time", None)
- additional_config = getattr(message_dict, "additional_config", None)
- processed_plain_text = getattr(message_dict, "processed_plain_text", "")
+ from src.common.data_models.database_data_model import DatabaseMessages
- # 构建MessageRecv对象
- user_info = {
- "platform": user_platform,
- "user_id": user_id,
- "user_nickname": user_nickname,
- "user_cardname": user_cardname,
- }
-
- group_info = {}
- if chat_info_group_id:
- group_info = {
- "platform": chat_info_group_platform,
- "group_id": chat_info_group_id,
- "group_name": chat_info_group_name,
- }
-
- format_info = {"content_format": "", "accept_format": ""}
- template_info = {"template_items": {}}
-
- message_info = {
- "platform": chat_info_platform,
- "message_id": message_id,
- "time": time_val,
- "group_info": group_info,
- "user_info": user_info,
- "additional_config": additional_config,
- "format_info": format_info,
- "template_info": template_info,
- }
-
- new_message_dict = {
- "message_info": message_info,
- "raw_message": processed_plain_text,
- "processed_plain_text": processed_plain_text,
- }
-
- message_recv = MessageRecv(new_message_dict)
-
- logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {user_nickname}")
- return message_recv
+ # 如果已经是 DatabaseMessages,直接返回
+ if isinstance(message_dict, DatabaseMessages):
+ return message_dict
+
+ # 从字典提取信息
+ user_platform = message_dict.get("user_platform", "")
+ user_id = message_dict.get("user_id", "")
+ user_nickname = message_dict.get("user_nickname", "")
+ user_cardname = message_dict.get("user_cardname", "")
+ chat_info_group_id = message_dict.get("chat_info_group_id")
+ chat_info_group_platform = message_dict.get("chat_info_group_platform", "")
+ chat_info_group_name = message_dict.get("chat_info_group_name", "")
+ chat_info_platform = message_dict.get("chat_info_platform", "")
+ message_id = message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id")
+ time_val = message_dict.get("time", time.time())
+ additional_config = message_dict.get("additional_config")
+ processed_plain_text = message_dict.get("processed_plain_text", "")
+
+ # DatabaseMessages 使用扁平参数构造
+ db_message = DatabaseMessages(
+ message_id=message_id or "temp_reply_id",
+ time=time_val,
+ user_id=user_id,
+ user_nickname=user_nickname,
+ user_cardname=user_cardname,
+ user_platform=user_platform,
+ chat_info_group_id=chat_info_group_id,
+ chat_info_group_name=chat_info_group_name,
+ chat_info_group_platform=chat_info_group_platform,
+ chat_info_platform=chat_info_platform,
+ processed_plain_text=processed_plain_text,
+ additional_config=additional_config
+ )
+
+ logger.info(f"[SendAPI] 构建回复消息对象,发送者: {user_nickname}")
+ return db_message
def put_adapter_response(request_id: str, response_data: dict) -> None:
@@ -285,17 +257,17 @@ async def _send_to_target(
"message_id": "temp_reply_id", # 临时ID
"time": time.time()
}
- anchor_message = message_dict_to_message_recv(message_dict=temp_message_dict)
+ anchor_message = message_dict_to_db_message(message_dict=temp_message_dict)
else:
anchor_message = None
reply_to_platform_id = f"{target_stream.platform}:{sender_id}" if anchor_message else None
elif reply_to_message:
- anchor_message = message_dict_to_message_recv(message_dict=reply_to_message)
+ anchor_message = message_dict_to_db_message(message_dict=reply_to_message)
if anchor_message:
- anchor_message.update_chat_stream(target_stream)
+ # DatabaseMessages 不需要 update_chat_stream,它是纯数据对象
reply_to_platform_id = (
- f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}"
+ f"{anchor_message.chat_info.platform}:{anchor_message.user_info.user_id}"
)
else:
reply_to_platform_id = None
diff --git a/src/plugin_system/base/base_command.py b/src/plugin_system/base/base_command.py
index 9cb41ed04..7076bbba6 100644
--- a/src/plugin_system/base/base_command.py
+++ b/src/plugin_system/base/base_command.py
@@ -1,10 +1,14 @@
from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING
-from src.chat.message_receive.message import MessageRecv
+from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.plugin_system.apis import send_api
from src.plugin_system.base.component_types import ChatType, CommandInfo, ComponentType
+if TYPE_CHECKING:
+ from src.chat.message_receive.chat_stream import ChatStream
+
logger = get_logger("base_command")
@@ -29,11 +33,11 @@ class BaseCommand(ABC):
chat_type_allow: ChatType = ChatType.ALL
"""允许的聊天类型,默认为所有类型"""
- def __init__(self, message: MessageRecv, plugin_config: dict | None = None):
+ def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None):
"""初始化Command组件
Args:
- message: 接收到的消息对象
+ message: 接收到的消息对象(DatabaseMessages)
plugin_config: 插件配置字典
"""
self.message = message
@@ -41,6 +45,9 @@ class BaseCommand(ABC):
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
self.log_prefix = "[Command]"
+
+ # chat_stream 会在运行时被 bot.py 设置
+ self.chat_stream: "ChatStream | None" = None
# 从类属性获取chat_type_allow设置
self.chat_type_allow = getattr(self.__class__, "chat_type_allow", ChatType.ALL)
@@ -49,7 +56,7 @@ class BaseCommand(ABC):
# 验证聊天类型限制
if not self._validate_chat_type():
- is_group = hasattr(self.message, "is_group_message") and self.message.is_group_message
+ is_group = message.group_info is not None
logger.warning(
f"{self.log_prefix} Command '{self.command_name}' 不支持当前聊天类型: "
f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}"
@@ -72,8 +79,8 @@ class BaseCommand(ABC):
if self.chat_type_allow == ChatType.ALL:
return True
- # 检查是否为群聊消息
- is_group = self.message.message_info.group_info
+ # 检查是否为群聊消息(DatabaseMessages使用group_info来判断)
+ is_group = self.message.group_info is not None
if self.chat_type_allow == ChatType.GROUP and is_group:
return True
@@ -137,12 +144,11 @@ class BaseCommand(ABC):
bool: 是否发送成功
"""
# 获取聊天流信息
- chat_stream = self.message.chat_stream
- if not chat_stream or not hasattr(chat_stream, "stream_id"):
+ if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
- return await send_api.text_to_stream(text=content, stream_id=chat_stream.stream_id, reply_to=reply_to)
+ return await send_api.text_to_stream(text=content, stream_id=self.chat_stream.stream_id, reply_to=reply_to)
async def send_type(
self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = ""
@@ -160,15 +166,14 @@ class BaseCommand(ABC):
bool: 是否发送成功
"""
# 获取聊天流信息
- chat_stream = self.message.chat_stream
- if not chat_stream or not hasattr(chat_stream, "stream_id"):
+ if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.custom_to_stream(
message_type=message_type,
content=content,
- stream_id=chat_stream.stream_id,
+ stream_id=self.chat_stream.stream_id,
display_message=display_message,
typing=typing,
reply_to=reply_to,
@@ -190,8 +195,7 @@ class BaseCommand(ABC):
"""
try:
# 获取聊天流信息
- chat_stream = self.message.chat_stream
- if not chat_stream or not hasattr(chat_stream, "stream_id"):
+ if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
@@ -200,7 +204,7 @@ class BaseCommand(ABC):
success = await send_api.command_to_stream(
command=command_data,
- stream_id=chat_stream.stream_id,
+ stream_id=self.chat_stream.stream_id,
storage_message=storage_message,
display_message=display_message,
)
@@ -225,12 +229,11 @@ class BaseCommand(ABC):
Returns:
bool: 是否发送成功
"""
- chat_stream = self.message.chat_stream
- if not chat_stream or not hasattr(chat_stream, "stream_id"):
+ if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
- return await send_api.emoji_to_stream(emoji_base64, chat_stream.stream_id)
+ return await send_api.emoji_to_stream(emoji_base64, self.chat_stream.stream_id)
async def send_image(self, image_base64: str) -> bool:
"""发送图片
@@ -241,12 +244,11 @@ class BaseCommand(ABC):
Returns:
bool: 是否发送成功
"""
- chat_stream = self.message.chat_stream
- if not chat_stream or not hasattr(chat_stream, "stream_id"):
+ if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
- return await send_api.image_to_stream(image_base64, chat_stream.stream_id)
+ return await send_api.image_to_stream(image_base64, self.chat_stream.stream_id)
@classmethod
def get_command_info(cls) -> "CommandInfo":
diff --git a/src/plugin_system/base/plus_command.py b/src/plugin_system/base/plus_command.py
index e442d76c1..b53846fc2 100644
--- a/src/plugin_system/base/plus_command.py
+++ b/src/plugin_system/base/plus_command.py
@@ -5,8 +5,9 @@
import re
from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING
-from src.chat.message_receive.message import MessageRecv
+from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_system.apis import send_api
@@ -14,6 +15,9 @@ from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.command_args import CommandArgs
from src.plugin_system.base.component_types import ChatType, ComponentType, PlusCommandInfo
+if TYPE_CHECKING:
+ from src.chat.message_receive.chat_stream import ChatStream
+
logger = get_logger("plus_command")
@@ -50,23 +54,26 @@ class PlusCommand(ABC):
intercept_message: bool = False
"""是否拦截消息,不进行后续处理"""
- def __init__(self, message: MessageRecv, plugin_config: dict | None = None):
+ def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None):
"""初始化命令组件
Args:
- message: 接收到的消息对象
+ message: 接收到的消息对象(DatabaseMessages)
plugin_config: 插件配置字典
"""
self.message = message
self.plugin_config = plugin_config or {}
self.log_prefix = "[PlusCommand]"
+
+ # chat_stream 会在运行时被 bot.py 设置
+ self.chat_stream: "ChatStream | None" = None
# 解析命令参数
self._parse_command()
# 验证聊天类型限制
if not self._validate_chat_type():
- is_group = self.message.message_info.group_info.group_id
+ is_group = message.group_info is not None
logger.warning(
f"{self.log_prefix} 命令 '{self.command_name}' 不支持当前聊天类型: "
f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}"
@@ -124,8 +131,8 @@ class PlusCommand(ABC):
if self.chat_type_allow == ChatType.ALL:
return True
- # 检查是否为群聊消息
- is_group = hasattr(self.message.message_info, "group_info") and self.message.message_info.group_info
+ # 检查是否为群聊消息(DatabaseMessages使用group_info判断)
+ is_group = self.message.group_info is not None
if self.chat_type_allow == ChatType.GROUP and is_group:
return True
@@ -152,7 +159,7 @@ class PlusCommand(ABC):
def _is_exact_command_call(self) -> bool:
"""检查是否是精确的命令调用(无参数)"""
- if not hasattr(self.message, "plain_text") or not self.message.processed_plain_text:
+ if not self.message.processed_plain_text:
return False
plain_text = self.message.processed_plain_text.strip()
@@ -218,12 +225,11 @@ class PlusCommand(ABC):
bool: 是否发送成功
"""
# 获取聊天流信息
- chat_stream = self.message.chat_stream
- if not chat_stream or not hasattr(chat_stream, "stream_id"):
+ if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
- return await send_api.text_to_stream(text=content, stream_id=chat_stream.stream_id, reply_to=reply_to)
+ return await send_api.text_to_stream(text=content, stream_id=self.chat_stream.stream_id, reply_to=reply_to)
async def send_type(
self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = ""
@@ -241,15 +247,14 @@ class PlusCommand(ABC):
bool: 是否发送成功
"""
# 获取聊天流信息
- chat_stream = self.message.chat_stream
- if not chat_stream or not hasattr(chat_stream, "stream_id"):
+ if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.custom_to_stream(
message_type=message_type,
content=content,
- stream_id=chat_stream.stream_id,
+ stream_id=self.chat_stream.stream_id,
display_message=display_message,
typing=typing,
reply_to=reply_to,
@@ -264,12 +269,11 @@ class PlusCommand(ABC):
Returns:
bool: 是否发送成功
"""
- chat_stream = self.message.chat_stream
- if not chat_stream or not hasattr(chat_stream, "stream_id"):
+ if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
- return await send_api.emoji_to_stream(emoji_base64, chat_stream.stream_id)
+ return await send_api.emoji_to_stream(emoji_base64, self.chat_stream.stream_id)
async def send_image(self, image_base64: str) -> bool:
"""发送图片
@@ -280,12 +284,11 @@ class PlusCommand(ABC):
Returns:
bool: 是否发送成功
"""
- chat_stream = self.message.chat_stream
- if not chat_stream or not hasattr(chat_stream, "stream_id"):
+ if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
- return await send_api.image_to_stream(image_base64, chat_stream.stream_id)
+ return await send_api.image_to_stream(image_base64, self.chat_stream.stream_id)
@classmethod
def get_plus_command_info(cls) -> "PlusCommandInfo":
@@ -340,12 +343,12 @@ class PlusCommandAdapter(BaseCommand):
将PlusCommand适配到现有的插件系统,继承BaseCommand
"""
- def __init__(self, plus_command_class, message: MessageRecv, plugin_config: dict | None = None):
+ def __init__(self, plus_command_class, message: DatabaseMessages, plugin_config: dict | None = None):
"""初始化适配器
Args:
plus_command_class: PlusCommand子类
- message: 消息对象
+ message: 消息对象(DatabaseMessages)
plugin_config: 插件配置
"""
# 先设置必要的类属性
@@ -400,7 +403,7 @@ def create_plus_command_adapter(plus_command_class):
command_pattern = plus_command_class._generate_command_pattern()
chat_type_allow = getattr(plus_command_class, "chat_type_allow", ChatType.ALL)
- def __init__(self, message: MessageRecv, plugin_config: dict | None = None):
+ def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None):
super().__init__(message, plugin_config)
self.plus_command = plus_command_class(message, plugin_config)
self.priority = getattr(plus_command_class, "priority", 0)
diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_executor.py b/src/plugins/built_in/affinity_flow_chatter/plan_executor.py
index 3af389f9f..e150e7e62 100644
--- a/src/plugins/built_in/affinity_flow_chatter/plan_executor.py
+++ b/src/plugins/built_in/affinity_flow_chatter/plan_executor.py
@@ -410,11 +410,9 @@ class ChatterPlanExecutor:
)
# 添加到chat_stream的已读消息中
- if hasattr(chat_stream, "stream_context") and chat_stream.stream_context:
- chat_stream.stream_context.history_messages.append(bot_message)
- logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...")
- else:
- logger.warning("chat_stream没有stream_context,无法添加已读消息")
+ chat_stream.context_manager.context.history_messages.append(bot_message)
+ logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...")
+
except Exception as e:
logger.error(f"添加机器人回复到已读消息时出错: {e}")
diff --git a/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_event.py b/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_event.py
index 2a719da83..7b8ffdad1 100644
--- a/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_event.py
+++ b/src/plugins/built_in/affinity_flow_chatter/proactive_thinking_event.py
@@ -96,7 +96,7 @@ class ProactiveThinkingMessageHandler(BaseEventHandler):
"""处理消息事件
Args:
- kwargs: 事件参数,格式为 {"message": MessageRecv}
+ kwargs: 事件参数,格式为 {"message": DatabaseMessages}
Returns:
HandlerResult: 处理结果
@@ -104,7 +104,7 @@ class ProactiveThinkingMessageHandler(BaseEventHandler):
if not kwargs:
return HandlerResult(success=True, continue_process=True, message=None)
- # 从 kwargs 中获取 MessageRecv 对象
+ # 从 kwargs 中获取 DatabaseMessages 对象
message = kwargs.get("message")
if not message or not hasattr(message, "chat_stream"):
return HandlerResult(success=True, continue_process=True, message=None)