重构消息处理并用DatabaseMessages替换MessageRecv

-更新PlusCommand以使用DatabaseMessages而不是MessageRecv。
-将消息处理逻辑重构到一个新模块message_processor.py中,以处理消息段并从消息字典中创建DatabaseMessages。
-删除了已弃用的MessageRecv类及其相关逻辑。
-调整了各种插件以适应新的DatabaseMessages结构。
-增强了消息处理功能中的错误处理和日志记录。
This commit is contained in:
Windpicker-owo
2025-10-31 19:24:58 +08:00
parent 50260818a8
commit 371041c9db
23 changed files with 1520 additions and 1801 deletions

View File

@@ -1,303 +0,0 @@
"""
关系追踪工具集成测试脚本
注意:此脚本需要在完整的应用环境中运行
建议通过 bot.py 启动后在交互式环境中测试
"""
import asyncio
async def test_user_profile_tool():
"""测试用户画像工具"""
print("\n" + "=" * 80)
print("测试 UserProfileTool")
print("=" * 80)
from src.plugins.built_in.affinity_flow_chatter.user_profile_tool import UserProfileTool
from src.common.database.sqlalchemy_database_api import db_query
from src.common.database.sqlalchemy_models import UserRelationships
tool = UserProfileTool()
print(f"✅ 工具名称: {tool.name}")
print(f" 工具描述: {tool.description}")
# 执行工具
test_user_id = "integration_test_user_001"
result = await tool.execute({
"target_user_id": test_user_id,
"user_aliases": "测试小明,TestMing,小明君",
"impression_description": "这是一个集成测试用户性格开朗活泼喜欢技术讨论对AI和编程特别感兴趣。经常提出有深度的问题。",
"preference_keywords": "AI,Python,深度学习,游戏开发,科幻小说",
"affection_score": 0.85
})
print(f"\n✅ 工具执行结果:")
print(f" 类型: {result.get('type')}")
print(f" 内容: {result.get('content')}")
# 验证数据库
db_data = await db_query(
UserRelationships,
filters={"user_id": test_user_id},
limit=1
)
if db_data:
data = db_data[0]
print(f"\n✅ 数据库验证:")
print(f" user_id: {data.get('user_id')}")
print(f" user_aliases: {data.get('user_aliases')}")
print(f" relationship_text: {data.get('relationship_text', '')[:80]}...")
print(f" preference_keywords: {data.get('preference_keywords')}")
print(f" relationship_score: {data.get('relationship_score')}")
return True
else:
print(f"\n❌ 数据库中未找到数据")
return False
async def test_chat_stream_impression_tool():
"""测试聊天流印象工具"""
print("\n" + "=" * 80)
print("测试 ChatStreamImpressionTool")
print("=" * 80)
from src.plugins.built_in.affinity_flow_chatter.chat_stream_impression_tool import ChatStreamImpressionTool
from src.common.database.sqlalchemy_database_api import db_query
from src.common.database.sqlalchemy_models import ChatStreams, get_db_session
# 准备测试数据:先创建一条 ChatStreams 记录
test_stream_id = "integration_test_stream_001"
print(f"🔧 准备测试数据:创建聊天流记录 {test_stream_id}")
import time
current_time = time.time()
async with get_db_session() as session:
new_stream = ChatStreams(
stream_id=test_stream_id,
create_time=current_time,
last_active_time=current_time,
platform="QQ",
user_platform="QQ",
user_id="test_user_123",
user_nickname="测试用户",
group_name="测试技术交流群",
group_platform="QQ",
group_id="test_group_456",
stream_impression_text="", # 初始为空
stream_chat_style="",
stream_topic_keywords="",
stream_interest_score=0.5
)
session.add(new_stream)
await session.commit()
print(f"✅ 测试聊天流记录已创建")
tool = ChatStreamImpressionTool()
print(f"✅ 工具名称: {tool.name}")
print(f" 工具描述: {tool.description}")
# 执行工具
result = await tool.execute({
"stream_id": test_stream_id,
"impression_description": "这是一个技术交流群成员主要是程序员和AI爱好者。大家经常分享最新的技术文章讨论编程问题氛围友好且专业。",
"chat_style": "专业技术交流,活跃讨论,互帮互助,知识分享",
"topic_keywords": "Python开发,机器学习,AI应用,Web后端,数据分析,开源项目",
"interest_score": 0.90
})
print(f"\n✅ 工具执行结果:")
print(f" 类型: {result.get('type')}")
print(f" 内容: {result.get('content')}")
# 验证数据库
db_data = await db_query(
ChatStreams,
filters={"stream_id": test_stream_id},
limit=1
)
if db_data:
data = db_data[0]
print(f"\n✅ 数据库验证:")
print(f" stream_id: {data.get('stream_id')}")
print(f" stream_impression_text: {data.get('stream_impression_text', '')[:80]}...")
print(f" stream_chat_style: {data.get('stream_chat_style')}")
print(f" stream_topic_keywords: {data.get('stream_topic_keywords')}")
print(f" stream_interest_score: {data.get('stream_interest_score')}")
return True
else:
print(f"\n❌ 数据库中未找到数据")
return False
async def test_relationship_info_build():
"""测试关系信息构建"""
print("\n" + "=" * 80)
print("测试关系信息构建(提示词集成)")
print("=" * 80)
from src.person_info.relationship_fetcher import relationship_fetcher_manager
test_stream_id = "integration_test_stream_001"
test_person_id = "test_person_999" # 使用一个可能不存在的ID来测试
fetcher = relationship_fetcher_manager.get_fetcher(test_stream_id)
print(f"✅ RelationshipFetcher 已创建")
# 测试聊天流印象构建
print(f"\n🔍 构建聊天流印象...")
stream_info = await fetcher.build_chat_stream_impression(test_stream_id)
if stream_info:
print(f"✅ 聊天流印象构建成功")
print(f"\n{'=' * 80}")
print(stream_info)
print(f"{'=' * 80}")
else:
print(f"⚠️ 聊天流印象为空(可能测试数据不存在)")
return True
async def cleanup_test_data():
"""清理测试数据"""
print("\n" + "=" * 80)
print("清理测试数据")
print("=" * 80)
from src.common.database.sqlalchemy_database_api import db_query
from src.common.database.sqlalchemy_models import UserRelationships, ChatStreams
try:
# 清理用户数据
await db_query(
UserRelationships,
query_type="delete",
filters={"user_id": "integration_test_user_001"}
)
print("✅ 用户测试数据已清理")
# 清理聊天流数据
await db_query(
ChatStreams,
query_type="delete",
filters={"stream_id": "integration_test_stream_001"}
)
print("✅ 聊天流测试数据已清理")
return True
except Exception as e:
print(f"⚠️ 清理失败: {e}")
return False
async def run_all_tests():
"""运行所有测试"""
print("\n" + "=" * 80)
print("关系追踪工具集成测试")
print("=" * 80)
results = {}
# 测试1
try:
results["UserProfileTool"] = await test_user_profile_tool()
except Exception as e:
print(f"\n❌ UserProfileTool 测试失败: {e}")
import traceback
traceback.print_exc()
results["UserProfileTool"] = False
# 测试2
try:
results["ChatStreamImpressionTool"] = await test_chat_stream_impression_tool()
except Exception as e:
print(f"\n❌ ChatStreamImpressionTool 测试失败: {e}")
import traceback
traceback.print_exc()
results["ChatStreamImpressionTool"] = False
# 测试3
try:
results["RelationshipFetcher"] = await test_relationship_info_build()
except Exception as e:
print(f"\n❌ RelationshipFetcher 测试失败: {e}")
import traceback
traceback.print_exc()
results["RelationshipFetcher"] = False
# 清理
try:
await cleanup_test_data()
except Exception as e:
print(f"\n⚠️ 清理测试数据失败: {e}")
# 总结
print("\n" + "=" * 80)
print("测试总结")
print("=" * 80)
passed = sum(1 for r in results.values() if r)
total = len(results)
for test_name, result in results.items():
status = "✅ 通过" if result else "❌ 失败"
print(f"{status} - {test_name}")
print(f"\n总计: {passed}/{total} 测试通过")
if passed == total:
print("\n🎉 所有测试通过!")
else:
print(f"\n⚠️ {total - passed} 个测试失败")
return passed == total
# 使用说明
print("""
============================================================================
关系追踪工具集成测试脚本
============================================================================
此脚本需要在完整的应用环境中运行。
使用方法1: 在 bot.py 中添加测试调用
-----------------------------------
在 bot.py 的 main() 函数中添加:
# 测试关系追踪工具
from tests.integration_test_relationship_tools import run_all_tests
await run_all_tests()
使用方法2: 在 Python REPL 中运行
-----------------------------------
启动 bot.py 后,在 Python 调试控制台中执行:
import asyncio
from tests.integration_test_relationship_tools import run_all_tests
asyncio.create_task(run_all_tests())
使用方法3: 直接在此文件底部运行
-----------------------------------
取消注释下面的代码,然后确保已启动应用环境
============================================================================
""")
# 如果需要直接运行(需要应用环境已启动)
if __name__ == "__main__":
print("\n⚠️ 警告: 直接运行此脚本可能会失败,因为缺少应用环境")
print("建议在 bot.py 启动后的环境中运行\n")
try:
asyncio.run(run_all_tests())
except Exception as e:
print(f"\n❌ 测试失败: {e}")
print("\n建议:")
print("1. 确保已启动 bot.py")
print("2. 在 Python 调试控制台中运行测试")
print("3. 或在 bot.py 中添加测试调用")

View File

@@ -27,6 +27,6 @@
"venvPath": ".", "venvPath": ".",
"venv": ".venv", "venv": ".venv",
"executionEnvironments": [ "executionEnvironments": [
{"root": "src"} {"root": "."}
] ]
} }

View File

@@ -6,7 +6,7 @@
import re import re
from src.chat.message_receive.message import MessageRecv from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger("anti_injector.message_processor") logger = get_logger("anti_injector.message_processor")
@@ -15,7 +15,7 @@ logger = get_logger("anti_injector.message_processor")
class MessageProcessor: class MessageProcessor:
"""消息内容处理器""" """消息内容处理器"""
def extract_text_content(self, message: MessageRecv) -> str: def extract_text_content(self, message: DatabaseMessages) -> str:
"""提取消息中的文本内容,过滤掉引用的历史内容 """提取消息中的文本内容,过滤掉引用的历史内容
Args: Args:
@@ -64,7 +64,7 @@ class MessageProcessor:
return new_content return new_content
@staticmethod @staticmethod
def check_whitelist(message: MessageRecv, whitelist: list) -> tuple | None: def check_whitelist(message: DatabaseMessages, whitelist: list) -> tuple | None:
"""检查用户白名单 """检查用户白名单
Args: Args:
@@ -74,8 +74,8 @@ class MessageProcessor:
Returns: Returns:
如果在白名单中返回结果元组否则返回None 如果在白名单中返回结果元组否则返回None
""" """
user_id = message.message_info.user_info.user_id user_id = message.user_info.user_id
platform = message.message_info.platform platform = message.chat_info.platform
# 检查用户白名单:格式为 [[platform, user_id], ...] # 检查用户白名单:格式为 [[platform, user_id], ...]
for whitelist_entry in whitelist: for whitelist_entry in whitelist:

View File

@@ -29,7 +29,6 @@ class SingleStreamContextManager:
# 配置参数 # 配置参数
self.max_context_size = max_context_size or getattr(global_config.chat, "max_context_size", 100) self.max_context_size = max_context_size or getattr(global_config.chat, "max_context_size", 100)
self.context_ttl = getattr(global_config.chat, "context_ttl", 24 * 3600) # 24小时
# 元数据 # 元数据
self.created_time = time.time() self.created_time = time.time()
@@ -93,27 +92,24 @@ class SingleStreamContextManager:
return True return True
else: else:
logger.warning(f"消息缓存系统添加失败,回退到直接添加: {self.stream_id}") logger.warning(f"消息缓存系统添加失败,回退到直接添加: {self.stream_id}")
except ImportError:
logger.debug("MessageManager不可用使用直接添加模式")
except Exception as e: except Exception as e:
logger.warning(f"消息缓存系统异常,回退到直接添加: {self.stream_id}, error={e}") logger.warning(f"消息缓存系统异常,回退到直接添加: {self.stream_id}, error={e}")
# 回退方案:直接添加到未读消息 # 回退方案:直接添加到未读消息
message.is_read = False message.is_read = False
self.context.unread_messages.append(message) self.context.unread_messages.append(message)
# 自动检测和更新chat type # 自动检测和更新chat type
self._detect_chat_type(message) self._detect_chat_type(message)
# 在上下文管理器中计算兴趣值 # 在上下文管理器中计算兴趣值
await self._calculate_message_interest(message) await self._calculate_message_interest(message)
self.total_messages += 1 self.total_messages += 1
self.last_access_time = time.time() self.last_access_time = time.time()
# 启动流的循环任务(如果还未启动) # 启动流的循环任务(如果还未启动)
asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id)) asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}") logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}")
return True return True
except Exception as e: except Exception as e:
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True) logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
return False return False

View File

