重构消息处理并用DatabaseMessages替换MessageRecv

-更新PlusCommand以使用DatabaseMessages而不是MessageRecv。
-将消息处理逻辑重构到一个新模块message_processor.py中,以处理消息段并从消息字典中创建DatabaseMessages。
-删除了已弃用的MessageRecv类及其相关逻辑。
-调整了各种插件以适应新的DatabaseMessages结构。
-增强了消息处理功能中的错误处理和日志记录。
This commit is contained in:
Windpicker-owo
2025-10-31 19:24:58 +08:00
parent 50260818a8
commit 371041c9db
23 changed files with 1520 additions and 1801 deletions

View File

@@ -1,303 +0,0 @@
"""
关系追踪工具集成测试脚本
注意:此脚本需要在完整的应用环境中运行
建议通过 bot.py 启动后在交互式环境中测试
"""
import asyncio
async def test_user_profile_tool():
"""测试用户画像工具"""
print("\n" + "=" * 80)
print("测试 UserProfileTool")
print("=" * 80)
from src.plugins.built_in.affinity_flow_chatter.user_profile_tool import UserProfileTool
from src.common.database.sqlalchemy_database_api import db_query
from src.common.database.sqlalchemy_models import UserRelationships
tool = UserProfileTool()
print(f"✅ 工具名称: {tool.name}")
print(f" 工具描述: {tool.description}")
# 执行工具
test_user_id = "integration_test_user_001"
result = await tool.execute({
"target_user_id": test_user_id,
"user_aliases": "测试小明,TestMing,小明君",
"impression_description": "这是一个集成测试用户性格开朗活泼喜欢技术讨论对AI和编程特别感兴趣。经常提出有深度的问题。",
"preference_keywords": "AI,Python,深度学习,游戏开发,科幻小说",
"affection_score": 0.85
})
print(f"\n✅ 工具执行结果:")
print(f" 类型: {result.get('type')}")
print(f" 内容: {result.get('content')}")
# 验证数据库
db_data = await db_query(
UserRelationships,
filters={"user_id": test_user_id},
limit=1
)
if db_data:
data = db_data[0]
print(f"\n✅ 数据库验证:")
print(f" user_id: {data.get('user_id')}")
print(f" user_aliases: {data.get('user_aliases')}")
print(f" relationship_text: {data.get('relationship_text', '')[:80]}...")
print(f" preference_keywords: {data.get('preference_keywords')}")
print(f" relationship_score: {data.get('relationship_score')}")
return True
else:
print(f"\n❌ 数据库中未找到数据")
return False
async def test_chat_stream_impression_tool():
"""测试聊天流印象工具"""
print("\n" + "=" * 80)
print("测试 ChatStreamImpressionTool")
print("=" * 80)
from src.plugins.built_in.affinity_flow_chatter.chat_stream_impression_tool import ChatStreamImpressionTool
from src.common.database.sqlalchemy_database_api import db_query
from src.common.database.sqlalchemy_models import ChatStreams, get_db_session
# 准备测试数据:先创建一条 ChatStreams 记录
test_stream_id = "integration_test_stream_001"
print(f"🔧 准备测试数据:创建聊天流记录 {test_stream_id}")
import time
current_time = time.time()
async with get_db_session() as session:
new_stream = ChatStreams(
stream_id=test_stream_id,
create_time=current_time,
last_active_time=current_time,
platform="QQ",
user_platform="QQ",
user_id="test_user_123",
user_nickname="测试用户",
group_name="测试技术交流群",
group_platform="QQ",
group_id="test_group_456",
stream_impression_text="", # 初始为空
stream_chat_style="",
stream_topic_keywords="",
stream_interest_score=0.5
)
session.add(new_stream)
await session.commit()
print(f"✅ 测试聊天流记录已创建")
tool = ChatStreamImpressionTool()
print(f"✅ 工具名称: {tool.name}")
print(f" 工具描述: {tool.description}")
# 执行工具
result = await tool.execute({
"stream_id": test_stream_id,
"impression_description": "这是一个技术交流群成员主要是程序员和AI爱好者。大家经常分享最新的技术文章讨论编程问题氛围友好且专业。",
"chat_style": "专业技术交流,活跃讨论,互帮互助,知识分享",
"topic_keywords": "Python开发,机器学习,AI应用,Web后端,数据分析,开源项目",
"interest_score": 0.90
})
print(f"\n✅ 工具执行结果:")
print(f" 类型: {result.get('type')}")
print(f" 内容: {result.get('content')}")
# 验证数据库
db_data = await db_query(
ChatStreams,
filters={"stream_id": test_stream_id},
limit=1
)
if db_data:
data = db_data[0]
print(f"\n✅ 数据库验证:")
print(f" stream_id: {data.get('stream_id')}")
print(f" stream_impression_text: {data.get('stream_impression_text', '')[:80]}...")
print(f" stream_chat_style: {data.get('stream_chat_style')}")
print(f" stream_topic_keywords: {data.get('stream_topic_keywords')}")
print(f" stream_interest_score: {data.get('stream_interest_score')}")
return True
else:
print(f"\n❌ 数据库中未找到数据")
return False
async def test_relationship_info_build():
"""测试关系信息构建"""
print("\n" + "=" * 80)
print("测试关系信息构建(提示词集成)")
print("=" * 80)
from src.person_info.relationship_fetcher import relationship_fetcher_manager
test_stream_id = "integration_test_stream_001"
test_person_id = "test_person_999" # 使用一个可能不存在的ID来测试
fetcher = relationship_fetcher_manager.get_fetcher(test_stream_id)
print(f"✅ RelationshipFetcher 已创建")
# 测试聊天流印象构建
print(f"\n🔍 构建聊天流印象...")
stream_info = await fetcher.build_chat_stream_impression(test_stream_id)
if stream_info:
print(f"✅ 聊天流印象构建成功")
print(f"\n{'=' * 80}")
print(stream_info)
print(f"{'=' * 80}")
else:
print(f"⚠️ 聊天流印象为空(可能测试数据不存在)")
return True
async def cleanup_test_data():
"""清理测试数据"""
print("\n" + "=" * 80)
print("清理测试数据")
print("=" * 80)
from src.common.database.sqlalchemy_database_api import db_query
from src.common.database.sqlalchemy_models import UserRelationships, ChatStreams
try:
# 清理用户数据
await db_query(
UserRelationships,
query_type="delete",
filters={"user_id": "integration_test_user_001"}
)
print("✅ 用户测试数据已清理")
# 清理聊天流数据
await db_query(
ChatStreams,
query_type="delete",
filters={"stream_id": "integration_test_stream_001"}
)
print("✅ 聊天流测试数据已清理")
return True
except Exception as e:
print(f"⚠️ 清理失败: {e}")
return False
async def run_all_tests():
"""运行所有测试"""
print("\n" + "=" * 80)
print("关系追踪工具集成测试")
print("=" * 80)
results = {}
# 测试1
try:
results["UserProfileTool"] = await test_user_profile_tool()
except Exception as e:
print(f"\n❌ UserProfileTool 测试失败: {e}")
import traceback
traceback.print_exc()
results["UserProfileTool"] = False
# 测试2
try:
results["ChatStreamImpressionTool"] = await test_chat_stream_impression_tool()
except Exception as e:
print(f"\n❌ ChatStreamImpressionTool 测试失败: {e}")
import traceback
traceback.print_exc()
results["ChatStreamImpressionTool"] = False
# 测试3
try:
results["RelationshipFetcher"] = await test_relationship_info_build()
except Exception as e:
print(f"\n❌ RelationshipFetcher 测试失败: {e}")
import traceback
traceback.print_exc()
results["RelationshipFetcher"] = False
# 清理
try:
await cleanup_test_data()
except Exception as e:
print(f"\n⚠️ 清理测试数据失败: {e}")
# 总结
print("\n" + "=" * 80)
print("测试总结")
print("=" * 80)
passed = sum(1 for r in results.values() if r)
total = len(results)
for test_name, result in results.items():
status = "✅ 通过" if result else "❌ 失败"
print(f"{status} - {test_name}")
print(f"\n总计: {passed}/{total} 测试通过")
if passed == total:
print("\n🎉 所有测试通过!")
else:
print(f"\n⚠️ {total - passed} 个测试失败")
return passed == total
# 使用说明
print("""
============================================================================
关系追踪工具集成测试脚本
============================================================================
此脚本需要在完整的应用环境中运行。
使用方法1: 在 bot.py 中添加测试调用
-----------------------------------
在 bot.py 的 main() 函数中添加:
# 测试关系追踪工具
from tests.integration_test_relationship_tools import run_all_tests
await run_all_tests()
使用方法2: 在 Python REPL 中运行
-----------------------------------
启动 bot.py 后,在 Python 调试控制台中执行:
import asyncio
from tests.integration_test_relationship_tools import run_all_tests
asyncio.create_task(run_all_tests())
使用方法3: 直接在此文件底部运行
-----------------------------------
取消注释下面的代码,然后确保已启动应用环境
============================================================================
""")
# 如果需要直接运行(需要应用环境已启动)
if __name__ == "__main__":
print("\n⚠️ 警告: 直接运行此脚本可能会失败,因为缺少应用环境")
print("建议在 bot.py 启动后的环境中运行\n")
try:
asyncio.run(run_all_tests())
except Exception as e:
print(f"\n❌ 测试失败: {e}")
print("\n建议:")
print("1. 确保已启动 bot.py")
print("2. 在 Python 调试控制台中运行测试")
print("3. 或在 bot.py 中添加测试调用")

View File

@@ -27,6 +27,6 @@
"venvPath": ".",
"venv": ".venv",
"executionEnvironments": [
{"root": "src"}
{"root": "."}
]
}

View File

