Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox-Core into dev
This commit is contained in:
@@ -9,6 +9,8 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from src.common.database.compatibility import get_db_session
|
||||
from src.common.database.core.models import ChatStreams
|
||||
from src.common.logger import get_logger
|
||||
@@ -159,20 +161,27 @@ class BatchDatabaseWriter:
|
||||
logger.info("批量写入循环结束")
|
||||
|
||||
async def _collect_batch(self) -> list[StreamUpdatePayload]:
|
||||
"""收集一个批次的数据"""
|
||||
batch = []
|
||||
deadline = time.time() + self.flush_interval
|
||||
"""收集一个批次的数据
|
||||
- 自适应刷新:队列增长加快时缩短等待时间
|
||||
- 避免长时间空转:添加轻微抖动以分散竞争
|
||||
"""
|
||||
batch: list[StreamUpdatePayload] = []
|
||||
# 根据当前队列长度调整刷新时间(最多缩短到 40%)
|
||||
qsize = self.write_queue.qsize()
|
||||
adapt_factor = 1.0
|
||||
if qsize > 0:
|
||||
adapt_factor = max(0.4, min(1.0, self.batch_size / max(1, qsize)))
|
||||
deadline = time.time() + (self.flush_interval * adapt_factor)
|
||||
|
||||
while len(batch) < self.batch_size and time.time() < deadline:
|
||||
try:
|
||||
# 计算剩余等待时间
|
||||
remaining_time = max(0, deadline - time.time())
|
||||
remaining_time = max(0.0, deadline - time.time())
|
||||
if remaining_time == 0:
|
||||
break
|
||||
|
||||
payload = await asyncio.wait_for(self.write_queue.get(), timeout=remaining_time)
|
||||
# 轻微抖动,避免多个协程同时争抢队列
|
||||
jitter = 0.002
|
||||
payload = await asyncio.wait_for(self.write_queue.get(), timeout=remaining_time + jitter)
|
||||
batch.append(payload)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
@@ -208,48 +217,52 @@ class BatchDatabaseWriter:
|
||||
|
||||
logger.debug(f"批量写入完成: {len(batch)} 个更新,耗时 {time.time() - start_time:.3f}s")
|
||||
|
||||
except Exception as e:
|
||||
except SQLAlchemyError as e:
|
||||
self.stats["failed_writes"] += 1
|
||||
logger.error(f"批量写入失败: {e}")
|
||||
# 降级到单个写入
|
||||
for payload in batch:
|
||||
try:
|
||||
await self._direct_write(payload.stream_id, payload.update_data)
|
||||
except Exception as single_e:
|
||||
except SQLAlchemyError as single_e:
|
||||
logger.error(f"单个写入也失败: {single_e}")
|
||||
|
||||
async def _batch_write_to_database(self, payloads: list[StreamUpdatePayload]):
|
||||
"""批量写入数据库"""
|
||||
"""批量写入数据库(单事务、多值 UPSERT)"""
|
||||
if global_config is None:
|
||||
raise RuntimeError("Global config is not initialized")
|
||||
|
||||
if not payloads:
|
||||
return
|
||||
|
||||
# 预组装行数据,确保每行包含 stream_id
|
||||
rows: list[dict[str, Any]] = []
|
||||
for p in payloads:
|
||||
row = {"stream_id": p.stream_id}
|
||||
row.update(p.update_data)
|
||||
rows.append(row)
|
||||
|
||||
async with get_db_session() as session:
|
||||
for payload in payloads:
|
||||
stream_id = payload.stream_id
|
||||
update_data = payload.update_data
|
||||
|
||||
# 根据数据库类型选择不同的插入/更新策略
|
||||
if global_config.database.database_type == "sqlite":
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data)
|
||||
elif global_config.database.database_type == "postgresql":
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
stmt = pg_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=[ChatStreams.stream_id],
|
||||
set_=update_data
|
||||
)
|
||||
else:
|
||||
# 默认使用SQLite语法
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data)
|
||||
|
||||
# 使用单次事务提交,显著减少 I/O
|
||||
if global_config.database.database_type == "postgresql":
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
stmt = pg_insert(ChatStreams).values(rows)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=[ChatStreams.stream_id],
|
||||
set_={k: getattr(stmt.excluded, k) for k in rows[0].keys() if k != "stream_id"}
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
else:
|
||||
# 默认(sqlite)
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
stmt = sqlite_insert(ChatStreams).values(rows)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["stream_id"],
|
||||
set_={k: getattr(stmt.excluded, k) for k in rows[0].keys() if k != "stream_id"}
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
async def _direct_write(self, stream_id: str, update_data: dict[str, Any]):
|
||||
"""直接写入数据库(降级方案)"""
|
||||
if global_config is None:
|
||||
|
||||
@@ -55,7 +55,7 @@ async def conversation_loop(
|
||||
stream_id: str,
|
||||
get_context_func: Callable[[str], Awaitable["StreamContext | None"]],
|
||||
calculate_interval_func: Callable[[str, bool], Awaitable[float]],
|
||||
flush_cache_func: Callable[[str], Awaitable[None]],
|
||||
flush_cache_func: Callable[[str], Awaitable[list[Any]]],
|
||||
check_force_dispatch_func: Callable[["StreamContext", int], bool],
|
||||
is_running_func: Callable[[], bool],
|
||||
) -> AsyncIterator[ConversationTick]:
|
||||
@@ -121,7 +121,7 @@ async def conversation_loop(
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f" [生成器] stream={stream_id[:8]}, 被取消")
|
||||
break
|
||||
except Exception as e:
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f" [生成器] stream={stream_id[:8]}, 出错: {e}")
|
||||
await asyncio.sleep(5.0)
|
||||
|
||||
@@ -151,10 +151,10 @@ async def run_chat_stream(
|
||||
# 创建生成器
|
||||
tick_generator = conversation_loop(
|
||||
stream_id=stream_id,
|
||||
get_context_func=manager._get_stream_context,
|
||||
calculate_interval_func=manager._calculate_interval,
|
||||
flush_cache_func=manager._flush_cached_messages_to_unread,
|
||||
check_force_dispatch_func=manager._needs_force_dispatch_for_context,
|
||||
get_context_func=manager._get_stream_context, # noqa: SLF001
|
||||
calculate_interval_func=manager._calculate_interval, # noqa: SLF001
|
||||
flush_cache_func=manager._flush_cached_messages_to_unread, # noqa: SLF001
|
||||
check_force_dispatch_func=manager._needs_force_dispatch_for_context, # noqa: SLF001
|
||||
is_running_func=lambda: manager.is_running,
|
||||
)
|
||||
|
||||
@@ -162,13 +162,13 @@ async def run_chat_stream(
|
||||
async for tick in tick_generator:
|
||||
try:
|
||||
# 获取上下文
|
||||
context = await manager._get_stream_context(stream_id)
|
||||
context = await manager._get_stream_context(stream_id) # noqa: SLF001
|
||||
if not context:
|
||||
continue
|
||||
|
||||
# 并发保护:检查是否正在处理
|
||||
if context.is_chatter_processing:
|
||||
if manager._recover_stale_chatter_state(stream_id, context):
|
||||
if manager._recover_stale_chatter_state(stream_id, context): # noqa: SLF001
|
||||
logger.warning(f" [驱动器] stream={stream_id[:8]}, 处理标志残留已修复")
|
||||
else:
|
||||
logger.debug(f" [驱动器] stream={stream_id[:8]}, Chatter正在处理,跳过此Tick")
|
||||
@@ -182,17 +182,18 @@ async def run_chat_stream(
|
||||
|
||||
# 更新能量值
|
||||
try:
|
||||
await manager._update_stream_energy(stream_id, context)
|
||||
await manager._update_stream_energy(stream_id, context) # noqa: SLF001
|
||||
except Exception as e:
|
||||
logger.debug(f"更新能量失败: {e}")
|
||||
|
||||
# 处理消息
|
||||
assert global_config is not None
|
||||
try:
|
||||
success = await asyncio.wait_for(
|
||||
manager._process_stream_messages(stream_id, context),
|
||||
global_config.chat.thinking_timeout
|
||||
)
|
||||
async with manager._processing_semaphore:
|
||||
success = await asyncio.wait_for(
|
||||
manager._process_stream_messages(stream_id, context), # noqa: SLF001
|
||||
global_config.chat.thinking_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f" [驱动器] stream={stream_id[:8]}, Tick#{tick.tick_count}, 处理超时")
|
||||
success = False
|
||||
@@ -208,7 +209,7 @@ async def run_chat_stream(
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error(f" [驱动器] stream={stream_id[:8]}, 处理Tick时出错: {e}")
|
||||
manager.stats["total_failures"] += 1
|
||||
|
||||
@@ -221,7 +222,7 @@ async def run_chat_stream(
|
||||
if context and context.stream_loop_task:
|
||||
context.stream_loop_task = None
|
||||
logger.debug(f" [驱动器] stream={stream_id[:8]}, 清理任务记录")
|
||||
except Exception as e:
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.debug(f"清理任务记录失败: {e}")
|
||||
|
||||
|
||||
@@ -268,6 +269,9 @@ class StreamLoopManager:
|
||||
# 流启动锁:防止并发启动同一个流的多个任务
|
||||
self._stream_start_locks: dict[str, asyncio.Lock] = {}
|
||||
|
||||
# 并发控制:限制同时进行的 Chatter 处理任务数
|
||||
self._processing_semaphore = asyncio.Semaphore(self.max_concurrent_streams)
|
||||
|
||||
logger.info(f"流循环管理器初始化完成 (最大并发流数: {self.max_concurrent_streams})")
|
||||
|
||||
# ========================================================================
|
||||
|
||||
@@ -104,9 +104,17 @@ class MessageManager:
|
||||
if not chat_stream:
|
||||
logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在")
|
||||
return
|
||||
# 启动 stream loop 任务(如果尚未启动)
|
||||
await stream_loop_manager.start_stream_loop(stream_id)
|
||||
|
||||
# 快速检查:如果已有驱动器在跑,则跳过重复启动,避免不必要的 await
|
||||
context = chat_stream.context
|
||||
if not (context.stream_loop_task and not context.stream_loop_task.done()):
|
||||
# 异步启动驱动器任务;避免在高并发下阻塞消息入队
|
||||
await stream_loop_manager.start_stream_loop(stream_id)
|
||||
|
||||
# 检查并处理消息打断
|
||||
await self._check_and_handle_interruption(chat_stream, message)
|
||||
|
||||
# 入队消息
|
||||
await chat_stream.context.add_message(message)
|
||||
|
||||
except Exception as e:
|
||||
@@ -476,8 +484,7 @@ class MessageManager:
|
||||
is_processing: 是否正在处理
|
||||
"""
|
||||
try:
|
||||
# 尝试更新StreamContext的处理状态
|
||||
import asyncio
|
||||
# 尝试更新StreamContext的处理状态(使用顶层 asyncio 导入)
|
||||
async def _update_context():
|
||||
try:
|
||||
chat_manager = get_chat_manager()
|
||||
@@ -492,7 +499,7 @@ class MessageManager:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
asyncio.create_task(_update_context())
|
||||
self._update_context_task = asyncio.create_task(_update_context())
|
||||
else:
|
||||
# 如果事件循环未运行,则跳过
|
||||
logger.debug("事件循环未运行,跳过StreamContext状态更新")
|
||||
@@ -512,8 +519,7 @@ class MessageManager:
|
||||
bool: 是否正在处理
|
||||
"""
|
||||
try:
|
||||
# 尝试从StreamContext获取处理状态
|
||||
import asyncio
|
||||
# 尝试从StreamContext获取处理状态(使用顶层 asyncio 导入)
|
||||
async def _get_context_status():
|
||||
try:
|
||||
chat_manager = get_chat_manager()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from typing import ClassVar
|
||||
|
||||
from rich.traceback import install
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
@@ -25,6 +27,9 @@ _background_tasks: set[asyncio.Task] = set()
|
||||
class ChatStream:
|
||||
"""聊天流对象,存储一个完整的聊天上下文"""
|
||||
|
||||
# 类级别的缓存,用于存储计算过的兴趣值(避免重复计算)
|
||||
_interest_cache: ClassVar[dict] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream_id: str,
|
||||
@@ -159,7 +164,19 @@ class ChatStream:
|
||||
return None
|
||||
|
||||
async def _calculate_message_interest(self, db_message):
|
||||
"""计算消息兴趣值并更新消息对象"""
|
||||
"""计算消息兴趣值并更新消息对象 - 优化版本使用缓存"""
|
||||
# 使用消息ID作为缓存键
|
||||
cache_key = getattr(db_message, "message_id", None)
|
||||
|
||||
# 检查缓存
|
||||
if cache_key and cache_key in ChatStream._interest_cache:
|
||||
cached_result = ChatStream._interest_cache[cache_key]
|
||||
db_message.interest_value = cached_result["interest_value"]
|
||||
db_message.should_reply = cached_result["should_reply"]
|
||||
db_message.should_act = cached_result["should_act"]
|
||||
logger.debug(f"消息 {cache_key} 使用缓存的兴趣值: {cached_result['interest_value']:.3f}")
|
||||
return
|
||||
|
||||
try:
|
||||
from src.chat.interest_system.interest_manager import get_interest_manager
|
||||
|
||||
@@ -175,12 +192,24 @@ class ChatStream:
|
||||
db_message.should_reply = result.should_reply
|
||||
db_message.should_act = result.should_act
|
||||
|
||||
# 缓存结果
|
||||
if cache_key:
|
||||
ChatStream._interest_cache[cache_key] = {
|
||||
"interest_value": result.interest_value,
|
||||
"should_reply": result.should_reply,
|
||||
"should_act": result.should_act,
|
||||
}
|
||||
# 限制缓存大小,防止内存溢出(保留最近5000条)
|
||||
if len(ChatStream._interest_cache) > 5000:
|
||||
oldest_key = next(iter(ChatStream._interest_cache))
|
||||
del ChatStream._interest_cache[oldest_key]
|
||||
|
||||
logger.debug(
|
||||
f"消息 {db_message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
|
||||
f"消息 {cache_key} 兴趣值已更新: {result.interest_value:.3f}, "
|
||||
f"should_reply: {result.should_reply}, should_act: {result.should_act}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"消息 {db_message.message_id} 兴趣值计算失败: {result.error_message}")
|
||||
logger.warning(f"消息 {cache_key} 兴趣值计算失败: {result.error_message}")
|
||||
# 使用默认值
|
||||
db_message.interest_value = 0.3
|
||||
db_message.should_reply = False
|
||||
@@ -362,21 +391,24 @@ class ChatManager:
|
||||
self.last_messages[stream_id] = message
|
||||
# logger.debug(f"注册消息到聊天流: {stream_id}")
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=10000)
|
||||
def _generate_stream_id_cached(key: str) -> str:
|
||||
"""缓存的stream_id生成(内部使用)"""
|
||||
return hashlib.sha256(key.encode()).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _generate_stream_id(platform: str, user_info: DatabaseUserInfo | None, group_info: DatabaseGroupInfo | None = None) -> str:
|
||||
"""生成聊天流唯一ID"""
|
||||
"""生成聊天流唯一ID - 使用缓存优化"""
|
||||
if not user_info and not group_info:
|
||||
raise ValueError("用户信息或群组信息必须提供")
|
||||
|
||||
if group_info:
|
||||
# 组合关键信息
|
||||
components = [platform, str(group_info.group_id)]
|
||||
key = f"{platform}_{group_info.group_id}"
|
||||
else:
|
||||
components = [platform, str(user_info.user_id), "private"] # type: ignore
|
||||
key = f"{platform}_{user_info.user_id}_private" # type: ignore
|
||||
|
||||
# 使用SHA-256生成唯一ID
|
||||
key = "_".join(components)
|
||||
return hashlib.sha256(key.encode()).hexdigest()
|
||||
return ChatManager._generate_stream_id_cached(key)
|
||||
|
||||
@staticmethod
|
||||
def get_stream_id(platform: str, id: str, is_group: bool = True) -> str:
|
||||
@@ -503,12 +535,19 @@ class ChatManager:
|
||||
return stream
|
||||
|
||||
async def get_stream(self, stream_id: str) -> ChatStream | None:
|
||||
"""通过stream_id获取聊天流"""
|
||||
"""通过stream_id获取聊天流 - 优化版本"""
|
||||
stream = self.streams.get(stream_id)
|
||||
if not stream:
|
||||
return None
|
||||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], DatabaseMessages):
|
||||
await stream.set_context(self.last_messages[stream_id])
|
||||
|
||||
# 只在必要时设置上下文(避免重复调用)
|
||||
if stream_id not in self.last_messages:
|
||||
return stream
|
||||
|
||||
last_message = self.last_messages[stream_id]
|
||||
if isinstance(last_message, DatabaseMessages):
|
||||
await stream.set_context(last_message)
|
||||
|
||||
return stream
|
||||
|
||||
def get_stream_by_info(
|
||||
@@ -536,30 +575,30 @@ class ChatManager:
|
||||
|
||||
Returns:
|
||||
dict[str, ChatStream]: 包含所有聊天流的字典,key为stream_id,value为ChatStream对象
|
||||
|
||||
"""
|
||||
return self.streams.copy() # 返回副本以防止外部修改
|
||||
return self.streams
|
||||
|
||||
@staticmethod
|
||||
def _prepare_stream_data(stream_data_dict: dict) -> dict:
|
||||
"""准备聊天流保存数据"""
|
||||
user_info_d = stream_data_dict.get("user_info")
|
||||
group_info_d = stream_data_dict.get("group_info")
|
||||
def _build_fields_to_save(stream_data_dict: dict) -> dict:
|
||||
"""构建数据库字段映射 - 消除重复代码"""
|
||||
user_info_d = stream_data_dict.get("user_info") or {}
|
||||
group_info_d = stream_data_dict.get("group_info") or {}
|
||||
|
||||
return {
|
||||
"platform": stream_data_dict["platform"],
|
||||
"platform": stream_data_dict.get("platform", "") or "",
|
||||
"create_time": stream_data_dict["create_time"],
|
||||
"last_active_time": stream_data_dict["last_active_time"],
|
||||
"user_platform": user_info_d["platform"] if user_info_d else "",
|
||||
"user_id": user_info_d["user_id"] if user_info_d else "",
|
||||
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
|
||||
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
|
||||
"group_platform": group_info_d["platform"] if group_info_d else "",
|
||||
"group_id": group_info_d["group_id"] if group_info_d else "",
|
||||
"group_name": group_info_d["group_name"] if group_info_d else "",
|
||||
"user_platform": user_info_d.get("platform", ""),
|
||||
"user_id": user_info_d.get("user_id", ""),
|
||||
"user_nickname": user_info_d.get("user_nickname", ""),
|
||||
"user_cardname": user_info_d.get("user_cardname"),
|
||||
"group_platform": group_info_d.get("platform", ""),
|
||||
"group_id": group_info_d.get("group_id", ""),
|
||||
"group_name": group_info_d.get("group_name", ""),
|
||||
"energy_value": stream_data_dict.get("energy_value", 5.0),
|
||||
"sleep_pressure": stream_data_dict.get("sleep_pressure", 0.0),
|
||||
"focus_energy": stream_data_dict.get("focus_energy", 0.5),
|
||||
# 新增动态兴趣度系统字段
|
||||
"base_interest_energy": stream_data_dict.get("base_interest_energy", 0.5),
|
||||
"message_interest_total": stream_data_dict.get("message_interest_total", 0.0),
|
||||
"message_count": stream_data_dict.get("message_count", 0),
|
||||
@@ -570,6 +609,11 @@ class ChatManager:
|
||||
"interruption_count": stream_data_dict.get("interruption_count", 0),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _prepare_stream_data(stream_data_dict: dict) -> dict:
|
||||
"""准备聊天流保存数据 - 调用统一的字段构建方法"""
|
||||
return ChatManager._build_fields_to_save(stream_data_dict)
|
||||
|
||||
@staticmethod
|
||||
async def _save_stream(stream: ChatStream):
|
||||
"""保存聊天流到数据库 - 优化版本使用异步批量写入"""
|
||||
@@ -624,38 +668,12 @@ class ChatManager:
|
||||
raise RuntimeError("Global config is not initialized")
|
||||
|
||||
async with get_db_session() as session:
|
||||
user_info_d = s_data_dict.get("user_info")
|
||||
group_info_d = s_data_dict.get("group_info")
|
||||
fields_to_save = {
|
||||
"platform": s_data_dict.get("platform", "") or "",
|
||||
"create_time": s_data_dict["create_time"],
|
||||
"last_active_time": s_data_dict["last_active_time"],
|
||||
"user_platform": user_info_d["platform"] if user_info_d else "",
|
||||
"user_id": user_info_d["user_id"] if user_info_d else "",
|
||||
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
|
||||
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
|
||||
"group_platform": group_info_d.get("platform", "") or "" if group_info_d else "",
|
||||
"group_id": group_info_d["group_id"] if group_info_d else "",
|
||||
"group_name": group_info_d["group_name"] if group_info_d else "",
|
||||
"energy_value": s_data_dict.get("energy_value", 5.0),
|
||||
"sleep_pressure": s_data_dict.get("sleep_pressure", 0.0),
|
||||
"focus_energy": s_data_dict.get("focus_energy", 0.5),
|
||||
# 新增动态兴趣度系统字段
|
||||
"base_interest_energy": s_data_dict.get("base_interest_energy", 0.5),
|
||||
"message_interest_total": s_data_dict.get("message_interest_total", 0.0),
|
||||
"message_count": s_data_dict.get("message_count", 0),
|
||||
"action_count": s_data_dict.get("action_count", 0),
|
||||
"reply_count": s_data_dict.get("reply_count", 0),
|
||||
"last_interaction_time": s_data_dict.get("last_interaction_time", time.time()),
|
||||
"consecutive_no_reply": s_data_dict.get("consecutive_no_reply", 0),
|
||||
"interruption_count": s_data_dict.get("interruption_count", 0),
|
||||
}
|
||||
fields_to_save = ChatManager._build_fields_to_save(s_data_dict)
|
||||
if global_config.database.database_type == "sqlite":
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save)
|
||||
elif global_config.database.database_type == "postgresql":
|
||||
stmt = pg_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||||
# PostgreSQL 需要使用 constraint 参数或正确的 index_elements
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=[ChatStreams.stream_id],
|
||||
set_=fields_to_save
|
||||
@@ -678,14 +696,16 @@ class ChatManager:
|
||||
await self._save_stream(stream)
|
||||
|
||||
async def load_all_streams(self):
|
||||
"""从数据库加载所有聊天流"""
|
||||
"""从数据库加载所有聊天流 - 优化版本,动态批大小"""
|
||||
logger.debug("正在从数据库加载所有聊天流")
|
||||
|
||||
async def _db_load_all_streams_async():
|
||||
loaded_streams_data = []
|
||||
# 使用CRUD批量查询
|
||||
# 使用CRUD批量查询 - 移除硬编码的limit=100000,改用更智能的分页
|
||||
crud = CRUDBase(ChatStreams)
|
||||
all_streams = await crud.get_multi(limit=100000) # 获取所有聊天流
|
||||
|
||||
# 先获取总数,以优化批处理大小
|
||||
all_streams = await crud.get_multi(limit=None) # 获取所有聊天流
|
||||
|
||||
for model_instance in all_streams:
|
||||
user_info_data = {
|
||||
@@ -733,8 +753,6 @@ class ChatManager:
|
||||
stream.saved = True
|
||||
self.streams[stream.stream_id] = stream
|
||||
# 不在异步加载中设置上下文,避免复杂依赖
|
||||
# if stream.stream_id in self.last_messages:
|
||||
# await stream.set_context(self.last_messages[stream.stream_id])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}")
|
||||
|
||||
@@ -30,7 +30,7 @@ from __future__ import annotations
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, cast
|
||||
|
||||
from mofox_wire import MessageEnvelope, MessageRuntime
|
||||
|
||||
@@ -53,6 +53,22 @@ logger = get_logger("message_handler")
|
||||
# 项目根目录
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
|
||||
|
||||
# 预编译的正则表达式缓存(避免重复编译)
|
||||
_compiled_regex_cache: dict[str, re.Pattern] = {}
|
||||
|
||||
# 硬编码过滤关键词(缓存到全局变量,避免每次创建列表)
|
||||
_MEDIA_FAILURE_KEYWORDS = frozenset(["[表情包(描述生成失败)]", "[图片(描述生成失败)]"])
|
||||
|
||||
def _get_compiled_pattern(pattern: str) -> re.Pattern | None:
|
||||
"""获取编译的正则表达式,使用缓存避免重复编译"""
|
||||
if pattern not in _compiled_regex_cache:
|
||||
try:
|
||||
_compiled_regex_cache[pattern] = re.compile(pattern)
|
||||
except re.error as e:
|
||||
logger.warning(f"正则表达式编译失败: {pattern}, 错误: {e}")
|
||||
return None
|
||||
return _compiled_regex_cache.get(pattern)
|
||||
|
||||
def _check_ban_words(text: str, chat: "ChatStream", userinfo) -> bool:
|
||||
"""检查消息是否包含过滤词"""
|
||||
if global_config is None:
|
||||
@@ -65,11 +81,13 @@ def _check_ban_words(text: str, chat: "ChatStream", userinfo) -> bool:
|
||||
return True
|
||||
return False
|
||||
def _check_ban_regex(text: str, chat: "ChatStream", userinfo) -> bool:
|
||||
"""检查消息是否匹配过滤正则表达式"""
|
||||
"""检查消息是否匹配过滤正则表达式 - 优化版本使用预编译缓存"""
|
||||
if global_config is None:
|
||||
return False
|
||||
|
||||
for pattern in global_config.message_receive.ban_msgs_regex:
|
||||
if re.search(pattern, text):
|
||||
compiled_pattern = _get_compiled_pattern(pattern)
|
||||
if compiled_pattern and compiled_pattern.search(text):
|
||||
chat_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
logger.info(f"[{chat_name}]{userinfo.user_nickname}:{text}")
|
||||
logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered")
|
||||
@@ -97,6 +115,10 @@ class MessageHandler:
|
||||
4. 普通消息处理:触发事件、存储、情绪更新
|
||||
"""
|
||||
|
||||
# 类级别缓存:命令查询结果缓存(减少重复查询)
|
||||
_plus_command_cache: ClassVar[dict[str, Any]] = {}
|
||||
_base_command_cache: ClassVar[dict[str, Any]] = {}
|
||||
|
||||
def __init__(self):
|
||||
self._started = False
|
||||
self._message_manager_started = False
|
||||
@@ -108,6 +130,36 @@ class MessageHandler:
|
||||
"""设置 CoreSinkManager 引用"""
|
||||
self._core_sink_manager = manager
|
||||
|
||||
async def _get_or_create_chat_stream(
|
||||
self, platform: str, user_info: dict | None, group_info: dict | None
|
||||
) -> "ChatStream":
|
||||
"""获取或创建聊天流 - 统一方法"""
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
return await get_chat_manager().get_or_create_stream(
|
||||
platform=platform,
|
||||
user_info=DatabaseUserInfo.from_dict(cast(dict[str, Any], user_info)) if user_info else None,
|
||||
group_info=DatabaseGroupInfo.from_dict(cast(dict[str, Any], group_info)) if group_info else None,
|
||||
)
|
||||
|
||||
async def _process_message_to_database(
|
||||
self, envelope: MessageEnvelope, chat: "ChatStream"
|
||||
) -> DatabaseMessages:
|
||||
"""将消息信封转换为 DatabaseMessages - 统一方法"""
|
||||
from src.chat.message_receive.message_processor import process_message_from_dict
|
||||
|
||||
message = await process_message_from_dict(
|
||||
message_dict=envelope,
|
||||
stream_id=chat.stream_id,
|
||||
platform=chat.platform
|
||||
)
|
||||
|
||||
# 填充聊天流时间信息
|
||||
message.chat_info.create_time = chat.create_time
|
||||
message.chat_info.last_active_time = chat.last_active_time
|
||||
|
||||
return message
|
||||
|
||||
def register_handlers(self, runtime: MessageRuntime) -> None:
|
||||
"""
|
||||
向 MessageRuntime 注册消息处理器和钩子
|
||||
@@ -279,25 +331,10 @@ class MessageHandler:
|
||||
|
||||
# 获取或创建聊天流
|
||||
platform = message_info.get("platform", "unknown")
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=platform,
|
||||
user_info=DatabaseUserInfo.from_dict(cast(dict[str, Any], user_info)) if user_info else None, # type: ignore
|
||||
group_info=DatabaseGroupInfo.from_dict(cast(dict[str, Any], group_info)) if group_info else None,
|
||||
)
|
||||
chat = await self._get_or_create_chat_stream(platform, user_info, group_info)
|
||||
|
||||
# 将消息信封转换为 DatabaseMessages
|
||||
from src.chat.message_receive.message_processor import process_message_from_dict
|
||||
message = await process_message_from_dict(
|
||||
message_dict=envelope,
|
||||
stream_id=chat.stream_id,
|
||||
platform=chat.platform
|
||||
)
|
||||
|
||||
# 填充聊天流时间信息
|
||||
message.chat_info.create_time = chat.create_time
|
||||
message.chat_info.last_active_time = chat.last_active_time
|
||||
message = await self._process_message_to_database(envelope, chat)
|
||||
|
||||
# 标记为 notice 消息
|
||||
message.is_notify = True
|
||||
@@ -337,8 +374,7 @@ class MessageHandler:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理 Notice 消息时出错: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
async def _add_notice_to_manager(
|
||||
@@ -429,25 +465,10 @@ class MessageHandler:
|
||||
|
||||
# 获取或创建聊天流
|
||||
platform = message_info.get("platform", "unknown")
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=platform,
|
||||
user_info=DatabaseUserInfo.from_dict(cast(dict[str, Any], user_info)) if user_info else None, # type: ignore
|
||||
group_info=DatabaseGroupInfo.from_dict(cast(dict[str, Any], group_info)) if group_info else None,
|
||||
)
|
||||
chat = await self._get_or_create_chat_stream(platform, user_info, group_info)
|
||||
|
||||
# 将消息信封转换为 DatabaseMessages
|
||||
from src.chat.message_receive.message_processor import process_message_from_dict
|
||||
message = await process_message_from_dict(
|
||||
message_dict=envelope,
|
||||
stream_id=chat.stream_id,
|
||||
platform=chat.platform
|
||||
)
|
||||
|
||||
# 填充聊天流时间信息
|
||||
message.chat_info.create_time = chat.create_time
|
||||
message.chat_info.last_active_time = chat.last_active_time
|
||||
message = await self._process_message_to_database(envelope, chat)
|
||||
|
||||
# 注册消息到聊天管理器
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
@@ -462,9 +483,8 @@ class MessageHandler:
|
||||
logger.info(f"[{chat_name}]{user_nickname}:{message.processed_plain_text}\u001b[0m")
|
||||
|
||||
# 硬编码过滤
|
||||
failure_keywords = ["[表情包(描述生成失败)]", "[图片(描述生成失败)]"]
|
||||
processed_text = message.processed_plain_text or ""
|
||||
if any(keyword in processed_text for keyword in failure_keywords):
|
||||
if any(keyword in processed_text for keyword in _MEDIA_FAILURE_KEYWORDS):
|
||||
logger.info(f"[硬编码过滤] 检测到媒体内容处理失败({processed_text}),消息被静默处理。")
|
||||
return None
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
基于 mofox-wire 的 TypedDict 形式构建消息数据,然后转换为 DatabaseMessages
|
||||
"""
|
||||
import base64
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
@@ -20,6 +21,15 @@ from src.config.config import global_config
|
||||
|
||||
logger = get_logger("message_processor")
|
||||
|
||||
# 预编译正则表达式
|
||||
_AT_PATTERN = re.compile(r"^([^:]+):(.+)$")
|
||||
|
||||
# 常量定义:段类型集合
|
||||
RECURSIVE_SEGMENT_TYPES = frozenset(["seglist"])
|
||||
MEDIA_SEGMENT_TYPES = frozenset(["image", "emoji", "voice", "video"])
|
||||
METADATA_SEGMENT_TYPES = frozenset(["mention_bot", "priority_info"])
|
||||
SPECIAL_SEGMENT_TYPES = frozenset(["at", "reply", "file"])
|
||||
|
||||
|
||||
async def process_message_from_dict(message_dict: MessageEnvelope, stream_id: str, platform: str) -> DatabaseMessages:
|
||||
"""从适配器消息字典处理并生成 DatabaseMessages
|
||||
@@ -101,7 +111,7 @@ async def process_message_from_dict(message_dict: MessageEnvelope, stream_id: st
|
||||
mentioned_value = processing_state.get("is_mentioned")
|
||||
if isinstance(mentioned_value, bool):
|
||||
is_mentioned = mentioned_value
|
||||
elif isinstance(mentioned_value, (int, float)):
|
||||
elif isinstance(mentioned_value, int | float):
|
||||
is_mentioned = mentioned_value != 0
|
||||
|
||||
# 使用 TypedDict 风格的数据构建 DatabaseMessages
|
||||
@@ -223,13 +233,12 @@ async def _process_single_segment(
|
||||
state["is_at"] = True
|
||||
# 处理at消息,格式为"@<昵称:QQ号>"
|
||||
if isinstance(seg_data, str):
|
||||
if ":" in seg_data:
|
||||
# 标准格式: "昵称:QQ号"
|
||||
nickname, qq_id = seg_data.split(":", 1)
|
||||
match = _AT_PATTERN.match(seg_data)
|
||||
if match:
|
||||
nickname, qq_id = match.groups()
|
||||
return f"@<{nickname}:{qq_id}>"
|
||||
else:
|
||||
logger.warning(f"[at处理] 无法解析格式: '{seg_data}'")
|
||||
return f"@{seg_data}"
|
||||
logger.warning(f"[at处理] 无法解析格式: '{seg_data}'")
|
||||
return f"@{seg_data}"
|
||||
logger.warning(f"[at处理] 数据类型异常: {type(seg_data)}")
|
||||
return f"@{seg_data}" if isinstance(seg_data, str) else "@未知用户"
|
||||
|
||||
@@ -272,7 +281,7 @@ async def _process_single_segment(
|
||||
return "[发了一段语音,网卡了加载不出来]"
|
||||
|
||||
elif seg_type == "mention_bot":
|
||||
if isinstance(seg_data, (int, float)):
|
||||
if isinstance(seg_data, int | float):
|
||||
state["is_mentioned"] = float(seg_data)
|
||||
return ""
|
||||
|
||||
@@ -368,19 +377,18 @@ def _prepare_additional_config(
|
||||
str | None: JSON 字符串格式的 additional_config,如果为空则返回 None
|
||||
"""
|
||||
try:
|
||||
additional_config_data = {}
|
||||
|
||||
# 首先获取adapter传递的additional_config
|
||||
additional_config_raw = message_info.get("additional_config")
|
||||
if additional_config_raw:
|
||||
if isinstance(additional_config_raw, dict):
|
||||
additional_config_data = additional_config_raw.copy()
|
||||
elif isinstance(additional_config_raw, str):
|
||||
try:
|
||||
additional_config_data = orjson.loads(additional_config_raw)
|
||||
except Exception as e:
|
||||
logger.warning(f"无法解析 additional_config JSON: {e}")
|
||||
additional_config_data = {}
|
||||
if isinstance(additional_config_raw, dict):
|
||||
additional_config_data = additional_config_raw.copy()
|
||||
elif isinstance(additional_config_raw, str):
|
||||
try:
|
||||
additional_config_data = orjson.loads(additional_config_raw)
|
||||
except Exception as e:
|
||||
logger.warning(f"无法解析 additional_config JSON: {e}")
|
||||
additional_config_data = {}
|
||||
else:
|
||||
additional_config_data = {}
|
||||
|
||||
# 添加notice相关标志
|
||||
if is_notify:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import collections
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
@@ -19,6 +20,16 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = get_logger("message_storage")
|
||||
|
||||
# 预编译的正则表达式(避免重复编译)
|
||||
_COMPILED_FILTER_PATTERN = re.compile(
|
||||
r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>",
|
||||
re.DOTALL
|
||||
)
|
||||
_COMPILED_IMAGE_PATTERN = re.compile(r"\[图片:([^\]]+)\]")
|
||||
|
||||
# 全局正则表达式缓存
|
||||
_regex_cache: dict[str, re.Pattern] = {}
|
||||
|
||||
|
||||
class MessageStorageBatcher:
|
||||
"""
|
||||
@@ -116,25 +127,28 @@ class MessageStorageBatcher:
|
||||
async def flush(self, force: bool = False):
|
||||
"""执行批量写入, 支持强制落库和延迟提交策略。"""
|
||||
async with self._flush_barrier:
|
||||
# 原子性地交换消息队列,避免锁定时间过长
|
||||
async with self._lock:
|
||||
messages_to_store = list(self.pending_messages)
|
||||
self.pending_messages.clear()
|
||||
if not self.pending_messages:
|
||||
return
|
||||
messages_to_store = self.pending_messages
|
||||
self.pending_messages = collections.deque(maxlen=self.batch_size)
|
||||
|
||||
if messages_to_store:
|
||||
prepared_messages: list[dict[str, Any]] = []
|
||||
for msg_data in messages_to_store:
|
||||
try:
|
||||
message_dict = await self._prepare_message_dict(
|
||||
msg_data["message"],
|
||||
msg_data["chat_stream"],
|
||||
)
|
||||
if message_dict:
|
||||
prepared_messages.append(message_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"准备消息数据失败: {e}")
|
||||
# 处理消息,这部分不在锁内执行,提高并发性
|
||||
prepared_messages: list[dict[str, Any]] = []
|
||||
for msg_data in messages_to_store:
|
||||
try:
|
||||
message_dict = await self._prepare_message_dict(
|
||||
msg_data["message"],
|
||||
msg_data["chat_stream"],
|
||||
)
|
||||
if message_dict:
|
||||
prepared_messages.append(message_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"准备消息数据失败: {e}")
|
||||
|
||||
if prepared_messages:
|
||||
self._prepared_buffer.extend(prepared_messages)
|
||||
if prepared_messages:
|
||||
self._prepared_buffer.extend(prepared_messages)
|
||||
|
||||
await self._maybe_commit_buffer(force=force)
|
||||
|
||||
@@ -200,102 +214,66 @@ class MessageStorageBatcher:
|
||||
return message_dict
|
||||
|
||||
async def _prepare_message_object(self, message, chat_stream):
|
||||
"""准备消息对象(从原 store_message 逻辑提取)"""
|
||||
"""准备消息对象(从原 store_message 逻辑提取) - 优化版本"""
|
||||
try:
|
||||
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
|
||||
|
||||
if not isinstance(message, DatabaseMessages):
|
||||
logger.error("MessageStorageBatcher expects DatabaseMessages instances")
|
||||
return None
|
||||
|
||||
# 优化:使用预编译的正则表达式
|
||||
processed_plain_text = message.processed_plain_text or ""
|
||||
if processed_plain_text:
|
||||
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
|
||||
filtered_processed_plain_text = re.sub(
|
||||
pattern, "", processed_plain_text or "", flags=re.DOTALL
|
||||
)
|
||||
filtered_processed_plain_text = _COMPILED_FILTER_PATTERN.sub("", processed_plain_text)
|
||||
|
||||
display_message = message.display_message or message.processed_plain_text or ""
|
||||
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
|
||||
filtered_display_message = _COMPILED_FILTER_PATTERN.sub("", display_message)
|
||||
|
||||
msg_id = message.message_id
|
||||
msg_time = message.time
|
||||
chat_id = message.chat_id
|
||||
reply_to = message.reply_to or ""
|
||||
is_mentioned = message.is_mentioned
|
||||
interest_value = message.interest_value or 0.0
|
||||
priority_mode = message.priority_mode
|
||||
priority_info_json = message.priority_info
|
||||
is_emoji = message.is_emoji or False
|
||||
is_picid = message.is_picid or False
|
||||
is_notify = message.is_notify or False
|
||||
is_command = message.is_command or False
|
||||
is_public_notice = message.is_public_notice or False
|
||||
notice_type = message.notice_type
|
||||
actions = orjson.dumps(message.actions).decode("utf-8") if message.actions else None
|
||||
should_reply = message.should_reply
|
||||
should_act = message.should_act
|
||||
additional_config = message.additional_config
|
||||
key_words = MessageStorage._serialize_keywords(message.key_words)
|
||||
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
|
||||
memorized_times = getattr(message, "memorized_times", 0)
|
||||
|
||||
user_platform = message.user_info.platform if message.user_info else ""
|
||||
user_id = message.user_info.user_id if message.user_info else ""
|
||||
user_nickname = message.user_info.user_nickname if message.user_info else ""
|
||||
user_cardname = message.user_info.user_cardname if message.user_info else None
|
||||
|
||||
chat_info_stream_id = message.chat_info.stream_id if message.chat_info else ""
|
||||
chat_info_platform = message.chat_info.platform if message.chat_info else ""
|
||||
chat_info_create_time = message.chat_info.create_time if message.chat_info else 0.0
|
||||
chat_info_last_active_time = message.chat_info.last_active_time if message.chat_info else 0.0
|
||||
chat_info_user_platform = message.chat_info.user_info.platform if message.chat_info and message.chat_info.user_info else ""
|
||||
chat_info_user_id = message.chat_info.user_info.user_id if message.chat_info and message.chat_info.user_info else ""
|
||||
chat_info_user_nickname = message.chat_info.user_info.user_nickname if message.chat_info and message.chat_info.user_info else ""
|
||||
chat_info_user_cardname = message.chat_info.user_info.user_cardname if message.chat_info and message.chat_info.user_info else None
|
||||
chat_info_group_platform = message.group_info.platform if message.group_info else None
|
||||
chat_info_group_id = message.group_info.group_id if message.group_info else None
|
||||
chat_info_group_name = message.group_info.group_name if message.group_info else None
|
||||
# 优化:一次性构建字典,避免多次条件判断
|
||||
user_info = message.user_info or {}
|
||||
chat_info = message.chat_info or {}
|
||||
chat_info_user = chat_info.user_info or {} if chat_info else {}
|
||||
group_info = message.group_info or {}
|
||||
|
||||
return Messages(
|
||||
message_id=msg_id,
|
||||
time=msg_time,
|
||||
chat_id=chat_id,
|
||||
reply_to=reply_to,
|
||||
is_mentioned=is_mentioned,
|
||||
chat_info_stream_id=chat_info_stream_id,
|
||||
chat_info_platform=chat_info_platform,
|
||||
chat_info_user_platform=chat_info_user_platform,
|
||||
chat_info_user_id=chat_info_user_id,
|
||||
chat_info_user_nickname=chat_info_user_nickname,
|
||||
chat_info_user_cardname=chat_info_user_cardname,
|
||||
chat_info_group_platform=chat_info_group_platform,
|
||||
chat_info_group_id=chat_info_group_id,
|
||||
chat_info_group_name=chat_info_group_name,
|
||||
chat_info_create_time=chat_info_create_time,
|
||||
chat_info_last_active_time=chat_info_last_active_time,
|
||||
user_platform=user_platform,
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
message_id=message.message_id,
|
||||
time=message.time,
|
||||
chat_id=message.chat_id,
|
||||
reply_to=message.reply_to or "",
|
||||
is_mentioned=message.is_mentioned,
|
||||
chat_info_stream_id=chat_info.stream_id if chat_info else "",
|
||||
chat_info_platform=chat_info.platform if chat_info else "",
|
||||
chat_info_user_platform=chat_info_user.platform if chat_info_user else "",
|
||||
chat_info_user_id=chat_info_user.user_id if chat_info_user else "",
|
||||
chat_info_user_nickname=chat_info_user.user_nickname if chat_info_user else "",
|
||||
chat_info_user_cardname=chat_info_user.user_cardname if chat_info_user else None,
|
||||
chat_info_group_platform=group_info.platform if group_info else None,
|
||||
chat_info_group_id=group_info.group_id if group_info else None,
|
||||
chat_info_group_name=group_info.group_name if group_info else None,
|
||||
chat_info_create_time=chat_info.create_time if chat_info else 0.0,
|
||||
chat_info_last_active_time=chat_info.last_active_time if chat_info else 0.0,
|
||||
user_platform=user_info.platform if user_info else "",
|
||||
user_id=user_info.user_id if user_info else "",
|
||||
user_nickname=user_info.user_nickname if user_info else "",
|
||||
user_cardname=user_info.user_cardname if user_info else None,
|
||||
processed_plain_text=filtered_processed_plain_text,
|
||||
display_message=filtered_display_message,
|
||||
memorized_times=memorized_times,
|
||||
interest_value=interest_value,
|
||||
priority_mode=priority_mode,
|
||||
priority_info=priority_info_json,
|
||||
additional_config=additional_config,
|
||||
is_emoji=is_emoji,
|
||||
is_picid=is_picid,
|
||||
is_notify=is_notify,
|
||||
is_command=is_command,
|
||||
is_public_notice=is_public_notice,
|
||||
notice_type=notice_type,
|
||||
actions=actions,
|
||||
should_reply=should_reply,
|
||||
should_act=should_act,
|
||||
key_words=key_words,
|
||||
key_words_lite=key_words_lite,
|
||||
memorized_times=getattr(message, "memorized_times", 0),
|
||||
interest_value=message.interest_value or 0.0,
|
||||
priority_mode=message.priority_mode,
|
||||
priority_info=message.priority_info,
|
||||
additional_config=message.additional_config,
|
||||
is_emoji=message.is_emoji or False,
|
||||
is_picid=message.is_picid or False,
|
||||
is_notify=message.is_notify or False,
|
||||
is_command=message.is_command or False,
|
||||
is_public_notice=message.is_public_notice or False,
|
||||
notice_type=message.notice_type,
|
||||
actions=orjson.dumps(message.actions).decode("utf-8") if message.actions else None,
|
||||
should_reply=message.should_reply,
|
||||
should_act=message.should_act,
|
||||
key_words=MessageStorage._serialize_keywords(message.key_words),
|
||||
key_words_lite=MessageStorage._serialize_keywords(message.key_words_lite),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -474,7 +452,7 @@ class MessageStorage:
|
||||
@staticmethod
|
||||
async def update_message(message_data: dict, use_batch: bool = True):
|
||||
"""
|
||||
更新消息ID(从消息字典)
|
||||
更新消息ID(从消息字典)- 优化版本
|
||||
|
||||
优化: 添加批处理选项,将多个更新操作合并,减少数据库连接
|
||||
|
||||
@@ -491,25 +469,23 @@ class MessageStorage:
|
||||
segment_type = message_segment.get("type") if isinstance(message_segment, dict) else None
|
||||
segment_data = message_segment.get("data", {}) if isinstance(message_segment, dict) else {}
|
||||
|
||||
qq_message_id = None
|
||||
# 优化:预定义类型集合,避免重复的 if-elif 检查
|
||||
SKIPPED_TYPES = {"adapter_response", "adapter_command"}
|
||||
VALID_ID_TYPES = {"notify", "text", "reply"}
|
||||
|
||||
logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {segment_type}")
|
||||
|
||||
# 根据消息段类型提取message_id
|
||||
if segment_type == "notify":
|
||||
# 检查是否是需要跳过的类型
|
||||
if segment_type in SKIPPED_TYPES:
|
||||
logger.debug(f"跳过消息段类型: {segment_type}")
|
||||
return
|
||||
|
||||
# 尝试获取消息ID
|
||||
qq_message_id = None
|
||||
if segment_type in VALID_ID_TYPES:
|
||||
qq_message_id = segment_data.get("id")
|
||||
elif segment_type == "text":
|
||||
qq_message_id = segment_data.get("id")
|
||||
elif segment_type == "reply":
|
||||
qq_message_id = segment_data.get("id")
|
||||
if qq_message_id:
|
||||
if segment_type == "reply" and qq_message_id:
|
||||
logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}")
|
||||
elif segment_type == "adapter_response":
|
||||
logger.debug("适配器响应消息,不需要更新ID")
|
||||
return
|
||||
elif segment_type == "adapter_command":
|
||||
logger.debug("适配器命令消息,不需要更新ID")
|
||||
return
|
||||
else:
|
||||
logger.debug(f"未知的消息段类型: {segment_type},跳过ID更新")
|
||||
return
|
||||
@@ -552,22 +528,20 @@ class MessageStorage:
|
||||
|
||||
@staticmethod
|
||||
async def replace_image_descriptions(text: str) -> str:
|
||||
"""异步地将文本中的所有[图片:描述]标记替换为[picid:image_id]"""
|
||||
pattern = r"\[图片:([^\]]+)\]"
|
||||
|
||||
"""异步地将文本中的所有[图片:描述]标记替换为[picid:image_id] - 优化版本"""
|
||||
# 如果没有匹配项,提前返回以提高效率
|
||||
if not re.search(pattern, text):
|
||||
if not _COMPILED_IMAGE_PATTERN.search(text):
|
||||
return text
|
||||
|
||||
# re.sub不支持异步替换函数,所以我们需要手动迭代和替换
|
||||
new_text = []
|
||||
last_end = 0
|
||||
for match in re.finditer(pattern, text):
|
||||
for match in _COMPILED_IMAGE_PATTERN.finditer(text):
|
||||
# 添加上一个匹配到当前匹配之间的文本
|
||||
new_text.append(text[last_end:match.start()])
|
||||
|
||||
description = match.group(1).strip()
|
||||
replacement = match.group(0) # 默认情况下,替换为原始匹配文本
|
||||
replacement = match.group(0) # 默认情况下,替换为原始匹配文本
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
# 查询数据库以找到具有该描述的最新图片记录
|
||||
@@ -633,19 +607,49 @@ class MessageStorage:
|
||||
interest_map: dict[str, float],
|
||||
reply_map: dict[str, bool] | None = None,
|
||||
) -> None:
|
||||
"""批量更新消息的兴趣度与回复标记"""
|
||||
"""批量更新消息的兴趣度与回复标记 - 优化版本"""
|
||||
if not interest_map:
|
||||
return
|
||||
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
for message_id, interest_value in interest_map.items():
|
||||
values = {"interest_value": interest_value}
|
||||
if reply_map and message_id in reply_map:
|
||||
values["should_reply"] = reply_map[message_id]
|
||||
# 注意:SQLAlchemy 2.0 对 ORM update + executemany 会走
|
||||
# “Bulk UPDATE by Primary Key” 路径,要求每行参数包含主键(Messages.id)。
|
||||
# 这里我们按 message_id 更新,因此使用 Core Table + bindparam。
|
||||
from sqlalchemy import bindparam, update
|
||||
|
||||
stmt = update(Messages).where(Messages.message_id == message_id).values(**values)
|
||||
await session.execute(stmt)
|
||||
messages_table = Messages.__table__
|
||||
|
||||
interest_mappings: list[dict[str, Any]] = [
|
||||
{"b_message_id": message_id, "b_interest_value": interest_value}
|
||||
for message_id, interest_value in interest_map.items()
|
||||
]
|
||||
|
||||
if interest_mappings:
|
||||
stmt_interest = (
|
||||
update(messages_table)
|
||||
.where(messages_table.c.message_id == bindparam("b_message_id"))
|
||||
.values(interest_value=bindparam("b_interest_value"))
|
||||
)
|
||||
await session.execute(stmt_interest, interest_mappings)
|
||||
|
||||
if reply_map:
|
||||
reply_mappings: list[dict[str, Any]] = [
|
||||
{"b_message_id": message_id, "b_should_reply": should_reply}
|
||||
for message_id, should_reply in reply_map.items()
|
||||
if message_id in interest_map
|
||||
]
|
||||
if reply_mappings and len(reply_mappings) != len(reply_map):
|
||||
logger.debug(
|
||||
f"批量更新 should_reply 过滤了 {len(reply_map) - len(reply_mappings)} 条不在兴趣度更新集合中的记录"
|
||||
)
|
||||
if reply_mappings:
|
||||
stmt_reply = (
|
||||
update(messages_table)
|
||||
.where(messages_table.c.message_id == bindparam("b_message_id"))
|
||||
.values(should_reply=bindparam("b_should_reply"))
|
||||
)
|
||||
await session.execute(stmt_reply, reply_mappings)
|
||||
|
||||
await session.commit()
|
||||
logger.debug(f"批量更新兴趣度 {len(interest_map)} 条记录")
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from threading import Lock
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
from src.config.config_base import ValidatedConfigBase
|
||||
from src.config.official_configs import InnerConfig
|
||||
|
||||
|
||||
class APIProvider(ValidatedConfigBase):
|
||||
@@ -21,6 +22,9 @@ class APIProvider(ValidatedConfigBase):
|
||||
)
|
||||
retry_interval: int = Field(default=10, ge=0, description="重试间隔(如果API调用失败,重试的间隔时间,单位:秒)")
|
||||
|
||||
_api_key_lock: Lock = PrivateAttr(default_factory=Lock)
|
||||
_api_key_index: int = PrivateAttr(default=0)
|
||||
|
||||
@classmethod
|
||||
def validate_base_url(cls, v):
|
||||
"""验证base_url,确保URL格式正确"""
|
||||
@@ -44,11 +48,6 @@ class APIProvider(ValidatedConfigBase):
|
||||
raise ValueError("API密钥必须是字符串或字符串列表")
|
||||
return v
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self._api_key_lock = Lock()
|
||||
self._api_key_index = 0
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
with self._api_key_lock:
|
||||
if isinstance(self.api_key, str):
|
||||
@@ -134,6 +133,7 @@ class ModelTaskConfig(ValidatedConfigBase):
|
||||
replyer_private: TaskConfig = Field(..., description="normal_chat首要回复模型模型配置(私聊使用)")
|
||||
maizone: TaskConfig = Field(..., description="maizone专用模型")
|
||||
emotion: TaskConfig = Field(..., description="情绪模型配置")
|
||||
mood: TaskConfig = Field(..., description="心情模型配置")
|
||||
vlm: TaskConfig = Field(..., description="视觉语言模型配置")
|
||||
voice: TaskConfig = Field(..., description="语音识别模型配置")
|
||||
tool_use: TaskConfig = Field(..., description="专注工具使用模型配置")
|
||||
@@ -178,14 +178,26 @@ class ModelTaskConfig(ValidatedConfigBase):
|
||||
class APIAdapterConfig(ValidatedConfigBase):
|
||||
"""API Adapter配置类"""
|
||||
|
||||
inner: InnerConfig = Field(..., description="配置元信息")
|
||||
models: list[ModelInfo] = Field(..., min_length=1, description="模型列表")
|
||||
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
|
||||
api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表")
|
||||
|
||||
_api_providers_dict: dict[str, APIProvider] = PrivateAttr(default_factory=dict)
|
||||
_models_dict: dict[str, ModelInfo] = PrivateAttr(default_factory=dict)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
|
||||
self.models_dict = {model.name: model for model in self.models}
|
||||
self._api_providers_dict = {provider.name: provider for provider in self.api_providers}
|
||||
self._models_dict = {model.name: model for model in self.models}
|
||||
|
||||
@property
|
||||
def api_providers_dict(self) -> dict[str, APIProvider]:
|
||||
return self._api_providers_dict
|
||||
|
||||
@property
|
||||
def models_dict(self) -> dict[str, ModelInfo]:
|
||||
return self._models_dict
|
||||
|
||||
@classmethod
|
||||
def validate_models_list(cls, v):
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import typing
|
||||
import types
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, get_args, get_origin
|
||||
|
||||
import tomlkit
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from rich.traceback import install
|
||||
from tomlkit import TOMLDocument
|
||||
from tomlkit.items import KeyType, Table
|
||||
@@ -25,6 +29,8 @@ from src.config.official_configs import (
|
||||
EmojiConfig,
|
||||
ExperimentalConfig,
|
||||
ExpressionConfig,
|
||||
InnerConfig,
|
||||
LogConfig,
|
||||
KokoroFlowChatterConfig,
|
||||
LPMMKnowledgeConfig,
|
||||
MemoryConfig,
|
||||
@@ -65,7 +71,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
|
||||
|
||||
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
||||
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
||||
MMC_VERSION = "0.13.1-alpha.1"
|
||||
MMC_VERSION = "0.13.1-alpha.2"
|
||||
|
||||
# 全局配置变量
|
||||
_CONFIG_INITIALIZED = False
|
||||
@@ -180,6 +186,76 @@ def _remove_obsolete_keys(target: TOMLDocument | dict | Table, reference: TOMLDo
|
||||
_remove_obsolete_keys(target[key], reference[key]) # type: ignore
|
||||
|
||||
|
||||
def _prune_unknown_keys_by_schema(target: TOMLDocument | Table, schema_model: type[BaseModel]):
|
||||
"""
|
||||
基于 Pydantic Schema 递归移除未知配置键(含可重复的 AoT 表)。
|
||||
|
||||
说明:
|
||||
- 只移除 schema 中不存在的键,避免跨版本遗留废弃配置项。
|
||||
- 对于 list[BaseModel] 字段(TOML 的 [[...]]),会遍历每个元素并递归清理。
|
||||
- 对于 dict[str, Any] 等自由结构字段,不做键级裁剪。
|
||||
"""
|
||||
|
||||
def _strip_optional(annotation: Any) -> Any:
|
||||
origin = get_origin(annotation)
|
||||
if origin is None:
|
||||
return annotation
|
||||
|
||||
# 兼容 | None 与 Union[..., None]
|
||||
union_type = getattr(types, "UnionType", None)
|
||||
if origin is union_type or origin is typing.Union:
|
||||
args = [a for a in get_args(annotation) if a is not type(None)]
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return annotation
|
||||
|
||||
def _is_model_type(annotation: Any) -> bool:
|
||||
return isinstance(annotation, type) and issubclass(annotation, BaseModel)
|
||||
|
||||
def _prune_table(table: TOMLDocument | Table, model: type[BaseModel]):
|
||||
name_by_key: dict[str, str] = {}
|
||||
allowed_keys: set[str] = set()
|
||||
|
||||
for field_name, field_info in model.model_fields.items():
|
||||
allowed_keys.add(field_name)
|
||||
name_by_key[field_name] = field_name
|
||||
|
||||
alias = getattr(field_info, "alias", None)
|
||||
if isinstance(alias, str) and alias:
|
||||
allowed_keys.add(alias)
|
||||
name_by_key[alias] = field_name
|
||||
|
||||
for key in list(table.keys()):
|
||||
if key not in allowed_keys:
|
||||
del table[key]
|
||||
continue
|
||||
|
||||
field_name = name_by_key[key]
|
||||
field_info = model.model_fields[field_name]
|
||||
annotation = _strip_optional(getattr(field_info, "annotation", Any))
|
||||
|
||||
value = table.get(key)
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
if _is_model_type(annotation) and isinstance(value, (TOMLDocument, Table)):
|
||||
_prune_table(value, annotation)
|
||||
continue
|
||||
|
||||
origin = get_origin(annotation)
|
||||
if origin is list:
|
||||
args = get_args(annotation)
|
||||
elem_ann = _strip_optional(args[0]) if args else Any
|
||||
|
||||
# list[BaseModel] 对应 TOML 的 AoT([[...]])
|
||||
if _is_model_type(elem_ann) and hasattr(value, "__iter__"):
|
||||
for item in value:
|
||||
if isinstance(item, (TOMLDocument, Table)):
|
||||
_prune_table(item, elem_ann)
|
||||
|
||||
_prune_table(target, schema_model)
|
||||
|
||||
|
||||
def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
|
||||
"""
|
||||
将source字典的值更新到target字典中
|
||||
@@ -232,13 +308,14 @@ def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dic
|
||||
target[key] = value
|
||||
|
||||
|
||||
def _update_config_generic(config_name: str, template_name: str):
|
||||
def _update_config_generic(config_name: str, template_name: str, schema_model: type[BaseModel] | None = None):
|
||||
"""
|
||||
通用的配置文件更新函数
|
||||
|
||||
Args:
|
||||
config_name: 配置文件名(不含扩展名),如 'bot_config' 或 'model_config'
|
||||
template_name: 模板文件名(不含扩展名),如 'bot_config_template' 或 'model_config_template'
|
||||
schema_model: 用于裁剪未知键的 Pydantic 模型(避免跨版本遗留废弃配置项)
|
||||
"""
|
||||
# 获取根目录路径
|
||||
old_config_dir = os.path.join(CONFIG_DIR, "old")
|
||||
@@ -355,11 +432,14 @@ def _update_config_generic(config_name: str, template_name: str):
|
||||
logger.info(f"开始合并{config_name}新旧配置...")
|
||||
_update_dict(new_config, old_config)
|
||||
|
||||
# 移除在新模板中已不存在的旧配置项
|
||||
# 移除未知/废弃的旧配置项(尤其是可重复的 [[...]] 段落)
|
||||
logger.info(f"开始移除{config_name}中已废弃的配置项...")
|
||||
with open(template_path, encoding="utf-8") as f:
|
||||
template_doc = tomlkit.load(f)
|
||||
_remove_obsolete_keys(new_config, template_doc)
|
||||
if schema_model is not None:
|
||||
_prune_unknown_keys_by_schema(new_config, schema_model)
|
||||
else:
|
||||
with open(template_path, encoding="utf-8") as f:
|
||||
template_doc = tomlkit.load(f)
|
||||
_remove_obsolete_keys(new_config, template_doc)
|
||||
logger.info(f"已移除{config_name}中已废弃的配置项")
|
||||
|
||||
# 保存更新后的配置(保留注释和格式)
|
||||
@@ -370,18 +450,18 @@ def _update_config_generic(config_name: str, template_name: str):
|
||||
|
||||
def update_config():
|
||||
"""更新bot_config.toml配置文件"""
|
||||
_update_config_generic("bot_config", "bot_config_template")
|
||||
_update_config_generic("bot_config", "bot_config_template", schema_model=Config)
|
||||
|
||||
|
||||
def update_model_config():
|
||||
"""更新model_config.toml配置文件"""
|
||||
_update_config_generic("model_config", "model_config_template")
|
||||
_update_config_generic("model_config", "model_config_template", schema_model=APIAdapterConfig)
|
||||
|
||||
|
||||
class Config(ValidatedConfigBase):
|
||||
"""总配置类"""
|
||||
|
||||
MMC_VERSION: str = Field(default=MMC_VERSION, description="MaiCore版本号")
|
||||
inner: InnerConfig = Field(..., description="配置元信息")
|
||||
|
||||
database: DatabaseConfig = Field(..., description="数据库配置")
|
||||
bot: BotConfig = Field(..., description="机器人基本配置")
|
||||
@@ -397,6 +477,7 @@ class Config(ValidatedConfigBase):
|
||||
chinese_typo: ChineseTypoConfig = Field(..., description="中文错别字配置")
|
||||
response_post_process: ResponsePostProcessConfig = Field(..., description="响应后处理配置")
|
||||
response_splitter: ResponseSplitterConfig = Field(..., description="响应分割配置")
|
||||
log: LogConfig = Field(..., description="日志配置")
|
||||
experimental: ExperimentalConfig = Field(default_factory=lambda: ExperimentalConfig(), description="实验性功能配置")
|
||||
message_bus: MessageBusConfig = Field(..., description="消息总线配置")
|
||||
lpmm_knowledge: LPMMKnowledgeConfig = Field(..., description="LPMM知识配置")
|
||||
@@ -433,18 +514,34 @@ class Config(ValidatedConfigBase):
|
||||
default_factory=lambda: PluginHttpSystemConfig(), description="插件HTTP端点系统配置"
|
||||
)
|
||||
|
||||
@property
|
||||
def MMC_VERSION(self) -> str: # noqa: N802
|
||||
return MMC_VERSION
|
||||
|
||||
|
||||
class APIAdapterConfig(ValidatedConfigBase):
|
||||
"""API Adapter配置类"""
|
||||
|
||||
inner: InnerConfig = Field(..., description="配置元信息")
|
||||
models: list[ModelInfo] = Field(..., min_length=1, description="模型列表")
|
||||
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
|
||||
api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表")
|
||||
|
||||
_api_providers_dict: dict[str, APIProvider] = PrivateAttr(default_factory=dict)
|
||||
_models_dict: dict[str, ModelInfo] = PrivateAttr(default_factory=dict)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
|
||||
self.models_dict = {model.name: model for model in self.models}
|
||||
self._api_providers_dict = {provider.name: provider for provider in self.api_providers}
|
||||
self._models_dict = {model.name: model for model in self.models}
|
||||
|
||||
@property
|
||||
def api_providers_dict(self) -> dict[str, APIProvider]:
|
||||
return self._api_providers_dict
|
||||
|
||||
@property
|
||||
def models_dict(self) -> dict[str, ModelInfo]:
|
||||
return self._models_dict
|
||||
|
||||
@classmethod
|
||||
def validate_models_list(cls, v):
|
||||
@@ -502,9 +599,14 @@ def load_config(config_path: str) -> Config:
|
||||
Returns:
|
||||
Config对象
|
||||
"""
|
||||
# 读取配置文件
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
# 读取配置文件(会自动删除未知/废弃配置项)
|
||||
original_text = Path(config_path).read_text(encoding="utf-8")
|
||||
config_data = tomlkit.parse(original_text)
|
||||
_prune_unknown_keys_by_schema(config_data, Config)
|
||||
new_text = tomlkit.dumps(config_data)
|
||||
if new_text != original_text:
|
||||
Path(config_path).write_text(new_text, encoding="utf-8")
|
||||
logger.warning(f"已自动移除 {config_path} 中未知/废弃配置项")
|
||||
|
||||
# 将 tomlkit 对象转换为纯 Python 字典,避免 Pydantic 严格模式下的类型验证问题
|
||||
# tomlkit 返回的是特殊类型(如 Array、String 等),虽然继承自 Python 标准类型,
|
||||
@@ -530,11 +632,16 @@ def api_ada_load_config(config_path: str) -> APIAdapterConfig:
|
||||
Returns:
|
||||
APIAdapterConfig对象
|
||||
"""
|
||||
# 读取配置文件
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
# 读取配置文件(会自动删除未知/废弃配置项)
|
||||
original_text = Path(config_path).read_text(encoding="utf-8")
|
||||
config_data = tomlkit.parse(original_text)
|
||||
_prune_unknown_keys_by_schema(config_data, APIAdapterConfig)
|
||||
new_text = tomlkit.dumps(config_data)
|
||||
if new_text != original_text:
|
||||
Path(config_path).write_text(new_text, encoding="utf-8")
|
||||
logger.warning(f"已自动移除 {config_path} 中未知/废弃配置项")
|
||||
|
||||
config_dict = dict(config_data)
|
||||
config_dict = config_data.unwrap()
|
||||
|
||||
try:
|
||||
logger.debug("正在解析和验证API适配器配置文件...")
|
||||
|
||||
@@ -142,7 +142,7 @@ class ValidatedConfigBase(BaseModel):
|
||||
"""带验证的配置基类,继承自Pydantic BaseModel"""
|
||||
|
||||
model_config = {
|
||||
"extra": "allow", # 允许额外字段
|
||||
"extra": "forbid", # 禁止额外字段(防止跨版本遗留废弃配置项)
|
||||
"validate_assignment": True, # 验证赋值
|
||||
"arbitrary_types_allowed": True, # 允许任意类型
|
||||
"strict": True, # 如果设为 True 会完全禁用类型转换
|
||||
|
||||
@@ -13,6 +13,12 @@ from src.config.config_base import ValidatedConfigBase
|
||||
"""
|
||||
|
||||
|
||||
class InnerConfig(ValidatedConfigBase):
|
||||
"""配置文件元信息"""
|
||||
|
||||
version: str = Field(..., description="配置文件版本号(用于配置文件升级与兼容性检查)")
|
||||
|
||||
|
||||
class DatabaseConfig(ValidatedConfigBase):
|
||||
"""数据库配置类"""
|
||||
|
||||
@@ -191,9 +197,9 @@ class NoticeConfig(ValidatedConfigBase):
|
||||
enable_notice_trigger_chat: bool = Field(default=True, description="是否允许notice消息触发聊天流程")
|
||||
notice_in_prompt: bool = Field(default=True, description="是否在提示词中展示最近的notice消息")
|
||||
notice_prompt_limit: int = Field(default=5, ge=1, le=20, description="在提示词中展示的最大notice数量")
|
||||
notice_time_window: int = Field(default=3600, ge=60, le=86400, description="notice时间窗口(秒)")
|
||||
notice_time_window: int = Field(default=3600, ge=10, le=86400, description="notice时间窗口(秒)")
|
||||
max_notices_per_chat: int = Field(default=30, ge=10, le=100, description="每个聊天保留的notice数量上限")
|
||||
notice_retention_time: int = Field(default=86400, ge=3600, le=604800, description="notice保留时间(秒)")
|
||||
notice_retention_time: int = Field(default=86400, ge=10, le=604800, description="notice保留时间(秒)")
|
||||
|
||||
|
||||
class ExpressionRule(ValidatedConfigBase):
|
||||
@@ -588,6 +594,20 @@ class ResponseSplitterConfig(ValidatedConfigBase):
|
||||
enable_kaomoji_protection: bool = Field(default=False, description="启用颜文字保护")
|
||||
|
||||
|
||||
class LogConfig(ValidatedConfigBase):
|
||||
"""日志配置类"""
|
||||
|
||||
date_style: str = Field(default="m-d H:i:s", description="日期格式")
|
||||
log_level_style: str = Field(default="lite", description="日志级别样式")
|
||||
color_text: str = Field(default="full", description="日志文本颜色")
|
||||
log_level: str = Field(default="INFO", description="全局日志级别(向下兼容,优先级低于分别设置)")
|
||||
file_retention_days: int = Field(default=7, description="文件日志保留天数,0=禁用文件日志,-1=永不删除")
|
||||
console_log_level: str = Field(default="INFO", description="控制台日志级别")
|
||||
file_log_level: str = Field(default="DEBUG", description="文件日志级别")
|
||||
suppress_libraries: list[str] = Field(default_factory=list, description="完全屏蔽日志的第三方库列表")
|
||||
library_log_levels: dict[str, str] = Field(default_factory=dict, description="设置特定库的日志级别")
|
||||
|
||||
|
||||
class DebugConfig(ValidatedConfigBase):
|
||||
"""调试配置类"""
|
||||
|
||||
@@ -703,6 +723,7 @@ class WebSearchConfig(ValidatedConfigBase):
|
||||
enable_url_tool: bool = Field(default=True, description="启用URL工具")
|
||||
tavily_api_keys: list[str] = Field(default_factory=lambda: [], description="Tavily API密钥列表,支持轮询机制")
|
||||
exa_api_keys: list[str] = Field(default_factory=lambda: [], description="exa API密钥列表,支持轮询机制")
|
||||
metaso_api_keys: list[str] = Field(default_factory=lambda: [], description="Metaso API密钥列表,支持轮询机制")
|
||||
searxng_instances: list[str] = Field(default_factory=list, description="SearXNG 实例 URL 列表")
|
||||
searxng_api_keys: list[str] = Field(default_factory=list, description="SearXNG 实例 API 密钥列表")
|
||||
serper_api_keys: list[str] = Field(default_factory=list, description="serper API 密钥列表")
|
||||
@@ -988,6 +1009,12 @@ class KokoroFlowChatterConfig(ValidatedConfigBase):
|
||||
description="开启后KFC将接管所有私聊消息;关闭后私聊消息将由AFC处理"
|
||||
)
|
||||
|
||||
# --- 工作模式 ---
|
||||
mode: Literal["unified", "split"] = Field(
|
||||
default="split",
|
||||
description='工作模式: "unified"(单次调用) 或 "split"(planner+replyer两次调用)',
|
||||
)
|
||||
|
||||
# --- 核心行为配置 ---
|
||||
max_wait_seconds_default: int = Field(
|
||||
default=300, ge=30, le=3600,
|
||||
@@ -998,6 +1025,12 @@ class KokoroFlowChatterConfig(ValidatedConfigBase):
|
||||
description="是否在等待期间启用心理活动更新"
|
||||
)
|
||||
|
||||
# --- 自定义决策提示词 ---
|
||||
custom_decision_prompt: str = Field(
|
||||
default="",
|
||||
description="自定义KFC决策行为指导提示词(unified影响整体,split仅影响planner)",
|
||||
)
|
||||
|
||||
waiting: KokoroFlowChatterWaitingConfig = Field(
|
||||
default_factory=KokoroFlowChatterWaitingConfig,
|
||||
description="等待策略配置(默认等待时间、倍率等)",
|
||||
|
||||
@@ -451,7 +451,7 @@ class UnifiedMemoryManager:
|
||||
(0.3, 10.0, 0.4),
|
||||
(0.1, 15.0, 0.6),
|
||||
]
|
||||
|
||||
|
||||
for threshold, min_val, factor in occupancy_thresholds:
|
||||
if occupancy >= threshold:
|
||||
return max(min_val, base_interval * factor)
|
||||
@@ -461,24 +461,24 @@ class UnifiedMemoryManager:
|
||||
async def _transfer_blocks_to_short_term(self, blocks: list[MemoryBlock]) -> None:
|
||||
"""实际转换逻辑在后台执行(优化:并行处理多个块,批量触发唤醒)"""
|
||||
logger.debug(f"正在后台处理 {len(blocks)} 个感知记忆块")
|
||||
|
||||
|
||||
# 优化:使用 asyncio.gather 并行处理转移
|
||||
async def _transfer_single(block: MemoryBlock) -> tuple[MemoryBlock, bool]:
|
||||
try:
|
||||
stm = await self.short_term_manager.add_from_block(block)
|
||||
if not stm:
|
||||
return block, False
|
||||
|
||||
|
||||
await self.perceptual_manager.remove_block(block.id)
|
||||
logger.debug(f"✓ 记忆块 {block.id} 已被转移到短期记忆 {stm.id}")
|
||||
return block, True
|
||||
except Exception as exc:
|
||||
logger.error(f"后台转移失败,记忆块 {block.id}: {exc}")
|
||||
return block, False
|
||||
|
||||
|
||||
# 并行处理所有块
|
||||
results = await asyncio.gather(*[_transfer_single(block) for block in blocks], return_exceptions=True)
|
||||
|
||||
|
||||
# 统计成功的转移
|
||||
success_count = sum(1 for result in results if isinstance(result, tuple) and result[1])
|
||||
if success_count > 0:
|
||||
@@ -491,7 +491,7 @@ class UnifiedMemoryManager:
|
||||
seen = set()
|
||||
decay = 0.15
|
||||
manual_queries: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
for raw in queries:
|
||||
text = (raw or "").strip()
|
||||
if text and text not in seen:
|
||||
@@ -517,7 +517,7 @@ class UnifiedMemoryManager:
|
||||
"top_k": self._config["long_term"]["search_top_k"],
|
||||
"use_multi_query": bool(manual_queries),
|
||||
}
|
||||
|
||||
|
||||
if recent_chat_history or manual_queries:
|
||||
context: dict[str, Any] = {}
|
||||
if recent_chat_history:
|
||||
@@ -541,7 +541,7 @@ class UnifiedMemoryManager:
|
||||
mem_id = mem.get("id")
|
||||
else:
|
||||
mem_id = getattr(mem, "id", None)
|
||||
|
||||
|
||||
# 检查去重
|
||||
if mem_id and mem_id in seen_ids:
|
||||
continue
|
||||
@@ -600,7 +600,7 @@ class UnifiedMemoryManager:
|
||||
new_memories.append(memory)
|
||||
if mem_id:
|
||||
cached_ids.add(mem_id)
|
||||
|
||||
|
||||
if new_memories:
|
||||
transfer_cache.extend(new_memories)
|
||||
logger.debug(
|
||||
@@ -632,7 +632,7 @@ class UnifiedMemoryManager:
|
||||
await self.short_term_manager.clear_transferred_memories(
|
||||
result["transferred_memory_ids"]
|
||||
)
|
||||
|
||||
|
||||
# 优化:使用生成器表达式保留未转移的记忆
|
||||
transfer_cache = [
|
||||
m
|
||||
|
||||
Reference in New Issue
Block a user