@@ -71,14 +71,6 @@ class MessageManager:
except Exception as e: except Exception as e:
logger.error(f"启动批量数据库写入器失败: {e}") logger.error(f"启动批量数据库写入器失败: {e}")
# 启动流缓存管理器
try:
from src.chat.message_manager.stream_cache_manager import init_stream_cache_manager
await init_stream_cache_manager()
except Exception as e:
logger.error(f"启动流缓存管理器失败: {e}")
# 启动消息缓存系统(内置) # 启动消息缓存系统(内置)
logger.info("📦 消息缓存系统已启动") logger.info("📦 消息缓存系统已启动")
@@ -116,15 +108,6 @@ class MessageManager:
except Exception as e: except Exception as e:
logger.error(f"停止批量数据库写入器失败: {e}") logger.error(f"停止批量数据库写入器失败: {e}")
# 停止流缓存管理器
try:
from src.chat.message_manager.stream_cache_manager import shutdown_stream_cache_manager
await shutdown_stream_cache_manager()
logger.info("🗄️ 流缓存管理器已停止")
except Exception as e:
logger.error(f"停止流缓存管理器失败: {e}")
# 停止消息缓存系统(内置) # 停止消息缓存系统(内置)
self.message_caches.clear() self.message_caches.clear()
self.stream_processing_status.clear() self.stream_processing_status.clear()
@@ -152,7 +135,7 @@ class MessageManager:
# 检查是否为notice消息 # 检查是否为notice消息
if self._is_notice_message(message): if self._is_notice_message(message):
# Notice消息处理 - 添加到全局管理器 # Notice消息处理 - 添加到全局管理器
logger.info(f"📢 检测到notice消息: message_id={message.message_id}, is_notify={message.is_notify}, notice_type={getattr(message, 'notice_type', None)}") logger.info(f"📢 检测到notice消息: notice_type={getattr(message, 'notice_type', None)}")
await self._handle_notice_message(stream_id, message) await self._handle_notice_message(stream_id, message)
# 根据配置决定是否继续处理(触发聊天流程) # 根据配置决定是否继续处理(触发聊天流程)
@@ -206,39 +189,6 @@ class MessageManager:
except Exception as e: except Exception as e:
logger.error(f"更新消息 {message_id} 时发生错误: {e}") logger.error(f"更新消息 {message_id} 时发生错误: {e}")
async def bulk_update_messages(self, stream_id: str, updates: list[dict[str, Any]]) -> int:
"""批量更新消息信息,降低更新频率"""
if not updates:
return 0
try:
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"MessageManager.bulk_update_messages: 聊天流 {stream_id} 不存在")
return 0
updated_count = 0
for item in updates:
message_id = item.get("message_id")
if not message_id:
continue
payload = {key: value for key, value in item.items() if key != "message_id" and value is not None}
if not payload:
continue
success = await chat_stream.context_manager.update_message(message_id, payload)
if success:
updated_count += 1
if updated_count:
logger.debug(f"批量更新消息 {updated_count} 条 (stream={stream_id})")
return updated_count
except Exception as e:
logger.error(f"批量更新聊天流 {stream_id} 消息失败: {e}")
return 0
async def add_action(self, stream_id: str, message_id: str, action: str): async def add_action(self, stream_id: str, message_id: str, action: str):
"""添加动作到消息""" """添加动作到消息"""
@@ -266,7 +216,7 @@ class MessageManager:
logger.warning(f"停用流失败: 聊天流 {stream_id} 不存在") logger.warning(f"停用流失败: 聊天流 {stream_id} 不存在")
return return
context = chat_stream.stream_context context = chat_stream.context_manager.context
context.is_active = False context.is_active = False
# 取消处理任务 # 取消处理任务
@@ -288,7 +238,7 @@ class MessageManager:
logger.warning(f"激活流失败: 聊天流 {stream_id} 不存在") logger.warning(f"激活流失败: 聊天流 {stream_id} 不存在")
return return
context = chat_stream.stream_context context = chat_stream.context_manager.context
context.is_active = True context.is_active = True
logger.info(f"激活聊天流: {stream_id}") logger.info(f"激活聊天流: {stream_id}")
@@ -304,7 +254,7 @@ class MessageManager:
if not chat_stream: if not chat_stream:
return None return None
context = chat_stream.stream_context context = chat_stream.context_manager.context
unread_count = len(chat_stream.context_manager.get_unread_messages()) unread_count = len(chat_stream.context_manager.get_unread_messages())
return StreamStats( return StreamStats(
@@ -447,7 +397,7 @@ class MessageManager:
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
# 获取当前的stream context # 获取当前的stream context
context = chat_stream.stream_context context = chat_stream.context_manager.context
# 确保有未读消息需要处理 # 确保有未读消息需要处理
unread_messages = context.get_unread_messages() unread_messages = context.get_unread_messages()

View File

@@ -1,377 +0,0 @@
"""
流缓存管理器 - 使用优化版聊天流和智能缓存策略
提供分层缓存和自动清理功能
"""
import asyncio
import time
from collections import OrderedDict
from dataclasses import dataclass
from maim_message import GroupInfo, UserInfo
from src.chat.message_receive.optimized_chat_stream import OptimizedChatStream, create_optimized_chat_stream
from src.common.logger import get_logger
logger = get_logger("stream_cache_manager")
@dataclass
class StreamCacheStats:
"""缓存统计信息"""
hot_cache_size: int = 0
warm_storage_size: int = 0
cold_storage_size: int = 0
total_memory_usage: int = 0 # 估算的内存使用(字节)
cache_hits: int = 0
cache_misses: int = 0
evictions: int = 0
last_cleanup_time: float = 0
class TieredStreamCache:
"""分层流缓存管理器"""
def __init__(
self,
max_hot_size: int = 100,
max_warm_size: int = 500,
max_cold_size: int = 2000,
cleanup_interval: float = 300.0, # 5分钟清理一次
hot_timeout: float = 1800.0, # 30分钟未访问降级到warm
warm_timeout: float = 7200.0, # 2小时未访问降级到cold
cold_timeout: float = 86400.0, # 24小时未访问删除
):
self.max_hot_size = max_hot_size
self.max_warm_size = max_warm_size
self.max_cold_size = max_cold_size
self.cleanup_interval = cleanup_interval
self.hot_timeout = hot_timeout
self.warm_timeout = warm_timeout
self.cold_timeout = cold_timeout
# 三层缓存存储
self.hot_cache: OrderedDict[str, OptimizedChatStream] = OrderedDict() # 热数据LRU
self.warm_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间)
self.cold_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间)
# 统计信息
self.stats = StreamCacheStats()
# 清理任务
self.cleanup_task: asyncio.Task | None = None
self.is_running = False
logger.info(f"分层流缓存管理器初始化完成 (hot:{max_hot_size}, warm:{max_warm_size}, cold:{max_cold_size})")
async def start(self):
"""启动缓存管理器"""
if self.is_running:
logger.warning("缓存管理器已经在运行")
return
self.is_running = True
self.cleanup_task = asyncio.create_task(self._cleanup_loop(), name="stream_cache_cleanup")
async def stop(self):
"""停止缓存管理器"""
if not self.is_running:
return
self.is_running = False
if self.cleanup_task and not self.cleanup_task.done():
self.cleanup_task.cancel()
try:
await asyncio.wait_for(self.cleanup_task, timeout=10.0)
except asyncio.TimeoutError:
logger.warning("缓存清理任务停止超时")
except Exception as e:
logger.error(f"停止缓存清理任务时出错: {e}")
logger.info("分层流缓存管理器已停止")
async def get_or_create_stream(
self,
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: GroupInfo | None = None,
data: dict | None = None,
) -> OptimizedChatStream:
"""获取或创建流 - 优化版本"""
current_time = time.time()
# 1. 检查热缓存
if stream_id in self.hot_cache:
stream = self.hot_cache[stream_id]
# 移动到末尾LRU更新
self.hot_cache.move_to_end(stream_id)
self.stats.cache_hits += 1
logger.debug(f"热缓存命中: {stream_id}")
return stream.create_snapshot()
# 2. 检查温存储
if stream_id in self.warm_storage:
stream, last_access = self.warm_storage[stream_id]
self.warm_storage[stream_id] = (stream, current_time)
self.stats.cache_hits += 1
logger.debug(f"温缓存命中: {stream_id}")
# 提升到热缓存
await self._promote_to_hot(stream_id, stream)
return stream.create_snapshot()
# 3. 检查冷存储
if stream_id in self.cold_storage:
stream, last_access = self.cold_storage[stream_id]
self.cold_storage[stream_id] = (stream, current_time)
self.stats.cache_hits += 1
logger.debug(f"冷缓存命中: {stream_id}")
# 提升到温缓存
await self._promote_to_warm(stream_id, stream)
return stream.create_snapshot()
# 4. 缓存未命中,创建新流
self.stats.cache_misses += 1
stream = create_optimized_chat_stream(
stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info, data=data
)
logger.debug(f"缓存未命中,创建新流: {stream_id}")
# 添加到热缓存
await self._add_to_hot(stream_id, stream)
return stream
async def _add_to_hot(self, stream_id: str, stream: OptimizedChatStream):
"""添加到热缓存"""
# 检查是否需要驱逐
if len(self.hot_cache) >= self.max_hot_size:
await self._evict_from_hot()
self.hot_cache[stream_id] = stream
self.stats.hot_cache_size = len(self.hot_cache)
async def _promote_to_hot(self, stream_id: str, stream: OptimizedChatStream):
"""提升到热缓存"""
# 从温存储中移除
if stream_id in self.warm_storage:
del self.warm_storage[stream_id]
self.stats.warm_storage_size = len(self.warm_storage)
# 添加到热缓存
await self._add_to_hot(stream_id, stream)
logger.debug(f"{stream_id} 提升到热缓存")
async def _promote_to_warm(self, stream_id: str, stream: OptimizedChatStream):
"""提升到温缓存"""
# 从冷存储中移除
if stream_id in self.cold_storage:
del self.cold_storage[stream_id]
self.stats.cold_storage_size = len(self.cold_storage)
# 添加到温存储
if len(self.warm_storage) >= self.max_warm_size:
await self._evict_from_warm()
current_time = time.time()
self.warm_storage[stream_id] = (stream, current_time)
self.stats.warm_storage_size = len(self.warm_storage)
logger.debug(f"{stream_id} 提升到温缓存")
async def _evict_from_hot(self):
"""从热缓存驱逐最久未使用的流"""
if not self.hot_cache:
return
# LRU驱逐
stream_id, stream = self.hot_cache.popitem(last=False)
self.stats.evictions += 1
logger.debug(f"从热缓存驱逐: {stream_id}")
# 移动到温存储
if len(self.warm_storage) < self.max_warm_size:
current_time = time.time()
self.warm_storage[stream_id] = (stream, current_time)
self.stats.warm_storage_size = len(self.warm_storage)
else:
# 温存储也满了,直接删除
logger.debug(f"温存储已满,删除流: {stream_id}")
self.stats.hot_cache_size = len(self.hot_cache)
async def _evict_from_warm(self):
"""从温存储驱逐最久未使用的流"""
if not self.warm_storage:
return
# 找到最久未访问的流
oldest_stream_id = min(self.warm_storage.keys(), key=lambda k: self.warm_storage[k][1])
stream, last_access = self.warm_storage.pop(oldest_stream_id)
self.stats.evictions += 1
logger.debug(f"从温存储驱逐: {oldest_stream_id}")
# 移动到冷存储
if len(self.cold_storage) < self.max_cold_size:
current_time = time.time()
self.cold_storage[oldest_stream_id] = (stream, current_time)
self.stats.cold_storage_size = len(self.cold_storage)
else:
# 冷存储也满了,直接删除
logger.debug(f"冷存储已满,删除流: {oldest_stream_id}")
self.stats.warm_storage_size = len(self.warm_storage)
async def _cleanup_loop(self):
"""清理循环"""
logger.info("流缓存清理循环启动")
while self.is_running:
try:
await asyncio.sleep(self.cleanup_interval)
await self._perform_cleanup()
except asyncio.CancelledError:
logger.info("流缓存清理循环被取消")
break
except Exception as e:
logger.error(f"流缓存清理出错: {e}")
logger.info("流缓存清理循环结束")
async def _perform_cleanup(self):
"""执行清理操作"""
current_time = time.time()
cleanup_stats = {
"hot_to_warm": 0,
"warm_to_cold": 0,
"cold_removed": 0,
}
# 1. 检查热缓存超时
hot_to_demote = []
for stream_id, stream in self.hot_cache.items():
# 获取最后访问时间(简化:使用创建时间作为近似)
last_access = getattr(stream, "last_active_time", stream.create_time)
if current_time - last_access > self.hot_timeout:
hot_to_demote.append(stream_id)
for stream_id in hot_to_demote:
stream = self.hot_cache.pop(stream_id)
current_time_local = time.time()
self.warm_storage[stream_id] = (stream, current_time_local)
cleanup_stats["hot_to_warm"] += 1
# 2. 检查温存储超时
warm_to_demote = []
for stream_id, (stream, last_access) in self.warm_storage.items():
if current_time - last_access > self.warm_timeout:
warm_to_demote.append(stream_id)
for stream_id in warm_to_demote:
stream, last_access = self.warm_storage.pop(stream_id)
self.cold_storage[stream_id] = (stream, last_access)
cleanup_stats["warm_to_cold"] += 1
# 3. 检查冷存储超时
cold_to_remove = []
for stream_id, (stream, last_access) in self.cold_storage.items():
if current_time - last_access > self.cold_timeout:
cold_to_remove.append(stream_id)
for stream_id in cold_to_remove:
self.cold_storage.pop(stream_id)
cleanup_stats["cold_removed"] += 1
# 更新统计信息
self.stats.hot_cache_size = len(self.hot_cache)
self.stats.warm_storage_size = len(self.warm_storage)
self.stats.cold_storage_size = len(self.cold_storage)
self.stats.last_cleanup_time = current_time
# 估算内存使用(粗略估计)
self.stats.total_memory_usage = (
len(self.hot_cache) * 1024 # 每个热流约1KB
+ len(self.warm_storage) * 512 # 每个温流约512B
+ len(self.cold_storage) * 256 # 每个冷流约256B
)
if sum(cleanup_stats.values()) > 0:
logger.info(
f"缓存清理完成: {cleanup_stats['hot_to_warm']}热→温, "
f"{cleanup_stats['warm_to_cold']}温→冷, "
f"{cleanup_stats['cold_removed']}冷删除"
)
def get_stats(self) -> StreamCacheStats:
"""获取缓存统计信息"""
# 计算命中率
total_requests = self.stats.cache_hits + self.stats.cache_misses
hit_rate = self.stats.cache_hits / total_requests if total_requests > 0 else 0
stats_copy = StreamCacheStats(
hot_cache_size=self.stats.hot_cache_size,
warm_storage_size=self.stats.warm_storage_size,
cold_storage_size=self.stats.cold_storage_size,
total_memory_usage=self.stats.total_memory_usage,
cache_hits=self.stats.cache_hits,
cache_misses=self.stats.cache_misses,
evictions=self.stats.evictions,
last_cleanup_time=self.stats.last_cleanup_time,
)
# 添加命中率信息
stats_copy.hit_rate = hit_rate
return stats_copy
def clear_cache(self):
"""清空所有缓存"""
self.hot_cache.clear()
self.warm_storage.clear()
self.cold_storage.clear()
self.stats.hot_cache_size = 0
self.stats.warm_storage_size = 0
self.stats.cold_storage_size = 0
self.stats.total_memory_usage = 0
logger.info("所有缓存已清空")
async def get_stream_snapshot(self, stream_id: str) -> OptimizedChatStream | None:
"""获取流的快照(不修改缓存状态)"""
if stream_id in self.hot_cache:
return self.hot_cache[stream_id].create_snapshot()
elif stream_id in self.warm_storage:
return self.warm_storage[stream_id][0].create_snapshot()
elif stream_id in self.cold_storage:
return self.cold_storage[stream_id][0].create_snapshot()
return None
def get_cached_stream_ids(self) -> set[str]:
"""获取所有缓存的流ID"""
return set(self.hot_cache.keys()) | set(self.warm_storage.keys()) | set(self.cold_storage.keys())
# 全局缓存管理器实例
_cache_manager: TieredStreamCache | None = None
def get_stream_cache_manager() -> TieredStreamCache:
"""获取流缓存管理器实例"""
global _cache_manager
if _cache_manager is None:
_cache_manager = TieredStreamCache()
return _cache_manager
async def init_stream_cache_manager():
"""初始化流缓存管理器"""
manager = get_stream_cache_manager()
await manager.start()
async def shutdown_stream_cache_manager():
"""关闭流缓存管理器"""
manager = get_stream_cache_manager()
await manager.stop()

View File

