启用数据库预加载器,清理日志
This commit is contained in:
111
bot.py
111
bot.py
@@ -14,12 +14,30 @@ from rich.traceback import install
|
|||||||
|
|
||||||
# 初始化日志系统
|
# 初始化日志系统
|
||||||
from src.common.logger import get_logger, initialize_logging, shutdown_logging
|
from src.common.logger import get_logger, initialize_logging, shutdown_logging
|
||||||
|
from src.config.config import MMC_VERSION, global_config, model_config
|
||||||
|
|
||||||
# 初始化日志和错误显示
|
# 初始化日志和错误显示
|
||||||
initialize_logging()
|
initialize_logging()
|
||||||
logger = get_logger("main")
|
logger = get_logger("main")
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
||||||
|
class StartupStageReporter:
|
||||||
|
"""启动阶段报告器"""
|
||||||
|
|
||||||
|
def __init__(self, bound_logger):
|
||||||
|
self._logger = bound_logger
|
||||||
|
|
||||||
|
def emit(self, title: str, **details):
|
||||||
|
detail_pairs = [f"{key}={value}" for key, value in details.items() if value not in (None, "")]
|
||||||
|
if detail_pairs:
|
||||||
|
self._logger.info(f"{title} ({', '.join(detail_pairs)})")
|
||||||
|
else:
|
||||||
|
self._logger.info(title)
|
||||||
|
|
||||||
|
|
||||||
|
startup_stage = StartupStageReporter(logger)
|
||||||
|
|
||||||
# 常量定义
|
# 常量定义
|
||||||
SUPPORTED_DATABASES = ["sqlite", "postgresql"]
|
SUPPORTED_DATABASES = ["sqlite", "postgresql"]
|
||||||
SHUTDOWN_TIMEOUT = 10.0
|
SHUTDOWN_TIMEOUT = 10.0
|
||||||
@@ -30,7 +48,7 @@ MAX_ENV_FILE_SIZE = 1024 * 1024 # 1MB限制
|
|||||||
# 设置工作目录为脚本所在目录
|
# 设置工作目录为脚本所在目录
|
||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
os.chdir(script_dir)
|
os.chdir(script_dir)
|
||||||
logger.info("工作目录已设置")
|
logger.debug("工作目录已设置")
|
||||||
|
|
||||||
|
|
||||||
class ConfigManager:
|
class ConfigManager:
|
||||||
@@ -44,7 +62,7 @@ class ConfigManager:
|
|||||||
|
|
||||||
if not env_file.exists():
|
if not env_file.exists():
|
||||||
if template_env.exists():
|
if template_env.exists():
|
||||||
logger.info("未找到.env文件,正在从模板创建...")
|
logger.debug("未找到.env文件,正在从模板创建...")
|
||||||
try:
|
try:
|
||||||
env_file.write_text(template_env.read_text(encoding="utf-8"), encoding="utf-8")
|
env_file.write_text(template_env.read_text(encoding="utf-8"), encoding="utf-8")
|
||||||
logger.info("已从template/template.env创建.env文件")
|
logger.info("已从template/template.env创建.env文件")
|
||||||
@@ -90,7 +108,7 @@ class ConfigManager:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
logger.info("环境变量加载成功")
|
logger.debug("环境变量加载成功")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"加载环境变量失败: {e}")
|
logger.error(f"加载环境变量失败: {e}")
|
||||||
@@ -113,7 +131,7 @@ class EULAManager:
|
|||||||
# 从 os.environ 读取(避免重复 I/O)
|
# 从 os.environ 读取(避免重复 I/O)
|
||||||
eula_confirmed = os.getenv("EULA_CONFIRMED", "").lower()
|
eula_confirmed = os.getenv("EULA_CONFIRMED", "").lower()
|
||||||
if eula_confirmed == "true":
|
if eula_confirmed == "true":
|
||||||
logger.info("EULA已通过环境变量确认")
|
logger.debug("EULA已通过环境变量确认")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 提示用户确认EULA
|
# 提示用户确认EULA
|
||||||
@@ -290,7 +308,7 @@ class DatabaseManager:
|
|||||||
from src.common.database.core import check_and_migrate_database as initialize_sql_database
|
from src.common.database.core import check_and_migrate_database as initialize_sql_database
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
logger.info("正在初始化数据库连接...")
|
logger.debug("正在初始化数据库连接...")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# 使用线程执行器运行潜在的阻塞操作
|
# 使用线程执行器运行潜在的阻塞操作
|
||||||
@@ -421,10 +439,10 @@ class WebUIManager:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
if WebUIManager._process and WebUIManager._process.returncode is None:
|
if WebUIManager._process and WebUIManager._process.returncode is None:
|
||||||
logger.info("WebUI 开发服务器已在运行,跳过重复启动")
|
logger.debug("WebUI 开发服务器已在运行,跳过重复启动")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
logger.info(f"正在启动 WebUI 开发服务器: npm run dev (cwd={webui_dir})")
|
logger.debug(f"正在启动 WebUI 开发服务器: npm run dev (cwd={webui_dir})")
|
||||||
npm_exe = "npm.cmd" if platform.system().lower() == "windows" else "npm"
|
npm_exe = "npm.cmd" if platform.system().lower() == "windows" else "npm"
|
||||||
proc = await asyncio.create_subprocess_exec(
|
proc = await asyncio.create_subprocess_exec(
|
||||||
npm_exe,
|
npm_exe,
|
||||||
@@ -475,7 +493,7 @@ class WebUIManager:
|
|||||||
|
|
||||||
if line:
|
if line:
|
||||||
text = line.decode(errors="ignore").rstrip()
|
text = line.decode(errors="ignore").rstrip()
|
||||||
logger.info(f"[webui] {text}")
|
logger.debug(f"[webui] {text}")
|
||||||
low = text.lower()
|
low = text.lower()
|
||||||
if any(k in low for k in success_keywords):
|
if any(k in low for k in success_keywords):
|
||||||
detected_success = True
|
detected_success = True
|
||||||
@@ -496,7 +514,7 @@ class WebUIManager:
|
|||||||
if not line:
|
if not line:
|
||||||
break
|
break
|
||||||
text = line.decode(errors="ignore").rstrip()
|
text = line.decode(errors="ignore").rstrip()
|
||||||
logger.info(f"[webui] {text}")
|
logger.debug(f"[webui] {text}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"webui 日志读取停止: {e}")
|
logger.debug(f"webui 日志读取停止: {e}")
|
||||||
|
|
||||||
@@ -538,7 +556,7 @@ class WebUIManager:
|
|||||||
await WebUIManager._drain_task
|
await WebUIManager._drain_task
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
logger.info("WebUI 开发服务器已停止")
|
logger.debug("WebUI 开发服务器已停止")
|
||||||
return True
|
return True
|
||||||
finally:
|
finally:
|
||||||
WebUIManager._process = None
|
WebUIManager._process = None
|
||||||
@@ -555,22 +573,71 @@ class MaiBotMain:
|
|||||||
try:
|
try:
|
||||||
if platform.system().lower() != "windows":
|
if platform.system().lower() != "windows":
|
||||||
time.tzset() # type: ignore
|
time.tzset() # type: ignore
|
||||||
logger.info("时区设置完成")
|
logger.debug("时区设置完成")
|
||||||
else:
|
else:
|
||||||
logger.info("Windows系统,跳过时区设置")
|
logger.debug("Windows系统,跳过时区设置")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"时区设置失败: {e}")
|
logger.warning(f"时区设置失败: {e}")
|
||||||
|
|
||||||
|
def _emit_config_summary(self):
|
||||||
|
"""输出配置加载阶段摘要"""
|
||||||
|
if not global_config:
|
||||||
|
return
|
||||||
|
|
||||||
|
bot_cfg = getattr(global_config, "bot", None)
|
||||||
|
db_cfg = getattr(global_config, "database", None)
|
||||||
|
platform = getattr(bot_cfg, "platform", "unknown") if bot_cfg else "unknown"
|
||||||
|
nickname = getattr(bot_cfg, "nickname", "unknown") if bot_cfg else "unknown"
|
||||||
|
db_type = getattr(db_cfg, "database_type", "unknown") if db_cfg else "unknown"
|
||||||
|
model_count = len(getattr(model_config, "models", []) or [])
|
||||||
|
|
||||||
|
startup_stage.emit(
|
||||||
|
"配置加载完成",
|
||||||
|
platform=platform,
|
||||||
|
nickname=nickname,
|
||||||
|
database=db_type,
|
||||||
|
models=model_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _emit_component_summary(self):
|
||||||
|
"""输出组件初始化阶段摘要"""
|
||||||
|
adapter_total = running_adapters = 0
|
||||||
|
plugin_total = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.plugin_system.core.adapter_manager import get_adapter_manager
|
||||||
|
|
||||||
|
adapter_state = get_adapter_manager().list_adapters()
|
||||||
|
adapter_total = len(adapter_state)
|
||||||
|
running_adapters = sum(1 for info in adapter_state.values() if info.get("running"))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug(f"统计适配器信息失败: {exc}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||||
|
|
||||||
|
plugin_total = len(plugin_manager.list_loaded_plugins())
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug(f"统计插件信息失败: {exc}")
|
||||||
|
|
||||||
|
startup_stage.emit(
|
||||||
|
"核心组件初始化完成",
|
||||||
|
adapters=adapter_total,
|
||||||
|
running=running_adapters,
|
||||||
|
plugins=plugin_total,
|
||||||
|
)
|
||||||
|
|
||||||
async def initialize_database_async(self):
|
async def initialize_database_async(self):
|
||||||
"""异步初始化数据库表结构"""
|
"""异步初始化数据库表结构"""
|
||||||
logger.info("正在初始化数据库表结构...")
|
logger.debug("正在初始化数据库表结构")
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
from src.common.database.core import check_and_migrate_database
|
from src.common.database.core import check_and_migrate_database
|
||||||
|
|
||||||
await check_and_migrate_database()
|
await check_and_migrate_database()
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
logger.info(f"数据库表结构初始化完成,耗时: {elapsed_time:.2f}秒")
|
db_type = getattr(getattr(global_config, "database", None), "database_type", "unknown")
|
||||||
|
startup_stage.emit("数据库就绪", engine=db_type, elapsed=f"{elapsed_time:.2f}s")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"数据库表结构初始化失败: {e}")
|
logger.error(f"数据库表结构初始化失败: {e}")
|
||||||
raise
|
raise
|
||||||
@@ -590,6 +657,7 @@ class MaiBotMain:
|
|||||||
if not ConfigurationValidator.validate_configuration():
|
if not ConfigurationValidator.validate_configuration():
|
||||||
raise RuntimeError("配置验证失败,请检查配置文件")
|
raise RuntimeError("配置验证失败,请检查配置文件")
|
||||||
|
|
||||||
|
self._emit_config_summary()
|
||||||
return self.create_main_system()
|
return self.create_main_system()
|
||||||
|
|
||||||
async def run_async_init(self, main_system):
|
async def run_async_init(self, main_system):
|
||||||
@@ -600,6 +668,7 @@ class MaiBotMain:
|
|||||||
|
|
||||||
# 初始化主系统
|
# 初始化主系统
|
||||||
await main_system.initialize()
|
await main_system.initialize()
|
||||||
|
self._emit_component_summary()
|
||||||
|
|
||||||
# 显示彩蛋
|
# 显示彩蛋
|
||||||
EasterEgg.show()
|
EasterEgg.show()
|
||||||
@@ -609,7 +678,7 @@ async def wait_for_user_input():
|
|||||||
"""等待用户输入(异步方式)"""
|
"""等待用户输入(异步方式)"""
|
||||||
try:
|
try:
|
||||||
if os.getenv("ENVIRONMENT") != "production":
|
if os.getenv("ENVIRONMENT") != "production":
|
||||||
logger.info("程序执行完成,按 Ctrl+C 退出...")
|
logger.debug("程序执行完成,按 Ctrl+C 退出...")
|
||||||
# 使用 asyncio.Event 而不是 sleep 循环
|
# 使用 asyncio.Event 而不是 sleep 循环
|
||||||
shutdown_event = asyncio.Event()
|
shutdown_event = asyncio.Event()
|
||||||
await shutdown_event.wait()
|
await shutdown_event.wait()
|
||||||
@@ -646,7 +715,17 @@ async def main_async():
|
|||||||
|
|
||||||
# 运行主任务
|
# 运行主任务
|
||||||
main_task = asyncio.create_task(main_system.schedule_tasks())
|
main_task = asyncio.create_task(main_system.schedule_tasks())
|
||||||
logger.info("麦麦机器人启动完成,开始运行主任务...")
|
bot_cfg = getattr(global_config, "bot", None)
|
||||||
|
platform = getattr(bot_cfg, "platform", "unknown") if bot_cfg else "unknown"
|
||||||
|
nickname = getattr(bot_cfg, "nickname", "MoFox") if bot_cfg else "MoFox"
|
||||||
|
version = getattr(global_config, "MMC_VERSION", MMC_VERSION) if global_config else MMC_VERSION
|
||||||
|
startup_stage.emit(
|
||||||
|
"MoFox 已成功启动",
|
||||||
|
version=version,
|
||||||
|
platform=platform,
|
||||||
|
nickname=nickname,
|
||||||
|
)
|
||||||
|
logger.debug("麦麦机器人启动完成,开始运行主任务")
|
||||||
|
|
||||||
# 同时运行主任务和用户输入等待
|
# 同时运行主任务和用户输入等待
|
||||||
user_input_done = asyncio.create_task(wait_for_user_input())
|
user_input_done = asyncio.create_task(wait_for_user_input())
|
||||||
|
|||||||
@@ -31,12 +31,10 @@ async def clean_permission_nodes():
|
|||||||
|
|
||||||
deleted_count = getattr(result, "rowcount", 0)
|
deleted_count = getattr(result, "rowcount", 0)
|
||||||
logger.info(f"✅ 已清理 {deleted_count} 个权限节点记录")
|
logger.info(f"✅ 已清理 {deleted_count} 个权限节点记录")
|
||||||
print(f"✅ 已清理 {deleted_count} 个权限节点记录")
|
|
||||||
print("请重启应用以重新注册权限节点")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"❌ 清理权限节点失败: {e}")
|
logger.error(f"❌ 清理权限节点失败: {e}")
|
||||||
print(f"❌ 清理权限节点失败: {e}")
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -415,7 +415,6 @@ class EmojiManager:
|
|||||||
self.emoji_num_max = global_config.emoji.max_reg_num
|
self.emoji_num_max = global_config.emoji.max_reg_num
|
||||||
self.emoji_num_max_reach_deletion = global_config.emoji.do_replace
|
self.emoji_num_max_reach_deletion = global_config.emoji.do_replace
|
||||||
self.emoji_objects: list[MaiEmoji] = [] # 存储MaiEmoji对象的列表,使用类型注解明确列表元素类型
|
self.emoji_objects: list[MaiEmoji] = [] # 存储MaiEmoji对象的列表,使用类型注解明确列表元素类型
|
||||||
logger.info("启动表情包管理器")
|
|
||||||
_ensure_emoji_dir()
|
_ensure_emoji_dir()
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
logger.info("启动表情包管理器")
|
logger.info("启动表情包管理器")
|
||||||
@@ -531,8 +530,8 @@ class EmojiManager:
|
|||||||
|
|
||||||
# 4. 调用LLM进行决策
|
# 4. 调用LLM进行决策
|
||||||
decision, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=0.5, max_tokens=20)
|
decision, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=0.5, max_tokens=20)
|
||||||
logger.info(f"LLM选择的描述: {text_emotion}")
|
logger.debug(f"LLM选择的描述: {text_emotion}")
|
||||||
logger.info(f"LLM决策结果: {decision}")
|
logger.debug(f"LLM决策结果: {decision}")
|
||||||
|
|
||||||
# 5. 解析LLM的决策结果
|
# 5. 解析LLM的决策结果
|
||||||
match = re.search(r"(\d+)", decision)
|
match = re.search(r"(\d+)", decision)
|
||||||
@@ -773,7 +772,7 @@ class EmojiManager:
|
|||||||
# 先从内存中查找
|
# 先从内存中查找
|
||||||
emoji = await self.get_emoji_from_manager(emoji_hash)
|
emoji = await self.get_emoji_from_manager(emoji_hash)
|
||||||
if emoji and emoji.emotion:
|
if emoji and emoji.emotion:
|
||||||
logger.info(f"[缓存命中] 从内存获取表情包描述: {emoji.emotion}...")
|
logger.debug(f"[缓存命中] 从内存获取表情包描述: {emoji.emotion}...")
|
||||||
return ",".join(emoji.emotion)
|
return ",".join(emoji.emotion)
|
||||||
|
|
||||||
# 如果内存中没有,从数据库查找
|
# 如果内存中没有,从数据库查找
|
||||||
@@ -781,7 +780,7 @@ class EmojiManager:
|
|||||||
emoji_record = await self.get_emoji_from_db(emoji_hash)
|
emoji_record = await self.get_emoji_from_db(emoji_hash)
|
||||||
if emoji_record and emoji_record[0].emotion:
|
if emoji_record and emoji_record[0].emotion:
|
||||||
emotion_str = ",".join(emoji_record[0].emotion)
|
emotion_str = ",".join(emoji_record[0].emotion)
|
||||||
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emotion_str[:50]}...")
|
logger.debug(f"[缓存命中] 从数据库获取表情包描述: {emotion_str[:50]}...")
|
||||||
return emotion_str
|
return emotion_str
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"从数据库查询表情包描述时出错: {e}")
|
logger.error(f"从数据库查询表情包描述时出错: {e}")
|
||||||
@@ -806,7 +805,7 @@ class EmojiManager:
|
|||||||
# 先从内存中查找
|
# 先从内存中查找
|
||||||
emoji = await self.get_emoji_from_manager(emoji_hash)
|
emoji = await self.get_emoji_from_manager(emoji_hash)
|
||||||
if emoji and emoji.description:
|
if emoji and emoji.description:
|
||||||
logger.info(f"[缓存命中] 从内存获取表情包描述: {emoji.description[:50]}...")
|
logger.debug(f"[缓存命中] 从内存获取表情包描述: {emoji.description[:50]}...")
|
||||||
return emoji.description
|
return emoji.description
|
||||||
|
|
||||||
# 如果内存中没有,从数据库查找(使用 QueryBuilder 启用数据库缓存)
|
# 如果内存中没有,从数据库查找(使用 QueryBuilder 启用数据库缓存)
|
||||||
@@ -815,7 +814,7 @@ class EmojiManager:
|
|||||||
|
|
||||||
emoji_record = cast(Emoji | None, await QueryBuilder(Emoji).filter(emoji_hash=emoji_hash).first())
|
emoji_record = cast(Emoji | None, await QueryBuilder(Emoji).filter(emoji_hash=emoji_hash).first())
|
||||||
if emoji_record and emoji_record.description:
|
if emoji_record and emoji_record.description:
|
||||||
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
|
logger.debug(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
|
||||||
return emoji_record.description
|
return emoji_record.description
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"从数据库查询表情包描述时出错: {e}")
|
logger.error(f"从数据库查询表情包描述时出错: {e}")
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ class BotInterestManager:
|
|||||||
|
|
||||||
# 检查embedding配置是否存在
|
# 检查embedding配置是否存在
|
||||||
if not hasattr(model_config.model_task_config, "embedding"):
|
if not hasattr(model_config.model_task_config, "embedding"):
|
||||||
raise RuntimeError("❌ 未找到embedding模型配置")
|
raise RuntimeError("未找到embedding模型配置")
|
||||||
|
|
||||||
self.embedding_config = model_config.model_task_config.embedding
|
self.embedding_config = model_config.model_task_config.embedding
|
||||||
|
|
||||||
@@ -127,7 +127,7 @@ class BotInterestManager:
|
|||||||
logger.debug("正在保存至数据库...")
|
logger.debug("正在保存至数据库...")
|
||||||
await self._save_interests_to_database(generated_interests)
|
await self._save_interests_to_database(generated_interests)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("❌ 兴趣标签生成失败")
|
raise RuntimeError("兴趣标签生成失败")
|
||||||
|
|
||||||
async def _generate_interests_from_personality(
|
async def _generate_interests_from_personality(
|
||||||
self, personality_description: str, personality_id: str
|
self, personality_description: str, personality_id: str
|
||||||
@@ -138,7 +138,7 @@ class BotInterestManager:
|
|||||||
|
|
||||||
# 检查embedding客户端是否可用
|
# 检查embedding客户端是否可用
|
||||||
if not hasattr(self, "embedding_request"):
|
if not hasattr(self, "embedding_request"):
|
||||||
raise RuntimeError("❌ Embedding客户端未初始化,无法生成兴趣标签")
|
raise RuntimeError("Embedding客户端未初始化,无法生成兴趣标签")
|
||||||
|
|
||||||
# 构建提示词
|
# 构建提示词
|
||||||
prompt = f"""
|
prompt = f"""
|
||||||
@@ -284,10 +284,10 @@ class BotInterestManager:
|
|||||||
provider = model_config.get_provider(model_info.api_provider)
|
provider = model_config.get_provider(model_info.api_provider)
|
||||||
original_timeouts[provider.name] = provider.timeout
|
original_timeouts[provider.name] = provider.timeout
|
||||||
if provider.timeout < INIT_TIMEOUT:
|
if provider.timeout < INIT_TIMEOUT:
|
||||||
logger.debug(f"⏱️ 临时增加 API provider '{provider.name}' 超时: {provider.timeout}s → {INIT_TIMEOUT}s")
|
logger.debug(f"临时增加 API provider '{provider.name}' 超时: {provider.timeout}s → {INIT_TIMEOUT}s")
|
||||||
provider.timeout = INIT_TIMEOUT
|
provider.timeout = INIT_TIMEOUT
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"⚠️ 无法修改模型 '{model_name}' 的超时设置: {e}")
|
logger.warning(f"无法修改模型 '{model_name}' 的超时设置: {e}")
|
||||||
|
|
||||||
# 调用LLM API
|
# 调用LLM API
|
||||||
success, response, reasoning_content, model_name = await llm_api.generate_with_model(
|
success, response, reasoning_content, model_name = await llm_api.generate_with_model(
|
||||||
@@ -303,28 +303,28 @@ class BotInterestManager:
|
|||||||
try:
|
try:
|
||||||
provider = model_config.get_provider(provider_name)
|
provider = model_config.get_provider(provider_name)
|
||||||
if provider.timeout != original_timeout:
|
if provider.timeout != original_timeout:
|
||||||
logger.debug(f"⏱️ 恢复 API provider '{provider_name}' 超时: {provider.timeout}s → {original_timeout}s")
|
logger.debug(f"恢复 API provider '{provider_name}' 超时: {provider.timeout}s → {original_timeout}s")
|
||||||
provider.timeout = original_timeout
|
provider.timeout = original_timeout
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"⚠️ 无法恢复 provider '{provider_name}' 的超时设置: {e}")
|
logger.warning(f"无法恢复 provider '{provider_name}' 的超时设置: {e}")
|
||||||
|
|
||||||
if success and response:
|
if success and response:
|
||||||
# 直接返回原始响应,后续使用统一的 JSON 解析工具
|
# 直接返回原始响应,后续使用统一的 JSON 解析工具
|
||||||
return response
|
return response
|
||||||
else:
|
else:
|
||||||
logger.warning("⚠️ LLM返回空响应或调用失败")
|
logger.warning("LLM返回空响应或调用失败")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"❌ 调用LLM生成兴趣标签失败: {e}")
|
logger.error(f"调用LLM生成兴趣标签失败: {e}")
|
||||||
logger.error("🔍 错误详情:")
|
logger.error("错误详情:")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _generate_embeddings_for_tags(self, interests: BotPersonalityInterests):
|
async def _generate_embeddings_for_tags(self, interests: BotPersonalityInterests):
|
||||||
"""为所有兴趣标签生成embedding(缓存在内存和文件中)"""
|
"""为所有兴趣标签生成embedding(缓存在内存和文件中)"""
|
||||||
if not hasattr(self, "embedding_request"):
|
if not hasattr(self, "embedding_request"):
|
||||||
raise RuntimeError("❌ Embedding客户端未初始化,无法生成embedding")
|
raise RuntimeError("Embedding客户端未初始化,无法生成embedding")
|
||||||
|
|
||||||
total_tags = len(interests.interest_tags)
|
total_tags = len(interests.interest_tags)
|
||||||
|
|
||||||
@@ -335,7 +335,7 @@ class BotInterestManager:
|
|||||||
filtered_cache = {key: value for key, value in file_cache.items() if key in allowed_keys}
|
filtered_cache = {key: value for key, value in file_cache.items() if key in allowed_keys}
|
||||||
dropped_cache = len(file_cache) - len(filtered_cache)
|
dropped_cache = len(file_cache) - len(filtered_cache)
|
||||||
if dropped_cache > 0:
|
if dropped_cache > 0:
|
||||||
logger.debug(f"🧹 跳过 {dropped_cache} 个与当前兴趣标签无关的缓存embedding")
|
logger.debug(f"跳过 {dropped_cache} 个与当前兴趣标签无关的缓存embedding")
|
||||||
self.embedding_cache.update(filtered_cache)
|
self.embedding_cache.update(filtered_cache)
|
||||||
|
|
||||||
memory_cached_count = 0
|
memory_cached_count = 0
|
||||||
@@ -349,10 +349,10 @@ class BotInterestManager:
|
|||||||
tag.embedding = self.embedding_cache[tag.tag_name]
|
tag.embedding = self.embedding_cache[tag.tag_name]
|
||||||
if file_cache and tag.tag_name in file_cache:
|
if file_cache and tag.tag_name in file_cache:
|
||||||
file_cached_count += 1
|
file_cached_count += 1
|
||||||
logger.debug(f" [{i}/{total_tags}] 📂 '{tag.tag_name}' - 使用文件缓存")
|
logger.debug(f" [{i}/{total_tags}] '{tag.tag_name}' - 使用文件缓存")
|
||||||
else:
|
else:
|
||||||
memory_cached_count += 1
|
memory_cached_count += 1
|
||||||
logger.debug(f" [{i}/{total_tags}] 💾 '{tag.tag_name}' - 使用内存缓存")
|
logger.debug(f" [{i}/{total_tags}] '{tag.tag_name}' - 使用内存缓存")
|
||||||
else:
|
else:
|
||||||
# 动态生成新的embedding
|
# 动态生成新的embedding
|
||||||
embedding_text = tag.tag_name
|
embedding_text = tag.tag_name
|
||||||
@@ -362,13 +362,13 @@ class BotInterestManager:
|
|||||||
tag.embedding = embedding # 设置到 tag 对象(内存中)
|
tag.embedding = embedding # 设置到 tag 对象(内存中)
|
||||||
self.embedding_cache[tag.tag_name] = embedding # 同时缓存到内存
|
self.embedding_cache[tag.tag_name] = embedding # 同时缓存到内存
|
||||||
generated_count += 1
|
generated_count += 1
|
||||||
logger.debug(f" ✅ '{tag.tag_name}' embedding动态生成成功")
|
logger.debug(f"'{tag.tag_name}' embedding动态生成成功")
|
||||||
else:
|
else:
|
||||||
failed_count += 1
|
failed_count += 1
|
||||||
logger.warning(f" ❌ '{tag.tag_name}' embedding生成失败")
|
logger.warning(f"'{tag.tag_name}' embedding生成失败")
|
||||||
|
|
||||||
if failed_count > 0:
|
if failed_count > 0:
|
||||||
raise RuntimeError(f"❌ 有 {failed_count} 个兴趣标签embedding生成失败")
|
raise RuntimeError(f"有 {failed_count} 个兴趣标签embedding生成失败")
|
||||||
|
|
||||||
# 如果有新生成的embedding,保存到文件
|
# 如果有新生成的embedding,保存到文件
|
||||||
if generated_count > 0:
|
if generated_count > 0:
|
||||||
@@ -382,7 +382,7 @@ class BotInterestManager:
|
|||||||
cache=False 用于消息内容,避免在 embedding_cache 中长期保留大文本导致内存膨胀。
|
cache=False 用于消息内容,避免在 embedding_cache 中长期保留大文本导致内存膨胀。
|
||||||
"""
|
"""
|
||||||
if not hasattr(self, "embedding_request"):
|
if not hasattr(self, "embedding_request"):
|
||||||
raise RuntimeError("❌ Embedding请求客户端未初始化")
|
raise RuntimeError("Embedding请求客户端未初始化")
|
||||||
|
|
||||||
# 检查缓存
|
# 检查缓存
|
||||||
if cache and text in self.embedding_cache:
|
if cache and text in self.embedding_cache:
|
||||||
@@ -390,7 +390,7 @@ class BotInterestManager:
|
|||||||
|
|
||||||
# 使用LLMRequest获取embedding
|
# 使用LLMRequest获取embedding
|
||||||
if not self.embedding_request:
|
if not self.embedding_request:
|
||||||
raise RuntimeError("❌ Embedding客户端未初始化")
|
raise RuntimeError("Embedding客户端未初始化")
|
||||||
embedding, model_name = await self.embedding_request.get_embedding(text)
|
embedding, model_name = await self.embedding_request.get_embedding(text)
|
||||||
|
|
||||||
if embedding and len(embedding) > 0:
|
if embedding and len(embedding) > 0:
|
||||||
@@ -409,7 +409,7 @@ class BotInterestManager:
|
|||||||
self._detected_embedding_dimension = current_dim
|
self._detected_embedding_dimension = current_dim
|
||||||
if self.embedding_dimension and self.embedding_dimension != current_dim:
|
if self.embedding_dimension and self.embedding_dimension != current_dim:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"⚠️ 实际embedding维度(%d)与配置值(%d)不一致,请在 model_config.model_task_config.embedding.embedding_dimension 中同步更新",
|
"实际embedding维度(%d)与配置值(%d)不一致,请在 model_config.model_task_config.embedding.embedding_dimension 中同步更新",
|
||||||
current_dim,
|
current_dim,
|
||||||
self.embedding_dimension,
|
self.embedding_dimension,
|
||||||
)
|
)
|
||||||
@@ -417,13 +417,13 @@ class BotInterestManager:
|
|||||||
self.embedding_dimension = current_dim
|
self.embedding_dimension = current_dim
|
||||||
elif current_dim != self.embedding_dimension:
|
elif current_dim != self.embedding_dimension:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"⚠️ 收到的embedding维度发生变化: 之前=%d, 当前=%d。请确认模型配置是否正确。",
|
"收到的embedding维度发生变化: 之前=%d, 当前=%d。请确认模型配置是否正确。",
|
||||||
self.embedding_dimension,
|
self.embedding_dimension,
|
||||||
current_dim,
|
current_dim,
|
||||||
)
|
)
|
||||||
return embedding_float
|
return embedding_float
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"❌ 返回的embedding为空: {embedding}")
|
raise RuntimeError(f"返回的embedding为空: {embedding}")
|
||||||
|
|
||||||
async def _generate_message_embedding(self, message_text: str, keywords: list[str]) -> list[float]:
|
async def _generate_message_embedding(self, message_text: str, keywords: list[str]) -> list[float]:
|
||||||
"""为消息生成embedding向量"""
|
"""为消息生成embedding向量"""
|
||||||
@@ -489,7 +489,7 @@ class BotInterestManager:
|
|||||||
if not active_tags:
|
if not active_tags:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.debug(f"🔍 开始计算与 {len(active_tags)} 个兴趣标签的相似度")
|
logger.debug(f"开始计算与 {len(active_tags)} 个兴趣标签的相似度")
|
||||||
|
|
||||||
for tag in active_tags:
|
for tag in active_tags:
|
||||||
if tag.embedding:
|
if tag.embedding:
|
||||||
@@ -501,11 +501,11 @@ class BotInterestManager:
|
|||||||
if similarity > 0.3:
|
if similarity > 0.3:
|
||||||
result.add_match(tag.tag_name, weighted_score, keywords)
|
result.add_match(tag.tag_name, weighted_score, keywords)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 加权分数={weighted_score:.3f}"
|
f"'{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 加权分数={weighted_score:.3f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"❌ 计算相似度分数失败: {e}")
|
logger.error(f"计算相似度分数失败: {e}")
|
||||||
|
|
||||||
async def calculate_interest_match(
|
async def calculate_interest_match(
|
||||||
self, message_text: str, keywords: list[str] | None = None, message_embedding: list[float] | None = None
|
self, message_text: str, keywords: list[str] | None = None, message_embedding: list[float] | None = None
|
||||||
@@ -560,7 +560,7 @@ class BotInterestManager:
|
|||||||
medium_threshold = affinity_config.medium_match_interest_threshold
|
medium_threshold = affinity_config.medium_match_interest_threshold
|
||||||
low_threshold = affinity_config.low_match_interest_threshold
|
low_threshold = affinity_config.low_match_interest_threshold
|
||||||
|
|
||||||
logger.debug(f"🔍 使用分级相似度阈值: 高={high_threshold}, 中={medium_threshold}, 低={low_threshold}")
|
logger.debug(f"使用分级相似度阈值: 高={high_threshold}, 中={medium_threshold}, 低={low_threshold}")
|
||||||
|
|
||||||
for tag in active_tags:
|
for tag in active_tags:
|
||||||
if tag.embedding:
|
if tag.embedding:
|
||||||
@@ -647,7 +647,7 @@ class BotInterestManager:
|
|||||||
if hasattr(self, "_new_expanded_embeddings_generated") and self._new_expanded_embeddings_generated:
|
if hasattr(self, "_new_expanded_embeddings_generated") and self._new_expanded_embeddings_generated:
|
||||||
await self._save_embedding_cache_to_file(self.current_interests.personality_id)
|
await self._save_embedding_cache_to_file(self.current_interests.personality_id)
|
||||||
self._new_expanded_embeddings_generated = False
|
self._new_expanded_embeddings_generated = False
|
||||||
logger.debug("💾 已保存新生成的扩展embedding到缓存文件")
|
logger.debug("已保存新生成的扩展embedding到缓存文件")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -671,7 +671,7 @@ class BotInterestManager:
|
|||||||
self.expanded_tag_cache[tag_name] = expanded_tag
|
self.expanded_tag_cache[tag_name] = expanded_tag
|
||||||
self.expanded_embedding_cache[tag_name] = embedding
|
self.expanded_embedding_cache[tag_name] = embedding
|
||||||
self._new_expanded_embeddings_generated = True # 标记有新生成的embedding
|
self._new_expanded_embeddings_generated = True # 标记有新生成的embedding
|
||||||
logger.debug(f"✅ 为标签'{tag_name}'生成并缓存扩展embedding: {expanded_tag[:50]}...")
|
logger.debug(f"为标签'{tag_name}'生成并缓存扩展embedding: {expanded_tag[:50]}...")
|
||||||
return embedding
|
return embedding
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"为标签'{tag_name}'生成扩展embedding失败: {e}")
|
logger.warning(f"为标签'{tag_name}'生成扩展embedding失败: {e}")
|
||||||
@@ -700,12 +700,12 @@ class BotInterestManager:
|
|||||||
if self.current_interests:
|
if self.current_interests:
|
||||||
for tag in self.current_interests.interest_tags:
|
for tag in self.current_interests.interest_tags:
|
||||||
if tag.tag_name == tag_name and tag.expanded:
|
if tag.tag_name == tag_name and tag.expanded:
|
||||||
logger.debug(f"✅ 使用LLM生成的扩展描述: {tag_name} -> {tag.expanded[:50]}...")
|
logger.debug(f"使用LLM生成的扩展描述: {tag_name} -> {tag.expanded[:50]}...")
|
||||||
self.expanded_tag_cache[tag_name] = tag.expanded
|
self.expanded_tag_cache[tag_name] = tag.expanded
|
||||||
return tag.expanded
|
return tag.expanded
|
||||||
|
|
||||||
# 🔧 回退策略:基于规则的扩展(用于兼容旧数据或LLM未生成扩展的情况)
|
# 🔧 回退策略:基于规则的扩展(用于兼容旧数据或LLM未生成扩展的情况)
|
||||||
logger.debug(f"⚠️ 标签'{tag_name}'没有LLM扩展描述,使用规则回退方案")
|
logger.debug(f"标签'{tag_name}'没有LLM扩展描述,使用规则回退方案")
|
||||||
tag_lower = tag_name.lower()
|
tag_lower = tag_name.lower()
|
||||||
|
|
||||||
# 技术编程类标签(具体化描述)
|
# 技术编程类标签(具体化描述)
|
||||||
@@ -790,7 +790,7 @@ class BotInterestManager:
|
|||||||
if keyword_lower == tag_name_lower:
|
if keyword_lower == tag_name_lower:
|
||||||
bonus += affinity_config.high_match_interest_threshold * 0.6 # 使用高匹配阈值的60%作为完全匹配奖励
|
bonus += affinity_config.high_match_interest_threshold * 0.6 # 使用高匹配阈值的60%作为完全匹配奖励
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f" 🎯 关键词完全匹配: '{keyword}' == '{tag_name}' (+{affinity_config.high_match_interest_threshold * 0.6:.3f})"
|
f"关键词完全匹配: '{keyword}' == '{tag_name}' (+{affinity_config.high_match_interest_threshold * 0.6:.3f})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 包含匹配
|
# 包含匹配
|
||||||
@@ -799,14 +799,14 @@ class BotInterestManager:
|
|||||||
affinity_config.medium_match_interest_threshold * 0.3
|
affinity_config.medium_match_interest_threshold * 0.3
|
||||||
) # 使用中匹配阈值的30%作为包含匹配奖励
|
) # 使用中匹配阈值的30%作为包含匹配奖励
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f" 🎯 关键词包含匹配: '{keyword}' ⊃ '{tag_name}' (+{affinity_config.medium_match_interest_threshold * 0.3:.3f})"
|
f"关键词包含匹配: '{keyword}' ⊃ '{tag_name}' (+{affinity_config.medium_match_interest_threshold * 0.3:.3f})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 部分匹配(编辑距离)
|
# 部分匹配(编辑距离)
|
||||||
elif self._calculate_partial_match(keyword_lower, tag_name_lower):
|
elif self._calculate_partial_match(keyword_lower, tag_name_lower):
|
||||||
bonus += affinity_config.low_match_interest_threshold * 0.4 # 使用低匹配阈值的40%作为部分匹配奖励
|
bonus += affinity_config.low_match_interest_threshold * 0.4 # 使用低匹配阈值的40%作为部分匹配奖励
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f" 🎯 关键词部分匹配: '{keyword}' ≈ '{tag_name}' (+{affinity_config.low_match_interest_threshold * 0.4:.3f})"
|
f"关键词部分匹配: '{keyword}' ≈ '{tag_name}' (+{affinity_config.low_match_interest_threshold * 0.4:.3f})"
|
||||||
)
|
)
|
||||||
|
|
||||||
if bonus > 0:
|
if bonus > 0:
|
||||||
@@ -939,11 +939,11 @@ class BotInterestManager:
|
|||||||
return interests
|
return interests
|
||||||
|
|
||||||
except (orjson.JSONDecodeError, Exception) as e:
|
except (orjson.JSONDecodeError, Exception) as e:
|
||||||
logger.error(f"❌ 解析兴趣标签JSON失败: {e}")
|
logger.error(f"解析兴趣标签JSON失败: {e}")
|
||||||
logger.debug(f"🔍 原始JSON数据: {db_interests.interest_tags[:200]}...")
|
logger.debug(f"原始JSON数据: {db_interests.interest_tags[:200]}...")
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
logger.info(f"ℹ️ 数据库中未找到personality_id为 '{personality_id}' 的兴趣标签配置")
|
logger.info(f"数据库中未找到personality_id为 '{personality_id}' 的兴趣标签配置")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -955,11 +955,6 @@ class BotInterestManager:
|
|||||||
async def _save_interests_to_database(self, interests: BotPersonalityInterests):
|
async def _save_interests_to_database(self, interests: BotPersonalityInterests):
|
||||||
"""保存兴趣标签到数据库"""
|
"""保存兴趣标签到数据库"""
|
||||||
try:
|
try:
|
||||||
logger.info("💾 正在保存兴趣标签到数据库...")
|
|
||||||
logger.info(f"📋 personality_id: {interests.personality_id}")
|
|
||||||
logger.info(f"🏷️ 兴趣标签数量: {len(interests.interest_tags)}")
|
|
||||||
logger.info(f"🔄 版本: {interests.version}")
|
|
||||||
|
|
||||||
# 导入SQLAlchemy相关模块
|
# 导入SQLAlchemy相关模块
|
||||||
import orjson
|
import orjson
|
||||||
|
|
||||||
@@ -999,18 +994,18 @@ class BotInterestManager:
|
|||||||
|
|
||||||
if existing_record:
|
if existing_record:
|
||||||
# 更新现有记录
|
# 更新现有记录
|
||||||
logger.info("🔄 更新现有的兴趣标签配置")
|
logger.info("更新现有的兴趣标签配置")
|
||||||
existing_record.interest_tags = json_data.decode("utf-8")
|
existing_record.interest_tags = json_data.decode("utf-8")
|
||||||
existing_record.personality_description = interests.personality_description
|
existing_record.personality_description = interests.personality_description
|
||||||
existing_record.embedding_model = interests.embedding_model
|
existing_record.embedding_model = interests.embedding_model
|
||||||
existing_record.version = interests.version
|
existing_record.version = interests.version
|
||||||
existing_record.last_updated = interests.last_updated
|
existing_record.last_updated = interests.last_updated
|
||||||
|
|
||||||
logger.info(f"✅ 成功更新兴趣标签配置,版本: {interests.version}")
|
logger.info(f"成功更新兴趣标签配置,版本: {interests.version}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# 创建新记录
|
# 创建新记录
|
||||||
logger.info("🆕 创建新的兴趣标签配置")
|
logger.info("创建新的兴趣标签配置")
|
||||||
new_record = DBBotPersonalityInterests(
|
new_record = DBBotPersonalityInterests(
|
||||||
personality_id=interests.personality_id,
|
personality_id=interests.personality_id,
|
||||||
personality_description=interests.personality_description,
|
personality_description=interests.personality_description,
|
||||||
@@ -1021,9 +1016,8 @@ class BotInterestManager:
|
|||||||
)
|
)
|
||||||
session.add(new_record)
|
session.add(new_record)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
logger.info(f"✅ 成功创建兴趣标签配置,版本: {interests.version}")
|
|
||||||
|
|
||||||
logger.info("✅ 兴趣标签已成功保存到数据库")
|
logger.info("兴趣标签已成功保存到数据库")
|
||||||
|
|
||||||
# 验证保存是否成功
|
# 验证保存是否成功
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
@@ -1039,9 +1033,9 @@ class BotInterestManager:
|
|||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if saved_record:
|
if saved_record:
|
||||||
logger.info(f"✅ 验证成功:数据库中存在personality_id为 {interests.personality_id} 的记录")
|
logger.info(f"验证成功:数据库中存在personality_id为 {interests.personality_id} 的记录")
|
||||||
logger.info(f" 版本: {saved_record.version}")
|
logger.info(f"版本: {saved_record.version}")
|
||||||
logger.info(f" 最后更新: {saved_record.last_updated}")
|
logger.info(f"最后更新: {saved_record.last_updated}")
|
||||||
else:
|
else:
|
||||||
logger.error(f"❌ 验证失败:数据库中未找到personality_id为 {interests.personality_id} 的记录")
|
logger.error(f"❌ 验证失败:数据库中未找到personality_id为 {interests.personality_id} 的记录")
|
||||||
|
|
||||||
@@ -1089,13 +1083,12 @@ class BotInterestManager:
|
|||||||
expanded_embeddings = cache_data.get("expanded_embeddings", {})
|
expanded_embeddings = cache_data.get("expanded_embeddings", {})
|
||||||
if expanded_embeddings:
|
if expanded_embeddings:
|
||||||
self.expanded_embedding_cache.update(expanded_embeddings)
|
self.expanded_embedding_cache.update(expanded_embeddings)
|
||||||
logger.info(f"📂 加载 {len(expanded_embeddings)} 个扩展标签embedding缓存")
|
|
||||||
|
|
||||||
logger.info(f"✅ 成功从文件加载 {len(embeddings)} 个标签embedding缓存 (版本: {cache_version}, 模型: {cache_embedding_model})")
|
logger.info(f"成功从文件加载 {len(embeddings)} 个标签embedding缓存 (版本: {cache_version}, 模型: {cache_embedding_model})")
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"⚠️ 加载embedding缓存文件失败: {e}")
|
logger.warning(f"加载embedding缓存文件失败: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _save_embedding_cache_to_file(self, personality_id: str):
|
async def _save_embedding_cache_to_file(self, personality_id: str):
|
||||||
@@ -1134,10 +1127,10 @@ class BotInterestManager:
|
|||||||
async with aiofiles.open(cache_file, "wb") as f:
|
async with aiofiles.open(cache_file, "wb") as f:
|
||||||
await f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2))
|
await f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2))
|
||||||
|
|
||||||
logger.debug(f"💾 已保存 {len(self.embedding_cache)} 个标签embedding和 {len(self.expanded_embedding_cache)} 个扩展embedding到缓存文件: {cache_file}")
|
logger.debug(f"已保存 {len(self.embedding_cache)} 个标签embedding和 {len(self.expanded_embedding_cache)} 个扩展embedding到缓存文件: {cache_file}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"⚠️ 保存embedding缓存文件失败: {e}")
|
logger.warning(f"保存embedding缓存文件失败: {e}")
|
||||||
|
|
||||||
def get_current_interests(self) -> BotPersonalityInterests | None:
|
def get_current_interests(self) -> BotPersonalityInterests | None:
|
||||||
"""获取当前的兴趣标签配置"""
|
"""获取当前的兴趣标签配置"""
|
||||||
|
|||||||
@@ -82,7 +82,6 @@ class InterestManager:
|
|||||||
if await calculator.initialize():
|
if await calculator.initialize():
|
||||||
self._current_calculator = calculator
|
self._current_calculator = calculator
|
||||||
logger.info(f"兴趣值计算组件注册成功: {calculator.component_name} v{calculator.component_version}")
|
logger.info(f"兴趣值计算组件注册成功: {calculator.component_name} v{calculator.component_version}")
|
||||||
logger.info("系统现在只有一个活跃的兴趣值计算器")
|
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
logger.error(f"兴趣值计算组件初始化失败: {calculator.component_name}")
|
logger.error(f"兴趣值计算组件初始化失败: {calculator.component_name}")
|
||||||
|
|||||||
@@ -468,7 +468,7 @@ class EmbeddingStore:
|
|||||||
logger.info(f"使用实际检测到的 embedding 维度: {embedding_dim}")
|
logger.info(f"使用实际检测到的 embedding 维度: {embedding_dim}")
|
||||||
self.faiss_index = faiss.IndexFlatIP(embedding_dim)
|
self.faiss_index = faiss.IndexFlatIP(embedding_dim)
|
||||||
self.faiss_index.add(embeddings)
|
self.faiss_index.add(embeddings)
|
||||||
logger.info(f"✅ 成功构建 Faiss 索引: {len(embeddings)} 个向量, 维度={embedding_dim}")
|
logger.info(f"成功构建 Faiss 索引: {len(embeddings)} 个向量, 维度={embedding_dim}")
|
||||||
|
|
||||||
def search_top_k(self, query: list[float], k: int) -> list[tuple[str, float]]:
|
def search_top_k(self, query: list[float], k: int) -> list[tuple[str, float]]:
|
||||||
"""搜索最相似的k个项,以余弦相似度为度量
|
"""搜索最相似的k个项,以余弦相似度为度量
|
||||||
|
|||||||
@@ -308,7 +308,6 @@ async def _process_single_segment(
|
|||||||
filename = seg_data.get("filename", "video.mp4")
|
filename = seg_data.get("filename", "video.mp4")
|
||||||
|
|
||||||
logger.info(f"视频文件名: {filename}")
|
logger.info(f"视频文件名: {filename}")
|
||||||
logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
|
|
||||||
|
|
||||||
if video_base64:
|
if video_base64:
|
||||||
# 解码base64视频数据
|
# 解码base64视频数据
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ class ActionModifier:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if removals_s0:
|
if removals_s0:
|
||||||
logger.info(f"{self.log_prefix} 第0阶段:类型/Chatter过滤 - 移除了 {len(removals_s0)} 个动作")
|
logger.info(f"{self.log_prefix} 第0阶段:类型Chatter过滤 - 移除了 {len(removals_s0)} 个动作")
|
||||||
for action_name, reason in removals_s0:
|
for action_name, reason in removals_s0:
|
||||||
logger.debug(f"{self.log_prefix} - 移除 {action_name}: {reason}")
|
logger.debug(f"{self.log_prefix} - 移除 {action_name}: {reason}")
|
||||||
|
|
||||||
|
|||||||
@@ -876,7 +876,6 @@ class DefaultReplyer:
|
|||||||
notice_lines.append("")
|
notice_lines.append("")
|
||||||
|
|
||||||
result = "\n".join(notice_lines)
|
result = "\n".join(notice_lines)
|
||||||
logger.info(f"notice块构建成功,chat_id={chat_id}, 长度={len(result)}")
|
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
logger.debug(f"没有可用的notice文本,chat_id={chat_id}")
|
logger.debug(f"没有可用的notice文本,chat_id={chat_id}")
|
||||||
|
|||||||
@@ -179,40 +179,17 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def _yield_control(iteration: int, interval: int = 200) -> None:
|
async def _yield_control(iteration: int, interval: int = 200) -> None:
|
||||||
"""
|
"""
|
||||||
<EFBFBD>ڴ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ʱ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>첽<EFBFBD>¼<EFBFBD>ѭ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ӧ
|
在长时间运行的循环中定期让出控制权,以防止阻塞事件循环
|
||||||
|
:param iteration: 当前迭代次数
|
||||||
Args:
|
:param interval: 每隔多少次迭代让出一次控制权
|
||||||
iteration: <20><>ǰ<EFBFBD><C7B0><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
|
|
||||||
interval: ÿ<><C3BF><EFBFBD><EFBFBD><EFBFBD>ٴ<EFBFBD><D9B4>л<EFBFBD>һ<EFBFBD><D2BB>
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if iteration % interval == 0:
|
if iteration % interval == 0:
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
try:
|
|
||||||
now = datetime.now()
|
|
||||||
logger.info("正在收集统计数据(异步)...")
|
|
||||||
stats = await self._collect_all_statistics(now)
|
|
||||||
logger.info("统计数据收集完成")
|
|
||||||
|
|
||||||
self._statistic_console_output(stats, now)
|
|
||||||
# 使用新的 HTMLReportGenerator 生成报告
|
|
||||||
chart_data = await self._collect_chart_data(stats)
|
|
||||||
deploy_time = datetime.fromtimestamp(float(local_storage.get("deploy_time", now.timestamp()))) # type: ignore
|
|
||||||
report_generator = HTMLReportGenerator(
|
|
||||||
name_mapping=self.name_mapping,
|
|
||||||
stat_period=self.stat_period,
|
|
||||||
deploy_time=deploy_time,
|
|
||||||
)
|
|
||||||
await report_generator.generate_report(stats, chart_data, now, self.record_file_path)
|
|
||||||
logger.info("统计数据HTML报告输出完成")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"输出统计数据过程中发生异常,错误信息:{e}")
|
|
||||||
|
|
||||||
async def run_async_background(self):
|
|
||||||
"""
|
"""
|
||||||
备选方案:完全异步后台运行统计输出
|
完全异步后台运行统计输出
|
||||||
使用此方法可以让统计任务完全非阻塞
|
使用此方法可以让统计任务完全非阻塞
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -1,590 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
视频分析器模块 - 旧版本兼容模块
|
|
||||||
支持多种分析模式:批处理、逐帧、自动选择
|
|
||||||
包含Python原生的抽帧功能,作为Rust模块的降级方案
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import base64
|
|
||||||
import io
|
|
||||||
import os
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.config.config import global_config, model_config
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
|
||||||
|
|
||||||
logger = get_logger("utils_video_legacy")
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_frames_worker(
|
|
||||||
video_path: str,
|
|
||||||
max_frames: int,
|
|
||||||
frame_quality: int,
|
|
||||||
max_image_size: int,
|
|
||||||
frame_extraction_mode: str,
|
|
||||||
frame_interval_seconds: float | None,
|
|
||||||
) -> list[tuple[str, float]] | list[tuple[str, str]]:
|
|
||||||
"""线程池中提取视频帧的工作函数"""
|
|
||||||
frames: list[tuple[str, float]] = []
|
|
||||||
try:
|
|
||||||
cap = cv2.VideoCapture(video_path)
|
|
||||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
||||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
||||||
duration = total_frames / fps if fps > 0 else 0
|
|
||||||
|
|
||||||
if frame_extraction_mode == "time_interval":
|
|
||||||
# 新模式:按时间间隔抽帧
|
|
||||||
time_interval = frame_interval_seconds or 2.0
|
|
||||||
next_frame_time = 0.0
|
|
||||||
extracted_count = 0 # 初始化提取帧计数器
|
|
||||||
|
|
||||||
while cap.isOpened():
|
|
||||||
ret, frame = cap.read()
|
|
||||||
if not ret:
|
|
||||||
break
|
|
||||||
|
|
||||||
current_time = cap.get(cv2.CAP_PROP_POS_MSEC) / 1000.0
|
|
||||||
|
|
||||||
if current_time >= next_frame_time:
|
|
||||||
# 转换为PIL图像并压缩
|
|
||||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
||||||
pil_image = Image.fromarray(frame_rgb)
|
|
||||||
|
|
||||||
# 调整图像大小
|
|
||||||
if max(pil_image.size) > max_image_size:
|
|
||||||
ratio = max_image_size / max(pil_image.size)
|
|
||||||
new_size = (int(pil_image.size[0] * ratio), int(pil_image.size[1] * ratio))
|
|
||||||
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
|
|
||||||
|
|
||||||
# 转换为base64
|
|
||||||
buffer = io.BytesIO()
|
|
||||||
pil_image.save(buffer, format="JPEG", quality=frame_quality)
|
|
||||||
frame_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
|
||||||
|
|
||||||
frames.append((frame_base64, current_time))
|
|
||||||
extracted_count += 1
|
|
||||||
|
|
||||||
# 注意:这里不能使用logger,因为在线程池中
|
|
||||||
# logger.debug(f"提取第{extracted_count}帧 (时间: {current_time:.2f}s)")
|
|
||||||
|
|
||||||
next_frame_time += time_interval
|
|
||||||
else:
|
|
||||||
# 使用numpy优化帧间隔计算
|
|
||||||
if duration > 0:
|
|
||||||
frame_interval = max(1, int(duration / max_frames * fps))
|
|
||||||
else:
|
|
||||||
frame_interval = 30 # 默认间隔
|
|
||||||
|
|
||||||
# 使用numpy计算目标帧位置
|
|
||||||
target_frames = np.arange(0, min(max_frames, total_frames // frame_interval + 1)) * frame_interval
|
|
||||||
target_frames = target_frames[target_frames < total_frames].astype(int)
|
|
||||||
|
|
||||||
for target_frame in target_frames:
|
|
||||||
# 跳转到目标帧
|
|
||||||
cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame)
|
|
||||||
ret, frame = cap.read()
|
|
||||||
if not ret:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 使用numpy优化图像处理
|
|
||||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
||||||
|
|
||||||
# 转换为PIL图像并使用numpy进行尺寸计算
|
|
||||||
height, width = frame_rgb.shape[:2]
|
|
||||||
max_dim = max(height, width)
|
|
||||||
|
|
||||||
if max_dim > max_image_size:
|
|
||||||
# 使用numpy计算缩放比例
|
|
||||||
ratio = max_image_size / max_dim
|
|
||||||
new_width = int(width * ratio)
|
|
||||||
new_height = int(height * ratio)
|
|
||||||
|
|
||||||
# 使用opencv进行高效缩放
|
|
||||||
frame_resized = cv2.resize(frame_rgb, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4)
|
|
||||||
pil_image = Image.fromarray(frame_resized)
|
|
||||||
else:
|
|
||||||
pil_image = Image.fromarray(frame_rgb)
|
|
||||||
|
|
||||||
# 转换为base64
|
|
||||||
buffer = io.BytesIO()
|
|
||||||
pil_image.save(buffer, format="JPEG", quality=frame_quality)
|
|
||||||
frame_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
|
||||||
|
|
||||||
# 计算时间戳
|
|
||||||
timestamp = target_frame / fps if fps > 0 else 0
|
|
||||||
frames.append((frame_base64, timestamp))
|
|
||||||
|
|
||||||
cap.release()
|
|
||||||
return frames
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# 返回错误信息
|
|
||||||
return [("ERROR", str(e))]
|
|
||||||
|
|
||||||
|
|
||||||
class LegacyVideoAnalyzer:
|
|
||||||
"""旧版本兼容的视频分析器类"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""初始化视频分析器"""
|
|
||||||
assert global_config is not None
|
|
||||||
assert model_config is not None
|
|
||||||
# 使用专用的视频分析配置
|
|
||||||
try:
|
|
||||||
self.video_llm = LLMRequest(
|
|
||||||
model_set=model_config.model_task_config.video_analysis, request_type="video_analysis"
|
|
||||||
)
|
|
||||||
logger.info("✅ 使用video_analysis模型配置")
|
|
||||||
except (AttributeError, KeyError) as e:
|
|
||||||
# 如果video_analysis不存在,使用vlm配置
|
|
||||||
self.video_llm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="vlm")
|
|
||||||
logger.warning(f"video_analysis配置不可用({e}),回退使用vlm配置")
|
|
||||||
|
|
||||||
# 从配置文件读取参数,如果配置不存在则使用默认值
|
|
||||||
config = global_config.video_analysis
|
|
||||||
|
|
||||||
# 使用 getattr 统一获取配置参数,如果配置不存在则使用默认值
|
|
||||||
self.max_frames = getattr(config, "max_frames", 6)
|
|
||||||
self.frame_quality = getattr(config, "frame_quality", 85)
|
|
||||||
self.max_image_size = getattr(config, "max_image_size", 600)
|
|
||||||
self.enable_frame_timing = getattr(config, "enable_frame_timing", True)
|
|
||||||
|
|
||||||
# 从personality配置中获取人格信息
|
|
||||||
try:
|
|
||||||
personality_config = global_config.personality
|
|
||||||
self.personality_core = getattr(personality_config, "personality_core", "是一个积极向上的女大学生")
|
|
||||||
self.personality_side = getattr(
|
|
||||||
personality_config, "personality_side", "用一句话或几句话描述人格的侧面特点"
|
|
||||||
)
|
|
||||||
except AttributeError:
|
|
||||||
# 如果没有personality配置,使用默认值
|
|
||||||
self.personality_core = "是一个积极向上的女大学生"
|
|
||||||
self.personality_side = "用一句话或几句话描述人格的侧面特点"
|
|
||||||
|
|
||||||
self.batch_analysis_prompt = getattr(
|
|
||||||
config,
|
|
||||||
"batch_analysis_prompt",
|
|
||||||
"""请以第一人称的视角来观看这一个视频,你看到的这些是从视频中按时间顺序提取的关键帧。
|
|
||||||
|
|
||||||
你的核心人设是:{personality_core}。
|
|
||||||
你的人格细节是:{personality_side}。
|
|
||||||
|
|
||||||
请提供详细的视频内容描述,涵盖以下方面:
|
|
||||||
1. 视频的整体内容和主题
|
|
||||||
2. 主要人物、对象和场景描述
|
|
||||||
3. 动作、情节和时间线发展
|
|
||||||
4. 视觉风格和艺术特点
|
|
||||||
5. 整体氛围和情感表达
|
|
||||||
6. 任何特殊的视觉效果或文字内容
|
|
||||||
|
|
||||||
请用中文回答,结果要详细准确。""",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 新增的线程池配置
|
|
||||||
self.use_multiprocessing = getattr(config, "use_multiprocessing", True)
|
|
||||||
self.max_workers = getattr(config, "max_workers", 2)
|
|
||||||
self.frame_extraction_mode = getattr(config, "frame_extraction_mode", "fixed_number")
|
|
||||||
self.frame_interval_seconds = getattr(config, "frame_interval_seconds", 2.0)
|
|
||||||
|
|
||||||
# 将配置文件中的模式映射到内部使用的模式名称
|
|
||||||
config_mode = getattr(config, "analysis_mode", "auto")
|
|
||||||
if config_mode == "batch_frames":
|
|
||||||
self.analysis_mode = "batch"
|
|
||||||
elif config_mode == "frame_by_frame":
|
|
||||||
self.analysis_mode = "sequential"
|
|
||||||
elif config_mode == "auto":
|
|
||||||
self.analysis_mode = "auto"
|
|
||||||
else:
|
|
||||||
logger.warning(f"无效的分析模式: {config_mode},使用默认的auto模式")
|
|
||||||
self.analysis_mode = "auto"
|
|
||||||
|
|
||||||
self.frame_analysis_delay = 0.3 # API调用间隔(秒)
|
|
||||||
self.frame_interval = 1.0 # 抽帧时间间隔(秒)
|
|
||||||
self.batch_size = 3 # 批处理时每批处理的帧数
|
|
||||||
self.timeout = 60.0 # 分析超时时间(秒)
|
|
||||||
|
|
||||||
if config:
|
|
||||||
logger.info("✅ 从配置文件读取视频分析参数")
|
|
||||||
else:
|
|
||||||
logger.warning("配置文件中缺少video_analysis配置,使用默认值")
|
|
||||||
|
|
||||||
# 系统提示词
|
|
||||||
self.system_prompt = "你是一个专业的视频内容分析助手。请仔细观察用户提供的视频关键帧,详细描述视频内容。"
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"✅ 旧版本视频分析器初始化完成,分析模式: {self.analysis_mode}, 线程池: {self.use_multiprocessing}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def extract_frames(self, video_path: str) -> list[tuple[str, float]]:
|
|
||||||
"""提取视频帧 - 支持多进程和单线程模式"""
|
|
||||||
# 先获取视频信息
|
|
||||||
cap = cv2.VideoCapture(video_path)
|
|
||||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
||||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
||||||
duration = total_frames / fps if fps > 0 else 0
|
|
||||||
cap.release()
|
|
||||||
|
|
||||||
logger.info(f"视频信息: {total_frames}帧, {fps:.2f}FPS, {duration:.2f}秒")
|
|
||||||
|
|
||||||
# 估算提取帧数
|
|
||||||
if duration > 0:
|
|
||||||
frame_interval = max(1, int(duration / self.max_frames * fps))
|
|
||||||
estimated_frames = min(self.max_frames, total_frames // frame_interval + 1)
|
|
||||||
else:
|
|
||||||
estimated_frames = self.max_frames
|
|
||||||
frame_interval = 1
|
|
||||||
|
|
||||||
logger.info(f"计算得出帧间隔: {frame_interval} (将提取约{estimated_frames}帧)")
|
|
||||||
|
|
||||||
# 根据配置选择处理方式
|
|
||||||
if self.use_multiprocessing:
|
|
||||||
return await self._extract_frames_multiprocess(video_path)
|
|
||||||
else:
|
|
||||||
return await self._extract_frames_fallback(video_path)
|
|
||||||
|
|
||||||
async def _extract_frames_multiprocess(self, video_path: str) -> list[tuple[str, float]]:
|
|
||||||
"""线程池版本的帧提取"""
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info("🔄 启动线程池帧提取...")
|
|
||||||
# 使用线程池,避免进程间的导入问题
|
|
||||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
|
||||||
frames = await loop.run_in_executor(
|
|
||||||
executor,
|
|
||||||
_extract_frames_worker,
|
|
||||||
video_path,
|
|
||||||
self.max_frames,
|
|
||||||
self.frame_quality,
|
|
||||||
self.max_image_size,
|
|
||||||
self.frame_extraction_mode,
|
|
||||||
self.frame_interval_seconds,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 检查是否有错误
|
|
||||||
if frames and frames[0][0] == "ERROR":
|
|
||||||
logger.error(f"线程池帧提取失败: {frames[0][1]}")
|
|
||||||
# 降级到单线程模式
|
|
||||||
logger.info("🔄 降级到单线程模式...")
|
|
||||||
return await self._extract_frames_fallback(video_path)
|
|
||||||
|
|
||||||
logger.info(f"✅ 成功提取{len(frames)}帧 (线程池模式)")
|
|
||||||
return frames # type: ignore
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"线程池帧提取失败: {e}")
|
|
||||||
# 降级到原始方法
|
|
||||||
logger.info("🔄 降级到单线程模式...")
|
|
||||||
return await self._extract_frames_fallback(video_path)
|
|
||||||
|
|
||||||
async def _extract_frames_fallback(self, video_path: str) -> list[tuple[str, float]]:
|
|
||||||
"""帧提取的降级方法 - 原始异步版本"""
|
|
||||||
frames = []
|
|
||||||
extracted_count = 0
|
|
||||||
cap = cv2.VideoCapture(video_path)
|
|
||||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
||||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
||||||
duration = total_frames / fps if fps > 0 else 0
|
|
||||||
|
|
||||||
logger.info(f"视频信息: {total_frames}帧, {fps:.2f}FPS, {duration:.2f}秒")
|
|
||||||
|
|
||||||
if self.frame_extraction_mode == "time_interval":
|
|
||||||
# 新模式:按时间间隔抽帧
|
|
||||||
time_interval = self.frame_interval_seconds
|
|
||||||
next_frame_time = 0.0
|
|
||||||
|
|
||||||
while cap.isOpened():
|
|
||||||
ret, frame = cap.read()
|
|
||||||
if not ret:
|
|
||||||
break
|
|
||||||
|
|
||||||
current_time = cap.get(cv2.CAP_PROP_POS_MSEC) / 1000.0
|
|
||||||
|
|
||||||
if current_time >= next_frame_time:
|
|
||||||
# 转换为PIL图像并压缩
|
|
||||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
||||||
pil_image = Image.fromarray(frame_rgb)
|
|
||||||
|
|
||||||
# 调整图像大小
|
|
||||||
if max(pil_image.size) > self.max_image_size:
|
|
||||||
ratio = self.max_image_size / max(pil_image.size)
|
|
||||||
new_size = (int(pil_image.size[0] * ratio), int(pil_image.size[1] * ratio))
|
|
||||||
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
|
|
||||||
|
|
||||||
# 转换为base64
|
|
||||||
buffer = io.BytesIO()
|
|
||||||
pil_image.save(buffer, format="JPEG", quality=self.frame_quality)
|
|
||||||
frame_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
|
||||||
|
|
||||||
frames.append((frame_base64, current_time))
|
|
||||||
extracted_count += 1
|
|
||||||
|
|
||||||
logger.debug(f"提取第{extracted_count}帧 (时间: {current_time:.2f}s)")
|
|
||||||
|
|
||||||
next_frame_time += time_interval
|
|
||||||
else:
|
|
||||||
# 使用numpy优化帧间隔计算
|
|
||||||
if duration > 0:
|
|
||||||
frame_interval = max(1, int(duration / self.max_frames * fps))
|
|
||||||
else:
|
|
||||||
frame_interval = 30 # 默认间隔
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"计算得出帧间隔: {frame_interval} (将提取约{min(self.max_frames, total_frames // frame_interval + 1)}帧)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 使用numpy计算目标帧位置
|
|
||||||
target_frames = np.arange(0, min(self.max_frames, total_frames // frame_interval + 1)) * frame_interval
|
|
||||||
target_frames = target_frames[target_frames < total_frames].astype(int)
|
|
||||||
|
|
||||||
extracted_count = 0
|
|
||||||
|
|
||||||
for target_frame in target_frames:
|
|
||||||
# 跳转到目标帧
|
|
||||||
cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame)
|
|
||||||
ret, frame = cap.read()
|
|
||||||
if not ret:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 使用numpy优化图像处理
|
|
||||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
||||||
|
|
||||||
# 转换为PIL图像并使用numpy进行尺寸计算
|
|
||||||
height, width = frame_rgb.shape[:2]
|
|
||||||
max_dim = max(height, width)
|
|
||||||
|
|
||||||
if max_dim > self.max_image_size:
|
|
||||||
# 使用numpy计算缩放比例
|
|
||||||
ratio = self.max_image_size / max_dim
|
|
||||||
new_width = int(width * ratio)
|
|
||||||
new_height = int(height * ratio)
|
|
||||||
|
|
||||||
# 使用opencv进行高效缩放
|
|
||||||
frame_resized = cv2.resize(frame_rgb, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4)
|
|
||||||
pil_image = Image.fromarray(frame_resized)
|
|
||||||
else:
|
|
||||||
pil_image = Image.fromarray(frame_rgb)
|
|
||||||
|
|
||||||
# 转换为base64
|
|
||||||
buffer = io.BytesIO()
|
|
||||||
pil_image.save(buffer, format="JPEG", quality=self.frame_quality)
|
|
||||||
frame_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
|
||||||
|
|
||||||
# 计算时间戳
|
|
||||||
timestamp = target_frame / fps if fps > 0 else 0
|
|
||||||
frames.append((frame_base64, timestamp))
|
|
||||||
extracted_count += 1
|
|
||||||
|
|
||||||
logger.debug(f"提取第{extracted_count}帧 (时间: {timestamp:.2f}s, 帧号: {target_frame})")
|
|
||||||
|
|
||||||
# 每提取一帧让步一次
|
|
||||||
await asyncio.sleep(0.001)
|
|
||||||
|
|
||||||
cap.release()
|
|
||||||
logger.info(f"✅ 成功提取{len(frames)}帧")
|
|
||||||
return frames
|
|
||||||
|
|
||||||
async def analyze_frames_batch(self, frames: list[tuple[str, float]], user_question: str | None = None) -> str:
|
|
||||||
"""批量分析所有帧"""
|
|
||||||
logger.info(f"开始批量分析{len(frames)}帧")
|
|
||||||
|
|
||||||
if not frames:
|
|
||||||
return "❌ 没有可分析的帧"
|
|
||||||
|
|
||||||
# 构建提示词并格式化人格信息,要不然占位符的那个会爆炸
|
|
||||||
prompt = self.batch_analysis_prompt.format(
|
|
||||||
personality_core=self.personality_core, personality_side=self.personality_side
|
|
||||||
)
|
|
||||||
|
|
||||||
if user_question:
|
|
||||||
prompt += f"\n\n用户问题: {user_question}"
|
|
||||||
|
|
||||||
# 添加帧信息到提示词
|
|
||||||
frame_info = []
|
|
||||||
for i, (_frame_base64, timestamp) in enumerate(frames):
|
|
||||||
if self.enable_frame_timing:
|
|
||||||
frame_info.append(f"第{i + 1}帧 (时间: {timestamp:.2f}s)")
|
|
||||||
else:
|
|
||||||
frame_info.append(f"第{i + 1}帧")
|
|
||||||
|
|
||||||
prompt += f"\n\n视频包含{len(frames)}帧图像:{', '.join(frame_info)}"
|
|
||||||
prompt += "\n\n请基于所有提供的帧图像进行综合分析,关注并描述视频的完整内容和故事发展。"
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 尝试使用多图片分析
|
|
||||||
response = await self._analyze_multiple_frames(frames, prompt)
|
|
||||||
logger.info("✅ 视频识别完成")
|
|
||||||
return response
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 视频识别失败: {e}")
|
|
||||||
# 降级到单帧分析
|
|
||||||
logger.warning("降级到单帧分析模式")
|
|
||||||
try:
|
|
||||||
frame_base64, timestamp = frames[0]
|
|
||||||
fallback_prompt = (
|
|
||||||
prompt
|
|
||||||
+ f"\n\n注意:由于技术限制,当前仅显示第1帧 (时间: {timestamp:.2f}s),视频共有{len(frames)}帧。请基于这一帧进行分析。"
|
|
||||||
)
|
|
||||||
|
|
||||||
response, _ = await self.video_llm.generate_response_for_image(
|
|
||||||
prompt=fallback_prompt, image_base64=frame_base64, image_format="jpeg"
|
|
||||||
)
|
|
||||||
logger.info("✅ 降级的单帧分析完成")
|
|
||||||
return response
|
|
||||||
except Exception as fallback_e:
|
|
||||||
logger.error(f"❌ 降级分析也失败: {fallback_e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def _analyze_multiple_frames(self, frames: list[tuple[str, float]], prompt: str) -> str:
|
|
||||||
"""使用多图片分析方法"""
|
|
||||||
logger.info(f"开始构建包含{len(frames)}帧的分析请求")
|
|
||||||
|
|
||||||
# 导入MessageBuilder用于构建多图片消息
|
|
||||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
|
||||||
from src.llm_models.utils_model import RequestType
|
|
||||||
|
|
||||||
# 构建包含多张图片的消息
|
|
||||||
message_builder = MessageBuilder().set_role(RoleType.User).add_text_content(prompt)
|
|
||||||
|
|
||||||
# 添加所有帧图像
|
|
||||||
for _i, (frame_base64, _timestamp) in enumerate(frames):
|
|
||||||
message_builder.add_image_content("jpeg", frame_base64)
|
|
||||||
# logger.info(f"已添加第{i+1}帧到分析请求 (时间: {timestamp:.2f}s, 图片大小: {len(frame_base64)} chars)")
|
|
||||||
|
|
||||||
message = message_builder.build()
|
|
||||||
# logger.info(f"✅ 多帧消息构建完成,包含{len(frames)}张图片")
|
|
||||||
|
|
||||||
# 获取模型信息和客户端
|
|
||||||
model_info, api_provider, client = self.video_llm._select_model() # type: ignore
|
|
||||||
# logger.info(f"使用模型: {model_info.name} 进行多帧分析")
|
|
||||||
|
|
||||||
# 直接执行多图片请求
|
|
||||||
api_response = await self.video_llm._execute_request( # type: ignore
|
|
||||||
api_provider=api_provider,
|
|
||||||
client=client,
|
|
||||||
request_type=RequestType.RESPONSE,
|
|
||||||
model_info=model_info,
|
|
||||||
message_list=[message],
|
|
||||||
temperature=None,
|
|
||||||
max_tokens=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"视频识别完成,响应长度: {len(api_response.content or '')} ")
|
|
||||||
return api_response.content or "❌ 未获得响应内容"
|
|
||||||
|
|
||||||
async def analyze_frames_sequential(self, frames: list[tuple[str, float]], user_question: str | None = None) -> str:
|
|
||||||
"""逐帧分析并汇总"""
|
|
||||||
logger.info(f"开始逐帧分析{len(frames)}帧")
|
|
||||||
|
|
||||||
frame_analyses = []
|
|
||||||
|
|
||||||
for i, (frame_base64, timestamp) in enumerate(frames):
|
|
||||||
try:
|
|
||||||
prompt = f"请分析这个视频的第{i + 1}帧"
|
|
||||||
if self.enable_frame_timing:
|
|
||||||
prompt += f" (时间: {timestamp:.2f}s)"
|
|
||||||
prompt += "。描述你看到的内容,包括人物、动作、场景、文字等。"
|
|
||||||
|
|
||||||
if user_question:
|
|
||||||
prompt += f"\n特别关注: {user_question}"
|
|
||||||
|
|
||||||
response, _ = await self.video_llm.generate_response_for_image(
|
|
||||||
prompt=prompt, image_base64=frame_base64, image_format="jpeg"
|
|
||||||
)
|
|
||||||
|
|
||||||
frame_analyses.append(f"第{i + 1}帧 ({timestamp:.2f}s): {response}")
|
|
||||||
logger.debug(f"✅ 第{i + 1}帧分析完成")
|
|
||||||
|
|
||||||
# API调用间隔
|
|
||||||
if i < len(frames) - 1:
|
|
||||||
await asyncio.sleep(self.frame_analysis_delay)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 第{i + 1}帧分析失败: {e}")
|
|
||||||
frame_analyses.append(f"第{i + 1}帧: 分析失败 - {e}")
|
|
||||||
|
|
||||||
# 生成汇总
|
|
||||||
logger.info("开始生成汇总分析")
|
|
||||||
summary_prompt = f"""基于以下各帧的分析结果,请提供一个完整的视频内容总结:
|
|
||||||
|
|
||||||
{chr(10).join(frame_analyses)}
|
|
||||||
|
|
||||||
请综合所有帧的信息,描述视频的整体内容、故事线、主要元素和特点。"""
|
|
||||||
|
|
||||||
if user_question:
|
|
||||||
summary_prompt += f"\n特别回答用户的问题: {user_question}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 使用最后一帧进行汇总分析
|
|
||||||
if frames:
|
|
||||||
last_frame_base64, _ = frames[-1]
|
|
||||||
summary, _ = await self.video_llm.generate_response_for_image(
|
|
||||||
prompt=summary_prompt, image_base64=last_frame_base64, image_format="jpeg"
|
|
||||||
)
|
|
||||||
logger.info("✅ 逐帧分析和汇总完成")
|
|
||||||
return summary
|
|
||||||
else:
|
|
||||||
return "❌ 没有可用于汇总的帧"
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 汇总分析失败: {e}")
|
|
||||||
# 如果汇总失败,返回各帧分析结果
|
|
||||||
return f"视频逐帧分析结果:\n\n{chr(10).join(frame_analyses)}"
|
|
||||||
|
|
||||||
async def analyze_video(self, video_path: str, user_question: str | None = None) -> str:
|
|
||||||
"""分析视频的主要方法"""
|
|
||||||
try:
|
|
||||||
logger.info(f"开始分析视频: {os.path.basename(video_path)}")
|
|
||||||
|
|
||||||
# 提取帧
|
|
||||||
frames = await self.extract_frames(video_path)
|
|
||||||
if not frames:
|
|
||||||
return "❌ 无法从视频中提取有效帧"
|
|
||||||
|
|
||||||
# 根据模式选择分析方法
|
|
||||||
if self.analysis_mode == "auto":
|
|
||||||
# 智能选择:少于等于3帧用批量,否则用逐帧
|
|
||||||
mode = "batch" if len(frames) <= 3 else "sequential"
|
|
||||||
logger.info(f"自动选择分析模式: {mode} (基于{len(frames)}帧)")
|
|
||||||
else:
|
|
||||||
mode = self.analysis_mode
|
|
||||||
|
|
||||||
# 执行分析
|
|
||||||
if mode == "batch":
|
|
||||||
result = await self.analyze_frames_batch(frames, user_question)
|
|
||||||
else: # sequential
|
|
||||||
result = await self.analyze_frames_sequential(frames, user_question)
|
|
||||||
|
|
||||||
logger.info("✅ 视频分析完成")
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"❌ 视频分析失败: {e!s}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
return error_msg
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_supported_video(file_path: str) -> bool:
|
|
||||||
"""检查是否为支持的视频格式"""
|
|
||||||
supported_formats = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".m4v", ".3gp", ".webm"}
|
|
||||||
return Path(file_path).suffix.lower() in supported_formats
|
|
||||||
|
|
||||||
|
|
||||||
# 全局实例
|
|
||||||
_legacy_video_analyzer = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_legacy_video_analyzer() -> LegacyVideoAnalyzer:
|
|
||||||
"""获取旧版本视频分析器实例(单例模式)"""
|
|
||||||
global _legacy_video_analyzer
|
|
||||||
if _legacy_video_analyzer is None:
|
|
||||||
_legacy_video_analyzer = LegacyVideoAnalyzer()
|
|
||||||
return _legacy_video_analyzer
|
|
||||||
@@ -154,7 +154,7 @@ class CacheManager:
|
|||||||
if key in self.l1_kv_cache:
|
if key in self.l1_kv_cache:
|
||||||
entry = self.l1_kv_cache[key]
|
entry = self.l1_kv_cache[key]
|
||||||
if time.time() < entry["expires_at"]:
|
if time.time() < entry["expires_at"]:
|
||||||
logger.info(f"命中L1键值缓存: {key}")
|
logger.debug(f"命中L1键值缓存: {key}")
|
||||||
return entry["data"]
|
return entry["data"]
|
||||||
else:
|
else:
|
||||||
del self.l1_kv_cache[key]
|
del self.l1_kv_cache[key]
|
||||||
@@ -178,7 +178,7 @@ class CacheManager:
|
|||||||
hit_index = indices[0][0]
|
hit_index = indices[0][0]
|
||||||
l1_hit_key = self.l1_vector_id_to_key.get(hit_index)
|
l1_hit_key = self.l1_vector_id_to_key.get(hit_index)
|
||||||
if l1_hit_key and l1_hit_key in self.l1_kv_cache:
|
if l1_hit_key and l1_hit_key in self.l1_kv_cache:
|
||||||
logger.info(f"命中L1语义缓存: {l1_hit_key}")
|
logger.debug(f"命中L1语义缓存: {l1_hit_key}")
|
||||||
return self.l1_kv_cache[l1_hit_key]["data"]
|
return self.l1_kv_cache[l1_hit_key]["data"]
|
||||||
|
|
||||||
# 步骤 2b: L2 精确缓存 (数据库)
|
# 步骤 2b: L2 精确缓存 (数据库)
|
||||||
@@ -190,7 +190,7 @@ class CacheManager:
|
|||||||
# 使用 getattr 安全访问属性,避免 Pylance 类型检查错误
|
# 使用 getattr 安全访问属性,避免 Pylance 类型检查错误
|
||||||
expires_at = getattr(cache_results_obj, "expires_at", 0)
|
expires_at = getattr(cache_results_obj, "expires_at", 0)
|
||||||
if time.time() < expires_at:
|
if time.time() < expires_at:
|
||||||
logger.info(f"命中L2键值缓存: {key}")
|
logger.debug(f"命中L2键值缓存: {key}")
|
||||||
cache_value = getattr(cache_results_obj, "cache_value", "{}")
|
cache_value = getattr(cache_results_obj, "cache_value", "{}")
|
||||||
data = orjson.loads(cache_value)
|
data = orjson.loads(cache_value)
|
||||||
|
|
||||||
@@ -228,7 +228,7 @@ class CacheManager:
|
|||||||
|
|
||||||
if distance != "N/A" and distance < 0.75:
|
if distance != "N/A" and distance < 0.75:
|
||||||
l2_hit_key = results["ids"][0][0] if isinstance(results["ids"][0], list) else results["ids"][0]
|
l2_hit_key = results["ids"][0][0] if isinstance(results["ids"][0], list) else results["ids"][0]
|
||||||
logger.info(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}")
|
logger.debug(f"命中L2语义缓存: key='{l2_hit_key}', 距离={distance:.4f}")
|
||||||
|
|
||||||
# 从数据库获取缓存数据
|
# 从数据库获取缓存数据
|
||||||
semantic_cache_results_obj = await db_query(
|
semantic_cache_results_obj = await db_query(
|
||||||
|
|||||||
@@ -218,7 +218,7 @@ class CoreSinkManager:
|
|||||||
# 存储引用
|
# 存储引用
|
||||||
self._process_sinks[adapter_name] = (server, incoming_queue, outgoing_queue)
|
self._process_sinks[adapter_name] = (server, incoming_queue, outgoing_queue)
|
||||||
|
|
||||||
logger.info(f"为适配器 {adapter_name} 创建了 ProcessCoreSink 通信队列")
|
logger.debug(f"为适配器 {adapter_name} 创建了 ProcessCoreSink 通信队列")
|
||||||
|
|
||||||
return incoming_queue, outgoing_queue
|
return incoming_queue, outgoing_queue
|
||||||
|
|
||||||
@@ -237,7 +237,7 @@ class CoreSinkManager:
|
|||||||
task = asyncio.create_task(server.close())
|
task = asyncio.create_task(server.close())
|
||||||
self._background_tasks.add(task)
|
self._background_tasks.add(task)
|
||||||
task.add_done_callback(self._background_tasks.discard)
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
logger.info(f"已移除适配器 {adapter_name} 的 ProcessCoreSink 通信队列")
|
logger.debug(f"已移除适配器 {adapter_name} 的 ProcessCoreSink 通信队列")
|
||||||
|
|
||||||
async def send_outgoing(
|
async def send_outgoing(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from src.config.config import model_config
|
||||||
from . import BaseDataModel
|
from . import BaseDataModel
|
||||||
|
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ class BotPersonalityInterests(BaseDataModel):
|
|||||||
personality_id: str
|
personality_id: str
|
||||||
personality_description: str # 人设描述文本
|
personality_description: str # 人设描述文本
|
||||||
interest_tags: list[BotInterestTag] = field(default_factory=list)
|
interest_tags: list[BotInterestTag] = field(default_factory=list)
|
||||||
embedding_model: str = "text-embedding-ada-002" # 使用的embedding模型
|
embedding_model: list[str] = field(default_factory=lambda: model_config.model_task_config.embedding.model_list) # 使用的embedding模型
|
||||||
last_updated: datetime = field(default_factory=datetime.now)
|
last_updated: datetime = field(default_factory=datetime.now)
|
||||||
version: int = 1 # 版本号,用于追踪更新
|
version: int = 1 # 版本号,用于追踪更新
|
||||||
|
|
||||||
|
|||||||
@@ -546,8 +546,6 @@ class StreamContext(BaseDataModel):
|
|||||||
removed_count = len(self.history_messages) - self.max_context_size
|
removed_count = len(self.history_messages) - self.max_context_size
|
||||||
self.history_messages = self.history_messages[-self.max_context_size :]
|
self.history_messages = self.history_messages[-self.max_context_size :]
|
||||||
logger.debug(f"[历史加载] 移除了 {removed_count} 条最早的消息以适配当前容量限制")
|
logger.debug(f"[历史加载] 移除了 {removed_count} 条最早的消息以适配当前容量限制")
|
||||||
|
|
||||||
logger.info(f"[历史加载] 成功加载 {loaded_count} 条历史消息到内存: {self.stream_id}")
|
|
||||||
else:
|
else:
|
||||||
logger.debug(f"无历史消息需要加载: {self.stream_id}")
|
logger.debug(f"无历史消息需要加载: {self.stream_id}")
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from src.common.database.optimization import (
|
|||||||
Priority,
|
Priority,
|
||||||
get_batch_scheduler,
|
get_batch_scheduler,
|
||||||
get_cache,
|
get_cache,
|
||||||
|
record_preload_access,
|
||||||
)
|
)
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
@@ -145,6 +146,16 @@ class CRUDBase(Generic[T]):
|
|||||||
"""
|
"""
|
||||||
cache_key = f"{self.model_name}:id:{id}"
|
cache_key = f"{self.model_name}:id:{id}"
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
async def _preload_loader() -> dict[str, Any] | None:
|
||||||
|
async with get_db_session() as session:
|
||||||
|
stmt = select(self.model).where(self.model.id == id)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
instance = result.scalar_one_or_none()
|
||||||
|
return _model_to_dict(instance) if instance is not None else None
|
||||||
|
|
||||||
|
await record_preload_access(cache_key, loader=_preload_loader)
|
||||||
|
|
||||||
# 尝试从缓存获取 (缓存的是字典)
|
# 尝试从缓存获取 (缓存的是字典)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
cache = await get_cache()
|
cache = await get_cache()
|
||||||
@@ -189,6 +200,21 @@ class CRUDBase(Generic[T]):
|
|||||||
"""
|
"""
|
||||||
cache_key = f"{self.model_name}:filter:{sorted(filters.items())!s}"
|
cache_key = f"{self.model_name}:filter:{sorted(filters.items())!s}"
|
||||||
|
|
||||||
|
filters_copy = dict(filters)
|
||||||
|
if use_cache:
|
||||||
|
async def _preload_loader() -> dict[str, Any] | None:
|
||||||
|
async with get_db_session() as session:
|
||||||
|
stmt = select(self.model)
|
||||||
|
for key, value in filters_copy.items():
|
||||||
|
if hasattr(self.model, key):
|
||||||
|
stmt = stmt.where(getattr(self.model, key) == value)
|
||||||
|
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
instance = result.scalar_one_or_none()
|
||||||
|
return _model_to_dict(instance) if instance is not None else None
|
||||||
|
|
||||||
|
await record_preload_access(cache_key, loader=_preload_loader)
|
||||||
|
|
||||||
# 尝试从缓存获取 (缓存的是字典)
|
# 尝试从缓存获取 (缓存的是字典)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
cache = await get_cache()
|
cache = await get_cache()
|
||||||
@@ -241,6 +267,29 @@ class CRUDBase(Generic[T]):
|
|||||||
"""
|
"""
|
||||||
cache_key = f"{self.model_name}:multi:{skip}:{limit}:{sorted(filters.items())!s}"
|
cache_key = f"{self.model_name}:multi:{skip}:{limit}:{sorted(filters.items())!s}"
|
||||||
|
|
||||||
|
filters_copy = dict(filters)
|
||||||
|
if use_cache:
|
||||||
|
async def _preload_loader() -> list[dict[str, Any]]:
|
||||||
|
async with get_db_session() as session:
|
||||||
|
stmt = select(self.model)
|
||||||
|
|
||||||
|
# 应用过滤条件
|
||||||
|
for key, value in filters_copy.items():
|
||||||
|
if hasattr(self.model, key):
|
||||||
|
if isinstance(value, list | tuple | set):
|
||||||
|
stmt = stmt.where(getattr(self.model, key).in_(value))
|
||||||
|
else:
|
||||||
|
stmt = stmt.where(getattr(self.model, key) == value)
|
||||||
|
|
||||||
|
# 应用分页
|
||||||
|
stmt = stmt.offset(skip).limit(limit)
|
||||||
|
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
instances = list(result.scalars().all())
|
||||||
|
return [_model_to_dict(inst) for inst in instances]
|
||||||
|
|
||||||
|
await record_preload_access(cache_key, loader=_preload_loader)
|
||||||
|
|
||||||
# 尝试从缓存获取 (缓存的是字典列表)
|
# 尝试从缓存获取 (缓存的是字典列表)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
cache = await get_cache()
|
cache = await get_cache()
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from sqlalchemy import and_, asc, desc, func, or_, select
|
|||||||
from src.common.database.api.crud import _dict_to_model, _model_to_dict
|
from src.common.database.api.crud import _dict_to_model, _model_to_dict
|
||||||
from src.common.database.core.models import Base
|
from src.common.database.core.models import Base
|
||||||
from src.common.database.core.session import get_db_session
|
from src.common.database.core.session import get_db_session
|
||||||
from src.common.database.optimization import get_cache
|
from src.common.database.optimization import get_cache, record_preload_access
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("database.query")
|
logger = get_logger("database.query")
|
||||||
@@ -273,6 +273,16 @@ class QueryBuilder(Generic[T]):
|
|||||||
模型实例列表或字典列表
|
模型实例列表或字典列表
|
||||||
"""
|
"""
|
||||||
cache_key = ":".join(self._cache_key_parts) + ":all"
|
cache_key = ":".join(self._cache_key_parts) + ":all"
|
||||||
|
stmt = self._stmt
|
||||||
|
|
||||||
|
if self._use_cache:
|
||||||
|
async def _preload_loader() -> list[dict[str, Any]]:
|
||||||
|
async with get_db_session() as session:
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
instances = list(result.scalars().all())
|
||||||
|
return [_model_to_dict(inst) for inst in instances]
|
||||||
|
|
||||||
|
await record_preload_access(cache_key, loader=_preload_loader)
|
||||||
|
|
||||||
# 尝试从缓存获取 (缓存的是字典列表)
|
# 尝试从缓存获取 (缓存的是字典列表)
|
||||||
if self._use_cache:
|
if self._use_cache:
|
||||||
@@ -311,6 +321,16 @@ class QueryBuilder(Generic[T]):
|
|||||||
模型实例或None
|
模型实例或None
|
||||||
"""
|
"""
|
||||||
cache_key = ":".join(self._cache_key_parts) + ":first"
|
cache_key = ":".join(self._cache_key_parts) + ":first"
|
||||||
|
stmt = self._stmt
|
||||||
|
|
||||||
|
if self._use_cache:
|
||||||
|
async def _preload_loader() -> dict[str, Any] | None:
|
||||||
|
async with get_db_session() as session:
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
instance = result.scalars().first()
|
||||||
|
return _model_to_dict(instance) if instance is not None else None
|
||||||
|
|
||||||
|
await record_preload_access(cache_key, loader=_preload_loader)
|
||||||
|
|
||||||
# 尝试从缓存获取 (缓存的是字典)
|
# 尝试从缓存获取 (缓存的是字典)
|
||||||
if self._use_cache:
|
if self._use_cache:
|
||||||
@@ -349,6 +369,15 @@ class QueryBuilder(Generic[T]):
|
|||||||
记录数量
|
记录数量
|
||||||
"""
|
"""
|
||||||
cache_key = ":".join(self._cache_key_parts) + ":count"
|
cache_key = ":".join(self._cache_key_parts) + ":count"
|
||||||
|
count_stmt = select(func.count()).select_from(self._stmt.subquery())
|
||||||
|
|
||||||
|
if self._use_cache:
|
||||||
|
async def _preload_loader() -> int:
|
||||||
|
async with get_db_session() as session:
|
||||||
|
result = await session.execute(count_stmt)
|
||||||
|
return result.scalar() or 0
|
||||||
|
|
||||||
|
await record_preload_access(cache_key, loader=_preload_loader)
|
||||||
|
|
||||||
# 尝试从缓存获取
|
# 尝试从缓存获取
|
||||||
if self._use_cache:
|
if self._use_cache:
|
||||||
@@ -358,8 +387,6 @@ class QueryBuilder(Generic[T]):
|
|||||||
return cached
|
return cached
|
||||||
|
|
||||||
# 构建count查询
|
# 构建count查询
|
||||||
count_stmt = select(func.count()).select_from(self._stmt.subquery())
|
|
||||||
|
|
||||||
# 从数据库查询
|
# 从数据库查询
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
result = await session.execute(count_stmt)
|
result = await session.execute(count_stmt)
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ async def get_engine() -> AsyncEngine:
|
|||||||
elif db_type == "postgresql":
|
elif db_type == "postgresql":
|
||||||
await _enable_postgresql_optimizations(_engine)
|
await _enable_postgresql_optimizations(_engine)
|
||||||
|
|
||||||
logger.info(f"✅ {db_type.upper()} 数据库引擎初始化成功")
|
logger.info(f"{db_type.upper()} 数据库引擎初始化成功")
|
||||||
return _engine
|
return _engine
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -116,7 +116,7 @@ def _build_sqlite_config(config) -> tuple[str, dict]:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"SQLite配置: {db_path}")
|
logger.debug(f"SQLite配置: {db_path}")
|
||||||
return url, engine_kwargs
|
return url, engine_kwargs
|
||||||
|
|
||||||
|
|
||||||
@@ -167,7 +167,7 @@ def _build_postgresql_config(config) -> tuple[str, dict]:
|
|||||||
if connect_args:
|
if connect_args:
|
||||||
engine_kwargs["connect_args"] = connect_args
|
engine_kwargs["connect_args"] = connect_args
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"PostgreSQL配置: {config.postgresql_user}@{config.postgresql_host}:{config.postgresql_port}/{config.postgresql_database}"
|
f"PostgreSQL配置: {config.postgresql_user}@{config.postgresql_host}:{config.postgresql_port}/{config.postgresql_database}"
|
||||||
)
|
)
|
||||||
return url, engine_kwargs
|
return url, engine_kwargs
|
||||||
@@ -184,7 +184,7 @@ async def close_engine():
|
|||||||
logger.info("正在关闭数据库引擎...")
|
logger.info("正在关闭数据库引擎...")
|
||||||
await _engine.dispose()
|
await _engine.dispose()
|
||||||
_engine = None
|
_engine = None
|
||||||
logger.info("✅ 数据库引擎已关闭")
|
logger.info("数据库引擎已关闭")
|
||||||
|
|
||||||
|
|
||||||
async def _enable_sqlite_optimizations(engine: AsyncEngine):
|
async def _enable_sqlite_optimizations(engine: AsyncEngine):
|
||||||
@@ -214,8 +214,6 @@ async def _enable_sqlite_optimizations(engine: AsyncEngine):
|
|||||||
# 临时存储使用内存
|
# 临时存储使用内存
|
||||||
await conn.execute(text("PRAGMA temp_store = MEMORY"))
|
await conn.execute(text("PRAGMA temp_store = MEMORY"))
|
||||||
|
|
||||||
logger.info("✅ SQLite性能优化已启用 (WAL模式 + 并发优化)")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"⚠️ SQLite性能优化失败: {e},将使用默认配置")
|
logger.warning(f"⚠️ SQLite性能优化失败: {e},将使用默认配置")
|
||||||
|
|
||||||
@@ -241,8 +239,6 @@ async def _enable_postgresql_optimizations(engine: AsyncEngine):
|
|||||||
# 启用自动 EXPLAIN(可选,用于调试)
|
# 启用自动 EXPLAIN(可选,用于调试)
|
||||||
# await conn.execute(text("SET auto_explain.log_min_duration = '1000'"))
|
# await conn.execute(text("SET auto_explain.log_min_duration = '1000'"))
|
||||||
|
|
||||||
logger.info("✅ PostgreSQL性能优化已启用")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"⚠️ PostgreSQL性能优化失败: {e},将使用默认配置")
|
logger.warning(f"⚠️ PostgreSQL性能优化失败: {e},将使用默认配置")
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from .preloader import (
|
|||||||
DataPreloader,
|
DataPreloader,
|
||||||
close_preloader,
|
close_preloader,
|
||||||
get_preloader,
|
get_preloader,
|
||||||
|
record_preload_access,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -51,4 +52,5 @@ __all__ = [
|
|||||||
"get_batch_scheduler",
|
"get_batch_scheduler",
|
||||||
"get_cache",
|
"get_cache",
|
||||||
"get_preloader",
|
"get_preloader",
|
||||||
|
"record_preload_access",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from collections import defaultdict
|
|||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
@@ -22,6 +23,15 @@ from src.common.logger import get_logger
|
|||||||
|
|
||||||
logger = get_logger("preloader")
|
logger = get_logger("preloader")
|
||||||
|
|
||||||
|
# 预加载注册表(用于后台刷新热点数据)
|
||||||
|
_preload_loader_registry: OrderedDict[str, Callable[[], Awaitable[Any]]] = OrderedDict()
|
||||||
|
_registry_lock = asyncio.Lock()
|
||||||
|
_preload_task: asyncio.Task | None = None
|
||||||
|
_preload_task_lock = asyncio.Lock()
|
||||||
|
_PRELOAD_REGISTRY_LIMIT = 1024
|
||||||
|
# 默认后台预加载轮询间隔(秒)
|
||||||
|
_DEFAULT_PRELOAD_INTERVAL = 60
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AccessPattern:
|
class AccessPattern:
|
||||||
@@ -223,16 +233,19 @@ class DataPreloader:
|
|||||||
|
|
||||||
async def start_preload_batch(
|
async def start_preload_batch(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
loaders: dict[str, Callable[[], Awaitable[Any]]],
|
loaders: dict[str, Callable[[], Awaitable[Any]]],
|
||||||
|
limit: int = 100,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""批量启动预加载任务
|
"""批量启动预加载任务
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session: 数据库会话
|
|
||||||
loaders: 数据键到加载函数的映射
|
loaders: 数据键到加载函数的映射
|
||||||
|
limit: 参与预加载的热点键数量上限
|
||||||
"""
|
"""
|
||||||
preload_keys = await self.get_preload_keys()
|
if not loaders:
|
||||||
|
return
|
||||||
|
|
||||||
|
preload_keys = await self.get_preload_keys(limit=limit)
|
||||||
|
|
||||||
for key in preload_keys:
|
for key in preload_keys:
|
||||||
if key in loaders:
|
if key in loaders:
|
||||||
@@ -418,6 +431,91 @@ class CommonDataPreloader:
|
|||||||
await self.preloader.preload_data(cache_key, loader)
|
await self.preloader.preload_data(cache_key, loader)
|
||||||
|
|
||||||
|
|
||||||
|
# 预加载后台任务与注册表管理
|
||||||
|
async def _get_preload_interval() -> float:
|
||||||
|
"""获取后台预加载轮询间隔"""
|
||||||
|
try:
|
||||||
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
if global_config and getattr(global_config, "database", None):
|
||||||
|
interval = getattr(global_config.database, "preload_interval", None)
|
||||||
|
if interval:
|
||||||
|
return max(5.0, float(interval))
|
||||||
|
except Exception:
|
||||||
|
# 配置可能未加载或不存在该字段,使用默认值
|
||||||
|
pass
|
||||||
|
return float(_DEFAULT_PRELOAD_INTERVAL)
|
||||||
|
|
||||||
|
|
||||||
|
async def _register_preload_loader(
|
||||||
|
cache_key: str,
|
||||||
|
loader: Callable[[], Awaitable[Any]],
|
||||||
|
) -> None:
|
||||||
|
"""注册用于热点预加载的加载函数"""
|
||||||
|
async with _registry_lock:
|
||||||
|
# move_to_end可以保持最近注册的顺序,便于淘汰旧项
|
||||||
|
_preload_loader_registry[cache_key] = loader
|
||||||
|
_preload_loader_registry.move_to_end(cache_key)
|
||||||
|
|
||||||
|
# 控制注册表大小,避免无限增长
|
||||||
|
while len(_preload_loader_registry) > _PRELOAD_REGISTRY_LIMIT:
|
||||||
|
_preload_loader_registry.popitem(last=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def _snapshot_loaders() -> dict[str, Callable[[], Awaitable[Any]]]:
|
||||||
|
"""获取当前注册的预加载loader快照"""
|
||||||
|
async with _registry_lock:
|
||||||
|
return dict(_preload_loader_registry)
|
||||||
|
|
||||||
|
|
||||||
|
async def _preload_worker() -> None:
|
||||||
|
"""后台周期性预加载任务"""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
interval = await _get_preload_interval()
|
||||||
|
loaders = await _snapshot_loaders()
|
||||||
|
|
||||||
|
if loaders:
|
||||||
|
preloader = await get_preloader()
|
||||||
|
await preloader.start_preload_batch(loaders)
|
||||||
|
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"预加载后台任务异常: {e}")
|
||||||
|
# 避免紧急重试导致CPU占用过高
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
|
||||||
|
async def _ensure_preload_worker() -> None:
|
||||||
|
"""确保后台预加载任务已启动"""
|
||||||
|
global _preload_task
|
||||||
|
|
||||||
|
async with _preload_task_lock:
|
||||||
|
if _preload_task is None or _preload_task.done():
|
||||||
|
_preload_task = asyncio.create_task(_preload_worker())
|
||||||
|
|
||||||
|
|
||||||
|
async def record_preload_access(
|
||||||
|
cache_key: str,
|
||||||
|
*,
|
||||||
|
related_keys: list[str] | None = None,
|
||||||
|
loader: Callable[[], Awaitable[Any]] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""记录访问并注册预加载loader
|
||||||
|
|
||||||
|
这个入口为上层API(CRUD/Query)提供:记录访问模式、建立关联关系、
|
||||||
|
以及注册用于后续后台预加载的加载函数。
|
||||||
|
"""
|
||||||
|
preloader = await get_preloader()
|
||||||
|
await preloader.record_access(cache_key, related_keys)
|
||||||
|
|
||||||
|
if loader is not None:
|
||||||
|
await _register_preload_loader(cache_key, loader)
|
||||||
|
await _ensure_preload_worker()
|
||||||
|
|
||||||
|
|
||||||
# 全局预加载器实例
|
# 全局预加载器实例
|
||||||
_global_preloader: DataPreloader | None = None
|
_global_preloader: DataPreloader | None = None
|
||||||
_preloader_lock = asyncio.Lock()
|
_preloader_lock = asyncio.Lock()
|
||||||
@@ -438,7 +536,22 @@ async def get_preloader() -> DataPreloader:
|
|||||||
async def close_preloader() -> None:
|
async def close_preloader() -> None:
|
||||||
"""关闭全局预加载器"""
|
"""关闭全局预加载器"""
|
||||||
global _global_preloader
|
global _global_preloader
|
||||||
|
global _preload_task
|
||||||
|
|
||||||
|
# 停止后台任务
|
||||||
|
if _preload_task is not None:
|
||||||
|
_preload_task.cancel()
|
||||||
|
try:
|
||||||
|
await _preload_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
_preload_task = None
|
||||||
|
|
||||||
|
# 清理注册表
|
||||||
|
async with _registry_lock:
|
||||||
|
_preload_loader_registry.clear()
|
||||||
|
|
||||||
|
# 清理预加载器实例
|
||||||
if _global_preloader is not None:
|
if _global_preloader is not None:
|
||||||
await _global_preloader.clear()
|
await _global_preloader.clear()
|
||||||
_global_preloader = None
|
_global_preloader = None
|
||||||
|
|||||||
@@ -879,14 +879,12 @@ class ModuleColoredConsoleRenderer:
|
|||||||
# sourcery skip: merge-duplicate-blocks
|
# sourcery skip: merge-duplicate-blocks
|
||||||
"""渲染日志消息"""
|
"""渲染日志消息"""
|
||||||
|
|
||||||
# 获取基本信息
|
|
||||||
timestamp = event_dict.get("timestamp", "")
|
timestamp = event_dict.get("timestamp", "")
|
||||||
level = event_dict.get("level", "info")
|
level = event_dict.get("level", "info")
|
||||||
logger_name = event_dict.get("logger_name", "")
|
logger_name = event_dict.get("logger_name", "")
|
||||||
event = event_dict.get("event", "")
|
event = event_dict.get("event", "")
|
||||||
|
|
||||||
# 构建 Rich Text 对象列表
|
parts: list[Text] = []
|
||||||
parts = []
|
|
||||||
|
|
||||||
# 日志级别样式配置
|
# 日志级别样式配置
|
||||||
log_level_style = self._config.get("log_level_style", "lite")
|
log_level_style = self._config.get("log_level_style", "lite")
|
||||||
@@ -1298,9 +1296,9 @@ def start_log_cleanup_task():
|
|||||||
threading.Thread(target=cleanup_task, daemon=True, name="log-cleanup").start()
|
threading.Thread(target=cleanup_task, daemon=True, name="log-cleanup").start()
|
||||||
logger = get_logger("logger")
|
logger = get_logger("logger")
|
||||||
if retention_days == -1:
|
if retention_days == -1:
|
||||||
logger.info("已启动日志任务: 每天 00:00 压缩旧日志(不删除)")
|
logger.debug("已启动日志任务: 每天 00:00 压缩旧日志(不删除)")
|
||||||
else:
|
else:
|
||||||
logger.info(f"已启动日志任务: 每天 00:00 压缩并删除早于 {retention_days} 天的日志")
|
logger.debug(f"已启动日志任务: 每天 00:00 压缩并删除早于 {retention_days} 天的日志")
|
||||||
|
|
||||||
|
|
||||||
def shutdown_logging():
|
def shutdown_logging():
|
||||||
|
|||||||
@@ -112,9 +112,6 @@ def start_tracemalloc(max_frames: int = 25) -> None:
|
|||||||
"""
|
"""
|
||||||
if not tracemalloc.is_tracing():
|
if not tracemalloc.is_tracing():
|
||||||
tracemalloc.start(max_frames)
|
tracemalloc.start(max_frames)
|
||||||
logger.info("tracemalloc started with max_frames=%s", max_frames)
|
|
||||||
else:
|
|
||||||
logger.info("tracemalloc already started")
|
|
||||||
|
|
||||||
|
|
||||||
def stop_tracemalloc() -> None:
|
def stop_tracemalloc() -> None:
|
||||||
|
|||||||
@@ -508,9 +508,9 @@ def load_config(config_path: str) -> Config:
|
|||||||
|
|
||||||
# 创建Config对象(各个配置类会自动进行 Pydantic 验证)
|
# 创建Config对象(各个配置类会自动进行 Pydantic 验证)
|
||||||
try:
|
try:
|
||||||
logger.info("正在解析和验证配置文件...")
|
logger.debug("正在解析和验证配置文件...")
|
||||||
config = Config.from_dict(config_data)
|
config = Config.from_dict(config_data)
|
||||||
logger.info("配置文件解析和验证完成")
|
logger.debug("配置文件解析和验证完成")
|
||||||
|
|
||||||
# 【临时修复】在验证后,手动从原始数据重新加载 master_users
|
# 【临时修复】在验证后,手动从原始数据重新加载 master_users
|
||||||
try:
|
try:
|
||||||
@@ -520,7 +520,7 @@ def load_config(config_path: str) -> Config:
|
|||||||
raw_master_users = config_dict["permission"]["master_users"]
|
raw_master_users = config_dict["permission"]["master_users"]
|
||||||
# 现在 raw_master_users 就是一个标准的 Python 列表了
|
# 现在 raw_master_users 就是一个标准的 Python 列表了
|
||||||
config.permission.master_users = raw_master_users
|
config.permission.master_users = raw_master_users
|
||||||
logger.info(f"【临时修复】已手动将 master_users 设置为: {config.permission.master_users}")
|
logger.debug(f"【临时修复】已手动将 master_users 设置为: {config.permission.master_users}")
|
||||||
except Exception as patch_exc:
|
except Exception as patch_exc:
|
||||||
logger.error(f"【临时修复】手动设置 master_users 失败: {patch_exc}")
|
logger.error(f"【临时修复】手动设置 master_users 失败: {patch_exc}")
|
||||||
|
|
||||||
@@ -545,9 +545,9 @@ def api_ada_load_config(config_path: str) -> APIAdapterConfig:
|
|||||||
config_dict = dict(config_data)
|
config_dict = dict(config_data)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info("正在解析和验证API适配器配置文件...")
|
logger.debug("正在解析和验证API适配器配置文件...")
|
||||||
config = APIAdapterConfig.from_dict(config_dict)
|
config = APIAdapterConfig.from_dict(config_dict)
|
||||||
logger.info("API适配器配置文件解析和验证完成")
|
logger.debug("API适配器配置文件解析和验证完成")
|
||||||
return config
|
return config
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.critical(f"API适配器配置文件解析失败: {e}")
|
logger.critical(f"API适配器配置文件解析失败: {e}")
|
||||||
@@ -566,11 +566,11 @@ def initialize_configs_once() -> tuple[Config, APIAdapterConfig]:
|
|||||||
logger.debug("config.py 初始化已执行,跳过重复运行")
|
logger.debug("config.py 初始化已执行,跳过重复运行")
|
||||||
return global_config, model_config
|
return global_config, model_config
|
||||||
|
|
||||||
logger.info(f"MaiCore当前版本: {MMC_VERSION}")
|
logger.debug(f"MaiCore当前版本: {MMC_VERSION}")
|
||||||
update_config()
|
update_config()
|
||||||
update_model_config()
|
update_model_config()
|
||||||
|
|
||||||
logger.info("正在品鉴配置文件...")
|
logger.debug("正在品鉴配置文件...")
|
||||||
global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml"))
|
global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml"))
|
||||||
model_config = api_ada_load_config(config_path=os.path.join(CONFIG_DIR, "model_config.toml"))
|
model_config = api_ada_load_config(config_path=os.path.join(CONFIG_DIR, "model_config.toml"))
|
||||||
|
|
||||||
@@ -581,4 +581,4 @@ def initialize_configs_once() -> tuple[Config, APIAdapterConfig]:
|
|||||||
# 同一进程只执行一次初始化,避免重复生成或覆盖配置
|
# 同一进程只执行一次初始化,避免重复生成或覆盖配置
|
||||||
global_config, model_config = initialize_configs_once()
|
global_config, model_config = initialize_configs_once()
|
||||||
|
|
||||||
logger.info("非常的新鲜,非常的美味!")
|
logger.debug("非常的新鲜,非常的美味!")
|
||||||
@@ -46,7 +46,7 @@ class Individuality:
|
|||||||
personality_hash, _ = self._get_config_hash(bot_nickname, personality_core, personality_side, identity)
|
personality_hash, _ = self._get_config_hash(bot_nickname, personality_core, personality_side, identity)
|
||||||
self.bot_person_id = personality_hash
|
self.bot_person_id = personality_hash
|
||||||
self.name = bot_nickname
|
self.name = bot_nickname
|
||||||
logger.info(f"生成的 personality_id: {self.bot_person_id[:16]}... (基于人设文本 hash)")
|
logger.debug(f"生成的 personality_id: {self.bot_person_id[:16]}... (基于人设文本 hash)")
|
||||||
|
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
|
|
||||||
|
|||||||
74
src/main.py
74
src/main.py
@@ -155,7 +155,7 @@ class MainSystem:
|
|||||||
default_enabled = getattr(calc_info, "enabled_by_default", True)
|
default_enabled = getattr(calc_info, "enabled_by_default", True)
|
||||||
|
|
||||||
if not enabled or not default_enabled:
|
if not enabled or not default_enabled:
|
||||||
logger.info(f"兴趣计算器 {calc_name} 未启用,跳过")
|
logger.debug(f"兴趣计算器 {calc_name} 未启用,跳过")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -170,7 +170,7 @@ class MainSystem:
|
|||||||
logger.warning(f"无法找到 {calc_name} 的组件类")
|
logger.warning(f"无法找到 {calc_name} 的组件类")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.info(f"成功获取 {calc_name} 的组件类: {component_class.__name__}")
|
logger.debug(f"成功获取 {calc_name} 的组件类: {component_class.__name__}")
|
||||||
|
|
||||||
# 确保组件是 BaseInterestCalculator 的子类
|
# 确保组件是 BaseInterestCalculator 的子类
|
||||||
if not issubclass(component_class, BaseInterestCalculator):
|
if not issubclass(component_class, BaseInterestCalculator):
|
||||||
@@ -191,7 +191,7 @@ class MainSystem:
|
|||||||
# 注册到兴趣管理器
|
# 注册到兴趣管理器
|
||||||
if await interest_manager.register_calculator(calculator_instance):
|
if await interest_manager.register_calculator(calculator_instance):
|
||||||
registered_calculators.append(calculator_instance)
|
registered_calculators.append(calculator_instance)
|
||||||
logger.info(f"成功注册兴趣计算器: {calc_name}")
|
logger.debug(f"成功注册兴趣计算器: {calc_name}")
|
||||||
else:
|
else:
|
||||||
logger.error(f"兴趣计算器 {calc_name} 注册失败")
|
logger.error(f"兴趣计算器 {calc_name} 注册失败")
|
||||||
|
|
||||||
@@ -199,9 +199,9 @@ class MainSystem:
|
|||||||
logger.error(f"处理兴趣计算器 {calc_name} 时出错: {e}")
|
logger.error(f"处理兴趣计算器 {calc_name} 时出错: {e}")
|
||||||
|
|
||||||
if registered_calculators:
|
if registered_calculators:
|
||||||
logger.info(f"成功注册了 {len(registered_calculators)} 个兴趣计算器")
|
logger.debug(f"成功注册了 {len(registered_calculators)} 个兴趣计算器")
|
||||||
for calc in registered_calculators:
|
for calc in registered_calculators:
|
||||||
logger.info(f" - {calc.component_name} v{calc.component_version}")
|
logger.debug(f" - {calc.component_name} v{calc.component_version}")
|
||||||
else:
|
else:
|
||||||
logger.error("未能成功注册任何兴趣计算器")
|
logger.error("未能成功注册任何兴趣计算器")
|
||||||
|
|
||||||
@@ -320,7 +320,7 @@ class MainSystem:
|
|||||||
|
|
||||||
# 并行执行所有清理任务
|
# 并行执行所有清理任务
|
||||||
if cleanup_tasks:
|
if cleanup_tasks:
|
||||||
logger.info(f"开始并行执行 {len(cleanup_tasks)} 个清理任务...")
|
logger.debug(f"开始并行执行 {len(cleanup_tasks)} 个清理任务...")
|
||||||
tasks = [task for _, task in cleanup_tasks]
|
tasks = [task for _, task in cleanup_tasks]
|
||||||
task_names = [name for name, _ in cleanup_tasks]
|
task_names = [name for name, _ in cleanup_tasks]
|
||||||
|
|
||||||
@@ -378,10 +378,10 @@ class MainSystem:
|
|||||||
logger.error("缺少必要的bot配置")
|
logger.error("缺少必要的bot配置")
|
||||||
raise ValueError("Bot配置不完整")
|
raise ValueError("Bot配置不完整")
|
||||||
|
|
||||||
logger.info(f"正在唤醒{global_config.bot.nickname}......")
|
logger.debug(f"正在唤醒{global_config.bot.nickname}......")
|
||||||
|
|
||||||
# 初始化 CoreSinkManager(包含 MessageRuntime)
|
# 初始化 CoreSinkManager(包含 MessageRuntime)
|
||||||
logger.info("正在初始化 CoreSinkManager...")
|
logger.debug("正在初始化 CoreSinkManager...")
|
||||||
self.core_sink_manager = await initialize_core_sink_manager()
|
self.core_sink_manager = await initialize_core_sink_manager()
|
||||||
|
|
||||||
# 获取 MessageHandler 并向 MessageRuntime 注册处理器
|
# 获取 MessageHandler 并向 MessageRuntime 注册处理器
|
||||||
@@ -390,7 +390,7 @@ class MainSystem:
|
|||||||
|
|
||||||
# 向 MessageRuntime 注册消息处理器和钩子
|
# 向 MessageRuntime 注册消息处理器和钩子
|
||||||
self.message_handler.register_handlers(self.core_sink_manager.runtime)
|
self.message_handler.register_handlers(self.core_sink_manager.runtime)
|
||||||
logger.info("CoreSinkManager 和 MessageHandler 初始化完成(使用 MessageRuntime 路由)")
|
logger.debug("CoreSinkManager 和 MessageHandler 初始化完成(使用 MessageRuntime 路由)")
|
||||||
|
|
||||||
# 初始化组件
|
# 初始化组件
|
||||||
await self._init_components()
|
await self._init_components()
|
||||||
@@ -399,19 +399,11 @@ class MainSystem:
|
|||||||
egg_texts, weights = zip(*EGG_PHRASES)
|
egg_texts, weights = zip(*EGG_PHRASES)
|
||||||
selected_egg = choices(egg_texts, weights=weights, k=1)[0]
|
selected_egg = choices(egg_texts, weights=weights, k=1)[0]
|
||||||
|
|
||||||
logger.info(f"""
|
logger.debug(
|
||||||
全部系统初始化完成,{global_config.bot.nickname if global_config and global_config.bot else 'Bot'}已成功唤醒
|
"全部系统初始化完成,%s 已唤醒(彩蛋:%s)",
|
||||||
=========================================================
|
global_config.bot.nickname if global_config and global_config.bot else "Bot",
|
||||||
MoFox_Bot(第三方修改版)
|
selected_egg,
|
||||||
全部组件已成功启动!
|
)
|
||||||
=========================================================
|
|
||||||
🌐 项目地址: https://github.com/MoFox-Studio/MoFox-Core
|
|
||||||
🏠 官方项目: https://github.com/Mai-with-u/MaiBot
|
|
||||||
=========================================================
|
|
||||||
这是基于原版MMC的社区改版,包含增强功能和优化(同时也有更多的'特性')
|
|
||||||
=========================================================
|
|
||||||
小贴士:{selected_egg}
|
|
||||||
""")
|
|
||||||
|
|
||||||
async def _init_components(self) -> None:
|
async def _init_components(self) -> None:
|
||||||
"""初始化其他组件"""
|
"""初始化其他组件"""
|
||||||
@@ -425,7 +417,7 @@ MoFox_Bot(第三方修改版)
|
|||||||
]
|
]
|
||||||
|
|
||||||
await asyncio.gather(*base_init_tasks, return_exceptions=True)
|
await asyncio.gather(*base_init_tasks, return_exceptions=True)
|
||||||
logger.info("基础定时任务初始化成功")
|
logger.debug("基础定时任务初始化成功")
|
||||||
|
|
||||||
# 注册默认事件
|
# 注册默认事件
|
||||||
event_manager.init_default_events()
|
event_manager.init_default_events()
|
||||||
@@ -438,7 +430,7 @@ MoFox_Bot(第三方修改版)
|
|||||||
permission_manager = PermissionManager()
|
permission_manager = PermissionManager()
|
||||||
await permission_manager.initialize()
|
await permission_manager.initialize()
|
||||||
permission_api.set_permission_manager(permission_manager)
|
permission_api.set_permission_manager(permission_manager)
|
||||||
logger.info("权限管理器初始化成功")
|
logger.debug("权限管理器初始化成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"权限管理器初始化失败: {e}")
|
logger.error(f"权限管理器初始化失败: {e}")
|
||||||
|
|
||||||
@@ -451,7 +443,7 @@ MoFox_Bot(第三方修改版)
|
|||||||
self.server.register_router(message_router, prefix="/api")
|
self.server.register_router(message_router, prefix="/api")
|
||||||
self.server.register_router(llm_statistic_router, prefix="/api")
|
self.server.register_router(llm_statistic_router, prefix="/api")
|
||||||
self.server.register_router(visualizer_router, prefix="/visualizer")
|
self.server.register_router(visualizer_router, prefix="/visualizer")
|
||||||
logger.info("API路由注册成功")
|
logger.debug("API路由注册成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"注册API路由失败: {e}")
|
logger.error(f"注册API路由失败: {e}")
|
||||||
# 初始化统一调度器
|
# 初始化统一调度器
|
||||||
@@ -477,11 +469,11 @@ MoFox_Bot(第三方修改版)
|
|||||||
|
|
||||||
# 初始化表情管理器
|
# 初始化表情管理器
|
||||||
get_emoji_manager().initialize()
|
get_emoji_manager().initialize()
|
||||||
logger.info("表情包管理器初始化成功")
|
logger.debug("表情包管理器初始化成功")
|
||||||
|
|
||||||
# 启动情绪管理器
|
# 启动情绪管理器
|
||||||
await mood_manager.start()
|
await mood_manager.start()
|
||||||
logger.info("情绪管理器初始化成功")
|
logger.debug("情绪管理器初始化成功")
|
||||||
|
|
||||||
# 启动聊天管理器的自动保存任务
|
# 启动聊天管理器的自动保存任务
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
@@ -500,9 +492,9 @@ MoFox_Bot(第三方修改版)
|
|||||||
try:
|
try:
|
||||||
if global_config and global_config.memory and global_config.memory.enable:
|
if global_config and global_config.memory and global_config.memory.enable:
|
||||||
from src.memory_graph.manager_singleton import initialize_unified_memory_manager
|
from src.memory_graph.manager_singleton import initialize_unified_memory_manager
|
||||||
logger.info("三层记忆系统已启用,正在初始化...")
|
logger.debug("三层记忆系统已启用,正在初始化...")
|
||||||
await initialize_unified_memory_manager()
|
await initialize_unified_memory_manager()
|
||||||
logger.info("三层记忆系统初始化成功")
|
logger.debug("三层记忆系统初始化成功")
|
||||||
else:
|
else:
|
||||||
logger.debug("三层记忆系统未启用(配置中禁用)")
|
logger.debug("三层记忆系统未启用(配置中禁用)")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -516,19 +508,19 @@ MoFox_Bot(第三方修改版)
|
|||||||
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge
|
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge
|
||||||
|
|
||||||
initialize_lpmm_knowledge()
|
initialize_lpmm_knowledge()
|
||||||
logger.info("LPMM知识库初始化成功")
|
logger.debug("LPMM知识库初始化成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"LPMM知识库初始化失败: {e}")
|
logger.error(f"LPMM知识库初始化失败: {e}")
|
||||||
|
|
||||||
# 消息接收器已在 initialize() 中通过 CoreSinkManager 创建
|
# 消息接收器已在 initialize() 中通过 CoreSinkManager 创建
|
||||||
logger.info("核心消息接收器已就绪(通过 CoreSinkManager)")
|
logger.debug("核心消息接收器已就绪(通过 CoreSinkManager)")
|
||||||
|
|
||||||
# 启动消息重组器
|
# 启动消息重组器
|
||||||
try:
|
try:
|
||||||
from src.utils.message_chunker import reassembler
|
from src.utils.message_chunker import reassembler
|
||||||
|
|
||||||
await reassembler.start_cleanup_task()
|
await reassembler.start_cleanup_task()
|
||||||
logger.info("消息重组器已启动")
|
logger.debug("消息重组器已启动")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"启动消息重组器失败: {e}")
|
logger.error(f"启动消息重组器失败: {e}")
|
||||||
|
|
||||||
@@ -538,11 +530,11 @@ MoFox_Bot(第三方修改版)
|
|||||||
|
|
||||||
storage_batcher = get_message_storage_batcher()
|
storage_batcher = get_message_storage_batcher()
|
||||||
await storage_batcher.start()
|
await storage_batcher.start()
|
||||||
logger.info("消息存储批处理器已启动")
|
logger.debug("消息存储批处理器已启动")
|
||||||
|
|
||||||
update_batcher = get_message_update_batcher()
|
update_batcher = get_message_update_batcher()
|
||||||
await update_batcher.start()
|
await update_batcher.start()
|
||||||
logger.info("消息更新批处理器已启动")
|
logger.debug("消息更新批处理器已启动")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"启动消息批处理器失败: {e}")
|
logger.error(f"启动消息批处理器失败: {e}")
|
||||||
|
|
||||||
@@ -551,7 +543,7 @@ MoFox_Bot(第三方修改版)
|
|||||||
from src.chat.message_manager import message_manager
|
from src.chat.message_manager import message_manager
|
||||||
|
|
||||||
await message_manager.start()
|
await message_manager.start()
|
||||||
logger.info("消息管理器已启动")
|
logger.debug("消息管理器已启动")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"启动消息管理器失败: {e}")
|
logger.error(f"启动消息管理器失败: {e}")
|
||||||
|
|
||||||
@@ -565,7 +557,7 @@ MoFox_Bot(第三方修改版)
|
|||||||
try:
|
try:
|
||||||
await event_manager.trigger_event(EventType.ON_START, permission_group="SYSTEM")
|
await event_manager.trigger_event(EventType.ON_START, permission_group="SYSTEM")
|
||||||
init_time = int(1000 * (time.time() - init_start_time))
|
init_time = int(1000 * (time.time() - init_start_time))
|
||||||
logger.info(f"初始化完成,神经元放电{init_time}次")
|
logger.debug(f"初始化完成,神经元放电{init_time}次")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"启动事件触发失败: {e}")
|
logger.error(f"启动事件触发失败: {e}")
|
||||||
|
|
||||||
@@ -575,7 +567,7 @@ MoFox_Bot(第三方修改版)
|
|||||||
|
|
||||||
adapter_manager = get_adapter_manager()
|
adapter_manager = get_adapter_manager()
|
||||||
await adapter_manager.start_all_adapters()
|
await adapter_manager.start_all_adapters()
|
||||||
logger.info("所有适配器已启动")
|
logger.debug("所有适配器已启动")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"启动适配器失败: {e}")
|
logger.error(f"启动适配器失败: {e}")
|
||||||
|
|
||||||
@@ -584,7 +576,7 @@ MoFox_Bot(第三方修改版)
|
|||||||
if MEM_MONITOR_ENABLED:
|
if MEM_MONITOR_ENABLED:
|
||||||
started = start_background_monitor(interval_sec=2400)
|
started = start_background_monitor(interval_sec=2400)
|
||||||
if started:
|
if started:
|
||||||
logger.info("[DEV] 内存监控已启动 (间隔=2400s ≈ 40min)")
|
logger.debug("[DEV] 内存监控已启动 (间隔=2400s ≈ 40min)")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"启动内存监控失败: {e}")
|
logger.error(f"启动内存监控失败: {e}")
|
||||||
|
|
||||||
@@ -594,7 +586,7 @@ MoFox_Bot(第三方修改版)
|
|||||||
if global_config and global_config.planning_system and global_config.planning_system.monthly_plan_enable:
|
if global_config and global_config.planning_system and global_config.planning_system.monthly_plan_enable:
|
||||||
try:
|
try:
|
||||||
await monthly_plan_manager.start_monthly_plan_generation()
|
await monthly_plan_manager.start_monthly_plan_generation()
|
||||||
logger.info("月度计划管理器初始化成功")
|
logger.debug("月度计划管理器初始化成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"月度计划管理器初始化失败: {e}")
|
logger.error(f"月度计划管理器初始化失败: {e}")
|
||||||
|
|
||||||
@@ -603,7 +595,7 @@ MoFox_Bot(第三方修改版)
|
|||||||
try:
|
try:
|
||||||
await schedule_manager.load_or_generate_today_schedule()
|
await schedule_manager.load_or_generate_today_schedule()
|
||||||
await schedule_manager.start_daily_schedule_generation()
|
await schedule_manager.start_daily_schedule_generation()
|
||||||
logger.info("日程表管理器初始化成功")
|
logger.debug("日程表管理器初始化成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"日程表管理器初始化失败: {e}")
|
logger.error(f"日程表管理器初始化失败: {e}")
|
||||||
|
|
||||||
@@ -615,7 +607,7 @@ MoFox_Bot(第三方修改版)
|
|||||||
result = init_func()
|
result = init_func()
|
||||||
if asyncio.iscoroutine(result):
|
if asyncio.iscoroutine(result):
|
||||||
await result
|
await result
|
||||||
logger.info(f"{component_name}初始化成功")
|
logger.debug(f"{component_name}初始化成功")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{component_name}初始化失败: {e}")
|
logger.error(f"{component_name}初始化失败: {e}")
|
||||||
|
|||||||
@@ -229,7 +229,7 @@ class NodeMerger:
|
|||||||
是否成功
|
是否成功
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"合并节点: '{source.content}' ({source.id}) → '{target.content}' ({target.id})")
|
logger.debug(f"合并节点: '{source.content}' ({source.id}) → '{target.content}' ({target.id})")
|
||||||
|
|
||||||
# 1. 在图存储中合并节点
|
# 1. 在图存储中合并节点
|
||||||
self.graph_store.merge_nodes(source.id, target.id)
|
self.graph_store.merge_nodes(source.id, target.id)
|
||||||
@@ -240,7 +240,7 @@ class NodeMerger:
|
|||||||
# 3. 更新所有相关记忆的节点引用
|
# 3. 更新所有相关记忆的节点引用
|
||||||
self._update_memory_references(source.id, target.id)
|
self._update_memory_references(source.id, target.id)
|
||||||
|
|
||||||
logger.info(f"节点合并成功: {source.id} → {target.id}")
|
logger.debug(f"节点合并成功: {source.id} → {target.id}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -657,7 +657,7 @@ class LongTermMemoryManager:
|
|||||||
memory.metadata["transferred_from_stm"] = source_stm.id
|
memory.metadata["transferred_from_stm"] = source_stm.id
|
||||||
memory.metadata["transfer_time"] = datetime.now().isoformat()
|
memory.metadata["transfer_time"] = datetime.now().isoformat()
|
||||||
|
|
||||||
logger.info(f"✅ 创建长期记忆: {memory.id} (来自短期记忆 {source_stm.id})")
|
logger.info(f"创建长期记忆: {memory.id} (来自短期记忆 {source_stm.id})")
|
||||||
# 强制注册 target_id,无论它是否符合 placeholder 格式
|
# 强制注册 target_id,无论它是否符合 placeholder 格式
|
||||||
# 这样即使 LLM 使用了中文描述作为 ID (如 "新创建的记忆"), 也能正确映射
|
# 这样即使 LLM 使用了中文描述作为 ID (如 "新创建的记忆"), 也能正确映射
|
||||||
self._register_temp_id(op.target_id, memory.id, temp_id_map, force=True)
|
self._register_temp_id(op.target_id, memory.id, temp_id_map, force=True)
|
||||||
@@ -690,7 +690,7 @@ class LongTermMemoryManager:
|
|||||||
success = await self.memory_manager.update_memory(memory_id, **updates)
|
success = await self.memory_manager.update_memory(memory_id, **updates)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
logger.info(f"✅ 更新长期记忆: {memory_id}")
|
logger.info(f"更新长期记忆: {memory_id}")
|
||||||
else:
|
else:
|
||||||
logger.error(f"更新长期记忆失败: {memory_id}")
|
logger.error(f"更新长期记忆失败: {memory_id}")
|
||||||
|
|
||||||
@@ -736,7 +736,7 @@ class LongTermMemoryManager:
|
|||||||
|
|
||||||
# 3. 异步保存
|
# 3. 异步保存
|
||||||
asyncio.create_task(self.memory_manager._async_save_graph_store("合并记忆"))
|
asyncio.create_task(self.memory_manager._async_save_graph_store("合并记忆"))
|
||||||
logger.info(f"✅ 合并记忆完成: {source_ids} -> {target_id}")
|
logger.info(f"合并记忆完成: {source_ids} -> {target_id}")
|
||||||
else:
|
else:
|
||||||
logger.error(f"合并记忆失败: {source_ids}")
|
logger.error(f"合并记忆失败: {source_ids}")
|
||||||
|
|
||||||
@@ -767,7 +767,7 @@ class LongTermMemoryManager:
|
|||||||
if success:
|
if success:
|
||||||
# 尝试为新节点生成 embedding (异步)
|
# 尝试为新节点生成 embedding (异步)
|
||||||
asyncio.create_task(self._generate_node_embedding(node_id, content))
|
asyncio.create_task(self._generate_node_embedding(node_id, content))
|
||||||
logger.info(f"✅ 创建节点: {content} ({node_type}) -> {memory_id}")
|
logger.info(f"创建节点: {content} ({node_type}) -> {memory_id}")
|
||||||
# 强制注册 target_id,无论它是否符合 placeholder 格式
|
# 强制注册 target_id,无论它是否符合 placeholder 格式
|
||||||
self._register_temp_id(op.target_id, node_id, temp_id_map, force=True)
|
self._register_temp_id(op.target_id, node_id, temp_id_map, force=True)
|
||||||
self._register_aliases_from_params(
|
self._register_aliases_from_params(
|
||||||
@@ -798,7 +798,7 @@ class LongTermMemoryManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
logger.info(f"✅ 更新节点: {node_id}")
|
logger.info(f"更新节点: {node_id}")
|
||||||
else:
|
else:
|
||||||
logger.error(f"更新节点失败: {node_id}")
|
logger.error(f"更新节点失败: {node_id}")
|
||||||
|
|
||||||
@@ -825,7 +825,7 @@ class LongTermMemoryManager:
|
|||||||
for source_id in sources:
|
for source_id in sources:
|
||||||
self.memory_manager.graph_store.merge_nodes(source_id, target_id)
|
self.memory_manager.graph_store.merge_nodes(source_id, target_id)
|
||||||
|
|
||||||
logger.info(f"✅ 合并节点: {sources} -> {target_id}")
|
logger.info(f"合并节点: {sources} -> {target_id}")
|
||||||
|
|
||||||
async def _execute_create_edge(
|
async def _execute_create_edge(
|
||||||
self, op: GraphOperation, temp_id_map: dict[str, str]
|
self, op: GraphOperation, temp_id_map: dict[str, str]
|
||||||
@@ -860,7 +860,7 @@ class LongTermMemoryManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if edge_id:
|
if edge_id:
|
||||||
logger.info(f"✅ 创建边: {source_id} -> {target_id} ({relation})")
|
logger.info(f"创建边: {source_id} -> {target_id} ({relation})")
|
||||||
else:
|
else:
|
||||||
logger.error(f"创建边失败: {op}")
|
logger.error(f"创建边失败: {op}")
|
||||||
|
|
||||||
@@ -884,7 +884,7 @@ class LongTermMemoryManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
logger.info(f"✅ 更新边: {edge_id}")
|
logger.info(f"更新边: {edge_id}")
|
||||||
else:
|
else:
|
||||||
logger.error(f"更新边失败: {edge_id}")
|
logger.error(f"更新边失败: {edge_id}")
|
||||||
|
|
||||||
@@ -901,7 +901,7 @@ class LongTermMemoryManager:
|
|||||||
success = self.memory_manager.graph_store.remove_edge(edge_id)
|
success = self.memory_manager.graph_store.remove_edge(edge_id)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
logger.info(f"✅ 删除边: {edge_id}")
|
logger.info(f"删除边: {edge_id}")
|
||||||
else:
|
else:
|
||||||
logger.error(f"删除边失败: {edge_id}")
|
logger.error(f"删除边失败: {edge_id}")
|
||||||
|
|
||||||
@@ -980,7 +980,7 @@ class LongTermMemoryManager:
|
|||||||
self.memory_manager.graph_store
|
self.memory_manager.graph_store
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"✅ 长期记忆衰减完成: {decayed_count} 条记忆已更新")
|
logger.info(f"长期记忆衰减完成: {decayed_count} 条记忆已更新")
|
||||||
return {"decayed_count": decayed_count, "total_memories": len(all_memories)}
|
return {"decayed_count": decayed_count, "total_memories": len(all_memories)}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1009,7 +1009,7 @@ class LongTermMemoryManager:
|
|||||||
# 长期记忆的保存由 MemoryManager 负责
|
# 长期记忆的保存由 MemoryManager 负责
|
||||||
|
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
logger.info("✅ 长期记忆管理器已关闭")
|
logger.info("长期记忆管理器已关闭")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"关闭长期记忆管理器失败: {e}")
|
logger.error(f"关闭长期记忆管理器失败: {e}")
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ class MemoryManager:
|
|||||||
self._maintenance_interval_hours = getattr(self.config, "consolidation_interval_hours", 1.0)
|
self._maintenance_interval_hours = getattr(self.config, "consolidation_interval_hours", 1.0)
|
||||||
self._maintenance_running = False # 维护任务运行状态
|
self._maintenance_running = False # 维护任务运行状态
|
||||||
|
|
||||||
logger.info(f"记忆管理器已创建 (data_dir={self.data_dir}, enable={getattr(self.config, 'enable', False)})")
|
logger.debug(f"记忆管理器已创建 (data_dir={self.data_dir}, enable={getattr(self.config, 'enable', False)})")
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -119,7 +119,7 @@ class MemoryManager:
|
|||||||
self.graph_store = GraphStore()
|
self.graph_store = GraphStore()
|
||||||
else:
|
else:
|
||||||
stats = self.graph_store.get_statistics()
|
stats = self.graph_store.get_statistics()
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"加载图数据: {stats['total_memories']} 条记忆, "
|
f"加载图数据: {stats['total_memories']} 条记忆, "
|
||||||
f"{stats['total_nodes']} 个节点, {stats['total_edges']} 条边"
|
f"{stats['total_nodes']} 个节点, {stats['total_edges']} 条边"
|
||||||
)
|
)
|
||||||
@@ -169,7 +169,7 @@ class MemoryManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
logger.info("✅ 记忆管理器初始化完成")
|
logger.info("记忆管理器初始化完成")
|
||||||
|
|
||||||
# 启动后台维护任务
|
# 启动后台维护任务
|
||||||
self._start_maintenance_task()
|
self._start_maintenance_task()
|
||||||
@@ -208,7 +208,7 @@ class MemoryManager:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
logger.info("✅ 记忆管理器已关闭")
|
logger.info("记忆管理器已关闭")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"关闭记忆管理器失败: {e}")
|
logger.error(f"关闭记忆管理器失败: {e}")
|
||||||
@@ -1013,11 +1013,11 @@ class MemoryManager:
|
|||||||
await self.persistence.save_graph_store(self.graph_store)
|
await self.persistence.save_graph_store(self.graph_store)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"✅ 自动遗忘完成: 遗忘了 {forgotten_count} 条记忆, "
|
f"自动遗忘完成: 遗忘了 {forgotten_count} 条记忆, "
|
||||||
f"清理了 {orphan_nodes} 个孤立节点, {orphan_edges} 条孤立边"
|
f"清理了 {orphan_nodes} 个孤立节点, {orphan_edges} 条孤立边"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("✅ 自动遗忘完成: 没有需要遗忘的记忆")
|
logger.info("自动遗忘完成: 没有需要遗忘的记忆")
|
||||||
|
|
||||||
return forgotten_count
|
return forgotten_count
|
||||||
|
|
||||||
@@ -1151,7 +1151,7 @@ class MemoryManager:
|
|||||||
await self.initialize()
|
await self.initialize()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info("🧹 开始记忆整理:检查遗忘 + 清理孤立节点...")
|
logger.info("开始记忆整理:检查遗忘 + 清理孤立节点...")
|
||||||
|
|
||||||
# 步骤1: 自动遗忘低激活度的记忆
|
# 步骤1: 自动遗忘低激活度的记忆
|
||||||
forgotten_count = await self.auto_forget()
|
forgotten_count = await self.auto_forget()
|
||||||
@@ -1166,7 +1166,7 @@ class MemoryManager:
|
|||||||
"message": "记忆整理完成(仅遗忘和清理孤立节点)"
|
"message": "记忆整理完成(仅遗忘和清理孤立节点)"
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"✅ 记忆整理完成: {result}")
|
logger.info(f"记忆整理完成: {result}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1274,7 +1274,7 @@ class MemoryManager:
|
|||||||
await self.initialize()
|
await self.initialize()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info("🔧 开始执行记忆系统维护...")
|
logger.info("开始执行记忆系统维护...")
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"forgotten": 0,
|
"forgotten": 0,
|
||||||
@@ -1303,11 +1303,11 @@ class MemoryManager:
|
|||||||
total_time = (datetime.now() - start_time).total_seconds()
|
total_time = (datetime.now() - start_time).total_seconds()
|
||||||
result["total_time"] = total_time
|
result["total_time"] = total_time
|
||||||
|
|
||||||
logger.info(f"✅ 维护完成 (耗时 {total_time:.2f}s): {result}")
|
logger.info(f"维护完成 (耗时 {total_time:.2f}s): {result}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"❌ 维护失败: {e}")
|
logger.error(f"维护失败: {e}")
|
||||||
return {"error": str(e), "total_time": 0}
|
return {"error": str(e), "total_time": 0}
|
||||||
|
|
||||||
async def _lightweight_auto_link_memories( # 已废弃
|
async def _lightweight_auto_link_memories( # 已废弃
|
||||||
@@ -1373,8 +1373,8 @@ class MemoryManager:
|
|||||||
name="memory_maintenance_loop"
|
name="memory_maintenance_loop"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"✅ 记忆维护后台任务已启动 "
|
f"记忆维护后台任务已启动 "
|
||||||
f"(间隔={self._maintenance_interval_hours}小时)"
|
f"(间隔={self._maintenance_interval_hours}小时)"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1397,7 +1397,7 @@ class MemoryManager:
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.debug("维护任务已取消")
|
logger.debug("维护任务已取消")
|
||||||
|
|
||||||
logger.info("✅ 记忆维护后台任务已停止")
|
logger.info("记忆维护后台任务已停止")
|
||||||
self._maintenance_task = None
|
self._maintenance_task = None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ async def initialize_memory_manager(
|
|||||||
await _memory_manager.initialize()
|
await _memory_manager.initialize()
|
||||||
|
|
||||||
_initialized = True
|
_initialized = True
|
||||||
logger.info("✅ 全局 MemoryManager 初始化成功")
|
logger.info("全局 MemoryManager 初始化成功")
|
||||||
|
|
||||||
return _memory_manager
|
return _memory_manager
|
||||||
|
|
||||||
@@ -98,7 +98,7 @@ async def shutdown_memory_manager():
|
|||||||
if _memory_manager:
|
if _memory_manager:
|
||||||
try:
|
try:
|
||||||
await _memory_manager.shutdown()
|
await _memory_manager.shutdown()
|
||||||
logger.info("✅ 全局 MemoryManager 已关闭")
|
logger.info("全局 MemoryManager 已关闭")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"关闭 MemoryManager 时出错: {e}")
|
logger.error(f"关闭 MemoryManager 时出错: {e}")
|
||||||
finally:
|
finally:
|
||||||
@@ -205,7 +205,7 @@ async def shutdown_unified_memory_manager() -> None:
|
|||||||
try:
|
try:
|
||||||
await _unified_memory_manager.shutdown()
|
await _unified_memory_manager.shutdown()
|
||||||
_unified_memory_manager = None
|
_unified_memory_manager = None
|
||||||
logger.info("✅ 统一记忆管理器已关闭")
|
logger.info("统一记忆管理器已关闭")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"关闭统一记忆管理器失败: {e}")
|
logger.error(f"关闭统一记忆管理器失败: {e}")
|
||||||
|
|||||||
@@ -417,13 +417,13 @@ class ShortTermMemoryManager:
|
|||||||
|
|
||||||
elif decision.operation == ShortTermOperation.DISCARD:
|
elif decision.operation == ShortTermOperation.DISCARD:
|
||||||
# 丢弃
|
# 丢弃
|
||||||
logger.info(f"🗑️ 丢弃低价值记忆: {decision.reasoning}")
|
logger.debug(f"丢弃低价值记忆: {decision.reasoning}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
elif decision.operation == ShortTermOperation.KEEP_SEPARATE:
|
elif decision.operation == ShortTermOperation.KEEP_SEPARATE:
|
||||||
# 保持独立
|
# 保持独立
|
||||||
self.memories.append(new_memory)
|
self.memories.append(new_memory)
|
||||||
logger.info(f"✅ 保持独立记忆: {new_memory.id}")
|
logger.debug(f"保持独立记忆: {new_memory.id}")
|
||||||
return new_memory
|
return new_memory
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -579,7 +579,7 @@ class ShortTermMemoryManager:
|
|||||||
for mem in results:
|
for mem in results:
|
||||||
mem.update_access()
|
mem.update_access()
|
||||||
|
|
||||||
logger.info(f"检索到 {len(results)} 条短期记忆")
|
logger.debug(f"检索到 {len(results)} 条短期记忆")
|
||||||
return results
|
return results
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -730,7 +730,7 @@ class ShortTermMemoryManager:
|
|||||||
memory.embedding = embedding
|
memory.embedding = embedding
|
||||||
success_count += 1
|
success_count += 1
|
||||||
|
|
||||||
logger.info(f"✅ 向量重新生成完成(成功: {success_count}/{len(memories_to_process)})")
|
logger.info(f"向量重新生成完成(成功: {success_count}/{len(memories_to_process)})")
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
"""关闭管理器"""
|
"""关闭管理器"""
|
||||||
@@ -744,7 +744,7 @@ class ShortTermMemoryManager:
|
|||||||
await self._save_to_disk()
|
await self._save_to_disk()
|
||||||
|
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
logger.info("✅ 短期记忆管理器已关闭")
|
logger.info("短期记忆管理器已关闭")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"关闭短期记忆管理器失败: {e}")
|
logger.error(f"关闭短期记忆管理器失败: {e}")
|
||||||
|
|||||||
@@ -39,9 +39,6 @@ class GraphStore:
|
|||||||
# 节点 -> {memory_id: [MemoryEdge]},用于快速获取邻接边
|
# 节点 -> {memory_id: [MemoryEdge]},用于快速获取邻接边
|
||||||
self.node_edge_index: dict[str, dict[str, list[MemoryEdge]]] = {}
|
self.node_edge_index: dict[str, dict[str, list[MemoryEdge]]] = {}
|
||||||
|
|
||||||
logger.info("初始化图存储")
|
|
||||||
|
|
||||||
|
|
||||||
def _register_memory_edges(self, memory: Memory) -> None:
|
def _register_memory_edges(self, memory: Memory) -> None:
|
||||||
"""在记忆中的边加入邻接索引"""
|
"""在记忆中的边加入邻接索引"""
|
||||||
for edge in memory.edges:
|
for edge in memory.edges:
|
||||||
|
|||||||
@@ -825,7 +825,7 @@ class MemoryTools:
|
|||||||
filter_rate = filtered_count / total_candidates
|
filter_rate = filtered_count / total_candidates
|
||||||
if filter_rate > 0.5: # 降低警告阈值到50%
|
if filter_rate > 0.5: # 降低警告阈值到50%
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"⚠️ 过滤率较高 ({filter_rate*100:.1f}%)!"
|
f"过滤率较高 ({filter_rate*100:.1f}%)!"
|
||||||
f"原因:{filter_stats['importance']}个记忆重要性 < {self.search_min_importance}。"
|
f"原因:{filter_stats['importance']}个记忆重要性 < {self.search_min_importance}。"
|
||||||
f"建议:1) 降低 min_importance 阈值,或 2) 检查记忆质量评分"
|
f"建议:1) 降低 min_importance 阈值,或 2) 检查记忆质量评分"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -159,7 +159,7 @@ class UnifiedMemoryManager:
|
|||||||
await self.long_term_manager.initialize()
|
await self.long_term_manager.initialize()
|
||||||
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
logger.info("✅ 统一记忆管理器初始化完成")
|
logger.info("统一记忆管理器初始化完成")
|
||||||
|
|
||||||
# 启动自动转移任务
|
# 启动自动转移任务
|
||||||
self._start_auto_transfer_task()
|
self._start_auto_transfer_task()
|
||||||
@@ -716,7 +716,7 @@ class UnifiedMemoryManager:
|
|||||||
await self.memory_manager.shutdown()
|
await self.memory_manager.shutdown()
|
||||||
|
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
logger.info("✅ 统一记忆管理器已关闭")
|
logger.info("统一记忆管理器已关闭")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"关闭统一记忆管理器失败: {e}")
|
logger.error(f"关闭统一记忆管理器失败: {e}")
|
||||||
|
|||||||
@@ -64,10 +64,10 @@ class EmbeddingGenerator:
|
|||||||
self._api_dimension = embedding_config.embedding_dimension
|
self._api_dimension = embedding_config.embedding_dimension
|
||||||
|
|
||||||
self._api_available = True
|
self._api_available = True
|
||||||
logger.info(f"✅ Embedding API 初始化成功 (维度: {self._api_dimension})")
|
logger.info(f"Embedding API 初始化成功 (维度: {self._api_dimension})")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"⚠️ Embedding API 初始化失败: {e}")
|
logger.warning(f"Embedding API 初始化失败: {e}")
|
||||||
self._api_available = False
|
self._api_available = False
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -745,7 +745,7 @@ class PathScoreExpansion:
|
|||||||
node_type_hints[node_id] = getattr(node_obj_type, "value", str(node_obj_type))
|
node_type_hints[node_id] = getattr(node_obj_type, "value", str(node_obj_type))
|
||||||
|
|
||||||
if all_node_ids:
|
if all_node_ids:
|
||||||
logger.info(f"🧠 预处理 {len(all_node_ids)} 个节点的类型信息")
|
logger.debug(f"预处理 {len(all_node_ids)} 个节点的类型信息")
|
||||||
for nid in all_node_ids:
|
for nid in all_node_ids:
|
||||||
node_attrs = self.graph_store.graph.nodes.get(nid, {}) if hasattr(self.graph_store, "graph") else {}
|
node_attrs = self.graph_store.graph.nodes.get(nid, {}) if hasattr(self.graph_store, "graph") else {}
|
||||||
metadata = node_attrs.get("metadata", {}) if isinstance(node_attrs, dict) else {}
|
metadata = node_attrs.get("metadata", {}) if isinstance(node_attrs, dict) else {}
|
||||||
|
|||||||
@@ -420,14 +420,6 @@ class UnifiedScheduler:
|
|||||||
# 取消所有正在执行的任务
|
# 取消所有正在执行的任务
|
||||||
await self._cancel_all_running_tasks()
|
await self._cancel_all_running_tasks()
|
||||||
|
|
||||||
# 显示最终统计
|
|
||||||
stats = self.get_statistics()
|
|
||||||
logger.info(
|
|
||||||
f"调度器最终统计: 总任务={stats['total_tasks']}, "
|
|
||||||
f"执行次数={stats['total_executions']}, "
|
|
||||||
f"失败={stats['total_failures']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 清理资源
|
# 清理资源
|
||||||
self._tasks.clear()
|
self._tasks.clear()
|
||||||
self._tasks_by_name.clear()
|
self._tasks_by_name.clear()
|
||||||
|
|||||||
@@ -137,7 +137,7 @@ class EventManager:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
event.enabled = True
|
event.enabled = True
|
||||||
logger.info(f"事件 {event_name} 已启用")
|
logger.debug(f"事件 {event_name} 已启用")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def disable_event(self, event_name: EventType | str) -> bool:
|
def disable_event(self, event_name: EventType | str) -> bool:
|
||||||
@@ -155,7 +155,7 @@ class EventManager:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
event.enabled = False
|
event.enabled = False
|
||||||
logger.info(f"事件 {event_name} 已禁用")
|
logger.debug(f"事件 {event_name} 已禁用")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def register_event_handler(self, handler_class: type[BaseEventHandler], plugin_config: dict | None = None) -> bool:
|
def register_event_handler(self, handler_class: type[BaseEventHandler], plugin_config: dict | None = None) -> bool:
|
||||||
@@ -198,7 +198,7 @@ class EventManager:
|
|||||||
self._pending_subscriptions[handler_name] = failed_subscriptions
|
self._pending_subscriptions[handler_name] = failed_subscriptions
|
||||||
logger.warning(f"事件处理器 {handler_name} 的部分订阅失败,已缓存: {failed_subscriptions}")
|
logger.warning(f"事件处理器 {handler_name} 的部分订阅失败,已缓存: {failed_subscriptions}")
|
||||||
|
|
||||||
logger.info(f"事件处理器 {handler_name} 注册成功")
|
logger.debug(f"事件处理器 {handler_name} 注册成功")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_event_handler(self, handler_name: str) -> BaseEventHandler | None:
|
def get_event_handler(self, handler_name: str) -> BaseEventHandler | None:
|
||||||
@@ -246,7 +246,7 @@ class EventManager:
|
|||||||
event.subscribers.remove(subscriber)
|
event.subscribers.remove(subscriber)
|
||||||
logger.debug(f"事件处理器 {handler_name} 已从事件 {event.name} 取消订阅。")
|
logger.debug(f"事件处理器 {handler_name} 已从事件 {event.name} 取消订阅。")
|
||||||
|
|
||||||
logger.info(f"事件处理器 {handler_name} 已被完全移除。")
|
logger.debug(f"事件处理器 {handler_name} 已被完全移除。")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@@ -284,7 +284,7 @@ class EventManager:
|
|||||||
# 按权重从高到低排序订阅者
|
# 按权重从高到低排序订阅者
|
||||||
event.subscribers.sort(key=lambda h: getattr(h, "weight", 0), reverse=True)
|
event.subscribers.sort(key=lambda h: getattr(h, "weight", 0), reverse=True)
|
||||||
|
|
||||||
logger.info(f"事件处理器 {handler_name} 成功订阅到事件 {event_name},当前权重排序完成")
|
logger.debug(f"事件处理器 {handler_name} 成功订阅到事件 {event_name},当前权重排序完成")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def unsubscribe_handler_from_event(self, handler_name: str, event_name: EventType | str) -> bool:
|
def unsubscribe_handler_from_event(self, handler_name: str, event_name: EventType | str) -> bool:
|
||||||
@@ -311,7 +311,7 @@ class EventManager:
|
|||||||
break
|
break
|
||||||
|
|
||||||
if removed:
|
if removed:
|
||||||
logger.info(f"事件处理器 {handler_name} 成功从事件 {event_name} 取消订阅")
|
logger.debug(f"事件处理器 {handler_name} 成功从事件 {event_name} 取消订阅")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"事件处理器 {handler_name} 未订阅事件 {event_name}")
|
logger.warning(f"事件处理器 {handler_name} 未订阅事件 {event_name}")
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,6 @@ class PluginManager:
|
|||||||
core_sink: 核心消息接收器实例(InProcessCoreSink)
|
core_sink: 核心消息接收器实例(InProcessCoreSink)
|
||||||
"""
|
"""
|
||||||
self._core_sink = core_sink
|
self._core_sink = core_sink
|
||||||
logger.info("已设置核心消息接收器")
|
|
||||||
|
|
||||||
def add_plugin_directory(self, directory: str) -> bool:
|
def add_plugin_directory(self, directory: str) -> bool:
|
||||||
"""添加插件目录"""
|
"""添加插件目录"""
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ class StreamToolHistoryManager:
|
|||||||
"average_execution_time": 0.0,
|
"average_execution_time": 0.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"[{chat_id}] 工具历史记录管理器初始化完成,最大历史: {max_history}")
|
logger.debug(f"[{chat_id}] 工具历史记录管理器初始化完成,最大历史: {max_history}")
|
||||||
|
|
||||||
async def add_tool_call(self, record: ToolCallRecord) -> None:
|
async def add_tool_call(self, record: ToolCallRecord) -> None:
|
||||||
"""添加工具调用记录
|
"""添加工具调用记录
|
||||||
@@ -141,7 +141,7 @@ class StreamToolHistoryManager:
|
|||||||
if self.enable_memory_cache:
|
if self.enable_memory_cache:
|
||||||
memory_result = self._search_memory_cache(tool_name, args)
|
memory_result = self._search_memory_cache(tool_name, args)
|
||||||
if memory_result:
|
if memory_result:
|
||||||
logger.info(f"[{self.chat_id}] 内存缓存命中: {tool_name}")
|
logger.debug(f"[{self.chat_id}] 内存缓存命中: {tool_name}")
|
||||||
return memory_result
|
return memory_result
|
||||||
|
|
||||||
# 然后检查全局缓存系统
|
# 然后检查全局缓存系统
|
||||||
@@ -436,7 +436,7 @@ def _evict_old_stream_managers() -> None:
|
|||||||
evicted.append(chat_id)
|
evicted.append(chat_id)
|
||||||
|
|
||||||
if evicted:
|
if evicted:
|
||||||
logger.info(f"🔧 StreamToolHistoryManager LRU淘汰: 释放了 {len(evicted)} 个不活跃的管理器")
|
logger.debug(f"StreamToolHistoryManager LRU淘汰: 释放了 {len(evicted)} 个不活跃的管理器")
|
||||||
|
|
||||||
|
|
||||||
def get_stream_tool_history_manager(chat_id: str) -> StreamToolHistoryManager:
|
def get_stream_tool_history_manager(chat_id: str) -> StreamToolHistoryManager:
|
||||||
|
|||||||
@@ -26,15 +26,13 @@ class InterestService:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info("开始初始化智能兴趣系统...")
|
logger.info("开始初始化智能兴趣系统...")
|
||||||
logger.info(f"人设ID: {personality_id}, 描述长度: {len(personality_description)}")
|
|
||||||
|
|
||||||
await bot_interest_manager.initialize(personality_description, personality_id)
|
await bot_interest_manager.initialize(personality_description, personality_id)
|
||||||
self.is_initialized = True
|
self.is_initialized = True
|
||||||
logger.info("智能兴趣系统初始化完成。")
|
logger.info("智能兴趣系统初始化完成。")
|
||||||
|
|
||||||
# 显示初始化后的统计信息
|
# 显示初始化后的统计信息
|
||||||
stats = bot_interest_manager.get_interest_stats()
|
stats = bot_interest_manager.get_interest_stats()
|
||||||
logger.info(f"兴趣系统统计: {stats}")
|
logger.debug(f"兴趣系统统计: {stats}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"初始化智能兴趣系统失败: {e}")
|
logger.error(f"初始化智能兴趣系统失败: {e}")
|
||||||
|
|||||||
@@ -77,14 +77,6 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
|||||||
logger.info(f" - 回复冷却减少: {self.reply_cooldown_reduction}")
|
logger.info(f" - 回复冷却减少: {self.reply_cooldown_reduction}")
|
||||||
logger.info(f" - 最大不回复计数: {self.max_no_reply_count}")
|
logger.info(f" - 最大不回复计数: {self.max_no_reply_count}")
|
||||||
|
|
||||||
# 检查 bot_interest_manager 状态
|
|
||||||
try:
|
|
||||||
logger.info(f" - bot_interest_manager 初始化状态: {bot_interest_manager.is_initialized}")
|
|
||||||
if not bot_interest_manager.is_initialized:
|
|
||||||
logger.warning(" - bot_interest_manager 未初始化,这将导致兴趣匹配返回默认值0.3")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f" - 检查 bot_interest_manager 时出错: {e}")
|
|
||||||
|
|
||||||
async def execute(self, message: "DatabaseMessages") -> InterestCalculationResult:
|
async def execute(self, message: "DatabaseMessages") -> InterestCalculationResult:
|
||||||
"""执行AffinityFlow风格的兴趣值计算"""
|
"""执行AffinityFlow风格的兴趣值计算"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -85,14 +85,14 @@ class ProactiveThinkingReplyHandler(BaseEventHandler):
|
|||||||
|
|
||||||
if success:
|
if success:
|
||||||
if was_paused:
|
if was_paused:
|
||||||
logger.info(f"[成功] 聊天流 {stream_id} 主动思考已恢复并重置")
|
logger.info(f"聊天流 {stream_id} 主动思考已恢复并重置")
|
||||||
else:
|
else:
|
||||||
logger.debug(f"[成功] 聊天流 {stream_id} 主动思考任务已重置")
|
logger.debug(f"聊天流 {stream_id} 主动思考任务已重置")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"[错误] 重置聊天流 {stream_id} 主动思考任务失败")
|
logger.warning(f"重置聊天流 {stream_id} 主动思考任务失败")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"❌ 处理reply事件时出错: {e}")
|
logger.error(f"处理reply事件时出错: {e}")
|
||||||
|
|
||||||
# 总是继续处理其他handler
|
# 总是继续处理其他handler
|
||||||
return HandlerResult(success=True, continue_process=True, message=None)
|
return HandlerResult(success=True, continue_process=True, message=None)
|
||||||
|
|||||||
@@ -141,7 +141,6 @@ class CounterAttackGenerator:
|
|||||||
if success and response:
|
if success and response:
|
||||||
# 清理响应
|
# 清理响应
|
||||||
response = response.strip().strip('"').strip("'")
|
response = response.strip().strip('"').strip("'")
|
||||||
logger.info(f"LLM生成反击响应: {response[:50]}...")
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -155,25 +155,25 @@ class QZoneService:
|
|||||||
return {"success": False, "message": f"好友'{target_name}'没有关联QQ号"}
|
return {"success": False, "message": f"好友'{target_name}'没有关联QQ号"}
|
||||||
|
|
||||||
qq_account = config_api.get_global_config("bot.qq_account", "")
|
qq_account = config_api.get_global_config("bot.qq_account", "")
|
||||||
logger.info(f"[DEBUG] 准备获取API客户端,qq_account={qq_account}")
|
logger.debug(f"准备获取API客户端,qq_account={qq_account}")
|
||||||
api_client = await self._get_api_client(qq_account, stream_id)
|
api_client = await self._get_api_client(qq_account, stream_id)
|
||||||
if not api_client:
|
if not api_client:
|
||||||
logger.error("[DEBUG] API客户端获取失败,返回错误")
|
logger.error("API客户端获取失败,返回错误")
|
||||||
return {"success": False, "message": "获取QZone API客户端失败"}
|
return {"success": False, "message": "获取QZone API客户端失败"}
|
||||||
|
|
||||||
logger.info("[DEBUG] API客户端获取成功,准备读取说说")
|
logger.debug("API客户端获取成功,准备读取说说")
|
||||||
num_to_read = self.get_config("read.read_number", 5)
|
num_to_read = self.get_config("read.read_number", 5)
|
||||||
|
|
||||||
# 尝试执行,如果Cookie失效则自动重试一次
|
# 尝试执行,如果Cookie失效则自动重试一次
|
||||||
for retry_count in range(2): # 最多尝试2次
|
for retry_count in range(2): # 最多尝试2次
|
||||||
try:
|
try:
|
||||||
logger.info(f"[DEBUG] 开始调用 list_feeds,target_qq={target_qq}, num={num_to_read}")
|
logger.debug(f"开始调用 list_feeds,target_qq={target_qq}, num={num_to_read}")
|
||||||
feeds = await api_client["list_feeds"](target_qq, num_to_read)
|
feeds = await api_client["list_feeds"](target_qq, num_to_read)
|
||||||
logger.info(f"[DEBUG] list_feeds 返回,feeds数量={len(feeds) if feeds else 0}")
|
logger.debug(f"list_feeds 返回,feeds数量={len(feeds) if feeds else 0}")
|
||||||
if not feeds:
|
if not feeds:
|
||||||
return {"success": True, "message": f"没有从'{target_name}'的空间获取到新说说。"}
|
return {"success": True, "message": f"没有从'{target_name}'的空间获取到新说说。"}
|
||||||
|
|
||||||
logger.info(f"[DEBUG] 准备处理 {len(feeds)} 条说说")
|
logger.debug(f"准备处理 {len(feeds)} 条说说")
|
||||||
total_liked = 0
|
total_liked = 0
|
||||||
total_commented = 0
|
total_commented = 0
|
||||||
for feed in feeds:
|
for feed in feeds:
|
||||||
@@ -624,7 +624,7 @@ class QZoneService:
|
|||||||
raise RuntimeError(f"无法连接到Napcat服务: 超过最大重试次数({max_retries})")
|
raise RuntimeError(f"无法连接到Napcat服务: 超过最大重试次数({max_retries})")
|
||||||
|
|
||||||
async def _get_api_client(self, qq_account: str, stream_id: str | None) -> dict | None:
|
async def _get_api_client(self, qq_account: str, stream_id: str | None) -> dict | None:
|
||||||
logger.info(f"[DEBUG] 开始获取API客户端,qq_account={qq_account}")
|
logger.debug(f"开始获取API客户端,qq_account={qq_account}")
|
||||||
cookies = await self.cookie_service.get_cookies(qq_account, stream_id)
|
cookies = await self.cookie_service.get_cookies(qq_account, stream_id)
|
||||||
if not cookies:
|
if not cookies:
|
||||||
logger.error(
|
logger.error(
|
||||||
@@ -632,14 +632,14 @@ class QZoneService:
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
logger.info(f"[DEBUG] Cookie获取成功,keys: {list(cookies.keys())}")
|
logger.debug(f"Cookie获取成功,keys: {list(cookies.keys())}")
|
||||||
|
|
||||||
p_skey = cookies.get("p_skey") or cookies.get("p_skey".upper())
|
p_skey = cookies.get("p_skey") or cookies.get("p_skey".upper())
|
||||||
if not p_skey:
|
if not p_skey:
|
||||||
logger.error(f"获取API客户端失败:Cookie中缺少关键的 'p_skey'。Cookie内容: {cookies}")
|
logger.error(f"获取API客户端失败:Cookie中缺少关键的 'p_skey'。Cookie内容: {cookies}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
logger.info("[DEBUG] p_skey获取成功")
|
logger.debug("p_skey获取成功")
|
||||||
|
|
||||||
gtk = self._generate_gtk(p_skey)
|
gtk = self._generate_gtk(p_skey)
|
||||||
uin = cookies.get("uin", "").lstrip("o")
|
uin = cookies.get("uin", "").lstrip("o")
|
||||||
@@ -647,7 +647,7 @@ class QZoneService:
|
|||||||
logger.error(f"获取API客户端失败:Cookie中缺少关键的 'uin'。Cookie内容: {cookies}")
|
logger.error(f"获取API客户端失败:Cookie中缺少关键的 'uin'。Cookie内容: {cookies}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
logger.info(f"[DEBUG] uin={uin}, gtk={gtk}, 准备构造API客户端")
|
logger.debug(f"uin={uin}, gtk={gtk}, 准备构造API客户端")
|
||||||
|
|
||||||
async def _request(method, url, params=None, data=None, headers=None):
|
async def _request(method, url, params=None, data=None, headers=None):
|
||||||
final_headers = {"referer": f"https://user.qzone.qq.com/{uin}", "origin": "https://user.qzone.qq.com"}
|
final_headers = {"referer": f"https://user.qzone.qq.com/{uin}", "origin": "https://user.qzone.qq.com"}
|
||||||
@@ -851,7 +851,7 @@ class QZoneService:
|
|||||||
async def _list_feeds(t_qq: str, num: int) -> list[dict]:
|
async def _list_feeds(t_qq: str, num: int) -> list[dict]:
|
||||||
"""获取指定用户说说列表 (统一接口)"""
|
"""获取指定用户说说列表 (统一接口)"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"[DEBUG] _list_feeds 开始,t_qq={t_qq}, num={num}")
|
logger.debug(f"_list_feeds 开始,t_qq={t_qq}, num={num}")
|
||||||
# 统一使用 format=json 获取完整评论
|
# 统一使用 format=json 获取完整评论
|
||||||
params = {
|
params = {
|
||||||
"g_tk": gtk,
|
"g_tk": gtk,
|
||||||
@@ -865,12 +865,11 @@ class QZoneService:
|
|||||||
"format": "json", # 关键:使用JSON格式
|
"format": "json", # 关键:使用JSON格式
|
||||||
"need_comment": 1,
|
"need_comment": 1,
|
||||||
}
|
}
|
||||||
logger.info(f"[DEBUG] 准备发送HTTP请求到 {self.LIST_URL}")
|
logger.debug(f"准备发送HTTP请求到 {self.LIST_URL}")
|
||||||
res_text = await _request("GET", self.LIST_URL, params=params)
|
res_text = await _request("GET", self.LIST_URL, params=params)
|
||||||
logger.info(f"[DEBUG] HTTP请求返回,响应长度={len(res_text)}")
|
logger.debug(f"HTTP请求返回,响应长度={len(res_text)}")
|
||||||
json_data = orjson.loads(res_text)
|
json_data = orjson.loads(res_text)
|
||||||
logger.info(f"[DEBUG] JSON解析成功,code={json_data.get('code')}")
|
logger.debug(f"JSON解析成功,code={json_data.get('code')}")
|
||||||
|
|
||||||
if json_data.get("code") != 0:
|
if json_data.get("code") != 0:
|
||||||
error_code = json_data.get("code")
|
error_code = json_data.get("code")
|
||||||
error_message = json_data.get("message", "未知错误")
|
error_message = json_data.get("message", "未知错误")
|
||||||
@@ -1250,7 +1249,7 @@ class QZoneService:
|
|||||||
logger.error(f"监控好友动态失败: {e}")
|
logger.error(f"监控好友动态失败: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
logger.info("[DEBUG] API客户端构造完成,返回包含6个方法的字典")
|
logger.debug("API客户端构造完成,返回包含6个方法的字典")
|
||||||
return {
|
return {
|
||||||
"publish": _publish,
|
"publish": _publish,
|
||||||
"list_feeds": _list_feeds,
|
"list_feeds": _list_feeds,
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ class SendHandler:
|
|||||||
"""
|
"""
|
||||||
处理来自核心的消息,将其转换为 Napcat 可接受的格式并发送
|
处理来自核心的消息,将其转换为 Napcat 可接受的格式并发送
|
||||||
"""
|
"""
|
||||||
logger.info("接收到来自MoFox-Bot的消息,处理中")
|
logger.debug("接收到来自MoFox-Bot的消息,处理中")
|
||||||
|
|
||||||
if not envelope:
|
if not envelope:
|
||||||
logger.warning("空的消息,跳过处理")
|
logger.warning("空的消息,跳过处理")
|
||||||
@@ -50,13 +50,13 @@ class SendHandler:
|
|||||||
seg_type = segment.get("type")
|
seg_type = segment.get("type")
|
||||||
|
|
||||||
if seg_type == "command":
|
if seg_type == "command":
|
||||||
logger.info("处理命令")
|
logger.debug("处理命令")
|
||||||
return await self.send_command(envelope)
|
return await self.send_command(envelope)
|
||||||
if seg_type == "adapter_command":
|
if seg_type == "adapter_command":
|
||||||
logger.info("处理适配器命令")
|
logger.debug("处理适配器命令")
|
||||||
return await self.handle_adapter_command(envelope)
|
return await self.handle_adapter_command(envelope)
|
||||||
if seg_type == "adapter_response":
|
if seg_type == "adapter_response":
|
||||||
logger.info("收到adapter_response消息,此消息应该由Bot端处理,跳过")
|
logger.debug("收到adapter_response消息,此消息应该由Bot端处理,跳过")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return await self.send_normal_message(envelope)
|
return await self.send_normal_message(envelope)
|
||||||
@@ -65,7 +65,6 @@ class SendHandler:
|
|||||||
"""
|
"""
|
||||||
处理普通消息发送
|
处理普通消息发送
|
||||||
"""
|
"""
|
||||||
logger.info("处理普通信息中")
|
|
||||||
message_info: MessageInfoPayload = envelope.get("message_info", {})
|
message_info: MessageInfoPayload = envelope.get("message_info", {})
|
||||||
message_segment: SegPayload = envelope.get("message_segment", {}) # type: ignore[assignment]
|
message_segment: SegPayload = envelope.get("message_segment", {}) # type: ignore[assignment]
|
||||||
|
|
||||||
@@ -487,7 +486,6 @@ class SendHandler:
|
|||||||
|
|
||||||
def handle_set_emoji_like_command(self, args: Dict[str, Any]) -> tuple[str, Dict[str, Any]]:
|
def handle_set_emoji_like_command(self, args: Dict[str, Any]) -> tuple[str, Dict[str, Any]]:
|
||||||
"""处理设置表情回应命令"""
|
"""处理设置表情回应命令"""
|
||||||
logger.info(f"开始处理表情回应命令, 接收到参数: {args}")
|
|
||||||
try:
|
try:
|
||||||
message_id = int(args["message_id"])
|
message_id = int(args["message_id"])
|
||||||
emoji_id = int(args["emoji_id"])
|
emoji_id = int(args["emoji_id"])
|
||||||
|
|||||||
@@ -61,7 +61,6 @@ class VoiceUploader:
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"正在上传音频文件: {audio_path}")
|
logger.info(f"正在上传音频文件: {audio_path}")
|
||||||
logger.info(f"文件大小: {len(audio_data)} bytes")
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
|
|||||||
Reference in New Issue
Block a user