@@ -6,7 +6,7 @@
import re
from src.chat.message_receive.message import MessageRecv
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
logger = get_logger("anti_injector.message_processor")
@@ -15,7 +15,7 @@ logger = get_logger("anti_injector.message_processor")
class MessageProcessor:
"""消息内容处理器"""
def extract_text_content(self, message: MessageRecv) -> str:
def extract_text_content(self, message: DatabaseMessages) -> str:
"""提取消息中的文本内容,过滤掉引用的历史内容
Args:
@@ -64,7 +64,7 @@ class MessageProcessor:
return new_content
@staticmethod
def check_whitelist(message: MessageRecv, whitelist: list) -> tuple | None:
def check_whitelist(message: DatabaseMessages, whitelist: list) -> tuple | None:
"""检查用户白名单
Args:
@@ -74,8 +74,8 @@ class MessageProcessor:
Returns:
如果在白名单中返回结果元组否则返回None
"""
user_id = message.message_info.user_info.user_id
platform = message.message_info.platform
user_id = message.user_info.user_id
platform = message.chat_info.platform
# 检查用户白名单:格式为 [[platform, user_id], ...]
for whitelist_entry in whitelist:

View File

@@ -29,7 +29,6 @@ class SingleStreamContextManager:
# 配置参数
self.max_context_size = max_context_size or getattr(global_config.chat, "max_context_size", 100)
self.context_ttl = getattr(global_config.chat, "context_ttl", 24 * 3600) # 24小时
# 元数据
self.created_time = time.time()
@@ -93,27 +92,24 @@ class SingleStreamContextManager:
return True
else:
logger.warning(f"消息缓存系统添加失败,回退到直接添加: {self.stream_id}")
except ImportError:
logger.debug("MessageManager不可用使用直接添加模式")
except Exception as e:
logger.warning(f"消息缓存系统异常,回退到直接添加: {self.stream_id}, error={e}")
# 回退方案:直接添加到未读消息
message.is_read = False
self.context.unread_messages.append(message)
# 回退方案:直接添加到未读消息
message.is_read = False
self.context.unread_messages.append(message)
# 自动检测和更新chat type
self._detect_chat_type(message)
# 自动检测和更新chat type
self._detect_chat_type(message)
# 在上下文管理器中计算兴趣值
await self._calculate_message_interest(message)
self.total_messages += 1
self.last_access_time = time.time()
# 启动流的循环任务(如果还未启动)
asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}")
return True
# 在上下文管理器中计算兴趣值
await self._calculate_message_interest(message)
self.total_messages += 1
self.last_access_time = time.time()
# 启动流的循环任务(如果还未启动)
asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}")
return True
except Exception as e:
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
return False

View File

@@ -71,14 +71,6 @@ class MessageManager:
except Exception as e:
logger.error(f"启动批量数据库写入器失败: {e}")
# 启动流缓存管理器
try:
from src.chat.message_manager.stream_cache_manager import init_stream_cache_manager
await init_stream_cache_manager()
except Exception as e:
logger.error(f"启动流缓存管理器失败: {e}")
# 启动消息缓存系统(内置)
logger.info("📦 消息缓存系统已启动")
@@ -116,15 +108,6 @@ class MessageManager:
except Exception as e:
logger.error(f"停止批量数据库写入器失败: {e}")
# 停止流缓存管理器
try:
from src.chat.message_manager.stream_cache_manager import shutdown_stream_cache_manager
await shutdown_stream_cache_manager()
logger.info("🗄️ 流缓存管理器已停止")
except Exception as e:
logger.error(f"停止流缓存管理器失败: {e}")
# 停止消息缓存系统(内置)
self.message_caches.clear()
self.stream_processing_status.clear()
@@ -152,7 +135,7 @@ class MessageManager:
# 检查是否为notice消息
if self._is_notice_message(message):
# Notice消息处理 - 添加到全局管理器
logger.info(f"📢 检测到notice消息: message_id={message.message_id}, is_notify={message.is_notify}, notice_type={getattr(message, 'notice_type', None)}")
logger.info(f"📢 检测到notice消息: notice_type={getattr(message, 'notice_type', None)}")
await self._handle_notice_message(stream_id, message)
# 根据配置决定是否继续处理(触发聊天流程)
@@ -206,39 +189,6 @@ class MessageManager:
except Exception as e:
logger.error(f"更新消息 {message_id} 时发生错误: {e}")
async def bulk_update_messages(self, stream_id: str, updates: list[dict[str, Any]]) -> int:
"""批量更新消息信息,降低更新频率"""
if not updates:
return 0
try:
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"MessageManager.bulk_update_messages: 聊天流 {stream_id} 不存在")
return 0
updated_count = 0
for item in updates:
message_id = item.get("message_id")
if not message_id:
continue
payload = {key: value for key, value in item.items() if key != "message_id" and value is not None}
if not payload:
continue
success = await chat_stream.context_manager.update_message(message_id, payload)
if success:
updated_count += 1
if updated_count:
logger.debug(f"批量更新消息 {updated_count} 条 (stream={stream_id})")
return updated_count
except Exception as e:
logger.error(f"批量更新聊天流 {stream_id} 消息失败: {e}")
return 0
async def add_action(self, stream_id: str, message_id: str, action: str):
"""添加动作到消息"""
@@ -266,7 +216,7 @@ class MessageManager:
logger.warning(f"停用流失败: 聊天流 {stream_id} 不存在")
return
context = chat_stream.stream_context
context = chat_stream.context_manager.context
context.is_active = False
# 取消处理任务
@@ -288,7 +238,7 @@ class MessageManager:
logger.warning(f"激活流失败: 聊天流 {stream_id} 不存在")
return
context = chat_stream.stream_context
context = chat_stream.context_manager.context
context.is_active = True
logger.info(f"激活聊天流: {stream_id}")
@@ -304,7 +254,7 @@ class MessageManager:
if not chat_stream:
return None
context = chat_stream.stream_context
context = chat_stream.context_manager.context
unread_count = len(chat_stream.context_manager.get_unread_messages())
return StreamStats(
@@ -447,7 +397,7 @@ class MessageManager:
await asyncio.sleep(0.1)
# 获取当前的stream context
context = chat_stream.stream_context
context = chat_stream.context_manager.context
# 确保有未读消息需要处理
unread_messages = context.get_unread_messages()

View File

@@ -1,377 +0,0 @@
"""
流缓存管理器 - 使用优化版聊天流和智能缓存策略
提供分层缓存和自动清理功能
"""
import asyncio
import time
from collections import OrderedDict
from dataclasses import dataclass
from maim_message import GroupInfo, UserInfo
from src.chat.message_receive.optimized_chat_stream import OptimizedChatStream, create_optimized_chat_stream
from src.common.logger import get_logger
logger = get_logger("stream_cache_manager")
@dataclass
class StreamCacheStats:
"""缓存统计信息"""
hot_cache_size: int = 0
warm_storage_size: int = 0
cold_storage_size: int = 0
total_memory_usage: int = 0 # 估算的内存使用(字节)
cache_hits: int = 0
cache_misses: int = 0
evictions: int = 0
last_cleanup_time: float = 0
class TieredStreamCache:
"""分层流缓存管理器"""
def __init__(
self,
max_hot_size: int = 100,
max_warm_size: int = 500,
max_cold_size: int = 2000,
cleanup_interval: float = 300.0, # 5分钟清理一次
hot_timeout: float = 1800.0, # 30分钟未访问降级到warm
warm_timeout: float = 7200.0, # 2小时未访问降级到cold
cold_timeout: float = 86400.0, # 24小时未访问删除
):
self.max_hot_size = max_hot_size
self.max_warm_size = max_warm_size
self.max_cold_size = max_cold_size
self.cleanup_interval = cleanup_interval
self.hot_timeout = hot_timeout
self.warm_timeout = warm_timeout
self.cold_timeout = cold_timeout
# 三层缓存存储
self.hot_cache: OrderedDict[str, OptimizedChatStream] = OrderedDict() # 热数据LRU
self.warm_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间)
self.cold_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间)
# 统计信息
self.stats = StreamCacheStats()
# 清理任务
self.cleanup_task: asyncio.Task | None = None
self.is_running = False
logger.info(f"分层流缓存管理器初始化完成 (hot:{max_hot_size}, warm:{max_warm_size}, cold:{max_cold_size})")
async def start(self):
"""启动缓存管理器"""
if self.is_running:
logger.warning("缓存管理器已经在运行")
return
self.is_running = True
self.cleanup_task = asyncio.create_task(self._cleanup_loop(), name="stream_cache_cleanup")
async def stop(self):
"""停止缓存管理器"""
if not self.is_running:
return
self.is_running = False
if self.cleanup_task and not self.cleanup_task.done():
self.cleanup_task.cancel()
try:
await asyncio.wait_for(self.cleanup_task, timeout=10.0)
except asyncio.TimeoutError:
logger.warning("缓存清理任务停止超时")
except Exception as e:
logger.error(f"停止缓存清理任务时出错: {e}")
logger.info("分层流缓存管理器已停止")
async def get_or_create_stream(
self,
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: GroupInfo | None = None,
data: dict | None = None,
) -> OptimizedChatStream:
"""获取或创建流 - 优化版本"""
current_time = time.time()
# 1. 检查热缓存
if stream_id in self.hot_cache:
stream = self.hot_cache[stream_id]
# 移动到末尾LRU更新
self.hot_cache.move_to_end(stream_id)
self.stats.cache_hits += 1
logger.debug(f"热缓存命中: {stream_id}")
return stream.create_snapshot()
# 2. 检查温存储
if stream_id in self.warm_storage:
stream, last_access = self.warm_storage[stream_id]
self.warm_storage[stream_id] = (stream, current_time)
self.stats.cache_hits += 1
logger.debug(f"温缓存命中: {stream_id}")
# 提升到热缓存
await self._promote_to_hot(stream_id, stream)
return stream.create_snapshot()
# 3. 检查冷存储
if stream_id in self.cold_storage:
stream, last_access = self.cold_storage[stream_id]
self.cold_storage[stream_id] = (stream, current_time)
self.stats.cache_hits += 1
logger.debug(f"冷缓存命中: {stream_id}")
# 提升到温缓存
await self._promote_to_warm(stream_id, stream)
return stream.create_snapshot()
# 4. 缓存未命中,创建新流
self.stats.cache_misses += 1
stream = create_optimized_chat_stream(
stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info, data=data
)
logger.debug(f"缓存未命中,创建新流: {stream_id}")
# 添加到热缓存
await self._add_to_hot(stream_id, stream)
return stream
async def _add_to_hot(self, stream_id: str, stream: OptimizedChatStream):
"""添加到热缓存"""
# 检查是否需要驱逐
if len(self.hot_cache) >= self.max_hot_size:
await self._evict_from_hot()
self.hot_cache[stream_id] = stream
self.stats.hot_cache_size = len(self.hot_cache)
async def _promote_to_hot(self, stream_id: str, stream: OptimizedChatStream):
"""提升到热缓存"""
# 从温存储中移除
if stream_id in self.warm_storage:
del self.warm_storage[stream_id]
self.stats.warm_storage_size = len(self.warm_storage)
# 添加到热缓存
await self._add_to_hot(stream_id, stream)
logger.debug(f"{stream_id} 提升到热缓存")
async def _promote_to_warm(self, stream_id: str, stream: OptimizedChatStream):
"""提升到温缓存"""
# 从冷存储中移除
if stream_id in self.cold_storage:
del self.cold_storage[stream_id]
self.stats.cold_storage_size = len(self.cold_storage)
# 添加到温存储
if len(self.warm_storage) >= self.max_warm_size:
await self._evict_from_warm()
current_time = time.time()
self.warm_storage[stream_id] = (stream, current_time)
self.stats.warm_storage_size = len(self.warm_storage)
logger.debug(f"{stream_id} 提升到温缓存")
async def _evict_from_hot(self):
"""从热缓存驱逐最久未使用的流"""
if not self.hot_cache:
return
# LRU驱逐
stream_id, stream = self.hot_cache.popitem(last=False)
self.stats.evictions += 1
logger.debug(f"从热缓存驱逐: {stream_id}")
# 移动到温存储
if len(self.warm_storage) < self.max_warm_size:
current_time = time.time()
self.warm_storage[stream_id] = (stream, current_time)
self.stats.warm_storage_size = len(self.warm_storage)
else:
# 温存储也满了,直接删除
logger.debug(f"温存储已满,删除流: {stream_id}")
self.stats.hot_cache_size = len(self.hot_cache)
async def _evict_from_warm(self):
"""从温存储驱逐最久未使用的流"""
if not self.warm_storage:
return
# 找到最久未访问的流
oldest_stream_id = min(self.warm_storage.keys(), key=lambda k: self.warm_storage[k][1])
stream, last_access = self.warm_storage.pop(oldest_stream_id)
self.stats.evictions += 1
logger.debug(f"从温存储驱逐: {oldest_stream_id}")
# 移动到冷存储
if len(self.cold_storage) < self.max_cold_size:
current_time = time.time()
self.cold_storage[oldest_stream_id] = (stream, current_time)
self.stats.cold_storage_size = len(self.cold_storage)
else:
# 冷存储也满了,直接删除
logger.debug(f"冷存储已满,删除流: {oldest_stream_id}")
self.stats.warm_storage_size = len(self.warm_storage)
async def _cleanup_loop(self):
"""清理循环"""
logger.info("流缓存清理循环启动")
while self.is_running:
try:
await asyncio.sleep(self.cleanup_interval)
await self._perform_cleanup()
except asyncio.CancelledError:
logger.info("流缓存清理循环被取消")
break
except Exception as e:
logger.error(f"流缓存清理出错: {e}")
logger.info("流缓存清理循环结束")
async def _perform_cleanup(self):
"""执行清理操作"""
current_time = time.time()
cleanup_stats = {
"hot_to_warm": 0,
"warm_to_cold": 0,
"cold_removed": 0,
}
# 1. 检查热缓存超时
hot_to_demote = []
for stream_id, stream in self.hot_cache.items():
# 获取最后访问时间(简化:使用创建时间作为近似)
last_access = getattr(stream, "last_active_time", stream.create_time)
if current_time - last_access > self.hot_timeout:
hot_to_demote.append(stream_id)
for stream_id in hot_to_demote:
stream = self.hot_cache.pop(stream_id)
current_time_local = time.time()
self.warm_storage[stream_id] = (stream, current_time_local)
cleanup_stats["hot_to_warm"] += 1
# 2. 检查温存储超时
warm_to_demote = []
for stream_id, (stream, last_access) in self.warm_storage.items():
if current_time - last_access > self.warm_timeout:
warm_to_demote.append(stream_id)
for stream_id in warm_to_demote:
stream, last_access = self.warm_storage.pop(stream_id)
self.cold_storage[stream_id] = (stream, last_access)
cleanup_stats["warm_to_cold"] += 1
# 3. 检查冷存储超时
cold_to_remove = []
for stream_id, (stream, last_access) in self.cold_storage.items():
if current_time - last_access > self.cold_timeout:
cold_to_remove.append(stream_id)
for stream_id in cold_to_remove:
self.cold_storage.pop(stream_id)
cleanup_stats["cold_removed"] += 1
# 更新统计信息
self.stats.hot_cache_size = len(self.hot_cache)
self.stats.warm_storage_size = len(self.warm_storage)
self.stats.cold_storage_size = len(self.cold_storage)
self.stats.last_cleanup_time = current_time
# 估算内存使用(粗略估计)
self.stats.total_memory_usage = (
len(self.hot_cache) * 1024 # 每个热流约1KB
+ len(self.warm_storage) * 512 # 每个温流约512B
+ len(self.cold_storage) * 256 # 每个冷流约256B
)
if sum(cleanup_stats.values()) > 0:
logger.info(
f"缓存清理完成: {cleanup_stats['hot_to_warm']}热→温, "
f"{cleanup_stats['warm_to_cold']}温→冷, "
f"{cleanup_stats['cold_removed']}冷删除"
)
def get_stats(self) -> StreamCacheStats:
"""获取缓存统计信息"""
# 计算命中率
total_requests = self.stats.cache_hits + self.stats.cache_misses
hit_rate = self.stats.cache_hits / total_requests if total_requests > 0 else 0
stats_copy = StreamCacheStats(
hot_cache_size=self.stats.hot_cache_size,
warm_storage_size=self.stats.warm_storage_size,
cold_storage_size=self.stats.cold_storage_size,
total_memory_usage=self.stats.total_memory_usage,
cache_hits=self.stats.cache_hits,
cache_misses=self.stats.cache_misses,
evictions=self.stats.evictions,
last_cleanup_time=self.stats.last_cleanup_time,
)
# 添加命中率信息
stats_copy.hit_rate = hit_rate
return stats_copy
def clear_cache(self):
"""清空所有缓存"""
self.hot_cache.clear()
self.warm_storage.clear()
self.cold_storage.clear()
self.stats.hot_cache_size = 0
self.stats.warm_storage_size = 0
self.stats.cold_storage_size = 0
self.stats.total_memory_usage = 0
logger.info("所有缓存已清空")
async def get_stream_snapshot(self, stream_id: str) -> OptimizedChatStream | None:
"""获取流的快照(不修改缓存状态)"""
if stream_id in self.hot_cache:
return self.hot_cache[stream_id].create_snapshot()
elif stream_id in self.warm_storage:
return self.warm_storage[stream_id][0].create_snapshot()
elif stream_id in self.cold_storage:
return self.cold_storage[stream_id][0].create_snapshot()
return None
def get_cached_stream_ids(self) -> set[str]:
"""获取所有缓存的流ID"""
return set(self.hot_cache.keys()) | set(self.warm_storage.keys()) | set(self.cold_storage.keys())
# 全局缓存管理器实例
_cache_manager: TieredStreamCache | None = None
def get_stream_cache_manager() -> TieredStreamCache:
"""获取流缓存管理器实例"""
global _cache_manager
if _cache_manager is None:
_cache_manager = TieredStreamCache()
return _cache_manager
async def init_stream_cache_manager():
"""初始化流缓存管理器"""
manager = get_stream_cache_manager()
await manager.start()
async def shutdown_stream_cache_manager():
"""关闭流缓存管理器"""
manager = get_stream_cache_manager()
await manager.stop()

View File

@@ -9,10 +9,10 @@ from maim_message import UserInfo
from src.chat.antipromptinjector import initialize_anti_injector
from src.chat.message_manager import message_manager
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.prompt import create_prompt_async, global_prompt_manager
from src.chat.utils.utils import is_mentioned_bot_in_message
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.config.config import global_config
from src.mood.mood_manager import mood_manager # 导入情绪管理器
@@ -105,10 +105,10 @@ class ChatBot:
self._started = True
async def _process_plus_commands(self, message: MessageRecv):
async def _process_plus_commands(self, message: DatabaseMessages, chat: ChatStream):
"""独立处理PlusCommand系统"""
try:
text = message.processed_plain_text
text = message.processed_plain_text or ""
# 获取配置的命令前缀
from src.config.config import global_config
@@ -166,10 +166,10 @@ class ChatBot:
# 检查命令是否被禁用
if (
message.chat_stream
and message.chat_stream.stream_id
chat
and chat.stream_id
and plus_command_name
in global_announcement_manager.get_disabled_chat_commands(message.chat_stream.stream_id)
in global_announcement_manager.get_disabled_chat_commands(chat.stream_id)
):
logger.info("用户禁用的PlusCommand跳过处理")
return False, None, True
@@ -181,11 +181,14 @@ class ChatBot:
# 创建PlusCommand实例
plus_command_instance = plus_command_class(message, plugin_config)
# 为插件实例设置 chat_stream 运行时属性
setattr(plus_command_instance, "chat_stream", chat)
try:
# 检查聊天类型限制
if not plus_command_instance.is_chat_type_allowed():
is_group = message.message_info.group_info
is_group = chat.group_info is not None
logger.info(
f"PlusCommand {plus_command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
)
@@ -225,11 +228,11 @@ class ChatBot:
logger.error(f"处理PlusCommand时出错: {e}")
return False, None, True # 出错时继续处理消息
async def _process_commands_with_new_system(self, message: MessageRecv):
async def _process_commands_with_new_system(self, message: DatabaseMessages, chat: ChatStream):
# sourcery skip: use-named-expression
"""使用新插件系统处理命令"""
try:
text = message.processed_plain_text
text = message.processed_plain_text or ""
# 使用新的组件注册中心查找命令
command_result = component_registry.find_command_by_text(text)
@@ -238,10 +241,10 @@ class ChatBot:
plugin_name = command_info.plugin_name
command_name = command_info.name
if (
message.chat_stream
and message.chat_stream.stream_id
chat
and chat.stream_id
and command_name
in global_announcement_manager.get_disabled_chat_commands(message.chat_stream.stream_id)
in global_announcement_manager.get_disabled_chat_commands(chat.stream_id)
):
logger.info("用户禁用的命令,跳过处理")
return False, None, True
@@ -254,11 +257,14 @@ class ChatBot:
# 创建命令实例
command_instance: BaseCommand = command_class(message, plugin_config)
command_instance.set_matched_groups(matched_groups)
# 为插件实例设置 chat_stream 运行时属性
setattr(command_instance, "chat_stream", chat)
try:
# 检查聊天类型限制
if not command_instance.is_chat_type_allowed():
is_group = message.message_info.group_info
is_group = chat.group_info is not None
logger.info(
f"命令 {command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
)
@@ -295,13 +301,20 @@ class ChatBot:
logger.error(f"处理命令时出错: {e}")
return False, None, True # 出错时继续处理消息
async def handle_notice_message(self, message: MessageRecv):
async def handle_notice_message(self, message: DatabaseMessages):
"""处理notice消息
notice消息是系统事件通知如禁言、戳一戳等具有以下特点
1. 默认不触发聊天流程,只记录
2. 可通过配置开启触发聊天流程
3. 会在提示词中展示
Args:
message: DatabaseMessages 对象
Returns:
bool: True表示notice已完整处理需要存储并终止后续流程
False表示不是notice或notice需要继续处理触发聊天流程
"""
# 检查是否是notice消息
if message.is_notify:
@@ -309,53 +322,42 @@ class ChatBot:
# 根据配置决定是否触发聊天流程
if not global_config.notice.enable_notice_trigger_chat:
logger.debug("notice消息不触发聊天流程配置已关闭")
return True # 返回True表示已处理,不继续后续流程
logger.debug("notice消息不触发聊天流程配置已关闭,将存储后终止")
return True # 返回True:需要在调用处存储并终止
else:
logger.debug("notice消息触发聊天流程配置已开启")
return False # 返回False表示继续处理,触发聊天流程
logger.debug("notice消息触发聊天流程配置已开启,继续处理")
return False # 返回False:继续正常流程,作为普通消息处理
# 兼容旧的notice判断方式
if message.message_info.message_id == "notice":
message.is_notify = True
if message.message_id == "notice":
# 为 DatabaseMessages 设置 is_notify 运行时属性
from src.chat.message_receive.message_processor import set_db_message_runtime_attr
set_db_message_runtime_attr(message, "is_notify", True)
logger.info("旧格式notice消息")
# 同样根据配置决定
if not global_config.notice.enable_notice_trigger_chat:
return True
logger.debug("旧格式notice消息不触发聊天流程将存储后终止")
return True # 需要存储并终止
else:
return False
logger.debug("旧格式notice消息触发聊天流程继续处理")
return False # 继续正常流程
# 处理适配器响应消息
if hasattr(message, "message_segment") and message.message_segment:
if message.message_segment.type == "adapter_response":
await self.handle_adapter_response(message)
return True
elif message.message_segment.type == "adapter_command":
# 适配器命令消息不需要进一步处理
logger.debug("收到适配器命令消息,跳过后续处理")
return True
# DatabaseMessages 不再有 message_segment适配器响应处理已在消息处理阶段完成
# 这里保留逻辑以防万一,但实际上不会再执行到
return False # 不是notice消息继续正常流程
return False
async def handle_adapter_response(self, message: MessageRecv):
"""处理适配器命令响应"""
async def handle_adapter_response(self, message: DatabaseMessages):
"""处理适配器命令响应
注意: 此方法目前未被调用,但保留以备将来使用
"""
try:
from src.plugin_system.apis.send_api import put_adapter_response
seg_data = message.message_segment.data
if isinstance(seg_data, dict):
request_id = seg_data.get("request_id")
response_data = seg_data.get("response")
else:
request_id = None
response_data = None
if request_id and response_data:
logger.debug(f"收到适配器响应: request_id={request_id}")
put_adapter_response(request_id, response_data)
else:
logger.warning("适配器响应消息格式不正确")
# DatabaseMessages 使用 message_segments 字段存储消息段
# 注意: 这可能需要根据实际使用情况进行调整
logger.warning("handle_adapter_response 方法被调用,但目前未实现对 DatabaseMessages 的支持")
except Exception as e:
logger.error(f"处理适配器响应时出错: {e}")
@@ -381,9 +383,6 @@ class ChatBot:
await self._ensure_started()
# 控制握手等消息可能缺少 message_info这里直接跳过避免 KeyError
if not isinstance(message_data, dict):
logger.warning(f"收到无法解析的消息类型: {type(message_data)},已跳过")
return
message_info = message_data.get("message_info")
if not isinstance(message_info, dict):
logger.debug(
@@ -392,8 +391,6 @@ class ChatBot:
)
return
platform = message_info.get("platform")
if message_info.get("group_info") is not None:
message_info["group_info"]["group_id"] = str(
message_info["group_info"]["group_id"]
@@ -404,74 +401,94 @@ class ChatBot:
)
# print(message_data)
# logger.debug(str(message_data))
message = MessageRecv(message_data)
group_info = message.message_info.group_info
user_info = message.message_info.user_info
if message.message_info.additional_config:
sent_message = message.message_info.additional_config.get("echo", False)
# 先提取基础信息检查是否是自身消息上报
from maim_message import BaseMessageInfo
temp_message_info = BaseMessageInfo.from_dict(message_data.get("message_info", {}))
if temp_message_info.additional_config:
sent_message = temp_message_info.additional_config.get("echo", False)
if sent_message: # 这一段只是为了在一切处理前劫持上报的自身消息用于更新message_id需要ada支持上报事件实际测试中不会对正常使用造成任何问题
await MessageStorage.update_message(message)
# 直接使用消息字典更新,不再需要创建 MessageRecv
await MessageStorage.update_message(message_data)
return
group_info = temp_message_info.group_info
user_info = temp_message_info.user_info
get_chat_manager().register_message(message)
# 获取或创建聊天流
chat = await get_chat_manager().get_or_create_stream(
platform=message.message_info.platform, # type: ignore
platform=temp_message_info.platform, # type: ignore
user_info=user_info, # type: ignore
group_info=group_info,
)
message.update_chat_stream(chat)
# 处理消息内容,生成纯文本
await message.process()
# 使用新的消息处理器直接生成 DatabaseMessages
from src.chat.message_receive.message_processor import process_message_from_dict
message = await process_message_from_dict(
message_dict=message_data,
stream_id=chat.stream_id,
platform=chat.platform
)
# 填充聊天流时间信息
message.chat_info.create_time = chat.create_time
message.chat_info.last_active_time = chat.last_active_time
# 注册消息到聊天管理器
get_chat_manager().register_message(message)
# 检测是否提及机器人
message.is_mentioned, _ = is_mentioned_bot_in_message(message)
# 在这里打印[所见]日志,确保在所有处理和过滤之前记录
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
if message.message_info.user_info:
logger.info(
f"[{chat_name}]{message.message_info.user_info.user_nickname}:{message.processed_plain_text}\u001b[0m"
)
user_nickname = message.user_info.user_nickname if message.user_info else "未知用户"
logger.info(
f"[{chat_name}]{user_nickname}:{message.processed_plain_text}\u001b[0m"
)
# 在此添加硬编码过滤,防止回复图片处理失败的消息
failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"]
if any(keyword in message.processed_plain_text for keyword in failure_keywords):
logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({message.processed_plain_text}),消息被静默处理。")
processed_text = message.processed_plain_text or ""
if any(keyword in processed_text for keyword in failure_keywords):
logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({processed_text}),消息被静默处理。")
return
# 处理notice消息
# notice_handled=True: 表示notice不触发聊天需要在此存储并终止
# notice_handled=False: 表示notice触发聊天或不是notice继续正常流程
notice_handled = await self.handle_notice_message(message)
if notice_handled:
# notice消息已处理,使用统一的转换方法
# notice消息不触发聊天流程,在此进行存储和记录后终止
try:
# 直接转换为 DatabaseMessages
db_message = message.to_database_message()
# message 已经是 DatabaseMessages,直接使用
# 添加到message_manager这会将notice添加到全局notice管理器
await message_manager.add_message(message.chat_stream.stream_id, db_message)
logger.info(f"✅ Notice消息已添加到message_manager: type={message.notice_type}, stream={message.chat_stream.stream_id}")
await message_manager.add_message(chat.stream_id, message)
logger.info(f"✅ Notice消息已添加到message_manager: type={message.notice_type}, stream={chat.stream_id}")
except Exception as e:
logger.error(f"Notice消息添加到message_manager失败: {e}", exc_info=True)
# 存储后直接返回
await MessageStorage.store_message(message, chat)
logger.debug("notice消息已存储,跳过后续处理")
# 存储notice消息到数据库需要更新 storage.py 支持 DatabaseMessages
# 暂时跳过存储,等待更新 storage.py
logger.debug("notice消息已添加到message_manager存储功能待更新")
return
# 如果notice_handled=False则继续执行后续流程
# 对于启用触发聊天的notice会在后续的正常流程中被存储和处理
# 过滤检查
# DatabaseMessages 使用 display_message 作为原始消息表示
raw_text = message.display_message or message.processed_plain_text or ""
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
message.raw_message, # type: ignore
raw_text,
chat,
user_info, # type: ignore
):
return
# 命令处理 - 首先尝试PlusCommand独立处理
is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message)
is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message, chat)
# 如果是PlusCommand且不需要继续处理则直接返回
if is_plus_command and not plus_continue_process:
@@ -481,7 +498,7 @@ class ChatBot:
# 如果不是PlusCommand尝试传统的BaseCommand处理
if not is_plus_command:
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message)
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message, chat)
# 如果是命令且不需要继续处理,则直接返回
if is_command and not continue_process:
@@ -493,24 +510,14 @@ class ChatBot:
if result and not result.all_continue_process():
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理")
# TODO:暂不可用
# TODO:暂不可用 - DatabaseMessages 不再有 message_info.template_info
# 确认从接口发来的message是否有自定义的prompt模板信息
if message.message_info.template_info and not message.message_info.template_info.template_default:
template_group_name: str | None = message.message_info.template_info.template_name # type: ignore
template_items = message.message_info.template_info.template_items
async with global_prompt_manager.async_message_scope(template_group_name):
if isinstance(template_items, dict):
for k in template_items.keys():
await create_prompt_async(template_items[k], k)
logger.debug(f"注册{template_items[k]},{k}")
else:
template_group_name = None
# 这个功能需要在 adapter 层通过 additional_config 传递
template_group_name = None
async def preprocess():
# 使用统一的转换方法创建数据库消息对象
db_message = message.to_database_message()
group_info = getattr(message.chat_stream, "group_info", None)
# message 已经是 DatabaseMessages直接使用
group_info = chat.group_info
# 先交给消息管理器处理,计算兴趣度等衍生数据
try:
@@ -527,31 +534,15 @@ class ChatBot:
should_process_in_manager = False
if should_process_in_manager:
await message_manager.add_message(message.chat_stream.stream_id, db_message)
logger.debug(f"消息已添加到消息管理器: {message.chat_stream.stream_id}")
await message_manager.add_message(chat.stream_id, message)
logger.debug(f"消息已添加到消息管理器: {chat.stream_id}")
except Exception as e:
logger.error(f"消息添加到消息管理器失败: {e}")
# 将兴趣度结果同步回原始消息,便于后续流程使用
message.interest_value = getattr(db_message, "interest_value", getattr(message, "interest_value", 0.0))
setattr(
message,
"should_reply",
getattr(db_message, "should_reply", getattr(message, "should_reply", False)),
)
setattr(message, "should_act", getattr(db_message, "should_act", getattr(message, "should_act", False)))
# 存储消息到数据库,只进行一次写入
try:
await MessageStorage.store_message(message, message.chat_stream)
logger.debug(
"消息已存储到数据库: %s (interest=%.3f, should_reply=%s, should_act=%s)",
message.message_info.message_id,
getattr(message, "interest_value", -1.0),
getattr(message, "should_reply", None),
getattr(message, "should_act", None),
)
await MessageStorage.store_message(message, chat)
except Exception as e:
logger.error(f"存储消息到数据库失败: {e}")
traceback.print_exc()
@@ -560,13 +551,13 @@ class ChatBot:
try:
if global_config.mood.enable_mood:
# 获取兴趣度用于情绪更新
interest_rate = getattr(message, "interest_value", 0.0)
interest_rate = message.interest_value
if interest_rate is None:
interest_rate = 0.0
logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}")
# 获取当前聊天的情绪对象并更新情绪状态
chat_mood = mood_manager.get_mood_by_chat_id(message.chat_stream.stream_id)
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
await chat_mood.update_mood_by_message(message, interest_rate)
logger.debug("情绪状态更新完成")
except Exception as e:

View File

@@ -12,13 +12,10 @@ from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.config.config import global_config # 新增导入
# 避免循环导入使用TYPE_CHECKING进行类型提示
if TYPE_CHECKING:
from .message import MessageRecv
install(extra_lines=3)
@@ -33,7 +30,7 @@ class ChatStream:
self,
stream_id: str,
platform: str,
user_info: UserInfo,
user_info: UserInfo | None = None,
group_info: GroupInfo | None = None,
data: dict | None = None,
):
@@ -46,20 +43,18 @@ class ChatStream:
self.sleep_pressure = data.get("sleep_pressure", 0.0) if data else 0.0
self.saved = False
# 使用StreamContext替代ChatMessageContext
# 创建单流上下文管理器包含StreamContext
from src.chat.message_manager.context_manager import SingleStreamContextManager
from src.common.data_models.message_manager_data_model import StreamContext
from src.plugin_system.base.component_types import ChatMode, ChatType
# 创建StreamContext
self.stream_context: StreamContext = StreamContext(
stream_id=stream_id, chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, chat_mode=ChatMode.NORMAL
)
# 创建单流上下文管理器
from src.chat.message_manager.context_manager import SingleStreamContextManager
self.context_manager: SingleStreamContextManager = SingleStreamContextManager(
stream_id=stream_id, context=self.stream_context
stream_id=stream_id,
context=StreamContext(
stream_id=stream_id,
chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE,
chat_mode=ChatMode.NORMAL,
),
)
# 基础参数
@@ -88,13 +83,12 @@ class ChatStream:
new_stream._focus_energy = self._focus_energy
new_stream.no_reply_consecutive = self.no_reply_consecutive
# 复制 stream_context但跳过 processing_task
new_stream.stream_context = copy.deepcopy(self.stream_context, memo)
if hasattr(new_stream.stream_context, "processing_task"):
new_stream.stream_context.processing_task = None
# 复制 context_manager
# 复制 context_manager包含 stream_context
new_stream.context_manager = copy.deepcopy(self.context_manager, memo)
# 清理 processing_task如果存在
if hasattr(new_stream.context_manager.context, "processing_task"):
new_stream.context_manager.context.processing_task = None
return new_stream
@@ -111,11 +105,11 @@ class ChatStream:
"focus_energy": self.focus_energy,
# 基础兴趣度
"base_interest_energy": self.base_interest_energy,
# stream_context基本信息
"stream_context_chat_type": self.stream_context.chat_type.value,
"stream_context_chat_mode": self.stream_context.chat_mode.value,
# stream_context基本信息通过context_manager访问
"stream_context_chat_type": self.context_manager.context.chat_type.value,
"stream_context_chat_mode": self.context_manager.context.chat_mode.value,
# 统计信息
"interruption_count": self.stream_context.interruption_count,
"interruption_count": self.context_manager.context.interruption_count,
}
@classmethod
@@ -132,27 +126,19 @@ class ChatStream:
data=data,
)
# 恢复stream_context信息
# 恢复stream_context信息通过context_manager访问
if "stream_context_chat_type" in data:
from src.plugin_system.base.component_types import ChatMode, ChatType
instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"])
instance.context_manager.context.chat_type = ChatType(data["stream_context_chat_type"])
if "stream_context_chat_mode" in data:
from src.plugin_system.base.component_types import ChatMode, ChatType
instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"])
instance.context_manager.context.chat_mode = ChatMode(data["stream_context_chat_mode"])
# 恢复interruption_count信息
if "interruption_count" in data:
instance.stream_context.interruption_count = data["interruption_count"]
# 确保 context_manager 已初始化
if not hasattr(instance, "context_manager"):
from src.chat.message_manager.context_manager import SingleStreamContextManager
instance.context_manager = SingleStreamContextManager(
stream_id=instance.stream_id, context=instance.stream_context
)
instance.context_manager.context.interruption_count = data["interruption_count"]
return instance
@@ -160,156 +146,44 @@ class ChatStream:
"""获取原始的、未哈希的聊天流ID字符串"""
if self.group_info:
return f"{self.platform}:{self.group_info.group_id}:group"
else:
elif self.user_info:
return f"{self.platform}:{self.user_info.user_id}:private"
else:
return f"{self.platform}:unknown:private"
def update_active_time(self):
"""更新最后活跃时间"""
self.last_active_time = time.time()
self.saved = False
async def set_context(self, message: "MessageRecv"):
"""设置聊天消息上下文"""
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
import json
from src.common.data_models.database_data_model import DatabaseMessages
# 安全获取message_info中的数据
message_info = getattr(message, "message_info", {})
user_info = getattr(message_info, "user_info", {})
group_info = getattr(message_info, "group_info", {})
# 提取reply_to信息从message_segment中查找reply类型的段
reply_to = None
if hasattr(message, "message_segment") and message.message_segment:
reply_to = self._extract_reply_from_segment(message.message_segment)
# 完整的数据转移逻辑
db_message = DatabaseMessages(
# 基础消息信息
message_id=getattr(message, "message_id", ""),
time=getattr(message, "time", time.time()),
chat_id=self._generate_chat_id(message_info),
reply_to=reply_to,
# 兴趣度相关
interest_value=getattr(message, "interest_value", 0.0),
# 关键词
key_words=json.dumps(getattr(message, "key_words", []), ensure_ascii=False)
if getattr(message, "key_words", None)
else None,
key_words_lite=json.dumps(getattr(message, "key_words_lite", []), ensure_ascii=False)
if getattr(message, "key_words_lite", None)
else None,
# 消息状态标记
is_mentioned=getattr(message, "is_mentioned", None),
is_at=getattr(message, "is_at", False),
is_emoji=getattr(message, "is_emoji", False),
is_picid=getattr(message, "is_picid", False),
is_voice=getattr(message, "is_voice", False),
is_video=getattr(message, "is_video", False),
is_command=getattr(message, "is_command", False),
is_notify=getattr(message, "is_notify", False),
is_public_notice=getattr(message, "is_public_notice", False),
notice_type=getattr(message, "notice_type", None),
# 消息内容
processed_plain_text=getattr(message, "processed_plain_text", ""),
display_message=getattr(message, "processed_plain_text", ""), # 默认使用processed_plain_text
# 优先级信息
priority_mode=getattr(message, "priority_mode", None),
priority_info=json.dumps(getattr(message, "priority_info", None))
if getattr(message, "priority_info", None)
else None,
# 额外配置 - 需要将 format_info 嵌入到 additional_config 中
additional_config=self._prepare_additional_config(message_info),
# 用户信息
user_id=str(getattr(user_info, "user_id", "")),
user_nickname=getattr(user_info, "user_nickname", ""),
user_cardname=getattr(user_info, "user_cardname", None),
user_platform=getattr(user_info, "platform", ""),
# 群组信息
chat_info_group_id=getattr(group_info, "group_id", None),
chat_info_group_name=getattr(group_info, "group_name", None),
chat_info_group_platform=getattr(group_info, "platform", None),
# 聊天流信息
chat_info_user_id=str(getattr(user_info, "user_id", "")),
chat_info_user_nickname=getattr(user_info, "user_nickname", ""),
chat_info_user_cardname=getattr(user_info, "user_cardname", None),
chat_info_user_platform=getattr(user_info, "platform", ""),
chat_info_stream_id=self.stream_id,
chat_info_platform=self.platform,
chat_info_create_time=self.create_time,
chat_info_last_active_time=self.last_active_time,
# 新增兴趣度系统字段 - 添加安全处理
actions=self._safe_get_actions(message),
should_reply=getattr(message, "should_reply", False),
should_act=getattr(message, "should_act", False),
)
self.stream_context.set_current_message(db_message)
self.stream_context.priority_mode = getattr(message, "priority_mode", None)
self.stream_context.priority_info = getattr(message, "priority_info", None)
# 调试日志:记录数据转移情况
logger.debug(
f"消息数据转移完成 - message_id: {db_message.message_id}, "
f"chat_id: {db_message.chat_id}, "
f"is_mentioned: {db_message.is_mentioned}, "
f"is_emoji: {db_message.is_emoji}, "
f"is_picid: {db_message.is_picid}, "
f"interest_value: {db_message.interest_value}"
)
def _prepare_additional_config(self, message_info) -> str | None:
"""
准备 additional_config将 format_info 嵌入其中
这个方法模仿 storage.py 中的逻辑,确保 DatabaseMessages 中的 additional_config
包含 format_info使得 action_modifier 能够正确获取适配器支持的消息类型
async def set_context(self, message: DatabaseMessages):
"""设置聊天消息上下文
Args:
message_info: BaseMessageInfo 对象
Returns:
str | None: JSON 字符串格式的 additional_config如果为空则返回 None
message: DatabaseMessages 对象,直接使用不需要转换
"""
import orjson
# 首先获取adapter传递的additional_config
additional_config_data = {}
if hasattr(message_info, 'additional_config') and message_info.additional_config:
if isinstance(message_info.additional_config, dict):
additional_config_data = message_info.additional_config.copy()
elif isinstance(message_info.additional_config, str):
# 如果是字符串,尝试解析
try:
additional_config_data = orjson.loads(message_info.additional_config)
except Exception as e:
logger.warning(f"无法解析 additional_config JSON: {e}")
additional_config_data = {}
# 直接使用传入的 DatabaseMessages设置到上下文中
self.context_manager.context.set_current_message(message)
# 然后添加format_info到additional_config中
if hasattr(message_info, 'format_info') and message_info.format_info:
try:
format_info_dict = message_info.format_info.to_dict()
additional_config_data["format_info"] = format_info_dict
logger.debug(f"嵌入 format_info 到 additional_config: {format_info_dict}")
except Exception as e:
logger.warning(f"将 format_info 转换为字典失败: {e}")
else:
logger.warning(f"[问题] 消息缺少 format_info: message_id={getattr(message_info, 'message_id', 'unknown')}")
logger.warning("[问题] 这可能导致 Action 无法正确检查适配器支持的类型")
# 序列化为JSON字符串
if additional_config_data:
try:
return orjson.dumps(additional_config_data).decode("utf-8")
except Exception as e:
logger.error(f"序列化 additional_config 失败: {e}")
return None
return None
# 设置优先级信息(如果存在)
priority_mode = getattr(message, "priority_mode", None)
priority_info = getattr(message, "priority_info", None)
if priority_mode:
self.context_manager.context.priority_mode = priority_mode
if priority_info:
self.context_manager.context.priority_info = priority_info
def _safe_get_actions(self, message: "MessageRecv") -> list | None:
# 调试日志
logger.debug(
f"消息上下文已设置 - message_id: {message.message_id}, "
f"chat_id: {message.chat_id}, "
f"is_mentioned: {message.is_mentioned}, "
f"is_emoji: {message.is_emoji}, "
f"is_picid: {message.is_picid}, "
f"interest_value: {message.interest_value}"
)
def _safe_get_actions(self, message: DatabaseMessages) -> list | None:
"""安全获取消息的actions字段"""
import json
@@ -380,23 +254,6 @@ class ChatStream:
if hasattr(db_message, "should_act"):
db_message.should_act = False
def _extract_reply_from_segment(self, segment) -> str | None:
"""从消息段中提取reply_to信息"""
try:
if hasattr(segment, "type") and segment.type == "seglist":
# 递归搜索seglist中的reply段
if hasattr(segment, "data") and segment.data:
for seg in segment.data:
reply_id = self._extract_reply_from_segment(seg)
if reply_id:
return reply_id
elif hasattr(segment, "type") and segment.type == "reply":
# 找到reply段返回message_id
return str(segment.data) if segment.data else None
except Exception as e:
logger.warning(f"提取reply_to信息失败: {e}")
return None
def _generate_chat_id(self, message_info) -> str:
"""生成chat_id基于群组或用户信息"""
try:
@@ -493,8 +350,10 @@ class ChatManager:
def __init__(self):
if not self._initialized:
from src.common.data_models.database_data_model import DatabaseMessages
self.streams: dict[str, ChatStream] = {} # stream_id -> ChatStream
self.last_messages: dict[str, "MessageRecv"] = {} # stream_id -> last_message
self.last_messages: dict[str, DatabaseMessages] = {} # stream_id -> last_message
# try:
# async with get_db_session() as session:
# db.connect(reuse_if_open=True)
@@ -528,12 +387,30 @@ class ChatManager:
except Exception as e:
logger.error(f"聊天流自动保存失败: {e!s}")
def register_message(self, message: "MessageRecv"):
def register_message(self, message: DatabaseMessages):
"""注册消息到聊天流"""
# 从 DatabaseMessages 提取平台和用户/群组信息
from maim_message import UserInfo, GroupInfo
user_info = UserInfo(
platform=message.user_info.platform,
user_id=message.user_info.user_id,
user_nickname=message.user_info.user_nickname,
user_cardname=message.user_info.user_cardname or ""
)
group_info = None
if message.group_info:
group_info = GroupInfo(
platform=message.group_info.group_platform or "",
group_id=message.group_info.group_id,
group_name=message.group_info.group_name
)
stream_id = self._generate_stream_id(
message.message_info.platform, # type: ignore
message.message_info.user_info,
message.message_info.group_info,
message.chat_info.platform,
user_info,
group_info,
)
self.last_messages[stream_id] = message
# logger.debug(f"注册消息到聊天流: {stream_id}")
@@ -578,32 +455,6 @@ class ChatManager:
try:
stream_id = self._generate_stream_id(platform, user_info, group_info)
# 优先使用缓存管理器(优化版本)
try:
from src.chat.message_manager.stream_cache_manager import get_stream_cache_manager
cache_manager = get_stream_cache_manager()
if cache_manager.is_running:
optimized_stream = await cache_manager.get_or_create_stream(
stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info
)
# 设置消息上下文
from .message import MessageRecv
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
optimized_stream.set_context(self.last_messages[stream_id])
# 转换为原始ChatStream以保持兼容性
original_stream = self._convert_to_original_stream(optimized_stream)
return original_stream
except Exception as e:
logger.debug(f"缓存管理器获取流失败,使用原始方法: {e}")
# 回退到原始方法
# 检查内存中是否存在
if stream_id in self.streams:
stream = self.streams[stream_id]
@@ -615,12 +466,13 @@ class ChatManager:
stream.user_info = user_info
if group_info:
stream.group_info = group_info
from .message import MessageRecv # 延迟导入,避免循环引用
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
# 检查是否有最后一条消息(现在使用 DatabaseMessages
from src.common.data_models.database_data_model import DatabaseMessages
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
await stream.set_context(self.last_messages[stream_id])
else:
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的或还没有消息")
return stream
# 检查数据库中是否存在
@@ -679,19 +531,27 @@ class ChatManager:
raise e
stream = copy.deepcopy(stream)
from .message import MessageRecv # 延迟导入,避免循环引用
from src.common.data_models.database_data_model import DatabaseMessages
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
await stream.set_context(self.last_messages[stream_id])
else:
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
# 确保 ChatStream 有自己的 context_manager
if not hasattr(stream, "context_manager"):
# 创建新的单流上下文管理器
from src.chat.message_manager.context_manager import SingleStreamContextManager
from src.common.data_models.message_manager_data_model import StreamContext
from src.plugin_system.base.component_types import ChatMode, ChatType
stream.context_manager = SingleStreamContextManager(stream_id=stream_id, context=stream.stream_context)
stream.context_manager = SingleStreamContextManager(
stream_id=stream_id,
context=StreamContext(
stream_id=stream_id,
chat_type=ChatType.GROUP if stream.group_info else ChatType.PRIVATE,
chat_mode=ChatMode.NORMAL,
),
)
# 保存到内存和数据库
self.streams[stream_id] = stream
@@ -700,10 +560,12 @@ class ChatManager:
async def get_stream(self, stream_id: str) -> ChatStream | None:
"""通过stream_id获取聊天流"""
from src.common.data_models.database_data_model import DatabaseMessages
stream = self.streams.get(stream_id)
if not stream:
return None
if stream_id in self.last_messages:
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
await stream.set_context(self.last_messages[stream_id])
return stream
@@ -921,9 +783,16 @@ class ChatManager:
# 确保 ChatStream 有自己的 context_manager
if not hasattr(stream, "context_manager"):
from src.chat.message_manager.context_manager import SingleStreamContextManager
from src.common.data_models.message_manager_data_model import StreamContext
from src.plugin_system.base.component_types import ChatMode, ChatType
stream.context_manager = SingleStreamContextManager(
stream_id=stream.stream_id, context=stream.stream_context
stream_id=stream.stream_id,
context=StreamContext(
stream_id=stream.stream_id,
chat_type=ChatType.GROUP if stream.group_info else ChatType.PRIVATE,
chat_mode=ChatMode.NORMAL,
),
)
except Exception as e:
logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}", exc_info=True)
@@ -932,46 +801,6 @@ class ChatManager:
chat_manager = None
def _convert_to_original_stream(self, optimized_stream) -> "ChatStream":
"""将OptimizedChatStream转换为原始ChatStream以保持兼容性"""
try:
# 创建原始ChatStream实例
original_stream = ChatStream(
stream_id=optimized_stream.stream_id,
platform=optimized_stream.platform,
user_info=optimized_stream._get_effective_user_info(),
group_info=optimized_stream._get_effective_group_info(),
)
# 复制状态
original_stream.create_time = optimized_stream.create_time
original_stream.last_active_time = optimized_stream.last_active_time
original_stream.sleep_pressure = optimized_stream.sleep_pressure
original_stream.base_interest_energy = optimized_stream.base_interest_energy
original_stream._focus_energy = optimized_stream._focus_energy
original_stream.no_reply_consecutive = optimized_stream.no_reply_consecutive
original_stream.saved = optimized_stream.saved
# 复制上下文信息(如果存在)
if hasattr(optimized_stream, "_stream_context") and optimized_stream._stream_context:
original_stream.stream_context = optimized_stream._stream_context
if hasattr(optimized_stream, "_context_manager") and optimized_stream._context_manager:
original_stream.context_manager = optimized_stream._context_manager
return original_stream
except Exception as e:
logger.error(f"转换OptimizedChatStream失败: {e}")
# 如果转换失败,创建一个新的原始流
return ChatStream(
stream_id=optimized_stream.stream_id,
platform=optimized_stream.platform,
user_info=optimized_stream._get_effective_user_info(),
group_info=optimized_stream._get_effective_group_info(),
)
def get_chat_manager():
global chat_manager
if chat_manager is None:

View File

@@ -2,7 +2,7 @@ import base64
import time
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any, Optional, Union
import urllib3
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo
@@ -13,6 +13,7 @@ from src.chat.utils.self_voice_cache import consume_self_voice_text
from src.chat.utils.utils_image import get_image_manager
from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available
from src.chat.utils.utils_voice import get_voice_text
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.config.config import global_config
@@ -43,7 +44,7 @@ class Message(MessageBase, metaclass=ABCMeta):
user_info: UserInfo,
message_segment: Seg | None = None,
timestamp: float | None = None,
reply: Optional["MessageRecv"] = None,
reply: Optional["DatabaseMessages"] = None,
processed_plain_text: str = "",
):
# 使用传入的时间戳或当前时间
@@ -95,346 +96,12 @@ class Message(MessageBase, metaclass=ABCMeta):
@dataclass
class MessageRecv(Message):
"""接收消息类用于处理从MessageCQ序列化的消息"""
def __init__(self, message_dict: dict[str, Any]):
"""从MessageCQ的字典初始化
Args:
message_dict: MessageCQ序列化后的字典
"""
# Manually initialize attributes from MessageBase and Message
self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
self.raw_message = message_dict.get("raw_message")
self.chat_stream = None
self.reply = None
self.processed_plain_text = message_dict.get("processed_plain_text", "")
self.memorized_times = 0
# MessageRecv specific attributes
self.is_emoji = False
self.has_emoji = False
self.is_picid = False
self.has_picid = False
self.is_voice = False
self.is_video = False
self.is_mentioned = None
self.is_notify = False # 是否为notice消息
self.is_public_notice = False # 是否为公共notice
self.notice_type = None # notice类型
self.is_at = False
self.is_command = False
self.priority_mode = "interest"
self.priority_info = None
self.interest_value: float = 0.0
self.key_words = []
self.key_words_lite = []
# 解析additional_config中的notice信息
if self.message_info.additional_config and isinstance(self.message_info.additional_config, dict):
self.is_notify = self.message_info.additional_config.get("is_notice", False)
self.is_public_notice = self.message_info.additional_config.get("is_public_notice", False)
self.notice_type = self.message_info.additional_config.get("notice_type")
def update_chat_stream(self, chat_stream: "ChatStream"):
self.chat_stream = chat_stream
def to_database_message(self) -> "DatabaseMessages":
"""将 MessageRecv 转换为 DatabaseMessages 对象
Returns:
DatabaseMessages: 数据库消息对象
"""
from src.common.data_models.database_data_model import DatabaseMessages
import json
import time
message_info = self.message_info
msg_user_info = getattr(message_info, "user_info", None)
stream_user_info = getattr(self.chat_stream, "user_info", None) if self.chat_stream else None
group_info = getattr(self.chat_stream, "group_info", None) if self.chat_stream else None
message_id = message_info.message_id or ""
message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time()
is_mentioned = None
if isinstance(self.is_mentioned, bool):
is_mentioned = self.is_mentioned
elif isinstance(self.is_mentioned, int | float):
is_mentioned = self.is_mentioned != 0
# 提取用户信息
user_id = ""
user_nickname = ""
user_cardname = None
user_platform = ""
if msg_user_info:
user_id = str(getattr(msg_user_info, "user_id", "") or "")
user_nickname = getattr(msg_user_info, "user_nickname", "") or ""
user_cardname = getattr(msg_user_info, "user_cardname", None)
user_platform = getattr(msg_user_info, "platform", "") or ""
elif stream_user_info:
user_id = str(getattr(stream_user_info, "user_id", "") or "")
user_nickname = getattr(stream_user_info, "user_nickname", "") or ""
user_cardname = getattr(stream_user_info, "user_cardname", None)
user_platform = getattr(stream_user_info, "platform", "") or ""
# 提取聊天流信息
chat_user_id = str(getattr(stream_user_info, "user_id", "") or "") if stream_user_info else ""
chat_user_nickname = getattr(stream_user_info, "user_nickname", "") or "" if stream_user_info else ""
chat_user_cardname = getattr(stream_user_info, "user_cardname", None) if stream_user_info else None
chat_user_platform = getattr(stream_user_info, "platform", "") or "" if stream_user_info else ""
group_id = getattr(group_info, "group_id", None) if group_info else None
group_name = getattr(group_info, "group_name", None) if group_info else None
group_platform = getattr(group_info, "platform", None) if group_info else None
# 准备 additional_config
additional_config_str = None
try:
import orjson
additional_config_data = {}
# 首先获取adapter传递的additional_config
if hasattr(message_info, 'additional_config') and message_info.additional_config:
if isinstance(message_info.additional_config, dict):
additional_config_data = message_info.additional_config.copy()
elif isinstance(message_info.additional_config, str):
try:
additional_config_data = orjson.loads(message_info.additional_config)
except Exception as e:
logger.warning(f"无法解析 additional_config JSON: {e}")
additional_config_data = {}
# 添加notice相关标志
if self.is_notify:
additional_config_data["is_notice"] = True
additional_config_data["notice_type"] = self.notice_type or "unknown"
additional_config_data["is_public_notice"] = bool(self.is_public_notice)
# 添加format_info到additional_config中
if hasattr(message_info, 'format_info') and message_info.format_info:
try:
format_info_dict = message_info.format_info.to_dict()
additional_config_data["format_info"] = format_info_dict
logger.debug(f"[message.py] 嵌入 format_info 到 additional_config: {format_info_dict}")
except Exception as e:
logger.warning(f"将 format_info 转换为字典失败: {e}")
# 序列化为JSON字符串
if additional_config_data:
additional_config_str = orjson.dumps(additional_config_data).decode("utf-8")
except Exception as e:
logger.error(f"准备 additional_config 失败: {e}")
# 创建数据库消息对象
db_message = DatabaseMessages(
message_id=message_id,
time=float(message_time),
chat_id=self.chat_stream.stream_id if self.chat_stream else "",
processed_plain_text=self.processed_plain_text,
display_message=self.processed_plain_text,
is_mentioned=is_mentioned,
is_at=bool(self.is_at) if self.is_at is not None else None,
is_emoji=bool(self.is_emoji),
is_picid=bool(self.is_picid),
is_command=bool(self.is_command),
is_notify=bool(self.is_notify),
is_public_notice=bool(self.is_public_notice),
notice_type=self.notice_type,
additional_config=additional_config_str,
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
user_platform=user_platform,
chat_info_stream_id=self.chat_stream.stream_id if self.chat_stream else "",
chat_info_platform=self.chat_stream.platform if self.chat_stream else "",
chat_info_create_time=float(self.chat_stream.create_time) if self.chat_stream else 0.0,
chat_info_last_active_time=float(self.chat_stream.last_active_time) if self.chat_stream else 0.0,
chat_info_user_id=chat_user_id,
chat_info_user_nickname=chat_user_nickname,
chat_info_user_cardname=chat_user_cardname,
chat_info_user_platform=chat_user_platform,
chat_info_group_id=group_id,
chat_info_group_name=group_name,
chat_info_group_platform=group_platform,
)
# 同步兴趣度等衍生属性
db_message.interest_value = getattr(self, "interest_value", 0.0)
setattr(db_message, "should_reply", getattr(self, "should_reply", False))
setattr(db_message, "should_act", getattr(self, "should_act", False))
return db_message
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本
这个方法必须在创建实例后显式调用,因为它包含异步操作。
"""
self.processed_plain_text = await self._process_message_segments(self.message_segment)
async def _process_single_segment(self, segment: Seg) -> str:
"""处理单个消息段
Args:
segment: 消息段
Returns:
str: 处理后的文本
"""
try:
if segment.type == "text":
self.is_picid = False
self.is_emoji = False
self.is_video = False
return segment.data # type: ignore
elif segment.type == "at":
self.is_picid = False
self.is_emoji = False
self.is_video = False
# 处理at消息格式为"昵称:QQ号"
if isinstance(segment.data, str) and ":" in segment.data:
nickname, qq_id = segment.data.split(":", 1)
return f"@{nickname}"
return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户"
elif segment.type == "image":
# 如果是base64图片数据
if isinstance(segment.data, str):
self.has_picid = True
self.is_picid = True
self.is_emoji = False
self.is_video = False
image_manager = get_image_manager()
# print(f"segment.data: {segment.data}")
_, processed_text = await image_manager.process_image(segment.data)
return processed_text
return "[发了一张图片,网卡了加载不出来]"
elif segment.type == "emoji":
self.has_emoji = True
self.is_emoji = True
self.is_picid = False
self.is_voice = False
self.is_video = False
if isinstance(segment.data, str):
return await get_image_manager().get_emoji_description(segment.data)
return "[发了一个表情包,网卡了加载不出来]"
elif segment.type == "voice":
self.is_picid = False
self.is_emoji = False
self.is_voice = True
self.is_video = False
# 检查消息是否由机器人自己发送
if self.message_info and self.message_info.user_info and str(self.message_info.user_info.user_id) == str(global_config.bot.qq_account):
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {self.message_info.user_info.user_id}),尝试从缓存获取文本。")
if isinstance(segment.data, str):
cached_text = consume_self_voice_text(segment.data)
if cached_text:
logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'")
return f"[语音:{cached_text}]"
else:
logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
# 标准语音识别流程 (也作为缓存未命中的后备方案)
if isinstance(segment.data, str):
return await get_voice_text(segment.data)
return "[发了一段语音,网卡了加载不出来]"
elif segment.type == "mention_bot":
self.is_picid = False
self.is_emoji = False
self.is_voice = False
self.is_video = False
self.is_mentioned = float(segment.data) # type: ignore
return ""
elif segment.type == "priority_info":
self.is_picid = False
self.is_emoji = False
self.is_voice = False
if isinstance(segment.data, dict):
# 处理优先级信息
self.priority_mode = "priority"
self.priority_info = segment.data
"""
{
'message_type': 'vip', # vip or normal
'message_priority': 1.0, # 优先级大为优先float
}
"""
return ""
elif segment.type == "file":
if isinstance(segment.data, dict):
file_name = segment.data.get('name', '未知文件')
file_size = segment.data.get('size', '未知大小')
return f"[文件:{file_name} ({file_size}字节)]"
return "[收到一个文件]"
elif segment.type == "video":
self.is_picid = False
self.is_emoji = False
self.is_voice = False
self.is_video = True
logger.info(f"接收到视频消息,数据类型: {type(segment.data)}")
# 检查视频分析功能是否可用
if not is_video_analysis_available():
logger.warning("⚠️ Rust视频处理模块不可用跳过视频分析")
return "[视频]"
if global_config.video_analysis.enable:
logger.info("已启用视频识别,开始识别")
if isinstance(segment.data, dict):
try:
# 从Adapter接收的视频数据
video_base64 = segment.data.get("base64")
filename = segment.data.get("filename", "video.mp4")
logger.info(f"视频文件名: {filename}")
logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
if video_base64:
# 解码base64视频数据
video_bytes = base64.b64decode(video_base64)
logger.info(f"解码后视频大小: {len(video_bytes)} 字节")
# 使用video analyzer分析视频
video_analyzer = get_video_analyzer()
result = await video_analyzer.analyze_video_from_bytes(
video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt
)
logger.info(f"视频分析结果: {result}")
# 返回视频分析结果
summary = result.get("summary", "")
if summary:
return f"[视频内容] {summary}"
else:
return "[已收到视频,但分析失败]"
else:
logger.warning("视频消息中没有base64数据")
return "[收到视频消息,但数据异常]"
except Exception as e:
logger.error(f"视频处理失败: {e!s}")
import traceback
logger.error(f"错误详情: {traceback.format_exc()}")
return "[收到视频,但处理时出现错误]"
else:
logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}")
return "[发了一个视频,但格式不支持]"
else:
return ""
else:
logger.warning(f"未知的消息段类型: {segment.type}")
return f"[{segment.type} 消息]"
except Exception as e:
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
# MessageRecv 类已被完全移除,现在统一使用 DatabaseMessages
# 如需从消息字典创建 DatabaseMessages请使用
# from src.chat.message_receive.message_processor import process_message_from_dict
#
# 迁移完成日期: 2025-10-31
@dataclass
@@ -447,7 +114,7 @@ class MessageProcessBase(Message):
chat_stream: "ChatStream",
bot_user_info: UserInfo,
message_segment: Seg | None = None,
reply: Optional["MessageRecv"] = None,
reply: Optional["DatabaseMessages"] = None,
thinking_start_time: float = 0,
timestamp: float | None = None,
):
@@ -548,7 +215,7 @@ class MessageSending(MessageProcessBase):
sender_info: UserInfo | None, # 用来记录发送者信息
message_segment: Seg,
display_message: str = "",
reply: Optional["MessageRecv"] = None,
reply: Optional["DatabaseMessages"] = None,
is_head: bool = False,
is_emoji: bool = False,
thinking_start_time: float = 0,
@@ -567,7 +234,11 @@ class MessageSending(MessageProcessBase):
# 发送状态特有属性
self.sender_info = sender_info
self.reply_to_message_id = reply.message_info.message_id if reply else None
# 从 DatabaseMessages 获取 message_id
if reply:
self.reply_to_message_id = reply.message_id
else:
self.reply_to_message_id = None
self.is_head = is_head
self.is_emoji = is_emoji
self.apply_set_reply_logic = apply_set_reply_logic
@@ -582,14 +253,18 @@ class MessageSending(MessageProcessBase):
def build_reply(self):
"""设置回复消息"""
if self.reply:
self.reply_to_message_id = self.reply.message_info.message_id
self.message_segment = Seg(
type="seglist",
data=[
Seg(type="reply", data=self.reply.message_info.message_id), # type: ignore
self.message_segment,
],
)
# 从 DatabaseMessages 获取 message_id
message_id = self.reply.message_id
if message_id:
self.reply_to_message_id = message_id
self.message_segment = Seg(
type="seglist",
data=[
Seg(type="reply", data=message_id), # type: ignore
self.message_segment,
],
)
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本"""
@@ -607,48 +282,5 @@ class MessageSending(MessageProcessBase):
return self.message_info.group_info is None or self.message_info.group_info.group_id is None
def message_recv_from_dict(message_dict: dict) -> MessageRecv:
return MessageRecv(message_dict)
def message_from_db_dict(db_dict: dict) -> MessageRecv:
"""从数据库字典创建MessageRecv实例"""
# 转换扁平的数据库字典为嵌套结构
message_info_dict = {
"platform": db_dict.get("chat_info_platform"),
"message_id": db_dict.get("message_id"),
"time": db_dict.get("time"),
"group_info": {
"platform": db_dict.get("chat_info_group_platform"),
"group_id": db_dict.get("chat_info_group_id"),
"group_name": db_dict.get("chat_info_group_name"),
},
"user_info": {
"platform": db_dict.get("user_platform"),
"user_id": db_dict.get("user_id"),
"user_nickname": db_dict.get("user_nickname"),
"user_cardname": db_dict.get("user_cardname"),
},
}
processed_text = db_dict.get("processed_plain_text", "")
# 构建 MessageRecv 需要的字典
recv_dict = {
"message_info": message_info_dict,
"message_segment": {"type": "text", "data": processed_text}, # 从纯文本重建消息段
"raw_message": None, # 数据库中未存储原始消息
"processed_plain_text": processed_text,
}
# 创建 MessageRecv 实例
msg = MessageRecv(recv_dict)
# 从数据库字典中填充其他可选字段
msg.interest_value = db_dict.get("interest_value", 0.0)
msg.is_mentioned = db_dict.get("is_mentioned")
msg.priority_mode = db_dict.get("priority_mode", "interest")
msg.priority_info = db_dict.get("priority_info")
msg.is_emoji = db_dict.get("is_emoji", False)
msg.is_picid = db_dict.get("is_picid", False)
return msg
# message_recv_from_dictmessage_from_db_dict 函数已被移除
# 请使用: from src.chat.message_receive.message_processor import process_message_from_dict