@@ -9,10 +9,10 @@ from maim_message import UserInfo
from src.chat.antipromptinjector import initialize_anti_injector from src.chat.antipromptinjector import initialize_anti_injector
from src.chat.message_manager import message_manager from src.chat.message_manager import message_manager
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.storage import MessageStorage from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.prompt import create_prompt_async, global_prompt_manager from src.chat.utils.prompt import create_prompt_async, global_prompt_manager
from src.chat.utils.utils import is_mentioned_bot_in_message from src.chat.utils.utils import is_mentioned_bot_in_message
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
from src.mood.mood_manager import mood_manager # 导入情绪管理器 from src.mood.mood_manager import mood_manager # 导入情绪管理器
@@ -105,10 +105,10 @@ class ChatBot:
self._started = True self._started = True
async def _process_plus_commands(self, message: MessageRecv): async def _process_plus_commands(self, message: DatabaseMessages, chat: ChatStream):
"""独立处理PlusCommand系统""" """独立处理PlusCommand系统"""
try: try:
text = message.processed_plain_text text = message.processed_plain_text or ""
# 获取配置的命令前缀 # 获取配置的命令前缀
from src.config.config import global_config from src.config.config import global_config
@@ -166,10 +166,10 @@ class ChatBot:
# 检查命令是否被禁用 # 检查命令是否被禁用
if ( if (
message.chat_stream chat
and message.chat_stream.stream_id and chat.stream_id
and plus_command_name and plus_command_name
in global_announcement_manager.get_disabled_chat_commands(message.chat_stream.stream_id) in global_announcement_manager.get_disabled_chat_commands(chat.stream_id)
): ):
logger.info("用户禁用的PlusCommand跳过处理") logger.info("用户禁用的PlusCommand跳过处理")
return False, None, True return False, None, True
@@ -181,11 +181,14 @@ class ChatBot:
# 创建PlusCommand实例 # 创建PlusCommand实例
plus_command_instance = plus_command_class(message, plugin_config) plus_command_instance = plus_command_class(message, plugin_config)
# 为插件实例设置 chat_stream 运行时属性
setattr(plus_command_instance, "chat_stream", chat)
try: try:
# 检查聊天类型限制 # 检查聊天类型限制
if not plus_command_instance.is_chat_type_allowed(): if not plus_command_instance.is_chat_type_allowed():
is_group = message.message_info.group_info is_group = chat.group_info is not None
logger.info( logger.info(
f"PlusCommand {plus_command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}" f"PlusCommand {plus_command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
) )
@@ -225,11 +228,11 @@ class ChatBot:
logger.error(f"处理PlusCommand时出错: {e}") logger.error(f"处理PlusCommand时出错: {e}")
return False, None, True # 出错时继续处理消息 return False, None, True # 出错时继续处理消息
async def _process_commands_with_new_system(self, message: MessageRecv): async def _process_commands_with_new_system(self, message: DatabaseMessages, chat: ChatStream):
# sourcery skip: use-named-expression # sourcery skip: use-named-expression
"""使用新插件系统处理命令""" """使用新插件系统处理命令"""
try: try:
text = message.processed_plain_text text = message.processed_plain_text or ""
# 使用新的组件注册中心查找命令 # 使用新的组件注册中心查找命令
command_result = component_registry.find_command_by_text(text) command_result = component_registry.find_command_by_text(text)
@@ -238,10 +241,10 @@ class ChatBot:
plugin_name = command_info.plugin_name plugin_name = command_info.plugin_name
command_name = command_info.name command_name = command_info.name
if ( if (
message.chat_stream chat
and message.chat_stream.stream_id and chat.stream_id
and command_name and command_name
in global_announcement_manager.get_disabled_chat_commands(message.chat_stream.stream_id) in global_announcement_manager.get_disabled_chat_commands(chat.stream_id)
): ):
logger.info("用户禁用的命令,跳过处理") logger.info("用户禁用的命令,跳过处理")
return False, None, True return False, None, True
@@ -254,11 +257,14 @@ class ChatBot:
# 创建命令实例 # 创建命令实例
command_instance: BaseCommand = command_class(message, plugin_config) command_instance: BaseCommand = command_class(message, plugin_config)
command_instance.set_matched_groups(matched_groups) command_instance.set_matched_groups(matched_groups)
# 为插件实例设置 chat_stream 运行时属性
setattr(command_instance, "chat_stream", chat)
try: try:
# 检查聊天类型限制 # 检查聊天类型限制
if not command_instance.is_chat_type_allowed(): if not command_instance.is_chat_type_allowed():
is_group = message.message_info.group_info is_group = chat.group_info is not None
logger.info( logger.info(
f"命令 {command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}" f"命令 {command_class.__name__} 不支持当前聊天类型: {'群聊' if is_group else '私聊'}"
) )
@@ -295,13 +301,20 @@ class ChatBot:
logger.error(f"处理命令时出错: {e}") logger.error(f"处理命令时出错: {e}")
return False, None, True # 出错时继续处理消息 return False, None, True # 出错时继续处理消息
async def handle_notice_message(self, message: MessageRecv): async def handle_notice_message(self, message: DatabaseMessages):
"""处理notice消息 """处理notice消息
notice消息是系统事件通知如禁言、戳一戳等具有以下特点 notice消息是系统事件通知如禁言、戳一戳等具有以下特点
1. 默认不触发聊天流程,只记录 1. 默认不触发聊天流程,只记录
2. 可通过配置开启触发聊天流程 2. 可通过配置开启触发聊天流程
3. 会在提示词中展示 3. 会在提示词中展示
Args:
message: DatabaseMessages 对象
Returns:
bool: True表示notice已完整处理需要存储并终止后续流程
False表示不是notice或notice需要继续处理触发聊天流程
""" """
# 检查是否是notice消息 # 检查是否是notice消息
if message.is_notify: if message.is_notify:
@@ -309,53 +322,42 @@ class ChatBot:
# 根据配置决定是否触发聊天流程 # 根据配置决定是否触发聊天流程
if not global_config.notice.enable_notice_trigger_chat: if not global_config.notice.enable_notice_trigger_chat:
logger.debug("notice消息不触发聊天流程配置已关闭") logger.debug("notice消息不触发聊天流程配置已关闭,将存储后终止")
return True # 返回True表示已处理,不继续后续流程 return True # 返回True:需要在调用处存储并终止
else: else:
logger.debug("notice消息触发聊天流程配置已开启") logger.debug("notice消息触发聊天流程配置已开启,继续处理")
return False # 返回False表示继续处理,触发聊天流程 return False # 返回False:继续正常流程,作为普通消息处理
# 兼容旧的notice判断方式 # 兼容旧的notice判断方式
if message.message_info.message_id == "notice": if message.message_id == "notice":
message.is_notify = True # 为 DatabaseMessages 设置 is_notify 运行时属性
from src.chat.message_receive.message_processor import set_db_message_runtime_attr
set_db_message_runtime_attr(message, "is_notify", True)
logger.info("旧格式notice消息") logger.info("旧格式notice消息")
# 同样根据配置决定 # 同样根据配置决定
if not global_config.notice.enable_notice_trigger_chat: if not global_config.notice.enable_notice_trigger_chat:
return True logger.debug("旧格式notice消息不触发聊天流程将存储后终止")
return True # 需要存储并终止
else: else:
return False logger.debug("旧格式notice消息触发聊天流程继续处理")
return False # 继续正常流程
# 处理适配器响应消息 # DatabaseMessages 不再有 message_segment适配器响应处理已在消息处理阶段完成
if hasattr(message, "message_segment") and message.message_segment: # 这里保留逻辑以防万一,但实际上不会再执行到
if message.message_segment.type == "adapter_response": return False # 不是notice消息继续正常流程
await self.handle_adapter_response(message)
return True
elif message.message_segment.type == "adapter_command":
# 适配器命令消息不需要进一步处理
logger.debug("收到适配器命令消息,跳过后续处理")
return True
return False async def handle_adapter_response(self, message: DatabaseMessages):
"""处理适配器命令响应
async def handle_adapter_response(self, message: MessageRecv):
"""处理适配器命令响应""" 注意: 此方法目前未被调用,但保留以备将来使用
"""
try: try:
from src.plugin_system.apis.send_api import put_adapter_response from src.plugin_system.apis.send_api import put_adapter_response
seg_data = message.message_segment.data # DatabaseMessages 使用 message_segments 字段存储消息段
if isinstance(seg_data, dict): # 注意: 这可能需要根据实际使用情况进行调整
request_id = seg_data.get("request_id") logger.warning("handle_adapter_response 方法被调用,但目前未实现对 DatabaseMessages 的支持")
response_data = seg_data.get("response")
else:
request_id = None
response_data = None
if request_id and response_data:
logger.debug(f"收到适配器响应: request_id={request_id}")
put_adapter_response(request_id, response_data)
else:
logger.warning("适配器响应消息格式不正确")
except Exception as e: except Exception as e:
logger.error(f"处理适配器响应时出错: {e}") logger.error(f"处理适配器响应时出错: {e}")
@@ -381,9 +383,6 @@ class ChatBot:
await self._ensure_started() await self._ensure_started()
# 控制握手等消息可能缺少 message_info这里直接跳过避免 KeyError # 控制握手等消息可能缺少 message_info这里直接跳过避免 KeyError
if not isinstance(message_data, dict):
logger.warning(f"收到无法解析的消息类型: {type(message_data)},已跳过")
return
message_info = message_data.get("message_info") message_info = message_data.get("message_info")
if not isinstance(message_info, dict): if not isinstance(message_info, dict):
logger.debug( logger.debug(
@@ -392,8 +391,6 @@ class ChatBot:
) )
return return
platform = message_info.get("platform")
if message_info.get("group_info") is not None: if message_info.get("group_info") is not None:
message_info["group_info"]["group_id"] = str( message_info["group_info"]["group_id"] = str(
message_info["group_info"]["group_id"] message_info["group_info"]["group_id"]
@@ -404,74 +401,94 @@ class ChatBot:
) )
# print(message_data) # print(message_data)
# logger.debug(str(message_data)) # logger.debug(str(message_data))
message = MessageRecv(message_data)
# 先提取基础信息检查是否是自身消息上报
group_info = message.message_info.group_info from maim_message import BaseMessageInfo
user_info = message.message_info.user_info temp_message_info = BaseMessageInfo.from_dict(message_data.get("message_info", {}))
if message.message_info.additional_config: if temp_message_info.additional_config:
sent_message = message.message_info.additional_config.get("echo", False) sent_message = temp_message_info.additional_config.get("echo", False)
if sent_message: # 这一段只是为了在一切处理前劫持上报的自身消息用于更新message_id需要ada支持上报事件实际测试中不会对正常使用造成任何问题 if sent_message: # 这一段只是为了在一切处理前劫持上报的自身消息用于更新message_id需要ada支持上报事件实际测试中不会对正常使用造成任何问题
await MessageStorage.update_message(message) # 直接使用消息字典更新,不再需要创建 MessageRecv
await MessageStorage.update_message(message_data)
return return
group_info = temp_message_info.group_info
user_info = temp_message_info.user_info
get_chat_manager().register_message(message) # 获取或创建聊天流
chat = await get_chat_manager().get_or_create_stream( chat = await get_chat_manager().get_or_create_stream(
platform=message.message_info.platform, # type: ignore platform=temp_message_info.platform, # type: ignore
user_info=user_info, # type: ignore user_info=user_info, # type: ignore
group_info=group_info, group_info=group_info,
) )
message.update_chat_stream(chat) # 使用新的消息处理器直接生成 DatabaseMessages
from src.chat.message_receive.message_processor import process_message_from_dict
# 处理消息内容,生成纯文本 message = await process_message_from_dict(
await message.process() message_dict=message_data,
stream_id=chat.stream_id,
platform=chat.platform
)
# 填充聊天流时间信息
message.chat_info.create_time = chat.create_time
message.chat_info.last_active_time = chat.last_active_time
# 注册消息到聊天管理器
get_chat_manager().register_message(message)
# 检测是否提及机器人
message.is_mentioned, _ = is_mentioned_bot_in_message(message) message.is_mentioned, _ = is_mentioned_bot_in_message(message)
# 在这里打印[所见]日志,确保在所有处理和过滤之前记录 # 在这里打印[所见]日志,确保在所有处理和过滤之前记录
chat_name = chat.group_info.group_name if chat.group_info else "私聊" chat_name = chat.group_info.group_name if chat.group_info else "私聊"
if message.message_info.user_info: user_nickname = message.user_info.user_nickname if message.user_info else "未知用户"
logger.info( logger.info(
f"[{chat_name}]{message.message_info.user_info.user_nickname}:{message.processed_plain_text}\u001b[0m" f"[{chat_name}]{user_nickname}:{message.processed_plain_text}\u001b[0m"
) )
# 在此添加硬编码过滤,防止回复图片处理失败的消息 # 在此添加硬编码过滤,防止回复图片处理失败的消息
failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"] failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"]
if any(keyword in message.processed_plain_text for keyword in failure_keywords): processed_text = message.processed_plain_text or ""
logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({message.processed_plain_text}),消息被静默处理。") if any(keyword in processed_text for keyword in failure_keywords):
logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({processed_text}),消息被静默处理。")
return return
# 处理notice消息 # 处理notice消息
# notice_handled=True: 表示notice不触发聊天需要在此存储并终止
# notice_handled=False: 表示notice触发聊天或不是notice继续正常流程
notice_handled = await self.handle_notice_message(message) notice_handled = await self.handle_notice_message(message)
if notice_handled: if notice_handled:
# notice消息已处理,使用统一的转换方法 # notice消息不触发聊天流程,在此进行存储和记录后终止
try: try:
# 直接转换为 DatabaseMessages # message 已经是 DatabaseMessages,直接使用
db_message = message.to_database_message()
# 添加到message_manager这会将notice添加到全局notice管理器 # 添加到message_manager这会将notice添加到全局notice管理器
await message_manager.add_message(message.chat_stream.stream_id, db_message) await message_manager.add_message(chat.stream_id, message)
logger.info(f"✅ Notice消息已添加到message_manager: type={message.notice_type}, stream={message.chat_stream.stream_id}") logger.info(f"✅ Notice消息已添加到message_manager: type={message.notice_type}, stream={chat.stream_id}")
except Exception as e: except Exception as e:
logger.error(f"Notice消息添加到message_manager失败: {e}", exc_info=True) logger.error(f"Notice消息添加到message_manager失败: {e}", exc_info=True)
# 存储后直接返回 # 存储notice消息到数据库需要更新 storage.py 支持 DatabaseMessages
await MessageStorage.store_message(message, chat) # 暂时跳过存储,等待更新 storage.py
logger.debug("notice消息已存储,跳过后续处理") logger.debug("notice消息已添加到message_manager存储功能待更新")
return return
# 如果notice_handled=False则继续执行后续流程
# 对于启用触发聊天的notice会在后续的正常流程中被存储和处理
# 过滤检查 # 过滤检查
# DatabaseMessages 使用 display_message 作为原始消息表示
raw_text = message.display_message or message.processed_plain_text or ""
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
message.raw_message, # type: ignore raw_text,
chat, chat,
user_info, # type: ignore user_info, # type: ignore
): ):
return return
# 命令处理 - 首先尝试PlusCommand独立处理 # 命令处理 - 首先尝试PlusCommand独立处理
is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message) is_plus_command, plus_cmd_result, plus_continue_process = await self._process_plus_commands(message, chat)
# 如果是PlusCommand且不需要继续处理则直接返回 # 如果是PlusCommand且不需要继续处理则直接返回
if is_plus_command and not plus_continue_process: if is_plus_command and not plus_continue_process:
@@ -481,7 +498,7 @@ class ChatBot:
# 如果不是PlusCommand尝试传统的BaseCommand处理 # 如果不是PlusCommand尝试传统的BaseCommand处理
if not is_plus_command: if not is_plus_command:
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message) is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message, chat)
# 如果是命令且不需要继续处理,则直接返回 # 如果是命令且不需要继续处理,则直接返回
if is_command and not continue_process: if is_command and not continue_process:
@@ -493,24 +510,14 @@ class ChatBot:
if result and not result.all_continue_process(): if result and not result.all_continue_process():
raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理") raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于消息到达时取消了消息处理")
# TODO:暂不可用 # TODO:暂不可用 - DatabaseMessages 不再有 message_info.template_info
# 确认从接口发来的message是否有自定义的prompt模板信息 # 确认从接口发来的message是否有自定义的prompt模板信息
if message.message_info.template_info and not message.message_info.template_info.template_default: # 这个功能需要在 adapter 层通过 additional_config 传递
template_group_name: str | None = message.message_info.template_info.template_name # type: ignore template_group_name = None
template_items = message.message_info.template_info.template_items
async with global_prompt_manager.async_message_scope(template_group_name):
if isinstance(template_items, dict):
for k in template_items.keys():
await create_prompt_async(template_items[k], k)
logger.debug(f"注册{template_items[k]},{k}")
else:
template_group_name = None
async def preprocess(): async def preprocess():
# 使用统一的转换方法创建数据库消息对象 # message 已经是 DatabaseMessages直接使用
db_message = message.to_database_message() group_info = chat.group_info
group_info = getattr(message.chat_stream, "group_info", None)
# 先交给消息管理器处理,计算兴趣度等衍生数据 # 先交给消息管理器处理,计算兴趣度等衍生数据
try: try:
@@ -527,31 +534,15 @@ class ChatBot:
should_process_in_manager = False should_process_in_manager = False
if should_process_in_manager: if should_process_in_manager:
await message_manager.add_message(message.chat_stream.stream_id, db_message) await message_manager.add_message(chat.stream_id, message)
logger.debug(f"消息已添加到消息管理器: {message.chat_stream.stream_id}") logger.debug(f"消息已添加到消息管理器: {chat.stream_id}")
except Exception as e: except Exception as e:
logger.error(f"消息添加到消息管理器失败: {e}") logger.error(f"消息添加到消息管理器失败: {e}")
# 将兴趣度结果同步回原始消息,便于后续流程使用
message.interest_value = getattr(db_message, "interest_value", getattr(message, "interest_value", 0.0))
setattr(
message,
"should_reply",
getattr(db_message, "should_reply", getattr(message, "should_reply", False)),
)
setattr(message, "should_act", getattr(db_message, "should_act", getattr(message, "should_act", False)))
# 存储消息到数据库,只进行一次写入 # 存储消息到数据库,只进行一次写入
try: try:
await MessageStorage.store_message(message, message.chat_stream) await MessageStorage.store_message(message, chat)
logger.debug(
"消息已存储到数据库: %s (interest=%.3f, should_reply=%s, should_act=%s)",
message.message_info.message_id,
getattr(message, "interest_value", -1.0),
getattr(message, "should_reply", None),
getattr(message, "should_act", None),
)
except Exception as e: except Exception as e:
logger.error(f"存储消息到数据库失败: {e}") logger.error(f"存储消息到数据库失败: {e}")
traceback.print_exc() traceback.print_exc()
@@ -560,13 +551,13 @@ class ChatBot:
try: try:
if global_config.mood.enable_mood: if global_config.mood.enable_mood:
# 获取兴趣度用于情绪更新 # 获取兴趣度用于情绪更新
interest_rate = getattr(message, "interest_value", 0.0) interest_rate = message.interest_value
if interest_rate is None: if interest_rate is None:
interest_rate = 0.0 interest_rate = 0.0
logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}") logger.debug(f"开始更新情绪状态,兴趣度: {interest_rate:.2f}")
# 获取当前聊天的情绪对象并更新情绪状态 # 获取当前聊天的情绪对象并更新情绪状态
chat_mood = mood_manager.get_mood_by_chat_id(message.chat_stream.stream_id) chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
await chat_mood.update_mood_by_message(message, interest_rate) await chat_mood.update_mood_by_message(message, interest_rate)
logger.debug("情绪状态更新完成") logger.debug("情绪状态更新完成")
except Exception as e: except Exception as e:

View File

