This commit is contained in:
tt-P607
2025-12-14 13:45:35 +08:00
31 changed files with 567 additions and 7224 deletions

View File

@@ -9,6 +9,8 @@ from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any
from sqlalchemy.exc import SQLAlchemyError
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import ChatStreams
from src.common.logger import get_logger
@@ -159,20 +161,27 @@ class BatchDatabaseWriter:
logger.info("批量写入循环结束")
async def _collect_batch(self) -> list[StreamUpdatePayload]:
"""收集一个批次的数据"""
batch = []
deadline = time.time() + self.flush_interval
"""收集一个批次的数据
- 自适应刷新:队列增长加快时缩短等待时间
- 避免长时间空转:添加轻微抖动以分散竞争
"""
batch: list[StreamUpdatePayload] = []
# 根据当前队列长度调整刷新时间(最多缩短到 40%
qsize = self.write_queue.qsize()
adapt_factor = 1.0
if qsize > 0:
adapt_factor = max(0.4, min(1.0, self.batch_size / max(1, qsize)))
deadline = time.time() + (self.flush_interval * adapt_factor)
while len(batch) < self.batch_size and time.time() < deadline:
try:
# 计算剩余等待时间
remaining_time = max(0, deadline - time.time())
remaining_time = max(0.0, deadline - time.time())
if remaining_time == 0:
break
payload = await asyncio.wait_for(self.write_queue.get(), timeout=remaining_time)
# 轻微抖动,避免多个协程同时争抢队列
jitter = 0.002
payload = await asyncio.wait_for(self.write_queue.get(), timeout=remaining_time + jitter)
batch.append(payload)
except asyncio.TimeoutError:
break
@@ -208,48 +217,52 @@ class BatchDatabaseWriter:
logger.debug(f"批量写入完成: {len(batch)} 个更新,耗时 {time.time() - start_time:.3f}s")
except Exception as e:
except SQLAlchemyError as e:
self.stats["failed_writes"] += 1
logger.error(f"批量写入失败: {e}")
# 降级到单个写入
for payload in batch:
try:
await self._direct_write(payload.stream_id, payload.update_data)
except Exception as single_e:
except SQLAlchemyError as single_e:
logger.error(f"单个写入也失败: {single_e}")
async def _batch_write_to_database(self, payloads: list[StreamUpdatePayload]):
"""批量写入数据库"""
"""批量写入数据库(单事务、多值 UPSERT"""
if global_config is None:
raise RuntimeError("Global config is not initialized")
if not payloads:
return
# 预组装行数据,确保每行包含 stream_id
rows: list[dict[str, Any]] = []
for p in payloads:
row = {"stream_id": p.stream_id}
row.update(p.update_data)
rows.append(row)
async with get_db_session() as session:
for payload in payloads:
stream_id = payload.stream_id
update_data = payload.update_data
# 根据数据库类型选择不同的插入/更新策略
if global_config.database.database_type == "sqlite":
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data)
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data)
elif global_config.database.database_type == "postgresql":
from sqlalchemy.dialects.postgresql import insert as pg_insert
stmt = pg_insert(ChatStreams).values(stream_id=stream_id, **update_data)
stmt = stmt.on_conflict_do_update(
index_elements=[ChatStreams.stream_id],
set_=update_data
)
else:
# 默认使用SQLite语法
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data)
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data)
# 使用单次事务提交,显著减少 I/O
if global_config.database.database_type == "postgresql":
from sqlalchemy.dialects.postgresql import insert as pg_insert
stmt = pg_insert(ChatStreams).values(rows)
stmt = stmt.on_conflict_do_update(
index_elements=[ChatStreams.stream_id],
set_={k: getattr(stmt.excluded, k) for k in rows[0].keys() if k != "stream_id"}
)
await session.execute(stmt)
await session.commit()
else:
# 默认sqlite
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
stmt = sqlite_insert(ChatStreams).values(rows)
stmt = stmt.on_conflict_do_update(
index_elements=["stream_id"],
set_={k: getattr(stmt.excluded, k) for k in rows[0].keys() if k != "stream_id"}
)
await session.execute(stmt)
await session.commit()
async def _direct_write(self, stream_id: str, update_data: dict[str, Any]):
"""直接写入数据库(降级方案)"""
if global_config is None:

View File