View File

@@ -0,0 +1,493 @@
"""消息处理工具模块
将原 MessageRecv 的消息处理逻辑提取为独立函数,
直接从适配器消息字典生成 DatabaseMessages
"""
import base64
import time
from typing import Any
import orjson
from maim_message import BaseMessageInfo, Seg
from src.chat.utils.self_voice_cache import consume_self_voice_text
from src.chat.utils.utils_image import get_image_manager
from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available
from src.chat.utils.utils_voice import get_voice_text
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("message_processor")
async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str, platform: str) -> DatabaseMessages:
"""从适配器消息字典处理并生成 DatabaseMessages
这个函数整合了原 MessageRecv 的所有处理逻辑:
1. 解析 message_segment 并异步处理内容(图片、语音、视频等)
2. 提取所有消息元数据
3. 直接构造 DatabaseMessages 对象
Args:
message_dict: MessageCQ序列化后的字典
stream_id: 聊天流ID
platform: 平台标识
Returns:
DatabaseMessages: 处理完成的数据库消息对象
"""
# 解析基础信息
message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
# 初始化处理状态
processing_state = {
"is_emoji": False,
"has_emoji": False,
"is_picid": False,
"has_picid": False,
"is_voice": False,
"is_video": False,
"is_mentioned": None,
"is_at": False,
"priority_mode": "interest",
"priority_info": None,
}
# 异步处理消息段,生成纯文本
processed_plain_text = await _process_message_segments(message_segment, processing_state, message_info)
# 解析 notice 信息
is_notify = False
is_public_notice = False
notice_type = None
if message_info.additional_config and isinstance(message_info.additional_config, dict):
is_notify = message_info.additional_config.get("is_notice", False)
is_public_notice = message_info.additional_config.get("is_public_notice", False)
notice_type = message_info.additional_config.get("notice_type")
# 提取用户信息
user_info = message_info.user_info
user_id = str(user_info.user_id) if user_info and user_info.user_id else ""
user_nickname = (user_info.user_nickname or "") if user_info else ""
user_cardname = user_info.user_cardname if user_info else None
user_platform = (user_info.platform or "") if user_info else ""
# 提取群组信息
group_info = message_info.group_info
group_id = group_info.group_id if group_info else None
group_name = group_info.group_name if group_info else None
group_platform = group_info.platform if group_info else None
# 生成 chat_id
if group_info and group_id:
chat_id = f"{platform}_{group_id}"
elif user_info and user_info.user_id:
chat_id = f"{platform}_{user_info.user_id}_private"
else:
chat_id = stream_id
# 准备 additional_config
additional_config_str = _prepare_additional_config(message_info, is_notify, is_public_notice, notice_type)
# 提取 reply_to
reply_to = _extract_reply_from_segment(message_segment)
# 构造 DatabaseMessages
message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time()
message_id = message_info.message_id or ""
# 处理 is_mentioned
is_mentioned = None
mentioned_value = processing_state.get("is_mentioned")
if isinstance(mentioned_value, bool):
is_mentioned = mentioned_value
elif isinstance(mentioned_value, (int, float)):
is_mentioned = mentioned_value != 0
db_message = DatabaseMessages(
message_id=message_id,
time=float(message_time),
chat_id=chat_id,
reply_to=reply_to,
processed_plain_text=processed_plain_text,
display_message=processed_plain_text,
is_mentioned=is_mentioned,
is_at=bool(processing_state.get("is_at", False)),
is_emoji=bool(processing_state.get("is_emoji", False)),
is_picid=bool(processing_state.get("is_picid", False)),
is_command=False, # 将在后续处理中设置
is_notify=bool(is_notify),
is_public_notice=bool(is_public_notice),
notice_type=notice_type,
additional_config=additional_config_str,
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
user_platform=user_platform,
chat_info_stream_id=stream_id,
chat_info_platform=platform,
chat_info_create_time=0.0, # 将由 ChatStream 填充
chat_info_last_active_time=0.0, # 将由 ChatStream 填充
chat_info_user_id=user_id,
chat_info_user_nickname=user_nickname,
chat_info_user_cardname=user_cardname,
chat_info_user_platform=user_platform,
chat_info_group_id=group_id,
chat_info_group_name=group_name,
chat_info_group_platform=group_platform,
)
# 设置优先级信息
if processing_state.get("priority_mode"):
setattr(db_message, "priority_mode", processing_state["priority_mode"])
if processing_state.get("priority_info"):
setattr(db_message, "priority_info", processing_state["priority_info"])
# 设置其他运行时属性
setattr(db_message, "is_voice", bool(processing_state.get("is_voice", False)))
setattr(db_message, "is_video", bool(processing_state.get("is_video", False)))
setattr(db_message, "has_emoji", bool(processing_state.get("has_emoji", False)))
setattr(db_message, "has_picid", bool(processing_state.get("has_picid", False)))
return db_message
async def _process_message_segments(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str:
"""递归处理消息段,转换为文字描述
Args:
segment: 要处理的消息段
state: 处理状态字典(用于记录消息类型标记)
message_info: 消息基础信息(用于某些处理逻辑)
Returns:
str: 处理后的文本
"""
if segment.type == "seglist":
# 处理消息段列表
segments_text = []
for seg in segment.data:
processed = await _process_message_segments(seg, state, message_info)
if processed:
segments_text.append(processed)
return " ".join(segments_text)
else:
# 处理单个消息段
return await _process_single_segment(segment, state, message_info)
async def _process_single_segment(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str:
"""处理单个消息段
Args:
segment: 消息段
state: 处理状态字典
message_info: 消息基础信息
Returns:
str: 处理后的文本
"""
try:
if segment.type == "text":
state["is_picid"] = False
state["is_emoji"] = False
state["is_video"] = False
return segment.data
elif segment.type == "at":
state["is_picid"] = False
state["is_emoji"] = False
state["is_video"] = False
state["is_at"] = True
# 处理at消息格式为"昵称:QQ号"
if isinstance(segment.data, str) and ":" in segment.data:
nickname, qq_id = segment.data.split(":", 1)
return f"@{nickname}"
return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户"
elif segment.type == "image":
# 如果是base64图片数据
if isinstance(segment.data, str):
state["has_picid"] = True
state["is_picid"] = True
state["is_emoji"] = False
state["is_video"] = False
image_manager = get_image_manager()
_, processed_text = await image_manager.process_image(segment.data)
return processed_text
return "[发了一张图片,网卡了加载不出来]"
elif segment.type == "emoji":
state["has_emoji"] = True
state["is_emoji"] = True
state["is_picid"] = False
state["is_voice"] = False
state["is_video"] = False
if isinstance(segment.data, str):
return await get_image_manager().get_emoji_description(segment.data)
return "[发了一个表情包,网卡了加载不出来]"
elif segment.type == "voice":
state["is_picid"] = False
state["is_emoji"] = False
state["is_voice"] = True
state["is_video"] = False
# 检查消息是否由机器人自己发送
if message_info and message_info.user_info and str(message_info.user_info.user_id) == str(global_config.bot.qq_account):
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {message_info.user_info.user_id}),尝试从缓存获取文本。")
if isinstance(segment.data, str):
cached_text = consume_self_voice_text(segment.data)
if cached_text:
logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'")
return f"[语音:{cached_text}]"
else:
logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
# 标准语音识别流程
if isinstance(segment.data, str):
return await get_voice_text(segment.data)
return "[发了一段语音,网卡了加载不出来]"
elif segment.type == "mention_bot":
state["is_picid"] = False
state["is_emoji"] = False
state["is_voice"] = False
state["is_video"] = False
state["is_mentioned"] = float(segment.data)
return ""
elif segment.type == "priority_info":
state["is_picid"] = False
state["is_emoji"] = False
state["is_voice"] = False
if isinstance(segment.data, dict):
# 处理优先级信息
state["priority_mode"] = "priority"
state["priority_info"] = segment.data
return ""
elif segment.type == "file":
if isinstance(segment.data, dict):
file_name = segment.data.get('name', '未知文件')
file_size = segment.data.get('size', '未知大小')
return f"[文件:{file_name} ({file_size}字节)]"
return "[收到一个文件]"
elif segment.type == "video":
state["is_picid"] = False
state["is_emoji"] = False
state["is_voice"] = False
state["is_video"] = True
logger.info(f"接收到视频消息,数据类型: {type(segment.data)}")
# 检查视频分析功能是否可用
if not is_video_analysis_available():
logger.warning("⚠️ Rust视频处理模块不可用跳过视频分析")
return "[视频]"
if global_config.video_analysis.enable:
logger.info("已启用视频识别,开始识别")
if isinstance(segment.data, dict):
try:
# 从Adapter接收的视频数据
video_base64 = segment.data.get("base64")
filename = segment.data.get("filename", "video.mp4")
logger.info(f"视频文件名: {filename}")
logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
if video_base64:
# 解码base64视频数据
video_bytes = base64.b64decode(video_base64)
logger.info(f"解码后视频大小: {len(video_bytes)} 字节")
# 使用video analyzer分析视频
video_analyzer = get_video_analyzer()
result = await video_analyzer.analyze_video_from_bytes(
video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt
)
logger.info(f"视频分析结果: {result}")
# 返回视频分析结果
summary = result.get("summary", "")
if summary:
return f"[视频内容] {summary}"
else:
return "[已收到视频,但分析失败]"
else:
logger.warning("视频消息中没有base64数据")
return "[收到视频消息,但数据异常]"
except Exception as e:
logger.error(f"视频处理失败: {e!s}")
import traceback
logger.error(f"错误详情: {traceback.format_exc()}")
return "[收到视频,但处理时出现错误]"
else:
logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}")
return "[发了一个视频,但格式不支持]"
else:
return ""
else:
logger.warning(f"未知的消息段类型: {segment.type}")
return f"[{segment.type} 消息]"
except Exception as e:
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, is_public_notice: bool, notice_type: str | None) -> str | None:
"""准备 additional_config包含 format_info 和 notice 信息
Args:
message_info: 消息基础信息
is_notify: 是否为notice消息
is_public_notice: 是否为公共notice
notice_type: notice类型
Returns:
str | None: JSON 字符串格式的 additional_config如果为空则返回 None
"""
try:
additional_config_data = {}
# 首先获取adapter传递的additional_config
if hasattr(message_info, 'additional_config') and message_info.additional_config:
if isinstance(message_info.additional_config, dict):
additional_config_data = message_info.additional_config.copy()
elif isinstance(message_info.additional_config, str):
try:
additional_config_data = orjson.loads(message_info.additional_config)
except Exception as e:
logger.warning(f"无法解析 additional_config JSON: {e}")
additional_config_data = {}
# 添加notice相关标志
if is_notify:
additional_config_data["is_notice"] = True
additional_config_data["notice_type"] = notice_type or "unknown"
additional_config_data["is_public_notice"] = bool(is_public_notice)
# 添加format_info到additional_config中
if hasattr(message_info, 'format_info') and message_info.format_info:
try:
format_info_dict = message_info.format_info.to_dict()
additional_config_data["format_info"] = format_info_dict
logger.debug(f"[message_processor] 嵌入 format_info 到 additional_config: {format_info_dict}")
except Exception as e:
logger.warning(f"将 format_info 转换为字典失败: {e}")
# 序列化为JSON字符串
if additional_config_data:
return orjson.dumps(additional_config_data).decode("utf-8")
except Exception as e:
logger.error(f"准备 additional_config 失败: {e}")
return None
def _extract_reply_from_segment(segment: Seg) -> str | None:
"""从消息段中提取reply_to信息
Args:
segment: 消息段
Returns:
str | None: 回复的消息ID如果没有则返回None
"""
try:
if hasattr(segment, "type") and segment.type == "seglist":
# 递归搜索seglist中的reply段
if hasattr(segment, "data") and segment.data:
for seg in segment.data:
reply_id = _extract_reply_from_segment(seg)
if reply_id:
return reply_id
elif hasattr(segment, "type") and segment.type == "reply":
# 找到reply段返回message_id
return str(segment.data) if segment.data else None
except Exception as e:
logger.warning(f"提取reply_to信息失败: {e}")
return None
# =============================================================================
# DatabaseMessages 扩展工具函数
# =============================================================================
def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessageInfo:
"""从 DatabaseMessages 重建 BaseMessageInfo用于需要 message_info 的遗留代码)
Args:
db_message: DatabaseMessages 对象
Returns:
BaseMessageInfo: 重建的消息信息对象
"""
from maim_message import UserInfo, GroupInfo
# 从 DatabaseMessages 的 user_info 转换为 maim_message.UserInfo
user_info = UserInfo(
platform=db_message.user_info.platform,
user_id=db_message.user_info.user_id,
user_nickname=db_message.user_info.user_nickname,
user_cardname=db_message.user_info.user_cardname or ""
)
# 从 DatabaseMessages 的 group_info 转换为 maim_message.GroupInfo如果存在
group_info = None
if db_message.group_info:
group_info = GroupInfo(
platform=db_message.group_info.group_platform or "",
group_id=db_message.group_info.group_id,
group_name=db_message.group_info.group_name
)
# 解析 additional_config从 JSON 字符串到字典)
additional_config = None
if db_message.additional_config:
try:
additional_config = orjson.loads(db_message.additional_config)
except Exception:
# 如果解析失败,保持为字符串
pass
# 创建 BaseMessageInfo
message_info = BaseMessageInfo(
platform=db_message.chat_info.platform,
message_id=db_message.message_id,
time=db_message.time,
user_info=user_info,
group_info=group_info,
additional_config=additional_config # type: ignore
)
return message_info
def set_db_message_runtime_attr(db_message: DatabaseMessages, attr_name: str, value: Any) -> None:
"""安全地为 DatabaseMessages 设置运行时属性
Args:
db_message: DatabaseMessages 对象
attr_name: 属性名
value: 属性值
"""
setattr(db_message, attr_name, value)
def get_db_message_runtime_attr(db_message: DatabaseMessages, attr_name: str, default: Any = None) -> Any:
"""安全地获取 DatabaseMessages 的运行时属性
Args:
db_message: DatabaseMessages 对象
attr_name: 属性名
default: 默认值
Returns:
属性值或默认值
"""
return getattr(db_message, attr_name, default)