@@ -12,13 +12,10 @@ from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from src.common.database.sqlalchemy_database_api import get_db_session from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入 from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config # 新增导入 from src.config.config import global_config # 新增导入
# 避免循环导入使用TYPE_CHECKING进行类型提示
if TYPE_CHECKING:
from .message import MessageRecv
install(extra_lines=3) install(extra_lines=3)
@@ -33,7 +30,7 @@ class ChatStream:
self, self,
stream_id: str, stream_id: str,
platform: str, platform: str,
user_info: UserInfo, user_info: UserInfo | None = None,
group_info: GroupInfo | None = None, group_info: GroupInfo | None = None,
data: dict | None = None, data: dict | None = None,
): ):
@@ -46,20 +43,18 @@ class ChatStream:
self.sleep_pressure = data.get("sleep_pressure", 0.0) if data else 0.0 self.sleep_pressure = data.get("sleep_pressure", 0.0) if data else 0.0
self.saved = False self.saved = False
# 使用StreamContext替代ChatMessageContext # 创建单流上下文管理器包含StreamContext
from src.chat.message_manager.context_manager import SingleStreamContextManager
from src.common.data_models.message_manager_data_model import StreamContext from src.common.data_models.message_manager_data_model import StreamContext
from src.plugin_system.base.component_types import ChatMode, ChatType from src.plugin_system.base.component_types import ChatMode, ChatType
# 创建StreamContext
self.stream_context: StreamContext = StreamContext(
stream_id=stream_id, chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, chat_mode=ChatMode.NORMAL
)
# 创建单流上下文管理器
from src.chat.message_manager.context_manager import SingleStreamContextManager
self.context_manager: SingleStreamContextManager = SingleStreamContextManager( self.context_manager: SingleStreamContextManager = SingleStreamContextManager(
stream_id=stream_id, context=self.stream_context stream_id=stream_id,
context=StreamContext(
stream_id=stream_id,
chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE,
chat_mode=ChatMode.NORMAL,
),
) )
# 基础参数 # 基础参数
@@ -88,13 +83,12 @@ class ChatStream:
new_stream._focus_energy = self._focus_energy new_stream._focus_energy = self._focus_energy
new_stream.no_reply_consecutive = self.no_reply_consecutive new_stream.no_reply_consecutive = self.no_reply_consecutive
# 复制 stream_context但跳过 processing_task # 复制 context_manager包含 stream_context
new_stream.stream_context = copy.deepcopy(self.stream_context, memo)
if hasattr(new_stream.stream_context, "processing_task"):
new_stream.stream_context.processing_task = None
# 复制 context_manager
new_stream.context_manager = copy.deepcopy(self.context_manager, memo) new_stream.context_manager = copy.deepcopy(self.context_manager, memo)
# 清理 processing_task如果存在
if hasattr(new_stream.context_manager.context, "processing_task"):
new_stream.context_manager.context.processing_task = None
return new_stream return new_stream
@@ -111,11 +105,11 @@ class ChatStream:
"focus_energy": self.focus_energy, "focus_energy": self.focus_energy,
# 基础兴趣度 # 基础兴趣度
"base_interest_energy": self.base_interest_energy, "base_interest_energy": self.base_interest_energy,
# stream_context基本信息 # stream_context基本信息通过context_manager访问
"stream_context_chat_type": self.stream_context.chat_type.value, "stream_context_chat_type": self.context_manager.context.chat_type.value,
"stream_context_chat_mode": self.stream_context.chat_mode.value, "stream_context_chat_mode": self.context_manager.context.chat_mode.value,
# 统计信息 # 统计信息
"interruption_count": self.stream_context.interruption_count, "interruption_count": self.context_manager.context.interruption_count,
} }
@classmethod @classmethod
@@ -132,27 +126,19 @@ class ChatStream:
data=data, data=data,
) )
# 恢复stream_context信息 # 恢复stream_context信息通过context_manager访问
if "stream_context_chat_type" in data: if "stream_context_chat_type" in data:
from src.plugin_system.base.component_types import ChatMode, ChatType from src.plugin_system.base.component_types import ChatMode, ChatType
instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"]) instance.context_manager.context.chat_type = ChatType(data["stream_context_chat_type"])
if "stream_context_chat_mode" in data: if "stream_context_chat_mode" in data:
from src.plugin_system.base.component_types import ChatMode, ChatType from src.plugin_system.base.component_types import ChatMode, ChatType
instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"]) instance.context_manager.context.chat_mode = ChatMode(data["stream_context_chat_mode"])
# 恢复interruption_count信息 # 恢复interruption_count信息
if "interruption_count" in data: if "interruption_count" in data:
instance.stream_context.interruption_count = data["interruption_count"] instance.context_manager.context.interruption_count = data["interruption_count"]
# 确保 context_manager 已初始化
if not hasattr(instance, "context_manager"):
from src.chat.message_manager.context_manager import SingleStreamContextManager
instance.context_manager = SingleStreamContextManager(
stream_id=instance.stream_id, context=instance.stream_context
)
return instance return instance
@@ -160,156 +146,44 @@ class ChatStream:
"""获取原始的、未哈希的聊天流ID字符串""" """获取原始的、未哈希的聊天流ID字符串"""
if self.group_info: if self.group_info:
return f"{self.platform}:{self.group_info.group_id}:group" return f"{self.platform}:{self.group_info.group_id}:group"
else: elif self.user_info:
return f"{self.platform}:{self.user_info.user_id}:private" return f"{self.platform}:{self.user_info.user_id}:private"
else:
return f"{self.platform}:unknown:private"
def update_active_time(self): def update_active_time(self):
"""更新最后活跃时间""" """更新最后活跃时间"""
self.last_active_time = time.time() self.last_active_time = time.time()
self.saved = False self.saved = False
async def set_context(self, message: "MessageRecv"): async def set_context(self, message: DatabaseMessages):
"""设置聊天消息上下文""" """设置聊天消息上下文
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
import json
from src.common.data_models.database_data_model import DatabaseMessages
# 安全获取message_info中的数据
message_info = getattr(message, "message_info", {})
user_info = getattr(message_info, "user_info", {})
group_info = getattr(message_info, "group_info", {})
# 提取reply_to信息从message_segment中查找reply类型的段
reply_to = None
if hasattr(message, "message_segment") and message.message_segment:
reply_to = self._extract_reply_from_segment(message.message_segment)
# 完整的数据转移逻辑
db_message = DatabaseMessages(
# 基础消息信息
message_id=getattr(message, "message_id", ""),
time=getattr(message, "time", time.time()),
chat_id=self._generate_chat_id(message_info),
reply_to=reply_to,
# 兴趣度相关
interest_value=getattr(message, "interest_value", 0.0),
# 关键词
key_words=json.dumps(getattr(message, "key_words", []), ensure_ascii=False)
if getattr(message, "key_words", None)
else None,
key_words_lite=json.dumps(getattr(message, "key_words_lite", []), ensure_ascii=False)
if getattr(message, "key_words_lite", None)
else None,
# 消息状态标记
is_mentioned=getattr(message, "is_mentioned", None),
is_at=getattr(message, "is_at", False),
is_emoji=getattr(message, "is_emoji", False),
is_picid=getattr(message, "is_picid", False),
is_voice=getattr(message, "is_voice", False),
is_video=getattr(message, "is_video", False),
is_command=getattr(message, "is_command", False),
is_notify=getattr(message, "is_notify", False),
is_public_notice=getattr(message, "is_public_notice", False),
notice_type=getattr(message, "notice_type", None),
# 消息内容
processed_plain_text=getattr(message, "processed_plain_text", ""),
display_message=getattr(message, "processed_plain_text", ""), # 默认使用processed_plain_text
# 优先级信息
priority_mode=getattr(message, "priority_mode", None),
priority_info=json.dumps(getattr(message, "priority_info", None))
if getattr(message, "priority_info", None)
else None,
# 额外配置 - 需要将 format_info 嵌入到 additional_config 中
additional_config=self._prepare_additional_config(message_info),
# 用户信息
user_id=str(getattr(user_info, "user_id", "")),
user_nickname=getattr(user_info, "user_nickname", ""),
user_cardname=getattr(user_info, "user_cardname", None),
user_platform=getattr(user_info, "platform", ""),
# 群组信息
chat_info_group_id=getattr(group_info, "group_id", None),
chat_info_group_name=getattr(group_info, "group_name", None),
chat_info_group_platform=getattr(group_info, "platform", None),
# 聊天流信息
chat_info_user_id=str(getattr(user_info, "user_id", "")),
chat_info_user_nickname=getattr(user_info, "user_nickname", ""),
chat_info_user_cardname=getattr(user_info, "user_cardname", None),
chat_info_user_platform=getattr(user_info, "platform", ""),
chat_info_stream_id=self.stream_id,
chat_info_platform=self.platform,
chat_info_create_time=self.create_time,
chat_info_last_active_time=self.last_active_time,
# 新增兴趣度系统字段 - 添加安全处理
actions=self._safe_get_actions(message),
should_reply=getattr(message, "should_reply", False),
should_act=getattr(message, "should_act", False),
)
self.stream_context.set_current_message(db_message)
self.stream_context.priority_mode = getattr(message, "priority_mode", None)
self.stream_context.priority_info = getattr(message, "priority_info", None)
# 调试日志:记录数据转移情况
logger.debug(
f"消息数据转移完成 - message_id: {db_message.message_id}, "
f"chat_id: {db_message.chat_id}, "
f"is_mentioned: {db_message.is_mentioned}, "
f"is_emoji: {db_message.is_emoji}, "
f"is_picid: {db_message.is_picid}, "
f"interest_value: {db_message.interest_value}"
)
def _prepare_additional_config(self, message_info) -> str | None:
"""
准备 additional_config将 format_info 嵌入其中
这个方法模仿 storage.py 中的逻辑,确保 DatabaseMessages 中的 additional_config
包含 format_info使得 action_modifier 能够正确获取适配器支持的消息类型
Args: Args:
message_info: BaseMessageInfo 对象 message: DatabaseMessages 对象,直接使用不需要转换
Returns:
str | None: JSON 字符串格式的 additional_config如果为空则返回 None
""" """
import orjson # 直接使用传入的 DatabaseMessages设置到上下文中
self.context_manager.context.set_current_message(message)
# 首先获取adapter传递的additional_config
additional_config_data = {}
if hasattr(message_info, 'additional_config') and message_info.additional_config:
if isinstance(message_info.additional_config, dict):
additional_config_data = message_info.additional_config.copy()
elif isinstance(message_info.additional_config, str):
# 如果是字符串,尝试解析
try:
additional_config_data = orjson.loads(message_info.additional_config)
except Exception as e:
logger.warning(f"无法解析 additional_config JSON: {e}")
additional_config_data = {}
# 然后添加format_info到additional_config中 # 设置优先级信息(如果存在)
if hasattr(message_info, 'format_info') and message_info.format_info: priority_mode = getattr(message, "priority_mode", None)
try: priority_info = getattr(message, "priority_info", None)
format_info_dict = message_info.format_info.to_dict() if priority_mode:
additional_config_data["format_info"] = format_info_dict self.context_manager.context.priority_mode = priority_mode
logger.debug(f"嵌入 format_info 到 additional_config: {format_info_dict}") if priority_info:
except Exception as e: self.context_manager.context.priority_info = priority_info
logger.warning(f"将 format_info 转换为字典失败: {e}")
else:
logger.warning(f"[问题] 消息缺少 format_info: message_id={getattr(message_info, 'message_id', 'unknown')}")
logger.warning("[问题] 这可能导致 Action 无法正确检查适配器支持的类型")
# 序列化为JSON字符串
if additional_config_data:
try:
return orjson.dumps(additional_config_data).decode("utf-8")
except Exception as e:
logger.error(f"序列化 additional_config 失败: {e}")
return None
return None
def _safe_get_actions(self, message: "MessageRecv") -> list | None: # 调试日志
logger.debug(
f"消息上下文已设置 - message_id: {message.message_id}, "
f"chat_id: {message.chat_id}, "
f"is_mentioned: {message.is_mentioned}, "
f"is_emoji: {message.is_emoji}, "
f"is_picid: {message.is_picid}, "
f"interest_value: {message.interest_value}"
)
def _safe_get_actions(self, message: DatabaseMessages) -> list | None:
"""安全获取消息的actions字段""" """安全获取消息的actions字段"""
import json import json
@@ -380,23 +254,6 @@ class ChatStream:
if hasattr(db_message, "should_act"): if hasattr(db_message, "should_act"):
db_message.should_act = False db_message.should_act = False
def _extract_reply_from_segment(self, segment) -> str | None:
"""从消息段中提取reply_to信息"""
try:
if hasattr(segment, "type") and segment.type == "seglist":
# 递归搜索seglist中的reply段
if hasattr(segment, "data") and segment.data:
for seg in segment.data:
reply_id = self._extract_reply_from_segment(seg)
if reply_id:
return reply_id
elif hasattr(segment, "type") and segment.type == "reply":
# 找到reply段返回message_id
return str(segment.data) if segment.data else None
except Exception as e:
logger.warning(f"提取reply_to信息失败: {e}")
return None
def _generate_chat_id(self, message_info) -> str: def _generate_chat_id(self, message_info) -> str:
"""生成chat_id基于群组或用户信息""" """生成chat_id基于群组或用户信息"""
try: try:
@@ -493,8 +350,10 @@ class ChatManager:
def __init__(self): def __init__(self):
if not self._initialized: if not self._initialized:
from src.common.data_models.database_data_model import DatabaseMessages
self.streams: dict[str, ChatStream] = {} # stream_id -> ChatStream self.streams: dict[str, ChatStream] = {} # stream_id -> ChatStream
self.last_messages: dict[str, "MessageRecv"] = {} # stream_id -> last_message self.last_messages: dict[str, DatabaseMessages] = {} # stream_id -> last_message
# try: # try:
# async with get_db_session() as session: # async with get_db_session() as session:
# db.connect(reuse_if_open=True) # db.connect(reuse_if_open=True)
@@ -528,12 +387,30 @@ class ChatManager:
except Exception as e: except Exception as e:
logger.error(f"聊天流自动保存失败: {e!s}") logger.error(f"聊天流自动保存失败: {e!s}")
def register_message(self, message: "MessageRecv"): def register_message(self, message: DatabaseMessages):
"""注册消息到聊天流""" """注册消息到聊天流"""
# 从 DatabaseMessages 提取平台和用户/群组信息
from maim_message import UserInfo, GroupInfo
user_info = UserInfo(
platform=message.user_info.platform,
user_id=message.user_info.user_id,
user_nickname=message.user_info.user_nickname,
user_cardname=message.user_info.user_cardname or ""
)
group_info = None
if message.group_info:
group_info = GroupInfo(
platform=message.group_info.group_platform or "",
group_id=message.group_info.group_id,
group_name=message.group_info.group_name
)
stream_id = self._generate_stream_id( stream_id = self._generate_stream_id(
message.message_info.platform, # type: ignore message.chat_info.platform,
message.message_info.user_info, user_info,
message.message_info.group_info, group_info,
) )
self.last_messages[stream_id] = message self.last_messages[stream_id] = message
# logger.debug(f"注册消息到聊天流: {stream_id}") # logger.debug(f"注册消息到聊天流: {stream_id}")
@@ -578,32 +455,6 @@ class ChatManager:
try: try:
stream_id = self._generate_stream_id(platform, user_info, group_info) stream_id = self._generate_stream_id(platform, user_info, group_info)
# 优先使用缓存管理器(优化版本)
try:
from src.chat.message_manager.stream_cache_manager import get_stream_cache_manager
cache_manager = get_stream_cache_manager()
if cache_manager.is_running:
optimized_stream = await cache_manager.get_or_create_stream(
stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info
)
# 设置消息上下文
from .message import MessageRecv
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
optimized_stream.set_context(self.last_messages[stream_id])
# 转换为原始ChatStream以保持兼容性
original_stream = self._convert_to_original_stream(optimized_stream)
return original_stream
except Exception as e:
logger.debug(f"缓存管理器获取流失败,使用原始方法: {e}")
# 回退到原始方法
# 检查内存中是否存在 # 检查内存中是否存在
if stream_id in self.streams: if stream_id in self.streams:
stream = self.streams[stream_id] stream = self.streams[stream_id]
@@ -615,12 +466,13 @@ class ChatManager:
stream.user_info = user_info stream.user_info = user_info
if group_info: if group_info:
stream.group_info = group_info stream.group_info = group_info
from .message import MessageRecv # 延迟导入,避免循环引用
# 检查是否有最后一条消息(现在使用 DatabaseMessages
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv): from src.common.data_models.database_data_model import DatabaseMessages
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
await stream.set_context(self.last_messages[stream_id]) await stream.set_context(self.last_messages[stream_id])
else: else:
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的") logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的或还没有消息")
return stream return stream
# 检查数据库中是否存在 # 检查数据库中是否存在
@@ -679,19 +531,27 @@ class ChatManager:
raise e raise e
stream = copy.deepcopy(stream) stream = copy.deepcopy(stream)
from .message import MessageRecv # 延迟导入,避免循环引用 from src.common.data_models.database_data_model import DatabaseMessages
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv): if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
await stream.set_context(self.last_messages[stream_id]) await stream.set_context(self.last_messages[stream_id])
else: else:
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的") logger.debug(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
# 确保 ChatStream 有自己的 context_manager # 确保 ChatStream 有自己的 context_manager
if not hasattr(stream, "context_manager"): if not hasattr(stream, "context_manager"):
# 创建新的单流上下文管理器
from src.chat.message_manager.context_manager import SingleStreamContextManager from src.chat.message_manager.context_manager import SingleStreamContextManager
from src.common.data_models.message_manager_data_model import StreamContext
from src.plugin_system.base.component_types import ChatMode, ChatType
stream.context_manager = SingleStreamContextManager(stream_id=stream_id, context=stream.stream_context) stream.context_manager = SingleStreamContextManager(
stream_id=stream_id,
context=StreamContext(
stream_id=stream_id,
chat_type=ChatType.GROUP if stream.group_info else ChatType.PRIVATE,
chat_mode=ChatMode.NORMAL,
),
)
# 保存到内存和数据库 # 保存到内存和数据库
self.streams[stream_id] = stream self.streams[stream_id] = stream
@@ -700,10 +560,12 @@ class ChatManager:
async def get_stream(self, stream_id: str) -> ChatStream | None: async def get_stream(self, stream_id: str) -> ChatStream | None:
"""通过stream_id获取聊天流""" """通过stream_id获取聊天流"""
from src.common.data_models.database_data_model import DatabaseMessages
stream = self.streams.get(stream_id) stream = self.streams.get(stream_id)
if not stream: if not stream:
return None return None
if stream_id in self.last_messages: if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
await stream.set_context(self.last_messages[stream_id]) await stream.set_context(self.last_messages[stream_id])
return stream return stream
@@ -921,9 +783,16 @@ class ChatManager:
# 确保 ChatStream 有自己的 context_manager # 确保 ChatStream 有自己的 context_manager
if not hasattr(stream, "context_manager"): if not hasattr(stream, "context_manager"):
from src.chat.message_manager.context_manager import SingleStreamContextManager from src.chat.message_manager.context_manager import SingleStreamContextManager
from src.common.data_models.message_manager_data_model import StreamContext
from src.plugin_system.base.component_types import ChatMode, ChatType
stream.context_manager = SingleStreamContextManager( stream.context_manager = SingleStreamContextManager(
stream_id=stream.stream_id, context=stream.stream_context stream_id=stream.stream_id,
context=StreamContext(
stream_id=stream.stream_id,
chat_type=ChatType.GROUP if stream.group_info else ChatType.PRIVATE,
chat_mode=ChatMode.NORMAL,
),
) )
except Exception as e: except Exception as e:
logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}", exc_info=True) logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}", exc_info=True)
@@ -932,46 +801,6 @@ class ChatManager:
chat_manager = None chat_manager = None
def _convert_to_original_stream(self, optimized_stream) -> "ChatStream":
"""将OptimizedChatStream转换为原始ChatStream以保持兼容性"""
try:
# 创建原始ChatStream实例
original_stream = ChatStream(
stream_id=optimized_stream.stream_id,
platform=optimized_stream.platform,
user_info=optimized_stream._get_effective_user_info(),
group_info=optimized_stream._get_effective_group_info(),
)
# 复制状态
original_stream.create_time = optimized_stream.create_time
original_stream.last_active_time = optimized_stream.last_active_time
original_stream.sleep_pressure = optimized_stream.sleep_pressure
original_stream.base_interest_energy = optimized_stream.base_interest_energy
original_stream._focus_energy = optimized_stream._focus_energy
original_stream.no_reply_consecutive = optimized_stream.no_reply_consecutive
original_stream.saved = optimized_stream.saved
# 复制上下文信息(如果存在)
if hasattr(optimized_stream, "_stream_context") and optimized_stream._stream_context:
original_stream.stream_context = optimized_stream._stream_context
if hasattr(optimized_stream, "_context_manager") and optimized_stream._context_manager:
original_stream.context_manager = optimized_stream._context_manager
return original_stream
except Exception as e:
logger.error(f"转换OptimizedChatStream失败: {e}")
# 如果转换失败,创建一个新的原始流
return ChatStream(
stream_id=optimized_stream.stream_id,
platform=optimized_stream.platform,
user_info=optimized_stream._get_effective_user_info(),
group_info=optimized_stream._get_effective_group_info(),
)
def get_chat_manager(): def get_chat_manager():
global chat_manager global chat_manager
if chat_manager is None: if chat_manager is None:

View File