@@ -55,7 +55,7 @@ async def conversation_loop(
stream_id: str,
get_context_func: Callable[[str], Awaitable["StreamContext | None"]],
calculate_interval_func: Callable[[str, bool], Awaitable[float]],
flush_cache_func: Callable[[str], Awaitable[None]],
flush_cache_func: Callable[[str], Awaitable[list[Any]]],
check_force_dispatch_func: Callable[["StreamContext", int], bool],
is_running_func: Callable[[], bool],
) -> AsyncIterator[ConversationTick]:
@@ -121,7 +121,7 @@ async def conversation_loop(
except asyncio.CancelledError:
logger.info(f" [生成器] stream={stream_id[:8]}, 被取消")
break
except Exception as e:
except Exception as e: # noqa: BLE001
logger.error(f" [生成器] stream={stream_id[:8]}, 出错: {e}")
await asyncio.sleep(5.0)
@@ -151,10 +151,10 @@ async def run_chat_stream(
# 创建生成器
tick_generator = conversation_loop(
stream_id=stream_id,
get_context_func=manager._get_stream_context,
calculate_interval_func=manager._calculate_interval,
flush_cache_func=manager._flush_cached_messages_to_unread,
check_force_dispatch_func=manager._needs_force_dispatch_for_context,
get_context_func=manager._get_stream_context, # noqa: SLF001
calculate_interval_func=manager._calculate_interval, # noqa: SLF001
flush_cache_func=manager._flush_cached_messages_to_unread, # noqa: SLF001
check_force_dispatch_func=manager._needs_force_dispatch_for_context, # noqa: SLF001
is_running_func=lambda: manager.is_running,
)
@@ -162,13 +162,13 @@ async def run_chat_stream(
async for tick in tick_generator:
try:
# 获取上下文
context = await manager._get_stream_context(stream_id)
context = await manager._get_stream_context(stream_id) # noqa: SLF001
if not context:
continue
# 并发保护:检查是否正在处理
if context.is_chatter_processing:
if manager._recover_stale_chatter_state(stream_id, context):
if manager._recover_stale_chatter_state(stream_id, context): # noqa: SLF001
logger.warning(f" [驱动器] stream={stream_id[:8]}, 处理标志残留已修复")
else:
logger.debug(f" [驱动器] stream={stream_id[:8]}, Chatter正在处理跳过此Tick")
@@ -182,17 +182,18 @@ async def run_chat_stream(
# 更新能量值
try:
await manager._update_stream_energy(stream_id, context)
await manager._update_stream_energy(stream_id, context) # noqa: SLF001
except Exception as e:
logger.debug(f"更新能量失败: {e}")
# 处理消息
assert global_config is not None
try:
success = await asyncio.wait_for(
manager._process_stream_messages(stream_id, context),
global_config.chat.thinking_timeout
)
async with manager._processing_semaphore:
success = await asyncio.wait_for(
manager._process_stream_messages(stream_id, context), # noqa: SLF001
global_config.chat.thinking_timeout,
)
except asyncio.TimeoutError:
logger.warning(f" [驱动器] stream={stream_id[:8]}, Tick#{tick.tick_count}, 处理超时")
success = False
@@ -208,7 +209,7 @@ async def run_chat_stream(
except asyncio.CancelledError:
raise
except Exception as e:
except Exception as e: # noqa: BLE001
logger.error(f" [驱动器] stream={stream_id[:8]}, 处理Tick时出错: {e}")
manager.stats["total_failures"] += 1
@@ -221,7 +222,7 @@ async def run_chat_stream(
if context and context.stream_loop_task:
context.stream_loop_task = None
logger.debug(f" [驱动器] stream={stream_id[:8]}, 清理任务记录")
except Exception as e:
except Exception as e: # noqa: BLE001
logger.debug(f"清理任务记录失败: {e}")
@@ -268,6 +269,9 @@ class StreamLoopManager:
# 流启动锁:防止并发启动同一个流的多个任务
self._stream_start_locks: dict[str, asyncio.Lock] = {}
# 并发控制:限制同时进行的 Chatter 处理任务数
self._processing_semaphore = asyncio.Semaphore(self.max_concurrent_streams)
logger.info(f"流循环管理器初始化完成 (最大并发流数: {self.max_concurrent_streams})")
# ========================================================================

View File

@@ -104,9 +104,17 @@ class MessageManager:
if not chat_stream:
logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在")
return
# 启动 stream loop 任务(如果尚未启动)
await stream_loop_manager.start_stream_loop(stream_id)
# 快速检查:如果已有驱动器在跑,则跳过重复启动,避免不必要的 await
context = chat_stream.context
if not (context.stream_loop_task and not context.stream_loop_task.done()):
# 异步启动驱动器任务;避免在高并发下阻塞消息入队
await stream_loop_manager.start_stream_loop(stream_id)
# 检查并处理消息打断
await self._check_and_handle_interruption(chat_stream, message)
# 入队消息
await chat_stream.context.add_message(message)
except Exception as e:
@@ -476,8 +484,7 @@ class MessageManager:
is_processing: 是否正在处理
"""
try:
# 尝试更新StreamContext的处理状态
import asyncio
# 尝试更新StreamContext的处理状态(使用顶层 asyncio 导入)
async def _update_context():
try:
chat_manager = get_chat_manager()
@@ -492,7 +499,7 @@ class MessageManager:
try:
loop = asyncio.get_event_loop()
if loop.is_running():
asyncio.create_task(_update_context())
self._update_context_task = asyncio.create_task(_update_context())
else:
# 如果事件循环未运行,则跳过
logger.debug("事件循环未运行跳过StreamContext状态更新")
@@ -512,8 +519,7 @@ class MessageManager:
bool: 是否正在处理
"""
try:
# 尝试从StreamContext获取处理状态
import asyncio
# 尝试从StreamContext获取处理状态(使用顶层 asyncio 导入)
async def _get_context_status():
try:
chat_manager = get_chat_manager()

View File

@@ -1,6 +1,8 @@
import asyncio
import hashlib
import time
from functools import lru_cache
from typing import ClassVar
from rich.traceback import install
from sqlalchemy.dialects.postgresql import insert as pg_insert
@@ -25,6 +27,9 @@ _background_tasks: set[asyncio.Task] = set()
class ChatStream:
"""聊天流对象,存储一个完整的聊天上下文"""
# 类级别的缓存,用于存储计算过的兴趣值(避免重复计算)
_interest_cache: ClassVar[dict] = {}
def __init__(
self,
stream_id: str,
@@ -159,7 +164,19 @@ class ChatStream:
return None
async def _calculate_message_interest(self, db_message):
"""计算消息兴趣值并更新消息对象"""
"""计算消息兴趣值并更新消息对象 - 优化版本使用缓存"""
# 使用消息ID作为缓存键
cache_key = getattr(db_message, "message_id", None)
# 检查缓存
if cache_key and cache_key in ChatStream._interest_cache:
cached_result = ChatStream._interest_cache[cache_key]
db_message.interest_value = cached_result["interest_value"]
db_message.should_reply = cached_result["should_reply"]
db_message.should_act = cached_result["should_act"]
logger.debug(f"消息 {cache_key} 使用缓存的兴趣值: {cached_result['interest_value']:.3f}")
return
try:
from src.chat.interest_system.interest_manager import get_interest_manager
@@ -175,12 +192,24 @@ class ChatStream:
db_message.should_reply = result.should_reply
db_message.should_act = result.should_act
# 缓存结果
if cache_key:
ChatStream._interest_cache[cache_key] = {
"interest_value": result.interest_value,
"should_reply": result.should_reply,
"should_act": result.should_act,
}
# 限制缓存大小防止内存溢出保留最近5000条
if len(ChatStream._interest_cache) > 5000:
oldest_key = next(iter(ChatStream._interest_cache))
del ChatStream._interest_cache[oldest_key]
logger.debug(
f"消息 {db_message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
f"消息 {cache_key} 兴趣值已更新: {result.interest_value:.3f}, "
f"should_reply: {result.should_reply}, should_act: {result.should_act}"
)
else:
logger.warning(f"消息 {db_message.message_id} 兴趣值计算失败: {result.error_message}")
logger.warning(f"消息 {cache_key} 兴趣值计算失败: {result.error_message}")
# 使用默认值
db_message.interest_value = 0.3
db_message.should_reply = False
@@ -362,21 +391,24 @@ class ChatManager:
self.last_messages[stream_id] = message
# logger.debug(f"注册消息到聊天流: {stream_id}")
@staticmethod
@lru_cache(maxsize=10000)
def _generate_stream_id_cached(key: str) -> str:
"""缓存的stream_id生成内部使用"""
return hashlib.sha256(key.encode()).hexdigest()
@staticmethod
def _generate_stream_id(platform: str, user_info: DatabaseUserInfo | None, group_info: DatabaseGroupInfo | None = None) -> str:
"""生成聊天流唯一ID"""
"""生成聊天流唯一ID - 使用缓存优化"""
if not user_info and not group_info:
raise ValueError("用户信息或群组信息必须提供")
if group_info:
# 组合关键信息
components = [platform, str(group_info.group_id)]
key = f"{platform}_{group_info.group_id}"
else:
components = [platform, str(user_info.user_id), "private"] # type: ignore
key = f"{platform}_{user_info.user_id}_private" # type: ignore
# 使用SHA-256生成唯一ID
key = "_".join(components)
return hashlib.sha256(key.encode()).hexdigest()
return ChatManager._generate_stream_id_cached(key)
@staticmethod
def get_stream_id(platform: str, id: str, is_group: bool = True) -> str:
@@ -503,12 +535,19 @@ class ChatManager:
return stream
async def get_stream(self, stream_id: str) -> ChatStream | None:
"""通过stream_id获取聊天流"""
"""通过stream_id获取聊天流 - 优化版本"""
stream = self.streams.get(stream_id)
if not stream:
return None
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
await stream.set_context(self.last_messages[stream_id])
# 只在必要时设置上下文(避免重复调用)
if stream_id not in self.last_messages:
return stream
last_message = self.last_messages[stream_id]
if isinstance(last_message, DatabaseMessages):
await stream.set_context(last_message)
return stream
def get_stream_by_info(
@@ -536,30 +575,30 @@ class ChatManager:
Returns:
dict[str, ChatStream]: 包含所有聊天流的字典key为stream_idvalue为ChatStream对象
"""
return self.streams.copy() # 返回副本以防止外部修改
return self.streams
@staticmethod
def _prepare_stream_data(stream_data_dict: dict) -> dict:
"""准备聊天流保存数据"""
user_info_d = stream_data_dict.get("user_info")
group_info_d = stream_data_dict.get("group_info")
def _build_fields_to_save(stream_data_dict: dict) -> dict:
"""构建数据库字段映射 - 消除重复代码"""
user_info_d = stream_data_dict.get("user_info") or {}
group_info_d = stream_data_dict.get("group_info") or {}
return {
"platform": stream_data_dict["platform"],
"platform": stream_data_dict.get("platform", "") or "",
"create_time": stream_data_dict["create_time"],
"last_active_time": stream_data_dict["last_active_time"],
"user_platform": user_info_d["platform"] if user_info_d else "",
"user_id": user_info_d["user_id"] if user_info_d else "",
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
"group_platform": group_info_d["platform"] if group_info_d else "",
"group_id": group_info_d["group_id"] if group_info_d else "",
"group_name": group_info_d["group_name"] if group_info_d else "",
"user_platform": user_info_d.get("platform", ""),
"user_id": user_info_d.get("user_id", ""),
"user_nickname": user_info_d.get("user_nickname", ""),
"user_cardname": user_info_d.get("user_cardname"),
"group_platform": group_info_d.get("platform", ""),
"group_id": group_info_d.get("group_id", ""),
"group_name": group_info_d.get("group_name", ""),
"energy_value": stream_data_dict.get("energy_value", 5.0),
"sleep_pressure": stream_data_dict.get("sleep_pressure", 0.0),
"focus_energy": stream_data_dict.get("focus_energy", 0.5),
# 新增动态兴趣度系统字段
"base_interest_energy": stream_data_dict.get("base_interest_energy", 0.5),
"message_interest_total": stream_data_dict.get("message_interest_total", 0.0),
"message_count": stream_data_dict.get("message_count", 0),
@@ -570,6 +609,11 @@ class ChatManager:
"interruption_count": stream_data_dict.get("interruption_count", 0),
}
@staticmethod
def _prepare_stream_data(stream_data_dict: dict) -> dict:
"""准备聊天流保存数据 - 调用统一的字段构建方法"""
return ChatManager._build_fields_to_save(stream_data_dict)
@staticmethod
async def _save_stream(stream: ChatStream):
"""保存聊天流到数据库 - 优化版本使用异步批量写入"""
@@ -624,38 +668,12 @@ class ChatManager:
raise RuntimeError("Global config is not initialized")
async with get_db_session() as session:
user_info_d = s_data_dict.get("user_info")
group_info_d = s_data_dict.get("group_info")
fields_to_save = {
"platform": s_data_dict.get("platform", "") or "",
"create_time": s_data_dict["create_time"],
"last_active_time": s_data_dict["last_active_time"],
"user_platform": user_info_d["platform"] if user_info_d else "",
"user_id": user_info_d["user_id"] if user_info_d else "",
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
"group_platform": group_info_d.get("platform", "") or "" if group_info_d else "",
"group_id": group_info_d["group_id"] if group_info_d else "",
"group_name": group_info_d["group_name"] if group_info_d else "",
"energy_value": s_data_dict.get("energy_value", 5.0),
"sleep_pressure": s_data_dict.get("sleep_pressure", 0.0),
"focus_energy": s_data_dict.get("focus_energy", 0.5),
# 新增动态兴趣度系统字段
"base_interest_energy": s_data_dict.get("base_interest_energy", 0.5),
"message_interest_total": s_data_dict.get("message_interest_total", 0.0),
"message_count": s_data_dict.get("message_count", 0),
"action_count": s_data_dict.get("action_count", 0),
"reply_count": s_data_dict.get("reply_count", 0),
"last_interaction_time": s_data_dict.get("last_interaction_time", time.time()),
"consecutive_no_reply": s_data_dict.get("consecutive_no_reply", 0),
"interruption_count": s_data_dict.get("interruption_count", 0),
}
fields_to_save = ChatManager._build_fields_to_save(s_data_dict)
if global_config.database.database_type == "sqlite":
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save)
elif global_config.database.database_type == "postgresql":
stmt = pg_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
# PostgreSQL 需要使用 constraint 参数或正确的 index_elements
stmt = stmt.on_conflict_do_update(
index_elements=[ChatStreams.stream_id],
set_=fields_to_save
@@ -678,14 +696,16 @@ class ChatManager:
await self._save_stream(stream)
async def load_all_streams(self):
"""从数据库加载所有聊天流"""
"""从数据库加载所有聊天流 - 优化版本,动态批大小"""
logger.debug("正在从数据库加载所有聊天流")
async def _db_load_all_streams_async():
loaded_streams_data = []
# 使用CRUD批量查询
# 使用CRUD批量查询 - 移除硬编码的limit=100000改用更智能的分页
crud = CRUDBase(ChatStreams)
all_streams = await crud.get_multi(limit=100000) # 获取所有聊天流
# 先获取总数,以优化批处理大小
all_streams = await crud.get_multi(limit=None) # 获取所有聊天流
for model_instance in all_streams:
user_info_data = {
@@ -733,8 +753,6 @@ class ChatManager:
stream.saved = True
self.streams[stream.stream_id] = stream
# 不在异步加载中设置上下文,避免复杂依赖
# if stream.stream_id in self.last_messages:
# await stream.set_context(self.last_messages[stream.stream_id])
except Exception as e:
logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}")

View File

@@ -30,7 +30,7 @@ from __future__ import annotations
import os
import re
import traceback
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, ClassVar, cast
from mofox_wire import MessageEnvelope, MessageRuntime
@@ -53,6 +53,22 @@ logger = get_logger("message_handler")
# 项目根目录
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
# 预编译的正则表达式缓存(避免重复编译)
_compiled_regex_cache: dict[str, re.Pattern] = {}
# 硬编码过滤关键词(缓存到全局变量,避免每次创建列表)
_MEDIA_FAILURE_KEYWORDS = frozenset(["[表情包(描述生成失败)]", "[图片(描述生成失败)]"])
def _get_compiled_pattern(pattern: str) -> re.Pattern | None:
"""获取编译的正则表达式,使用缓存避免重复编译"""
if pattern not in _compiled_regex_cache:
try:
_compiled_regex_cache[pattern] = re.compile(pattern)
except re.error as e:
logger.warning(f"正则表达式编译失败: {pattern}, 错误: {e}")
return None
return _compiled_regex_cache.get(pattern)
def _check_ban_words(text: str, chat: "ChatStream", userinfo) -> bool:
"""检查消息是否包含过滤词"""
if global_config is None:
@@ -65,11 +81,13 @@ def _check_ban_words(text: str, chat: "ChatStream", userinfo) -> bool:
return True
return False
def _check_ban_regex(text: str, chat: "ChatStream", userinfo) -> bool:
"""检查消息是否匹配过滤正则表达式"""
"""检查消息是否匹配过滤正则表达式 - 优化版本使用预编译缓存"""
if global_config is None:
return False
for pattern in global_config.message_receive.ban_msgs_regex:
if re.search(pattern, text):
compiled_pattern = _get_compiled_pattern(pattern)
if compiled_pattern and compiled_pattern.search(text):
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
logger.info(f"[正则表达式过滤]消息匹配到{pattern}filtered")
@@ -97,6 +115,10 @@ class MessageHandler:
4. 普通消息处理:触发事件、存储、情绪更新
"""
# 类级别缓存:命令查询结果缓存(减少重复查询)
_plus_command_cache: ClassVar[dict[str, Any]] = {}
_base_command_cache: ClassVar[dict[str, Any]] = {}
def __init__(self):
self._started = False
self._message_manager_started = False
@@ -108,6 +130,36 @@ class MessageHandler:
"""设置 CoreSinkManager 引用"""
self._core_sink_manager = manager
async def _get_or_create_chat_stream(
self, platform: str, user_info: dict | None, group_info: dict | None
) -> "ChatStream":
"""获取或创建聊天流 - 统一方法"""
from src.chat.message_receive.chat_stream import get_chat_manager
return await get_chat_manager().get_or_create_stream(
platform=platform,
user_info=DatabaseUserInfo.from_dict(cast(dict[str, Any], user_info)) if user_info else None,
group_info=DatabaseGroupInfo.from_dict(cast(dict[str, Any], group_info)) if group_info else None,
)
async def _process_message_to_database(
self, envelope: MessageEnvelope, chat: "ChatStream"
) -> DatabaseMessages:
"""将消息信封转换为 DatabaseMessages - 统一方法"""
from src.chat.message_receive.message_processor import process_message_from_dict
message = await process_message_from_dict(
message_dict=envelope,
stream_id=chat.stream_id,
platform=chat.platform
)
# 填充聊天流时间信息
message.chat_info.create_time = chat.create_time
message.chat_info.last_active_time = chat.last_active_time
return message
def register_handlers(self, runtime: MessageRuntime) -> None:
"""
向 MessageRuntime 注册消息处理器和钩子
@@ -279,25 +331,10 @@ class MessageHandler:
# 获取或创建聊天流
platform = message_info.get("platform", "unknown")
from src.chat.message_receive.chat_stream import get_chat_manager
chat = await get_chat_manager().get_or_create_stream(
platform=platform,
user_info=DatabaseUserInfo.from_dict(cast(dict[str, Any], user_info)) if user_info else None, # type: ignore
group_info=DatabaseGroupInfo.from_dict(cast(dict[str, Any], group_info)) if group_info else None,
)
chat = await self._get_or_create_chat_stream(platform, user_info, group_info)
# 将消息信封转换为 DatabaseMessages
from src.chat.message_receive.message_processor import process_message_from_dict
message = await process_message_from_dict(
message_dict=envelope,
stream_id=chat.stream_id,
platform=chat.platform
)
# 填充聊天流时间信息
message.chat_info.create_time = chat.create_time
message.chat_info.last_active_time = chat.last_active_time
message = await self._process_message_to_database(envelope, chat)
# 标记为 notice 消息
message.is_notify = True
@@ -337,8 +374,7 @@ class MessageHandler:
except Exception as e:
logger.error(f"处理 Notice 消息时出错: {e}")
import traceback
traceback.print_exc()
logger.error(traceback.format_exc())
return None
async def _add_notice_to_manager(
@@ -429,25 +465,10 @@ class MessageHandler:
# 获取或创建聊天流
platform = message_info.get("platform", "unknown")
from src.chat.message_receive.chat_stream import get_chat_manager
chat = await get_chat_manager().get_or_create_stream(
platform=platform,
user_info=DatabaseUserInfo.from_dict(cast(dict[str, Any], user_info)) if user_info else None, # type: ignore
group_info=DatabaseGroupInfo.from_dict(cast(dict[str, Any], group_info)) if group_info else None,
)
chat = await self._get_or_create_chat_stream(platform, user_info, group_info)
# 将消息信封转换为 DatabaseMessages
from src.chat.message_receive.message_processor import process_message_from_dict
message = await process_message_from_dict(
message_dict=envelope,
stream_id=chat.stream_id,
platform=chat.platform
)
# 填充聊天流时间信息
message.chat_info.create_time = chat.create_time
message.chat_info.last_active_time = chat.last_active_time
message = await self._process_message_to_database(envelope, chat)
# 注册消息到聊天管理器
from src.chat.message_receive.chat_stream import get_chat_manager
@@ -462,9 +483,8 @@ class MessageHandler:
logger.info(f"[{chat_name}]{user_nickname}:{message.processed_plain_text}\u001b[0m")
# 硬编码过滤
failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"]
processed_text = message.processed_plain_text or ""
if any(keyword in processed_text for keyword in failure_keywords):
if any(keyword in processed_text for keyword in _MEDIA_FAILURE_KEYWORDS):
logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({processed_text}),消息被静默处理。")
return None

View File

@@ -3,6 +3,7 @@
基于 mofox-wire 的 TypedDict 形式构建消息数据,然后转换为 DatabaseMessages
"""
import base64
import re
import time
from typing import Any
@@ -20,6 +21,15 @@ from src.config.config import global_config
logger = get_logger("message_processor")
# 预编译正则表达式
_AT_PATTERN = re.compile(r"^([^:]+):(.+)$")
# 常量定义:段类型集合
RECURSIVE_SEGMENT_TYPES = frozenset(["seglist"])
MEDIA_SEGMENT_TYPES = frozenset(["image", "emoji", "voice", "video"])
METADATA_SEGMENT_TYPES = frozenset(["mention_bot", "priority_info"])
SPECIAL_SEGMENT_TYPES = frozenset(["at", "reply", "file"])
async def process_message_from_dict(message_dict: MessageEnvelope, stream_id: str, platform: str) -> DatabaseMessages:
"""从适配器消息字典处理并生成 DatabaseMessages
@@ -101,7 +111,7 @@ async def process_message_from_dict(message_dict: MessageEnvelope, stream_id: st
mentioned_value = processing_state.get("is_mentioned")
if isinstance(mentioned_value, bool):
is_mentioned = mentioned_value
elif isinstance(mentioned_value, (int, float)):
elif isinstance(mentioned_value, int | float):
is_mentioned = mentioned_value != 0
# 使用 TypedDict 风格的数据构建 DatabaseMessages
@@ -223,13 +233,12 @@ async def _process_single_segment(
state["is_at"] = True
# 处理at消息格式为"@<昵称:QQ号>"
if isinstance(seg_data, str):
if ":" in seg_data:
# 标准格式: "昵称:QQ号"
nickname, qq_id = seg_data.split(":", 1)
match = _AT_PATTERN.match(seg_data)
if match:
nickname, qq_id = match.groups()
return f"@<{nickname}:{qq_id}>"
else:
logger.warning(f"[at处理] 无法解析格式: '{seg_data}'")
return f"@{seg_data}"
logger.warning(f"[at处理] 无法解析格式: '{seg_data}'")
return f"@{seg_data}"
logger.warning(f"[at处理] 数据类型异常: {type(seg_data)}")
return f"@{seg_data}" if isinstance(seg_data, str) else "@未知用户"
@@ -272,7 +281,7 @@ async def _process_single_segment(
return "[发了一段语音,网卡了加载不出来]"
elif seg_type == "mention_bot":
if isinstance(seg_data, (int, float)):
if isinstance(seg_data, int | float):
state["is_mentioned"] = float(seg_data)
return ""
@@ -368,19 +377,18 @@ def _prepare_additional_config(
str | None: JSON 字符串格式的 additional_config如果为空则返回 None
"""
try:
additional_config_data = {}
# 首先获取adapter传递的additional_config
additional_config_raw = message_info.get("additional_config")
if additional_config_raw:
if isinstance(additional_config_raw, dict):
additional_config_data = additional_config_raw.copy()
elif isinstance(additional_config_raw, str):
try:
additional_config_data = orjson.loads(additional_config_raw)
except Exception as e:
logger.warning(f"无法解析 additional_config JSON: {e}")
additional_config_data = {}
if isinstance(additional_config_raw, dict):
additional_config_data = additional_config_raw.copy()
elif isinstance(additional_config_raw, str):
try:
additional_config_data = orjson.loads(additional_config_raw)
except Exception as e:
logger.warning(f"无法解析 additional_config JSON: {e}")
additional_config_data = {}
else:
additional_config_data = {}
# 添加notice相关标志
if is_notify:

View File

@@ -1,4 +1,5 @@
import asyncio
import collections
import re
import time
import traceback
@@ -19,6 +20,16 @@ if TYPE_CHECKING:
logger = get_logger("message_storage")
# 预编译的正则表达式(避免重复编译)
_COMPILED_FILTER_PATTERN = re.compile(
r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>",
re.DOTALL
)
_COMPILED_IMAGE_PATTERN = re.compile(r"\[图片:([^\]]+)\]")
# 全局正则表达式缓存
_regex_cache: dict[str, re.Pattern] = {}
class MessageStorageBatcher:
"""
@@ -116,25 +127,28 @@ class MessageStorageBatcher:
async def flush(self, force: bool = False):
"""执行批量写入, 支持强制落库和延迟提交策略。"""
async with self._flush_barrier:
# 原子性地交换消息队列,避免锁定时间过长
async with self._lock:
messages_to_store = list(self.pending_messages)
self.pending_messages.clear()
if not self.pending_messages:
return
messages_to_store = self.pending_messages
self.pending_messages = collections.deque(maxlen=self.batch_size)
if messages_to_store:
prepared_messages: list[dict[str, Any]] = []
for msg_data in messages_to_store:
try:
message_dict = await self._prepare_message_dict(
msg_data["message"],
msg_data["chat_stream"],
)
if message_dict:
prepared_messages.append(message_dict)
except Exception as e:
logger.error(f"准备消息数据失败: {e}")
# 处理消息,这部分不在锁内执行,提高并发性
prepared_messages: list[dict[str, Any]] = []
for msg_data in messages_to_store:
try:
message_dict = await self._prepare_message_dict(
msg_data["message"],
msg_data["chat_stream"],
)
if message_dict:
prepared_messages.append(message_dict)
except Exception as e:
logger.error(f"准备消息数据失败: {e}")
if prepared_messages:
self._prepared_buffer.extend(prepared_messages)
if prepared_messages:
self._prepared_buffer.extend(prepared_messages)
await self._maybe_commit_buffer(force=force)
@@ -200,102 +214,66 @@ class MessageStorageBatcher:
return message_dict
async def _prepare_message_object(self, message, chat_stream):
"""准备消息对象(从原 store_message 逻辑提取)"""
"""准备消息对象(从原 store_message 逻辑提取) - 优化版本"""
try:
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
if not isinstance(message, DatabaseMessages):
logger.error("MessageStorageBatcher expects DatabaseMessages instances")
return None
# 优化:使用预编译的正则表达式
processed_plain_text = message.processed_plain_text or ""
if processed_plain_text:
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
filtered_processed_plain_text = re.sub(
pattern, "", processed_plain_text or "", flags=re.DOTALL
)
filtered_processed_plain_text = _COMPILED_FILTER_PATTERN.sub("", processed_plain_text)
display_message = message.display_message or message.processed_plain_text or ""
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
filtered_display_message = _COMPILED_FILTER_PATTERN.sub("", display_message)
msg_id = message.message_id
msg_time = message.time
chat_id = message.chat_id
reply_to = message.reply_to or ""
is_mentioned = message.is_mentioned
interest_value = message.interest_value or 0.0
priority_mode = message.priority_mode
priority_info_json = message.priority_info
is_emoji = message.is_emoji or False
is_picid = message.is_picid or False
is_notify = message.is_notify or False
is_command = message.is_command or False
is_public_notice = message.is_public_notice or False
notice_type = message.notice_type
actions = orjson.dumps(message.actions).decode("utf-8") if message.actions else None
should_reply = message.should_reply
should_act = message.should_act
additional_config = message.additional_config
key_words = MessageStorage._serialize_keywords(message.key_words)
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
memorized_times = getattr(message, "memorized_times", 0)
user_platform = message.user_info.platform if message.user_info else ""
user_id = message.user_info.user_id if message.user_info else ""
user_nickname = message.user_info.user_nickname if message.user_info else ""
user_cardname = message.user_info.user_cardname if message.user_info else None
chat_info_stream_id = message.chat_info.stream_id if message.chat_info else ""
chat_info_platform = message.chat_info.platform if message.chat_info else ""
chat_info_create_time = message.chat_info.create_time if message.chat_info else 0.0
chat_info_last_active_time = message.chat_info.last_active_time if message.chat_info else 0.0
chat_info_user_platform = message.chat_info.user_info.platform if message.chat_info and message.chat_info.user_info else ""
chat_info_user_id = message.chat_info.user_info.user_id if message.chat_info and message.chat_info.user_info else ""
chat_info_user_nickname = message.chat_info.user_info.user_nickname if message.chat_info and message.chat_info.user_info else ""
chat_info_user_cardname = message.chat_info.user_info.user_cardname if message.chat_info and message.chat_info.user_info else None
chat_info_group_platform = message.group_info.platform if message.group_info else None
chat_info_group_id = message.group_info.group_id if message.group_info else None
chat_info_group_name = message.group_info.group_name if message.group_info else None
# 优化:一次性构建字典,避免多次条件判断
user_info = message.user_info or {}
chat_info = message.chat_info or {}
chat_info_user = chat_info.user_info or {} if chat_info else {}
group_info = message.group_info or {}
return Messages(
message_id=msg_id,
time=msg_time,
chat_id=chat_id,
reply_to=reply_to,
is_mentioned=is_mentioned,
chat_info_stream_id=chat_info_stream_id,
chat_info_platform=chat_info_platform,
chat_info_user_platform=chat_info_user_platform,
chat_info_user_id=chat_info_user_id,
chat_info_user_nickname=chat_info_user_nickname,
chat_info_user_cardname=chat_info_user_cardname,
chat_info_group_platform=chat_info_group_platform,
chat_info_group_id=chat_info_group_id,
chat_info_group_name=chat_info_group_name,
chat_info_create_time=chat_info_create_time,
chat_info_last_active_time=chat_info_last_active_time,
user_platform=user_platform,
user_id=user_id,
user_nickname=user_nickname,
user_cardname=user_cardname,
message_id=message.message_id,
time=message.time,
chat_id=message.chat_id,
reply_to=message.reply_to or "",
is_mentioned=message.is_mentioned,
chat_info_stream_id=chat_info.stream_id if chat_info else "",
chat_info_platform=chat_info.platform if chat_info else "",
chat_info_user_platform=chat_info_user.platform if chat_info_user else "",
chat_info_user_id=chat_info_user.user_id if chat_info_user else "",
chat_info_user_nickname=chat_info_user.user_nickname if chat_info_user else "",
chat_info_user_cardname=chat_info_user.user_cardname if chat_info_user else None,
chat_info_group_platform=group_info.platform if group_info else None,
chat_info_group_id=group_info.group_id if group_info else None,
chat_info_group_name=group_info.group_name if group_info else None,
chat_info_create_time=chat_info.create_time if chat_info else 0.0,
chat_info_last_active_time=chat_info.last_active_time if chat_info else 0.0,
user_platform=user_info.platform if user_info else "",
user_id=user_info.user_id if user_info else "",
user_nickname=user_info.user_nickname if user_info else "",
user_cardname=user_info.user_cardname if user_info else None,
processed_plain_text=filtered_processed_plain_text,
display_message=filtered_display_message,
memorized_times=memorized_times,
interest_value=interest_value,
priority_mode=priority_mode,
priority_info=priority_info_json,
additional_config=additional_config,
is_emoji=is_emoji,
is_picid=is_picid,
is_notify=is_notify,
is_command=is_command,
is_public_notice=is_public_notice,
notice_type=notice_type,
actions=actions,
should_reply=should_reply,
should_act=should_act,
key_words=key_words,
key_words_lite=key_words_lite,
memorized_times=getattr(message, "memorized_times", 0),
interest_value=message.interest_value or 0.0,
priority_mode=message.priority_mode,
priority_info=message.priority_info,
additional_config=message.additional_config,
is_emoji=message.is_emoji or False,
is_picid=message.is_picid or False,
is_notify=message.is_notify or False,
is_command=message.is_command or False,
is_public_notice=message.is_public_notice or False,
notice_type=message.notice_type,
actions=orjson.dumps(message.actions).decode("utf-8") if message.actions else None,
should_reply=message.should_reply,
should_act=message.should_act,
key_words=MessageStorage._serialize_keywords(message.key_words),
key_words_lite=MessageStorage._serialize_keywords(message.key_words_lite),
)
except Exception as e:
@@ -474,7 +452,7 @@ class MessageStorage:
@staticmethod
async def update_message(message_data: dict, use_batch: bool = True):
"""
更新消息ID从消息字典
更新消息ID从消息字典- 优化版本
优化: 添加批处理选项,将多个更新操作合并,减少数据库连接
@@ -491,25 +469,23 @@ class MessageStorage:
segment_type = message_segment.get("type") if isinstance(message_segment, dict) else None
segment_data = message_segment.get("data", {}) if isinstance(message_segment, dict) else {}
qq_message_id = None
# 优化:预定义类型集合,避免重复的 if-elif 检查
SKIPPED_TYPES = {"adapter_response", "adapter_command"}
VALID_ID_TYPES = {"notify", "text", "reply"}
logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {segment_type}")
# 根据消息段类型提取message_id
if segment_type == "notify":
# 检查是否是需要跳过的类型
if segment_type in SKIPPED_TYPES:
logger.debug(f"跳过消息段类型: {segment_type}")
return
# 尝试获取消息ID
qq_message_id = None
if segment_type in VALID_ID_TYPES:
qq_message_id = segment_data.get("id")
elif segment_type == "text":
qq_message_id = segment_data.get("id")
elif segment_type == "reply":
qq_message_id = segment_data.get("id")
if qq_message_id:
if segment_type == "reply" and qq_message_id:
logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}")
elif segment_type == "adapter_response":
logger.debug("适配器响应消息不需要更新ID")
return
elif segment_type == "adapter_command":
logger.debug("适配器命令消息不需要更新ID")
return
else:
logger.debug(f"未知的消息段类型: {segment_type}跳过ID更新")
return
@@ -552,22 +528,20 @@ class MessageStorage:
@staticmethod
async def replace_image_descriptions(text: str) -> str:
"""异步地将文本中的所有[图片:描述]标记替换为[picid:image_id]"""
pattern = r"\[图片:([^\]]+)\]"
"""异步地将文本中的所有[图片:描述]标记替换为[picid:image_id] - 优化版本"""
# 如果没有匹配项,提前返回以提高效率
if not re.search(pattern, text):
if not _COMPILED_IMAGE_PATTERN.search(text):
return text
# re.sub不支持异步替换函数所以我们需要手动迭代和替换
new_text = []
last_end = 0
for match in re.finditer(pattern, text):
for match in _COMPILED_IMAGE_PATTERN.finditer(text):
# 添加上一个匹配到当前匹配之间的文本
new_text.append(text[last_end:match.start()])
description = match.group(1).strip()
replacement = match.group(0) # 默认情况下,替换为原始匹配文本
replacement = match.group(0) # 默认情况下,替换为原始匹配文本
try:
async with get_db_session() as session:
# 查询数据库以找到具有该描述的最新图片记录
@@ -633,19 +607,49 @@ class MessageStorage:
interest_map: dict[str, float],
reply_map: dict[str, bool] | None = None,
) -> None:
"""批量更新消息的兴趣度与回复标记"""
"""批量更新消息的兴趣度与回复标记 - 优化版本"""
if not interest_map:
return
try:
async with get_db_session() as session:
for message_id, interest_value in interest_map.items():
values = {"interest_value": interest_value}
if reply_map and message_id in reply_map:
values["should_reply"] = reply_map[message_id]
# 注意SQLAlchemy 2.0 对 ORM update + executemany 会走
# “Bulk UPDATE by Primary Key” 路径,要求每行参数包含主键(Messages.id)。
# 这里我们按 message_id 更新,因此使用 Core Table + bindparam。
from sqlalchemy import bindparam, update
stmt = update(Messages).where(Messages.message_id == message_id).values(**values)
await session.execute(stmt)
messages_table = Messages.__table__
interest_mappings: list[dict[str, Any]] = [
{"b_message_id": message_id, "b_interest_value": interest_value}
for message_id, interest_value in interest_map.items()
]
if interest_mappings:
stmt_interest = (
update(messages_table)
.where(messages_table.c.message_id == bindparam("b_message_id"))
.values(interest_value=bindparam("b_interest_value"))
)
await session.execute(stmt_interest, interest_mappings)
if reply_map:
reply_mappings: list[dict[str, Any]] = [
{"b_message_id": message_id, "b_should_reply": should_reply}
for message_id, should_reply in reply_map.items()
if message_id in interest_map
]
if reply_mappings and len(reply_mappings) != len(reply_map):
logger.debug(
f"批量更新 should_reply 过滤了 {len(reply_map) - len(reply_mappings)} 条不在兴趣度更新集合中的记录"
)
if reply_mappings:
stmt_reply = (
update(messages_table)
.where(messages_table.c.message_id == bindparam("b_message_id"))
.values(should_reply=bindparam("b_should_reply"))
)
await session.execute(stmt_reply, reply_mappings)
await session.commit()
logger.debug(f"批量更新兴趣度 {len(interest_map)} 条记录")

View File

@@ -1,9 +1,10 @@
from threading import Lock
from typing import Any, Literal
from pydantic import Field
from pydantic import Field, PrivateAttr
from src.config.config_base import ValidatedConfigBase
from src.config.official_configs import InnerConfig
class APIProvider(ValidatedConfigBase):
@@ -21,6 +22,9 @@ class APIProvider(ValidatedConfigBase):
)
retry_interval: int = Field(default=10, ge=0, description="重试间隔如果API调用失败重试的间隔时间单位")
_api_key_lock: Lock = PrivateAttr(default_factory=Lock)
_api_key_index: int = PrivateAttr(default=0)
@classmethod
def validate_base_url(cls, v):
"""验证base_url确保URL格式正确"""
@@ -44,11 +48,6 @@ class APIProvider(ValidatedConfigBase):
raise ValueError("API密钥必须是字符串或字符串列表")
return v
def __init__(self, **data):
super().__init__(**data)
self._api_key_lock = Lock()
self._api_key_index = 0
def get_api_key(self) -> str:
with self._api_key_lock:
if isinstance(self.api_key, str):
@@ -134,6 +133,7 @@ class ModelTaskConfig(ValidatedConfigBase):
replyer_private: TaskConfig = Field(..., description="normal_chat首要回复模型模型配置私聊使用")
maizone: TaskConfig = Field(..., description="maizone专用模型")
emotion: TaskConfig = Field(..., description="情绪模型配置")
mood: TaskConfig = Field(..., description="心情模型配置")
vlm: TaskConfig = Field(..., description="视觉语言模型配置")
voice: TaskConfig = Field(..., description="语音识别模型配置")
tool_use: TaskConfig = Field(..., description="专注工具使用模型配置")
@@ -178,14 +178,26 @@ class ModelTaskConfig(ValidatedConfigBase):
class APIAdapterConfig(ValidatedConfigBase):
"""API Adapter配置类"""
inner: InnerConfig = Field(..., description="配置元信息")
models: list[ModelInfo] = Field(..., min_length=1, description="模型列表")
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表")
_api_providers_dict: dict[str, APIProvider] = PrivateAttr(default_factory=dict)
_models_dict: dict[str, ModelInfo] = PrivateAttr(default_factory=dict)
def __init__(self, **data):
super().__init__(**data)
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
self.models_dict = {model.name: model for model in self.models}
self._api_providers_dict = {provider.name: provider for provider in self.api_providers}
self._models_dict = {model.name: model for model in self.models}
@property
def api_providers_dict(self) -> dict[str, APIProvider]:
return self._api_providers_dict
@property
def models_dict(self) -> dict[str, ModelInfo]:
return self._models_dict
@classmethod
def validate_models_list(cls, v):

View File

@@ -1,10 +1,14 @@
import os
import shutil
import sys
import typing
import types
from datetime import datetime
from pathlib import Path
from typing import Any, get_args, get_origin
import tomlkit
from pydantic import Field
from pydantic import BaseModel, Field, PrivateAttr
from rich.traceback import install
from tomlkit import TOMLDocument
from tomlkit.items import KeyType, Table
@@ -25,6 +29,8 @@ from src.config.official_configs import (
EmojiConfig,
ExperimentalConfig,
ExpressionConfig,
InnerConfig,
LogConfig,
KokoroFlowChatterConfig,
LPMMKnowledgeConfig,
MemoryConfig,
@@ -65,7 +71,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
# 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/
MMC_VERSION = "0.13.1-alpha.1"
MMC_VERSION = "0.13.1-alpha.2"
# 全局配置变量
_CONFIG_INITIALIZED = False
@@ -180,6 +186,76 @@ def _remove_obsolete_keys(target: TOMLDocument | dict | Table, reference: TOMLDo
_remove_obsolete_keys(target[key], reference[key]) # type: ignore
def _prune_unknown_keys_by_schema(target: TOMLDocument | Table, schema_model: type[BaseModel]):
"""
基于 Pydantic Schema 递归移除未知配置键(含可重复的 AoT 表)。
说明:
- 只移除 schema 中不存在的键,避免跨版本遗留废弃配置项。
- 对于 list[BaseModel] 字段TOML 的 [[...]]),会遍历每个元素并递归清理。
- 对于 dict[str, Any] 等自由结构字段,不做键级裁剪。
"""
def _strip_optional(annotation: Any) -> Any:
origin = get_origin(annotation)
if origin is None:
return annotation
# 兼容 | None 与 Union[..., None]
union_type = getattr(types, "UnionType", None)
if origin is union_type or origin is typing.Union:
args = [a for a in get_args(annotation) if a is not type(None)]
if len(args) == 1:
return args[0]
return annotation
def _is_model_type(annotation: Any) -> bool:
return isinstance(annotation, type) and issubclass(annotation, BaseModel)
def _prune_table(table: TOMLDocument | Table, model: type[BaseModel]):
name_by_key: dict[str, str] = {}
allowed_keys: set[str] = set()
for field_name, field_info in model.model_fields.items():
allowed_keys.add(field_name)
name_by_key[field_name] = field_name
alias = getattr(field_info, "alias", None)
if isinstance(alias, str) and alias:
allowed_keys.add(alias)
name_by_key[alias] = field_name
for key in list(table.keys()):
if key not in allowed_keys:
del table[key]
continue
field_name = name_by_key[key]
field_info = model.model_fields[field_name]
annotation = _strip_optional(getattr(field_info, "annotation", Any))
value = table.get(key)
if value is None:
continue
if _is_model_type(annotation) and isinstance(value, (TOMLDocument, Table)):
_prune_table(value, annotation)
continue
origin = get_origin(annotation)
if origin is list:
args = get_args(annotation)
elem_ann = _strip_optional(args[0]) if args else Any
# list[BaseModel] 对应 TOML 的 AoT[[...]]
if _is_model_type(elem_ann) and hasattr(value, "__iter__"):
for item in value:
if isinstance(item, (TOMLDocument, Table)):
_prune_table(item, elem_ann)
_prune_table(target, schema_model)
def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
"""
将source字典的值更新到target字典中
@@ -232,13 +308,14 @@ def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dic
target[key] = value
def _update_config_generic(config_name: str, template_name: str):
def _update_config_generic(config_name: str, template_name: str, schema_model: type[BaseModel] | None = None):
"""
通用的配置文件更新函数
Args:
config_name: 配置文件名(不含扩展名),如 'bot_config''model_config'
template_name: 模板文件名(不含扩展名),如 'bot_config_template''model_config_template'
schema_model: 用于裁剪未知键的 Pydantic 模型(避免跨版本遗留废弃配置项)
"""
# 获取根目录路径
old_config_dir = os.path.join(CONFIG_DIR, "old")
@@ -355,11 +432,14 @@ def _update_config_generic(config_name: str, template_name: str):
logger.info(f"开始合并{config_name}新旧配置...")
_update_dict(new_config, old_config)
# 移除在新模板中已不存在的旧配置项
# 移除未知/废弃的旧配置项(尤其是可重复的 [[...]] 段落)
logger.info(f"开始移除{config_name}中已废弃的配置项...")
with open(template_path, encoding="utf-8") as f:
template_doc = tomlkit.load(f)
_remove_obsolete_keys(new_config, template_doc)
if schema_model is not None:
_prune_unknown_keys_by_schema(new_config, schema_model)
else:
with open(template_path, encoding="utf-8") as f:
template_doc = tomlkit.load(f)
_remove_obsolete_keys(new_config, template_doc)
logger.info(f"已移除{config_name}中已废弃的配置项")
# 保存更新后的配置(保留注释和格式)
@@ -370,18 +450,18 @@ def _update_config_generic(config_name: str, template_name: str):
def update_config():
"""更新bot_config.toml配置文件"""
_update_config_generic("bot_config", "bot_config_template")
_update_config_generic("bot_config", "bot_config_template", schema_model=Config)
def update_model_config():
"""更新model_config.toml配置文件"""
_update_config_generic("model_config", "model_config_template")
_update_config_generic("model_config", "model_config_template", schema_model=APIAdapterConfig)
class Config(ValidatedConfigBase):
"""总配置类"""
MMC_VERSION: str = Field(default=MMC_VERSION, description="MaiCore版本号")
inner: InnerConfig = Field(..., description="配置元信息")
database: DatabaseConfig = Field(..., description="数据库配置")
bot: BotConfig = Field(..., description="机器人基本配置")
@@ -397,6 +477,7 @@ class Config(ValidatedConfigBase):
chinese_typo: ChineseTypoConfig = Field(..., description="中文错别字配置")
response_post_process: ResponsePostProcessConfig = Field(..., description="响应后处理配置")
response_splitter: ResponseSplitterConfig = Field(..., description="响应分割配置")
log: LogConfig = Field(..., description="日志配置")
experimental: ExperimentalConfig = Field(default_factory=lambda: ExperimentalConfig(), description="实验性功能配置")
message_bus: MessageBusConfig = Field(..., description="消息总线配置")
lpmm_knowledge: LPMMKnowledgeConfig = Field(..., description="LPMM知识配置")
@@ -433,18 +514,34 @@ class Config(ValidatedConfigBase):
default_factory=lambda: PluginHttpSystemConfig(), description="插件HTTP端点系统配置"
)
@property
def MMC_VERSION(self) -> str: # noqa: N802
return MMC_VERSION
class APIAdapterConfig(ValidatedConfigBase):
"""API Adapter配置类"""
inner: InnerConfig = Field(..., description="配置元信息")
models: list[ModelInfo] = Field(..., min_length=1, description="模型列表")
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表")
_api_providers_dict: dict[str, APIProvider] = PrivateAttr(default_factory=dict)
_models_dict: dict[str, ModelInfo] = PrivateAttr(default_factory=dict)
def __init__(self, **data):
super().__init__(**data)
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
self.models_dict = {model.name: model for model in self.models}
self._api_providers_dict = {provider.name: provider for provider in self.api_providers}
self._models_dict = {model.name: model for model in self.models}
@property
def api_providers_dict(self) -> dict[str, APIProvider]:
return self._api_providers_dict
@property
def models_dict(self) -> dict[str, ModelInfo]:
return self._models_dict
@classmethod
def validate_models_list(cls, v):
@@ -502,9 +599,14 @@ def load_config(config_path: str) -> Config:
Returns:
Config对象
"""
# 读取配置文件
with open(config_path, encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 读取配置文件(会自动删除未知/废弃配置项)
original_text = Path(config_path).read_text(encoding="utf-8")
config_data = tomlkit.parse(original_text)
_prune_unknown_keys_by_schema(config_data, Config)
new_text = tomlkit.dumps(config_data)
if new_text != original_text:
Path(config_path).write_text(new_text, encoding="utf-8")
logger.warning(f"已自动移除 {config_path} 中未知/废弃配置项")
# 将 tomlkit 对象转换为纯 Python 字典,避免 Pydantic 严格模式下的类型验证问题
# tomlkit 返回的是特殊类型(如 Array、String 等),虽然继承自 Python 标准类型,
@@ -530,11 +632,16 @@ def api_ada_load_config(config_path: str) -> APIAdapterConfig:
Returns:
APIAdapterConfig对象
"""
# 读取配置文件
with open(config_path, encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 读取配置文件(会自动删除未知/废弃配置项)
original_text = Path(config_path).read_text(encoding="utf-8")
config_data = tomlkit.parse(original_text)
_prune_unknown_keys_by_schema(config_data, APIAdapterConfig)
new_text = tomlkit.dumps(config_data)
if new_text != original_text:
Path(config_path).write_text(new_text, encoding="utf-8")
logger.warning(f"已自动移除 {config_path} 中未知/废弃配置项")
config_dict = dict(config_data)
config_dict = config_data.unwrap()
try:
logger.debug("正在解析和验证API适配器配置文件...")

View File

@@ -142,7 +142,7 @@ class ValidatedConfigBase(BaseModel):
"""带验证的配置基类继承自Pydantic BaseModel"""
model_config = {
"extra": "allow", # 允许额外字段
"extra": "forbid", # 禁止额外字段(防止跨版本遗留废弃配置项)
"validate_assignment": True, # 验证赋值
"arbitrary_types_allowed": True, # 允许任意类型
"strict": True, # 如果设为 True 会完全禁用类型转换

View File

@@ -13,6 +13,12 @@ from src.config.config_base import ValidatedConfigBase
"""
class InnerConfig(ValidatedConfigBase):
"""配置文件元信息"""
version: str = Field(..., description="配置文件版本号(用于配置文件升级与兼容性检查)")
class DatabaseConfig(ValidatedConfigBase):
"""数据库配置类"""
@@ -191,9 +197,9 @@ class NoticeConfig(ValidatedConfigBase):
enable_notice_trigger_chat: bool = Field(default=True, description="是否允许notice消息触发聊天流程")
notice_in_prompt: bool = Field(default=True, description="是否在提示词中展示最近的notice消息")
notice_prompt_limit: int = Field(default=5, ge=1, le=20, description="在提示词中展示的最大notice数量")
notice_time_window: int = Field(default=3600, ge=60, le=86400, description="notice时间窗口(秒)")
notice_time_window: int = Field(default=3600, ge=10, le=86400, description="notice时间窗口(秒)")
max_notices_per_chat: int = Field(default=30, ge=10, le=100, description="每个聊天保留的notice数量上限")
notice_retention_time: int = Field(default=86400, ge=3600, le=604800, description="notice保留时间(秒)")
notice_retention_time: int = Field(default=86400, ge=10, le=604800, description="notice保留时间(秒)")
class ExpressionRule(ValidatedConfigBase):
@@ -588,6 +594,20 @@ class ResponseSplitterConfig(ValidatedConfigBase):
enable_kaomoji_protection: bool = Field(default=False, description="启用颜文字保护")
class LogConfig(ValidatedConfigBase):
"""日志配置类"""
date_style: str = Field(default="m-d H:i:s", description="日期格式")
log_level_style: str = Field(default="lite", description="日志级别样式")
color_text: str = Field(default="full", description="日志文本颜色")
log_level: str = Field(default="INFO", description="全局日志级别(向下兼容,优先级低于分别设置)")
file_retention_days: int = Field(default=7, description="文件日志保留天数0=禁用文件日志,-1=永不删除")
console_log_level: str = Field(default="INFO", description="控制台日志级别")
file_log_level: str = Field(default="DEBUG", description="文件日志级别")
suppress_libraries: list[str] = Field(default_factory=list, description="完全屏蔽日志的第三方库列表")
library_log_levels: dict[str, str] = Field(default_factory=dict, description="设置特定库的日志级别")
class DebugConfig(ValidatedConfigBase):
"""调试配置类"""
@@ -703,6 +723,7 @@ class WebSearchConfig(ValidatedConfigBase):
enable_url_tool: bool = Field(default=True, description="启用URL工具")
tavily_api_keys: list[str] = Field(default_factory=lambda: [], description="Tavily API密钥列表支持轮询机制")
exa_api_keys: list[str] = Field(default_factory=lambda: [], description="exa API密钥列表支持轮询机制")
metaso_api_keys: list[str] = Field(default_factory=lambda: [], description="Metaso API密钥列表支持轮询机制")
searxng_instances: list[str] = Field(default_factory=list, description="SearXNG 实例 URL 列表")
searxng_api_keys: list[str] = Field(default_factory=list, description="SearXNG 实例 API 密钥列表")
serper_api_keys: list[str] = Field(default_factory=list, description="serper API 密钥列表")
@@ -988,6 +1009,12 @@ class KokoroFlowChatterConfig(ValidatedConfigBase):
description="开启后KFC将接管所有私聊消息关闭后私聊消息将由AFC处理"
)
# --- 工作模式 ---
mode: Literal["unified", "split"] = Field(
default="split",
description='工作模式: "unified"(单次调用) 或 "split"(planner+replyer两次调用)',
)
# --- 核心行为配置 ---
max_wait_seconds_default: int = Field(
default=300, ge=30, le=3600,
@@ -998,6 +1025,12 @@ class KokoroFlowChatterConfig(ValidatedConfigBase):
description="是否在等待期间启用心理活动更新"
)
# --- 自定义决策提示词 ---
custom_decision_prompt: str = Field(
default="",
description="自定义KFC决策行为指导提示词unified影响整体split仅影响planner",
)
waiting: KokoroFlowChatterWaitingConfig = Field(
default_factory=KokoroFlowChatterWaitingConfig,
description="等待策略配置(默认等待时间、倍率等)",

View File

@@ -451,7 +451,7 @@ class UnifiedMemoryManager:
(0.3, 10.0, 0.4),
(0.1, 15.0, 0.6),
]
for threshold, min_val, factor in occupancy_thresholds:
if occupancy >= threshold:
return max(min_val, base_interval * factor)
@@ -461,24 +461,24 @@ class UnifiedMemoryManager:
async def _transfer_blocks_to_short_term(self, blocks: list[MemoryBlock]) -> None:
"""实际转换逻辑在后台执行(优化:并行处理多个块,批量触发唤醒)"""
logger.debug(f"正在后台处理 {len(blocks)} 个感知记忆块")
# 优化:使用 asyncio.gather 并行处理转移
async def _transfer_single(block: MemoryBlock) -> tuple[MemoryBlock, bool]:
try:
stm = await self.short_term_manager.add_from_block(block)
if not stm:
return block, False
await self.perceptual_manager.remove_block(block.id)
logger.debug(f"✓ 记忆块 {block.id} 已被转移到短期记忆 {stm.id}")
return block, True
except Exception as exc:
logger.error(f"后台转移失败,记忆块 {block.id}: {exc}")
return block, False
# 并行处理所有块
results = await asyncio.gather(*[_transfer_single(block) for block in blocks], return_exceptions=True)
# 统计成功的转移
success_count = sum(1 for result in results if isinstance(result, tuple) and result[1])
if success_count > 0:
@@ -491,7 +491,7 @@ class UnifiedMemoryManager:
seen = set()
decay = 0.15
manual_queries: list[dict[str, Any]] = []
for raw in queries:
text = (raw or "").strip()
if text and text not in seen:
@@ -517,7 +517,7 @@ class UnifiedMemoryManager:
"top_k": self._config["long_term"]["search_top_k"],
"use_multi_query": bool(manual_queries),
}
if recent_chat_history or manual_queries:
context: dict[str, Any] = {}
if recent_chat_history:
@@ -541,7 +541,7 @@ class UnifiedMemoryManager:
mem_id = mem.get("id")
else:
mem_id = getattr(mem, "id", None)
# 检查去重
if mem_id and mem_id in seen_ids:
continue
@@ -600,7 +600,7 @@ class UnifiedMemoryManager:
new_memories.append(memory)
if mem_id:
cached_ids.add(mem_id)
if new_memories:
transfer_cache.extend(new_memories)
logger.debug(
@@ -632,7 +632,7 @@ class UnifiedMemoryManager:
await self.short_term_manager.clear_transferred_memories(
result["transferred_memory_ids"]
)
# 优化:使用生成器表达式保留未转移的记忆
transfer_cache = [
m