View File

@@ -0,0 +1,434 @@
# MessageRecv 类备份 - 已从 message.py 中移除
# 备份日期: 2025-10-31
# 此类已被 DatabaseMessages 完全取代
# MessageRecv 类已被移除
# 现在所有消息处理都使用 DatabaseMessages
# 如果需要从消息字典创建 DatabaseMessages请使用 message_processor.process_message_from_dict()
#
# 历史参考: MessageRecv 曾经是接收消息的包装类,现已被 DatabaseMessages 完全取代
# 迁移完成日期: 2025-10-31
"""
# 以下是已删除的 MessageRecv 类(保留作为参考)
class MessageRecv:
接收消息类 - DatabaseMessages 的轻量级包装器
这个类现在主要作为适配器层,处理外部消息格式并内部使用 DatabaseMessages。
保留此类是为了向后兼容性和处理 message_segment 的异步逻辑。
"""
def __init__(self, message_dict: dict[str, Any]):
"""从MessageCQ的字典初始化
Args:
message_dict: MessageCQ序列化后的字典
"""
# 保留原始消息信息用于某些场景
self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
self.raw_message = message_dict.get("raw_message")
# 处理状态(在process()之前临时使用)
self._processing_state = {
"is_emoji": False,
"has_emoji": False,
"is_picid": False,
"has_picid": False,
"is_voice": False,
"is_video": False,
"is_mentioned": None,
"is_at": False,
"priority_mode": "interest",
"priority_info": None,
}
self.chat_stream = None
self.reply = None
self.processed_plain_text = message_dict.get("processed_plain_text", "")
# 解析additional_config中的notice信息
self.is_notify = False
self.is_public_notice = False
self.notice_type = None
if self.message_info.additional_config and isinstance(self.message_info.additional_config, dict):
self.is_notify = self.message_info.additional_config.get("is_notice", False)
self.is_public_notice = self.message_info.additional_config.get("is_public_notice", False)
self.notice_type = self.message_info.additional_config.get("notice_type")
# 兼容性属性 - 代理到 _processing_state
@property
def is_emoji(self) -> bool:
return self._processing_state["is_emoji"]
@is_emoji.setter
def is_emoji(self, value: bool):
self._processing_state["is_emoji"] = value
@property
def has_emoji(self) -> bool:
return self._processing_state["has_emoji"]
@has_emoji.setter
def has_emoji(self, value: bool):
self._processing_state["has_emoji"] = value
@property
def is_picid(self) -> bool:
return self._processing_state["is_picid"]
@is_picid.setter
def is_picid(self, value: bool):
self._processing_state["is_picid"] = value
@property
def has_picid(self) -> bool:
return self._processing_state["has_picid"]
@has_picid.setter
def has_picid(self, value: bool):
self._processing_state["has_picid"] = value
@property
def is_voice(self) -> bool:
return self._processing_state["is_voice"]
@is_voice.setter
def is_voice(self, value: bool):
self._processing_state["is_voice"] = value
@property
def is_video(self) -> bool:
return self._processing_state["is_video"]
@is_video.setter
def is_video(self, value: bool):
self._processing_state["is_video"] = value
@property
def is_mentioned(self):
return self._processing_state["is_mentioned"]
@is_mentioned.setter
def is_mentioned(self, value):
self._processing_state["is_mentioned"] = value
@property
def is_at(self) -> bool:
return self._processing_state["is_at"]
@is_at.setter
def is_at(self, value: bool):
self._processing_state["is_at"] = value
@property
def priority_mode(self) -> str:
return self._processing_state["priority_mode"]
@priority_mode.setter
def priority_mode(self, value: str):
self._processing_state["priority_mode"] = value
@property
def priority_info(self):
return self._processing_state["priority_info"]
@priority_info.setter
def priority_info(self, value):
self._processing_state["priority_info"] = value
# 其他常用属性
interest_value: float = 0.0
is_command: bool = False
memorized_times: int = 0
def __post_init__(self):
"""dataclass 初始化后处理"""
self.key_words = []
self.key_words_lite = []
def update_chat_stream(self, chat_stream: "ChatStream"):
self.chat_stream = chat_stream
def to_database_message(self) -> "DatabaseMessages":
"""将 MessageRecv 转换为 DatabaseMessages 对象
Returns:
DatabaseMessages: 数据库消息对象
"""
import time
message_info = self.message_info
msg_user_info = getattr(message_info, "user_info", None)
stream_user_info = getattr(self.chat_stream, "user_info", None) if self.chat_stream else None
group_info = getattr(self.chat_stream, "group_info", None) if self.chat_stream else None
message_id = message_info.message_id or ""
message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time()
is_mentioned = None
if isinstance(self.is_mentioned, bool):
is_mentioned = self.is_mentioned
elif isinstance(self.is_mentioned, int | float):
is_mentioned = self.is_mentioned != 0
# 提取用户信息
user_id = ""
user_nickname = ""
user_cardname = None
user_platform = ""
if msg_user_info:
user_id = str(getattr(msg_user_info, "user_id", "") or "")
user_nickname = getattr(msg_user_info, "user_nickname", "") or ""
user_cardname = getattr(msg_user_info, "user_cardname", None)
user_platform = getattr(msg_user_info, "platform", "") or ""
elif stream_user_info:
user_id = str(getattr(stream_user_info, "user_id", "") or "")
user_nickname = getattr(stream_user_info, "user_nickname", "") or ""
user_cardname = getattr(stream_user_info, "user_cardname", None)
user_platform = getattr(stream_user_info, "platform", "") or ""
# 提取聊天流信息
chat_user_id = str(getattr(stream_user_info, "user_id", "") or "") if stream_user_info else ""
chat_user_nickname = getattr(stream_user_info, "user_nickname", "") or "" if stream_user_info else ""
chat_user_cardname = getattr(stream_user_info, "user_cardname", None) if stream_user_info else None
chat_user_platform = getattr(stream_user_info, "platform", "") or "" if stream_user_info else ""
group_id = getattr(group_info, "group_id", None) if group_info else None
group_name = getattr(group_info, "group_name", None) if group_info else None
group_platform = getattr(group_info, "platform", None) if group_info else None
# 准备 additional_config
additional_config_str = None
try:
import orjson
additional_config_data = {}
# 首先获取adapter传递的additional_config
if hasattr(message_info, 'additional_config') and message_info.additional_config:
if isinstance(message_info.additional_config, dict):
additional_config_data = message_info.additional_config.copy()
elif isinstance(message_info.additional_config, str):
try:
additional_config_data = orjson.loads(message_info.additional_config)
except Exception as e:
logger.warning(f"无法解析 additional_config JSON: {e}")
additional_config_data = {}
# 添加notice相关标志
if self.is_notify:
additional_config_data["is_notice"] = True
additional_config_data["notice_type"] = self.notice_type or "unknown"
additional_config_data["is_public_notice"] = bool(self.is_public_notice)
# 添加format_info到additional_config中
if hasattr(message_info, 'format_info') and message_info.format_info:
try:
format_info_dict = message_info.format_info.to_dict()
additional_config_data["format_info"] = format_info_dict
logger.debug(f"[message.py] 嵌入 format_info 到 additional_config: {format_info_dict}")
except Exception as e:
logger.warning(f"将 format_info 转换为字典失败: {e}")
# 序列化为JSON字符串
if additional_config_data:
additional_config_str = orjson.dumps(additional_config_data).decode("utf-8")
except Exception as e:
logger.error(f"准备 additional_config 失败: {e}")
# 创建数据库消息对象
db_message = DatabaseMessages(
message_id=message_id,
time=float(message_time),
chat_id=self.chat_stream.stream_id if self.chat_stream else "",
processed_plain_text=self.processed_plain_text,
display_message=self.processed_plain_text,
is_mentioned=is_mentioned,
is_at=bool(self.is_at) if self.is_at is not None else None,
is_emoji=bool(self.is_emoji),
is_picid=bool(self.is_picid),
is_command=bool(self.is_command),
is_notify=bool(self.is_notify),
is_public_notice=bool(self.is_public_notice),
notice_type=self.notice_type,
additional_config=additional_config_str,
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
user_platform=user_platform,
chat_info_stream_id=self.chat_stream.stream_id if self.chat_stream else "",
chat_info_platform=self.chat_stream.platform if self.chat_stream else "",
chat_info_create_time=float(self.chat_stream.create_time) if self.chat_stream else 0.0,
chat_info_last_active_time=float(self.chat_stream.last_active_time) if self.chat_stream else 0.0,
chat_info_user_id=chat_user_id,
chat_info_user_nickname=chat_user_nickname,
chat_info_user_cardname=chat_user_cardname,
chat_info_user_platform=chat_user_platform,
chat_info_group_id=group_id,
chat_info_group_name=group_name,
chat_info_group_platform=group_platform,
)
# 同步兴趣度等衍生属性
db_message.interest_value = getattr(self, "interest_value", 0.0)
setattr(db_message, "should_reply", getattr(self, "should_reply", False))
setattr(db_message, "should_act", getattr(self, "should_act", False))
return db_message
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本
这个方法必须在创建实例后显式调用,因为它包含异步操作。
"""
self.processed_plain_text = await self._process_message_segments(self.message_segment)
async def _process_single_segment(self, segment: Seg) -> str:
"""处理单个消息段
Args:
segment: 消息段
Returns:
str: 处理后的文本
"""
try:
if segment.type == "text":
self.is_picid = False
self.is_emoji = False
self.is_video = False
return segment.data # type: ignore
elif segment.type == "at":
self.is_picid = False
self.is_emoji = False
self.is_video = False
# 处理at消息格式为"昵称:QQ号"
if isinstance(segment.data, str) and ":" in segment.data:
nickname, qq_id = segment.data.split(":", 1)
return f"@{nickname}"
return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户"
elif segment.type == "image":
# 如果是base64图片数据
if isinstance(segment.data, str):
self.has_picid = True
self.is_picid = True
self.is_emoji = False
self.is_video = False
image_manager = get_image_manager()
# print(f"segment.data: {segment.data}")
_, processed_text = await image_manager.process_image(segment.data)
return processed_text
return "[发了一张图片,网卡了加载不出来]"
elif segment.type == "emoji":
self.has_emoji = True
self.is_emoji = True
self.is_picid = False
self.is_voice = False
self.is_video = False
if isinstance(segment.data, str):
return await get_image_manager().get_emoji_description(segment.data)
return "[发了一个表情包,网卡了加载不出来]"
elif segment.type == "voice":
self.is_picid = False
self.is_emoji = False
self.is_voice = True
self.is_video = False
# 检查消息是否由机器人自己发送
if self.message_info and self.message_info.user_info and str(self.message_info.user_info.user_id) == str(global_config.bot.qq_account):
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {self.message_info.user_info.user_id}),尝试从缓存获取文本。")
if isinstance(segment.data, str):
cached_text = consume_self_voice_text(segment.data)
if cached_text:
logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'")
return f"[语音:{cached_text}]"
else:
logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
# 标准语音识别流程 (也作为缓存未命中的后备方案)
if isinstance(segment.data, str):
return await get_voice_text(segment.data)
return "[发了一段语音,网卡了加载不出来]"
elif segment.type == "mention_bot":
self.is_picid = False
self.is_emoji = False
self.is_voice = False
self.is_video = False
self.is_mentioned = float(segment.data) # type: ignore
return ""
elif segment.type == "priority_info":
self.is_picid = False
self.is_emoji = False
self.is_voice = False
if isinstance(segment.data, dict):
# 处理优先级信息
self.priority_mode = "priority"
self.priority_info = segment.data
"""
{
'message_type': 'vip', # vip or normal
'message_priority': 1.0, # 优先级大为优先float
}
"""
return ""
elif segment.type == "file":
if isinstance(segment.data, dict):
file_name = segment.data.get('name', '未知文件')
file_size = segment.data.get('size', '未知大小')
return f"[文件:{file_name} ({file_size}字节)]"
return "[收到一个文件]"
elif segment.type == "video":
self.is_picid = False
self.is_emoji = False
self.is_voice = False
self.is_video = True
logger.info(f"接收到视频消息,数据类型: {type(segment.data)}")
# 检查视频分析功能是否可用
if not is_video_analysis_available():
logger.warning("⚠️ Rust视频处理模块不可用跳过视频分析")
return "[视频]"
if global_config.video_analysis.enable:
logger.info("已启用视频识别,开始识别")
if isinstance(segment.data, dict):
try:
# 从Adapter接收的视频数据
video_base64 = segment.data.get("base64")
filename = segment.data.get("filename", "video.mp4")
logger.info(f"视频文件名: {filename}")
logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
if video_base64:
# 解码base64视频数据
video_bytes = base64.b64decode(video_base64)
logger.info(f"解码后视频大小: {len(video_bytes)} 字节")
# 使用video analyzer分析视频
video_analyzer = get_video_analyzer()
result = await video_analyzer.analyze_video_from_bytes(
video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt
)
logger.info(f"视频分析结果: {result}")
# 返回视频分析结果
summary = result.get("summary", "")
if summary:
return f"[视频内容] {summary}"
else:
return "[已收到视频,但分析失败]"
else:
logger.warning("视频消息中没有base64数据")
return "[收到视频消息,但数据异常]"
except Exception as e:
logger.error(f"视频处理失败: {e!s}")
import traceback
logger.error(f"错误详情: {traceback.format_exc()}")
return "[收到视频,但处理时出现错误]"
else:
logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}")
return "[发了一个视频,但格式不支持]"
else:

View File

@@ -9,8 +9,10 @@ from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Images, Messages
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from .chat_stream import ChatStream
from .message import MessageRecv, MessageSending
from .message import MessageSending
logger = get_logger("message_storage")
@@ -34,97 +36,166 @@ class MessageStorage:
return []
@staticmethod
async def store_message(message: MessageSending | MessageRecv, chat_stream: ChatStream) -> None:
async def store_message(message: DatabaseMessages | MessageSending, chat_stream: ChatStream) -> None:
"""存储消息到数据库"""
try:
# 过滤敏感信息的正则模式
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
processed_plain_text = message.processed_plain_text
if processed_plain_text:
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
# 增加对None的防御性处理
safe_processed_plain_text = processed_plain_text or ""
filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
else:
filtered_processed_plain_text = ""
if isinstance(message, MessageSending):
display_message = message.display_message
if display_message:
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
# 如果是 DatabaseMessages直接使用它的字段
if isinstance(message, DatabaseMessages):
processed_plain_text = message.processed_plain_text
if processed_plain_text:
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
safe_processed_plain_text = processed_plain_text or ""
filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
else:
# 如果没有设置display_message使用processed_plain_text作为显示消息
filtered_display_message = (
re.sub(pattern, "", (message.processed_plain_text or ""), flags=re.DOTALL)
)
interest_value = 0
is_mentioned = False
reply_to = message.reply_to
priority_mode = ""
priority_info = {}
is_emoji = False
is_picid = False
is_notify = False
is_command = False
key_words = ""
key_words_lite = ""
else:
filtered_display_message = ""
interest_value = message.interest_value
filtered_processed_plain_text = ""
display_message = message.display_message or message.processed_plain_text or ""
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
# 直接从 DatabaseMessages 获取所有字段
msg_id = message.message_id
msg_time = message.time
chat_id = message.chat_id
reply_to = "" # DatabaseMessages 没有 reply_to 字段
is_mentioned = message.is_mentioned
reply_to = ""
priority_mode = message.priority_mode
priority_info = message.priority_info
is_emoji = message.is_emoji
is_picid = message.is_picid
is_notify = message.is_notify
is_command = message.is_command
# 序列化关键词列表为JSON字符串
key_words = MessageStorage._serialize_keywords(message.key_words)
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
interest_value = message.interest_value or 0.0
priority_mode = "" # DatabaseMessages 没有 priority_mode
priority_info_json = None # DatabaseMessages 没有 priority_info
is_emoji = message.is_emoji or False
is_picid = message.is_picid or False
is_notify = message.is_notify or False
is_command = message.is_command or False
key_words = "" # DatabaseMessages 没有 key_words
key_words_lite = ""
memorized_times = 0 # DatabaseMessages 没有 memorized_times
# 使用 DatabaseMessages 中的嵌套对象信息
user_platform = message.user_info.platform if message.user_info else ""
user_id = message.user_info.user_id if message.user_info else ""
user_nickname = message.user_info.user_nickname if message.user_info else ""
user_cardname = message.user_info.user_cardname if message.user_info else None
chat_info_stream_id = message.chat_info.stream_id if message.chat_info else ""
chat_info_platform = message.chat_info.platform if message.chat_info else ""
chat_info_create_time = message.chat_info.create_time if message.chat_info else 0.0
chat_info_last_active_time = message.chat_info.last_active_time if message.chat_info else 0.0
chat_info_user_platform = message.chat_info.user_info.platform if message.chat_info and message.chat_info.user_info else ""
chat_info_user_id = message.chat_info.user_info.user_id if message.chat_info and message.chat_info.user_info else ""
chat_info_user_nickname = message.chat_info.user_info.user_nickname if message.chat_info and message.chat_info.user_info else ""
chat_info_user_cardname = message.chat_info.user_info.user_cardname if message.chat_info and message.chat_info.user_info else None
chat_info_group_platform = message.group_info.group_platform if message.group_info else None
chat_info_group_id = message.group_info.group_id if message.group_info else None
chat_info_group_name = message.group_info.group_name if message.group_info else None
else:
# MessageSending 处理逻辑
processed_plain_text = message.processed_plain_text
chat_info_dict = chat_stream.to_dict()
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
if processed_plain_text:
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
# 增加对None的防御性处理
safe_processed_plain_text = processed_plain_text or ""
filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
else:
filtered_processed_plain_text = ""
# message_id 现在是 TextField直接使用字符串值
msg_id = message.message_info.message_id
if isinstance(message, MessageSending):
display_message = message.display_message
if display_message:
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
else:
# 如果没有设置display_message使用processed_plain_text作为显示消息
filtered_display_message = (
re.sub(pattern, "", (message.processed_plain_text or ""), flags=re.DOTALL)
)
interest_value = 0
is_mentioned = False
reply_to = message.reply_to
priority_mode = ""
priority_info = {}
is_emoji = False
is_picid = False
is_notify = False
is_command = False
key_words = ""
key_words_lite = ""
else:
filtered_display_message = ""
interest_value = message.interest_value
is_mentioned = message.is_mentioned
reply_to = ""
priority_mode = message.priority_mode
priority_info = message.priority_info
is_emoji = message.is_emoji
is_picid = message.is_picid
is_notify = message.is_notify
is_command = message.is_command
# 序列化关键词列表为JSON字符串
key_words = MessageStorage._serialize_keywords(message.key_words)
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
# 安全地获取 group_info, 如果为 None 则视为空字典
group_info_from_chat = chat_info_dict.get("group_info") or {}
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
user_info_from_chat = chat_info_dict.get("user_info") or {}
chat_info_dict = chat_stream.to_dict()
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
# 将priority_info字典序列化为JSON字符串以便存储到数据库的Text字段
priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None
# message_id 现在是 TextField直接使用字符串值
msg_id = message.message_info.message_id
msg_time = float(message.message_info.time or time.time())
chat_id = chat_stream.stream_id
memorized_times = message.memorized_times
# 安全地获取 group_info, 如果为 None 则视为空字典
group_info_from_chat = chat_info_dict.get("group_info") or {}
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
user_info_from_chat = chat_info_dict.get("user_info") or {}
# 将priority_info字典序列化为JSON字符串以便存储到数据库的Text字段
priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None
user_platform = user_info_dict.get("platform")
user_id = user_info_dict.get("user_id")
user_nickname = user_info_dict.get("user_nickname")
user_cardname = user_info_dict.get("user_cardname")
chat_info_stream_id = chat_info_dict.get("stream_id")
chat_info_platform = chat_info_dict.get("platform")
chat_info_create_time = float(chat_info_dict.get("create_time", 0.0))
chat_info_last_active_time = float(chat_info_dict.get("last_active_time", 0.0))
chat_info_user_platform = user_info_from_chat.get("platform")
chat_info_user_id = user_info_from_chat.get("user_id")
chat_info_user_nickname = user_info_from_chat.get("user_nickname")
chat_info_user_cardname = user_info_from_chat.get("user_cardname")
chat_info_group_platform = group_info_from_chat.get("platform")
chat_info_group_id = group_info_from_chat.get("group_id")
chat_info_group_name = group_info_from_chat.get("group_name")
# 获取数据库会话
new_message = Messages(
message_id=msg_id,
time=float(message.message_info.time or time.time()),
chat_id=chat_stream.stream_id,
time=msg_time,
chat_id=chat_id,
reply_to=reply_to,
is_mentioned=is_mentioned,
chat_info_stream_id=chat_info_dict.get("stream_id"),
chat_info_platform=chat_info_dict.get("platform"),
chat_info_user_platform=user_info_from_chat.get("platform"),
chat_info_user_id=user_info_from_chat.get("user_id"),
chat_info_user_nickname=user_info_from_chat.get("user_nickname"),
chat_info_user_cardname=user_info_from_chat.get("user_cardname"),
chat_info_group_platform=group_info_from_chat.get("platform"),
chat_info_group_id=group_info_from_chat.get("group_id"),
chat_info_group_name=group_info_from_chat.get("group_name"),
chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)),
chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)),
user_platform=user_info_dict.get("platform"),
user_id=user_info_dict.get("user_id"),
user_nickname=user_info_dict.get("user_nickname"),
user_cardname=user_info_dict.get("user_cardname"),
chat_info_stream_id=chat_info_stream_id,
chat_info_platform=chat_info_platform,
chat_info_user_platform=chat_info_user_platform,
chat_info_user_id=chat_info_user_id,
chat_info_user_nickname=chat_info_user_nickname,
chat_info_user_cardname=chat_info_user_cardname,
chat_info_group_platform=chat_info_group_platform,
chat_info_group_id=chat_info_group_id,
chat_info_group_name=chat_info_group_name,
chat_info_create_time=chat_info_create_time,
chat_info_last_active_time=chat_info_last_active_time,
user_platform=user_platform,
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
processed_plain_text=filtered_processed_plain_text,
display_message=filtered_display_message,
memorized_times=message.memorized_times,
memorized_times=memorized_times,
interest_value=interest_value,
priority_mode=priority_mode,
priority_info=priority_info_json,
@@ -145,36 +216,43 @@ class MessageStorage:
traceback.print_exc()
@staticmethod
async def update_message(message):
"""更新消息ID"""
async def update_message(message_data: dict):
"""更新消息ID(从消息字典)"""
try:
mmc_message_id = message.message_info.message_id
# 从字典中提取信息
message_info = message_data.get("message_info", {})
mmc_message_id = message_info.get("message_id")
message_segment = message_data.get("message_segment", {})
segment_type = message_segment.get("type") if isinstance(message_segment, dict) else None
segment_data = message_segment.get("data", {}) if isinstance(message_segment, dict) else {}
qq_message_id = None
logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {message.message_segment.type}")
logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {segment_type}")
# 根据消息段类型提取message_id
if message.message_segment.type == "notify":
qq_message_id = message.message_segment.data.get("id")
elif message.message_segment.type == "text":
qq_message_id = message.message_segment.data.get("id")
elif message.message_segment.type == "reply":
qq_message_id = message.message_segment.data.get("id")
if segment_type == "notify":
qq_message_id = segment_data.get("id")
elif segment_type == "text":
qq_message_id = segment_data.get("id")
elif segment_type == "reply":
qq_message_id = segment_data.get("id")
if qq_message_id:
logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}")
elif message.message_segment.type == "adapter_response":
elif segment_type == "adapter_response":
logger.debug("适配器响应消息不需要更新ID")
return
elif message.message_segment.type == "adapter_command":
elif segment_type == "adapter_command":
logger.debug("适配器命令消息不需要更新ID")
return
else:
logger.debug(f"未知的消息段类型: {message.message_segment.type}跳过ID更新")
logger.debug(f"未知的消息段类型: {segment_type}跳过ID更新")
return
if not qq_message_id:
logger.debug(f"消息段类型 {message.message_segment.type} 中未找到有效的message_id跳过更新")
logger.debug(f"消息段数据: {message.message_segment.data}")
logger.debug(f"消息段类型 {segment_type} 中未找到有效的message_id跳过更新")
logger.debug(f"消息段数据: {segment_data}")
return
# 使用上下文管理器确保session正确管理

View File

@@ -137,7 +137,7 @@ class ActionModifier:
logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用")
# === 第二阶段:检查动作的关联类型 ===
chat_context = self.chat_stream.stream_context
chat_context = self.chat_stream.context_manager.context
current_actions_s2 = self.action_manager.get_using_actions()
type_mismatched_actions = self._check_action_associated_types(current_actions_s2, chat_context)

View File

@@ -13,7 +13,7 @@ from typing import Any
from src.chat.express.expression_selector import expression_selector
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.message import MessageRecv, MessageSending, Seg, UserInfo
from src.chat.message_receive.message import MessageSending, Seg, UserInfo
from src.chat.message_receive.uni_message_sender import HeartFCSender
from src.chat.utils.chat_message_builder import (
build_readable_messages,
@@ -1733,7 +1733,7 @@ class DefaultReplyer:
is_emoji: bool,
thinking_start_time: float,
display_message: str,
anchor_message: MessageRecv | None = None,
anchor_message: DatabaseMessages | None = None,
) -> MessageSending:
"""构建单个发送消息"""
@@ -1743,8 +1743,11 @@ class DefaultReplyer:
platform=self.chat_stream.platform,
)
# await anchor_message.process()
sender_info = anchor_message.message_info.user_info if anchor_message else None
# 从 DatabaseMessages 获取 sender_info
if anchor_message:
sender_info = anchor_message.user_info
else:
sender_info = None
return MessageSending(
message_id=message_id, # 使用片段的唯一ID

View File

@@ -11,7 +11,7 @@ import rjieba
from maim_message import UserInfo
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.message import MessageRecv
# MessageRecv 已被移除,现在使用 DatabaseMessages
from src.common.logger import get_logger
from src.common.message_repository import count_messages, find_messages
from src.config.config import global_config, model_config
@@ -41,34 +41,58 @@ def db_message_to_str(message_dict: dict) -> str:
return result
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
"""检查消息是否提到了机器人"""
def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
"""检查消息是否提到了机器人
Args:
message: DatabaseMessages 消息对象
Returns:
tuple[bool, float]: (是否提及, 提及概率)
"""
keywords = [global_config.bot.nickname]
nicknames = global_config.bot.alias_names
reply_probability = 0.0
is_at = False
is_mentioned = False
if message.is_mentioned is not None:
return bool(message.is_mentioned), message.is_mentioned
if (
message.message_info.additional_config is not None
and message.message_info.additional_config.get("is_mentioned") is not None
):
# 检查 is_mentioned 属性
mentioned_attr = getattr(message, "is_mentioned", None)
if mentioned_attr is not None:
try:
reply_probability = float(message.message_info.additional_config.get("is_mentioned")) # type: ignore
return bool(mentioned_attr), float(mentioned_attr)
except (ValueError, TypeError):
pass
# 检查 additional_config
additional_config = None
# DatabaseMessages: additional_config 是 JSON 字符串
if message.additional_config:
try:
import orjson
additional_config = orjson.loads(message.additional_config)
except Exception:
pass
if additional_config and additional_config.get("is_mentioned") is not None:
try:
reply_probability = float(additional_config.get("is_mentioned")) # type: ignore
is_mentioned = True
return is_mentioned, reply_probability
except Exception as e:
logger.warning(str(e))
logger.warning(
f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}"
f"消息中包含不合理的设置 is_mentioned: {additional_config.get('is_mentioned')}"
)
if global_config.bot.nickname in message.processed_plain_text:
# 检查消息文本内容
processed_text = message.processed_plain_text or ""
if global_config.bot.nickname in processed_text:
is_mentioned = True
for alias_name in global_config.bot.alias_names:
if alias_name in message.processed_plain_text:
if alias_name in processed_text:
is_mentioned = True
# 判断是否被@
@@ -110,7 +134,6 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
logger.debug("被提及回复概率设置为100%")
return is_mentioned, reply_probability
async def get_embedding(text, request_type="embedding") -> list[float] | None:
"""获取文本的embedding向量"""
# 每次都创建新的LLMRequest实例以避免事件循环冲突

View File

@@ -9,15 +9,18 @@ from src.common.logger import get_logger
logger = get_logger("db_migration")
async def check_and_migrate_database():
async def check_and_migrate_database(existing_engine=None):
"""
异步检查数据库结构并自动迁移。
- 自动创建不存在的表。
- 自动为现有表添加缺失的列。
- 自动为现有表创建缺失的索引。
Args:
existing_engine: 可选的已存在的数据库引擎。如果提供,将使用该引擎;否则获取全局引擎。
"""
logger.info("正在检查数据库结构并执行自动迁移...")
engine = await get_engine()
engine = existing_engine if existing_engine is not None else await get_engine()
async with engine.connect() as connection:
# 在同步上下文中运行inspector操作

View File

@@ -780,12 +780,8 @@ async def initialize_database() -> tuple["AsyncEngine", async_sessionmaker[Async
_SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False)
# 迁移
try:
from src.common.database.db_migration import check_and_migrate_database
await check_and_migrate_database(existing_engine=_engine)
except TypeError:
from src.common.database.db_migration import check_and_migrate_database as _legacy_migrate
await _legacy_migrate()
from src.common.database.db_migration import check_and_migrate_database
await check_and_migrate_database(existing_engine=_engine)
if config.database_type == "sqlite":
await enable_sqlite_wal_mode(_engine)

View File

@@ -2,7 +2,6 @@ import math
import random
import time
from src.chat.message_receive.message import MessageRecv
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.data_models.database_data_model import DatabaseMessages
@@ -98,7 +97,7 @@ class ChatMood:
if not hasattr(self, "last_change_time"):
self.last_change_time = 0
async def update_mood_by_message(self, message: MessageRecv | DatabaseMessages, interested_rate: float):
async def update_mood_by_message(self, message: DatabaseMessages, interested_rate: float):
# 确保异步初始化已完成
await self._initialize()
@@ -109,11 +108,8 @@ class ChatMood:
self.regression_count = 0
# 处理不同类型的消息对象
if isinstance(message, MessageRecv):
message_time = message.message_info.time
else: # DatabaseMessages
message_time = message.time
# 使用 DatabaseMessages 的时间字段
message_time = message.time
# 防止负时间差
during_last_time = max(0, message_time - self.last_change_time)

View File

@@ -86,13 +86,16 @@ async def file_to_stream(
import asyncio
import time
import traceback
from typing import Any
from typing import Any, TYPE_CHECKING
from maim_message import Seg, UserInfo
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
# 导入依赖
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.message_receive.message import MessageRecv, MessageSending
from src.chat.message_receive.message import MessageSending
from src.chat.message_receive.uni_message_sender import HeartFCSender
from src.common.logger import get_logger
from src.config.config import global_config
@@ -104,84 +107,53 @@ logger = get_logger("send_api")
_adapter_response_pool: dict[str, asyncio.Future] = {}
def message_dict_to_message_recv(message_dict: dict[str, Any]) -> MessageRecv | None:
"""查找要回复的消息
def message_dict_to_db_message(message_dict: dict[str, Any]) -> "DatabaseMessages | None":
"""从消息字典构建 DatabaseMessages 对象
Args:
message_dict: 消息字典或 DatabaseMessages 对象
Returns:
Optional[MessageRecv]: 找到的消息,如果没找到则返回None
Optional[DatabaseMessages]: 构建的消息对象,如果构建失败则返回None
"""
# 兼容 DatabaseMessages 对象和字典
if isinstance(message_dict, dict):
user_platform = message_dict.get("user_platform", "")
user_id = message_dict.get("user_id", "")
user_nickname = message_dict.get("user_nickname", "")
user_cardname = message_dict.get("user_cardname", "")
chat_info_group_id = message_dict.get("chat_info_group_id")
chat_info_group_platform = message_dict.get("chat_info_group_platform", "")
chat_info_group_name = message_dict.get("chat_info_group_name", "")
chat_info_platform = message_dict.get("chat_info_platform", "")
message_id = message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id")
time_val = message_dict.get("time")
additional_config = message_dict.get("additional_config")
processed_plain_text = message_dict.get("processed_plain_text")
else:
# DatabaseMessages 对象
user_platform = getattr(message_dict, "user_platform", "")
user_id = getattr(message_dict, "user_id", "")
user_nickname = getattr(message_dict, "user_nickname", "")
user_cardname = getattr(message_dict, "user_cardname", "")
chat_info_group_id = getattr(message_dict, "chat_info_group_id", None)
chat_info_group_platform = getattr(message_dict, "chat_info_group_platform", "")
chat_info_group_name = getattr(message_dict, "chat_info_group_name", "")
chat_info_platform = getattr(message_dict, "chat_info_platform", "")
message_id = getattr(message_dict, "message_id", None)
time_val = getattr(message_dict, "time", None)
additional_config = getattr(message_dict, "additional_config", None)
processed_plain_text = getattr(message_dict, "processed_plain_text", "")
from src.common.data_models.database_data_model import DatabaseMessages
# 构建MessageRecv对象
user_info = {
"platform": user_platform,
"user_id": user_id,
"user_nickname": user_nickname,
"user_cardname": user_cardname,
}
group_info = {}
if chat_info_group_id:
group_info = {
"platform": chat_info_group_platform,
"group_id": chat_info_group_id,
"group_name": chat_info_group_name,
}
format_info = {"content_format": "", "accept_format": ""}
template_info = {"template_items": {}}
message_info = {
"platform": chat_info_platform,
"message_id": message_id,
"time": time_val,
"group_info": group_info,
"user_info": user_info,
"additional_config": additional_config,
"format_info": format_info,
"template_info": template_info,
}
new_message_dict = {
"message_info": message_info,
"raw_message": processed_plain_text,
"processed_plain_text": processed_plain_text,
}
message_recv = MessageRecv(new_message_dict)
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {user_nickname}")
return message_recv
# 如果已经是 DatabaseMessages直接返回
if isinstance(message_dict, DatabaseMessages):
return message_dict
# 从字典提取信息
user_platform = message_dict.get("user_platform", "")
user_id = message_dict.get("user_id", "")
user_nickname = message_dict.get("user_nickname", "")
user_cardname = message_dict.get("user_cardname", "")
chat_info_group_id = message_dict.get("chat_info_group_id")
chat_info_group_platform = message_dict.get("chat_info_group_platform", "")
chat_info_group_name = message_dict.get("chat_info_group_name", "")
chat_info_platform = message_dict.get("chat_info_platform", "")
message_id = message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id")
time_val = message_dict.get("time", time.time())
additional_config = message_dict.get("additional_config")
processed_plain_text = message_dict.get("processed_plain_text", "")
# DatabaseMessages 使用扁平参数构造
db_message = DatabaseMessages(
message_id=message_id or "temp_reply_id",
time=time_val,
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
user_platform=user_platform,
chat_info_group_id=chat_info_group_id,
chat_info_group_name=chat_info_group_name,
chat_info_group_platform=chat_info_group_platform,
chat_info_platform=chat_info_platform,
processed_plain_text=processed_plain_text,
additional_config=additional_config
)
logger.info(f"[SendAPI] 构建回复消息对象,发送者: {user_nickname}")
return db_message
def put_adapter_response(request_id: str, response_data: dict) -> None:
@@ -285,17 +257,17 @@ async def _send_to_target(
"message_id": "temp_reply_id", # 临时ID
"time": time.time()
}
anchor_message = message_dict_to_message_recv(message_dict=temp_message_dict)
anchor_message = message_dict_to_db_message(message_dict=temp_message_dict)
else:
anchor_message = None
reply_to_platform_id = f"{target_stream.platform}:{sender_id}" if anchor_message else None
elif reply_to_message:
anchor_message = message_dict_to_message_recv(message_dict=reply_to_message)
anchor_message = message_dict_to_db_message(message_dict=reply_to_message)
if anchor_message:
anchor_message.update_chat_stream(target_stream)
# DatabaseMessages 不需要 update_chat_stream它是纯数据对象
reply_to_platform_id = (
f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}"
f"{anchor_message.chat_info.platform}:{anchor_message.user_info.user_id}"
)
else:
reply_to_platform_id = None