@@ -2,7 +2,7 @@ import base64
import time import time
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional from typing import Any, Optional, Union
import urllib3 import urllib3
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo
@@ -13,6 +13,7 @@ from src.chat.utils.self_voice_cache import consume_self_voice_text
from src.chat.utils.utils_image import get_image_manager from src.chat.utils.utils_image import get_image_manager
from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available
from src.chat.utils.utils_voice import get_voice_text from src.chat.utils.utils_voice import get_voice_text
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
@@ -43,7 +44,7 @@ class Message(MessageBase, metaclass=ABCMeta):
user_info: UserInfo, user_info: UserInfo,
message_segment: Seg | None = None, message_segment: Seg | None = None,
timestamp: float | None = None, timestamp: float | None = None,
reply: Optional["MessageRecv"] = None, reply: Optional["DatabaseMessages"] = None,
processed_plain_text: str = "", processed_plain_text: str = "",
): ):
# 使用传入的时间戳或当前时间 # 使用传入的时间戳或当前时间
@@ -95,346 +96,12 @@ class Message(MessageBase, metaclass=ABCMeta):
@dataclass @dataclass
class MessageRecv(Message):
"""接收消息类用于处理从MessageCQ序列化的消息"""
def __init__(self, message_dict: dict[str, Any]): # MessageRecv 类已被完全移除,现在统一使用 DatabaseMessages
"""从MessageCQ的字典初始化 # 如需从消息字典创建 DatabaseMessages请使用
# from src.chat.message_receive.message_processor import process_message_from_dict
Args: #
message_dict: MessageCQ序列化后的字典 # 迁移完成日期: 2025-10-31
"""
# Manually initialize attributes from MessageBase and Message
self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
self.raw_message = message_dict.get("raw_message")
self.chat_stream = None
self.reply = None
self.processed_plain_text = message_dict.get("processed_plain_text", "")
self.memorized_times = 0
# MessageRecv specific attributes
self.is_emoji = False
self.has_emoji = False
self.is_picid = False
self.has_picid = False
self.is_voice = False
self.is_video = False
self.is_mentioned = None
self.is_notify = False # 是否为notice消息
self.is_public_notice = False # 是否为公共notice
self.notice_type = None # notice类型
self.is_at = False
self.is_command = False
self.priority_mode = "interest"
self.priority_info = None
self.interest_value: float = 0.0
self.key_words = []
self.key_words_lite = []
# 解析additional_config中的notice信息
if self.message_info.additional_config and isinstance(self.message_info.additional_config, dict):
self.is_notify = self.message_info.additional_config.get("is_notice", False)
self.is_public_notice = self.message_info.additional_config.get("is_public_notice", False)
self.notice_type = self.message_info.additional_config.get("notice_type")
def update_chat_stream(self, chat_stream: "ChatStream"):
self.chat_stream = chat_stream
def to_database_message(self) -> "DatabaseMessages":
"""将 MessageRecv 转换为 DatabaseMessages 对象
Returns:
DatabaseMessages: 数据库消息对象
"""
from src.common.data_models.database_data_model import DatabaseMessages
import json
import time
message_info = self.message_info
msg_user_info = getattr(message_info, "user_info", None)
stream_user_info = getattr(self.chat_stream, "user_info", None) if self.chat_stream else None
group_info = getattr(self.chat_stream, "group_info", None) if self.chat_stream else None
message_id = message_info.message_id or ""
message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time()
is_mentioned = None
if isinstance(self.is_mentioned, bool):
is_mentioned = self.is_mentioned
elif isinstance(self.is_mentioned, int | float):
is_mentioned = self.is_mentioned != 0
# 提取用户信息
user_id = ""
user_nickname = ""
user_cardname = None
user_platform = ""
if msg_user_info:
user_id = str(getattr(msg_user_info, "user_id", "") or "")
user_nickname = getattr(msg_user_info, "user_nickname", "") or ""
user_cardname = getattr(msg_user_info, "user_cardname", None)
user_platform = getattr(msg_user_info, "platform", "") or ""
elif stream_user_info:
user_id = str(getattr(stream_user_info, "user_id", "") or "")
user_nickname = getattr(stream_user_info, "user_nickname", "") or ""
user_cardname = getattr(stream_user_info, "user_cardname", None)
user_platform = getattr(stream_user_info, "platform", "") or ""
# 提取聊天流信息
chat_user_id = str(getattr(stream_user_info, "user_id", "") or "") if stream_user_info else ""
chat_user_nickname = getattr(stream_user_info, "user_nickname", "") or "" if stream_user_info else ""
chat_user_cardname = getattr(stream_user_info, "user_cardname", None) if stream_user_info else None
chat_user_platform = getattr(stream_user_info, "platform", "") or "" if stream_user_info else ""
group_id = getattr(group_info, "group_id", None) if group_info else None
group_name = getattr(group_info, "group_name", None) if group_info else None
group_platform = getattr(group_info, "platform", None) if group_info else None
# 准备 additional_config
additional_config_str = None
try:
import orjson
additional_config_data = {}
# 首先获取adapter传递的additional_config
if hasattr(message_info, 'additional_config') and message_info.additional_config:
if isinstance(message_info.additional_config, dict):
additional_config_data = message_info.additional_config.copy()
elif isinstance(message_info.additional_config, str):
try:
additional_config_data = orjson.loads(message_info.additional_config)
except Exception as e:
logger.warning(f"无法解析 additional_config JSON: {e}")
additional_config_data = {}
# 添加notice相关标志
if self.is_notify:
additional_config_data["is_notice"] = True
additional_config_data["notice_type"] = self.notice_type or "unknown"
additional_config_data["is_public_notice"] = bool(self.is_public_notice)
# 添加format_info到additional_config中
if hasattr(message_info, 'format_info') and message_info.format_info:
try:
format_info_dict = message_info.format_info.to_dict()
additional_config_data["format_info"] = format_info_dict
logger.debug(f"[message.py] 嵌入 format_info 到 additional_config: {format_info_dict}")
except Exception as e:
logger.warning(f"将 format_info 转换为字典失败: {e}")
# 序列化为JSON字符串
if additional_config_data:
additional_config_str = orjson.dumps(additional_config_data).decode("utf-8")
except Exception as e:
logger.error(f"准备 additional_config 失败: {e}")
# 创建数据库消息对象
db_message = DatabaseMessages(
message_id=message_id,
time=float(message_time),
chat_id=self.chat_stream.stream_id if self.chat_stream else "",
processed_plain_text=self.processed_plain_text,
display_message=self.processed_plain_text,
is_mentioned=is_mentioned,
is_at=bool(self.is_at) if self.is_at is not None else None,
is_emoji=bool(self.is_emoji),
is_picid=bool(self.is_picid),
is_command=bool(self.is_command),
is_notify=bool(self.is_notify),
is_public_notice=bool(self.is_public_notice),
notice_type=self.notice_type,
additional_config=additional_config_str,
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
user_platform=user_platform,
chat_info_stream_id=self.chat_stream.stream_id if self.chat_stream else "",
chat_info_platform=self.chat_stream.platform if self.chat_stream else "",
chat_info_create_time=float(self.chat_stream.create_time) if self.chat_stream else 0.0,
chat_info_last_active_time=float(self.chat_stream.last_active_time) if self.chat_stream else 0.0,
chat_info_user_id=chat_user_id,
chat_info_user_nickname=chat_user_nickname,
chat_info_user_cardname=chat_user_cardname,
chat_info_user_platform=chat_user_platform,
chat_info_group_id=group_id,
chat_info_group_name=group_name,
chat_info_group_platform=group_platform,
)
# 同步兴趣度等衍生属性
db_message.interest_value = getattr(self, "interest_value", 0.0)
setattr(db_message, "should_reply", getattr(self, "should_reply", False))
setattr(db_message, "should_act", getattr(self, "should_act", False))
return db_message
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本
这个方法必须在创建实例后显式调用,因为它包含异步操作。
"""
self.processed_plain_text = await self._process_message_segments(self.message_segment)
async def _process_single_segment(self, segment: Seg) -> str:
"""处理单个消息段
Args:
segment: 消息段
Returns:
str: 处理后的文本
"""
try:
if segment.type == "text":
self.is_picid = False
self.is_emoji = False
self.is_video = False
return segment.data # type: ignore
elif segment.type == "at":
self.is_picid = False
self.is_emoji = False
self.is_video = False
# 处理at消息格式为"昵称:QQ号"
if isinstance(segment.data, str) and ":" in segment.data:
nickname, qq_id = segment.data.split(":", 1)
return f"@{nickname}"
return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户"
elif segment.type == "image":
# 如果是base64图片数据
if isinstance(segment.data, str):
self.has_picid = True
self.is_picid = True
self.is_emoji = False
self.is_video = False
image_manager = get_image_manager()
# print(f"segment.data: {segment.data}")
_, processed_text = await image_manager.process_image(segment.data)
return processed_text
return "[发了一张图片,网卡了加载不出来]"
elif segment.type == "emoji":
self.has_emoji = True
self.is_emoji = True
self.is_picid = False
self.is_voice = False
self.is_video = False
if isinstance(segment.data, str):
return await get_image_manager().get_emoji_description(segment.data)
return "[发了一个表情包,网卡了加载不出来]"
elif segment.type == "voice":
self.is_picid = False
self.is_emoji = False
self.is_voice = True
self.is_video = False
# 检查消息是否由机器人自己发送
if self.message_info and self.message_info.user_info and str(self.message_info.user_info.user_id) == str(global_config.bot.qq_account):
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {self.message_info.user_info.user_id}),尝试从缓存获取文本。")
if isinstance(segment.data, str):
cached_text = consume_self_voice_text(segment.data)
if cached_text:
logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'")
return f"[语音:{cached_text}]"
else:
logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
# 标准语音识别流程 (也作为缓存未命中的后备方案)
if isinstance(segment.data, str):
return await get_voice_text(segment.data)
return "[发了一段语音,网卡了加载不出来]"
elif segment.type == "mention_bot":
self.is_picid = False
self.is_emoji = False
self.is_voice = False
self.is_video = False
self.is_mentioned = float(segment.data) # type: ignore
return ""
elif segment.type == "priority_info":
self.is_picid = False
self.is_emoji = False
self.is_voice = False
if isinstance(segment.data, dict):
# 处理优先级信息
self.priority_mode = "priority"
self.priority_info = segment.data
"""
{
'message_type': 'vip', # vip or normal
'message_priority': 1.0, # 优先级大为优先float
}
"""
return ""
elif segment.type == "file":
if isinstance(segment.data, dict):
file_name = segment.data.get('name', '未知文件')
file_size = segment.data.get('size', '未知大小')
return f"[文件:{file_name} ({file_size}字节)]"
return "[收到一个文件]"
elif segment.type == "video":
self.is_picid = False
self.is_emoji = False
self.is_voice = False
self.is_video = True
logger.info(f"接收到视频消息,数据类型: {type(segment.data)}")
# 检查视频分析功能是否可用
if not is_video_analysis_available():
logger.warning("⚠️ Rust视频处理模块不可用跳过视频分析")
return "[视频]"
if global_config.video_analysis.enable:
logger.info("已启用视频识别,开始识别")
if isinstance(segment.data, dict):
try:
# 从Adapter接收的视频数据
video_base64 = segment.data.get("base64")
filename = segment.data.get("filename", "video.mp4")
logger.info(f"视频文件名: {filename}")
logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
if video_base64:
# 解码base64视频数据
video_bytes = base64.b64decode(video_base64)
logger.info(f"解码后视频大小: {len(video_bytes)} 字节")
# 使用video analyzer分析视频
video_analyzer = get_video_analyzer()
result = await video_analyzer.analyze_video_from_bytes(
video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt
)
logger.info(f"视频分析结果: {result}")
# 返回视频分析结果
summary = result.get("summary", "")
if summary:
return f"[视频内容] {summary}"
else:
return "[已收到视频,但分析失败]"
else:
logger.warning("视频消息中没有base64数据")
return "[收到视频消息,但数据异常]"
except Exception as e:
logger.error(f"视频处理失败: {e!s}")
import traceback
logger.error(f"错误详情: {traceback.format_exc()}")
return "[收到视频,但处理时出现错误]"
else:
logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}")
return "[发了一个视频,但格式不支持]"
else:
return ""
else:
logger.warning(f"未知的消息段类型: {segment.type}")
return f"[{segment.type} 消息]"
except Exception as e:
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
@dataclass @dataclass
@@ -447,7 +114,7 @@ class MessageProcessBase(Message):
chat_stream: "ChatStream", chat_stream: "ChatStream",
bot_user_info: UserInfo, bot_user_info: UserInfo,
message_segment: Seg | None = None, message_segment: Seg | None = None,
reply: Optional["MessageRecv"] = None, reply: Optional["DatabaseMessages"] = None,
thinking_start_time: float = 0, thinking_start_time: float = 0,
timestamp: float | None = None, timestamp: float | None = None,
): ):
@@ -548,7 +215,7 @@ class MessageSending(MessageProcessBase):
sender_info: UserInfo | None, # 用来记录发送者信息 sender_info: UserInfo | None, # 用来记录发送者信息
message_segment: Seg, message_segment: Seg,
display_message: str = "", display_message: str = "",
reply: Optional["MessageRecv"] = None, reply: Optional["DatabaseMessages"] = None,
is_head: bool = False, is_head: bool = False,
is_emoji: bool = False, is_emoji: bool = False,
thinking_start_time: float = 0, thinking_start_time: float = 0,
@@ -567,7 +234,11 @@ class MessageSending(MessageProcessBase):
# 发送状态特有属性 # 发送状态特有属性
self.sender_info = sender_info self.sender_info = sender_info
self.reply_to_message_id = reply.message_info.message_id if reply else None # 从 DatabaseMessages 获取 message_id
if reply:
self.reply_to_message_id = reply.message_id
else:
self.reply_to_message_id = None
self.is_head = is_head self.is_head = is_head
self.is_emoji = is_emoji self.is_emoji = is_emoji
self.apply_set_reply_logic = apply_set_reply_logic self.apply_set_reply_logic = apply_set_reply_logic
@@ -582,14 +253,18 @@ class MessageSending(MessageProcessBase):
def build_reply(self): def build_reply(self):
"""设置回复消息""" """设置回复消息"""
if self.reply: if self.reply:
self.reply_to_message_id = self.reply.message_info.message_id # 从 DatabaseMessages 获取 message_id
self.message_segment = Seg( message_id = self.reply.message_id
type="seglist",
data=[ if message_id:
Seg(type="reply", data=self.reply.message_info.message_id), # type: ignore self.reply_to_message_id = message_id
self.message_segment, self.message_segment = Seg(
], type="seglist",
) data=[
Seg(type="reply", data=message_id), # type: ignore
self.message_segment,
],
)
async def process(self) -> None: async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本""" """处理消息内容,生成纯文本和详细文本"""
@@ -607,48 +282,5 @@ class MessageSending(MessageProcessBase):
return self.message_info.group_info is None or self.message_info.group_info.group_id is None return self.message_info.group_info is None or self.message_info.group_info.group_id is None
def message_recv_from_dict(message_dict: dict) -> MessageRecv: # message_recv_from_dictmessage_from_db_dict 函数已被移除
return MessageRecv(message_dict) # 请使用: from src.chat.message_receive.message_processor import process_message_from_dict
def message_from_db_dict(db_dict: dict) -> MessageRecv:
"""从数据库字典创建MessageRecv实例"""
# 转换扁平的数据库字典为嵌套结构
message_info_dict = {
"platform": db_dict.get("chat_info_platform"),
"message_id": db_dict.get("message_id"),
"time": db_dict.get("time"),
"group_info": {
"platform": db_dict.get("chat_info_group_platform"),
"group_id": db_dict.get("chat_info_group_id"),
"group_name": db_dict.get("chat_info_group_name"),
},
"user_info": {
"platform": db_dict.get("user_platform"),
"user_id": db_dict.get("user_id"),
"user_nickname": db_dict.get("user_nickname"),
"user_cardname": db_dict.get("user_cardname"),
},
}
processed_text = db_dict.get("processed_plain_text", "")
# 构建 MessageRecv 需要的字典
recv_dict = {
"message_info": message_info_dict,
"message_segment": {"type": "text", "data": processed_text}, # 从纯文本重建消息段
"raw_message": None, # 数据库中未存储原始消息
"processed_plain_text": processed_text,
}
# 创建 MessageRecv 实例
msg = MessageRecv(recv_dict)
# 从数据库字典中填充其他可选字段
msg.interest_value = db_dict.get("interest_value", 0.0)
msg.is_mentioned = db_dict.get("is_mentioned")
msg.priority_mode = db_dict.get("priority_mode", "interest")
msg.priority_info = db_dict.get("priority_info")
msg.is_emoji = db_dict.get("is_emoji", False)
msg.is_picid = db_dict.get("is_picid", False)
return msg

View File

@@ -0,0 +1,493 @@
"""消息处理工具模块
将原 MessageRecv 的消息处理逻辑提取为独立函数,
直接从适配器消息字典生成 DatabaseMessages
"""
import base64
import time
from typing import Any
import orjson
from maim_message import BaseMessageInfo, Seg
from src.chat.utils.self_voice_cache import consume_self_voice_text
from src.chat.utils.utils_image import get_image_manager
from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available
from src.chat.utils.utils_voice import get_voice_text
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("message_processor")
async def process_message_from_dict(message_dict: dict[str, Any], stream_id: str, platform: str) -> DatabaseMessages:
"""从适配器消息字典处理并生成 DatabaseMessages
这个函数整合了原 MessageRecv 的所有处理逻辑:
1. 解析 message_segment 并异步处理内容(图片、语音、视频等)
2. 提取所有消息元数据
3. 直接构造 DatabaseMessages 对象
Args:
message_dict: MessageCQ序列化后的字典
stream_id: 聊天流ID
platform: 平台标识
Returns:
DatabaseMessages: 处理完成的数据库消息对象
"""
# 解析基础信息
message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
# 初始化处理状态
processing_state = {
"is_emoji": False,
"has_emoji": False,
"is_picid": False,
"has_picid": False,
"is_voice": False,
"is_video": False,
"is_mentioned": None,
"is_at": False,
"priority_mode": "interest",
"priority_info": None,
}
# 异步处理消息段,生成纯文本
processed_plain_text = await _process_message_segments(message_segment, processing_state, message_info)
# 解析 notice 信息
is_notify = False
is_public_notice = False
notice_type = None
if message_info.additional_config and isinstance(message_info.additional_config, dict):
is_notify = message_info.additional_config.get("is_notice", False)
is_public_notice = message_info.additional_config.get("is_public_notice", False)
notice_type = message_info.additional_config.get("notice_type")
# 提取用户信息
user_info = message_info.user_info
user_id = str(user_info.user_id) if user_info and user_info.user_id else ""
user_nickname = (user_info.user_nickname or "") if user_info else ""
user_cardname = user_info.user_cardname if user_info else None
user_platform = (user_info.platform or "") if user_info else ""
# 提取群组信息
group_info = message_info.group_info
group_id = group_info.group_id if group_info else None
group_name = group_info.group_name if group_info else None
group_platform = group_info.platform if group_info else None
# 生成 chat_id
if group_info and group_id:
chat_id = f"{platform}_{group_id}"
elif user_info and user_info.user_id:
chat_id = f"{platform}_{user_info.user_id}_private"
else:
chat_id = stream_id
# 准备 additional_config
additional_config_str = _prepare_additional_config(message_info, is_notify, is_public_notice, notice_type)
# 提取 reply_to
reply_to = _extract_reply_from_segment(message_segment)
# 构造 DatabaseMessages
message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time()
message_id = message_info.message_id or ""
# 处理 is_mentioned
is_mentioned = None
mentioned_value = processing_state.get("is_mentioned")
if isinstance(mentioned_value, bool):
is_mentioned = mentioned_value
elif isinstance(mentioned_value, (int, float)):
is_mentioned = mentioned_value != 0
db_message = DatabaseMessages(
message_id=message_id,
time=float(message_time),
chat_id=chat_id,
reply_to=reply_to,
processed_plain_text=processed_plain_text,
display_message=processed_plain_text,
is_mentioned=is_mentioned,
is_at=bool(processing_state.get("is_at", False)),
is_emoji=bool(processing_state.get("is_emoji", False)),
is_picid=bool(processing_state.get("is_picid", False)),
is_command=False, # 将在后续处理中设置
is_notify=bool(is_notify),
is_public_notice=bool(is_public_notice),
notice_type=notice_type,
additional_config=additional_config_str,
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
user_platform=user_platform,
chat_info_stream_id=stream_id,
chat_info_platform=platform,
chat_info_create_time=0.0, # 将由 ChatStream 填充
chat_info_last_active_time=0.0, # 将由 ChatStream 填充
chat_info_user_id=user_id,
chat_info_user_nickname=user_nickname,
chat_info_user_cardname=user_cardname,
chat_info_user_platform=user_platform,
chat_info_group_id=group_id,
chat_info_group_name=group_name,
chat_info_group_platform=group_platform,
)
# 设置优先级信息
if processing_state.get("priority_mode"):
setattr(db_message, "priority_mode", processing_state["priority_mode"])
if processing_state.get("priority_info"):
setattr(db_message, "priority_info", processing_state["priority_info"])
# 设置其他运行时属性
setattr(db_message, "is_voice", bool(processing_state.get("is_voice", False)))
setattr(db_message, "is_video", bool(processing_state.get("is_video", False)))
setattr(db_message, "has_emoji", bool(processing_state.get("has_emoji", False)))
setattr(db_message, "has_picid", bool(processing_state.get("has_picid", False)))
return db_message
async def _process_message_segments(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str:
"""递归处理消息段,转换为文字描述
Args:
segment: 要处理的消息段
state: 处理状态字典(用于记录消息类型标记)
message_info: 消息基础信息(用于某些处理逻辑)
Returns:
str: 处理后的文本
"""
if segment.type == "seglist":
# 处理消息段列表
segments_text = []
for seg in segment.data:
processed = await _process_message_segments(seg, state, message_info)
if processed:
segments_text.append(processed)
return " ".join(segments_text)
else:
# 处理单个消息段
return await _process_single_segment(segment, state, message_info)
async def _process_single_segment(segment: Seg, state: dict, message_info: BaseMessageInfo) -> str:
"""处理单个消息段
Args:
segment: 消息段
state: 处理状态字典
message_info: 消息基础信息
Returns:
str: 处理后的文本
"""
try:
if segment.type == "text":
state["is_picid"] = False
state["is_emoji"] = False
state["is_video"] = False
return segment.data
elif segment.type == "at":
state["is_picid"] = False
state["is_emoji"] = False
state["is_video"] = False
state["is_at"] = True
# 处理at消息格式为"昵称:QQ号"
if isinstance(segment.data, str) and ":" in segment.data:
nickname, qq_id = segment.data.split(":", 1)
return f"@{nickname}"
return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户"
elif segment.type == "image":
# 如果是base64图片数据
if isinstance(segment.data, str):
state["has_picid"] = True
state["is_picid"] = True
state["is_emoji"] = False
state["is_video"] = False
image_manager = get_image_manager()
_, processed_text = await image_manager.process_image(segment.data)
return processed_text
return "[发了一张图片,网卡了加载不出来]"
elif segment.type == "emoji":
state["has_emoji"] = True
state["is_emoji"] = True
state["is_picid"] = False
state["is_voice"] = False
state["is_video"] = False
if isinstance(segment.data, str):
return await get_image_manager().get_emoji_description(segment.data)
return "[发了一个表情包,网卡了加载不出来]"
elif segment.type == "voice":
state["is_picid"] = False
state["is_emoji"] = False
state["is_voice"] = True
state["is_video"] = False
# 检查消息是否由机器人自己发送
if message_info and message_info.user_info and str(message_info.user_info.user_id) == str(global_config.bot.qq_account):
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {message_info.user_info.user_id}),尝试从缓存获取文本。")
if isinstance(segment.data, str):
cached_text = consume_self_voice_text(segment.data)
if cached_text:
logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'")
return f"[语音:{cached_text}]"
else:
logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
# 标准语音识别流程
if isinstance(segment.data, str):
return await get_voice_text(segment.data)
return "[发了一段语音,网卡了加载不出来]"
elif segment.type == "mention_bot":
state["is_picid"] = False
state["is_emoji"] = False
state["is_voice"] = False
state["is_video"] = False
state["is_mentioned"] = float(segment.data)
return ""
elif segment.type == "priority_info":
state["is_picid"] = False
state["is_emoji"] = False
state["is_voice"] = False
if isinstance(segment.data, dict):
# 处理优先级信息
state["priority_mode"] = "priority"
state["priority_info"] = segment.data
return ""
elif segment.type == "file":
if isinstance(segment.data, dict):
file_name = segment.data.get('name', '未知文件')
file_size = segment.data.get('size', '未知大小')
return f"[文件:{file_name} ({file_size}字节)]"
return "[收到一个文件]"
elif segment.type == "video":
state["is_picid"] = False
state["is_emoji"] = False
state["is_voice"] = False
state["is_video"] = True
logger.info(f"接收到视频消息,数据类型: {type(segment.data)}")
# 检查视频分析功能是否可用
if not is_video_analysis_available():
logger.warning("⚠️ Rust视频处理模块不可用跳过视频分析")
return "[视频]"
if global_config.video_analysis.enable:
logger.info("已启用视频识别,开始识别")
if isinstance(segment.data, dict):
try:
# 从Adapter接收的视频数据
video_base64 = segment.data.get("base64")
filename = segment.data.get("filename", "video.mp4")
logger.info(f"视频文件名: {filename}")
logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
if video_base64:
# 解码base64视频数据
video_bytes = base64.b64decode(video_base64)
logger.info(f"解码后视频大小: {len(video_bytes)} 字节")
# 使用video analyzer分析视频
video_analyzer = get_video_analyzer()
result = await video_analyzer.analyze_video_from_bytes(
video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt
)
logger.info(f"视频分析结果: {result}")
# 返回视频分析结果
summary = result.get("summary", "")
if summary:
return f"[视频内容] {summary}"
else:
return "[已收到视频,但分析失败]"
else:
logger.warning("视频消息中没有base64数据")
return "[收到视频消息,但数据异常]"
except Exception as e:
logger.error(f"视频处理失败: {e!s}")
import traceback
logger.error(f"错误详情: {traceback.format_exc()}")
return "[收到视频,但处理时出现错误]"
else:
logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}")
return "[发了一个视频,但格式不支持]"
else:
return ""
else:
logger.warning(f"未知的消息段类型: {segment.type}")
return f"[{segment.type} 消息]"
except Exception as e:
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
def _prepare_additional_config(message_info: BaseMessageInfo, is_notify: bool, is_public_notice: bool, notice_type: str | None) -> str | None:
"""准备 additional_config包含 format_info 和 notice 信息
Args:
message_info: 消息基础信息
is_notify: 是否为notice消息
is_public_notice: 是否为公共notice
notice_type: notice类型
Returns:
str | None: JSON 字符串格式的 additional_config如果为空则返回 None
"""
try:
additional_config_data = {}
# 首先获取adapter传递的additional_config
if hasattr(message_info, 'additional_config') and message_info.additional_config:
if isinstance(message_info.additional_config, dict):
additional_config_data = message_info.additional_config.copy()
elif isinstance(message_info.additional_config, str):
try:
additional_config_data = orjson.loads(message_info.additional_config)
except Exception as e:
logger.warning(f"无法解析 additional_config JSON: {e}")
additional_config_data = {}
# 添加notice相关标志
if is_notify:
additional_config_data["is_notice"] = True
additional_config_data["notice_type"] = notice_type or "unknown"
additional_config_data["is_public_notice"] = bool(is_public_notice)
# 添加format_info到additional_config中
if hasattr(message_info, 'format_info') and message_info.format_info:
try:
format_info_dict = message_info.format_info.to_dict()
additional_config_data["format_info"] = format_info_dict
logger.debug(f"[message_processor] 嵌入 format_info 到 additional_config: {format_info_dict}")
except Exception as e:
logger.warning(f"将 format_info 转换为字典失败: {e}")
# 序列化为JSON字符串
if additional_config_data:
return orjson.dumps(additional_config_data).decode("utf-8")
except Exception as e:
logger.error(f"准备 additional_config 失败: {e}")
return None
def _extract_reply_from_segment(segment: Seg) -> str | None:
"""从消息段中提取reply_to信息
Args:
segment: 消息段
Returns:
str | None: 回复的消息ID如果没有则返回None
"""
try:
if hasattr(segment, "type") and segment.type == "seglist":
# 递归搜索seglist中的reply段
if hasattr(segment, "data") and segment.data:
for seg in segment.data:
reply_id = _extract_reply_from_segment(seg)
if reply_id:
return reply_id
elif hasattr(segment, "type") and segment.type == "reply":
# 找到reply段返回message_id
return str(segment.data) if segment.data else None
except Exception as e:
logger.warning(f"提取reply_to信息失败: {e}")
return None
# =============================================================================
# DatabaseMessages 扩展工具函数
# =============================================================================
def get_message_info_from_db_message(db_message: DatabaseMessages) -> BaseMessageInfo:
"""从 DatabaseMessages 重建 BaseMessageInfo用于需要 message_info 的遗留代码)
Args:
db_message: DatabaseMessages 对象
Returns:
BaseMessageInfo: 重建的消息信息对象
"""
from maim_message import UserInfo, GroupInfo
# 从 DatabaseMessages 的 user_info 转换为 maim_message.UserInfo
user_info = UserInfo(
platform=db_message.user_info.platform,
user_id=db_message.user_info.user_id,
user_nickname=db_message.user_info.user_nickname,
user_cardname=db_message.user_info.user_cardname or ""
)
# 从 DatabaseMessages 的 group_info 转换为 maim_message.GroupInfo如果存在
group_info = None
if db_message.group_info:
group_info = GroupInfo(
platform=db_message.group_info.group_platform or "",
group_id=db_message.group_info.group_id,
group_name=db_message.group_info.group_name
)
# 解析 additional_config从 JSON 字符串到字典)
additional_config = None
if db_message.additional_config:
try:
additional_config = orjson.loads(db_message.additional_config)
except Exception:
# 如果解析失败,保持为字符串
pass
# 创建 BaseMessageInfo
message_info = BaseMessageInfo(
platform=db_message.chat_info.platform,
message_id=db_message.message_id,
time=db_message.time,
user_info=user_info,
group_info=group_info,
additional_config=additional_config # type: ignore
)
return message_info
def set_db_message_runtime_attr(db_message: DatabaseMessages, attr_name: str, value: Any) -> None:
"""安全地为 DatabaseMessages 设置运行时属性
Args:
db_message: DatabaseMessages 对象
attr_name: 属性名
value: 属性值
"""
setattr(db_message, attr_name, value)
def get_db_message_runtime_attr(db_message: DatabaseMessages, attr_name: str, default: Any = None) -> Any:
"""安全地获取 DatabaseMessages 的运行时属性
Args:
db_message: DatabaseMessages 对象
attr_name: 属性名
default: 默认值
Returns:
属性值或默认值
"""
return getattr(db_message, attr_name, default)