View File

@@ -1,10 +1,14 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from src.chat.message_receive.message import MessageRecv
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.plugin_system.apis import send_api
from src.plugin_system.base.component_types import ChatType, CommandInfo, ComponentType
if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("base_command")
@@ -29,11 +33,11 @@ class BaseCommand(ABC):
chat_type_allow: ChatType = ChatType.ALL
"""允许的聊天类型,默认为所有类型"""
def __init__(self, message: MessageRecv, plugin_config: dict | None = None):
def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None):
"""初始化Command组件
Args:
message: 接收到的消息对象
message: 接收到的消息对象DatabaseMessages
plugin_config: 插件配置字典
"""
self.message = message
@@ -41,6 +45,9 @@ class BaseCommand(ABC):
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
self.log_prefix = "[Command]"
# chat_stream 会在运行时被 bot.py 设置
self.chat_stream: "ChatStream | None" = None
# 从类属性获取chat_type_allow设置
self.chat_type_allow = getattr(self.__class__, "chat_type_allow", ChatType.ALL)
@@ -49,7 +56,7 @@ class BaseCommand(ABC):
# 验证聊天类型限制
if not self._validate_chat_type():
is_group = hasattr(self.message, "is_group_message") and self.message.is_group_message
is_group = message.group_info is not None
logger.warning(
f"{self.log_prefix} Command '{self.command_name}' 不支持当前聊天类型: "
f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}"
@@ -72,8 +79,8 @@ class BaseCommand(ABC):
if self.chat_type_allow == ChatType.ALL:
return True
# 检查是否为群聊消息
is_group = self.message.message_info.group_info
# 检查是否为群聊消息DatabaseMessages使用group_info来判断
is_group = self.message.group_info is not None
if self.chat_type_allow == ChatType.GROUP and is_group:
return True
@@ -137,12 +144,11 @@ class BaseCommand(ABC):
bool: 是否发送成功
"""
# 获取聊天流信息
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.text_to_stream(text=content, stream_id=chat_stream.stream_id, reply_to=reply_to)
return await send_api.text_to_stream(text=content, stream_id=self.chat_stream.stream_id, reply_to=reply_to)
async def send_type(
self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = ""
@@ -160,15 +166,14 @@ class BaseCommand(ABC):
bool: 是否发送成功
"""
# 获取聊天流信息
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.custom_to_stream(
message_type=message_type,
content=content,
stream_id=chat_stream.stream_id,
stream_id=self.chat_stream.stream_id,
display_message=display_message,
typing=typing,
reply_to=reply_to,
@@ -190,8 +195,7 @@ class BaseCommand(ABC):
"""
try:
# 获取聊天流信息
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
@@ -200,7 +204,7 @@ class BaseCommand(ABC):
success = await send_api.command_to_stream(
command=command_data,
stream_id=chat_stream.stream_id,
stream_id=self.chat_stream.stream_id,
storage_message=storage_message,
display_message=display_message,
)
@@ -225,12 +229,11 @@ class BaseCommand(ABC):
Returns:
bool: 是否发送成功
"""
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.emoji_to_stream(emoji_base64, chat_stream.stream_id)
return await send_api.emoji_to_stream(emoji_base64, self.chat_stream.stream_id)
async def send_image(self, image_base64: str) -> bool:
"""发送图片
@@ -241,12 +244,11 @@ class BaseCommand(ABC):
Returns:
bool: 是否发送成功
"""
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.image_to_stream(image_base64, chat_stream.stream_id)
return await send_api.image_to_stream(image_base64, self.chat_stream.stream_id)
@classmethod
def get_command_info(cls) -> "CommandInfo":

View File

@@ -5,8 +5,9 @@
import re
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from src.chat.message_receive.message import MessageRecv
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_system.apis import send_api
@@ -14,6 +15,9 @@ from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.command_args import CommandArgs
from src.plugin_system.base.component_types import ChatType, ComponentType, PlusCommandInfo
if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("plus_command")
@@ -50,23 +54,26 @@ class PlusCommand(ABC):
intercept_message: bool = False
"""是否拦截消息,不进行后续处理"""
def __init__(self, message: MessageRecv, plugin_config: dict | None = None):
def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None):
"""初始化命令组件
Args:
message: 接收到的消息对象
message: 接收到的消息对象DatabaseMessages
plugin_config: 插件配置字典
"""
self.message = message
self.plugin_config = plugin_config or {}
self.log_prefix = "[PlusCommand]"
# chat_stream 会在运行时被 bot.py 设置
self.chat_stream: "ChatStream | None" = None
# 解析命令参数
self._parse_command()
# 验证聊天类型限制
if not self._validate_chat_type():
is_group = self.message.message_info.group_info.group_id
is_group = message.group_info is not None
logger.warning(
f"{self.log_prefix} 命令 '{self.command_name}' 不支持当前聊天类型: "
f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}"
@@ -124,8 +131,8 @@ class PlusCommand(ABC):
if self.chat_type_allow == ChatType.ALL:
return True
# 检查是否为群聊消息
is_group = hasattr(self.message.message_info, "group_info") and self.message.message_info.group_info
# 检查是否为群聊消息DatabaseMessages使用group_info判断
is_group = self.message.group_info is not None
if self.chat_type_allow == ChatType.GROUP and is_group:
return True
@@ -152,7 +159,7 @@ class PlusCommand(ABC):
def _is_exact_command_call(self) -> bool:
"""检查是否是精确的命令调用(无参数)"""
if not hasattr(self.message, "plain_text") or not self.message.processed_plain_text:
if not self.message.processed_plain_text:
return False
plain_text = self.message.processed_plain_text.strip()
@@ -218,12 +225,11 @@ class PlusCommand(ABC):
bool: 是否发送成功
"""
# 获取聊天流信息
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.text_to_stream(text=content, stream_id=chat_stream.stream_id, reply_to=reply_to)
return await send_api.text_to_stream(text=content, stream_id=self.chat_stream.stream_id, reply_to=reply_to)
async def send_type(
self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = ""
@@ -241,15 +247,14 @@ class PlusCommand(ABC):
bool: 是否发送成功
"""
# 获取聊天流信息
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.custom_to_stream(
message_type=message_type,
content=content,
stream_id=chat_stream.stream_id,
stream_id=self.chat_stream.stream_id,
display_message=display_message,
typing=typing,
reply_to=reply_to,
@@ -264,12 +269,11 @@ class PlusCommand(ABC):
Returns:
bool: 是否发送成功
"""
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.emoji_to_stream(emoji_base64, chat_stream.stream_id)
return await send_api.emoji_to_stream(emoji_base64, self.chat_stream.stream_id)
async def send_image(self, image_base64: str) -> bool:
"""发送图片
@@ -280,12 +284,11 @@ class PlusCommand(ABC):
Returns:
bool: 是否发送成功
"""
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.image_to_stream(image_base64, chat_stream.stream_id)
return await send_api.image_to_stream(image_base64, self.chat_stream.stream_id)
@classmethod
def get_plus_command_info(cls) -> "PlusCommandInfo":
@@ -340,12 +343,12 @@ class PlusCommandAdapter(BaseCommand):
将PlusCommand适配到现有的插件系统继承BaseCommand
"""
def __init__(self, plus_command_class, message: MessageRecv, plugin_config: dict | None = None):
def __init__(self, plus_command_class, message: DatabaseMessages, plugin_config: dict | None = None):
"""初始化适配器
Args:
plus_command_class: PlusCommand子类
message: 消息对象
message: 消息对象DatabaseMessages
plugin_config: 插件配置
"""
# 先设置必要的类属性
@@ -400,7 +403,7 @@ def create_plus_command_adapter(plus_command_class):
command_pattern = plus_command_class._generate_command_pattern()
chat_type_allow = getattr(plus_command_class, "chat_type_allow", ChatType.ALL)
def __init__(self, message: MessageRecv, plugin_config: dict | None = None):
def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None):
super().__init__(message, plugin_config)
self.plus_command = plus_command_class(message, plugin_config)
self.priority = getattr(plus_command_class, "priority", 0)

View File

@@ -410,11 +410,9 @@ class ChatterPlanExecutor:
)
# 添加到chat_stream的已读消息中
if hasattr(chat_stream, "stream_context") and chat_stream.stream_context:
chat_stream.stream_context.history_messages.append(bot_message)
logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...")
else:
logger.warning("chat_stream没有stream_context无法添加已读消息")
chat_stream.context_manager.context.history_messages.append(bot_message)
logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...")
except Exception as e:
logger.error(f"添加机器人回复到已读消息时出错: {e}")

View File

@@ -96,7 +96,7 @@ class ProactiveThinkingMessageHandler(BaseEventHandler):
"""处理消息事件
Args:
kwargs: 事件参数,格式为 {"message": MessageRecv}
kwargs: 事件参数,格式为 {"message": DatabaseMessages}
Returns:
HandlerResult: 处理结果
@@ -104,7 +104,7 @@ class ProactiveThinkingMessageHandler(BaseEventHandler):
if not kwargs:
return HandlerResult(success=True, continue_process=True, message=None)
# 从 kwargs 中获取 MessageRecv 对象
# 从 kwargs 中获取 DatabaseMessages 对象
message = kwargs.get("message")
if not message or not hasattr(message, "chat_stream"):
return HandlerResult(success=True, continue_process=True, message=None)