View File

@@ -0,0 +1,434 @@
# MessageRecv 类备份 - 已从 message.py 中移除
# 备份日期: 2025-10-31
# 此类已被 DatabaseMessages 完全取代
# MessageRecv 类已被移除
# 现在所有消息处理都使用 DatabaseMessages
# 如果需要从消息字典创建 DatabaseMessages请使用 message_processor.process_message_from_dict()
#
# 历史参考: MessageRecv 曾经是接收消息的包装类,现已被 DatabaseMessages 完全取代
# 迁移完成日期: 2025-10-31
"""
# 以下是已删除的 MessageRecv 类(保留作为参考)
class MessageRecv:
接收消息类 - DatabaseMessages 的轻量级包装器
这个类现在主要作为适配器层,处理外部消息格式并内部使用 DatabaseMessages。
保留此类是为了向后兼容性和处理 message_segment 的异步逻辑。
"""
def __init__(self, message_dict: dict[str, Any]):
"""从MessageCQ的字典初始化
Args:
message_dict: MessageCQ序列化后的字典
"""
# 保留原始消息信息用于某些场景
self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
self.raw_message = message_dict.get("raw_message")
# 处理状态(在process()之前临时使用)
self._processing_state = {
"is_emoji": False,
"has_emoji": False,
"is_picid": False,
"has_picid": False,
"is_voice": False,
"is_video": False,
"is_mentioned": None,
"is_at": False,
"priority_mode": "interest",
"priority_info": None,
}
self.chat_stream = None
self.reply = None
self.processed_plain_text = message_dict.get("processed_plain_text", "")
# 解析additional_config中的notice信息
self.is_notify = False
self.is_public_notice = False
self.notice_type = None
if self.message_info.additional_config and isinstance(self.message_info.additional_config, dict):
self.is_notify = self.message_info.additional_config.get("is_notice", False)
self.is_public_notice = self.message_info.additional_config.get("is_public_notice", False)
self.notice_type = self.message_info.additional_config.get("notice_type")
# 兼容性属性 - 代理到 _processing_state
@property
def is_emoji(self) -> bool:
return self._processing_state["is_emoji"]
@is_emoji.setter
def is_emoji(self, value: bool):
self._processing_state["is_emoji"] = value
@property
def has_emoji(self) -> bool:
return self._processing_state["has_emoji"]
@has_emoji.setter
def has_emoji(self, value: bool):
self._processing_state["has_emoji"] = value
@property
def is_picid(self) -> bool:
return self._processing_state["is_picid"]
@is_picid.setter
def is_picid(self, value: bool):
self._processing_state["is_picid"] = value
@property
def has_picid(self) -> bool:
return self._processing_state["has_picid"]
@has_picid.setter
def has_picid(self, value: bool):
self._processing_state["has_picid"] = value
@property
def is_voice(self) -> bool:
return self._processing_state["is_voice"]
@is_voice.setter
def is_voice(self, value: bool):
self._processing_state["is_voice"] = value
@property
def is_video(self) -> bool:
return self._processing_state["is_video"]
@is_video.setter
def is_video(self, value: bool):
self._processing_state["is_video"] = value
@property
def is_mentioned(self):
return self._processing_state["is_mentioned"]
@is_mentioned.setter
def is_mentioned(self, value):
self._processing_state["is_mentioned"] = value
@property
def is_at(self) -> bool:
return self._processing_state["is_at"]
@is_at.setter
def is_at(self, value: bool):
self._processing_state["is_at"] = value
@property
def priority_mode(self) -> str:
return self._processing_state["priority_mode"]
@priority_mode.setter
def priority_mode(self, value: str):
self._processing_state["priority_mode"] = value
@property
def priority_info(self):
return self._processing_state["priority_info"]
@priority_info.setter
def priority_info(self, value):
self._processing_state["priority_info"] = value
# 其他常用属性
interest_value: float = 0.0
is_command: bool = False
memorized_times: int = 0
def __post_init__(self):
"""dataclass 初始化后处理"""
self.key_words = []
self.key_words_lite = []
def update_chat_stream(self, chat_stream: "ChatStream"):
self.chat_stream = chat_stream
def to_database_message(self) -> "DatabaseMessages":
"""将 MessageRecv 转换为 DatabaseMessages 对象
Returns:
DatabaseMessages: 数据库消息对象
"""
import time
message_info = self.message_info
msg_user_info = getattr(message_info, "user_info", None)
stream_user_info = getattr(self.chat_stream, "user_info", None) if self.chat_stream else None
group_info = getattr(self.chat_stream, "group_info", None) if self.chat_stream else None
message_id = message_info.message_id or ""
message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time()
is_mentioned = None
if isinstance(self.is_mentioned, bool):
is_mentioned = self.is_mentioned
elif isinstance(self.is_mentioned, int | float):
is_mentioned = self.is_mentioned != 0
# 提取用户信息
user_id = ""
user_nickname = ""
user_cardname = None
user_platform = ""
if msg_user_info:
user_id = str(getattr(msg_user_info, "user_id", "") or "")
user_nickname = getattr(msg_user_info, "user_nickname", "") or ""
user_cardname = getattr(msg_user_info, "user_cardname", None)
user_platform = getattr(msg_user_info, "platform", "") or ""
elif stream_user_info:
user_id = str(getattr(stream_user_info, "user_id", "") or "")
user_nickname = getattr(stream_user_info, "user_nickname", "") or ""
user_cardname = getattr(stream_user_info, "user_cardname", None)
user_platform = getattr(stream_user_info, "platform", "") or ""
# 提取聊天流信息
chat_user_id = str(getattr(stream_user_info, "user_id", "") or "") if stream_user_info else ""
chat_user_nickname = getattr(stream_user_info, "user_nickname", "") or "" if stream_user_info else ""
chat_user_cardname = getattr(stream_user_info, "user_cardname", None) if stream_user_info else None
chat_user_platform = getattr(stream_user_info, "platform", "") or "" if stream_user_info else ""
group_id = getattr(group_info, "group_id", None) if group_info else None
group_name = getattr(group_info, "group_name", None) if group_info else None
group_platform = getattr(group_info, "platform", None) if group_info else None
# 准备 additional_config
additional_config_str = None
try:
import orjson
additional_config_data = {}
# 首先获取adapter传递的additional_config
if hasattr(message_info, 'additional_config') and message_info.additional_config:
if isinstance(message_info.additional_config, dict):
additional_config_data = message_info.additional_config.copy()
elif isinstance(message_info.additional_config, str):
try:
additional_config_data = orjson.loads(message_info.additional_config)
except Exception as e:
logger.warning(f"无法解析 additional_config JSON: {e}")
additional_config_data = {}
# 添加notice相关标志
if self.is_notify:
additional_config_data["is_notice"] = True
additional_config_data["notice_type"] = self.notice_type or "unknown"
additional_config_data["is_public_notice"] = bool(self.is_public_notice)
# 添加format_info到additional_config中
if hasattr(message_info, 'format_info') and message_info.format_info:
try:
format_info_dict = message_info.format_info.to_dict()
additional_config_data["format_info"] = format_info_dict
logger.debug(f"[message.py] 嵌入 format_info 到 additional_config: {format_info_dict}")
except Exception as e:
logger.warning(f"将 format_info 转换为字典失败: {e}")
# 序列化为JSON字符串
if additional_config_data:
additional_config_str = orjson.dumps(additional_config_data).decode("utf-8")
except Exception as e:
logger.error(f"准备 additional_config 失败: {e}")
# 创建数据库消息对象
db_message = DatabaseMessages(
message_id=message_id,
time=float(message_time),
chat_id=self.chat_stream.stream_id if self.chat_stream else "",
processed_plain_text=self.processed_plain_text,
display_message=self.processed_plain_text,
is_mentioned=is_mentioned,
is_at=bool(self.is_at) if self.is_at is not None else None,
is_emoji=bool(self.is_emoji),
is_picid=bool(self.is_picid),
is_command=bool(self.is_command),
is_notify=bool(self.is_notify),
is_public_notice=bool(self.is_public_notice),
notice_type=self.notice_type,
additional_config=additional_config_str,
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
user_platform=user_platform,
chat_info_stream_id=self.chat_stream.stream_id if self.chat_stream else "",
chat_info_platform=self.chat_stream.platform if self.chat_stream else "",
chat_info_create_time=float(self.chat_stream.create_time) if self.chat_stream else 0.0,
chat_info_last_active_time=float(self.chat_stream.last_active_time) if self.chat_stream else 0.0,
chat_info_user_id=chat_user_id,
chat_info_user_nickname=chat_user_nickname,
chat_info_user_cardname=chat_user_cardname,
chat_info_user_platform=chat_user_platform,
chat_info_group_id=group_id,
chat_info_group_name=group_name,
chat_info_group_platform=group_platform,
)
# 同步兴趣度等衍生属性
db_message.interest_value = getattr(self, "interest_value", 0.0)
setattr(db_message, "should_reply", getattr(self, "should_reply", False))
setattr(db_message, "should_act", getattr(self, "should_act", False))
return db_message
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本
这个方法必须在创建实例后显式调用,因为它包含异步操作。
"""
self.processed_plain_text = await self._process_message_segments(self.message_segment)
async def _process_single_segment(self, segment: Seg) -> str:
"""处理单个消息段
Args:
segment: 消息段
Returns:
str: 处理后的文本
"""
try:
if segment.type == "text":
self.is_picid = False
self.is_emoji = False
self.is_video = False
return segment.data # type: ignore
elif segment.type == "at":
self.is_picid = False
self.is_emoji = False
self.is_video = False
# 处理at消息格式为"昵称:QQ号"
if isinstance(segment.data, str) and ":" in segment.data:
nickname, qq_id = segment.data.split(":", 1)
return f"@{nickname}"
return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户"
elif segment.type == "image":
# 如果是base64图片数据
if isinstance(segment.data, str):
self.has_picid = True
self.is_picid = True
self.is_emoji = False
self.is_video = False
image_manager = get_image_manager()
# print(f"segment.data: {segment.data}")
_, processed_text = await image_manager.process_image(segment.data)
return processed_text
return "[发了一张图片,网卡了加载不出来]"
elif segment.type == "emoji":
self.has_emoji = True
self.is_emoji = True
self.is_picid = False
self.is_voice = False
self.is_video = False
if isinstance(segment.data, str):
return await get_image_manager().get_emoji_description(segment.data)
return "[发了一个表情包,网卡了加载不出来]"
elif segment.type == "voice":
self.is_picid = False
self.is_emoji = False
self.is_voice = True
self.is_video = False
# 检查消息是否由机器人自己发送
if self.message_info and self.message_info.user_info and str(self.message_info.user_info.user_id) == str(global_config.bot.qq_account):
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {self.message_info.user_info.user_id}),尝试从缓存获取文本。")
if isinstance(segment.data, str):
cached_text = consume_self_voice_text(segment.data)
if cached_text:
logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'")
return f"[语音:{cached_text}]"
else:
logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
# 标准语音识别流程 (也作为缓存未命中的后备方案)
if isinstance(segment.data, str):
return await get_voice_text(segment.data)
return "[发了一段语音,网卡了加载不出来]"
elif segment.type == "mention_bot":
self.is_picid = False
self.is_emoji = False
self.is_voice = False
self.is_video = False
self.is_mentioned = float(segment.data) # type: ignore
return ""
elif segment.type == "priority_info":
self.is_picid = False
self.is_emoji = False
self.is_voice = False
if isinstance(segment.data, dict):
# 处理优先级信息
self.priority_mode = "priority"
self.priority_info = segment.data
"""
{
'message_type': 'vip', # vip or normal
'message_priority': 1.0, # 优先级大为优先float
}
"""
return ""
elif segment.type == "file":
if isinstance(segment.data, dict):
file_name = segment.data.get('name', '未知文件')
file_size = segment.data.get('size', '未知大小')
return f"[文件:{file_name} ({file_size}字节)]"
return "[收到一个文件]"
elif segment.type == "video":
self.is_picid = False
self.is_emoji = False
self.is_voice = False
self.is_video = True
logger.info(f"接收到视频消息,数据类型: {type(segment.data)}")
# 检查视频分析功能是否可用
if not is_video_analysis_available():
logger.warning("⚠️ Rust视频处理模块不可用跳过视频分析")
return "[视频]"
if global_config.video_analysis.enable:
logger.info("已启用视频识别,开始识别")
if isinstance(segment.data, dict):
try:
# 从Adapter接收的视频数据
video_base64 = segment.data.get("base64")
filename = segment.data.get("filename", "video.mp4")
logger.info(f"视频文件名: {filename}")
logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
if video_base64:
# 解码base64视频数据
video_bytes = base64.b64decode(video_base64)
logger.info(f"解码后视频大小: {len(video_bytes)} 字节")
# 使用video analyzer分析视频
video_analyzer = get_video_analyzer()
result = await video_analyzer.analyze_video_from_bytes(
video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt
)
logger.info(f"视频分析结果: {result}")
# 返回视频分析结果
summary = result.get("summary", "")
if summary:
return f"[视频内容] {summary}"
else:
return "[已收到视频,但分析失败]"
else:
logger.warning("视频消息中没有base64数据")
return "[收到视频消息,但数据异常]"
except Exception as e:
logger.error(f"视频处理失败: {e!s}")
import traceback
logger.error(f"错误详情: {traceback.format_exc()}")
return "[收到视频,但处理时出现错误]"
else:
logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}")
return "[发了一个视频,但格式不支持]"
else:

View File

@@ -9,8 +9,10 @@ from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Images, Messages from src.common.database.sqlalchemy_models import Images, Messages
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from .chat_stream import ChatStream from .chat_stream import ChatStream
from .message import MessageRecv, MessageSending from .message import MessageSending
logger = get_logger("message_storage") logger = get_logger("message_storage")
@@ -34,97 +36,166 @@ class MessageStorage:
return [] return []
@staticmethod @staticmethod
async def store_message(message: MessageSending | MessageRecv, chat_stream: ChatStream) -> None: async def store_message(message: DatabaseMessages | MessageSending, chat_stream: ChatStream) -> None:
"""存储消息到数据库""" """存储消息到数据库"""
try: try:
# 过滤敏感信息的正则模式 # 过滤敏感信息的正则模式
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>" pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
processed_plain_text = message.processed_plain_text # 如果是 DatabaseMessages直接使用它的字段
if isinstance(message, DatabaseMessages):
if processed_plain_text: processed_plain_text = message.processed_plain_text
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text) if processed_plain_text:
# 增加对None的防御性处理 processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
safe_processed_plain_text = processed_plain_text or "" safe_processed_plain_text = processed_plain_text or ""
filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL) filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
else:
filtered_processed_plain_text = ""
if isinstance(message, MessageSending):
display_message = message.display_message
if display_message:
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
else: else:
# 如果没有设置display_message使用processed_plain_text作为显示消息 filtered_processed_plain_text = ""
filtered_display_message = (
re.sub(pattern, "", (message.processed_plain_text or ""), flags=re.DOTALL) display_message = message.display_message or message.processed_plain_text or ""
) filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
interest_value = 0
is_mentioned = False # 直接从 DatabaseMessages 获取所有字段
reply_to = message.reply_to msg_id = message.message_id
priority_mode = "" msg_time = message.time
priority_info = {} chat_id = message.chat_id
is_emoji = False reply_to = "" # DatabaseMessages 没有 reply_to 字段
is_picid = False
is_notify = False
is_command = False
key_words = ""
key_words_lite = ""
else:
filtered_display_message = ""
interest_value = message.interest_value
is_mentioned = message.is_mentioned is_mentioned = message.is_mentioned
reply_to = "" interest_value = message.interest_value or 0.0
priority_mode = message.priority_mode priority_mode = "" # DatabaseMessages 没有 priority_mode
priority_info = message.priority_info priority_info_json = None # DatabaseMessages 没有 priority_info
is_emoji = message.is_emoji is_emoji = message.is_emoji or False
is_picid = message.is_picid is_picid = message.is_picid or False
is_notify = message.is_notify is_notify = message.is_notify or False
is_command = message.is_command is_command = message.is_command or False
# 序列化关键词列表为JSON字符串 key_words = "" # DatabaseMessages 没有 key_words
key_words = MessageStorage._serialize_keywords(message.key_words) key_words_lite = ""
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite) memorized_times = 0 # DatabaseMessages 没有 memorized_times
# 使用 DatabaseMessages 中的嵌套对象信息
user_platform = message.user_info.platform if message.user_info else ""
user_id = message.user_info.user_id if message.user_info else ""
user_nickname = message.user_info.user_nickname if message.user_info else ""
user_cardname = message.user_info.user_cardname if message.user_info else None
chat_info_stream_id = message.chat_info.stream_id if message.chat_info else ""
chat_info_platform = message.chat_info.platform if message.chat_info else ""
chat_info_create_time = message.chat_info.create_time if message.chat_info else 0.0
chat_info_last_active_time = message.chat_info.last_active_time if message.chat_info else 0.0
chat_info_user_platform = message.chat_info.user_info.platform if message.chat_info and message.chat_info.user_info else ""
chat_info_user_id = message.chat_info.user_info.user_id if message.chat_info and message.chat_info.user_info else ""
chat_info_user_nickname = message.chat_info.user_info.user_nickname if message.chat_info and message.chat_info.user_info else ""
chat_info_user_cardname = message.chat_info.user_info.user_cardname if message.chat_info and message.chat_info.user_info else None
chat_info_group_platform = message.group_info.group_platform if message.group_info else None
chat_info_group_id = message.group_info.group_id if message.group_info else None
chat_info_group_name = message.group_info.group_name if message.group_info else None
else:
# MessageSending 处理逻辑
processed_plain_text = message.processed_plain_text
chat_info_dict = chat_stream.to_dict() if processed_plain_text:
user_info_dict = message.message_info.user_info.to_dict() # type: ignore processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
# 增加对None的防御性处理
safe_processed_plain_text = processed_plain_text or ""
filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
else:
filtered_processed_plain_text = ""
# message_id 现在是 TextField直接使用字符串值 if isinstance(message, MessageSending):
msg_id = message.message_info.message_id display_message = message.display_message
if display_message:
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
else:
# 如果没有设置display_message使用processed_plain_text作为显示消息
filtered_display_message = (
re.sub(pattern, "", (message.processed_plain_text or ""), flags=re.DOTALL)
)
interest_value = 0
is_mentioned = False
reply_to = message.reply_to
priority_mode = ""
priority_info = {}
is_emoji = False
is_picid = False
is_notify = False
is_command = False
key_words = ""
key_words_lite = ""
else:
filtered_display_message = ""
interest_value = message.interest_value
is_mentioned = message.is_mentioned
reply_to = ""
priority_mode = message.priority_mode
priority_info = message.priority_info
is_emoji = message.is_emoji
is_picid = message.is_picid
is_notify = message.is_notify
is_command = message.is_command
# 序列化关键词列表为JSON字符串
key_words = MessageStorage._serialize_keywords(message.key_words)
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
# 安全地获取 group_info, 如果为 None 则视为空字典 chat_info_dict = chat_stream.to_dict()
group_info_from_chat = chat_info_dict.get("group_info") or {} user_info_dict = message.message_info.user_info.to_dict() # type: ignore
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
user_info_from_chat = chat_info_dict.get("user_info") or {}
# 将priority_info字典序列化为JSON字符串以便存储到数据库的Text字段 # message_id 现在是 TextField直接使用字符串值
priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None msg_id = message.message_info.message_id
msg_time = float(message.message_info.time or time.time())
chat_id = chat_stream.stream_id
memorized_times = message.memorized_times
# 安全地获取 group_info, 如果为 None 则视为空字典
group_info_from_chat = chat_info_dict.get("group_info") or {}
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
user_info_from_chat = chat_info_dict.get("user_info") or {}
# 将priority_info字典序列化为JSON字符串以便存储到数据库的Text字段
priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None
user_platform = user_info_dict.get("platform")
user_id = user_info_dict.get("user_id")
user_nickname = user_info_dict.get("user_nickname")
user_cardname = user_info_dict.get("user_cardname")
chat_info_stream_id = chat_info_dict.get("stream_id")
chat_info_platform = chat_info_dict.get("platform")
chat_info_create_time = float(chat_info_dict.get("create_time", 0.0))
chat_info_last_active_time = float(chat_info_dict.get("last_active_time", 0.0))
chat_info_user_platform = user_info_from_chat.get("platform")
chat_info_user_id = user_info_from_chat.get("user_id")
chat_info_user_nickname = user_info_from_chat.get("user_nickname")
chat_info_user_cardname = user_info_from_chat.get("user_cardname")
chat_info_group_platform = group_info_from_chat.get("platform")
chat_info_group_id = group_info_from_chat.get("group_id")
chat_info_group_name = group_info_from_chat.get("group_name")
# 获取数据库会话 # 获取数据库会话
new_message = Messages( new_message = Messages(
message_id=msg_id, message_id=msg_id,
time=float(message.message_info.time or time.time()), time=msg_time,
chat_id=chat_stream.stream_id, chat_id=chat_id,
reply_to=reply_to, reply_to=reply_to,
is_mentioned=is_mentioned, is_mentioned=is_mentioned,
chat_info_stream_id=chat_info_dict.get("stream_id"), chat_info_stream_id=chat_info_stream_id,
chat_info_platform=chat_info_dict.get("platform"), chat_info_platform=chat_info_platform,
chat_info_user_platform=user_info_from_chat.get("platform"), chat_info_user_platform=chat_info_user_platform,
chat_info_user_id=user_info_from_chat.get("user_id"), chat_info_user_id=chat_info_user_id,
chat_info_user_nickname=user_info_from_chat.get("user_nickname"), chat_info_user_nickname=chat_info_user_nickname,
chat_info_user_cardname=user_info_from_chat.get("user_cardname"), chat_info_user_cardname=chat_info_user_cardname,
chat_info_group_platform=group_info_from_chat.get("platform"), chat_info_group_platform=chat_info_group_platform,
chat_info_group_id=group_info_from_chat.get("group_id"), chat_info_group_id=chat_info_group_id,
chat_info_group_name=group_info_from_chat.get("group_name"), chat_info_group_name=chat_info_group_name,
chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)), chat_info_create_time=chat_info_create_time,
chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)), chat_info_last_active_time=chat_info_last_active_time,
user_platform=user_info_dict.get("platform"), user_platform=user_platform,
user_id=user_info_dict.get("user_id"), user_id=user_id,
user_nickname=user_info_dict.get("user_nickname"), user_nickname=user_nickname,
user_cardname=user_info_dict.get("user_cardname"), user_cardname=user_cardname,
processed_plain_text=filtered_processed_plain_text, processed_plain_text=filtered_processed_plain_text,
display_message=filtered_display_message, display_message=filtered_display_message,
memorized_times=message.memorized_times, memorized_times=memorized_times,
interest_value=interest_value, interest_value=interest_value,
priority_mode=priority_mode, priority_mode=priority_mode,
priority_info=priority_info_json, priority_info=priority_info_json,
@@ -145,36 +216,43 @@ class MessageStorage:
traceback.print_exc() traceback.print_exc()
@staticmethod @staticmethod
async def update_message(message): async def update_message(message_data: dict):
"""更新消息ID""" """更新消息ID(从消息字典)"""
try: try:
mmc_message_id = message.message_info.message_id # 从字典中提取信息
message_info = message_data.get("message_info", {})
mmc_message_id = message_info.get("message_id")
message_segment = message_data.get("message_segment", {})
segment_type = message_segment.get("type") if isinstance(message_segment, dict) else None
segment_data = message_segment.get("data", {}) if isinstance(message_segment, dict) else {}
qq_message_id = None qq_message_id = None
logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {message.message_segment.type}") logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {segment_type}")
# 根据消息段类型提取message_id # 根据消息段类型提取message_id
if message.message_segment.type == "notify": if segment_type == "notify":
qq_message_id = message.message_segment.data.get("id") qq_message_id = segment_data.get("id")
elif message.message_segment.type == "text": elif segment_type == "text":
qq_message_id = message.message_segment.data.get("id") qq_message_id = segment_data.get("id")
elif message.message_segment.type == "reply": elif segment_type == "reply":
qq_message_id = message.message_segment.data.get("id") qq_message_id = segment_data.get("id")
if qq_message_id: if qq_message_id:
logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}") logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}")
elif message.message_segment.type == "adapter_response": elif segment_type == "adapter_response":
logger.debug("适配器响应消息不需要更新ID") logger.debug("适配器响应消息不需要更新ID")
return return
elif message.message_segment.type == "adapter_command": elif segment_type == "adapter_command":
logger.debug("适配器命令消息不需要更新ID") logger.debug("适配器命令消息不需要更新ID")
return return
else: else:
logger.debug(f"未知的消息段类型: {message.message_segment.type}跳过ID更新") logger.debug(f"未知的消息段类型: {segment_type}跳过ID更新")
return return
if not qq_message_id: if not qq_message_id:
logger.debug(f"消息段类型 {message.message_segment.type} 中未找到有效的message_id跳过更新") logger.debug(f"消息段类型 {segment_type} 中未找到有效的message_id跳过更新")
logger.debug(f"消息段数据: {message.message_segment.data}") logger.debug(f"消息段数据: {segment_data}")
return return
# 使用上下文管理器确保session正确管理 # 使用上下文管理器确保session正确管理

View File

@@ -137,7 +137,7 @@ class ActionModifier:
logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用") logger.debug(f"{self.log_prefix}阶段一移除动作: {disabled_action_name},原因: 用户自行禁用")
# === 第二阶段:检查动作的关联类型 === # === 第二阶段:检查动作的关联类型 ===
chat_context = self.chat_stream.stream_context chat_context = self.chat_stream.context_manager.context
current_actions_s2 = self.action_manager.get_using_actions() current_actions_s2 = self.action_manager.get_using_actions()
type_mismatched_actions = self._check_action_associated_types(current_actions_s2, chat_context) type_mismatched_actions = self._check_action_associated_types(current_actions_s2, chat_context)

View File

@@ -13,7 +13,7 @@ from typing import Any
from src.chat.express.expression_selector import expression_selector from src.chat.express.expression_selector import expression_selector
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.message import MessageRecv, MessageSending, Seg, UserInfo from src.chat.message_receive.message import MessageSending, Seg, UserInfo
from src.chat.message_receive.uni_message_sender import HeartFCSender from src.chat.message_receive.uni_message_sender import HeartFCSender
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
build_readable_messages, build_readable_messages,
@@ -1733,7 +1733,7 @@ class DefaultReplyer:
is_emoji: bool, is_emoji: bool,
thinking_start_time: float, thinking_start_time: float,
display_message: str, display_message: str,
anchor_message: MessageRecv | None = None, anchor_message: DatabaseMessages | None = None,
) -> MessageSending: ) -> MessageSending:
"""构建单个发送消息""" """构建单个发送消息"""
@@ -1743,8 +1743,11 @@ class DefaultReplyer:
platform=self.chat_stream.platform, platform=self.chat_stream.platform,
) )
# await anchor_message.process() # 从 DatabaseMessages 获取 sender_info
sender_info = anchor_message.message_info.user_info if anchor_message else None if anchor_message:
sender_info = anchor_message.user_info
else:
sender_info = None
return MessageSending( return MessageSending(
message_id=message_id, # 使用片段的唯一ID message_id=message_id, # 使用片段的唯一ID

View File

@@ -11,7 +11,7 @@ import rjieba
from maim_message import UserInfo from maim_message import UserInfo
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.message import MessageRecv # MessageRecv 已被移除,现在使用 DatabaseMessages
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.message_repository import count_messages, find_messages from src.common.message_repository import count_messages, find_messages
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
@@ -41,34 +41,58 @@ def db_message_to_str(message_dict: dict) -> str:
return result return result
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
"""检查消息是否提到了机器人""" """检查消息是否提到了机器人
Args:
message: DatabaseMessages 消息对象
Returns:
tuple[bool, float]: (是否提及, 提及概率)
"""
keywords = [global_config.bot.nickname] keywords = [global_config.bot.nickname]
nicknames = global_config.bot.alias_names nicknames = global_config.bot.alias_names
reply_probability = 0.0 reply_probability = 0.0
is_at = False is_at = False
is_mentioned = False is_mentioned = False
if message.is_mentioned is not None:
return bool(message.is_mentioned), message.is_mentioned # 检查 is_mentioned 属性
if ( mentioned_attr = getattr(message, "is_mentioned", None)
message.message_info.additional_config is not None if mentioned_attr is not None:
and message.message_info.additional_config.get("is_mentioned") is not None
):
try: try:
reply_probability = float(message.message_info.additional_config.get("is_mentioned")) # type: ignore return bool(mentioned_attr), float(mentioned_attr)
except (ValueError, TypeError):
pass
# 检查 additional_config
additional_config = None
# DatabaseMessages: additional_config 是 JSON 字符串
if message.additional_config:
try:
import orjson
additional_config = orjson.loads(message.additional_config)
except Exception:
pass
if additional_config and additional_config.get("is_mentioned") is not None:
try:
reply_probability = float(additional_config.get("is_mentioned")) # type: ignore
is_mentioned = True is_mentioned = True
return is_mentioned, reply_probability return is_mentioned, reply_probability
except Exception as e: except Exception as e:
logger.warning(str(e)) logger.warning(str(e))
logger.warning( logger.warning(
f"消息中包含不合理的设置 is_mentioned: {message.message_info.additional_config.get('is_mentioned')}" f"消息中包含不合理的设置 is_mentioned: {additional_config.get('is_mentioned')}"
) )
if global_config.bot.nickname in message.processed_plain_text: # 检查消息文本内容
processed_text = message.processed_plain_text or ""
if global_config.bot.nickname in processed_text:
is_mentioned = True is_mentioned = True
for alias_name in global_config.bot.alias_names: for alias_name in global_config.bot.alias_names:
if alias_name in message.processed_plain_text: if alias_name in processed_text:
is_mentioned = True is_mentioned = True
# 判断是否被@ # 判断是否被@
@@ -110,7 +134,6 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
logger.debug("被提及回复概率设置为100%") logger.debug("被提及回复概率设置为100%")
return is_mentioned, reply_probability return is_mentioned, reply_probability
async def get_embedding(text, request_type="embedding") -> list[float] | None: async def get_embedding(text, request_type="embedding") -> list[float] | None:
"""获取文本的embedding向量""" """获取文本的embedding向量"""
# 每次都创建新的LLMRequest实例以避免事件循环冲突 # 每次都创建新的LLMRequest实例以避免事件循环冲突

View File

@@ -9,15 +9,18 @@ from src.common.logger import get_logger
logger = get_logger("db_migration") logger = get_logger("db_migration")
async def check_and_migrate_database(): async def check_and_migrate_database(existing_engine=None):
""" """
异步检查数据库结构并自动迁移。 异步检查数据库结构并自动迁移。
- 自动创建不存在的表。 - 自动创建不存在的表。
- 自动为现有表添加缺失的列。 - 自动为现有表添加缺失的列。
- 自动为现有表创建缺失的索引。 - 自动为现有表创建缺失的索引。
Args:
existing_engine: 可选的已存在的数据库引擎。如果提供,将使用该引擎;否则获取全局引擎。
""" """
logger.info("正在检查数据库结构并执行自动迁移...") logger.info("正在检查数据库结构并执行自动迁移...")
engine = await get_engine() engine = existing_engine if existing_engine is not None else await get_engine()
async with engine.connect() as connection: async with engine.connect() as connection:
# 在同步上下文中运行inspector操作 # 在同步上下文中运行inspector操作

View File

@@ -780,12 +780,8 @@ async def initialize_database() -> tuple["AsyncEngine", async_sessionmaker[Async
_SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False) _SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False)
# 迁移 # 迁移
try: from src.common.database.db_migration import check_and_migrate_database
from src.common.database.db_migration import check_and_migrate_database await check_and_migrate_database(existing_engine=_engine)
await check_and_migrate_database(existing_engine=_engine)
except TypeError:
from src.common.database.db_migration import check_and_migrate_database as _legacy_migrate
await _legacy_migrate()
if config.database_type == "sqlite": if config.database_type == "sqlite":
await enable_sqlite_wal_mode(_engine) await enable_sqlite_wal_mode(_engine)

View File

@@ -2,7 +2,6 @@ import math
import random import random
import time import time
from src.chat.message_receive.message import MessageRecv
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
@@ -98,7 +97,7 @@ class ChatMood:
if not hasattr(self, "last_change_time"): if not hasattr(self, "last_change_time"):
self.last_change_time = 0 self.last_change_time = 0
async def update_mood_by_message(self, message: MessageRecv | DatabaseMessages, interested_rate: float): async def update_mood_by_message(self, message: DatabaseMessages, interested_rate: float):
# 确保异步初始化已完成 # 确保异步初始化已完成
await self._initialize() await self._initialize()
@@ -109,11 +108,8 @@ class ChatMood:
self.regression_count = 0 self.regression_count = 0
# 处理不同类型的消息对象 # 使用 DatabaseMessages 的时间字段
if isinstance(message, MessageRecv): message_time = message.time
message_time = message.message_info.time
else: # DatabaseMessages
message_time = message.time
# 防止负时间差 # 防止负时间差
during_last_time = max(0, message_time - self.last_change_time) during_last_time = max(0, message_time - self.last_change_time)

View File

@@ -86,13 +86,16 @@ async def file_to_stream(
import asyncio import asyncio
import time import time
import traceback import traceback
from typing import Any from typing import Any, TYPE_CHECKING
from maim_message import Seg, UserInfo from maim_message import Seg, UserInfo
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
# 导入依赖 # 导入依赖
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.message_receive.message import MessageRecv, MessageSending from src.chat.message_receive.message import MessageSending
from src.chat.message_receive.uni_message_sender import HeartFCSender from src.chat.message_receive.uni_message_sender import HeartFCSender
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
@@ -104,84 +107,53 @@ logger = get_logger("send_api")
_adapter_response_pool: dict[str, asyncio.Future] = {} _adapter_response_pool: dict[str, asyncio.Future] = {}
def message_dict_to_message_recv(message_dict: dict[str, Any]) -> MessageRecv | None: def message_dict_to_db_message(message_dict: dict[str, Any]) -> "DatabaseMessages | None":
"""查找要回复的消息 """从消息字典构建 DatabaseMessages 对象
Args: Args:
message_dict: 消息字典或 DatabaseMessages 对象 message_dict: 消息字典或 DatabaseMessages 对象
Returns: Returns:
Optional[MessageRecv]: 找到的消息,如果没找到则返回None Optional[DatabaseMessages]: 构建的消息对象,如果构建失败则返回None
""" """
# 兼容 DatabaseMessages 对象和字典 from src.common.data_models.database_data_model import DatabaseMessages
if isinstance(message_dict, dict):
user_platform = message_dict.get("user_platform", "")
user_id = message_dict.get("user_id", "")
user_nickname = message_dict.get("user_nickname", "")
user_cardname = message_dict.get("user_cardname", "")
chat_info_group_id = message_dict.get("chat_info_group_id")
chat_info_group_platform = message_dict.get("chat_info_group_platform", "")
chat_info_group_name = message_dict.get("chat_info_group_name", "")
chat_info_platform = message_dict.get("chat_info_platform", "")
message_id = message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id")
time_val = message_dict.get("time")
additional_config = message_dict.get("additional_config")
processed_plain_text = message_dict.get("processed_plain_text")
else:
# DatabaseMessages 对象
user_platform = getattr(message_dict, "user_platform", "")
user_id = getattr(message_dict, "user_id", "")
user_nickname = getattr(message_dict, "user_nickname", "")
user_cardname = getattr(message_dict, "user_cardname", "")
chat_info_group_id = getattr(message_dict, "chat_info_group_id", None)
chat_info_group_platform = getattr(message_dict, "chat_info_group_platform", "")
chat_info_group_name = getattr(message_dict, "chat_info_group_name", "")
chat_info_platform = getattr(message_dict, "chat_info_platform", "")
message_id = getattr(message_dict, "message_id", None)
time_val = getattr(message_dict, "time", None)
additional_config = getattr(message_dict, "additional_config", None)
processed_plain_text = getattr(message_dict, "processed_plain_text", "")
# 构建MessageRecv对象 # 如果已经是 DatabaseMessages直接返回
user_info = { if isinstance(message_dict, DatabaseMessages):
"platform": user_platform, return message_dict
"user_id": user_id,
"user_nickname": user_nickname, # 从字典提取信息
"user_cardname": user_cardname, user_platform = message_dict.get("user_platform", "")
} user_id = message_dict.get("user_id", "")
user_nickname = message_dict.get("user_nickname", "")
group_info = {} user_cardname = message_dict.get("user_cardname", "")
if chat_info_group_id: chat_info_group_id = message_dict.get("chat_info_group_id")
group_info = { chat_info_group_platform = message_dict.get("chat_info_group_platform", "")
"platform": chat_info_group_platform, chat_info_group_name = message_dict.get("chat_info_group_name", "")
"group_id": chat_info_group_id, chat_info_platform = message_dict.get("chat_info_platform", "")
"group_name": chat_info_group_name, message_id = message_dict.get("message_id") or message_dict.get("chat_info_message_id") or message_dict.get("id")
} time_val = message_dict.get("time", time.time())
additional_config = message_dict.get("additional_config")
format_info = {"content_format": "", "accept_format": ""} processed_plain_text = message_dict.get("processed_plain_text", "")
template_info = {"template_items": {}}
# DatabaseMessages 使用扁平参数构造
message_info = { db_message = DatabaseMessages(
"platform": chat_info_platform, message_id=message_id or "temp_reply_id",
"message_id": message_id, time=time_val,
"time": time_val, user_id=user_id,
"group_info": group_info, user_nickname=user_nickname,
"user_info": user_info, user_cardname=user_cardname,
"additional_config": additional_config, user_platform=user_platform,
"format_info": format_info, chat_info_group_id=chat_info_group_id,
"template_info": template_info, chat_info_group_name=chat_info_group_name,
} chat_info_group_platform=chat_info_group_platform,
chat_info_platform=chat_info_platform,
new_message_dict = { processed_plain_text=processed_plain_text,
"message_info": message_info, additional_config=additional_config
"raw_message": processed_plain_text, )
"processed_plain_text": processed_plain_text,
} logger.info(f"[SendAPI] 构建回复消息对象,发送者: {user_nickname}")
return db_message
message_recv = MessageRecv(new_message_dict)
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {user_nickname}")
return message_recv
def put_adapter_response(request_id: str, response_data: dict) -> None: def put_adapter_response(request_id: str, response_data: dict) -> None:
@@ -285,17 +257,17 @@ async def _send_to_target(
"message_id": "temp_reply_id", # 临时ID "message_id": "temp_reply_id", # 临时ID
"time": time.time() "time": time.time()
} }
anchor_message = message_dict_to_message_recv(message_dict=temp_message_dict) anchor_message = message_dict_to_db_message(message_dict=temp_message_dict)
else: else:
anchor_message = None anchor_message = None
reply_to_platform_id = f"{target_stream.platform}:{sender_id}" if anchor_message else None reply_to_platform_id = f"{target_stream.platform}:{sender_id}" if anchor_message else None
elif reply_to_message: elif reply_to_message:
anchor_message = message_dict_to_message_recv(message_dict=reply_to_message) anchor_message = message_dict_to_db_message(message_dict=reply_to_message)
if anchor_message: if anchor_message:
anchor_message.update_chat_stream(target_stream) # DatabaseMessages 不需要 update_chat_stream它是纯数据对象
reply_to_platform_id = ( reply_to_platform_id = (
f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}" f"{anchor_message.chat_info.platform}:{anchor_message.user_info.user_id}"
) )
else: else:
reply_to_platform_id = None reply_to_platform_id = None

View File

@@ -1,10 +1,14 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from src.chat.message_receive.message import MessageRecv from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.apis import send_api from src.plugin_system.apis import send_api
from src.plugin_system.base.component_types import ChatType, CommandInfo, ComponentType from src.plugin_system.base.component_types import ChatType, CommandInfo, ComponentType
if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("base_command") logger = get_logger("base_command")
@@ -29,11 +33,11 @@ class BaseCommand(ABC):
chat_type_allow: ChatType = ChatType.ALL chat_type_allow: ChatType = ChatType.ALL
"""允许的聊天类型,默认为所有类型""" """允许的聊天类型,默认为所有类型"""
def __init__(self, message: MessageRecv, plugin_config: dict | None = None): def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None):
"""初始化Command组件 """初始化Command组件
Args: Args:
message: 接收到的消息对象 message: 接收到的消息对象DatabaseMessages
plugin_config: 插件配置字典 plugin_config: 插件配置字典
""" """
self.message = message self.message = message
@@ -41,6 +45,9 @@ class BaseCommand(ABC):
self.plugin_config = plugin_config or {} # 直接存储插件配置字典 self.plugin_config = plugin_config or {} # 直接存储插件配置字典
self.log_prefix = "[Command]" self.log_prefix = "[Command]"
# chat_stream 会在运行时被 bot.py 设置
self.chat_stream: "ChatStream | None" = None
# 从类属性获取chat_type_allow设置 # 从类属性获取chat_type_allow设置
self.chat_type_allow = getattr(self.__class__, "chat_type_allow", ChatType.ALL) self.chat_type_allow = getattr(self.__class__, "chat_type_allow", ChatType.ALL)
@@ -49,7 +56,7 @@ class BaseCommand(ABC):
# 验证聊天类型限制 # 验证聊天类型限制
if not self._validate_chat_type(): if not self._validate_chat_type():
is_group = hasattr(self.message, "is_group_message") and self.message.is_group_message is_group = message.group_info is not None
logger.warning( logger.warning(
f"{self.log_prefix} Command '{self.command_name}' 不支持当前聊天类型: " f"{self.log_prefix} Command '{self.command_name}' 不支持当前聊天类型: "
f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}" f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}"
@@ -72,8 +79,8 @@ class BaseCommand(ABC):
if self.chat_type_allow == ChatType.ALL: if self.chat_type_allow == ChatType.ALL:
return True return True
# 检查是否为群聊消息 # 检查是否为群聊消息DatabaseMessages使用group_info来判断
is_group = self.message.message_info.group_info is_group = self.message.group_info is not None
if self.chat_type_allow == ChatType.GROUP and is_group: if self.chat_type_allow == ChatType.GROUP and is_group:
return True return True
@@ -137,12 +144,11 @@ class BaseCommand(ABC):
bool: 是否发送成功 bool: 是否发送成功
""" """
# 获取聊天流信息 # 获取聊天流信息
chat_stream = self.message.chat_stream if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
if not chat_stream or not hasattr(chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False return False
return await send_api.text_to_stream(text=content, stream_id=chat_stream.stream_id, reply_to=reply_to) return await send_api.text_to_stream(text=content, stream_id=self.chat_stream.stream_id, reply_to=reply_to)
async def send_type( async def send_type(
self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = "" self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = ""
@@ -160,15 +166,14 @@ class BaseCommand(ABC):
bool: 是否发送成功 bool: 是否发送成功
""" """
# 获取聊天流信息 # 获取聊天流信息
chat_stream = self.message.chat_stream if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
if not chat_stream or not hasattr(chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False return False
return await send_api.custom_to_stream( return await send_api.custom_to_stream(
message_type=message_type, message_type=message_type,
content=content, content=content,
stream_id=chat_stream.stream_id, stream_id=self.chat_stream.stream_id,
display_message=display_message, display_message=display_message,
typing=typing, typing=typing,
reply_to=reply_to, reply_to=reply_to,
@@ -190,8 +195,7 @@ class BaseCommand(ABC):
""" """
try: try:
# 获取聊天流信息 # 获取聊天流信息
chat_stream = self.message.chat_stream if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
if not chat_stream or not hasattr(chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False return False
@@ -200,7 +204,7 @@ class BaseCommand(ABC):
success = await send_api.command_to_stream( success = await send_api.command_to_stream(
command=command_data, command=command_data,
stream_id=chat_stream.stream_id, stream_id=self.chat_stream.stream_id,
storage_message=storage_message, storage_message=storage_message,
display_message=display_message, display_message=display_message,
) )
@@ -225,12 +229,11 @@ class BaseCommand(ABC):
Returns: Returns:
bool: 是否发送成功 bool: 是否发送成功
""" """
chat_stream = self.message.chat_stream if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
if not chat_stream or not hasattr(chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False return False
return await send_api.emoji_to_stream(emoji_base64, chat_stream.stream_id) return await send_api.emoji_to_stream(emoji_base64, self.chat_stream.stream_id)
async def send_image(self, image_base64: str) -> bool: async def send_image(self, image_base64: str) -> bool:
"""发送图片 """发送图片
@@ -241,12 +244,11 @@ class BaseCommand(ABC):
Returns: Returns:
bool: 是否发送成功 bool: 是否发送成功
""" """
chat_stream = self.message.chat_stream if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
if not chat_stream or not hasattr(chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False return False
return await send_api.image_to_stream(image_base64, chat_stream.stream_id) return await send_api.image_to_stream(image_base64, self.chat_stream.stream_id)
@classmethod @classmethod
def get_command_info(cls) -> "CommandInfo": def get_command_info(cls) -> "CommandInfo":

View File

@@ -5,8 +5,9 @@
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from src.chat.message_receive.message import MessageRecv from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
from src.plugin_system.apis import send_api from src.plugin_system.apis import send_api
@@ -14,6 +15,9 @@ from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.command_args import CommandArgs from src.plugin_system.base.command_args import CommandArgs
from src.plugin_system.base.component_types import ChatType, ComponentType, PlusCommandInfo from src.plugin_system.base.component_types import ChatType, ComponentType, PlusCommandInfo
if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("plus_command") logger = get_logger("plus_command")
@@ -50,23 +54,26 @@ class PlusCommand(ABC):
intercept_message: bool = False intercept_message: bool = False
"""是否拦截消息,不进行后续处理""" """是否拦截消息,不进行后续处理"""
def __init__(self, message: MessageRecv, plugin_config: dict | None = None): def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None):
"""初始化命令组件 """初始化命令组件
Args: Args:
message: 接收到的消息对象 message: 接收到的消息对象DatabaseMessages
plugin_config: 插件配置字典 plugin_config: 插件配置字典
""" """
self.message = message self.message = message
self.plugin_config = plugin_config or {} self.plugin_config = plugin_config or {}
self.log_prefix = "[PlusCommand]" self.log_prefix = "[PlusCommand]"
# chat_stream 会在运行时被 bot.py 设置
self.chat_stream: "ChatStream | None" = None
# 解析命令参数 # 解析命令参数
self._parse_command() self._parse_command()
# 验证聊天类型限制 # 验证聊天类型限制
if not self._validate_chat_type(): if not self._validate_chat_type():
is_group = self.message.message_info.group_info.group_id is_group = message.group_info is not None
logger.warning( logger.warning(
f"{self.log_prefix} 命令 '{self.command_name}' 不支持当前聊天类型: " f"{self.log_prefix} 命令 '{self.command_name}' 不支持当前聊天类型: "
f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}" f"{'群聊' if is_group else '私聊'}, 允许类型: {self.chat_type_allow.value}"
@@ -124,8 +131,8 @@ class PlusCommand(ABC):
if self.chat_type_allow == ChatType.ALL: if self.chat_type_allow == ChatType.ALL:
return True return True
# 检查是否为群聊消息 # 检查是否为群聊消息DatabaseMessages使用group_info判断
is_group = hasattr(self.message.message_info, "group_info") and self.message.message_info.group_info is_group = self.message.group_info is not None
if self.chat_type_allow == ChatType.GROUP and is_group: if self.chat_type_allow == ChatType.GROUP and is_group:
return True return True
@@ -152,7 +159,7 @@ class PlusCommand(ABC):
def _is_exact_command_call(self) -> bool: def _is_exact_command_call(self) -> bool:
"""检查是否是精确的命令调用(无参数)""" """检查是否是精确的命令调用(无参数)"""
if not hasattr(self.message, "plain_text") or not self.message.processed_plain_text: if not self.message.processed_plain_text:
return False return False
plain_text = self.message.processed_plain_text.strip() plain_text = self.message.processed_plain_text.strip()
@@ -218,12 +225,11 @@ class PlusCommand(ABC):
bool: 是否发送成功 bool: 是否发送成功
""" """
# 获取聊天流信息 # 获取聊天流信息
chat_stream = self.message.chat_stream if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
if not chat_stream or not hasattr(chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False return False
return await send_api.text_to_stream(text=content, stream_id=chat_stream.stream_id, reply_to=reply_to) return await send_api.text_to_stream(text=content, stream_id=self.chat_stream.stream_id, reply_to=reply_to)
async def send_type( async def send_type(
self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = "" self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = ""
@@ -241,15 +247,14 @@ class PlusCommand(ABC):
bool: 是否发送成功 bool: 是否发送成功
""" """
# 获取聊天流信息 # 获取聊天流信息
chat_stream = self.message.chat_stream if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
if not chat_stream or not hasattr(chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False return False
return await send_api.custom_to_stream( return await send_api.custom_to_stream(
message_type=message_type, message_type=message_type,
content=content, content=content,
stream_id=chat_stream.stream_id, stream_id=self.chat_stream.stream_id,
display_message=display_message, display_message=display_message,
typing=typing, typing=typing,
reply_to=reply_to, reply_to=reply_to,
@@ -264,12 +269,11 @@ class PlusCommand(ABC):
Returns: Returns:
bool: 是否发送成功 bool: 是否发送成功
""" """
chat_stream = self.message.chat_stream if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
if not chat_stream or not hasattr(chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False return False
return await send_api.emoji_to_stream(emoji_base64, chat_stream.stream_id) return await send_api.emoji_to_stream(emoji_base64, self.chat_stream.stream_id)
async def send_image(self, image_base64: str) -> bool: async def send_image(self, image_base64: str) -> bool:
"""发送图片 """发送图片
@@ -280,12 +284,11 @@ class PlusCommand(ABC):
Returns: Returns:
bool: 是否发送成功 bool: 是否发送成功
""" """
chat_stream = self.message.chat_stream if not self.chat_stream or not hasattr(self.chat_stream, "stream_id"):
if not chat_stream or not hasattr(chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False return False
return await send_api.image_to_stream(image_base64, chat_stream.stream_id) return await send_api.image_to_stream(image_base64, self.chat_stream.stream_id)
@classmethod @classmethod
def get_plus_command_info(cls) -> "PlusCommandInfo": def get_plus_command_info(cls) -> "PlusCommandInfo":
@@ -340,12 +343,12 @@ class PlusCommandAdapter(BaseCommand):
将PlusCommand适配到现有的插件系统继承BaseCommand 将PlusCommand适配到现有的插件系统继承BaseCommand
""" """
def __init__(self, plus_command_class, message: MessageRecv, plugin_config: dict | None = None): def __init__(self, plus_command_class, message: DatabaseMessages, plugin_config: dict | None = None):
"""初始化适配器 """初始化适配器
Args: Args:
plus_command_class: PlusCommand子类 plus_command_class: PlusCommand子类
message: 消息对象 message: 消息对象DatabaseMessages
plugin_config: 插件配置 plugin_config: 插件配置
""" """
# 先设置必要的类属性 # 先设置必要的类属性
@@ -400,7 +403,7 @@ def create_plus_command_adapter(plus_command_class):
command_pattern = plus_command_class._generate_command_pattern() command_pattern = plus_command_class._generate_command_pattern()
chat_type_allow = getattr(plus_command_class, "chat_type_allow", ChatType.ALL) chat_type_allow = getattr(plus_command_class, "chat_type_allow", ChatType.ALL)
def __init__(self, message: MessageRecv, plugin_config: dict | None = None): def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None):
super().__init__(message, plugin_config) super().__init__(message, plugin_config)
self.plus_command = plus_command_class(message, plugin_config) self.plus_command = plus_command_class(message, plugin_config)
self.priority = getattr(plus_command_class, "priority", 0) self.priority = getattr(plus_command_class, "priority", 0)

View File

@@ -410,11 +410,9 @@ class ChatterPlanExecutor:
) )
# 添加到chat_stream的已读消息中 # 添加到chat_stream的已读消息中
if hasattr(chat_stream, "stream_context") and chat_stream.stream_context: chat_stream.context_manager.context.history_messages.append(bot_message)
chat_stream.stream_context.history_messages.append(bot_message) logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...")
logger.debug(f"机器人回复已添加到已读消息: {reply_content[:50]}...")
else:
logger.warning("chat_stream没有stream_context无法添加已读消息")
except Exception as e: except Exception as e:
logger.error(f"添加机器人回复到已读消息时出错: {e}") logger.error(f"添加机器人回复到已读消息时出错: {e}")

View File

@@ -96,7 +96,7 @@ class ProactiveThinkingMessageHandler(BaseEventHandler):
"""处理消息事件 """处理消息事件
Args: Args:
kwargs: 事件参数,格式为 {"message": MessageRecv} kwargs: 事件参数,格式为 {"message": DatabaseMessages}
Returns: Returns:
HandlerResult: 处理结果 HandlerResult: 处理结果
@@ -104,7 +104,7 @@ class ProactiveThinkingMessageHandler(BaseEventHandler):
if not kwargs: if not kwargs:
return HandlerResult(success=True, continue_process=True, message=None) return HandlerResult(success=True, continue_process=True, message=None)
# 从 kwargs 中获取 MessageRecv 对象 # 从 kwargs 中获取 DatabaseMessages 对象
message = kwargs.get("message") message = kwargs.get("message")
if not message or not hasattr(message, "chat_stream"): if not message or not hasattr(message, "chat_stream"):
return HandlerResult(success=True, continue_process=True, message=None) return HandlerResult(success=True, continue_process=True, message=None)