re-style: 格式化代码
This commit is contained in:
@@ -12,7 +12,7 @@ if __name__ == "__main__":
|
||||
|
||||
# 执行bot.py的代码
|
||||
bot_file = current_dir / "bot.py"
|
||||
with open(bot_file, "r", encoding="utf-8") as f:
|
||||
with open(bot_file, encoding="utf-8") as f:
|
||||
exec(f.read())
|
||||
|
||||
|
||||
|
||||
24
bot.py
24
bot.py
@@ -1,30 +1,30 @@
|
||||
# import asyncio
|
||||
import asyncio
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import time
|
||||
import platform
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from rich.traceback import install
|
||||
from colorama import init, Fore
|
||||
|
||||
from colorama import Fore, init
|
||||
from dotenv import load_dotenv # 处理.env文件
|
||||
from rich.traceback import install
|
||||
|
||||
# maim_message imports for console input
|
||||
|
||||
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
|
||||
from src.common.logger import initialize_logging, get_logger, shutdown_logging
|
||||
from src.common.logger import get_logger, initialize_logging, shutdown_logging
|
||||
|
||||
# UI日志适配器
|
||||
initialize_logging()
|
||||
|
||||
from src.main import MainSystem # noqa
|
||||
from src import BaseMain # noqa
|
||||
from src.manager.async_task_manager import async_task_manager # noqa
|
||||
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge # noqa
|
||||
from src.config.config import global_config # noqa
|
||||
from src.common.database.database import initialize_sql_database # noqa
|
||||
from src.common.database.sqlalchemy_models import initialize_database as init_db # noqa
|
||||
from src import BaseMain
|
||||
from src.manager.async_task_manager import async_task_manager
|
||||
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge
|
||||
from src.config.config import global_config
|
||||
from src.common.database.database import initialize_sql_database
|
||||
from src.common.database.sqlalchemy_models import initialize_database as init_db
|
||||
|
||||
logger = get_logger("main")
|
||||
|
||||
@@ -247,7 +247,7 @@ if __name__ == "__main__":
|
||||
# The actual shutdown logic is now in the finally block.
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"主程序发生异常: {str(e)} {str(traceback.format_exc())}")
|
||||
logger.error(f"主程序发生异常: {e!s} {traceback.format_exc()!s}")
|
||||
exit_code = 1 # 标记发生错误
|
||||
finally:
|
||||
# 确保 loop 在任何情况下都尝试关闭(如果存在且未关闭)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Bilibili 插件包
|
||||
提供B站视频观看体验功能,像真实用户一样浏览和评价视频
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Bilibili 工具基础模块
|
||||
提供 B 站视频信息获取和视频分析功能
|
||||
"""
|
||||
|
||||
import re
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from src.chat.utils.utils_video import get_video_analyzer
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("bilibili_tool")
|
||||
|
||||
@@ -25,7 +26,7 @@ class BilibiliVideoAnalyzer:
|
||||
"Referer": "https://www.bilibili.com/",
|
||||
}
|
||||
|
||||
def extract_bilibili_url(self, text: str) -> Optional[str]:
|
||||
def extract_bilibili_url(self, text: str) -> str | None:
|
||||
"""从文本中提取哔哩哔哩视频链接"""
|
||||
# 哔哩哔哩短链接模式
|
||||
short_pattern = re.compile(r"https?://b23\.tv/[\w]+", re.IGNORECASE)
|
||||
@@ -44,7 +45,7 @@ class BilibiliVideoAnalyzer:
|
||||
|
||||
return None
|
||||
|
||||
async def get_video_info(self, url: str) -> Optional[Dict[str, Any]]:
|
||||
async def get_video_info(self, url: str) -> dict[str, Any] | None:
|
||||
"""获取哔哩哔哩视频基本信息"""
|
||||
try:
|
||||
logger.info(f"🔍 解析视频URL: {url}")
|
||||
@@ -127,7 +128,7 @@ class BilibiliVideoAnalyzer:
|
||||
logger.exception("详细错误信息:")
|
||||
return None
|
||||
|
||||
async def get_video_stream_url(self, aid: int, cid: int) -> Optional[str]:
|
||||
async def get_video_stream_url(self, aid: int, cid: int) -> str | None:
|
||||
"""获取视频流URL"""
|
||||
try:
|
||||
logger.info(f"🎥 获取视频流URL: aid={aid}, cid={cid}")
|
||||
@@ -164,7 +165,7 @@ class BilibiliVideoAnalyzer:
|
||||
return stream_url
|
||||
|
||||
# 降级到FLV格式
|
||||
if "durl" in play_data and play_data["durl"]:
|
||||
if play_data.get("durl"):
|
||||
logger.info("📹 使用FLV格式视频流")
|
||||
stream_url = play_data["durl"][0].get("url")
|
||||
if stream_url:
|
||||
@@ -185,7 +186,7 @@ class BilibiliVideoAnalyzer:
|
||||
logger.exception("详细错误信息:")
|
||||
return None
|
||||
|
||||
async def download_video_bytes(self, stream_url: str, max_size_mb: int = 100) -> Optional[bytes]:
|
||||
async def download_video_bytes(self, stream_url: str, max_size_mb: int = 100) -> bytes | None:
|
||||
"""下载视频字节数据
|
||||
|
||||
Args:
|
||||
@@ -244,7 +245,7 @@ class BilibiliVideoAnalyzer:
|
||||
logger.exception("详细错误信息:")
|
||||
return None
|
||||
|
||||
async def analyze_bilibili_video(self, url: str, prompt: str = None) -> Dict[str, Any]:
|
||||
async def analyze_bilibili_video(self, url: str, prompt: str = None) -> dict[str, Any]:
|
||||
"""分析哔哩哔哩视频并返回详细信息和AI分析结果"""
|
||||
try:
|
||||
logger.info(f"🎬 开始分析哔哩哔哩视频: {url}")
|
||||
@@ -322,10 +323,10 @@ class BilibiliVideoAnalyzer:
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"分析哔哩哔哩视频时发生异常: {str(e)}"
|
||||
error_msg = f"分析哔哩哔哩视频时发生异常: {e!s}"
|
||||
logger.error(f"❌ {error_msg}")
|
||||
logger.exception("详细错误信息:") # 记录完整的异常堆栈
|
||||
return {"error": f"分析失败: {str(e)}"}
|
||||
return {"error": f"分析失败: {e!s}"}
|
||||
|
||||
|
||||
# 全局实例
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Bilibili 视频观看体验工具
|
||||
支持哔哩哔哩视频链接解析和AI视频内容分析
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Tuple, Type
|
||||
from src.plugin_system import BaseTool, ToolParamType, BasePlugin, register_plugin, ComponentInfo, ConfigField
|
||||
from .bilibli_base import get_bilibili_analyzer
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BasePlugin, BaseTool, ComponentInfo, ConfigField, ToolParamType, register_plugin
|
||||
|
||||
from .bilibli_base import get_bilibili_analyzer
|
||||
|
||||
logger = get_logger("bilibili_tool")
|
||||
|
||||
@@ -41,7 +42,7 @@ class BilibiliTool(BaseTool):
|
||||
super().__init__(plugin_config)
|
||||
self.analyzer = get_bilibili_analyzer()
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行哔哩哔哩视频观看体验"""
|
||||
try:
|
||||
url = function_args.get("url", "").strip()
|
||||
@@ -83,7 +84,7 @@ class BilibiliTool(BaseTool):
|
||||
return {"name": self.name, "content": content.strip()}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"😅 看视频的时候出了点问题: {str(e)}"
|
||||
error_msg = f"😅 看视频的时候出了点问题: {e!s}"
|
||||
logger.error(error_msg)
|
||||
return {"name": self.name, "content": error_msg}
|
||||
|
||||
@@ -104,7 +105,7 @@ class BilibiliTool(BaseTool):
|
||||
|
||||
return base_prompt
|
||||
|
||||
def _format_watch_experience(self, video_info: Dict, ai_analysis: str, interest_focus: str = None) -> str:
|
||||
def _format_watch_experience(self, video_info: dict, ai_analysis: str, interest_focus: str = None) -> str:
|
||||
"""格式化观看体验报告"""
|
||||
|
||||
# 根据播放量生成热度评价
|
||||
@@ -191,8 +192,8 @@ class BilibiliPlugin(BasePlugin):
|
||||
# 插件基本信息
|
||||
plugin_name: str = "bilibili_video_watcher"
|
||||
enable_plugin: bool = True
|
||||
dependencies: List[str] = []
|
||||
python_dependencies: List[str] = []
|
||||
dependencies: list[str] = []
|
||||
python_dependencies: list[str] = []
|
||||
config_file_name: str = "config.toml"
|
||||
|
||||
# 配置节描述
|
||||
@@ -220,6 +221,6 @@ class BilibiliPlugin(BasePlugin):
|
||||
},
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
|
||||
"""返回插件包含的工具组件"""
|
||||
return [(BilibiliTool.get_tool_info(), BilibiliTool)]
|
||||
|
||||
@@ -4,14 +4,15 @@ Echo 示例插件
|
||||
展示增强命令系统的使用方法
|
||||
"""
|
||||
|
||||
from typing import List, Tuple, Type, Optional, Union
|
||||
from typing import Union
|
||||
|
||||
from src.plugin_system import (
|
||||
BasePlugin,
|
||||
PlusCommand,
|
||||
CommandArgs,
|
||||
PlusCommandInfo,
|
||||
ConfigField,
|
||||
ChatType,
|
||||
CommandArgs,
|
||||
ConfigField,
|
||||
PlusCommand,
|
||||
PlusCommandInfo,
|
||||
register_plugin,
|
||||
)
|
||||
from src.plugin_system.base.component_types import PythonDependency
|
||||
@@ -27,7 +28,7 @@ class EchoCommand(PlusCommand):
|
||||
chat_type_allow = ChatType.ALL
|
||||
intercept_message = True
|
||||
|
||||
async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]:
|
||||
async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]:
|
||||
"""执行echo命令"""
|
||||
if args.is_empty():
|
||||
await self.send_text("❓ 请提供要回显的内容\n用法: /echo <内容>")
|
||||
@@ -56,7 +57,7 @@ class HelloCommand(PlusCommand):
|
||||
chat_type_allow = ChatType.ALL
|
||||
intercept_message = True
|
||||
|
||||
async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]:
|
||||
async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]:
|
||||
"""执行hello命令"""
|
||||
if args.is_empty():
|
||||
await self.send_text("👋 Hello! 很高兴见到你!")
|
||||
@@ -77,7 +78,7 @@ class InfoCommand(PlusCommand):
|
||||
chat_type_allow = ChatType.ALL
|
||||
intercept_message = True
|
||||
|
||||
async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]:
|
||||
async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]:
|
||||
"""执行info命令"""
|
||||
info_text = (
|
||||
"📋 Echo 示例插件信息\n"
|
||||
@@ -105,7 +106,7 @@ class TestCommand(PlusCommand):
|
||||
chat_type_allow = ChatType.ALL
|
||||
intercept_message = True
|
||||
|
||||
async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]:
|
||||
async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]:
|
||||
"""执行test命令"""
|
||||
if args.is_empty():
|
||||
help_text = (
|
||||
@@ -166,8 +167,8 @@ class EchoExamplePlugin(BasePlugin):
|
||||
|
||||
plugin_name: str = "echo_example_plugin"
|
||||
enable_plugin: bool = True
|
||||
dependencies: List[str] = []
|
||||
python_dependencies: List[Union[str, "PythonDependency"]] = []
|
||||
dependencies: list[str] = []
|
||||
python_dependencies: list[Union[str, "PythonDependency"]] = []
|
||||
config_file_name: str = "config.toml"
|
||||
|
||||
config_schema = {
|
||||
@@ -187,7 +188,7 @@ class EchoExamplePlugin(BasePlugin):
|
||||
"commands": "命令相关配置",
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[PlusCommandInfo, Type]]:
|
||||
def get_plugin_components(self) -> list[tuple[PlusCommandInfo, type]]:
|
||||
"""获取插件组件"""
|
||||
components = []
|
||||
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
from typing import List, Tuple, Type, Dict, Any, Optional
|
||||
import logging
|
||||
import random
|
||||
from typing import Any
|
||||
|
||||
from src.plugin_system import (
|
||||
BasePlugin,
|
||||
register_plugin,
|
||||
ComponentInfo,
|
||||
BaseEventHandler,
|
||||
EventType,
|
||||
BaseTool,
|
||||
PlusCommand,
|
||||
CommandArgs,
|
||||
ChatType,
|
||||
BaseAction,
|
||||
ActionActivationType,
|
||||
BaseAction,
|
||||
BaseEventHandler,
|
||||
BasePlugin,
|
||||
BaseTool,
|
||||
ChatType,
|
||||
CommandArgs,
|
||||
ComponentInfo,
|
||||
ConfigField,
|
||||
EventType,
|
||||
PlusCommand,
|
||||
register_plugin,
|
||||
)
|
||||
from src.plugin_system.base.base_event import HandlerResult
|
||||
|
||||
@@ -39,7 +39,7 @@ class GetSystemInfoTool(BaseTool):
|
||||
available_for_llm = True
|
||||
parameters = []
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
return {"name": self.name, "content": "系统版本: 1.0.1, 状态: 运行正常"}
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ class HelloCommand(PlusCommand):
|
||||
command_aliases = ["hi", "你好"]
|
||||
chat_type_allow = ChatType.ALL
|
||||
|
||||
async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]:
|
||||
async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]:
|
||||
greeting = str(self.get_config("greeting.message", "Hello, World! 我是一个由 MoFox_Bot 驱动的插件。"))
|
||||
await self.send_text(greeting)
|
||||
return True, "成功发送问候", True
|
||||
@@ -67,7 +67,7 @@ class RandomEmojiAction(BaseAction):
|
||||
action_require = ["当对话气氛轻松时", "可以用来回应简单的情感表达"]
|
||||
associated_types = ["text"]
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
async def execute(self) -> tuple[bool, str]:
|
||||
emojis = ["😊", "😂", "👍", "🎉", "🤔", "🤖"]
|
||||
await self.send_text(random.choice(emojis))
|
||||
return True, "成功发送了一个随机表情"
|
||||
@@ -99,9 +99,9 @@ class HelloWorldPlugin(BasePlugin):
|
||||
},
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
|
||||
"""根据配置文件动态注册插件的功能组件。"""
|
||||
components: List[Tuple[ComponentInfo, Type]] = []
|
||||
components: list[tuple[ComponentInfo, type]] = []
|
||||
|
||||
components.append((StartupMessageHandler.get_handler_info(), StartupMessageHandler))
|
||||
components.append((GetSystemInfoTool.get_tool_info(), GetSystemInfoTool))
|
||||
|
||||
@@ -70,6 +70,7 @@ dependencies = [
|
||||
"tqdm>=4.67.1",
|
||||
"urllib3>=2.5.0",
|
||||
"uvicorn>=0.35.0",
|
||||
"watchdog>=6.0.0",
|
||||
"websockets>=15.0.1",
|
||||
"aiomysql>=0.2.0",
|
||||
"aiosqlite>=0.21.0",
|
||||
@@ -80,29 +81,41 @@ dependencies = [
|
||||
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||
default = true
|
||||
|
||||
[tool.uv.sources]
|
||||
amrita = { workspace = true }
|
||||
|
||||
[tool.ruff]
|
||||
|
||||
include = ["*.py"]
|
||||
|
||||
# 行长度设置
|
||||
line-length = 120
|
||||
target-version = "py310"
|
||||
|
||||
[tool.ruff.lint]
|
||||
fixable = ["ALL"]
|
||||
unfixable = []
|
||||
select = [
|
||||
"F", # Pyflakes
|
||||
"W", # pycodestyle warnings
|
||||
"E", # pycodestyle errors
|
||||
"UP", # pyupgrade
|
||||
"ASYNC", # flake8-async
|
||||
"C4", # flake8-comprehensions
|
||||
"T10", # flake8-debugger
|
||||
"PYI", # flake8-pyi
|
||||
"PT", # flake8-pytest-style
|
||||
"Q", # flake8-quotes
|
||||
"RUF", # Ruff-specific rules
|
||||
"I", # isort
|
||||
"PERF", # pylint-performance
|
||||
]
|
||||
ignore = [
|
||||
"E402", # module-import-not-at-top-of-file
|
||||
"E501", # line-too-long
|
||||
"UP037", # quoted-annotation
|
||||
"RUF001", # ambiguous-unicode-character-string
|
||||
"RUF002", # ambiguous-unicode-character-docstring
|
||||
"RUF003", # ambiguous-unicode-character-comment
|
||||
]
|
||||
|
||||
|
||||
# 如果一个变量的名称以下划线开头,即使它未被使用,也不应该被视为错误或警告。
|
||||
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
|
||||
|
||||
# 启用的规则
|
||||
select = [
|
||||
"E", # pycodestyle 错误
|
||||
"F", # pyflakes
|
||||
"B", # flake8-bugbear
|
||||
]
|
||||
|
||||
ignore = ["E711","E501"]
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
indent-style = "space"
|
||||
@@ -124,6 +137,4 @@ skip-magic-trailing-comma = false
|
||||
line-ending = "auto"
|
||||
|
||||
[dependency-groups]
|
||||
lint = [
|
||||
"loguru>=0.7.3",
|
||||
]
|
||||
lint = ["loguru>=0.7.3"]
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import time
|
||||
import sys
|
||||
import os
|
||||
from typing import Dict, List
|
||||
import sys
|
||||
import time
|
||||
|
||||
# Add project root to Python path
|
||||
from src.common.database.database_model import Expression, ChatStreams
|
||||
from src.common.database.database_model import ChatStreams, Expression
|
||||
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
@@ -30,7 +29,7 @@ def get_chat_name(chat_id: str) -> str:
|
||||
return f"查询失败 ({chat_id})"
|
||||
|
||||
|
||||
def calculate_time_distribution(expressions) -> Dict[str, int]:
|
||||
def calculate_time_distribution(expressions) -> dict[str, int]:
|
||||
"""Calculate distribution of last active time in days"""
|
||||
now = time.time()
|
||||
distribution = {
|
||||
@@ -64,7 +63,7 @@ def calculate_time_distribution(expressions) -> Dict[str, int]:
|
||||
return distribution
|
||||
|
||||
|
||||
def calculate_count_distribution(expressions) -> Dict[str, int]:
|
||||
def calculate_count_distribution(expressions) -> dict[str, int]:
|
||||
"""Calculate distribution of count values"""
|
||||
distribution = {"0-1": 0, "1-2": 0, "2-3": 0, "3-4": 0, "4-5": 0, "5-10": 0, "10+": 0}
|
||||
for expr in expressions:
|
||||
@@ -86,7 +85,7 @@ def calculate_count_distribution(expressions) -> Dict[str, int]:
|
||||
return distribution
|
||||
|
||||
|
||||
def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> List[Expression]:
|
||||
def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> list[Expression]:
|
||||
"""Get top N most used expressions for a specific chat_id"""
|
||||
return Expression.select().where(Expression.chat_id == chat_id).order_by(Expression.count.desc()).limit(top_n)
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import time
|
||||
import sys
|
||||
import os
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
# Add project root to Python path
|
||||
@@ -35,7 +34,7 @@ def format_timestamp(timestamp: float) -> str:
|
||||
return "未知时间"
|
||||
|
||||
|
||||
def calculate_interest_value_distribution(messages) -> Dict[str, int]:
|
||||
def calculate_interest_value_distribution(messages) -> dict[str, int]:
|
||||
"""Calculate distribution of interest_value"""
|
||||
distribution = {
|
||||
"0.000-0.010": 0,
|
||||
@@ -76,7 +75,7 @@ def calculate_interest_value_distribution(messages) -> Dict[str, int]:
|
||||
return distribution
|
||||
|
||||
|
||||
def get_interest_value_stats(messages) -> Dict[str, float]:
|
||||
def get_interest_value_stats(messages) -> dict[str, float]:
|
||||
"""Calculate basic statistics for interest_value"""
|
||||
values = [
|
||||
float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0
|
||||
@@ -97,7 +96,7 @@ def get_interest_value_stats(messages) -> Dict[str, float]:
|
||||
}
|
||||
|
||||
|
||||
def get_available_chats() -> List[Tuple[str, str, int]]:
|
||||
def get_available_chats() -> list[tuple[str, str, int]]:
|
||||
"""Get all available chats with message counts"""
|
||||
try:
|
||||
# 获取所有有消息的chat_id
|
||||
@@ -130,7 +129,7 @@ def get_available_chats() -> List[Tuple[str, str, int]]:
|
||||
return []
|
||||
|
||||
|
||||
def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
|
||||
def get_time_range_input() -> tuple[float | None, float | None]:
|
||||
"""Get time range input from user"""
|
||||
print("\n时间范围选择:")
|
||||
print("1. 最近1天")
|
||||
@@ -170,7 +169,7 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
|
||||
|
||||
|
||||
def analyze_interest_values(
|
||||
chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None
|
||||
chat_id: str | None = None, start_time: float | None = None, end_time: float | None = None
|
||||
) -> None:
|
||||
"""Analyze interest values with optional filters"""
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import tkinter as tk
|
||||
from tkinter import ttk, messagebox, filedialog, colorchooser
|
||||
import orjson
|
||||
from pathlib import Path
|
||||
import threading
|
||||
import toml
|
||||
from datetime import datetime
|
||||
from collections import defaultdict
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import tkinter as tk
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from tkinter import colorchooser, filedialog, messagebox, ttk
|
||||
|
||||
import orjson
|
||||
import toml
|
||||
|
||||
|
||||
class LogIndex:
|
||||
@@ -409,7 +410,7 @@ class AsyncLogLoader:
|
||||
file_size = os.path.getsize(file_path)
|
||||
processed_size = 0
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
line_count = 0
|
||||
batch_size = 1000 # 批量处理
|
||||
|
||||
@@ -561,7 +562,7 @@ class LogViewer:
|
||||
|
||||
try:
|
||||
if config_path.exists():
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
bot_config = toml.load(f)
|
||||
if "log" in bot_config:
|
||||
self.log_config.update(bot_config["log"])
|
||||
@@ -575,7 +576,7 @@ class LogViewer:
|
||||
|
||||
try:
|
||||
if viewer_config_path.exists():
|
||||
with open(viewer_config_path, "r", encoding="utf-8") as f:
|
||||
with open(viewer_config_path, encoding="utf-8") as f:
|
||||
viewer_config = toml.load(f)
|
||||
if "viewer" in viewer_config:
|
||||
self.viewer_config.update(viewer_config["viewer"])
|
||||
@@ -843,7 +844,7 @@ class LogViewer:
|
||||
mapping_file = Path("config/module_mapping.json")
|
||||
if mapping_file.exists():
|
||||
try:
|
||||
with open(mapping_file, "r", encoding="utf-8") as f:
|
||||
with open(mapping_file, encoding="utf-8") as f:
|
||||
custom_mapping = orjson.loads(f.read())
|
||||
self.module_name_mapping.update(custom_mapping)
|
||||
except Exception as e:
|
||||
@@ -1172,7 +1173,7 @@ class LogViewer:
|
||||
"""读取新的日志条目并返回它们"""
|
||||
new_entries = []
|
||||
new_modules = set() # 收集新发现的模块
|
||||
with open(self.current_log_file, "r", encoding="utf-8") as f:
|
||||
with open(self.current_log_file, encoding="utf-8") as f:
|
||||
f.seek(from_position)
|
||||
line_count = self.log_index.total_entries
|
||||
for line in f:
|
||||
|
||||
@@ -1,36 +1,37 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import orjson
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Optional
|
||||
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
|
||||
# 将项目根目录添加到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.chat.knowledge.open_ie import OpenIE
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
BarColumn,
|
||||
MofNCompleteColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TaskProgressColumn,
|
||||
TextColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
TaskProgressColumn,
|
||||
MofNCompleteColumn,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.chat.knowledge.open_ie import OpenIE
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger("LPMM_LearningTool")
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data", "lpmm_raw_data")
|
||||
@@ -59,7 +60,7 @@ def clear_cache():
|
||||
|
||||
|
||||
def process_text_file(file_path):
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
raw = f.read()
|
||||
return [p.strip() for p in raw.split("\n\n") if p.strip()]
|
||||
|
||||
@@ -86,7 +87,7 @@ def preprocess_raw_data():
|
||||
# --- 模块二:信息提取 ---
|
||||
|
||||
|
||||
def _parse_and_repair_json(json_string: str) -> Optional[dict]:
|
||||
def _parse_and_repair_json(json_string: str) -> dict | None:
|
||||
"""
|
||||
尝试解析JSON字符串,如果失败则尝试修复并重新解析。
|
||||
|
||||
@@ -249,7 +250,7 @@ def extract_information(paragraphs_dict, model_set):
|
||||
# --- 模块三:数据导入 ---
|
||||
|
||||
|
||||
async def import_data(openie_obj: Optional[OpenIE] = None):
|
||||
async def import_data(openie_obj: OpenIE | None = None):
|
||||
"""
|
||||
将OpenIE数据导入知识库(Embedding Store 和 KG)
|
||||
|
||||
|
||||
@@ -4,11 +4,13 @@
|
||||
提供插件manifest文件的创建、验证和管理功能
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import orjson
|
||||
from pathlib import Path
|
||||
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.utils.manifest_utils import (
|
||||
ManifestValidator,
|
||||
@@ -124,7 +126,7 @@ def validate_manifest_file(plugin_dir: str) -> bool:
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
with open(manifest_path, encoding="utf-8") as f:
|
||||
manifest_data = orjson.loads(f.read())
|
||||
|
||||
validator = ManifestValidator()
|
||||
|
||||
@@ -1,46 +1,48 @@
|
||||
import os
|
||||
import orjson
|
||||
import sys # 新增系统模块导入
|
||||
|
||||
# import time
|
||||
import pickle
|
||||
import sys # 新增系统模块导入
|
||||
from pathlib import Path
|
||||
|
||||
import orjson
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from typing import Dict, Any, List, Optional, Type
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from peewee import Field, IntegrityError, Model
|
||||
from pymongo import MongoClient
|
||||
from pymongo.errors import ConnectionFailure
|
||||
from peewee import Model, Field, IntegrityError
|
||||
|
||||
# Rich 进度条和显示组件
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
TextColumn,
|
||||
BarColumn,
|
||||
TaskProgressColumn,
|
||||
TimeRemainingColumn,
|
||||
TimeElapsedColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TaskProgressColumn,
|
||||
TextColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
)
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
# from rich.text import Text
|
||||
|
||||
# from rich.text import Text
|
||||
from src.common.database.database import db
|
||||
from src.common.database.sqlalchemy_models import (
|
||||
ChatStreams,
|
||||
Emoji,
|
||||
Messages,
|
||||
Images,
|
||||
ImageDescriptions,
|
||||
PersonInfo,
|
||||
Knowledges,
|
||||
ThinkingLog,
|
||||
GraphNodes,
|
||||
GraphEdges,
|
||||
GraphNodes,
|
||||
ImageDescriptions,
|
||||
Images,
|
||||
Knowledges,
|
||||
Messages,
|
||||
PersonInfo,
|
||||
ThinkingLog,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -54,12 +56,12 @@ class MigrationConfig:
|
||||
"""迁移配置类"""
|
||||
|
||||
mongo_collection: str
|
||||
target_model: Type[Model]
|
||||
field_mapping: Dict[str, str]
|
||||
target_model: type[Model]
|
||||
field_mapping: dict[str, str]
|
||||
batch_size: int = 500
|
||||
enable_validation: bool = True
|
||||
skip_duplicates: bool = True
|
||||
unique_fields: List[str] = field(default_factory=list) # 用于重复检查的字段
|
||||
unique_fields: list[str] = field(default_factory=list) # 用于重复检查的字段
|
||||
|
||||
|
||||
# 数据验证相关类已移除 - 用户要求不要数据验证
|
||||
@@ -73,7 +75,7 @@ class MigrationCheckpoint:
|
||||
processed_count: int
|
||||
last_processed_id: Any
|
||||
timestamp: datetime
|
||||
batch_errors: List[Dict[str, Any]] = field(default_factory=list)
|
||||
batch_errors: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -88,11 +90,11 @@ class MigrationStats:
|
||||
duplicate_count: int = 0
|
||||
validation_errors: int = 0
|
||||
batch_insert_count: int = 0
|
||||
errors: List[Dict[str, Any]] = field(default_factory=list)
|
||||
start_time: Optional[datetime] = None
|
||||
end_time: Optional[datetime] = None
|
||||
errors: list[dict[str, Any]] = field(default_factory=list)
|
||||
start_time: datetime | None = None
|
||||
end_time: datetime | None = None
|
||||
|
||||
def add_error(self, doc_id: Any, error: str, doc_data: Optional[Dict] = None):
|
||||
def add_error(self, doc_id: Any, error: str, doc_data: dict | None = None):
|
||||
"""添加错误记录"""
|
||||
self.errors.append(
|
||||
{"doc_id": str(doc_id), "error": error, "timestamp": datetime.now().isoformat(), "doc_data": doc_data}
|
||||
@@ -108,10 +110,10 @@ class MigrationStats:
|
||||
class MongoToSQLiteMigrator:
|
||||
"""MongoDB到SQLite数据迁移器 - 使用Peewee ORM"""
|
||||
|
||||
def __init__(self, mongo_uri: Optional[str] = None, database_name: Optional[str] = None):
|
||||
def __init__(self, mongo_uri: str | None = None, database_name: str | None = None):
|
||||
self.database_name = database_name or os.getenv("DATABASE_NAME", "MegBot")
|
||||
self.mongo_uri = mongo_uri or self._build_mongo_uri()
|
||||
self.mongo_client: Optional[MongoClient] = None
|
||||
self.mongo_client: MongoClient | None = None
|
||||
self.mongo_db = None
|
||||
|
||||
# 迁移配置
|
||||
@@ -142,7 +144,7 @@ class MongoToSQLiteMigrator:
|
||||
else:
|
||||
return f"mongodb://{host}:{port}/{self.database_name}"
|
||||
|
||||
def _initialize_migration_configs(self) -> List[MigrationConfig]:
|
||||
def _initialize_migration_configs(self) -> list[MigrationConfig]:
|
||||
"""初始化迁移配置"""
|
||||
return [ # 表情包迁移配置
|
||||
MigrationConfig(
|
||||
@@ -306,7 +308,7 @@ class MongoToSQLiteMigrator:
|
||||
),
|
||||
]
|
||||
|
||||
def _initialize_validation_rules(self) -> Dict[str, Any]:
|
||||
def _initialize_validation_rules(self) -> dict[str, Any]:
|
||||
"""数据验证已禁用 - 返回空字典"""
|
||||
return {}
|
||||
|
||||
@@ -337,7 +339,7 @@ class MongoToSQLiteMigrator:
|
||||
self.mongo_client.close()
|
||||
logger.info("MongoDB连接已关闭")
|
||||
|
||||
def _get_nested_value(self, document: Dict[str, Any], field_path: str) -> Any:
|
||||
def _get_nested_value(self, document: dict[str, Any], field_path: str) -> Any:
|
||||
"""获取嵌套字段的值"""
|
||||
if "." not in field_path:
|
||||
return document.get(field_path)
|
||||
@@ -434,7 +436,7 @@ class MongoToSQLiteMigrator:
|
||||
|
||||
return None
|
||||
|
||||
def _validate_data(self, collection_name: str, data: Dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool:
|
||||
def _validate_data(self, collection_name: str, data: dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool:
|
||||
"""数据验证已禁用 - 始终返回True"""
|
||||
return True
|
||||
|
||||
@@ -454,7 +456,7 @@ class MongoToSQLiteMigrator:
|
||||
except Exception as e:
|
||||
logger.warning(f"保存断点失败: {e}")
|
||||
|
||||
def _load_checkpoint(self, collection_name: str) -> Optional[MigrationCheckpoint]:
|
||||
def _load_checkpoint(self, collection_name: str) -> MigrationCheckpoint | None:
|
||||
"""加载迁移断点"""
|
||||
checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl"
|
||||
if not checkpoint_file.exists():
|
||||
@@ -467,7 +469,7 @@ class MongoToSQLiteMigrator:
|
||||
logger.warning(f"加载断点失败: {e}")
|
||||
return None
|
||||
|
||||
def _batch_insert(self, model: Type[Model], data_list: List[Dict[str, Any]]) -> int:
|
||||
def _batch_insert(self, model: type[Model], data_list: list[dict[str, Any]]) -> int:
|
||||
"""批量插入数据"""
|
||||
if not data_list:
|
||||
return 0
|
||||
@@ -494,7 +496,7 @@ class MongoToSQLiteMigrator:
|
||||
return success_count
|
||||
|
||||
def _check_duplicate_by_unique_fields(
|
||||
self, model: Type[Model], data: Dict[str, Any], unique_fields: List[str]
|
||||
self, model: type[Model], data: dict[str, Any], unique_fields: list[str]
|
||||
) -> bool:
|
||||
"""根据唯一字段检查重复"""
|
||||
if not unique_fields:
|
||||
@@ -512,7 +514,7 @@ class MongoToSQLiteMigrator:
|
||||
logger.debug(f"重复检查失败: {e}")
|
||||
return False
|
||||
|
||||
def _create_model_instance(self, model: Type[Model], data: Dict[str, Any]) -> Optional[Model]:
|
||||
def _create_model_instance(self, model: type[Model], data: dict[str, Any]) -> Model | None:
|
||||
"""使用ORM创建模型实例"""
|
||||
try:
|
||||
# 过滤掉不存在的字段
|
||||
@@ -669,7 +671,7 @@ class MongoToSQLiteMigrator:
|
||||
|
||||
return stats
|
||||
|
||||
def migrate_all(self) -> Dict[str, MigrationStats]:
|
||||
def migrate_all(self) -> dict[str, MigrationStats]:
|
||||
"""执行所有迁移任务"""
|
||||
logger.info("开始执行数据库迁移...")
|
||||
|
||||
@@ -730,7 +732,7 @@ class MongoToSQLiteMigrator:
|
||||
self._print_migration_summary(all_stats)
|
||||
return all_stats
|
||||
|
||||
def _print_migration_summary(self, all_stats: Dict[str, MigrationStats]):
|
||||
def _print_migration_summary(self, all_stats: dict[str, MigrationStats]):
|
||||
"""使用Rich打印美观的迁移汇总信息"""
|
||||
# 计算总体统计
|
||||
total_processed = sum(stats.processed_count for stats in all_stats.values())
|
||||
@@ -857,7 +859,7 @@ class MongoToSQLiteMigrator:
|
||||
"""添加新的迁移配置"""
|
||||
self.migration_configs.append(config)
|
||||
|
||||
def migrate_single_collection(self, collection_name: str) -> Optional[MigrationStats]:
|
||||
def migrate_single_collection(self, collection_name: str) -> MigrationStats | None:
|
||||
"""迁移单个指定的集合"""
|
||||
config = next((c for c in self.migration_configs if c.mongo_collection == collection_name), None)
|
||||
if not config:
|
||||
@@ -875,7 +877,7 @@ class MongoToSQLiteMigrator:
|
||||
finally:
|
||||
self.disconnect_mongodb()
|
||||
|
||||
def export_error_report(self, all_stats: Dict[str, MigrationStats], filepath: str):
|
||||
def export_error_report(self, all_stats: dict[str, MigrationStats], filepath: str):
|
||||
"""导出错误报告"""
|
||||
error_report = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
从现有ChromaDB数据重建JSON元数据索引
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from src.chat.memory_system.memory_system import MemorySystem
|
||||
from src.chat.memory_system.memory_metadata_index import MemoryMetadataIndexEntry
|
||||
from src.chat.memory_system.memory_system import MemorySystem
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
轻量烟雾测试:初始化 MemorySystem 并运行一次检索,验证 MemoryMetadata.source 访问不再报错
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import time
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
# Add project root to Python path
|
||||
@@ -63,7 +62,7 @@ def format_timestamp(timestamp: float) -> str:
|
||||
return "未知时间"
|
||||
|
||||
|
||||
def calculate_text_length_distribution(messages) -> Dict[str, int]:
|
||||
def calculate_text_length_distribution(messages) -> dict[str, int]:
|
||||
"""Calculate distribution of processed_plain_text length"""
|
||||
distribution = {
|
||||
"0": 0, # 空文本
|
||||
@@ -126,7 +125,7 @@ def calculate_text_length_distribution(messages) -> Dict[str, int]:
|
||||
return distribution
|
||||
|
||||
|
||||
def get_text_length_stats(messages) -> Dict[str, float]:
|
||||
def get_text_length_stats(messages) -> dict[str, float]:
|
||||
"""Calculate basic statistics for processed_plain_text length"""
|
||||
lengths = []
|
||||
null_count = 0
|
||||
@@ -168,7 +167,7 @@ def get_text_length_stats(messages) -> Dict[str, float]:
|
||||
}
|
||||
|
||||
|
||||
def get_available_chats() -> List[Tuple[str, str, int]]:
|
||||
def get_available_chats() -> list[tuple[str, str, int]]:
|
||||
"""Get all available chats with message counts"""
|
||||
try:
|
||||
# 获取所有有消息的chat_id,排除特殊类型消息
|
||||
@@ -202,7 +201,7 @@ def get_available_chats() -> List[Tuple[str, str, int]]:
|
||||
return []
|
||||
|
||||
|
||||
def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
|
||||
def get_time_range_input() -> tuple[float | None, float | None]:
|
||||
"""Get time range input from user"""
|
||||
print("\n时间范围选择:")
|
||||
print("1. 最近1天")
|
||||
@@ -241,7 +240,7 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
|
||||
return None, None
|
||||
|
||||
|
||||
def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int, str, str]]:
|
||||
def get_top_longest_messages(messages, top_n: int = 10) -> list[tuple[str, int, str, str]]:
|
||||
"""Get top N longest messages"""
|
||||
message_lengths = []
|
||||
|
||||
@@ -266,7 +265,7 @@ def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int,
|
||||
|
||||
|
||||
def analyze_text_lengths(
|
||||
chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None
|
||||
chat_id: str | None = None, start_time: float | None = None, end_time: float | None = None
|
||||
) -> None:
|
||||
"""Analyze processed_plain_text lengths with optional filters"""
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ def update_prompt_imports(file_path):
|
||||
print(f"文件不存在: {file_path}")
|
||||
return False
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# 替换导入语句
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import random
|
||||
from typing import List, Optional, Sequence
|
||||
from colorama import init, Fore
|
||||
from collections.abc import Sequence
|
||||
from typing import List, Optional
|
||||
|
||||
from colorama import Fore, init
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
egg = get_logger("小彩蛋")
|
||||
|
||||
|
||||
def weighted_choice(data: Sequence[str], weights: Optional[List[float]] = None) -> str:
|
||||
def weighted_choice(data: Sequence[str], weights: list[float] | None = None) -> str:
|
||||
"""
|
||||
从 data 中按权重随机返回一条。
|
||||
若 weights 为 None,则所有元素权重默认为 1。
|
||||
|
||||
@@ -3,8 +3,8 @@ MaiBot模块系统
|
||||
包含聊天、情绪、记忆、日程等功能模块
|
||||
"""
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
# 导出主要组件供外部使用
|
||||
__all__ = [
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
MaiBot 反注入系统模块
|
||||
|
||||
@@ -14,25 +13,25 @@ MaiBot 反注入系统模块
|
||||
"""
|
||||
|
||||
from .anti_injector import AntiPromptInjector, get_anti_injector, initialize_anti_injector
|
||||
from .types import DetectionResult, ProcessResult
|
||||
from .core import PromptInjectionDetector, MessageShield
|
||||
from .processors.message_processor import MessageProcessor
|
||||
from .management import AntiInjectionStatistics, UserBanManager
|
||||
from .core import MessageShield, PromptInjectionDetector
|
||||
from .decision import CounterAttackGenerator, ProcessingDecisionMaker
|
||||
from .management import AntiInjectionStatistics, UserBanManager
|
||||
from .processors.message_processor import MessageProcessor
|
||||
from .types import DetectionResult, ProcessResult
|
||||
|
||||
__all__ = [
|
||||
"AntiInjectionStatistics",
|
||||
"AntiPromptInjector",
|
||||
"CounterAttackGenerator",
|
||||
"DetectionResult",
|
||||
"MessageProcessor",
|
||||
"MessageShield",
|
||||
"ProcessResult",
|
||||
"ProcessingDecisionMaker",
|
||||
"PromptInjectionDetector",
|
||||
"UserBanManager",
|
||||
"get_anti_injector",
|
||||
"initialize_anti_injector",
|
||||
"DetectionResult",
|
||||
"ProcessResult",
|
||||
"PromptInjectionDetector",
|
||||
"MessageShield",
|
||||
"MessageProcessor",
|
||||
"AntiInjectionStatistics",
|
||||
"UserBanManager",
|
||||
"CounterAttackGenerator",
|
||||
"ProcessingDecisionMaker",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
LLM反注入系统主模块
|
||||
|
||||
@@ -12,15 +11,16 @@ LLM反注入系统主模块
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Optional, Tuple, Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from .types import ProcessResult
|
||||
from .core import PromptInjectionDetector, MessageShield
|
||||
from .processors.message_processor import MessageProcessor
|
||||
from .management import AntiInjectionStatistics, UserBanManager
|
||||
|
||||
from .core import MessageShield, PromptInjectionDetector
|
||||
from .decision import CounterAttackGenerator, ProcessingDecisionMaker
|
||||
from .management import AntiInjectionStatistics, UserBanManager
|
||||
from .processors.message_processor import MessageProcessor
|
||||
from .types import ProcessResult
|
||||
|
||||
logger = get_logger("anti_injector")
|
||||
|
||||
@@ -43,7 +43,7 @@ class AntiPromptInjector:
|
||||
|
||||
async def process_message(
|
||||
self, message_data: dict, chat_stream=None
|
||||
) -> Tuple[ProcessResult, Optional[str], Optional[str]]:
|
||||
) -> tuple[ProcessResult, str | None, str | None]:
|
||||
"""处理字典格式的消息并返回结果
|
||||
|
||||
Args:
|
||||
@@ -102,7 +102,7 @@ class AntiPromptInjector:
|
||||
await self.statistics.update_stats(error_count=1)
|
||||
|
||||
# 异常情况下直接阻止消息
|
||||
return ProcessResult.BLOCKED_INJECTION, None, f"反注入系统异常,消息已阻止: {str(e)}"
|
||||
return ProcessResult.BLOCKED_INJECTION, None, f"反注入系统异常,消息已阻止: {e!s}"
|
||||
|
||||
finally:
|
||||
# 更新处理时间统计
|
||||
@@ -111,7 +111,7 @@ class AntiPromptInjector:
|
||||
|
||||
async def _process_message_internal(
|
||||
self, text_to_detect: str, user_id: str, platform: str, processed_plain_text: str, start_time: float
|
||||
) -> Tuple[ProcessResult, Optional[str], Optional[str]]:
|
||||
) -> tuple[ProcessResult, str | None, str | None]:
|
||||
"""内部消息处理逻辑(共用的检测核心)"""
|
||||
|
||||
# 如果是纯引用消息,直接允许通过
|
||||
@@ -218,7 +218,7 @@ class AntiPromptInjector:
|
||||
return ProcessResult.ALLOWED, None, "消息检查通过"
|
||||
|
||||
async def handle_message_storage(
|
||||
self, result: ProcessResult, modified_content: Optional[str], reason: str, message_data: dict
|
||||
self, result: ProcessResult, modified_content: str | None, reason: str, message_data: dict
|
||||
) -> None:
|
||||
"""处理违禁消息的数据库存储,根据处理模式决定如何处理"""
|
||||
if result == ProcessResult.BLOCKED_INJECTION or result == ProcessResult.COUNTER_ATTACK:
|
||||
@@ -253,9 +253,10 @@ class AntiPromptInjector:
|
||||
async def _delete_message_from_storage(message_data: dict) -> None:
|
||||
"""从数据库中删除违禁消息记录"""
|
||||
try:
|
||||
from src.common.database.sqlalchemy_models import Messages, get_db_session
|
||||
from sqlalchemy import delete
|
||||
|
||||
from src.common.database.sqlalchemy_models import Messages, get_db_session
|
||||
|
||||
message_id = message_data.get("message_id")
|
||||
if not message_id:
|
||||
logger.warning("无法删除消息:缺少message_id")
|
||||
@@ -279,9 +280,10 @@ class AntiPromptInjector:
|
||||
async def _update_message_in_storage(message_data: dict, new_content: str) -> None:
|
||||
"""更新数据库中的消息内容为加盾版本"""
|
||||
try:
|
||||
from src.common.database.sqlalchemy_models import Messages, get_db_session
|
||||
from sqlalchemy import update
|
||||
|
||||
from src.common.database.sqlalchemy_models import Messages, get_db_session
|
||||
|
||||
message_id = message_data.get("message_id")
|
||||
if not message_id:
|
||||
logger.warning("无法更新消息:缺少message_id")
|
||||
@@ -305,7 +307,7 @@ class AntiPromptInjector:
|
||||
except Exception as e:
|
||||
logger.error(f"更新消息内容失败: {e}")
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
return await self.statistics.get_stats()
|
||||
|
||||
@@ -315,7 +317,7 @@ class AntiPromptInjector:
|
||||
|
||||
|
||||
# 全局反注入器实例
|
||||
_global_injector: Optional[AntiPromptInjector] = None
|
||||
_global_injector: AntiPromptInjector | None = None
|
||||
|
||||
|
||||
def get_anti_injector() -> AntiPromptInjector:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
反注入系统核心检测模块
|
||||
|
||||
@@ -10,4 +9,4 @@
|
||||
from .detector import PromptInjectionDetector
|
||||
from .shield import MessageShield
|
||||
|
||||
__all__ = ["PromptInjectionDetector", "MessageShield"]
|
||||
__all__ = ["MessageShield", "PromptInjectionDetector"]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
提示词注入检测器模块
|
||||
|
||||
@@ -8,19 +7,19 @@
|
||||
3. 缓存机制优化性能
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
import time
|
||||
import hashlib
|
||||
from typing import Dict, List
|
||||
from dataclasses import asdict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from ..types import DetectionResult
|
||||
|
||||
# 导入LLM API
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
from ..types import DetectionResult
|
||||
|
||||
logger = get_logger("anti_injector.detector")
|
||||
|
||||
|
||||
@@ -30,8 +29,8 @@ class PromptInjectionDetector:
|
||||
def __init__(self):
|
||||
"""初始化检测器"""
|
||||
self.config = global_config.anti_prompt_injection
|
||||
self._cache: Dict[str, DetectionResult] = {}
|
||||
self._compiled_patterns: List[re.Pattern] = []
|
||||
self._cache: dict[str, DetectionResult] = {}
|
||||
self._compiled_patterns: list[re.Pattern] = []
|
||||
self._compile_patterns()
|
||||
|
||||
def _compile_patterns(self):
|
||||
@@ -224,7 +223,7 @@ class PromptInjectionDetector:
|
||||
matched_patterns=[],
|
||||
processing_time=processing_time,
|
||||
detection_method="llm",
|
||||
reason=f"LLM检测出错: {str(e)}",
|
||||
reason=f"LLM检测出错: {e!s}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -250,7 +249,7 @@ class PromptInjectionDetector:
|
||||
请客观分析,避免误判正常对话。"""
|
||||
|
||||
@staticmethod
|
||||
def _parse_llm_response(response: str) -> Dict:
|
||||
def _parse_llm_response(response: str) -> dict:
|
||||
"""解析LLM响应"""
|
||||
try:
|
||||
lines = response.strip().split("\n")
|
||||
@@ -280,7 +279,7 @@ class PromptInjectionDetector:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析LLM响应失败: {e}")
|
||||
return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {str(e)}"}
|
||||
return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {e!s}"}
|
||||
|
||||
async def detect(self, message: str) -> DetectionResult:
|
||||
"""执行检测"""
|
||||
@@ -331,7 +330,7 @@ class PromptInjectionDetector:
|
||||
|
||||
return final_result
|
||||
|
||||
def _merge_results(self, results: List[DetectionResult]) -> DetectionResult:
|
||||
def _merge_results(self, results: list[DetectionResult]) -> DetectionResult:
|
||||
"""合并多个检测结果"""
|
||||
if not results:
|
||||
return DetectionResult(reason="无检测结果")
|
||||
@@ -384,7 +383,7 @@ class PromptInjectionDetector:
|
||||
if expired_keys:
|
||||
logger.debug(f"清理了{len(expired_keys)}个过期缓存项")
|
||||
|
||||
def get_cache_stats(self) -> Dict:
|
||||
def get_cache_stats(self) -> dict:
|
||||
"""获取缓存统计信息"""
|
||||
return {
|
||||
"cache_size": len(self._cache),
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
消息加盾模块
|
||||
|
||||
@@ -6,8 +5,6 @@
|
||||
主要通过注入系统提示词来指导AI安全响应。
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
@@ -35,7 +32,7 @@ class MessageShield:
|
||||
return SAFETY_SYSTEM_PROMPT
|
||||
|
||||
@staticmethod
|
||||
def is_shield_needed(confidence: float, matched_patterns: List[str]) -> bool:
|
||||
def is_shield_needed(confidence: float, matched_patterns: list[str]) -> bool:
|
||||
"""判断是否需要加盾
|
||||
|
||||
Args:
|
||||
@@ -60,7 +57,7 @@ class MessageShield:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def create_safety_summary(confidence: float, matched_patterns: List[str]) -> str:
|
||||
def create_safety_summary(confidence: float, matched_patterns: list[str]) -> str:
|
||||
"""创建安全处理摘要
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
反击消息生成模块
|
||||
|
||||
负责生成个性化的反击消息回应提示词注入攻击
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
from .types import DetectionResult
|
||||
|
||||
logger = get_logger("anti_injector.counter_attack")
|
||||
@@ -55,7 +53,7 @@ class CounterAttackGenerator:
|
||||
|
||||
async def generate_counter_attack_message(
|
||||
self, original_message: str, detection_result: DetectionResult
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
"""生成反击消息
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
反注入系统决策模块
|
||||
|
||||
@@ -7,7 +6,7 @@
|
||||
- counter_attack: 反击消息生成器
|
||||
"""
|
||||
|
||||
from .decision_maker import ProcessingDecisionMaker
|
||||
from .counter_attack import CounterAttackGenerator
|
||||
from .decision_maker import ProcessingDecisionMaker
|
||||
|
||||
__all__ = ["ProcessingDecisionMaker", "CounterAttackGenerator"]
|
||||
__all__ = ["CounterAttackGenerator", "ProcessingDecisionMaker"]
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
反击消息生成模块
|
||||
|
||||
负责生成个性化的反击消息回应提示词注入攻击
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
from ..types import DetectionResult
|
||||
|
||||
logger = get_logger("anti_injector.counter_attack")
|
||||
@@ -55,7 +53,7 @@ class CounterAttackGenerator:
|
||||
|
||||
async def generate_counter_attack_message(
|
||||
self, original_message: str, detection_result: DetectionResult
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
"""生成反击消息
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
处理决策器模块
|
||||
|
||||
@@ -6,6 +5,7 @@
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from ..types import DetectionResult
|
||||
|
||||
logger = get_logger("anti_injector.decision_maker")
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
处理决策器模块
|
||||
|
||||
@@ -6,6 +5,7 @@
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .types import DetectionResult
|
||||
|
||||
logger = get_logger("anti_injector.decision_maker")
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
提示词注入检测器模块
|
||||
|
||||
@@ -8,19 +7,19 @@
|
||||
3. 缓存机制优化性能
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
import time
|
||||
import hashlib
|
||||
from typing import Dict, List
|
||||
from dataclasses import asdict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from .types import DetectionResult
|
||||
|
||||
# 导入LLM API
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
from .types import DetectionResult
|
||||
|
||||
logger = get_logger("anti_injector.detector")
|
||||
|
||||
|
||||
@@ -30,8 +29,8 @@ class PromptInjectionDetector:
|
||||
def __init__(self):
|
||||
"""初始化检测器"""
|
||||
self.config = global_config.anti_prompt_injection
|
||||
self._cache: Dict[str, DetectionResult] = {}
|
||||
self._compiled_patterns: List[re.Pattern] = []
|
||||
self._cache: dict[str, DetectionResult] = {}
|
||||
self._compiled_patterns: list[re.Pattern] = []
|
||||
self._compile_patterns()
|
||||
|
||||
def _compile_patterns(self):
|
||||
@@ -221,7 +220,7 @@ class PromptInjectionDetector:
|
||||
matched_patterns=[],
|
||||
processing_time=processing_time,
|
||||
detection_method="llm",
|
||||
reason=f"LLM检测出错: {str(e)}",
|
||||
reason=f"LLM检测出错: {e!s}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -247,7 +246,7 @@ class PromptInjectionDetector:
|
||||
请客观分析,避免误判正常对话。"""
|
||||
|
||||
@staticmethod
|
||||
def _parse_llm_response(response: str) -> Dict:
|
||||
def _parse_llm_response(response: str) -> dict:
|
||||
"""解析LLM响应"""
|
||||
try:
|
||||
lines = response.strip().split("\n")
|
||||
@@ -277,7 +276,7 @@ class PromptInjectionDetector:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析LLM响应失败: {e}")
|
||||
return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {str(e)}"}
|
||||
return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {e!s}"}
|
||||
|
||||
async def detect(self, message: str) -> DetectionResult:
|
||||
"""执行检测"""
|
||||
@@ -328,7 +327,7 @@ class PromptInjectionDetector:
|
||||
|
||||
return final_result
|
||||
|
||||
def _merge_results(self, results: List[DetectionResult]) -> DetectionResult:
|
||||
def _merge_results(self, results: list[DetectionResult]) -> DetectionResult:
|
||||
"""合并多个检测结果"""
|
||||
if not results:
|
||||
return DetectionResult(reason="无检测结果")
|
||||
@@ -381,7 +380,7 @@ class PromptInjectionDetector:
|
||||
if expired_keys:
|
||||
logger.debug(f"清理了{len(expired_keys)}个过期缓存项")
|
||||
|
||||
def get_cache_stats(self) -> Dict:
|
||||
def get_cache_stats(self) -> dict:
|
||||
"""获取缓存统计信息"""
|
||||
return {
|
||||
"cache_size": len(self._cache),
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
反注入系统管理模块
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
反注入系统统计模块
|
||||
|
||||
@@ -6,12 +5,12 @@
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from typing import Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import AntiInjectionStats, get_db_session
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("anti_injector.statistics")
|
||||
@@ -94,7 +93,7 @@ class AntiInjectionStatistics:
|
||||
except Exception as e:
|
||||
logger.error(f"更新统计数据失败: {e}")
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
try:
|
||||
# 检查反注入系统是否启用
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
用户封禁管理模块
|
||||
|
||||
@@ -6,12 +5,12 @@
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import BanUser, get_db_session
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from ..types import DetectionResult
|
||||
|
||||
logger = get_logger("anti_injector.user_ban")
|
||||
@@ -28,7 +27,7 @@ class UserBanManager:
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
async def check_user_ban(self, user_id: str, platform: str) -> Optional[Tuple[bool, Optional[str], str]]:
|
||||
async def check_user_ban(self, user_id: str, platform: str) -> tuple[bool, str | None, str] | None:
|
||||
"""检查用户是否被封禁
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
反注入系统消息处理模块
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
消息内容处理模块
|
||||
|
||||
@@ -6,10 +5,9 @@
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("anti_injector.message_processor")
|
||||
|
||||
@@ -66,7 +64,7 @@ class MessageProcessor:
|
||||
return new_content
|
||||
|
||||
@staticmethod
|
||||
def check_whitelist(message: MessageRecv, whitelist: list) -> Optional[tuple]:
|
||||
def check_whitelist(message: MessageRecv, whitelist: list) -> tuple | None:
|
||||
"""检查用户白名单
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
反注入系统数据类型定义模块
|
||||
|
||||
@@ -10,7 +9,6 @@
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
@@ -31,8 +29,8 @@ class DetectionResult:
|
||||
|
||||
is_injection: bool = False
|
||||
confidence: float = 0.0
|
||||
matched_patterns: List[str] = field(default_factory=list)
|
||||
llm_analysis: Optional[str] = None
|
||||
matched_patterns: list[str] = field(default_factory=list)
|
||||
llm_analysis: str | None = None
|
||||
processing_time: float = 0.0
|
||||
detection_method: str = "unknown"
|
||||
reason: str = ""
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from typing import Dict, List, Optional, Any
|
||||
import time
|
||||
from src.plugin_system.base.base_chatter import BaseChatter
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from typing import Any
|
||||
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_chatter import BaseChatter
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
|
||||
logger = get_logger("chatter_manager")
|
||||
|
||||
@@ -12,8 +13,8 @@ logger = get_logger("chatter_manager")
|
||||
class ChatterManager:
|
||||
def __init__(self, action_manager: ChatterActionManager):
|
||||
self.action_manager = action_manager
|
||||
self.chatter_classes: Dict[ChatType, List[type]] = {}
|
||||
self.instances: Dict[str, BaseChatter] = {}
|
||||
self.chatter_classes: dict[ChatType, list[type]] = {}
|
||||
self.instances: dict[str, BaseChatter] = {}
|
||||
|
||||
# 管理器统计
|
||||
self.stats = {
|
||||
@@ -46,21 +47,21 @@ class ChatterManager:
|
||||
|
||||
self.stats["chatters_registered"] += 1
|
||||
|
||||
def get_chatter_class(self, chat_type: ChatType) -> Optional[type]:
|
||||
def get_chatter_class(self, chat_type: ChatType) -> type | None:
|
||||
"""获取指定聊天类型的聊天处理器类"""
|
||||
if chat_type in self.chatter_classes:
|
||||
return self.chatter_classes[chat_type][0]
|
||||
return None
|
||||
|
||||
def get_supported_chat_types(self) -> List[ChatType]:
|
||||
def get_supported_chat_types(self) -> list[ChatType]:
|
||||
"""获取支持的聊天类型列表"""
|
||||
return list(self.chatter_classes.keys())
|
||||
|
||||
def get_registered_chatters(self) -> Dict[ChatType, List[type]]:
|
||||
def get_registered_chatters(self) -> dict[ChatType, list[type]]:
|
||||
"""获取已注册的聊天处理器"""
|
||||
return self.chatter_classes.copy()
|
||||
|
||||
def get_stream_instance(self, stream_id: str) -> Optional[BaseChatter]:
|
||||
def get_stream_instance(self, stream_id: str) -> BaseChatter | None:
|
||||
"""获取指定流的聊天处理器实例"""
|
||||
return self.instances.get(stream_id)
|
||||
|
||||
@@ -139,7 +140,7 @@ class ChatterManager:
|
||||
logger.error(f"处理流 {stream_id} 时发生错误: {e}")
|
||||
raise
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""获取管理器统计信息"""
|
||||
stats = self.stats.copy()
|
||||
stats["active_instances"] = len(self.instances)
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
表情包发送历史记录模块
|
||||
"""
|
||||
|
||||
from typing import List, Dict
|
||||
from collections import deque
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -14,7 +12,7 @@ MAX_HISTORY_SIZE = 5 # 每个聊天会话最多保留最近5条表情历史
|
||||
|
||||
# 使用一个全局字典在内存中存储历史记录
|
||||
# 键是 chat_id,值是一个 deque 对象
|
||||
_history_cache: Dict[str, deque] = {}
|
||||
_history_cache: dict[str, deque] = {}
|
||||
|
||||
|
||||
def add_emoji_to_history(chat_id: str, emoji_description: str):
|
||||
@@ -38,7 +36,7 @@ def add_emoji_to_history(chat_id: str, emoji_description: str):
|
||||
logger.debug(f"已将表情 '{emoji_description}' 添加到聊天 {chat_id} 的内存历史中")
|
||||
|
||||
|
||||
def get_recent_emojis(chat_id: str, limit: int = 5) -> List[str]:
|
||||
def get_recent_emojis(chat_id: str, limit: int = 5) -> list[str]:
|
||||
"""
|
||||
从内存中获取最近发送的表情包描述列表。
|
||||
|
||||
|
||||
@@ -1,23 +1,24 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import binascii
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
import io
|
||||
import re
|
||||
import binascii
|
||||
from typing import Any, Optional
|
||||
|
||||
from typing import Optional, Tuple, List, Any
|
||||
from PIL import Image
|
||||
from rich.traceback import install
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.chat.utils.utils_image import get_image_manager, image_path_to_base64
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import Emoji, Images
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -47,14 +48,14 @@ class MaiEmoji:
|
||||
self.embedding = []
|
||||
self.hash = "" # 初始为空,在创建实例时会计算
|
||||
self.description = ""
|
||||
self.emotion: List[str] = []
|
||||
self.emotion: list[str] = []
|
||||
self.usage_count = 0
|
||||
self.last_used_time = time.time()
|
||||
self.register_time = time.time()
|
||||
self.is_deleted = False # 标记是否已被删除
|
||||
self.format = ""
|
||||
|
||||
async def initialize_hash_format(self) -> Optional[bool]:
|
||||
async def initialize_hash_format(self) -> bool | None:
|
||||
"""从文件创建表情包实例, 计算哈希值和格式"""
|
||||
try:
|
||||
# 使用 full_path 检查文件是否存在
|
||||
@@ -105,7 +106,7 @@ class MaiEmoji:
|
||||
self.is_deleted = True
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[初始化错误] 初始化表情包时发生未预期错误 ({self.filename}): {str(e)}")
|
||||
logger.error(f"[初始化错误] 初始化表情包时发生未预期错误 ({self.filename}): {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
self.is_deleted = True
|
||||
return None
|
||||
@@ -142,7 +143,7 @@ class MaiEmoji:
|
||||
self.path = EMOJI_REGISTERED_DIR
|
||||
# self.filename 保持不变
|
||||
except Exception as move_error:
|
||||
logger.error(f"[错误] 移动文件失败: {str(move_error)}")
|
||||
logger.error(f"[错误] 移动文件失败: {move_error!s}")
|
||||
# 如果移动失败,尝试将实例状态恢复?暂时不处理,仅返回失败
|
||||
return False
|
||||
|
||||
@@ -174,11 +175,11 @@ class MaiEmoji:
|
||||
return True
|
||||
|
||||
except Exception as db_error:
|
||||
logger.error(f"[错误] 保存数据库失败 ({self.filename}): {str(db_error)}")
|
||||
logger.error(f"[错误] 保存数据库失败 ({self.filename}): {db_error!s}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 注册表情包失败 ({self.filename}): {str(e)}")
|
||||
logger.error(f"[错误] 注册表情包失败 ({self.filename}): {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
@@ -198,7 +199,7 @@ class MaiEmoji:
|
||||
os.remove(file_to_delete)
|
||||
logger.debug(f"[删除] 文件: {file_to_delete}")
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除文件失败 {file_to_delete}: {str(e)}")
|
||||
logger.error(f"[错误] 删除文件失败 {file_to_delete}: {e!s}")
|
||||
# 文件删除失败,但仍然尝试删除数据库记录
|
||||
|
||||
# 2. 删除数据库记录
|
||||
@@ -214,7 +215,7 @@ class MaiEmoji:
|
||||
result = 1 # Successfully deleted one record
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除数据库记录时出错: {str(e)}")
|
||||
logger.error(f"[错误] 删除数据库记录时出错: {e!s}")
|
||||
result = 0
|
||||
|
||||
if result > 0:
|
||||
@@ -233,11 +234,11 @@ class MaiEmoji:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除表情包失败 ({self.filename}): {str(e)}")
|
||||
logger.error(f"[错误] 删除表情包失败 ({self.filename}): {e!s}")
|
||||
return False
|
||||
|
||||
|
||||
def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str]:
|
||||
def _emoji_objects_to_readable_list(emoji_objects: list["MaiEmoji"]) -> list[str]:
|
||||
"""将表情包对象列表转换为可读的字符串列表
|
||||
|
||||
参数:
|
||||
@@ -256,7 +257,7 @@ def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str
|
||||
return emoji_info_list
|
||||
|
||||
|
||||
def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
|
||||
def _to_emoji_objects(data: Any) -> tuple[list["MaiEmoji"], int]:
|
||||
emoji_objects = []
|
||||
load_errors = 0
|
||||
emoji_data_list = list(data)
|
||||
@@ -300,7 +301,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
|
||||
logger.error(f"[加载错误] 初始化 MaiEmoji 失败 ({full_path}): {ve}")
|
||||
load_errors += 1
|
||||
except Exception as e:
|
||||
logger.error(f"[加载错误] 处理数据库记录时出错 ({full_path}): {str(e)}")
|
||||
logger.error(f"[加载错误] 处理数据库记录时出错 ({full_path}): {e!s}")
|
||||
load_errors += 1
|
||||
return emoji_objects, load_errors
|
||||
|
||||
@@ -335,7 +336,7 @@ async def clear_temp_emoji() -> None:
|
||||
logger.debug(f"[清理] 删除: {filename}")
|
||||
|
||||
|
||||
async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], removed_count: int) -> int:
|
||||
async def clean_unused_emojis(emoji_dir: str, emoji_objects: list["MaiEmoji"], removed_count: int) -> int:
|
||||
"""清理指定目录中未被 emoji_objects 追踪的表情包文件"""
|
||||
if not os.path.exists(emoji_dir):
|
||||
logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}")
|
||||
@@ -361,7 +362,7 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r
|
||||
logger.info(f"[清理] 删除未追踪的表情包文件: {file_full_path}")
|
||||
cleaned_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除文件时出错 ({file_full_path}): {str(e)}")
|
||||
logger.error(f"[错误] 删除文件时出错 ({file_full_path}): {e!s}")
|
||||
|
||||
if cleaned_count > 0:
|
||||
logger.info(f"[清理] 在目录 {emoji_dir} 中清理了 {cleaned_count} 个破损表情包。")
|
||||
@@ -369,7 +370,7 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r
|
||||
logger.info(f"[清理] 目录 {emoji_dir} 中没有需要清理的。")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 清理未使用表情包文件时出错 ({emoji_dir}): {str(e)}")
|
||||
logger.error(f"[错误] 清理未使用表情包文件时出错 ({emoji_dir}): {e!s}")
|
||||
|
||||
return removed_count + cleaned_count
|
||||
|
||||
@@ -437,9 +438,9 @@ class EmojiManager:
|
||||
emoji_update.last_used_time = time.time() # Update last used time
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"记录表情使用失败: {str(e)}")
|
||||
logger.error(f"记录表情使用失败: {e!s}")
|
||||
|
||||
async def get_emoji_for_text(self, text_emotion: str) -> Optional[Tuple[str, str, str]]:
|
||||
async def get_emoji_for_text(self, text_emotion: str) -> tuple[str, str, str] | None:
|
||||
"""
|
||||
根据文本内容,使用LLM选择一个合适的表情包。
|
||||
|
||||
@@ -531,7 +532,7 @@ class EmojiManager:
|
||||
return selected_emoji.full_path, f"[表情包:{selected_emoji.description}]", text_emotion
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"使用LLM获取表情包时发生错误: {str(e)}")
|
||||
logger.error(f"使用LLM获取表情包时发生错误: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
@@ -578,7 +579,7 @@ class EmojiManager:
|
||||
continue
|
||||
|
||||
except Exception as item_error:
|
||||
logger.error(f"[错误] 处理表情包记录时出错 ({emoji.filename}): {str(item_error)}")
|
||||
logger.error(f"[错误] 处理表情包记录时出错 ({emoji.filename}): {item_error!s}")
|
||||
# 即使出错,也尝试继续检查下一个
|
||||
continue
|
||||
|
||||
@@ -597,7 +598,7 @@ class EmojiManager:
|
||||
logger.info(f"[检查] 已检查 {total_count} 个表情包记录,全部完好")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 检查表情包完整性失败: {str(e)}")
|
||||
logger.error(f"[错误] 检查表情包完整性失败: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def start_periodic_check_register(self) -> None:
|
||||
@@ -651,7 +652,7 @@ class EmojiManager:
|
||||
os.remove(file_path)
|
||||
logger.warning(f"[清理] 删除注册失败的表情包文件: {filename}")
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 扫描表情包目录失败: {str(e)}")
|
||||
logger.error(f"[错误] 扫描表情包目录失败: {e!s}")
|
||||
|
||||
await asyncio.sleep(global_config.emoji.check_interval * 60)
|
||||
|
||||
@@ -674,11 +675,11 @@ class EmojiManager:
|
||||
logger.warning(f"[数据库] 加载过程中出现 {load_errors} 个错误。")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 从数据库加载所有表情包对象失败: {str(e)}")
|
||||
logger.error(f"[错误] 从数据库加载所有表情包对象失败: {e!s}")
|
||||
self.emoji_objects = [] # 加载失败则清空列表
|
||||
self.emoji_num = 0
|
||||
|
||||
async def get_emoji_from_db(self, emoji_hash: Optional[str] = None) -> List["MaiEmoji"]:
|
||||
async def get_emoji_from_db(self, emoji_hash: str | None = None) -> list["MaiEmoji"]:
|
||||
"""获取指定哈希值的表情包并初始化为MaiEmoji类对象列表 (主要用于调试或特定查找)
|
||||
|
||||
参数:
|
||||
@@ -707,7 +708,7 @@ class EmojiManager:
|
||||
return emoji_objects
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 从数据库获取表情包对象失败: {str(e)}")
|
||||
logger.error(f"[错误] 从数据库获取表情包对象失败: {e!s}")
|
||||
return []
|
||||
|
||||
async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]:
|
||||
@@ -725,7 +726,7 @@ class EmojiManager:
|
||||
return emoji
|
||||
return None # 如果循环结束还没找到,则返回 None
|
||||
|
||||
async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[str]:
|
||||
async def get_emoji_tag_by_hash(self, emoji_hash: str) -> str | None:
|
||||
"""根据哈希值获取已注册表情包的描述
|
||||
|
||||
Args:
|
||||
@@ -753,10 +754,10 @@ class EmojiManager:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}")
|
||||
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {e!s}")
|
||||
return None
|
||||
|
||||
async def get_emoji_description_by_hash(self, emoji_hash: str) -> Optional[str]:
|
||||
async def get_emoji_description_by_hash(self, emoji_hash: str) -> str | None:
|
||||
"""根据哈希值获取已注册表情包的描述
|
||||
|
||||
Args:
|
||||
@@ -787,7 +788,7 @@ class EmojiManager:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}")
|
||||
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {e!s}")
|
||||
return None
|
||||
|
||||
async def delete_emoji(self, emoji_hash: str) -> bool:
|
||||
@@ -823,7 +824,7 @@ class EmojiManager:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除表情包失败: {str(e)}")
|
||||
logger.error(f"[错误] 删除表情包失败: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
@@ -909,11 +910,11 @@ class EmojiManager:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 替换表情包失败: {str(e)}")
|
||||
logger.error(f"[错误] 替换表情包失败: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
async def build_emoji_description(self, image_base64: str) -> Tuple[str, List[str]]:
|
||||
async def build_emoji_description(self, image_base64: str) -> tuple[str, list[str]]:
|
||||
"""
|
||||
获取表情包的详细描述和情感关键词列表。
|
||||
|
||||
@@ -976,14 +977,14 @@ class EmojiManager:
|
||||
|
||||
# 4. 内容审核,确保表情包符合规定
|
||||
if global_config.emoji.content_filtration:
|
||||
prompt = f'''
|
||||
prompt = f"""
|
||||
请根据以下标准审核这个表情包:
|
||||
1. 主题必须符合:"{global_config.emoji.filtration_prompt}"。
|
||||
2. 内容健康,不含色情、暴力、政治敏感等元素。
|
||||
3. 必须是表情包,而不是普通的聊天截图或视频截图。
|
||||
4. 表情包中的文字数量(如果有)不能超过5个。
|
||||
这个表情包是否完全满足以上所有要求?请只回答“是”或“否”。
|
||||
'''
|
||||
"""
|
||||
content, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.1, max_tokens=10
|
||||
)
|
||||
@@ -1023,7 +1024,7 @@ class EmojiManager:
|
||||
return final_description, emotions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"构建表情包描述时发生严重错误: {str(e)}")
|
||||
logger.error(f"构建表情包描述时发生严重错误: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
return "", []
|
||||
|
||||
@@ -1058,7 +1059,7 @@ class EmojiManager:
|
||||
os.remove(file_full_path)
|
||||
logger.info(f"[清理] 删除重复的待注册文件: {filename}")
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除重复文件失败: {str(e)}")
|
||||
logger.error(f"[错误] 删除重复文件失败: {e!s}")
|
||||
return False # 返回 False 表示未注册新表情
|
||||
|
||||
# 3. 构建描述和情感
|
||||
@@ -1075,7 +1076,7 @@ class EmojiManager:
|
||||
os.remove(file_full_path)
|
||||
logger.info(f"[清理] 删除描述生成失败的文件: {filename}")
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除描述生成失败文件时出错: {str(e)}")
|
||||
logger.error(f"[错误] 删除描述生成失败文件时出错: {e!s}")
|
||||
return False
|
||||
new_emoji.description = description
|
||||
new_emoji.emotion = emotions
|
||||
@@ -1086,7 +1087,7 @@ class EmojiManager:
|
||||
os.remove(file_full_path)
|
||||
logger.info(f"[清理] 删除描述生成异常的文件: {filename}")
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除描述生成异常文件时出错: {str(e)}")
|
||||
logger.error(f"[错误] 删除描述生成异常文件时出错: {e!s}")
|
||||
return False
|
||||
|
||||
# 4. 检查容量并决定是否替换或直接注册
|
||||
@@ -1100,7 +1101,7 @@ class EmojiManager:
|
||||
os.remove(file_full_path) # new_emoji 的 full_path 此时还是源路径
|
||||
logger.info(f"[清理] 删除替换失败的新表情文件: {filename}")
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除替换失败文件时出错: {str(e)}")
|
||||
logger.error(f"[错误] 删除替换失败文件时出错: {e!s}")
|
||||
return False
|
||||
# 替换成功时,replace_a_emoji 内部已处理 new_emoji 的注册和添加到列表
|
||||
return True
|
||||
@@ -1122,11 +1123,11 @@ class EmojiManager:
|
||||
os.remove(file_full_path)
|
||||
logger.info(f"[清理] 删除注册失败的源文件: {filename}")
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除注册失败源文件时出错: {str(e)}")
|
||||
logger.error(f"[错误] 删除注册失败源文件时出错: {e!s}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 注册表情包时发生未预期错误 ({filename}): {str(e)}")
|
||||
logger.error(f"[错误] 注册表情包时发生未预期错误 ({filename}): {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
# 尝试删除源文件以避免循环处理
|
||||
if os.path.exists(file_full_path):
|
||||
|
||||
@@ -4,24 +4,24 @@
|
||||
"""
|
||||
|
||||
from .energy_manager import (
|
||||
EnergyManager,
|
||||
EnergyLevel,
|
||||
EnergyComponent,
|
||||
EnergyCalculator,
|
||||
InterestEnergyCalculator,
|
||||
ActivityEnergyCalculator,
|
||||
EnergyCalculator,
|
||||
EnergyComponent,
|
||||
EnergyLevel,
|
||||
EnergyManager,
|
||||
InterestEnergyCalculator,
|
||||
RecencyEnergyCalculator,
|
||||
RelationshipEnergyCalculator,
|
||||
energy_manager,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EnergyManager",
|
||||
"EnergyLevel",
|
||||
"EnergyComponent",
|
||||
"EnergyCalculator",
|
||||
"InterestEnergyCalculator",
|
||||
"ActivityEnergyCalculator",
|
||||
"EnergyCalculator",
|
||||
"EnergyComponent",
|
||||
"EnergyLevel",
|
||||
"EnergyManager",
|
||||
"InterestEnergyCalculator",
|
||||
"RecencyEnergyCalculator",
|
||||
"RelationshipEnergyCalculator",
|
||||
"energy_manager",
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union, TypedDict
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
@@ -51,8 +51,8 @@ class EnergyContext(TypedDict):
|
||||
"""能量计算上下文"""
|
||||
|
||||
stream_id: str
|
||||
messages: List[Any]
|
||||
user_id: Optional[str]
|
||||
messages: list[Any]
|
||||
user_id: str | None
|
||||
|
||||
|
||||
class EnergyResult(TypedDict):
|
||||
@@ -61,7 +61,7 @@ class EnergyResult(TypedDict):
|
||||
energy: float
|
||||
level: EnergyLevel
|
||||
distribution_interval: float
|
||||
component_scores: Dict[str, float]
|
||||
component_scores: dict[str, float]
|
||||
cached: bool
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ class EnergyCalculator(ABC):
|
||||
"""能量计算器抽象基类"""
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
def calculate(self, context: dict[str, Any]) -> float:
|
||||
"""计算能量值"""
|
||||
pass
|
||||
|
||||
@@ -82,7 +82,7 @@ class EnergyCalculator(ABC):
|
||||
class InterestEnergyCalculator(EnergyCalculator):
|
||||
"""兴趣度能量计算器"""
|
||||
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
def calculate(self, context: dict[str, Any]) -> float:
|
||||
"""基于消息兴趣度计算能量"""
|
||||
messages = context.get("messages", [])
|
||||
if not messages:
|
||||
@@ -120,7 +120,7 @@ class ActivityEnergyCalculator(EnergyCalculator):
|
||||
def __init__(self):
|
||||
self.action_weights = {"reply": 0.4, "react": 0.3, "mention": 0.2, "other": 0.1}
|
||||
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
def calculate(self, context: dict[str, Any]) -> float:
|
||||
"""基于活跃度计算能量"""
|
||||
messages = context.get("messages", [])
|
||||
if not messages:
|
||||
@@ -150,7 +150,7 @@ class ActivityEnergyCalculator(EnergyCalculator):
|
||||
class RecencyEnergyCalculator(EnergyCalculator):
|
||||
"""最近性能量计算器"""
|
||||
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
def calculate(self, context: dict[str, Any]) -> float:
|
||||
"""基于最近性计算能量"""
|
||||
messages = context.get("messages", [])
|
||||
if not messages:
|
||||
@@ -197,7 +197,7 @@ class RecencyEnergyCalculator(EnergyCalculator):
|
||||
class RelationshipEnergyCalculator(EnergyCalculator):
|
||||
"""关系能量计算器"""
|
||||
|
||||
async def calculate(self, context: Dict[str, Any]) -> float:
|
||||
async def calculate(self, context: dict[str, Any]) -> float:
|
||||
"""基于关系计算能量"""
|
||||
user_id = context.get("user_id")
|
||||
if not user_id:
|
||||
@@ -223,7 +223,7 @@ class EnergyManager:
|
||||
"""能量管理器 - 统一管理所有能量计算"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.calculators: List[EnergyCalculator] = [
|
||||
self.calculators: list[EnergyCalculator] = [
|
||||
InterestEnergyCalculator(),
|
||||
ActivityEnergyCalculator(),
|
||||
RecencyEnergyCalculator(),
|
||||
@@ -231,14 +231,14 @@ class EnergyManager:
|
||||
]
|
||||
|
||||
# 能量缓存
|
||||
self.energy_cache: Dict[str, Tuple[float, float]] = {} # stream_id -> (energy, timestamp)
|
||||
self.energy_cache: dict[str, tuple[float, float]] = {} # stream_id -> (energy, timestamp)
|
||||
self.cache_ttl: int = 60 # 1分钟缓存
|
||||
|
||||
# AFC阈值配置
|
||||
self.thresholds: Dict[str, float] = {"high_match": 0.8, "reply": 0.4, "non_reply": 0.2}
|
||||
self.thresholds: dict[str, float] = {"high_match": 0.8, "reply": 0.4, "non_reply": 0.2}
|
||||
|
||||
# 统计信息
|
||||
self.stats: Dict[str, Union[int, float, str]] = {
|
||||
self.stats: dict[str, int | float | str] = {
|
||||
"total_calculations": 0,
|
||||
"cache_hits": 0,
|
||||
"cache_misses": 0,
|
||||
@@ -272,7 +272,7 @@ class EnergyManager:
|
||||
except Exception as e:
|
||||
logger.warning(f"加载AFC阈值失败,使用默认值: {e}")
|
||||
|
||||
async def calculate_focus_energy(self, stream_id: str, messages: List[Any], user_id: Optional[str] = None) -> float:
|
||||
async def calculate_focus_energy(self, stream_id: str, messages: list[Any], user_id: str | None = None) -> float:
|
||||
"""计算聊天流的focus_energy"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -297,7 +297,7 @@ class EnergyManager:
|
||||
}
|
||||
|
||||
# 计算各组件能量
|
||||
component_scores: Dict[str, float] = {}
|
||||
component_scores: dict[str, float] = {}
|
||||
total_weight = 0.0
|
||||
|
||||
for calculator in self.calculators:
|
||||
@@ -437,7 +437,7 @@ class EnergyManager:
|
||||
if expired_keys:
|
||||
logger.debug(f"清理了 {len(expired_keys)} 个过期能量缓存")
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
def get_statistics(self) -> dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
return {
|
||||
"cache_size": len(self.energy_cache),
|
||||
@@ -446,7 +446,7 @@ class EnergyManager:
|
||||
"performance_stats": self.stats.copy(),
|
||||
}
|
||||
|
||||
def update_thresholds(self, new_thresholds: Dict[str, float]) -> None:
|
||||
def update_thresholds(self, new_thresholds: dict[str, float]) -> None:
|
||||
"""更新阈值"""
|
||||
self.thresholds.update(new_thresholds)
|
||||
|
||||
|
||||
@@ -1,21 +1,20 @@
|
||||
import time
|
||||
import random
|
||||
import orjson
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
import orjson
|
||||
from sqlalchemy import select
|
||||
from src.common.database.sqlalchemy_models import Expression
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import build_anonymous_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import Expression
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
MAX_EXPRESSION_COUNT = 300
|
||||
DECAY_DAYS = 30 # 30天衰减到0.01
|
||||
@@ -193,7 +192,7 @@ class ExpressionLearner:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
|
||||
return False
|
||||
|
||||
async def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
|
||||
async def get_expression_by_chat_id(self) -> tuple[list[dict[str, float]], list[dict[str, float]]]:
|
||||
"""
|
||||
获取指定chat_id的style和grammar表达方式
|
||||
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
|
||||
@@ -341,7 +340,7 @@ class ExpressionLearner:
|
||||
return []
|
||||
|
||||
# 按chat_id分组
|
||||
chat_dict: Dict[str, List[Dict[str, Any]]] = {}
|
||||
chat_dict: dict[str, list[dict[str, Any]]] = {}
|
||||
for chat_id, situation, style in learnt_expressions:
|
||||
if chat_id not in chat_dict:
|
||||
chat_dict[chat_id] = []
|
||||
@@ -398,7 +397,7 @@ class ExpressionLearner:
|
||||
return learnt_expressions
|
||||
return None
|
||||
|
||||
async def learn_expression(self, type: str, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
|
||||
async def learn_expression(self, type: str, num: int = 10) -> tuple[list[tuple[str, str, str]], str] | None:
|
||||
"""从指定聊天流学习表达方式
|
||||
|
||||
Args:
|
||||
@@ -416,7 +415,7 @@ class ExpressionLearner:
|
||||
current_time = time.time()
|
||||
|
||||
# 获取上次学习时间
|
||||
random_msg: Optional[List[Dict[str, Any]]] = await get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
random_msg: list[dict[str, Any]] | None = await get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_learning_time,
|
||||
timestamp_end=current_time,
|
||||
@@ -447,16 +446,16 @@ class ExpressionLearner:
|
||||
|
||||
logger.debug(f"学习{type_str}的response: {response}")
|
||||
|
||||
expressions: List[Tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
|
||||
expressions: list[tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
|
||||
|
||||
return expressions, chat_id
|
||||
|
||||
@staticmethod
|
||||
def parse_expression_response(response: str, chat_id: str) -> List[Tuple[str, str, str]]:
|
||||
def parse_expression_response(response: str, chat_id: str) -> list[tuple[str, str, str]]:
|
||||
"""
|
||||
解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组
|
||||
"""
|
||||
expressions: List[Tuple[str, str, str]] = []
|
||||
expressions: list[tuple[str, str, str]] = []
|
||||
for line in response.splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
@@ -562,7 +561,7 @@ class ExpressionLearnerManager:
|
||||
if not os.path.exists(expr_file):
|
||||
continue
|
||||
try:
|
||||
with open(expr_file, "r", encoding="utf-8") as f:
|
||||
with open(expr_file, encoding="utf-8") as f:
|
||||
expressions = orjson.loads(f.read())
|
||||
|
||||
if not isinstance(expressions, list):
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
import orjson
|
||||
import time
|
||||
import random
|
||||
import hashlib
|
||||
import random
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from typing import List, Dict, Tuple, Optional, Any
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from sqlalchemy import select
|
||||
from src.common.database.sqlalchemy_models import Expression
|
||||
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import Expression
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
@@ -45,7 +45,7 @@ def init_prompt():
|
||||
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
|
||||
|
||||
|
||||
def weighted_sample(population: List[Dict], weights: List[float], k: int) -> List[Dict]:
|
||||
def weighted_sample(population: list[dict], weights: list[float], k: int) -> list[dict]:
|
||||
"""按权重随机抽样"""
|
||||
if not population or not weights or k <= 0:
|
||||
return []
|
||||
@@ -95,7 +95,7 @@ class ExpressionSelector:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
|
||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None:
|
||||
"""解析'platform:id:type'为chat_id(与get_stream_id一致)"""
|
||||
try:
|
||||
parts = stream_config_str.split(":")
|
||||
@@ -114,7 +114,7 @@ class ExpressionSelector:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_related_chat_ids(self, chat_id: str) -> List[str]:
|
||||
def get_related_chat_ids(self, chat_id: str) -> list[str]:
|
||||
"""根据expression.rules配置,获取与当前chat_id相关的所有chat_id(包括自身)"""
|
||||
rules = global_config.expression.rules
|
||||
current_group = None
|
||||
@@ -139,7 +139,7 @@ class ExpressionSelector:
|
||||
|
||||
async def get_random_expressions(
|
||||
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
|
||||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
# sourcery skip: extract-duplicate-method, move-assign
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
@@ -195,7 +195,7 @@ class ExpressionSelector:
|
||||
return selected_style, selected_grammar
|
||||
|
||||
@staticmethod
|
||||
async def update_expressions_count_batch(expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
|
||||
async def update_expressions_count_batch(expressions_to_update: list[dict[str, Any]], increment: float = 0.1):
|
||||
"""对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库"""
|
||||
if not expressions_to_update:
|
||||
return
|
||||
@@ -240,8 +240,8 @@ class ExpressionSelector:
|
||||
chat_info: str,
|
||||
max_num: int = 10,
|
||||
min_num: int = 5,
|
||||
target_message: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
target_message: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
# sourcery skip: inline-variable, list-comprehension
|
||||
"""使用LLM选择适合的表达方式"""
|
||||
|
||||
|
||||
@@ -16,8 +16,7 @@ Chat Frequency Analyzer
|
||||
"""
|
||||
|
||||
import time as time_module
|
||||
from datetime import datetime, timedelta, time
|
||||
from typing import List, Tuple, Optional
|
||||
from datetime import datetime, time, timedelta
|
||||
|
||||
from .tracker import chat_frequency_tracker
|
||||
|
||||
@@ -42,7 +41,7 @@ class ChatFrequencyAnalyzer:
|
||||
self._cache_ttl_seconds = 60 * 30 # 缓存30分钟
|
||||
|
||||
@staticmethod
|
||||
def _find_peak_windows(timestamps: List[float]) -> List[Tuple[datetime, datetime]]:
|
||||
def _find_peak_windows(timestamps: list[float]) -> list[tuple[datetime, datetime]]:
|
||||
"""
|
||||
使用滑动窗口算法来识别时间戳列表中的高峰时段。
|
||||
|
||||
@@ -59,7 +58,7 @@ class ChatFrequencyAnalyzer:
|
||||
datetimes = [datetime.fromtimestamp(ts) for ts in timestamps]
|
||||
datetimes.sort()
|
||||
|
||||
peak_windows: List[Tuple[datetime, datetime]] = []
|
||||
peak_windows: list[tuple[datetime, datetime]] = []
|
||||
window_start_idx = 0
|
||||
|
||||
for i in range(len(datetimes)):
|
||||
@@ -83,7 +82,7 @@ class ChatFrequencyAnalyzer:
|
||||
|
||||
return peak_windows
|
||||
|
||||
def get_peak_chat_times(self, chat_id: str) -> List[Tuple[time, time]]:
|
||||
def get_peak_chat_times(self, chat_id: str) -> list[tuple[time, time]]:
|
||||
"""
|
||||
获取指定用户的高峰聊天时间段。
|
||||
|
||||
@@ -116,7 +115,7 @@ class ChatFrequencyAnalyzer:
|
||||
|
||||
return peak_time_windows
|
||||
|
||||
def is_in_peak_time(self, chat_id: str, now: Optional[datetime] = None) -> bool:
|
||||
def is_in_peak_time(self, chat_id: str, now: datetime | None = None) -> bool:
|
||||
"""
|
||||
检查当前时间是否处于用户的高峰聊天时段内。
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import orjson
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# 数据存储路径
|
||||
@@ -19,10 +19,10 @@ class ChatFrequencyTracker:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._timestamps: Dict[str, List[float]] = self._load_timestamps()
|
||||
self._timestamps: dict[str, list[float]] = self._load_timestamps()
|
||||
|
||||
@staticmethod
|
||||
def _load_timestamps() -> Dict[str, List[float]]:
|
||||
def _load_timestamps() -> dict[str, list[float]]:
|
||||
"""从本地文件加载时间戳数据。"""
|
||||
if not TRACKER_FILE.exists():
|
||||
return {}
|
||||
@@ -61,7 +61,7 @@ class ChatFrequencyTracker:
|
||||
logger.debug(f"为 chat_id '{chat_id}' 记录了新的聊天时间: {now}")
|
||||
self._save_timestamps()
|
||||
|
||||
def get_timestamps_for_chat(self, chat_id: str) -> Optional[List[float]]:
|
||||
def get_timestamps_for_chat(self, chat_id: str) -> list[float] | None:
|
||||
"""
|
||||
获取指定聊天的所有时间戳记录。
|
||||
|
||||
|
||||
@@ -18,11 +18,10 @@ Frequency-Based Proactive Trigger
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
# AFC manager has been moved to chatter plugin
|
||||
|
||||
# AFC manager has been moved to chatter plugin
|
||||
# TODO: 需要重新实现主动思考和睡眠管理功能
|
||||
from .analyzer import chat_frequency_analyzer
|
||||
|
||||
@@ -42,10 +41,10 @@ class FrequencyBasedTrigger:
|
||||
|
||||
def __init__(self):
|
||||
# TODO: 需要重新实现睡眠管理器
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._task: asyncio.Task | None = None
|
||||
# 记录上次为用户触发的时间,用于冷却控制
|
||||
# 格式: { "chat_id": timestamp }
|
||||
self._last_triggered: Dict[str, float] = {}
|
||||
self._last_triggered: dict[str, float] = {}
|
||||
|
||||
async def _run_trigger_cycle(self):
|
||||
"""触发器的主要循环逻辑。"""
|
||||
|
||||
@@ -3,13 +3,14 @@
|
||||
提供机器人兴趣标签和智能匹配功能
|
||||
"""
|
||||
|
||||
from .bot_interest_manager import BotInterestManager, bot_interest_manager
|
||||
from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
|
||||
|
||||
from .bot_interest_manager import BotInterestManager, bot_interest_manager
|
||||
|
||||
__all__ = [
|
||||
"BotInterestManager",
|
||||
"bot_interest_manager",
|
||||
"BotInterestTag",
|
||||
"BotPersonalityInterests",
|
||||
"InterestMatchResult",
|
||||
"bot_interest_manager",
|
||||
]
|
||||
|
||||
@@ -3,17 +3,18 @@
|
||||
基于人设生成兴趣标签,并使用embedding计算匹配度
|
||||
"""
|
||||
|
||||
import orjson
|
||||
import traceback
|
||||
from typing import List, Dict, Optional, Any
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import orjson
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from src.common.data_models.bot_interest_data_model import BotPersonalityInterests, BotInterestTag, InterestMatchResult
|
||||
|
||||
logger = get_logger("bot_interest_manager")
|
||||
|
||||
@@ -22,8 +23,8 @@ class BotInterestManager:
|
||||
"""机器人兴趣标签管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.current_interests: Optional[BotPersonalityInterests] = None
|
||||
self.embedding_cache: Dict[str, List[float]] = {} # embedding缓存
|
||||
self.current_interests: BotPersonalityInterests | None = None
|
||||
self.embedding_cache: dict[str, list[float]] = {} # embedding缓存
|
||||
self._initialized = False
|
||||
|
||||
# Embedding客户端配置
|
||||
@@ -31,7 +32,7 @@ class BotInterestManager:
|
||||
self.embedding_config = None
|
||||
configured_dim = resolve_embedding_dimension()
|
||||
self.embedding_dimension = int(configured_dim) if configured_dim else 0
|
||||
self._detected_embedding_dimension: Optional[int] = None
|
||||
self._detected_embedding_dimension: int | None = None
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
@@ -145,7 +146,7 @@ class BotInterestManager:
|
||||
|
||||
async def _generate_interests_from_personality(
|
||||
self, personality_description: str, personality_id: str
|
||||
) -> Optional[BotPersonalityInterests]:
|
||||
) -> BotPersonalityInterests | None:
|
||||
"""根据人设生成兴趣标签"""
|
||||
try:
|
||||
logger.info("🎨 开始根据人设生成兴趣标签...")
|
||||
@@ -226,14 +227,14 @@ class BotInterestManager:
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
async def _call_llm_for_interest_generation(self, prompt: str) -> Optional[str]:
|
||||
async def _call_llm_for_interest_generation(self, prompt: str) -> str | None:
|
||||
"""调用LLM生成兴趣标签"""
|
||||
try:
|
||||
logger.info("🔧 配置LLM客户端...")
|
||||
|
||||
# 使用llm_api来处理请求
|
||||
from src.plugin_system.apis import llm_api
|
||||
from src.config.config import model_config
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
# 构建完整的提示词,明确要求只返回纯JSON
|
||||
full_prompt = f"""你是一个专业的机器人人设分析师,擅长根据人设描述生成合适的兴趣标签。
|
||||
@@ -342,7 +343,7 @@ class BotInterestManager:
|
||||
logger.info(f"🗃️ 总缓存大小: {len(self.embedding_cache)}")
|
||||
logger.info("=" * 50)
|
||||
|
||||
async def _get_embedding(self, text: str) -> List[float]:
|
||||
async def _get_embedding(self, text: str) -> list[float]:
|
||||
"""获取文本的embedding向量"""
|
||||
if not hasattr(self, "embedding_request"):
|
||||
raise RuntimeError("❌ Embedding请求客户端未初始化")
|
||||
@@ -383,7 +384,7 @@ class BotInterestManager:
|
||||
else:
|
||||
raise RuntimeError(f"❌ 返回的embedding为空: {embedding}")
|
||||
|
||||
async def _generate_message_embedding(self, message_text: str, keywords: List[str]) -> List[float]:
|
||||
async def _generate_message_embedding(self, message_text: str, keywords: list[str]) -> list[float]:
|
||||
"""为消息生成embedding向量"""
|
||||
# 组合消息文本和关键词作为embedding输入
|
||||
if keywords:
|
||||
@@ -399,7 +400,7 @@ class BotInterestManager:
|
||||
return embedding
|
||||
|
||||
async def _calculate_similarity_scores(
|
||||
self, result: InterestMatchResult, message_embedding: List[float], keywords: List[str]
|
||||
self, result: InterestMatchResult, message_embedding: list[float], keywords: list[str]
|
||||
):
|
||||
"""计算消息与兴趣标签的相似度分数"""
|
||||
try:
|
||||
@@ -428,7 +429,7 @@ class BotInterestManager:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 计算相似度分数失败: {e}")
|
||||
|
||||
async def calculate_interest_match(self, message_text: str, keywords: List[str] = None) -> InterestMatchResult:
|
||||
async def calculate_interest_match(self, message_text: str, keywords: list[str] = None) -> InterestMatchResult:
|
||||
"""计算消息与机器人兴趣的匹配度"""
|
||||
if not self.current_interests or not self._initialized:
|
||||
raise RuntimeError("❌ 兴趣标签系统未初始化")
|
||||
@@ -528,7 +529,7 @@ class BotInterestManager:
|
||||
)
|
||||
return result
|
||||
|
||||
def _calculate_keyword_match_bonus(self, keywords: List[str], matched_tags: List[str]) -> Dict[str, float]:
|
||||
def _calculate_keyword_match_bonus(self, keywords: list[str], matched_tags: list[str]) -> dict[str, float]:
|
||||
"""计算关键词直接匹配奖励"""
|
||||
if not keywords or not matched_tags:
|
||||
return {}
|
||||
@@ -610,7 +611,7 @@ class BotInterestManager:
|
||||
|
||||
return previous_row[-1]
|
||||
|
||||
def _calculate_cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
|
||||
def _calculate_cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float:
|
||||
"""计算余弦相似度"""
|
||||
try:
|
||||
vec1 = np.array(vec1)
|
||||
@@ -629,16 +630,17 @@ class BotInterestManager:
|
||||
logger.error(f"计算余弦相似度失败: {e}")
|
||||
return 0.0
|
||||
|
||||
async def _load_interests_from_database(self, personality_id: str) -> Optional[BotPersonalityInterests]:
|
||||
async def _load_interests_from_database(self, personality_id: str) -> BotPersonalityInterests | None:
|
||||
"""从数据库加载兴趣标签"""
|
||||
try:
|
||||
logger.debug(f"从数据库加载兴趣标签, personality_id: {personality_id}")
|
||||
|
||||
# 导入SQLAlchemy相关模块
|
||||
from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
import orjson
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests
|
||||
|
||||
async with get_db_session() as session:
|
||||
# 查询最新的兴趣标签配置
|
||||
db_interests = (
|
||||
@@ -716,10 +718,11 @@ class BotInterestManager:
|
||||
logger.info(f"🔄 版本: {interests.version}")
|
||||
|
||||
# 导入SQLAlchemy相关模块
|
||||
from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
import orjson
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests
|
||||
|
||||
# 将兴趣标签转换为JSON格式
|
||||
tags_data = []
|
||||
for tag in interests.interest_tags:
|
||||
@@ -803,11 +806,11 @@ class BotInterestManager:
|
||||
logger.error("🔍 错误详情:")
|
||||
traceback.print_exc()
|
||||
|
||||
def get_current_interests(self) -> Optional[BotPersonalityInterests]:
|
||||
def get_current_interests(self) -> BotPersonalityInterests | None:
|
||||
"""获取当前的兴趣标签配置"""
|
||||
return self.current_interests
|
||||
|
||||
def get_interest_stats(self) -> Dict[str, Any]:
|
||||
def get_interest_stats(self) -> dict[str, Any]:
|
||||
"""获取兴趣系统统计信息"""
|
||||
if not self.current_interests:
|
||||
return {"initialized": False}
|
||||
|
||||
@@ -1,33 +1,31 @@
|
||||
from dataclasses import dataclass
|
||||
import orjson
|
||||
import os
|
||||
import math
|
||||
import asyncio
|
||||
import math
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from dataclasses import dataclass
|
||||
|
||||
# import tqdm
|
||||
import faiss
|
||||
|
||||
from .utils.hash import get_sha256
|
||||
from .global_logger import logger
|
||||
from rich.traceback import install
|
||||
import numpy as np
|
||||
import orjson
|
||||
import pandas as pd
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
BarColumn,
|
||||
MofNCompleteColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TaskProgressColumn,
|
||||
TextColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
TaskProgressColumn,
|
||||
MofNCompleteColumn,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
from src.config.config import global_config
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from src.config.config import global_config
|
||||
|
||||
from .global_logger import logger
|
||||
from .utils.hash import get_sha256
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -79,7 +77,7 @@ def cosine_similarity(a, b):
|
||||
class EmbeddingStoreItem:
|
||||
"""嵌入库中的项"""
|
||||
|
||||
def __init__(self, item_hash: str, embedding: List[float], content: str):
|
||||
def __init__(self, item_hash: str, embedding: list[float], content: str):
|
||||
self.hash = item_hash
|
||||
self.embedding = embedding
|
||||
self.str = content
|
||||
@@ -127,7 +125,7 @@ class EmbeddingStore:
|
||||
self.idx2hash = None
|
||||
|
||||
@staticmethod
|
||||
def _get_embedding(s: str) -> List[float]:
|
||||
def _get_embedding(s: str) -> list[float]:
|
||||
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
|
||||
# 创建新的事件循环并在完成后立即关闭
|
||||
loop = asyncio.new_event_loop()
|
||||
@@ -135,8 +133,8 @@ class EmbeddingStore:
|
||||
|
||||
try:
|
||||
# 创建新的LLMRequest实例
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
|
||||
@@ -161,8 +159,8 @@ class EmbeddingStore:
|
||||
|
||||
@staticmethod
|
||||
def _get_embeddings_batch_threaded(
|
||||
strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
|
||||
) -> List[Tuple[str, List[float]]]:
|
||||
strs: list[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
|
||||
) -> list[tuple[str, list[float]]]:
|
||||
"""使用多线程批量获取嵌入向量
|
||||
|
||||
Args:
|
||||
@@ -192,8 +190,8 @@ class EmbeddingStore:
|
||||
chunk_results = []
|
||||
|
||||
# 为每个线程创建独立的LLMRequest实例
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
try:
|
||||
# 创建线程专用的LLM实例
|
||||
@@ -303,7 +301,7 @@ class EmbeddingStore:
|
||||
path = self.get_test_file_path()
|
||||
if not os.path.exists(path):
|
||||
return None
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
return orjson.loads(f.read())
|
||||
|
||||
def check_embedding_model_consistency(self):
|
||||
@@ -345,7 +343,7 @@ class EmbeddingStore:
|
||||
logger.info("嵌入模型一致性校验通过。")
|
||||
return True
|
||||
|
||||
def batch_insert_strs(self, strs: List[str], times: int) -> None:
|
||||
def batch_insert_strs(self, strs: list[str], times: int) -> None:
|
||||
"""向库中存入字符串(使用多线程优化)"""
|
||||
if not strs:
|
||||
return
|
||||
@@ -481,7 +479,7 @@ class EmbeddingStore:
|
||||
if os.path.exists(self.idx2hash_file_path):
|
||||
logger.info(f"正在加载{self.namespace}嵌入库的idx2hash映射...")
|
||||
logger.debug(f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射")
|
||||
with open(self.idx2hash_file_path, "r") as f:
|
||||
with open(self.idx2hash_file_path) as f:
|
||||
self.idx2hash = orjson.loads(f.read())
|
||||
logger.info(f"{self.namespace}嵌入库的idx2hash映射加载成功")
|
||||
else:
|
||||
@@ -511,7 +509,7 @@ class EmbeddingStore:
|
||||
self.faiss_index = faiss.IndexFlatIP(embedding_dim)
|
||||
self.faiss_index.add(embeddings)
|
||||
|
||||
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:
|
||||
def search_top_k(self, query: list[float], k: int) -> list[tuple[str, float]]:
|
||||
"""搜索最相似的k个项,以余弦相似度为度量
|
||||
Args:
|
||||
query: 查询的embedding
|
||||
@@ -575,11 +573,11 @@ class EmbeddingManager:
|
||||
"""对所有嵌入库做模型一致性校验"""
|
||||
return self.paragraphs_embedding_store.check_embedding_model_consistency()
|
||||
|
||||
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
|
||||
def _store_pg_into_embedding(self, raw_paragraphs: dict[str, str]):
|
||||
"""将段落编码存入Embedding库"""
|
||||
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1)
|
||||
|
||||
def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
|
||||
def _store_ent_into_embedding(self, triple_list_data: dict[str, list[list[str]]]):
|
||||
"""将实体编码存入Embedding库"""
|
||||
entities = set()
|
||||
for triple_list in triple_list_data.values():
|
||||
@@ -588,7 +586,7 @@ class EmbeddingManager:
|
||||
entities.add(triple[2])
|
||||
self.entities_embedding_store.batch_insert_strs(list(entities), times=2)
|
||||
|
||||
def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
|
||||
def _store_rel_into_embedding(self, triple_list_data: dict[str, list[list[str]]]):
|
||||
"""将关系编码存入Embedding库"""
|
||||
graph_triples = [] # a list of unique relation triple (in tuple) from all chunks
|
||||
for triples in triple_list_data.values():
|
||||
@@ -606,8 +604,8 @@ class EmbeddingManager:
|
||||
|
||||
def store_new_data_set(
|
||||
self,
|
||||
raw_paragraphs: Dict[str, str],
|
||||
triple_list_data: Dict[str, List[List[str]]],
|
||||
raw_paragraphs: dict[str, str],
|
||||
triple_list_data: dict[str, list[list[str]]],
|
||||
):
|
||||
if not self.check_all_embedding_model_consistency():
|
||||
raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。")
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import asyncio
|
||||
import orjson
|
||||
import time
|
||||
from typing import List, Union
|
||||
|
||||
from .global_logger import logger
|
||||
from . import prompt_template
|
||||
from .knowledge_lib import INVALID_ENTITY
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
from . import prompt_template
|
||||
from .global_logger import logger
|
||||
from .knowledge_lib import INVALID_ENTITY
|
||||
|
||||
|
||||
def _extract_json_from_text(text: str):
|
||||
# sourcery skip: assign-if-exp, extract-method
|
||||
@@ -46,7 +47,7 @@ def _extract_json_from_text(text: str):
|
||||
return []
|
||||
|
||||
|
||||
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
||||
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> list[str]:
|
||||
# sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression
|
||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
|
||||
@@ -92,7 +93,7 @@ def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
||||
return entity_extract_result
|
||||
|
||||
|
||||
def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> List[List[str]]:
|
||||
def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> list[list[str]]:
|
||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
|
||||
paragraph, entities=orjson.dumps(entities).decode("utf-8")
|
||||
@@ -141,7 +142,7 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
|
||||
|
||||
def info_extract_from_str(
|
||||
llm_client_for_ner: LLMRequest, llm_client_for_rdf: LLMRequest, paragraph: str
|
||||
) -> Union[tuple[None, None], tuple[list[str], list[list[str]]]]:
|
||||
) -> tuple[None, None] | tuple[list[str], list[list[str]]]:
|
||||
try_count = 0
|
||||
while True:
|
||||
try:
|
||||
|
||||
@@ -1,28 +1,26 @@
|
||||
import orjson
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import orjson
|
||||
import pandas as pd
|
||||
from quick_algo import di_graph, pagerank
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
BarColumn,
|
||||
MofNCompleteColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TaskProgressColumn,
|
||||
TextColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
TaskProgressColumn,
|
||||
MofNCompleteColumn,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
from quick_algo import di_graph, pagerank
|
||||
|
||||
|
||||
from .utils.hash import get_sha256
|
||||
from .embedding_store import EmbeddingManager, EmbeddingStoreItem
|
||||
from src.config.config import global_config
|
||||
|
||||
from .embedding_store import EmbeddingManager, EmbeddingStoreItem
|
||||
from .global_logger import logger
|
||||
from .utils.hash import get_sha256
|
||||
|
||||
|
||||
def _get_kg_dir():
|
||||
@@ -87,7 +85,7 @@ class KGManager:
|
||||
raise FileNotFoundError(f"KG图文件{self.graph_data_path}不存在")
|
||||
|
||||
# 加载段落hash
|
||||
with open(self.pg_hash_file_path, "r", encoding="utf-8") as f:
|
||||
with open(self.pg_hash_file_path, encoding="utf-8") as f:
|
||||
data = orjson.loads(f.read())
|
||||
self.stored_paragraph_hashes = set(data["stored_paragraph_hashes"])
|
||||
|
||||
@@ -100,8 +98,8 @@ class KGManager:
|
||||
|
||||
def _build_edges_between_ent(
|
||||
self,
|
||||
node_to_node: Dict[Tuple[str, str], float],
|
||||
triple_list_data: Dict[str, List[List[str]]],
|
||||
node_to_node: dict[tuple[str, str], float],
|
||||
triple_list_data: dict[str, list[list[str]]],
|
||||
):
|
||||
"""构建实体节点之间的关系,同时统计实体出现次数"""
|
||||
for triple_list in triple_list_data.values():
|
||||
@@ -124,8 +122,8 @@ class KGManager:
|
||||
|
||||
@staticmethod
|
||||
def _build_edges_between_ent_pg(
|
||||
node_to_node: Dict[Tuple[str, str], float],
|
||||
triple_list_data: Dict[str, List[List[str]]],
|
||||
node_to_node: dict[tuple[str, str], float],
|
||||
triple_list_data: dict[str, list[list[str]]],
|
||||
):
|
||||
"""构建实体节点与文段节点之间的关系"""
|
||||
for idx in triple_list_data:
|
||||
@@ -136,8 +134,8 @@ class KGManager:
|
||||
|
||||
@staticmethod
|
||||
def _synonym_connect(
|
||||
node_to_node: Dict[Tuple[str, str], float],
|
||||
triple_list_data: Dict[str, List[List[str]]],
|
||||
node_to_node: dict[tuple[str, str], float],
|
||||
triple_list_data: dict[str, list[list[str]]],
|
||||
embedding_manager: EmbeddingManager,
|
||||
) -> int:
|
||||
"""同义词连接"""
|
||||
@@ -208,7 +206,7 @@ class KGManager:
|
||||
|
||||
def _update_graph(
|
||||
self,
|
||||
node_to_node: Dict[Tuple[str, str], float],
|
||||
node_to_node: dict[tuple[str, str], float],
|
||||
embedding_manager: EmbeddingManager,
|
||||
):
|
||||
"""更新KG图结构
|
||||
@@ -280,7 +278,7 @@ class KGManager:
|
||||
|
||||
def build_kg(
|
||||
self,
|
||||
triple_list_data: Dict[str, List[List[str]]],
|
||||
triple_list_data: dict[str, list[list[str]]],
|
||||
embedding_manager: EmbeddingManager,
|
||||
):
|
||||
"""增量式构建KG
|
||||
@@ -317,8 +315,8 @@ class KGManager:
|
||||
|
||||
def kg_search(
|
||||
self,
|
||||
relation_search_result: List[Tuple[Tuple[str, str, str], float]],
|
||||
paragraph_search_result: List[Tuple[str, float]],
|
||||
relation_search_result: list[tuple[tuple[str, str, str], float]],
|
||||
paragraph_search_result: list[tuple[str, float]],
|
||||
embed_manager: EmbeddingManager,
|
||||
):
|
||||
"""RAG搜索与PageRank
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.qa_manager import QAManager
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.chat.knowledge.global_logger import logger
|
||||
from src.config.config import global_config
|
||||
import os
|
||||
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.global_logger import logger
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.chat.knowledge.qa_manager import QAManager
|
||||
from src.config.config import global_config
|
||||
|
||||
INVALID_ENTITY = [
|
||||
"",
|
||||
"你",
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import orjson
|
||||
import os
|
||||
import glob
|
||||
from typing import Any, Dict, List
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from .knowledge_lib import DATA_PATH, INVALID_ENTITY, ROOT_PATH
|
||||
|
||||
from .knowledge_lib import INVALID_ENTITY, ROOT_PATH, DATA_PATH
|
||||
# from src.manager.local_store_manager import local_storage
|
||||
|
||||
|
||||
def _filter_invalid_entities(entities: List[str]) -> List[str]:
|
||||
def _filter_invalid_entities(entities: list[str]) -> list[str]:
|
||||
"""过滤无效的实体"""
|
||||
valid_entities = set()
|
||||
for entity in entities:
|
||||
@@ -20,7 +21,7 @@ def _filter_invalid_entities(entities: List[str]) -> List[str]:
|
||||
return list(valid_entities)
|
||||
|
||||
|
||||
def _filter_invalid_triples(triples: List[List[str]]) -> List[List[str]]:
|
||||
def _filter_invalid_triples(triples: list[list[str]]) -> list[list[str]]:
|
||||
"""过滤无效的三元组"""
|
||||
unique_triples = set()
|
||||
valid_triples = []
|
||||
@@ -62,7 +63,7 @@ class OpenIE:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
docs: List[Dict[str, Any]],
|
||||
docs: list[dict[str, Any]],
|
||||
avg_ent_chars,
|
||||
avg_ent_words,
|
||||
):
|
||||
@@ -112,7 +113,7 @@ class OpenIE:
|
||||
json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json")))
|
||||
data_list = []
|
||||
for file in json_files:
|
||||
with open(file, "r", encoding="utf-8") as f:
|
||||
with open(file, encoding="utf-8") as f:
|
||||
data = orjson.loads(f.read())
|
||||
data_list.append(data)
|
||||
if not data_list:
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import time
|
||||
from typing import Tuple, List, Dict, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
from .global_logger import logger
|
||||
from .embedding_store import EmbeddingManager
|
||||
from .global_logger import logger
|
||||
from .kg_manager import KGManager
|
||||
|
||||
# from .lpmmconfig import global_config
|
||||
from .utils.dyn_topk import dyn_select_top_k
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.config.config import global_config, model_config
|
||||
|
||||
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
|
||||
|
||||
@@ -26,7 +27,7 @@ class QAManager:
|
||||
|
||||
async def process_query(
|
||||
self, question: str
|
||||
) -> Optional[Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]]:
|
||||
) -> tuple[list[tuple[str, float, float]], dict[str, float] | None] | None:
|
||||
"""处理查询"""
|
||||
|
||||
# 生成问题的Embedding
|
||||
@@ -98,7 +99,7 @@ class QAManager:
|
||||
|
||||
return result, ppr_node_weights
|
||||
|
||||
async def get_knowledge(self, question: str) -> Optional[Dict[str, Any]]:
|
||||
async def get_knowledge(self, question: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
获取知识,返回结构化字典
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import List, Any, Tuple
|
||||
from typing import Any
|
||||
|
||||
|
||||
def dyn_select_top_k(
|
||||
score: List[Tuple[Any, float]], jmp_factor: float, var_factor: float
|
||||
) -> List[Tuple[Any, float, float]]:
|
||||
score: list[tuple[Any, float]], jmp_factor: float, var_factor: float
|
||||
) -> list[tuple[Any, float, float]]:
|
||||
"""动态TopK选择"""
|
||||
# 检查输入列表是否为空
|
||||
if not score:
|
||||
|
||||
@@ -1,37 +1,35 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
简化记忆系统模块
|
||||
移除即时记忆和长期记忆分类,实现统一记忆架构和智能遗忘机制
|
||||
"""
|
||||
|
||||
# 核心数据结构
|
||||
# 激活器
|
||||
from .enhanced_memory_activator import MemoryActivator, enhanced_memory_activator, memory_activator
|
||||
from .memory_chunk import (
|
||||
ConfidenceLevel,
|
||||
ContentStructure,
|
||||
ImportanceLevel,
|
||||
MemoryChunk,
|
||||
MemoryMetadata,
|
||||
ContentStructure,
|
||||
MemoryType,
|
||||
ImportanceLevel,
|
||||
ConfidenceLevel,
|
||||
create_memory_chunk,
|
||||
)
|
||||
|
||||
# 兼容性别名
|
||||
from .memory_chunk import MemoryChunk as Memory
|
||||
|
||||
# 遗忘引擎
|
||||
from .memory_forgetting_engine import MemoryForgettingEngine, ForgettingConfig, get_memory_forgetting_engine
|
||||
|
||||
# Vector DB存储系统
|
||||
from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage
|
||||
|
||||
# 记忆核心系统
|
||||
from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system, initialize_memory_system
|
||||
from .memory_forgetting_engine import ForgettingConfig, MemoryForgettingEngine, get_memory_forgetting_engine
|
||||
|
||||
# 记忆管理器
|
||||
from .memory_manager import MemoryManager, MemoryResult, memory_manager
|
||||
|
||||
# 激活器
|
||||
from .enhanced_memory_activator import MemoryActivator, memory_activator, enhanced_memory_activator
|
||||
# 记忆核心系统
|
||||
from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system, initialize_memory_system
|
||||
|
||||
# 兼容性别名
|
||||
from .memory_chunk import MemoryChunk as Memory
|
||||
# Vector DB存储系统
|
||||
from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage
|
||||
|
||||
__all__ = [
|
||||
# 核心数据结构
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强记忆系统适配器
|
||||
将增强记忆系统集成到现有MoFox Bot架构中
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.integration_layer import MemoryIntegrationLayer, IntegrationConfig, IntegrationMode
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.chat.memory_system.integration_layer import IntegrationConfig, IntegrationMode, MemoryIntegrationLayer
|
||||
from src.chat.memory_system.memory_formatter import FormatterConfig, format_memories_for_llm
|
||||
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -47,10 +47,10 @@ class AdapterConfig:
|
||||
class EnhancedMemoryAdapter:
|
||||
"""增强记忆系统适配器"""
|
||||
|
||||
def __init__(self, llm_model: LLMRequest, config: Optional[AdapterConfig] = None):
|
||||
def __init__(self, llm_model: LLMRequest, config: AdapterConfig | None = None):
|
||||
self.llm_model = llm_model
|
||||
self.config = config or AdapterConfig()
|
||||
self.integration_layer: Optional[MemoryIntegrationLayer] = None
|
||||
self.integration_layer: MemoryIntegrationLayer | None = None
|
||||
self._initialized = False
|
||||
|
||||
# 统计信息
|
||||
@@ -96,7 +96,7 @@ class EnhancedMemoryAdapter:
|
||||
# 如果初始化失败,禁用增强记忆功能
|
||||
self.config.enable_enhanced_memory = False
|
||||
|
||||
async def process_conversation_memory(self, context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
async def process_conversation_memory(self, context: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
"""处理对话记忆,以上下文为唯一输入"""
|
||||
if not self._initialized or not self.config.enable_enhanced_memory:
|
||||
return {"success": False, "error": "Enhanced memory not available"}
|
||||
@@ -105,7 +105,7 @@ class EnhancedMemoryAdapter:
|
||||
self.adapter_stats["total_processed"] += 1
|
||||
|
||||
try:
|
||||
payload_context: Dict[str, Any] = dict(context or {})
|
||||
payload_context: dict[str, Any] = dict(context or {})
|
||||
|
||||
conversation_text = payload_context.get("conversation_text")
|
||||
if not conversation_text:
|
||||
@@ -146,8 +146,8 @@ class EnhancedMemoryAdapter:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def retrieve_relevant_memories(
|
||||
self, query: str, user_id: str, context: Optional[Dict[str, Any]] = None, limit: Optional[int] = None
|
||||
) -> List[MemoryChunk]:
|
||||
self, query: str, user_id: str, context: dict[str, Any] | None = None, limit: int | None = None
|
||||
) -> list[MemoryChunk]:
|
||||
"""检索相关记忆"""
|
||||
if not self._initialized or not self.config.enable_enhanced_memory:
|
||||
return []
|
||||
@@ -166,7 +166,7 @@ class EnhancedMemoryAdapter:
|
||||
return []
|
||||
|
||||
async def get_memory_context_for_prompt(
|
||||
self, query: str, user_id: str, context: Optional[Dict[str, Any]] = None, max_memories: int = 5
|
||||
self, query: str, user_id: str, context: dict[str, Any] | None = None, max_memories: int = 5
|
||||
) -> str:
|
||||
"""获取用于提示词的记忆上下文"""
|
||||
memories = await self.retrieve_relevant_memories(query, user_id, context, max_memories)
|
||||
@@ -186,7 +186,7 @@ class EnhancedMemoryAdapter:
|
||||
|
||||
return format_memories_for_llm(memories=memories, query_context=query, config=formatter_config)
|
||||
|
||||
async def get_enhanced_memory_summary(self, user_id: str) -> Dict[str, Any]:
|
||||
async def get_enhanced_memory_summary(self, user_id: str) -> dict[str, Any]:
|
||||
"""获取增强记忆系统摘要"""
|
||||
if not self._initialized or not self.config.enable_enhanced_memory:
|
||||
return {"available": False, "reason": "Not initialized or disabled"}
|
||||
@@ -222,7 +222,7 @@ class EnhancedMemoryAdapter:
|
||||
new_avg = (current_avg * (total_processed - 1) + processing_time) / total_processed
|
||||
self.adapter_stats["average_processing_time"] = new_avg
|
||||
|
||||
def get_adapter_stats(self) -> Dict[str, Any]:
|
||||
def get_adapter_stats(self) -> dict[str, Any]:
|
||||
"""获取适配器统计信息"""
|
||||
return self.adapter_stats.copy()
|
||||
|
||||
@@ -253,7 +253,7 @@ class EnhancedMemoryAdapter:
|
||||
|
||||
|
||||
# 全局适配器实例
|
||||
_enhanced_memory_adapter: Optional[EnhancedMemoryAdapter] = None
|
||||
_enhanced_memory_adapter: EnhancedMemoryAdapter | None = None
|
||||
|
||||
|
||||
async def get_enhanced_memory_adapter(llm_model: LLMRequest) -> EnhancedMemoryAdapter:
|
||||
@@ -292,8 +292,8 @@ async def initialize_enhanced_memory_system(llm_model: LLMRequest):
|
||||
|
||||
|
||||
async def process_conversation_with_enhanced_memory(
|
||||
context: Dict[str, Any], llm_model: Optional[LLMRequest] = None
|
||||
) -> Dict[str, Any]:
|
||||
context: dict[str, Any], llm_model: LLMRequest | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""使用增强记忆系统处理对话,上下文需包含 conversation_text 等信息"""
|
||||
if not llm_model:
|
||||
# 获取默认的LLM模型
|
||||
@@ -323,10 +323,10 @@ async def process_conversation_with_enhanced_memory(
|
||||
async def retrieve_memories_with_enhanced_system(
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
limit: int = 10,
|
||||
llm_model: Optional[LLMRequest] = None,
|
||||
) -> List[MemoryChunk]:
|
||||
llm_model: LLMRequest | None = None,
|
||||
) -> list[MemoryChunk]:
|
||||
"""使用增强记忆系统检索记忆"""
|
||||
if not llm_model:
|
||||
# 获取默认的LLM模型
|
||||
@@ -345,9 +345,9 @@ async def retrieve_memories_with_enhanced_system(
|
||||
async def get_memory_context_for_prompt(
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
max_memories: int = 5,
|
||||
llm_model: Optional[LLMRequest] = None,
|
||||
llm_model: LLMRequest | None = None,
|
||||
) -> str:
|
||||
"""获取用于提示词的记忆上下文"""
|
||||
if not llm_model:
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强记忆系统钩子
|
||||
用于在消息处理过程中自动构建和检索记忆
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -27,7 +27,7 @@ class EnhancedMemoryHooks:
|
||||
user_id: str,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
处理消息并构建记忆
|
||||
@@ -106,8 +106,8 @@ class EnhancedMemoryHooks:
|
||||
user_id: str,
|
||||
chat_id: str,
|
||||
limit: int = 5,
|
||||
extra_context: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
extra_context: dict[str, Any] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
为回复获取相关记忆
|
||||
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强记忆系统集成脚本
|
||||
用于在现有系统中无缝集成增强记忆功能
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from src.chat.memory_system.enhanced_memory_hooks import enhanced_memory_hooks
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.enhanced_memory_hooks import enhanced_memory_hooks
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def process_user_message_memory(
|
||||
message_content: str, user_id: str, chat_id: str, message_id: str, context: Optional[Dict[str, Any]] = None
|
||||
message_content: str, user_id: str, chat_id: str, message_id: str, context: dict[str, Any] | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
处理用户消息并构建记忆
|
||||
@@ -44,8 +44,8 @@ async def process_user_message_memory(
|
||||
|
||||
|
||||
async def get_relevant_memories_for_response(
|
||||
query_text: str, user_id: str, chat_id: str, limit: int = 5, extra_context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
query_text: str, user_id: str, chat_id: str, limit: int = 5, extra_context: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
为回复获取相关记忆
|
||||
|
||||
@@ -74,7 +74,7 @@ async def get_relevant_memories_for_response(
|
||||
return {"has_memories": False, "memories": [], "memory_count": 0}
|
||||
|
||||
|
||||
def format_memories_for_prompt(memories: Dict[str, Any]) -> str:
|
||||
def format_memories_for_prompt(memories: dict[str, Any]) -> str:
|
||||
"""
|
||||
格式化记忆信息用于Prompt
|
||||
|
||||
@@ -114,7 +114,7 @@ async def cleanup_memory_system():
|
||||
logger.error(f"记忆系统清理失败: {e}")
|
||||
|
||||
|
||||
def get_memory_system_status() -> Dict[str, Any]:
|
||||
def get_memory_system_status() -> dict[str, Any]:
|
||||
"""
|
||||
获取记忆系统状态
|
||||
|
||||
@@ -133,7 +133,7 @@ def get_memory_system_status() -> Dict[str, Any]:
|
||||
|
||||
# 便捷函数
|
||||
async def remember_message(
|
||||
message: str, user_id: str = "default_user", chat_id: str = "default_chat", context: Optional[Dict[str, Any]] = None
|
||||
message: str, user_id: str = "default_user", chat_id: str = "default_chat", context: dict[str, Any] | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
便捷的记忆构建函数
|
||||
@@ -159,8 +159,8 @@ async def recall_memories(
|
||||
user_id: str = "default_user",
|
||||
chat_id: str = "default_chat",
|
||||
limit: int = 5,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
便捷的记忆检索函数
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强重排序器
|
||||
实现文档设计的多维度评分模型
|
||||
@@ -6,12 +5,12 @@
|
||||
|
||||
import math
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -44,7 +43,7 @@ class ReRankingConfig:
|
||||
freq_max_score: float = 5.0 # 最大频率得分
|
||||
|
||||
# 类型匹配权重映射
|
||||
type_match_weights: Dict[str, Dict[str, float]] = None
|
||||
type_match_weights: dict[str, dict[str, float]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化类型匹配权重"""
|
||||
@@ -157,7 +156,7 @@ class IntentClassifier:
|
||||
],
|
||||
}
|
||||
|
||||
def classify_intent(self, query: str, context: Dict[str, Any]) -> IntentType:
|
||||
def classify_intent(self, query: str, context: dict[str, Any]) -> IntentType:
|
||||
"""识别对话意图"""
|
||||
if not query:
|
||||
return IntentType.UNKNOWN
|
||||
@@ -165,7 +164,7 @@ class IntentClassifier:
|
||||
query_lower = query.lower()
|
||||
|
||||
# 统计各意图的匹配分数
|
||||
intent_scores = {intent: 0 for intent in IntentType}
|
||||
intent_scores = dict.fromkeys(IntentType, 0)
|
||||
|
||||
for intent, patterns in self.patterns.items():
|
||||
for pattern in patterns:
|
||||
@@ -187,7 +186,7 @@ class IntentClassifier:
|
||||
class EnhancedReRanker:
|
||||
"""增强重排序器 - 实现文档设计的多维度评分模型"""
|
||||
|
||||
def __init__(self, config: Optional[ReRankingConfig] = None):
|
||||
def __init__(self, config: ReRankingConfig | None = None):
|
||||
self.config = config or ReRankingConfig()
|
||||
self.intent_classifier = IntentClassifier()
|
||||
|
||||
@@ -210,10 +209,10 @@ class EnhancedReRanker:
|
||||
def rerank_memories(
|
||||
self,
|
||||
query: str,
|
||||
candidate_memories: List[Tuple[str, MemoryChunk, float]], # (memory_id, memory, vector_similarity)
|
||||
context: Dict[str, Any],
|
||||
candidate_memories: list[tuple[str, MemoryChunk, float]], # (memory_id, memory, vector_similarity)
|
||||
context: dict[str, Any],
|
||||
limit: int = 10,
|
||||
) -> List[Tuple[str, MemoryChunk, float]]:
|
||||
) -> list[tuple[str, MemoryChunk, float]]:
|
||||
"""
|
||||
对候选记忆进行重排序
|
||||
|
||||
@@ -341,11 +340,11 @@ default_reranker = EnhancedReRanker()
|
||||
|
||||
def rerank_candidate_memories(
|
||||
query: str,
|
||||
candidate_memories: List[Tuple[str, MemoryChunk, float]],
|
||||
context: Dict[str, Any],
|
||||
candidate_memories: list[tuple[str, MemoryChunk, float]],
|
||||
context: dict[str, Any],
|
||||
limit: int = 10,
|
||||
config: Optional[ReRankingConfig] = None,
|
||||
) -> List[Tuple[str, MemoryChunk, float]]:
|
||||
config: ReRankingConfig | None = None,
|
||||
) -> list[tuple[str, MemoryChunk, float]]:
|
||||
"""
|
||||
便捷函数:对候选记忆进行重排序
|
||||
"""
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强记忆系统集成层
|
||||
现在只管理新的增强记忆系统,旧系统已被完全移除
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Any
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.enhanced_memory_core import EnhancedMemorySystem
|
||||
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -40,12 +40,12 @@ class IntegrationConfig:
|
||||
class MemoryIntegrationLayer:
|
||||
"""记忆系统集成层 - 现在只管理增强记忆系统"""
|
||||
|
||||
def __init__(self, llm_model: LLMRequest, config: Optional[IntegrationConfig] = None):
|
||||
def __init__(self, llm_model: LLMRequest, config: IntegrationConfig | None = None):
|
||||
self.llm_model = llm_model
|
||||
self.config = config or IntegrationConfig()
|
||||
|
||||
# 只初始化增强记忆系统
|
||||
self.enhanced_memory: Optional[EnhancedMemorySystem] = None
|
||||
self.enhanced_memory: EnhancedMemorySystem | None = None
|
||||
|
||||
# 集成统计
|
||||
self.integration_stats = {
|
||||
@@ -113,7 +113,7 @@ class MemoryIntegrationLayer:
|
||||
logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def process_conversation(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def process_conversation(self, context: dict[str, Any]) -> dict[str, Any]:
|
||||
"""处理对话记忆,仅使用上下文信息"""
|
||||
if not self._initialized or not self.enhanced_memory:
|
||||
return {"success": False, "error": "Memory system not available"}
|
||||
@@ -150,10 +150,10 @@ class MemoryIntegrationLayer:
|
||||
async def retrieve_relevant_memories(
|
||||
self,
|
||||
query: str,
|
||||
user_id: Optional[str] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[MemoryChunk]:
|
||||
user_id: str | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[MemoryChunk]:
|
||||
"""检索相关记忆"""
|
||||
if not self._initialized or not self.enhanced_memory:
|
||||
return []
|
||||
@@ -172,7 +172,7 @@ class MemoryIntegrationLayer:
|
||||
logger.error(f"检索相关记忆失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def get_system_status(self) -> Dict[str, Any]:
|
||||
async def get_system_status(self) -> dict[str, Any]:
|
||||
"""获取系统状态"""
|
||||
if not self._initialized:
|
||||
return {"status": "not_initialized"}
|
||||
@@ -193,7 +193,7 @@ class MemoryIntegrationLayer:
|
||||
logger.error(f"获取系统状态失败: {e}", exc_info=True)
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def get_integration_stats(self) -> Dict[str, Any]:
|
||||
def get_integration_stats(self) -> dict[str, Any]:
|
||||
"""获取集成统计信息"""
|
||||
return self.integration_stats.copy()
|
||||
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆系统集成钩子
|
||||
提供与现有MoFox Bot系统的无缝集成点
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.enhanced_memory_adapter import (
|
||||
get_memory_context_for_prompt,
|
||||
process_conversation_with_enhanced_memory,
|
||||
retrieve_memories_with_enhanced_system,
|
||||
get_memory_context_for_prompt,
|
||||
)
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ class HookResult:
|
||||
|
||||
success: bool
|
||||
data: Any = None
|
||||
error: Optional[str] = None
|
||||
error: str | None = None
|
||||
processing_time: float = 0.0
|
||||
|
||||
|
||||
@@ -125,8 +125,8 @@ class MemoryIntegrationHooks:
|
||||
|
||||
# 尝试注册到事件系统
|
||||
try:
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
# 注册消息后处理事件
|
||||
event_manager.subscribe(EventType.MESSAGE_PROCESSED, self._on_message_processed_handler)
|
||||
@@ -238,11 +238,11 @@ class MemoryIntegrationHooks:
|
||||
|
||||
# 钩子处理器方法
|
||||
|
||||
async def _on_message_processed_handler(self, event_data: Dict[str, Any]) -> HookResult:
|
||||
async def _on_message_processed_handler(self, event_data: dict[str, Any]) -> HookResult:
|
||||
"""事件系统的消息处理处理器"""
|
||||
return await self._on_message_processed_hook(event_data)
|
||||
|
||||
async def _on_message_processed_hook(self, message_data: Dict[str, Any]) -> HookResult:
|
||||
async def _on_message_processed_hook(self, message_data: dict[str, Any]) -> HookResult:
|
||||
"""消息后处理钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -289,7 +289,7 @@ class MemoryIntegrationHooks:
|
||||
logger.error(f"消息处理钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
async def _on_chat_stream_save_hook(self, chat_stream_data: Dict[str, Any]) -> HookResult:
|
||||
async def _on_chat_stream_save_hook(self, chat_stream_data: dict[str, Any]) -> HookResult:
|
||||
"""聊天流保存钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -345,7 +345,7 @@ class MemoryIntegrationHooks:
|
||||
logger.error(f"聊天流保存钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
async def _on_pre_response_hook(self, response_data: Dict[str, Any]) -> HookResult:
|
||||
async def _on_pre_response_hook(self, response_data: dict[str, Any]) -> HookResult:
|
||||
"""回复前钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -380,7 +380,7 @@ class MemoryIntegrationHooks:
|
||||
logger.error(f"回复前钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
async def _on_knowledge_query_hook(self, query_data: Dict[str, Any]) -> HookResult:
|
||||
async def _on_knowledge_query_hook(self, query_data: dict[str, Any]) -> HookResult:
|
||||
"""知识库查询钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -411,7 +411,7 @@ class MemoryIntegrationHooks:
|
||||
logger.error(f"知识库查询钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
async def _on_prompt_building_hook(self, prompt_data: Dict[str, Any]) -> HookResult:
|
||||
async def _on_prompt_building_hook(self, prompt_data: dict[str, Any]) -> HookResult:
|
||||
"""提示词构建钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -459,7 +459,7 @@ class MemoryIntegrationHooks:
|
||||
new_avg = (current_avg * (total_executions - 1) + processing_time) / total_executions
|
||||
self.hook_stats["average_hook_time"] = new_avg
|
||||
|
||||
def get_hook_stats(self) -> Dict[str, Any]:
|
||||
def get_hook_stats(self) -> dict[str, Any]:
|
||||
"""获取钩子统计信息"""
|
||||
return self.hook_stats.copy()
|
||||
|
||||
@@ -501,7 +501,7 @@ class MemoryMaintenanceTask:
|
||||
|
||||
|
||||
# 全局钩子实例
|
||||
_memory_hooks: Optional[MemoryIntegrationHooks] = None
|
||||
_memory_hooks: MemoryIntegrationHooks | None = None
|
||||
|
||||
|
||||
async def get_memory_integration_hooks() -> MemoryIntegrationHooks:
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
元数据索引系统
|
||||
为记忆系统提供多维度的精准过滤和查询能力
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
import orjson
|
||||
from typing import Dict, List, Optional, Tuple, Set, Any, Union
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk, MemoryType
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -40,21 +40,21 @@ class IndexType(Enum):
|
||||
class IndexQuery:
|
||||
"""索引查询条件"""
|
||||
|
||||
user_ids: Optional[List[str]] = None
|
||||
memory_types: Optional[List[MemoryType]] = None
|
||||
subjects: Optional[List[str]] = None
|
||||
keywords: Optional[List[str]] = None
|
||||
tags: Optional[List[str]] = None
|
||||
categories: Optional[List[str]] = None
|
||||
time_range: Optional[Tuple[float, float]] = None
|
||||
confidence_levels: Optional[List[ConfidenceLevel]] = None
|
||||
importance_levels: Optional[List[ImportanceLevel]] = None
|
||||
min_relationship_score: Optional[float] = None
|
||||
max_relationship_score: Optional[float] = None
|
||||
min_access_count: Optional[int] = None
|
||||
semantic_hashes: Optional[List[str]] = None
|
||||
limit: Optional[int] = None
|
||||
sort_by: Optional[str] = None # "created_at", "access_count", "relevance_score"
|
||||
user_ids: list[str] | None = None
|
||||
memory_types: list[MemoryType] | None = None
|
||||
subjects: list[str] | None = None
|
||||
keywords: list[str] | None = None
|
||||
tags: list[str] | None = None
|
||||
categories: list[str] | None = None
|
||||
time_range: tuple[float, float] | None = None
|
||||
confidence_levels: list[ConfidenceLevel] | None = None
|
||||
importance_levels: list[ImportanceLevel] | None = None
|
||||
min_relationship_score: float | None = None
|
||||
max_relationship_score: float | None = None
|
||||
min_access_count: int | None = None
|
||||
semantic_hashes: list[str] | None = None
|
||||
limit: int | None = None
|
||||
sort_by: str | None = None # "created_at", "access_count", "relevance_score"
|
||||
sort_order: str = "desc" # "asc", "desc"
|
||||
|
||||
|
||||
@@ -62,10 +62,10 @@ class IndexQuery:
|
||||
class IndexResult:
|
||||
"""索引结果"""
|
||||
|
||||
memory_ids: List[str]
|
||||
memory_ids: list[str]
|
||||
total_count: int
|
||||
query_time: float
|
||||
filtered_by: List[str]
|
||||
filtered_by: list[str]
|
||||
|
||||
|
||||
class MetadataIndexManager:
|
||||
@@ -94,7 +94,7 @@ class MetadataIndexManager:
|
||||
self.access_frequency_index = [] # [(access_count, memory_id), ...]
|
||||
|
||||
# 内存缓存
|
||||
self.memory_metadata_cache: Dict[str, Dict[str, Any]] = {}
|
||||
self.memory_metadata_cache: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# 统计信息
|
||||
self.index_stats = {
|
||||
@@ -140,7 +140,7 @@ class MetadataIndexManager:
|
||||
return key
|
||||
|
||||
@staticmethod
|
||||
def _serialize_metadata_entry(metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _serialize_metadata_entry(metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
serialized = {}
|
||||
for field_name, value in metadata.items():
|
||||
if isinstance(value, Enum):
|
||||
@@ -149,7 +149,7 @@ class MetadataIndexManager:
|
||||
serialized[field_name] = value
|
||||
return serialized
|
||||
|
||||
async def index_memories(self, memories: List[MemoryChunk]):
|
||||
async def index_memories(self, memories: list[MemoryChunk]):
|
||||
"""为记忆建立索引"""
|
||||
if not memories:
|
||||
return
|
||||
@@ -375,7 +375,7 @@ class MetadataIndexManager:
|
||||
logger.error(f"❌ 元数据查询失败: {e}", exc_info=True)
|
||||
return IndexResult(memory_ids=[], total_count=0, query_time=0.0, filtered_by=[])
|
||||
|
||||
def _get_candidate_memories(self, query: IndexQuery) -> Set[str]:
|
||||
def _get_candidate_memories(self, query: IndexQuery) -> set[str]:
|
||||
"""获取候选记忆ID集合"""
|
||||
candidate_ids = set()
|
||||
|
||||
@@ -444,7 +444,7 @@ class MetadataIndexManager:
|
||||
|
||||
return candidate_ids
|
||||
|
||||
def _collect_index_matches(self, index_type: IndexType, token: Optional[Union[str, Enum]]) -> Set[str]:
|
||||
def _collect_index_matches(self, index_type: IndexType, token: str | Enum | None) -> set[str]:
|
||||
"""根据给定token收集索引匹配,支持部分匹配"""
|
||||
mapping = self.indices.get(index_type)
|
||||
if mapping is None:
|
||||
@@ -461,7 +461,7 @@ class MetadataIndexManager:
|
||||
if not key:
|
||||
return set()
|
||||
|
||||
matches: Set[str] = set(mapping.get(key, set()))
|
||||
matches: set[str] = set(mapping.get(key, set()))
|
||||
|
||||
if matches:
|
||||
return set(matches)
|
||||
@@ -477,7 +477,7 @@ class MetadataIndexManager:
|
||||
|
||||
return matches
|
||||
|
||||
def _apply_filters(self, candidate_ids: Set[str], query: IndexQuery) -> List[str]:
|
||||
def _apply_filters(self, candidate_ids: set[str], query: IndexQuery) -> list[str]:
|
||||
"""应用过滤条件"""
|
||||
filtered_ids = list(candidate_ids)
|
||||
|
||||
@@ -545,7 +545,7 @@ class MetadataIndexManager:
|
||||
created_at = self.memory_metadata_cache[memory_id]["created_at"]
|
||||
return start_time <= created_at <= end_time
|
||||
|
||||
def _sort_memories(self, memory_ids: List[str], sort_by: str, sort_order: str) -> List[str]:
|
||||
def _sort_memories(self, memory_ids: list[str], sort_by: str, sort_order: str) -> list[str]:
|
||||
"""对记忆进行排序"""
|
||||
if sort_by == "created_at":
|
||||
# 使用时间索引(已经有序)
|
||||
@@ -582,7 +582,7 @@ class MetadataIndexManager:
|
||||
|
||||
return memory_ids
|
||||
|
||||
def _get_applied_filters(self, query: IndexQuery) -> List[str]:
|
||||
def _get_applied_filters(self, query: IndexQuery) -> list[str]:
|
||||
"""获取应用的过滤器列表"""
|
||||
filters = []
|
||||
if query.memory_types:
|
||||
@@ -686,11 +686,11 @@ class MetadataIndexManager:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 移除记忆索引失败: {e}")
|
||||
|
||||
async def get_memory_metadata(self, memory_id: str) -> Optional[Dict[str, Any]]:
|
||||
async def get_memory_metadata(self, memory_id: str) -> dict[str, Any] | None:
|
||||
"""获取记忆元数据"""
|
||||
return self.memory_metadata_cache.get(memory_id)
|
||||
|
||||
async def get_user_memory_ids(self, user_id: str, limit: Optional[int] = None) -> List[str]:
|
||||
async def get_user_memory_ids(self, user_id: str, limit: int | None = None) -> list[str]:
|
||||
"""获取用户的所有记忆ID"""
|
||||
user_memory_ids = list(self.indices[IndexType.USER_ID].get(user_id, set()))
|
||||
|
||||
@@ -699,7 +699,7 @@ class MetadataIndexManager:
|
||||
|
||||
return user_memory_ids
|
||||
|
||||
async def get_memory_statistics(self, user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
async def get_memory_statistics(self, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""获取记忆统计信息"""
|
||||
stats = {
|
||||
"total_memories": self.index_stats["total_memories"],
|
||||
@@ -784,7 +784,7 @@ class MetadataIndexManager:
|
||||
logger.info("正在保存元数据索引...")
|
||||
|
||||
# 保存各类索引
|
||||
indices_data: Dict[str, Dict[str, List[str]]] = {}
|
||||
indices_data: dict[str, dict[str, list[str]]] = {}
|
||||
for index_type, index_data in self.indices.items():
|
||||
serialized_index = {}
|
||||
for key, values in index_data.items():
|
||||
@@ -839,7 +839,7 @@ class MetadataIndexManager:
|
||||
# 加载各类索引
|
||||
indices_file = self.index_path / "indices.json"
|
||||
if indices_file.exists():
|
||||
with open(indices_file, "r", encoding="utf-8") as f:
|
||||
with open(indices_file, encoding="utf-8") as f:
|
||||
indices_data = orjson.loads(f.read())
|
||||
|
||||
for index_type_value, index_data in indices_data.items():
|
||||
@@ -853,25 +853,25 @@ class MetadataIndexManager:
|
||||
# 加载时间索引
|
||||
time_index_file = self.index_path / "time_index.json"
|
||||
if time_index_file.exists():
|
||||
with open(time_index_file, "r", encoding="utf-8") as f:
|
||||
with open(time_index_file, encoding="utf-8") as f:
|
||||
self.time_index = orjson.loads(f.read())
|
||||
|
||||
# 加载关系分索引
|
||||
relationship_index_file = self.index_path / "relationship_index.json"
|
||||
if relationship_index_file.exists():
|
||||
with open(relationship_index_file, "r", encoding="utf-8") as f:
|
||||
with open(relationship_index_file, encoding="utf-8") as f:
|
||||
self.relationship_index = orjson.loads(f.read())
|
||||
|
||||
# 加载访问频率索引
|
||||
access_frequency_index_file = self.index_path / "access_frequency_index.json"
|
||||
if access_frequency_index_file.exists():
|
||||
with open(access_frequency_index_file, "r", encoding="utf-8") as f:
|
||||
with open(access_frequency_index_file, encoding="utf-8") as f:
|
||||
self.access_frequency_index = orjson.loads(f.read())
|
||||
|
||||
# 加载元数据缓存
|
||||
metadata_cache_file = self.index_path / "metadata_cache.json"
|
||||
if metadata_cache_file.exists():
|
||||
with open(metadata_cache_file, "r", encoding="utf-8") as f:
|
||||
with open(metadata_cache_file, encoding="utf-8") as f:
|
||||
cache_data = orjson.loads(f.read())
|
||||
|
||||
# 转换置信度和重要性为枚举类型
|
||||
@@ -914,7 +914,7 @@ class MetadataIndexManager:
|
||||
# 加载统计信息
|
||||
stats_file = self.index_path / "index_stats.json"
|
||||
if stats_file.exists():
|
||||
with open(stats_file, "r", encoding="utf-8") as f:
|
||||
with open(stats_file, encoding="utf-8") as f:
|
||||
self.index_stats = orjson.loads(f.read())
|
||||
|
||||
# 更新记忆计数
|
||||
@@ -1004,7 +1004,7 @@ class MetadataIndexManager:
|
||||
if len(self.indices[IndexType.CATEGORY][category]) < min_frequency:
|
||||
del self.indices[IndexType.CATEGORY][category]
|
||||
|
||||
def get_index_stats(self) -> Dict[str, Any]:
|
||||
def get_index_stats(self) -> dict[str, Any]:
|
||||
"""获取索引统计信息"""
|
||||
stats = self.index_stats.copy()
|
||||
if stats["total_queries"] > 0:
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
多阶段召回机制
|
||||
实现粗粒度到细粒度的记忆检索优化
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional, Set, Any
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import orjson
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
import orjson
|
||||
from src.chat.memory_system.enhanced_reranker import EnhancedReRanker, ReRankingConfig
|
||||
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -73,11 +73,11 @@ class StageResult:
|
||||
"""阶段结果"""
|
||||
|
||||
stage: RetrievalStage
|
||||
memory_ids: List[str]
|
||||
memory_ids: list[str]
|
||||
processing_time: float
|
||||
filtered_count: int
|
||||
score_threshold: float
|
||||
details: List[Dict[str, Any]] = field(default_factory=list)
|
||||
details: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -86,17 +86,17 @@ class RetrievalResult:
|
||||
|
||||
query: str
|
||||
user_id: str
|
||||
final_memories: List[MemoryChunk]
|
||||
stage_results: List[StageResult]
|
||||
final_memories: list[MemoryChunk]
|
||||
stage_results: list[StageResult]
|
||||
total_processing_time: float
|
||||
total_filtered: int
|
||||
retrieval_stats: Dict[str, Any]
|
||||
retrieval_stats: dict[str, Any]
|
||||
|
||||
|
||||
class MultiStageRetrieval:
|
||||
"""多阶段召回系统"""
|
||||
|
||||
def __init__(self, config: Optional[RetrievalConfig] = None):
|
||||
def __init__(self, config: RetrievalConfig | None = None):
|
||||
self.config = config or RetrievalConfig.from_global_config()
|
||||
|
||||
# 初始化增强重排序器
|
||||
@@ -124,11 +124,11 @@ class MultiStageRetrieval:
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
metadata_index,
|
||||
vector_storage,
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
limit: Optional[int] = None,
|
||||
all_memories_cache: dict[str, MemoryChunk],
|
||||
limit: int | None = None,
|
||||
) -> RetrievalResult:
|
||||
"""多阶段记忆检索"""
|
||||
start_time = time.time()
|
||||
@@ -136,7 +136,7 @@ class MultiStageRetrieval:
|
||||
|
||||
stage_results = []
|
||||
current_memory_ids = set()
|
||||
memory_debug_info: Dict[str, Dict[str, Any]] = {}
|
||||
memory_debug_info: dict[str, dict[str, Any]] = {}
|
||||
|
||||
try:
|
||||
logger.debug(f"开始多阶段检索:query='{query}', user_id='{user_id}'")
|
||||
@@ -311,11 +311,11 @@ class MultiStageRetrieval:
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
metadata_index,
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
all_memories_cache: dict[str, MemoryChunk],
|
||||
*,
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
debug_log: dict[str, dict[str, Any]] | None = None,
|
||||
) -> StageResult:
|
||||
"""阶段1:元数据过滤"""
|
||||
start_time = time.time()
|
||||
@@ -345,7 +345,7 @@ class MultiStageRetrieval:
|
||||
result = await metadata_index.query_memories(index_query)
|
||||
result_ids = list(result.memory_ids)
|
||||
filtered_count = max(0, len(all_memories_cache) - len(result_ids))
|
||||
details: List[Dict[str, Any]] = []
|
||||
details: list[dict[str, Any]] = []
|
||||
|
||||
# 如果未命中任何索引且未指定所有者过滤,则回退到最近访问的记忆
|
||||
if not result_ids:
|
||||
@@ -440,12 +440,12 @@ class MultiStageRetrieval:
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
vector_storage,
|
||||
candidate_ids: Set[str],
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
candidate_ids: set[str],
|
||||
all_memories_cache: dict[str, MemoryChunk],
|
||||
*,
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
debug_log: dict[str, dict[str, Any]] | None = None,
|
||||
) -> StageResult:
|
||||
"""阶段2:向量搜索"""
|
||||
start_time = time.time()
|
||||
@@ -479,8 +479,8 @@ class MultiStageRetrieval:
|
||||
|
||||
# 过滤候选记忆
|
||||
filtered_memories = []
|
||||
details: List[Dict[str, Any]] = []
|
||||
raw_details: List[Dict[str, Any]] = []
|
||||
details: list[dict[str, Any]] = []
|
||||
raw_details: list[dict[str, Any]] = []
|
||||
threshold = self.config.vector_similarity_threshold
|
||||
|
||||
for memory_id, similarity in search_result:
|
||||
@@ -561,7 +561,7 @@ class MultiStageRetrieval:
|
||||
)
|
||||
|
||||
def _create_text_search_fallback(
|
||||
self, candidate_ids: Set[str], all_memories_cache: Dict[str, MemoryChunk], query_text: str, start_time: float
|
||||
self, candidate_ids: set[str], all_memories_cache: dict[str, MemoryChunk], query_text: str, start_time: float
|
||||
) -> StageResult:
|
||||
"""当向量搜索失败时,使用文本搜索作为回退策略"""
|
||||
try:
|
||||
@@ -618,18 +618,18 @@ class MultiStageRetrieval:
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
candidate_ids: Set[str],
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
context: dict[str, Any],
|
||||
candidate_ids: set[str],
|
||||
all_memories_cache: dict[str, MemoryChunk],
|
||||
*,
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
debug_log: dict[str, dict[str, Any]] | None = None,
|
||||
) -> StageResult:
|
||||
"""阶段3:语义重排序"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
reranked_memories = []
|
||||
details: List[Dict[str, Any]] = []
|
||||
details: list[dict[str, Any]] = []
|
||||
threshold = self.config.semantic_similarity_threshold
|
||||
|
||||
for memory_id in candidate_ids:
|
||||
@@ -704,19 +704,19 @@ class MultiStageRetrieval:
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
candidate_ids: List[str],
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
context: dict[str, Any],
|
||||
candidate_ids: list[str],
|
||||
all_memories_cache: dict[str, MemoryChunk],
|
||||
limit: int,
|
||||
*,
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
debug_log: dict[str, dict[str, Any]] | None = None,
|
||||
) -> StageResult:
|
||||
"""阶段4:上下文过滤"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
final_memories = []
|
||||
details: List[Dict[str, Any]] = []
|
||||
details: list[dict[str, Any]] = []
|
||||
|
||||
for memory_id in candidate_ids:
|
||||
if memory_id not in all_memories_cache:
|
||||
@@ -793,12 +793,12 @@ class MultiStageRetrieval:
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
context: dict[str, Any],
|
||||
all_memories_cache: dict[str, MemoryChunk],
|
||||
limit: int,
|
||||
*,
|
||||
excluded_ids: Optional[Set[str]] = None,
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
excluded_ids: set[str] | None = None,
|
||||
debug_log: dict[str, dict[str, Any]] | None = None,
|
||||
) -> StageResult:
|
||||
"""回退检索阶段 - 当主检索失败时使用更宽松的策略"""
|
||||
start_time = time.time()
|
||||
@@ -881,8 +881,8 @@ class MultiStageRetrieval:
|
||||
)
|
||||
|
||||
async def _generate_query_embedding(
|
||||
self, query: str, context: Dict[str, Any], vector_storage
|
||||
) -> Optional[List[float]]:
|
||||
self, query: str, context: dict[str, Any], vector_storage
|
||||
) -> list[float] | None:
|
||||
"""生成查询向量"""
|
||||
try:
|
||||
query_plan = context.get("query_plan")
|
||||
@@ -916,7 +916,7 @@ class MultiStageRetrieval:
|
||||
logger.error(f"生成查询向量时发生异常: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def _calculate_semantic_similarity(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
|
||||
async def _calculate_semantic_similarity(self, query: str, memory: MemoryChunk, context: dict[str, Any]) -> float:
|
||||
"""计算语义相似度 - 简化优化版本,提升召回率"""
|
||||
try:
|
||||
query_plan = context.get("query_plan")
|
||||
@@ -947,9 +947,10 @@ class MultiStageRetrieval:
|
||||
# 核心匹配策略2:词汇匹配
|
||||
word_score = 0.0
|
||||
try:
|
||||
import jieba
|
||||
import re
|
||||
|
||||
import jieba
|
||||
|
||||
# 分词处理
|
||||
query_words = list(jieba.cut(query_text)) + re.findall(r"[a-zA-Z]+", query_text)
|
||||
memory_words = list(jieba.cut(memory_text)) + re.findall(r"[a-zA-Z]+", memory_text)
|
||||
@@ -1059,7 +1060,7 @@ class MultiStageRetrieval:
|
||||
logger.warning(f"计算语义相似度失败: {e}")
|
||||
return 0.0
|
||||
|
||||
async def _calculate_context_relevance(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
|
||||
async def _calculate_context_relevance(self, query: str, memory: MemoryChunk, context: dict[str, Any]) -> float:
|
||||
"""计算上下文相关度"""
|
||||
try:
|
||||
score = 0.0
|
||||
@@ -1132,7 +1133,7 @@ class MultiStageRetrieval:
|
||||
return 0.0
|
||||
|
||||
async def _calculate_final_score(
|
||||
self, query: str, memory: MemoryChunk, context: Dict[str, Any], context_score: float
|
||||
self, query: str, memory: MemoryChunk, context: dict[str, Any], context_score: float
|
||||
) -> float:
|
||||
"""计算最终评分"""
|
||||
try:
|
||||
@@ -1184,7 +1185,7 @@ class MultiStageRetrieval:
|
||||
logger.warning(f"计算最终评分失败: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_subject_overlap(self, memory: MemoryChunk, required_subjects: Optional[List[str]]) -> float:
|
||||
def _calculate_subject_overlap(self, memory: MemoryChunk, required_subjects: list[str] | None) -> float:
|
||||
if not required_subjects:
|
||||
return 0.0
|
||||
|
||||
@@ -1229,7 +1230,7 @@ class MultiStageRetrieval:
|
||||
except Exception:
|
||||
return 0.5
|
||||
|
||||
def _extract_memory_types_from_context(self, context: Dict[str, Any]) -> List[MemoryType]:
|
||||
def _extract_memory_types_from_context(self, context: dict[str, Any]) -> list[MemoryType]:
|
||||
"""从上下文中提取记忆类型"""
|
||||
try:
|
||||
query_plan = context.get("query_plan")
|
||||
@@ -1256,10 +1257,10 @@ class MultiStageRetrieval:
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _extract_keywords_from_query(self, query: str, query_plan: Optional[Any] = None) -> List[str]:
|
||||
def _extract_keywords_from_query(self, query: str, query_plan: Any | None = None) -> list[str]:
|
||||
"""从查询中提取关键词"""
|
||||
try:
|
||||
extracted: List[str] = []
|
||||
extracted: list[str] = []
|
||||
|
||||
if query_plan and getattr(query_plan, "required_keywords", None):
|
||||
extracted.extend([kw.lower() for kw in query_plan.required_keywords if isinstance(kw, str)])
|
||||
@@ -1283,7 +1284,7 @@ class MultiStageRetrieval:
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _update_retrieval_stats(self, total_time: float, stage_results: List[StageResult]):
|
||||
def _update_retrieval_stats(self, total_time: float, stage_results: list[StageResult]):
|
||||
"""更新检索统计"""
|
||||
self.retrieval_stats["total_queries"] += 1
|
||||
|
||||
@@ -1306,7 +1307,7 @@ class MultiStageRetrieval:
|
||||
]
|
||||
stage_stat["avg_time"] = new_stage_avg
|
||||
|
||||
def get_retrieval_stats(self) -> Dict[str, Any]:
|
||||
def get_retrieval_stats(self) -> dict[str, Any]:
|
||||
"""获取检索统计信息"""
|
||||
return self.retrieval_stats.copy()
|
||||
|
||||
@@ -1328,12 +1329,12 @@ class MultiStageRetrieval:
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
candidate_ids: List[str],
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
context: dict[str, Any],
|
||||
candidate_ids: list[str],
|
||||
all_memories_cache: dict[str, MemoryChunk],
|
||||
limit: int,
|
||||
*,
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
debug_log: dict[str, dict[str, Any]] | None = None,
|
||||
) -> StageResult:
|
||||
"""阶段5:增强重排序 - 使用多维度评分模型"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -1,24 +1,23 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
向量数据库存储接口
|
||||
为记忆系统提供高效的向量存储和语义搜索能力
|
||||
"""
|
||||
|
||||
import time
|
||||
import orjson
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -48,7 +47,7 @@ class VectorStorageConfig:
|
||||
class VectorStorageManager:
|
||||
"""向量存储管理器"""
|
||||
|
||||
def __init__(self, config: Optional[VectorStorageConfig] = None):
|
||||
def __init__(self, config: VectorStorageConfig | None = None):
|
||||
self.config = config or VectorStorageConfig()
|
||||
|
||||
resolved_dimension = resolve_embedding_dimension(self.config.dimension)
|
||||
@@ -68,8 +67,8 @@ class VectorStorageManager:
|
||||
self.index_to_memory_id = {} # vector index -> memory_id
|
||||
|
||||
# 内存缓存
|
||||
self.memory_cache: Dict[str, MemoryChunk] = {}
|
||||
self.vector_cache: Dict[str, List[float]] = {}
|
||||
self.memory_cache: dict[str, MemoryChunk] = {}
|
||||
self.vector_cache: dict[str, list[float]] = {}
|
||||
|
||||
# 统计信息
|
||||
self.storage_stats = {
|
||||
@@ -125,7 +124,7 @@ class VectorStorageManager:
|
||||
)
|
||||
logger.info("✅ 嵌入模型初始化完成")
|
||||
|
||||
async def generate_query_embedding(self, query_text: str) -> Optional[List[float]]:
|
||||
async def generate_query_embedding(self, query_text: str) -> list[float] | None:
|
||||
"""生成查询向量,用于记忆召回"""
|
||||
if not query_text:
|
||||
logger.warning("查询文本为空,无法生成向量")
|
||||
@@ -155,7 +154,7 @@ class VectorStorageManager:
|
||||
logger.error(f"❌ 生成查询向量失败: {exc}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def store_memories(self, memories: List[MemoryChunk]):
|
||||
async def store_memories(self, memories: list[MemoryChunk]):
|
||||
"""存储记忆向量"""
|
||||
if not memories:
|
||||
return
|
||||
@@ -231,7 +230,7 @@ class VectorStorageManager:
|
||||
logger.debug("记忆 %s 缺少可用展示文本,使用占位符生成嵌入输入", memory.memory_id)
|
||||
return memory.memory_id
|
||||
|
||||
async def _batch_generate_and_store_embeddings(self, memory_texts: List[Tuple[str, str]]):
|
||||
async def _batch_generate_and_store_embeddings(self, memory_texts: list[tuple[str, str]]):
|
||||
"""批量生成和存储嵌入向量"""
|
||||
if not memory_texts:
|
||||
return
|
||||
@@ -253,12 +252,12 @@ class VectorStorageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 批量生成嵌入向量失败: {e}")
|
||||
|
||||
async def _batch_generate_embeddings(self, memory_ids: List[str], texts: List[str]) -> Dict[str, List[float]]:
|
||||
async def _batch_generate_embeddings(self, memory_ids: list[str], texts: list[str]) -> dict[str, list[float]]:
|
||||
"""批量生成嵌入向量"""
|
||||
if not texts:
|
||||
return {}
|
||||
|
||||
results: Dict[str, List[float]] = {}
|
||||
results: dict[str, list[float]] = {}
|
||||
|
||||
try:
|
||||
semaphore = asyncio.Semaphore(min(4, max(1, len(texts))))
|
||||
@@ -281,7 +280,9 @@ class VectorStorageManager:
|
||||
logger.warning("生成记忆 %s 的嵌入向量失败: %s", memory_id, exc)
|
||||
results[memory_id] = []
|
||||
|
||||
tasks = [asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts, strict=False)]
|
||||
tasks = [
|
||||
asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts, strict=False)
|
||||
]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
except Exception as e:
|
||||
@@ -291,7 +292,7 @@ class VectorStorageManager:
|
||||
|
||||
return results
|
||||
|
||||
async def _add_single_memory(self, memory: MemoryChunk, embedding: List[float]):
|
||||
async def _add_single_memory(self, memory: MemoryChunk, embedding: list[float]):
|
||||
"""添加单个记忆到向量存储"""
|
||||
with self._lock:
|
||||
try:
|
||||
@@ -337,7 +338,7 @@ class VectorStorageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 添加记忆到向量存储失败: {e}")
|
||||
|
||||
def _normalize_vector(self, vector: List[float]) -> List[float]:
|
||||
def _normalize_vector(self, vector: list[float]) -> list[float]:
|
||||
"""L2归一化向量"""
|
||||
if not vector:
|
||||
return vector
|
||||
@@ -357,12 +358,12 @@ class VectorStorageManager:
|
||||
|
||||
async def search_similar_memories(
|
||||
self,
|
||||
query_vector: Optional[List[float]] = None,
|
||||
query_vector: list[float] | None = None,
|
||||
*,
|
||||
query_text: Optional[str] = None,
|
||||
query_text: str | None = None,
|
||||
limit: int = 10,
|
||||
scope_id: Optional[str] = None,
|
||||
) -> List[Tuple[str, float]]:
|
||||
scope_id: str | None = None,
|
||||
) -> list[tuple[str, float]]:
|
||||
"""搜索相似记忆"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -379,7 +380,7 @@ class VectorStorageManager:
|
||||
logger.warning("查询向量生成失败")
|
||||
return []
|
||||
|
||||
scope_filter: Optional[str] = None
|
||||
scope_filter: str | None = None
|
||||
if isinstance(scope_id, str):
|
||||
normalized_scope = scope_id.strip().lower()
|
||||
if normalized_scope and normalized_scope not in {"global", "global_memory"}:
|
||||
@@ -491,7 +492,7 @@ class VectorStorageManager:
|
||||
logger.error(f"❌ 向量搜索失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def get_memory_by_id(self, memory_id: str) -> Optional[MemoryChunk]:
|
||||
async def get_memory_by_id(self, memory_id: str) -> MemoryChunk | None:
|
||||
"""根据ID获取记忆"""
|
||||
# 先检查缓存
|
||||
if memory_id in self.memory_cache:
|
||||
@@ -501,7 +502,7 @@ class VectorStorageManager:
|
||||
self.storage_stats["total_searches"] += 1
|
||||
return None
|
||||
|
||||
async def update_memory_embedding(self, memory_id: str, new_embedding: List[float]):
|
||||
async def update_memory_embedding(self, memory_id: str, new_embedding: list[float]):
|
||||
"""更新记忆的嵌入向量"""
|
||||
with self._lock:
|
||||
try:
|
||||
@@ -636,7 +637,7 @@ class VectorStorageManager:
|
||||
# 加载记忆缓存
|
||||
cache_file = self.storage_path / "memory_cache.json"
|
||||
if cache_file.exists():
|
||||
with open(cache_file, "r", encoding="utf-8") as f:
|
||||
with open(cache_file, encoding="utf-8") as f:
|
||||
cache_data = orjson.loads(f.read())
|
||||
|
||||
self.memory_cache = {
|
||||
@@ -646,13 +647,13 @@ class VectorStorageManager:
|
||||
# 加载向量缓存
|
||||
vector_cache_file = self.storage_path / "vector_cache.json"
|
||||
if vector_cache_file.exists():
|
||||
with open(vector_cache_file, "r", encoding="utf-8") as f:
|
||||
with open(vector_cache_file, encoding="utf-8") as f:
|
||||
self.vector_cache = orjson.loads(f.read())
|
||||
|
||||
# 加载映射关系
|
||||
mapping_file = self.storage_path / "id_mapping.json"
|
||||
if mapping_file.exists():
|
||||
with open(mapping_file, "r", encoding="utf-8") as f:
|
||||
with open(mapping_file, encoding="utf-8") as f:
|
||||
mapping_data = orjson.loads(f.read())
|
||||
raw_memory_to_index = mapping_data.get("memory_id_to_index", {})
|
||||
self.memory_id_to_index = {
|
||||
@@ -689,7 +690,7 @@ class VectorStorageManager:
|
||||
# 加载统计信息
|
||||
stats_file = self.storage_path / "storage_stats.json"
|
||||
if stats_file.exists():
|
||||
with open(stats_file, "r", encoding="utf-8") as f:
|
||||
with open(stats_file, encoding="utf-8") as f:
|
||||
self.storage_stats = orjson.loads(f.read())
|
||||
|
||||
# 更新向量计数
|
||||
@@ -806,7 +807,7 @@ class VectorStorageManager:
|
||||
if invalid_memory_ids:
|
||||
logger.info(f"清理了 {len(invalid_memory_ids)} 个无效引用")
|
||||
|
||||
def get_storage_stats(self) -> Dict[str, Any]:
|
||||
def get_storage_stats(self) -> dict[str, Any]:
|
||||
"""获取存储统计信息"""
|
||||
stats = self.storage_stats.copy()
|
||||
if stats["total_searches"] > 0:
|
||||
@@ -821,11 +822,11 @@ class SimpleVectorIndex:
|
||||
|
||||
def __init__(self, dimension: int):
|
||||
self.dimension = dimension
|
||||
self.vectors: List[List[float]] = []
|
||||
self.vector_ids: List[int] = []
|
||||
self.vectors: list[list[float]] = []
|
||||
self.vector_ids: list[int] = []
|
||||
self.next_id = 0
|
||||
|
||||
def add_vector(self, vector: List[float]) -> int:
|
||||
def add_vector(self, vector: list[float]) -> int:
|
||||
"""添加向量"""
|
||||
if len(vector) != self.dimension:
|
||||
raise ValueError(f"向量维度不匹配,期望 {self.dimension},实际 {len(vector)}")
|
||||
@@ -837,7 +838,7 @@ class SimpleVectorIndex:
|
||||
|
||||
return vector_id
|
||||
|
||||
def search(self, query_vector: List[float], limit: int) -> List[Tuple[int, float]]:
|
||||
def search(self, query_vector: list[float], limit: int) -> list[tuple[int, float]]:
|
||||
"""搜索相似向量"""
|
||||
if len(query_vector) != self.dimension:
|
||||
raise ValueError(f"查询向量维度不匹配,期望 {self.dimension},实际 {len(query_vector)}")
|
||||
@@ -853,7 +854,7 @@ class SimpleVectorIndex:
|
||||
|
||||
return results[:limit]
|
||||
|
||||
def _calculate_cosine_similarity(self, v1: List[float], v2: List[float]) -> float:
|
||||
def _calculate_cosine_similarity(self, v1: list[float], v2: list[float]) -> float:
|
||||
"""计算余弦相似度"""
|
||||
try:
|
||||
dot_product = sum(x * y for x, y in zip(v1, v2, strict=False))
|
||||
|
||||
@@ -1,25 +1,24 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆激活器
|
||||
记忆系统的激活器组件
|
||||
"""
|
||||
|
||||
import difflib
|
||||
import orjson
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
|
||||
from src.chat.memory_system.memory_manager import MemoryResult
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger("memory_activator")
|
||||
|
||||
|
||||
def get_keywords_from_json(json_str) -> List:
|
||||
def get_keywords_from_json(json_str) -> list:
|
||||
"""
|
||||
从JSON字符串中提取关键词列表
|
||||
|
||||
@@ -81,7 +80,7 @@ class MemoryActivator:
|
||||
self.cached_keywords = set() # 用于缓存历史关键词
|
||||
self.last_memory_query_time = 0 # 上次查询记忆的时间
|
||||
|
||||
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Dict]:
|
||||
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> list[dict]:
|
||||
"""
|
||||
激活记忆
|
||||
"""
|
||||
@@ -155,7 +154,7 @@ class MemoryActivator:
|
||||
|
||||
return self.running_memory
|
||||
|
||||
async def _query_unified_memory(self, keywords: List[str], query_text: str) -> List[MemoryResult]:
|
||||
async def _query_unified_memory(self, keywords: list[str], query_text: str) -> list[MemoryResult]:
|
||||
"""查询统一记忆系统"""
|
||||
try:
|
||||
# 使用记忆系统
|
||||
@@ -198,7 +197,7 @@ class MemoryActivator:
|
||||
logger.error(f"查询统一记忆失败: {e}")
|
||||
return []
|
||||
|
||||
async def get_instant_memory(self, target_message: str, chat_id: str) -> Optional[str]:
|
||||
async def get_instant_memory(self, target_message: str, chat_id: str) -> str | None:
|
||||
"""
|
||||
获取即时记忆 - 兼容原有接口(使用统一存储)
|
||||
"""
|
||||
|
||||
@@ -1,25 +1,24 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆激活器
|
||||
记忆系统的激活器组件
|
||||
"""
|
||||
|
||||
import difflib
|
||||
import orjson
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
|
||||
from src.chat.memory_system.memory_manager import MemoryResult
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger("memory_activator")
|
||||
|
||||
|
||||
def get_keywords_from_json(json_str) -> List:
|
||||
def get_keywords_from_json(json_str) -> list:
|
||||
"""
|
||||
从JSON字符串中提取关键词列表
|
||||
|
||||
@@ -81,7 +80,7 @@ class MemoryActivator:
|
||||
self.cached_keywords = set() # 用于缓存历史关键词
|
||||
self.last_memory_query_time = 0 # 上次查询记忆的时间
|
||||
|
||||
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Dict]:
|
||||
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> list[dict]:
|
||||
"""
|
||||
激活记忆
|
||||
"""
|
||||
@@ -155,7 +154,7 @@ class MemoryActivator:
|
||||
|
||||
return self.running_memory
|
||||
|
||||
async def _query_unified_memory(self, keywords: List[str], query_text: str) -> List[MemoryResult]:
|
||||
async def _query_unified_memory(self, keywords: list[str], query_text: str) -> list[MemoryResult]:
|
||||
"""查询统一记忆系统"""
|
||||
try:
|
||||
# 使用记忆系统
|
||||
@@ -198,7 +197,7 @@ class MemoryActivator:
|
||||
logger.error(f"查询统一记忆失败: {e}")
|
||||
return []
|
||||
|
||||
async def get_instant_memory(self, target_message: str, chat_id: str) -> Optional[str]:
|
||||
async def get_instant_memory(self, target_message: str, chat_id: str) -> str | None:
|
||||
"""
|
||||
获取即时记忆 - 兼容原有接口(使用统一存储)
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆构建模块
|
||||
从对话流中提取高质量、结构化记忆单元
|
||||
@@ -33,19 +32,19 @@ import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union, Type
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.memory_system.memory_chunk import (
|
||||
MemoryChunk,
|
||||
MemoryType,
|
||||
ConfidenceLevel,
|
||||
ImportanceLevel,
|
||||
MemoryChunk,
|
||||
MemoryType,
|
||||
create_memory_chunk,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -62,8 +61,8 @@ class ExtractionStrategy(Enum):
|
||||
class ExtractionResult:
|
||||
"""提取结果"""
|
||||
|
||||
memories: List[MemoryChunk]
|
||||
confidence_scores: List[float]
|
||||
memories: list[MemoryChunk]
|
||||
confidence_scores: list[float]
|
||||
extraction_time: float
|
||||
strategy_used: ExtractionStrategy
|
||||
|
||||
@@ -85,8 +84,8 @@ class MemoryBuilder:
|
||||
}
|
||||
|
||||
async def build_memories(
|
||||
self, conversation_text: str, context: Dict[str, Any], user_id: str, timestamp: float
|
||||
) -> List[MemoryChunk]:
|
||||
self, conversation_text: str, context: dict[str, Any], user_id: str, timestamp: float
|
||||
) -> list[MemoryChunk]:
|
||||
"""从对话中构建记忆"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -116,8 +115,8 @@ class MemoryBuilder:
|
||||
raise
|
||||
|
||||
async def _extract_with_llm(
|
||||
self, text: str, context: Dict[str, Any], user_id: str, timestamp: float
|
||||
) -> List[MemoryChunk]:
|
||||
self, text: str, context: dict[str, Any], user_id: str, timestamp: float
|
||||
) -> list[MemoryChunk]:
|
||||
"""使用LLM提取记忆"""
|
||||
try:
|
||||
prompt = self._build_llm_extraction_prompt(text, context)
|
||||
@@ -135,7 +134,7 @@ class MemoryBuilder:
|
||||
logger.error(f"LLM提取失败: {e}")
|
||||
raise MemoryExtractionError(str(e)) from e
|
||||
|
||||
def _build_llm_extraction_prompt(self, text: str, context: Dict[str, Any]) -> str:
|
||||
def _build_llm_extraction_prompt(self, text: str, context: dict[str, Any]) -> str:
|
||||
"""构建LLM提取提示"""
|
||||
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
message_type = context.get("message_type", "normal")
|
||||
@@ -315,7 +314,7 @@ class MemoryBuilder:
|
||||
|
||||
return prompt
|
||||
|
||||
def _extract_json_payload(self, response: str) -> Optional[str]:
|
||||
def _extract_json_payload(self, response: str) -> str | None:
|
||||
"""从模型响应中提取JSON部分,兼容Markdown代码块等格式"""
|
||||
if not response:
|
||||
return None
|
||||
@@ -338,8 +337,8 @@ class MemoryBuilder:
|
||||
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
|
||||
|
||||
def _parse_llm_response(
|
||||
self, response: str, user_id: str, timestamp: float, context: Dict[str, Any]
|
||||
) -> List[MemoryChunk]:
|
||||
self, response: str, user_id: str, timestamp: float, context: dict[str, Any]
|
||||
) -> list[MemoryChunk]:
|
||||
"""解析LLM响应"""
|
||||
if not response:
|
||||
raise MemoryExtractionError("LLM未返回任何响应")
|
||||
@@ -385,7 +384,7 @@ class MemoryBuilder:
|
||||
|
||||
bot_display = self._clean_subject_text(bot_display)
|
||||
|
||||
memories: List[MemoryChunk] = []
|
||||
memories: list[MemoryChunk] = []
|
||||
|
||||
for mem_data in memory_list:
|
||||
try:
|
||||
@@ -460,7 +459,7 @@ class MemoryBuilder:
|
||||
|
||||
return memories
|
||||
|
||||
def _parse_enum_value(self, enum_cls: Type[Enum], raw_value: Any, default: Enum, field_name: str) -> Enum:
|
||||
def _parse_enum_value(self, enum_cls: type[Enum], raw_value: Any, default: Enum, field_name: str) -> Enum:
|
||||
"""解析枚举值,兼容数字/字符串表示"""
|
||||
if isinstance(raw_value, enum_cls):
|
||||
return raw_value
|
||||
@@ -514,7 +513,7 @@ class MemoryBuilder:
|
||||
)
|
||||
return default
|
||||
|
||||
def _collect_bot_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]:
|
||||
def _collect_bot_identifiers(self, context: dict[str, Any] | None) -> set[str]:
|
||||
identifiers: set[str] = {"bot", "机器人", "ai助手"}
|
||||
if not context:
|
||||
return identifiers
|
||||
@@ -540,7 +539,7 @@ class MemoryBuilder:
|
||||
|
||||
return identifiers
|
||||
|
||||
def _collect_system_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]:
|
||||
def _collect_system_identifiers(self, context: dict[str, Any] | None) -> set[str]:
|
||||
identifiers: set[str] = set()
|
||||
if not context:
|
||||
return identifiers
|
||||
@@ -568,8 +567,8 @@ class MemoryBuilder:
|
||||
|
||||
return identifiers
|
||||
|
||||
def _resolve_conversation_participants(self, context: Optional[Dict[str, Any]], user_id: str) -> List[str]:
|
||||
participants: List[str] = []
|
||||
def _resolve_conversation_participants(self, context: dict[str, Any] | None, user_id: str) -> list[str]:
|
||||
participants: list[str] = []
|
||||
|
||||
if context:
|
||||
candidate_keys = [
|
||||
@@ -609,7 +608,7 @@ class MemoryBuilder:
|
||||
if not participants:
|
||||
participants = ["对话参与者"]
|
||||
|
||||
deduplicated: List[str] = []
|
||||
deduplicated: list[str] = []
|
||||
seen = set()
|
||||
for name in participants:
|
||||
key = name.lower()
|
||||
@@ -620,7 +619,7 @@ class MemoryBuilder:
|
||||
|
||||
return deduplicated
|
||||
|
||||
def _resolve_user_display(self, context: Optional[Dict[str, Any]], user_id: str) -> str:
|
||||
def _resolve_user_display(self, context: dict[str, Any] | None, user_id: str) -> str:
|
||||
candidate_keys = [
|
||||
"user_display_name",
|
||||
"user_name",
|
||||
@@ -683,7 +682,7 @@ class MemoryBuilder:
|
||||
|
||||
return False
|
||||
|
||||
def _split_subject_string(self, value: str) -> List[str]:
|
||||
def _split_subject_string(self, value: str) -> list[str]:
|
||||
if not value:
|
||||
return []
|
||||
|
||||
@@ -699,12 +698,12 @@ class MemoryBuilder:
|
||||
subject: Any,
|
||||
bot_identifiers: set[str],
|
||||
system_identifiers: set[str],
|
||||
default_subjects: List[str],
|
||||
bot_display: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
default_subjects: list[str],
|
||||
bot_display: str | None = None,
|
||||
) -> list[str]:
|
||||
defaults = default_subjects or ["对话参与者"]
|
||||
|
||||
raw_candidates: List[str] = []
|
||||
raw_candidates: list[str] = []
|
||||
if isinstance(subject, list):
|
||||
for item in subject:
|
||||
if isinstance(item, str):
|
||||
@@ -716,7 +715,7 @@ class MemoryBuilder:
|
||||
elif subject is not None:
|
||||
raw_candidates.extend(self._split_subject_string(str(subject)))
|
||||
|
||||
normalized: List[str] = []
|
||||
normalized: list[str] = []
|
||||
bot_primary = self._clean_subject_text(bot_display or "")
|
||||
|
||||
for candidate in raw_candidates:
|
||||
@@ -741,7 +740,7 @@ class MemoryBuilder:
|
||||
if not normalized:
|
||||
normalized = list(defaults)
|
||||
|
||||
deduplicated: List[str] = []
|
||||
deduplicated: list[str] = []
|
||||
seen = set()
|
||||
for name in normalized:
|
||||
key = name.lower()
|
||||
@@ -752,7 +751,7 @@ class MemoryBuilder:
|
||||
|
||||
return deduplicated
|
||||
|
||||
def _extract_value_from_object(self, obj: Union[str, Dict[str, Any], List[Any]], keys: List[str]) -> Optional[str]:
|
||||
def _extract_value_from_object(self, obj: str | dict[str, Any] | list[Any], keys: list[str]) -> str | None:
|
||||
if isinstance(obj, dict):
|
||||
for key in keys:
|
||||
value = obj.get(key)
|
||||
@@ -773,9 +772,7 @@ class MemoryBuilder:
|
||||
return obj.strip() or None
|
||||
return None
|
||||
|
||||
def _compose_display_text(
|
||||
self, subjects: List[str], predicate: str, obj: Union[str, Dict[str, Any], List[Any]]
|
||||
) -> str:
|
||||
def _compose_display_text(self, subjects: list[str], predicate: str, obj: str | dict[str, Any] | list[Any]) -> str:
|
||||
subject_phrase = "、".join(subjects) if subjects else "对话参与者"
|
||||
predicate = (predicate or "").strip()
|
||||
|
||||
@@ -841,7 +838,7 @@ class MemoryBuilder:
|
||||
return f"{subject_phrase}{predicate}".strip()
|
||||
return subject_phrase
|
||||
|
||||
def _validate_and_enhance_memories(self, memories: List[MemoryChunk], context: Dict[str, Any]) -> List[MemoryChunk]:
|
||||
def _validate_and_enhance_memories(self, memories: list[MemoryChunk], context: dict[str, Any]) -> list[MemoryChunk]:
|
||||
"""验证和增强记忆"""
|
||||
validated_memories = []
|
||||
|
||||
@@ -876,7 +873,7 @@ class MemoryBuilder:
|
||||
|
||||
return True
|
||||
|
||||
def _enhance_memory(self, memory: MemoryChunk, context: Dict[str, Any]) -> MemoryChunk:
|
||||
def _enhance_memory(self, memory: MemoryChunk, context: dict[str, Any]) -> MemoryChunk:
|
||||
"""增强记忆块"""
|
||||
# 时间规范化处理
|
||||
self._normalize_time_in_memory(memory)
|
||||
@@ -985,7 +982,7 @@ class MemoryBuilder:
|
||||
total_confidence / self.extraction_stats["successful_extractions"]
|
||||
)
|
||||
|
||||
def get_extraction_stats(self) -> Dict[str, Any]:
|
||||
def get_extraction_stats(self) -> dict[str, Any]:
|
||||
"""获取提取统计信息"""
|
||||
return self.extraction_stats.copy()
|
||||
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
结构化记忆单元设计
|
||||
实现高质量、结构化的记忆单元,符合文档设计规范
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
import uuid
|
||||
import orjson
|
||||
from typing import Dict, List, Optional, Any, Union, Iterable
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import hashlib
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -56,17 +57,17 @@ class ImportanceLevel(Enum):
|
||||
class ContentStructure:
|
||||
"""主谓宾结构,包含自然语言描述"""
|
||||
|
||||
subject: Union[str, List[str]]
|
||||
subject: str | list[str]
|
||||
predicate: str
|
||||
object: Union[str, Dict]
|
||||
object: str | dict
|
||||
display: str = ""
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {"subject": self.subject, "predicate": self.predicate, "object": self.object, "display": self.display}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ContentStructure":
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ContentStructure":
|
||||
"""从字典创建实例"""
|
||||
return cls(
|
||||
subject=data.get("subject", ""),
|
||||
@@ -75,7 +76,7 @@ class ContentStructure:
|
||||
display=data.get("display", ""),
|
||||
)
|
||||
|
||||
def to_subject_list(self) -> List[str]:
|
||||
def to_subject_list(self) -> list[str]:
|
||||
"""将主语转换为列表形式"""
|
||||
if isinstance(self.subject, list):
|
||||
return [s for s in self.subject if isinstance(s, str) and s.strip()]
|
||||
@@ -99,7 +100,7 @@ class MemoryMetadata:
|
||||
# 基础信息
|
||||
memory_id: str # 唯一标识符
|
||||
user_id: str # 用户ID
|
||||
chat_id: Optional[str] = None # 聊天ID(群聊或私聊)
|
||||
chat_id: str | None = None # 聊天ID(群聊或私聊)
|
||||
|
||||
# 时间信息
|
||||
created_at: float = 0.0 # 创建时间戳
|
||||
@@ -124,9 +125,9 @@ class MemoryMetadata:
|
||||
last_forgetting_check: float = 0.0 # 上次遗忘检查时间
|
||||
|
||||
# 来源信息
|
||||
source_context: Optional[str] = None # 来源上下文片段
|
||||
source_context: str | None = None # 来源上下文片段
|
||||
# 兼容旧字段: 一些代码或旧版本可能直接访问 metadata.source
|
||||
source: Optional[str] = None
|
||||
source: str | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""后初始化处理"""
|
||||
@@ -209,7 +210,7 @@ class MemoryMetadata:
|
||||
# 设置最小和最大阈值
|
||||
return max(7.0, min(threshold, 365.0)) # 7天到1年之间
|
||||
|
||||
def should_forget(self, current_time: Optional[float] = None) -> bool:
|
||||
def should_forget(self, current_time: float | None = None) -> bool:
|
||||
"""判断是否应该遗忘"""
|
||||
if current_time is None:
|
||||
current_time = time.time()
|
||||
@@ -222,7 +223,7 @@ class MemoryMetadata:
|
||||
|
||||
return days_since_activation > self.forgetting_threshold
|
||||
|
||||
def is_dormant(self, current_time: Optional[float] = None, inactive_days: int = 90) -> bool:
|
||||
def is_dormant(self, current_time: float | None = None, inactive_days: int = 90) -> bool:
|
||||
"""判断是否处于休眠状态(长期未激活)"""
|
||||
if current_time is None:
|
||||
current_time = time.time()
|
||||
@@ -230,7 +231,7 @@ class MemoryMetadata:
|
||||
days_since_last_access = (current_time - self.last_accessed) / 86400
|
||||
return days_since_last_access > inactive_days
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"memory_id": self.memory_id,
|
||||
@@ -252,7 +253,7 @@ class MemoryMetadata:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "MemoryMetadata":
|
||||
def from_dict(cls, data: dict[str, Any]) -> "MemoryMetadata":
|
||||
"""从字典创建实例"""
|
||||
return cls(
|
||||
memory_id=data.get("memory_id", ""),
|
||||
@@ -286,17 +287,17 @@ class MemoryChunk:
|
||||
memory_type: MemoryType # 记忆类型
|
||||
|
||||
# 扩展信息
|
||||
keywords: List[str] = field(default_factory=list) # 关键词列表
|
||||
tags: List[str] = field(default_factory=list) # 标签列表
|
||||
categories: List[str] = field(default_factory=list) # 分类列表
|
||||
keywords: list[str] = field(default_factory=list) # 关键词列表
|
||||
tags: list[str] = field(default_factory=list) # 标签列表
|
||||
categories: list[str] = field(default_factory=list) # 分类列表
|
||||
|
||||
# 语义信息
|
||||
embedding: Optional[List[float]] = None # 语义向量
|
||||
semantic_hash: Optional[str] = None # 语义哈希值
|
||||
embedding: list[float] | None = None # 语义向量
|
||||
semantic_hash: str | None = None # 语义哈希值
|
||||
|
||||
# 关联信息
|
||||
related_memories: List[str] = field(default_factory=list) # 关联记忆ID列表
|
||||
temporal_context: Optional[Dict[str, Any]] = None # 时间上下文
|
||||
related_memories: list[str] = field(default_factory=list) # 关联记忆ID列表
|
||||
temporal_context: dict[str, Any] | None = None # 时间上下文
|
||||
|
||||
def __post_init__(self):
|
||||
"""后初始化处理"""
|
||||
@@ -310,7 +311,7 @@ class MemoryChunk:
|
||||
|
||||
try:
|
||||
# 使用向量和内容生成稳定的哈希
|
||||
content_str = f"{self.content.subject}:{self.content.predicate}:{str(self.content.object)}"
|
||||
content_str = f"{self.content.subject}:{self.content.predicate}:{self.content.object!s}"
|
||||
embedding_str = ",".join(map(str, [round(x, 6) for x in self.embedding]))
|
||||
|
||||
hash_input = f"{content_str}|{embedding_str}"
|
||||
@@ -342,7 +343,7 @@ class MemoryChunk:
|
||||
return self.content.display or str(self.content)
|
||||
|
||||
@property
|
||||
def subjects(self) -> List[str]:
|
||||
def subjects(self) -> list[str]:
|
||||
"""获取主语列表"""
|
||||
return self.content.to_subject_list()
|
||||
|
||||
@@ -354,11 +355,11 @@ class MemoryChunk:
|
||||
"""更新相关度评分"""
|
||||
self.metadata.update_relevance(new_score)
|
||||
|
||||
def should_forget(self, current_time: Optional[float] = None) -> bool:
|
||||
def should_forget(self, current_time: float | None = None) -> bool:
|
||||
"""判断是否应该遗忘"""
|
||||
return self.metadata.should_forget(current_time)
|
||||
|
||||
def is_dormant(self, current_time: Optional[float] = None, inactive_days: int = 90) -> bool:
|
||||
def is_dormant(self, current_time: float | None = None, inactive_days: int = 90) -> bool:
|
||||
"""判断是否处于休眠状态(长期未激活)"""
|
||||
return self.metadata.is_dormant(current_time, inactive_days)
|
||||
|
||||
@@ -386,7 +387,7 @@ class MemoryChunk:
|
||||
if memory_id and memory_id not in self.related_memories:
|
||||
self.related_memories.append(memory_id)
|
||||
|
||||
def set_embedding(self, embedding: List[float]):
|
||||
def set_embedding(self, embedding: list[float]):
|
||||
"""设置语义向量"""
|
||||
self.embedding = embedding
|
||||
self._generate_semantic_hash()
|
||||
@@ -415,7 +416,7 @@ class MemoryChunk:
|
||||
logger.warning(f"计算记忆相似度失败: {e}")
|
||||
return 0.0
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为完整的字典格式"""
|
||||
return {
|
||||
"metadata": self.metadata.to_dict(),
|
||||
@@ -431,7 +432,7 @@ class MemoryChunk:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "MemoryChunk":
|
||||
def from_dict(cls, data: dict[str, Any]) -> "MemoryChunk":
|
||||
"""从字典创建实例"""
|
||||
metadata = MemoryMetadata.from_dict(data.get("metadata", {}))
|
||||
content = ContentStructure.from_dict(data.get("content", {}))
|
||||
@@ -541,7 +542,7 @@ class MemoryChunk:
|
||||
return f"MemoryChunk(id={self.memory_id[:8]}..., type={self.memory_type.value}, user={self.user_id})"
|
||||
|
||||
|
||||
def _build_display_text(subjects: Iterable[str], predicate: str, obj: Union[str, Dict]) -> str:
|
||||
def _build_display_text(subjects: Iterable[str], predicate: str, obj: str | dict) -> str:
|
||||
"""根据主谓宾生成自然语言描述"""
|
||||
subjects_clean = [s.strip() for s in subjects if s and isinstance(s, str)]
|
||||
subject_part = "、".join(subjects_clean) if subjects_clean else "对话参与者"
|
||||
@@ -569,15 +570,15 @@ def _build_display_text(subjects: Iterable[str], predicate: str, obj: Union[str,
|
||||
|
||||
def create_memory_chunk(
|
||||
user_id: str,
|
||||
subject: Union[str, List[str]],
|
||||
subject: str | list[str],
|
||||
predicate: str,
|
||||
obj: Union[str, Dict],
|
||||
obj: str | dict,
|
||||
memory_type: MemoryType,
|
||||
chat_id: Optional[str] = None,
|
||||
source_context: Optional[str] = None,
|
||||
chat_id: str | None = None,
|
||||
source_context: str | None = None,
|
||||
importance: ImportanceLevel = ImportanceLevel.NORMAL,
|
||||
confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM,
|
||||
display: Optional[str] = None,
|
||||
display: str | None = None,
|
||||
**kwargs,
|
||||
) -> MemoryChunk:
|
||||
"""便捷的内存块创建函数"""
|
||||
@@ -593,10 +594,10 @@ def create_memory_chunk(
|
||||
source_context=source_context,
|
||||
)
|
||||
|
||||
subjects: List[str]
|
||||
subjects: list[str]
|
||||
if isinstance(subject, list):
|
||||
subjects = [s for s in subject if isinstance(s, str) and s.strip()]
|
||||
subject_payload: Union[str, List[str]] = subjects
|
||||
subject_payload: str | list[str] = subjects
|
||||
else:
|
||||
cleaned = subject.strip() if isinstance(subject, str) else ""
|
||||
subjects = [cleaned] if cleaned else []
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
智能记忆遗忘引擎
|
||||
基于重要程度、置信度和激活频率的智能遗忘机制
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from datetime import datetime
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, ImportanceLevel, ConfidenceLevel
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -65,7 +63,7 @@ class ForgettingConfig:
|
||||
class MemoryForgettingEngine:
|
||||
"""智能记忆遗忘引擎"""
|
||||
|
||||
def __init__(self, config: Optional[ForgettingConfig] = None):
|
||||
def __init__(self, config: ForgettingConfig | None = None):
|
||||
self.config = config or ForgettingConfig()
|
||||
self.stats = ForgettingStats()
|
||||
self._last_forgetting_check = 0.0
|
||||
@@ -116,7 +114,7 @@ class MemoryForgettingEngine:
|
||||
# 确保在合理范围内
|
||||
return max(self.config.min_forgetting_days, min(threshold, self.config.max_forgetting_days))
|
||||
|
||||
def should_forget_memory(self, memory: MemoryChunk, current_time: Optional[float] = None) -> bool:
|
||||
def should_forget_memory(self, memory: MemoryChunk, current_time: float | None = None) -> bool:
|
||||
"""
|
||||
判断记忆是否应该被遗忘
|
||||
|
||||
@@ -155,7 +153,7 @@ class MemoryForgettingEngine:
|
||||
|
||||
return should_forget
|
||||
|
||||
def is_dormant_memory(self, memory: MemoryChunk, current_time: Optional[float] = None) -> bool:
|
||||
def is_dormant_memory(self, memory: MemoryChunk, current_time: float | None = None) -> bool:
|
||||
"""
|
||||
判断记忆是否处于休眠状态
|
||||
|
||||
@@ -168,7 +166,7 @@ class MemoryForgettingEngine:
|
||||
"""
|
||||
return memory.is_dormant(current_time, self.config.dormant_threshold_days)
|
||||
|
||||
def should_force_forget_dormant(self, memory: MemoryChunk, current_time: Optional[float] = None) -> bool:
|
||||
def should_force_forget_dormant(self, memory: MemoryChunk, current_time: float | None = None) -> bool:
|
||||
"""
|
||||
判断是否应该强制遗忘休眠记忆
|
||||
|
||||
@@ -189,7 +187,7 @@ class MemoryForgettingEngine:
|
||||
days_since_last_access = (current_time - memory.metadata.last_accessed) / 86400
|
||||
return days_since_last_access > self.config.force_forget_dormant_days
|
||||
|
||||
async def check_memories_for_forgetting(self, memories: List[MemoryChunk]) -> Tuple[List[str], List[str]]:
|
||||
async def check_memories_for_forgetting(self, memories: list[MemoryChunk]) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
检查记忆列表,识别需要遗忘的记忆
|
||||
|
||||
@@ -241,7 +239,7 @@ class MemoryForgettingEngine:
|
||||
|
||||
return normal_forgetting_ids, force_forgetting_ids
|
||||
|
||||
async def perform_forgetting_check(self, memories: List[MemoryChunk]) -> Dict[str, any]:
|
||||
async def perform_forgetting_check(self, memories: list[MemoryChunk]) -> dict[str, any]:
|
||||
"""
|
||||
执行完整的遗忘检查流程
|
||||
|
||||
@@ -314,7 +312,7 @@ class MemoryForgettingEngine:
|
||||
except Exception as e:
|
||||
logger.error(f"定期遗忘检查失败: {e}", exc_info=True)
|
||||
|
||||
def get_forgetting_stats(self) -> Dict[str, any]:
|
||||
def get_forgetting_stats(self) -> dict[str, any]:
|
||||
"""获取遗忘统计信息"""
|
||||
return {
|
||||
"total_checked": self.stats.total_checked,
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆融合与去重机制
|
||||
避免记忆碎片化,确保长期记忆库的高质量
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, ConfidenceLevel, ImportanceLevel
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -22,9 +20,9 @@ class FusionResult:
|
||||
original_count: int
|
||||
fused_count: int
|
||||
removed_duplicates: int
|
||||
merged_memories: List[MemoryChunk]
|
||||
merged_memories: list[MemoryChunk]
|
||||
fusion_time: float
|
||||
details: List[str]
|
||||
details: list[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -32,9 +30,9 @@ class DuplicateGroup:
|
||||
"""重复记忆组"""
|
||||
|
||||
group_id: str
|
||||
memories: List[MemoryChunk]
|
||||
similarity_matrix: List[List[float]]
|
||||
representative_memory: Optional[MemoryChunk] = None
|
||||
memories: list[MemoryChunk]
|
||||
similarity_matrix: list[list[float]]
|
||||
representative_memory: MemoryChunk | None = None
|
||||
|
||||
|
||||
class MemoryFusionEngine:
|
||||
@@ -59,8 +57,8 @@ class MemoryFusionEngine:
|
||||
}
|
||||
|
||||
async def fuse_memories(
|
||||
self, new_memories: List[MemoryChunk], existing_memories: Optional[List[MemoryChunk]] = None
|
||||
) -> List[MemoryChunk]:
|
||||
self, new_memories: list[MemoryChunk], existing_memories: list[MemoryChunk] | None = None
|
||||
) -> list[MemoryChunk]:
|
||||
"""融合记忆列表"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -106,8 +104,8 @@ class MemoryFusionEngine:
|
||||
return new_memories # 失败时返回原始记忆
|
||||
|
||||
async def _detect_duplicate_groups(
|
||||
self, new_memories: List[MemoryChunk], existing_memories: List[MemoryChunk]
|
||||
) -> List[DuplicateGroup]:
|
||||
self, new_memories: list[MemoryChunk], existing_memories: list[MemoryChunk]
|
||||
) -> list[DuplicateGroup]:
|
||||
"""检测重复记忆组"""
|
||||
all_memories = new_memories + existing_memories
|
||||
new_memory_ids = {memory.memory_id for memory in new_memories}
|
||||
@@ -212,7 +210,7 @@ class MemoryFusionEngine:
|
||||
jaccard_similarity = len(intersection) / len(union)
|
||||
return jaccard_similarity
|
||||
|
||||
def _calculate_keyword_similarity(self, keywords1: List[str], keywords2: List[str]) -> float:
|
||||
def _calculate_keyword_similarity(self, keywords1: list[str], keywords2: list[str]) -> float:
|
||||
"""计算关键词相似度"""
|
||||
if not keywords1 or not keywords2:
|
||||
return 0.0
|
||||
@@ -302,7 +300,7 @@ class MemoryFusionEngine:
|
||||
|
||||
return best_memory
|
||||
|
||||
async def _fuse_memory_group(self, group: DuplicateGroup) -> Optional[MemoryChunk]:
|
||||
async def _fuse_memory_group(self, group: DuplicateGroup) -> MemoryChunk | None:
|
||||
"""融合记忆组"""
|
||||
if not group.memories:
|
||||
return None
|
||||
@@ -328,7 +326,7 @@ class MemoryFusionEngine:
|
||||
# 返回置信度最高的记忆
|
||||
return max(group.memories, key=lambda m: m.metadata.confidence.value)
|
||||
|
||||
async def _merge_memory_attributes(self, base_memory: MemoryChunk, memories: List[MemoryChunk]) -> MemoryChunk:
|
||||
async def _merge_memory_attributes(self, base_memory: MemoryChunk, memories: list[MemoryChunk]) -> MemoryChunk:
|
||||
"""合并记忆属性"""
|
||||
# 创建基础记忆的深拷贝
|
||||
fused_memory = MemoryChunk.from_dict(base_memory.to_dict())
|
||||
@@ -395,7 +393,7 @@ class MemoryFusionEngine:
|
||||
source_ids = [m.memory_id[:8] for m in group.memories]
|
||||
fused_memory.metadata.source_context = f"Fused from {len(group.memories)} memories: {', '.join(source_ids)}"
|
||||
|
||||
def _merge_temporal_context(self, memories: List[MemoryChunk]) -> Dict[str, Any]:
|
||||
def _merge_temporal_context(self, memories: list[MemoryChunk]) -> dict[str, Any]:
|
||||
"""合并时间上下文"""
|
||||
contexts = [m.temporal_context for m in memories if m.temporal_context]
|
||||
|
||||
@@ -426,8 +424,8 @@ class MemoryFusionEngine:
|
||||
return merged_context
|
||||
|
||||
async def incremental_fusion(
|
||||
self, new_memory: MemoryChunk, existing_memories: List[MemoryChunk]
|
||||
) -> Tuple[MemoryChunk, List[MemoryChunk]]:
|
||||
self, new_memory: MemoryChunk, existing_memories: list[MemoryChunk]
|
||||
) -> tuple[MemoryChunk, list[MemoryChunk]]:
|
||||
"""增量融合(单个新记忆与现有记忆融合)"""
|
||||
# 寻找相似记忆
|
||||
similar_memories = []
|
||||
@@ -493,7 +491,7 @@ class MemoryFusionEngine:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 记忆融合引擎维护失败: {e}", exc_info=True)
|
||||
|
||||
def get_fusion_stats(self) -> Dict[str, Any]:
|
||||
def get_fusion_stats(self) -> dict[str, Any]:
|
||||
"""获取融合统计信息"""
|
||||
return self.fusion_stats.copy()
|
||||
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆系统管理器
|
||||
替代原有的 Hippocampus 和 instant_memory 系统
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_system import MemorySystem
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.chat.memory_system.memory_system import initialize_memory_system
|
||||
from src.chat.memory_system.memory_system import MemorySystem, initialize_memory_system
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -27,14 +25,14 @@ class MemoryResult:
|
||||
timestamp: float
|
||||
source: str = "memory"
|
||||
relevance_score: float = 0.0
|
||||
structure: Dict[str, Any] | None = None
|
||||
structure: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
"""记忆系统管理器 - 替代原有的 HippocampusManager"""
|
||||
|
||||
def __init__(self):
|
||||
self.memory_system: Optional[MemorySystem] = None
|
||||
self.memory_system: MemorySystem | None = None
|
||||
self.is_initialized = False
|
||||
self.user_cache = {} # 用户记忆缓存
|
||||
|
||||
@@ -63,8 +61,8 @@ class MemoryManager:
|
||||
logger.info("正在初始化记忆系统...")
|
||||
|
||||
# 获取LLM模型
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
llm_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="memory")
|
||||
|
||||
@@ -121,7 +119,7 @@ class MemoryManager:
|
||||
max_memory_length: int = 2,
|
||||
time_weight: float = 1.0,
|
||||
keyword_weight: float = 1.0,
|
||||
) -> List[Tuple[str, str]]:
|
||||
) -> list[tuple[str, str]]:
|
||||
"""从文本获取相关记忆 - 兼容原有接口"""
|
||||
if not self.is_initialized or not self.memory_system:
|
||||
return []
|
||||
@@ -152,8 +150,8 @@ class MemoryManager:
|
||||
return []
|
||||
|
||||
async def get_memory_from_topic(
|
||||
self, valid_keywords: List[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
|
||||
) -> List[Tuple[str, str]]:
|
||||
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
|
||||
) -> list[tuple[str, str]]:
|
||||
"""从关键词获取记忆 - 兼容原有接口"""
|
||||
if not self.is_initialized or not self.memory_system:
|
||||
return []
|
||||
@@ -208,8 +206,8 @@ class MemoryManager:
|
||||
return []
|
||||
|
||||
async def process_conversation(
|
||||
self, conversation_text: str, context: Dict[str, Any], user_id: str, timestamp: Optional[float] = None
|
||||
) -> List[MemoryChunk]:
|
||||
self, conversation_text: str, context: dict[str, Any], user_id: str, timestamp: float | None = None
|
||||
) -> list[MemoryChunk]:
|
||||
"""处理对话并构建记忆 - 新增功能"""
|
||||
if not self.is_initialized or not self.memory_system:
|
||||
return []
|
||||
@@ -235,8 +233,8 @@ class MemoryManager:
|
||||
return []
|
||||
|
||||
async def get_enhanced_memory_context(
|
||||
self, query_text: str, user_id: str, context: Optional[Dict[str, Any]] = None, limit: int = 5
|
||||
) -> List[MemoryResult]:
|
||||
self, query_text: str, user_id: str, context: dict[str, Any] | None = None, limit: int = 5
|
||||
) -> list[MemoryResult]:
|
||||
"""获取增强记忆上下文 - 新增功能"""
|
||||
if not self.is_initialized or not self.memory_system:
|
||||
return []
|
||||
@@ -267,7 +265,7 @@ class MemoryManager:
|
||||
logger.error(f"get_enhanced_memory_context 失败: {e}")
|
||||
return []
|
||||
|
||||
def _format_memory_chunk(self, memory: MemoryChunk) -> Tuple[str, Dict[str, Any]]:
|
||||
def _format_memory_chunk(self, memory: MemoryChunk) -> tuple[str, dict[str, Any]]:
|
||||
"""将记忆块转换为更易读的文本描述"""
|
||||
structure = memory.content.to_dict()
|
||||
if memory.display:
|
||||
@@ -289,7 +287,7 @@ class MemoryManager:
|
||||
|
||||
return formatted, structure
|
||||
|
||||
def _format_subject(self, subject: Optional[str], memory: MemoryChunk) -> str:
|
||||
def _format_subject(self, subject: str | None, memory: MemoryChunk) -> str:
|
||||
if not subject:
|
||||
return "该用户"
|
||||
|
||||
@@ -299,7 +297,7 @@ class MemoryManager:
|
||||
return "该聊天"
|
||||
return self._clean_text(subject)
|
||||
|
||||
def _apply_predicate_format(self, subject: str, predicate: str, obj: Any) -> Optional[str]:
|
||||
def _apply_predicate_format(self, subject: str, predicate: str, obj: Any) -> str | None:
|
||||
predicate = (predicate or "").strip()
|
||||
obj_value = obj
|
||||
|
||||
@@ -446,10 +444,10 @@ class MemoryManager:
|
||||
text = self._truncate(str(obj).strip())
|
||||
return self._clean_text(text)
|
||||
|
||||
def _extract_from_object(self, obj: Any, keys: List[str]) -> Optional[str]:
|
||||
def _extract_from_object(self, obj: Any, keys: list[str]) -> str | None:
|
||||
if isinstance(obj, dict):
|
||||
for key in keys:
|
||||
if key in obj and obj[key]:
|
||||
if obj.get(key):
|
||||
value = obj[key]
|
||||
if isinstance(value, (dict, list)):
|
||||
return self._clean_text(self._format_object(value))
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆元数据索引管理器
|
||||
使用JSON文件存储记忆元数据,支持快速模糊搜索和过滤
|
||||
"""
|
||||
|
||||
import orjson
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -25,10 +25,10 @@ class MemoryMetadataIndexEntry:
|
||||
|
||||
# 分类信息
|
||||
memory_type: str # MemoryType.value
|
||||
subjects: List[str] # 主语列表
|
||||
objects: List[str] # 宾语列表
|
||||
keywords: List[str] # 关键词列表
|
||||
tags: List[str] # 标签列表
|
||||
subjects: list[str] # 主语列表
|
||||
objects: list[str] # 宾语列表
|
||||
keywords: list[str] # 关键词列表
|
||||
tags: list[str] # 标签列表
|
||||
|
||||
# 数值字段(用于范围过滤)
|
||||
importance: int # ImportanceLevel.value (1-4)
|
||||
@@ -37,8 +37,8 @@ class MemoryMetadataIndexEntry:
|
||||
access_count: int # 访问次数
|
||||
|
||||
# 可选字段
|
||||
chat_id: Optional[str] = None
|
||||
content_preview: Optional[str] = None # 内容预览(前100字符)
|
||||
chat_id: str | None = None
|
||||
content_preview: str | None = None # 内容预览(前100字符)
|
||||
|
||||
|
||||
class MemoryMetadataIndex:
|
||||
@@ -46,13 +46,13 @@ class MemoryMetadataIndex:
|
||||
|
||||
def __init__(self, index_file: str = "data/memory_metadata_index.json"):
|
||||
self.index_file = Path(index_file)
|
||||
self.index: Dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry
|
||||
self.index: dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry
|
||||
|
||||
# 倒排索引(用于快速查找)
|
||||
self.type_index: Dict[str, Set[str]] = {} # type -> {memory_ids}
|
||||
self.subject_index: Dict[str, Set[str]] = {} # subject -> {memory_ids}
|
||||
self.keyword_index: Dict[str, Set[str]] = {} # keyword -> {memory_ids}
|
||||
self.tag_index: Dict[str, Set[str]] = {} # tag -> {memory_ids}
|
||||
self.type_index: dict[str, set[str]] = {} # type -> {memory_ids}
|
||||
self.subject_index: dict[str, set[str]] = {} # subject -> {memory_ids}
|
||||
self.keyword_index: dict[str, set[str]] = {} # keyword -> {memory_ids}
|
||||
self.tag_index: dict[str, set[str]] = {} # tag -> {memory_ids}
|
||||
|
||||
self.lock = threading.RLock()
|
||||
|
||||
@@ -178,7 +178,7 @@ class MemoryMetadataIndex:
|
||||
self._remove_from_inverted_indices(memory_id)
|
||||
del self.index[memory_id]
|
||||
|
||||
def batch_add_or_update(self, entries: List[MemoryMetadataIndexEntry]):
|
||||
def batch_add_or_update(self, entries: list[MemoryMetadataIndexEntry]):
|
||||
"""批量添加或更新"""
|
||||
with self.lock:
|
||||
for entry in entries:
|
||||
@@ -191,18 +191,18 @@ class MemoryMetadataIndex:
|
||||
|
||||
def search(
|
||||
self,
|
||||
memory_types: Optional[List[str]] = None,
|
||||
subjects: Optional[List[str]] = None,
|
||||
keywords: Optional[List[str]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
importance_min: Optional[int] = None,
|
||||
importance_max: Optional[int] = None,
|
||||
created_after: Optional[float] = None,
|
||||
created_before: Optional[float] = None,
|
||||
user_id: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
memory_types: list[str] | None = None,
|
||||
subjects: list[str] | None = None,
|
||||
keywords: list[str] | None = None,
|
||||
tags: list[str] | None = None,
|
||||
importance_min: int | None = None,
|
||||
importance_max: int | None = None,
|
||||
created_after: float | None = None,
|
||||
created_before: float | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
flexible_mode: bool = True, # 新增:灵活匹配模式
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""
|
||||
搜索符合条件的记忆ID列表(支持模糊匹配)
|
||||
|
||||
@@ -237,14 +237,14 @@ class MemoryMetadataIndex:
|
||||
|
||||
def _search_flexible(
|
||||
self,
|
||||
memory_types: Optional[List[str]] = None,
|
||||
subjects: Optional[List[str]] = None,
|
||||
created_after: Optional[float] = None,
|
||||
created_before: Optional[float] = None,
|
||||
user_id: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
memory_types: list[str] | None = None,
|
||||
subjects: list[str] | None = None,
|
||||
created_after: float | None = None,
|
||||
created_before: float | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
**kwargs, # 接受但不使用的参数
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""
|
||||
灵活搜索模式:2/4项匹配即可,支持部分匹配
|
||||
|
||||
@@ -374,20 +374,20 @@ class MemoryMetadataIndex:
|
||||
|
||||
def _search_strict(
|
||||
self,
|
||||
memory_types: Optional[List[str]] = None,
|
||||
subjects: Optional[List[str]] = None,
|
||||
keywords: Optional[List[str]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
importance_min: Optional[int] = None,
|
||||
importance_max: Optional[int] = None,
|
||||
created_after: Optional[float] = None,
|
||||
created_before: Optional[float] = None,
|
||||
user_id: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
memory_types: list[str] | None = None,
|
||||
subjects: list[str] | None = None,
|
||||
keywords: list[str] | None = None,
|
||||
tags: list[str] | None = None,
|
||||
importance_min: int | None = None,
|
||||
importance_max: int | None = None,
|
||||
created_after: float | None = None,
|
||||
created_before: float | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[str]:
|
||||
"""严格搜索模式(原有逻辑)"""
|
||||
# 初始候选集(所有记忆)
|
||||
candidate_ids: Optional[Set[str]] = None
|
||||
candidate_ids: set[str] | None = None
|
||||
|
||||
# 用户过滤(必选)
|
||||
if user_id:
|
||||
@@ -471,11 +471,11 @@ class MemoryMetadataIndex:
|
||||
|
||||
return result_ids
|
||||
|
||||
def get_entry(self, memory_id: str) -> Optional[MemoryMetadataIndexEntry]:
|
||||
def get_entry(self, memory_id: str) -> MemoryMetadataIndexEntry | None:
|
||||
"""获取单个索引条目"""
|
||||
return self.index.get(memory_id)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""获取索引统计信息"""
|
||||
with self.lock:
|
||||
return {
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""记忆检索查询规划器"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
@@ -21,16 +20,16 @@ class MemoryQueryPlan:
|
||||
"""查询规划结果"""
|
||||
|
||||
semantic_query: str
|
||||
memory_types: List[MemoryType] = field(default_factory=list)
|
||||
subject_includes: List[str] = field(default_factory=list)
|
||||
object_includes: List[str] = field(default_factory=list)
|
||||
required_keywords: List[str] = field(default_factory=list)
|
||||
optional_keywords: List[str] = field(default_factory=list)
|
||||
owner_filters: List[str] = field(default_factory=list)
|
||||
memory_types: list[MemoryType] = field(default_factory=list)
|
||||
subject_includes: list[str] = field(default_factory=list)
|
||||
object_includes: list[str] = field(default_factory=list)
|
||||
required_keywords: list[str] = field(default_factory=list)
|
||||
optional_keywords: list[str] = field(default_factory=list)
|
||||
owner_filters: list[str] = field(default_factory=list)
|
||||
recency_preference: str = "any"
|
||||
limit: int = 10
|
||||
emphasis: Optional[str] = None
|
||||
raw_plan: Dict[str, Any] = field(default_factory=dict)
|
||||
emphasis: str | None = None
|
||||
raw_plan: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def ensure_defaults(self, fallback_query: str, default_limit: int) -> None:
|
||||
if not self.semantic_query:
|
||||
@@ -46,11 +45,11 @@ class MemoryQueryPlan:
|
||||
class MemoryQueryPlanner:
|
||||
"""基于小模型的记忆检索查询规划器"""
|
||||
|
||||
def __init__(self, planner_model: Optional[LLMRequest], default_limit: int = 10):
|
||||
def __init__(self, planner_model: LLMRequest | None, default_limit: int = 10):
|
||||
self.model = planner_model
|
||||
self.default_limit = default_limit
|
||||
|
||||
async def plan_query(self, query_text: str, context: Dict[str, Any]) -> MemoryQueryPlan:
|
||||
async def plan_query(self, query_text: str, context: dict[str, Any]) -> MemoryQueryPlan:
|
||||
if not self.model:
|
||||
logger.debug("未提供查询规划模型,使用默认规划")
|
||||
return self._default_plan(query_text)
|
||||
@@ -82,10 +81,10 @@ class MemoryQueryPlanner:
|
||||
def _default_plan(self, query_text: str) -> MemoryQueryPlan:
|
||||
return MemoryQueryPlan(semantic_query=query_text, limit=self.default_limit)
|
||||
|
||||
def _parse_plan_dict(self, data: Dict[str, Any], fallback_query: str) -> MemoryQueryPlan:
|
||||
def _parse_plan_dict(self, data: dict[str, Any], fallback_query: str) -> MemoryQueryPlan:
|
||||
semantic_query = self._safe_str(data.get("semantic_query")) or fallback_query
|
||||
|
||||
def _collect_list(key: str) -> List[str]:
|
||||
def _collect_list(key: str) -> list[str]:
|
||||
value = data.get(key)
|
||||
if isinstance(value, str):
|
||||
return [value]
|
||||
@@ -94,7 +93,7 @@ class MemoryQueryPlanner:
|
||||
return []
|
||||
|
||||
memory_type_values = _collect_list("memory_types")
|
||||
memory_types: List[MemoryType] = []
|
||||
memory_types: list[MemoryType] = []
|
||||
for item in memory_type_values:
|
||||
if not item:
|
||||
continue
|
||||
@@ -123,7 +122,7 @@ class MemoryQueryPlanner:
|
||||
)
|
||||
return plan
|
||||
|
||||
def _build_prompt(self, query_text: str, context: Dict[str, Any]) -> str:
|
||||
def _build_prompt(self, query_text: str, context: dict[str, Any]) -> str:
|
||||
participants = context.get("participants") or context.get("speaker_names") or []
|
||||
if isinstance(participants, str):
|
||||
participants = [participants]
|
||||
@@ -206,7 +205,7 @@ class MemoryQueryPlanner:
|
||||
请直接输出符合要求的 JSON 对象,禁止添加额外文本或 Markdown 代码块。
|
||||
"""
|
||||
|
||||
def _extract_json_payload(self, response: str) -> Optional[str]:
|
||||
def _extract_json_payload(self, response: str) -> str | None:
|
||||
if not response:
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
精准记忆系统核心模块
|
||||
1. 基于文档设计的高效记忆构建、存储与召回优化系统,覆盖构建、向量化与多阶段检索全流程。
|
||||
@@ -6,26 +5,27 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import orjson
|
||||
import re
|
||||
import hashlib
|
||||
from typing import Dict, List, Optional, Set, Any, TYPE_CHECKING
|
||||
import re
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.chat.memory_system.memory_builder import MemoryBuilder, MemoryExtractionError
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
|
||||
from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -121,7 +121,7 @@ class MemorySystemConfig:
|
||||
class MemorySystem:
|
||||
"""精准记忆系统核心类"""
|
||||
|
||||
def __init__(self, llm_model: Optional[LLMRequest] = None, config: Optional[MemorySystemConfig] = None):
|
||||
def __init__(self, llm_model: LLMRequest | None = None, config: MemorySystemConfig | None = None):
|
||||
self.config = config or MemorySystemConfig.from_global_config()
|
||||
self.llm_model = llm_model
|
||||
self.status = MemorySystemStatus.INITIALIZING
|
||||
@@ -131,7 +131,7 @@ class MemorySystem:
|
||||
self.fusion_engine: MemoryFusionEngine = None
|
||||
self.unified_storage = None # 统一存储系统
|
||||
self.query_planner: MemoryQueryPlanner = None
|
||||
self.forgetting_engine: Optional[MemoryForgettingEngine] = None
|
||||
self.forgetting_engine: MemoryForgettingEngine | None = None
|
||||
|
||||
# LLM模型
|
||||
self.value_assessment_model: LLMRequest = None
|
||||
@@ -143,10 +143,10 @@ class MemorySystem:
|
||||
self.last_retrieval_time = None
|
||||
|
||||
# 构建节流记录
|
||||
self._last_memory_build_times: Dict[str, float] = {}
|
||||
self._last_memory_build_times: dict[str, float] = {}
|
||||
|
||||
# 记忆指纹缓存,用于快速检测重复记忆
|
||||
self._memory_fingerprints: Dict[str, str] = {}
|
||||
self._memory_fingerprints: dict[str, str] = {}
|
||||
|
||||
logger.info("MemorySystem 初始化开始")
|
||||
|
||||
@@ -210,7 +210,7 @@ class MemorySystem:
|
||||
raise
|
||||
|
||||
# 初始化遗忘引擎
|
||||
from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine, ForgettingConfig
|
||||
from src.chat.memory_system.memory_forgetting_engine import ForgettingConfig, MemoryForgettingEngine
|
||||
|
||||
# 从全局配置创建遗忘引擎配置
|
||||
forgetting_config = ForgettingConfig(
|
||||
@@ -241,7 +241,7 @@ class MemorySystem:
|
||||
self.forgetting_engine = MemoryForgettingEngine(forgetting_config)
|
||||
|
||||
planner_task_config = getattr(model_config.model_task_config, "utils_small", None)
|
||||
planner_model: Optional[LLMRequest] = None
|
||||
planner_model: LLMRequest | None = None
|
||||
try:
|
||||
planner_model = LLMRequest(model_set=planner_task_config, request_type="memory.query_planner")
|
||||
except Exception as planner_exc:
|
||||
@@ -261,8 +261,8 @@ class MemorySystem:
|
||||
raise
|
||||
|
||||
async def retrieve_memories_for_building(
|
||||
self, query_text: str, user_id: Optional[str] = None, context: Optional[Dict[str, Any]] = None, limit: int = 5
|
||||
) -> List[MemoryChunk]:
|
||||
self, query_text: str, user_id: str | None = None, context: dict[str, Any] | None = None, limit: int = 5
|
||||
) -> list[MemoryChunk]:
|
||||
"""在构建记忆时检索相关记忆,使用统一存储系统
|
||||
|
||||
Args:
|
||||
@@ -302,8 +302,8 @@ class MemorySystem:
|
||||
return []
|
||||
|
||||
async def build_memory_from_conversation(
|
||||
self, conversation_text: str, context: Dict[str, Any], timestamp: Optional[float] = None
|
||||
) -> List[MemoryChunk]:
|
||||
self, conversation_text: str, context: dict[str, Any], timestamp: float | None = None
|
||||
) -> list[MemoryChunk]:
|
||||
"""从对话中构建记忆
|
||||
|
||||
Args:
|
||||
@@ -318,8 +318,8 @@ class MemorySystem:
|
||||
self.status = MemorySystemStatus.BUILDING
|
||||
start_time = time.time()
|
||||
|
||||
build_scope_key: Optional[str] = None
|
||||
build_marker_time: Optional[float] = None
|
||||
build_scope_key: str | None = None
|
||||
build_marker_time: float | None = None
|
||||
|
||||
try:
|
||||
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, timestamp)
|
||||
@@ -408,7 +408,7 @@ class MemorySystem:
|
||||
logger.error(f"❌ 记忆构建失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _log_memory_preview(self, memories: List[MemoryChunk]) -> None:
|
||||
def _log_memory_preview(self, memories: list[MemoryChunk]) -> None:
|
||||
"""在控制台输出记忆预览,便于人工检查"""
|
||||
if not memories:
|
||||
logger.info("📝 本次未生成新的记忆")
|
||||
@@ -425,12 +425,12 @@ class MemorySystem:
|
||||
f"置信度={memory.metadata.confidence.name} | 内容={text}"
|
||||
)
|
||||
|
||||
async def _collect_fusion_candidates(self, new_memories: List[MemoryChunk]) -> List[MemoryChunk]:
|
||||
async def _collect_fusion_candidates(self, new_memories: list[MemoryChunk]) -> list[MemoryChunk]:
|
||||
"""收集与新记忆相似的现有记忆,便于融合去重"""
|
||||
if not new_memories:
|
||||
return []
|
||||
|
||||
candidate_ids: Set[str] = set()
|
||||
candidate_ids: set[str] = set()
|
||||
new_memory_ids = {memory.memory_id for memory in new_memories if memory and getattr(memory, "memory_id", None)}
|
||||
|
||||
# 基于指纹的直接匹配
|
||||
@@ -493,7 +493,7 @@ class MemorySystem:
|
||||
continue
|
||||
candidate_ids.add(memory_id)
|
||||
|
||||
existing_candidates: List[MemoryChunk] = []
|
||||
existing_candidates: list[MemoryChunk] = []
|
||||
cache = self.unified_storage.memory_cache if self.unified_storage else {}
|
||||
for candidate_id in candidate_ids:
|
||||
if candidate_id in new_memory_ids:
|
||||
@@ -511,7 +511,7 @@ class MemorySystem:
|
||||
|
||||
return existing_candidates
|
||||
|
||||
async def process_conversation_memory(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def process_conversation_memory(self, context: dict[str, Any]) -> dict[str, Any]:
|
||||
"""对外暴露的对话记忆处理接口,仅依赖上下文信息"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -559,12 +559,12 @@ class MemorySystem:
|
||||
|
||||
async def retrieve_relevant_memories(
|
||||
self,
|
||||
query_text: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
query_text: str | None = None,
|
||||
user_id: str | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
limit: int = 5,
|
||||
**kwargs,
|
||||
) -> List[MemoryChunk]:
|
||||
) -> list[MemoryChunk]:
|
||||
"""检索相关记忆(三阶段召回:元数据粗筛 → 向量精筛 → 综合重排)"""
|
||||
raw_query = query_text or kwargs.get("query")
|
||||
if not raw_query:
|
||||
@@ -750,7 +750,7 @@ class MemorySystem:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _extract_json_payload(response: str) -> Optional[str]:
|
||||
def _extract_json_payload(response: str) -> str | None:
|
||||
"""从模型响应中提取JSON部分,兼容Markdown代码块等格式"""
|
||||
if not response:
|
||||
return None
|
||||
@@ -773,10 +773,10 @@ class MemorySystem:
|
||||
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
|
||||
|
||||
def _normalize_context(
|
||||
self, raw_context: Optional[Dict[str, Any]], user_id: Optional[str], timestamp: Optional[float]
|
||||
) -> Dict[str, Any]:
|
||||
self, raw_context: dict[str, Any] | None, user_id: str | None, timestamp: float | None
|
||||
) -> dict[str, Any]:
|
||||
"""标准化上下文,确保必备字段存在且格式正确"""
|
||||
context: Dict[str, Any] = {}
|
||||
context: dict[str, Any] = {}
|
||||
if raw_context:
|
||||
try:
|
||||
context = dict(raw_context)
|
||||
@@ -822,7 +822,7 @@ class MemorySystem:
|
||||
|
||||
return context
|
||||
|
||||
async def _build_enhanced_query_context(self, raw_query: str, normalized_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def _build_enhanced_query_context(self, raw_query: str, normalized_context: dict[str, Any]) -> dict[str, Any]:
|
||||
"""构建包含未读消息综合上下文的增强查询上下文
|
||||
|
||||
Args:
|
||||
@@ -861,7 +861,7 @@ class MemorySystem:
|
||||
|
||||
return enhanced_context
|
||||
|
||||
async def _collect_unread_messages_context(self, stream_id: str) -> Optional[Dict[str, Any]]:
|
||||
async def _collect_unread_messages_context(self, stream_id: str) -> dict[str, Any] | None:
|
||||
"""收集未读消息的综合上下文信息
|
||||
|
||||
Args:
|
||||
@@ -953,7 +953,7 @@ class MemorySystem:
|
||||
logger.warning(f"收集未读消息上下文失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _build_unread_context_summary(self, messages_summary: List[Dict[str, Any]]) -> str:
|
||||
def _build_unread_context_summary(self, messages_summary: list[dict[str, Any]]) -> str:
|
||||
"""构建未读消息的文本摘要
|
||||
|
||||
Args:
|
||||
@@ -974,7 +974,7 @@ class MemorySystem:
|
||||
|
||||
return " | ".join(summary_parts)
|
||||
|
||||
async def _resolve_conversation_context(self, fallback_text: str, context: Optional[Dict[str, Any]]) -> str:
|
||||
async def _resolve_conversation_context(self, fallback_text: str, context: dict[str, Any] | None) -> str:
|
||||
"""使用 stream_id 历史消息和相关记忆充实对话文本,默认回退到传入文本"""
|
||||
if not context:
|
||||
return fallback_text
|
||||
@@ -1043,11 +1043,11 @@ class MemorySystem:
|
||||
# 回退到传入文本
|
||||
return fallback_text
|
||||
|
||||
def _get_build_scope_key(self, context: Dict[str, Any], user_id: Optional[str]) -> Optional[str]:
|
||||
def _get_build_scope_key(self, context: dict[str, Any], user_id: str | None) -> str | None:
|
||||
"""确定用于节流控制的记忆构建作用域"""
|
||||
return "global_scope"
|
||||
|
||||
def _determine_history_limit(self, context: Dict[str, Any]) -> int:
|
||||
def _determine_history_limit(self, context: dict[str, Any]) -> int:
|
||||
"""确定历史消息获取数量,限制在30-50之间"""
|
||||
default_limit = 40
|
||||
candidate = context.get("history_limit") or context.get("history_window") or context.get("memory_history_limit")
|
||||
@@ -1065,12 +1065,12 @@ class MemorySystem:
|
||||
|
||||
return history_limit
|
||||
|
||||
def _format_history_messages(self, messages: List["DatabaseMessages"]) -> Optional[str]:
|
||||
def _format_history_messages(self, messages: list["DatabaseMessages"]) -> str | None:
|
||||
"""将历史消息格式化为可供LLM处理的多轮对话文本"""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
lines: List[str] = []
|
||||
lines: list[str] = []
|
||||
for msg in messages:
|
||||
try:
|
||||
content = getattr(msg, "processed_plain_text", None) or getattr(msg, "display_message", None)
|
||||
@@ -1105,7 +1105,7 @@ class MemorySystem:
|
||||
|
||||
return "\n".join(lines) if lines else None
|
||||
|
||||
async def _assess_information_value(self, text: str, context: Dict[str, Any]) -> float:
|
||||
async def _assess_information_value(self, text: str, context: dict[str, Any]) -> float:
|
||||
"""评估信息价值
|
||||
|
||||
Args:
|
||||
@@ -1201,7 +1201,7 @@ class MemorySystem:
|
||||
logger.error(f"信息价值评估失败: {e}", exc_info=True)
|
||||
return 0.5 # 默认中等价值
|
||||
|
||||
async def _store_memories_unified(self, memory_chunks: List[MemoryChunk]) -> int:
|
||||
async def _store_memories_unified(self, memory_chunks: list[MemoryChunk]) -> int:
|
||||
"""使用统一存储系统存储记忆块"""
|
||||
if not memory_chunks or not self.unified_storage:
|
||||
return 0
|
||||
@@ -1222,7 +1222,7 @@ class MemorySystem:
|
||||
return 0
|
||||
|
||||
# 保留原有方法以兼容旧代码
|
||||
async def _store_memories(self, memory_chunks: List[MemoryChunk]) -> int:
|
||||
async def _store_memories(self, memory_chunks: list[MemoryChunk]) -> int:
|
||||
"""兼容性方法:重定向到统一存储"""
|
||||
return await self._store_memories_unified(memory_chunks)
|
||||
|
||||
@@ -1271,7 +1271,7 @@ class MemorySystem:
|
||||
key = self._fingerprint_key(memory.user_id, fingerprint)
|
||||
self._memory_fingerprints[key] = memory.memory_id
|
||||
|
||||
def _register_memory_fingerprints(self, memories: List[MemoryChunk]) -> None:
|
||||
def _register_memory_fingerprints(self, memories: list[MemoryChunk]) -> None:
|
||||
for memory in memories:
|
||||
fingerprint = self._build_memory_fingerprint(memory)
|
||||
key = self._fingerprint_key(memory.user_id, fingerprint)
|
||||
@@ -1302,9 +1302,9 @@ class MemorySystem:
|
||||
|
||||
@staticmethod
|
||||
def _fingerprint_key(user_id: str, fingerprint: str) -> str:
|
||||
return f"{str(user_id)}:{fingerprint}"
|
||||
return f"{user_id!s}:{fingerprint}"
|
||||
|
||||
def get_system_stats(self) -> Dict[str, Any]:
|
||||
def get_system_stats(self) -> dict[str, Any]:
|
||||
"""获取系统统计信息"""
|
||||
return {
|
||||
"status": self.status.value,
|
||||
@@ -1314,7 +1314,7 @@ class MemorySystem:
|
||||
"config": asdict(self.config),
|
||||
}
|
||||
|
||||
def _compute_memory_score(self, query_text: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
|
||||
def _compute_memory_score(self, query_text: str, memory: MemoryChunk, context: dict[str, Any]) -> float:
|
||||
"""根据查询和上下文为记忆计算匹配分数"""
|
||||
tokens_query = self._tokenize_text(query_text)
|
||||
tokens_memory = self._tokenize_text(memory.text_content)
|
||||
@@ -1338,7 +1338,7 @@ class MemorySystem:
|
||||
final_score = base_score * 0.7 + keyword_overlap * 0.15 + importance_boost + confidence_boost
|
||||
return max(0.0, min(1.0, final_score))
|
||||
|
||||
def _tokenize_text(self, text: str) -> Set[str]:
|
||||
def _tokenize_text(self, text: str) -> set[str]:
|
||||
"""简单分词,兼容中英文"""
|
||||
if not text:
|
||||
return set()
|
||||
@@ -1450,7 +1450,7 @@ def get_memory_system() -> MemorySystem:
|
||||
return memory_system
|
||||
|
||||
|
||||
async def initialize_memory_system(llm_model: Optional[LLMRequest] = None):
|
||||
async def initialize_memory_system(llm_model: LLMRequest | None = None):
|
||||
"""初始化全局记忆系统"""
|
||||
global memory_system
|
||||
if memory_system is None:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
基于Vector DB的统一记忆存储系统 V2
|
||||
使用ChromaDB作为底层存储,替代JSON存储方式
|
||||
@@ -11,20 +10,21 @@
|
||||
- 自动清理过期记忆
|
||||
"""
|
||||
|
||||
import time
|
||||
import orjson
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.vector_db import vector_db_service
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, ConfidenceLevel, ImportanceLevel
|
||||
import orjson
|
||||
|
||||
from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk
|
||||
from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine
|
||||
from src.chat.memory_system.memory_metadata_index import MemoryMetadataIndex, MemoryMetadataIndexEntry
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.common.logger import get_logger
|
||||
from src.common.vector_db import vector_db_service
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -32,7 +32,7 @@ logger = get_logger(__name__)
|
||||
_ENUM_MAPPINGS_CACHE = {}
|
||||
|
||||
|
||||
def _build_enum_mapping(enum_class: type) -> Dict[str, Any]:
|
||||
def _build_enum_mapping(enum_class: type) -> dict[str, Any]:
|
||||
"""构建枚举类的完整映射表
|
||||
|
||||
Args:
|
||||
@@ -145,7 +145,7 @@ class VectorMemoryStorage:
|
||||
|
||||
"""基于Vector DB的记忆存储系统"""
|
||||
|
||||
def __init__(self, config: Optional[VectorStorageConfig] = None):
|
||||
def __init__(self, config: VectorStorageConfig | None = None):
|
||||
# 默认从全局配置读取,如果没有传入config
|
||||
if config is None:
|
||||
try:
|
||||
@@ -163,15 +163,15 @@ class VectorMemoryStorage:
|
||||
self.vector_db_service = vector_db_service
|
||||
|
||||
# 内存缓存
|
||||
self.memory_cache: Dict[str, MemoryChunk] = {}
|
||||
self.cache_timestamps: Dict[str, float] = {}
|
||||
self.memory_cache: dict[str, MemoryChunk] = {}
|
||||
self.cache_timestamps: dict[str, float] = {}
|
||||
self._cache = self.memory_cache # 别名,兼容旧代码
|
||||
|
||||
# 元数据索引管理器(JSON文件索引)
|
||||
self.metadata_index = MemoryMetadataIndex()
|
||||
|
||||
# 遗忘引擎
|
||||
self.forgetting_engine: Optional[MemoryForgettingEngine] = None
|
||||
self.forgetting_engine: MemoryForgettingEngine | None = None
|
||||
if self.config.enable_forgetting:
|
||||
self.forgetting_engine = MemoryForgettingEngine()
|
||||
|
||||
@@ -267,7 +267,7 @@ class VectorMemoryStorage:
|
||||
except Exception as e:
|
||||
logger.error(f"自动清理失败: {e}")
|
||||
|
||||
def _memory_to_vector_format(self, memory: MemoryChunk) -> Dict[str, Any]:
|
||||
def _memory_to_vector_format(self, memory: MemoryChunk) -> dict[str, Any]:
|
||||
"""将MemoryChunk转换为向量存储格式"""
|
||||
try:
|
||||
# 获取memory_id
|
||||
@@ -323,7 +323,7 @@ class VectorMemoryStorage:
|
||||
logger.error(f"转换记忆 {memory_id} 到向量格式失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _vector_result_to_memory(self, document: str, metadata: Dict[str, Any]) -> Optional[MemoryChunk]:
|
||||
def _vector_result_to_memory(self, document: str, metadata: dict[str, Any]) -> MemoryChunk | None:
|
||||
"""将Vector DB结果转换为MemoryChunk"""
|
||||
try:
|
||||
# 从元数据中恢复完整记忆
|
||||
@@ -440,7 +440,7 @@ class VectorMemoryStorage:
|
||||
logger.warning(f"不支持的{enum_class.__name__}值类型: {type(value)},使用默认值")
|
||||
return default
|
||||
|
||||
def _get_from_cache(self, memory_id: str) -> Optional[MemoryChunk]:
|
||||
def _get_from_cache(self, memory_id: str) -> MemoryChunk | None:
|
||||
"""从缓存获取记忆"""
|
||||
if not self.config.enable_caching:
|
||||
return None
|
||||
@@ -472,7 +472,7 @@ class VectorMemoryStorage:
|
||||
self.memory_cache[memory_id] = memory
|
||||
self.cache_timestamps[memory_id] = time.time()
|
||||
|
||||
async def store_memories(self, memories: List[MemoryChunk]) -> int:
|
||||
async def store_memories(self, memories: list[MemoryChunk]) -> int:
|
||||
"""批量存储记忆"""
|
||||
if not memories:
|
||||
return 0
|
||||
@@ -603,11 +603,11 @@ class VectorMemoryStorage:
|
||||
self,
|
||||
query_text: str,
|
||||
limit: int = 10,
|
||||
similarity_threshold: Optional[float] = None,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
filters: dict[str, Any] | None = None,
|
||||
# 新增:元数据过滤参数(用于JSON索引粗筛)
|
||||
metadata_filters: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Tuple[MemoryChunk, float]]:
|
||||
metadata_filters: dict[str, Any] | None = None,
|
||||
) -> list[tuple[MemoryChunk, float]]:
|
||||
"""
|
||||
搜索相似记忆(混合索引模式)
|
||||
|
||||
@@ -632,7 +632,7 @@ class VectorMemoryStorage:
|
||||
|
||||
try:
|
||||
# === 阶段一:JSON元数据粗筛(可选) ===
|
||||
candidate_ids: Optional[List[str]] = None
|
||||
candidate_ids: list[str] | None = None
|
||||
if metadata_filters:
|
||||
logger.debug(f"[JSON元数据粗筛] 开始,过滤条件: {metadata_filters}")
|
||||
candidate_ids = self.metadata_index.search(
|
||||
@@ -746,7 +746,7 @@ class VectorMemoryStorage:
|
||||
logger.error(f"搜索相似记忆失败: {e}")
|
||||
return []
|
||||
|
||||
async def get_memory_by_id(self, memory_id: str) -> Optional[MemoryChunk]:
|
||||
async def get_memory_by_id(self, memory_id: str) -> MemoryChunk | None:
|
||||
"""根据ID获取记忆"""
|
||||
# 首先尝试从缓存获取
|
||||
memory = self._get_from_cache(memory_id)
|
||||
@@ -772,7 +772,7 @@ class VectorMemoryStorage:
|
||||
|
||||
return None
|
||||
|
||||
async def get_memories_by_filters(self, filters: Dict[str, Any], limit: int = 100) -> List[MemoryChunk]:
|
||||
async def get_memories_by_filters(self, filters: dict[str, Any], limit: int = 100) -> list[MemoryChunk]:
|
||||
"""根据过滤条件获取记忆"""
|
||||
try:
|
||||
results = vector_db_service.get(collection_name=self.config.memory_collection, where=filters, limit=limit)
|
||||
@@ -848,7 +848,7 @@ class VectorMemoryStorage:
|
||||
logger.error(f"删除记忆 {memory_id} 失败: {e}")
|
||||
return False
|
||||
|
||||
async def delete_memories_by_filters(self, filters: Dict[str, Any]) -> int:
|
||||
async def delete_memories_by_filters(self, filters: dict[str, Any]) -> int:
|
||||
"""根据过滤条件批量删除记忆"""
|
||||
try:
|
||||
# 先获取要删除的记忆ID
|
||||
@@ -880,7 +880,7 @@ class VectorMemoryStorage:
|
||||
logger.error(f"批量删除记忆失败: {e}")
|
||||
return 0
|
||||
|
||||
async def perform_forgetting_check(self) -> Dict[str, Any]:
|
||||
async def perform_forgetting_check(self) -> dict[str, Any]:
|
||||
"""执行遗忘检查"""
|
||||
if not self.forgetting_engine:
|
||||
return {"error": "遗忘引擎未启用"}
|
||||
@@ -925,7 +925,7 @@ class VectorMemoryStorage:
|
||||
logger.error(f"执行遗忘检查失败: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def get_storage_stats(self) -> Dict[str, Any]:
|
||||
def get_storage_stats(self) -> dict[str, Any]:
|
||||
"""获取存储统计信息"""
|
||||
try:
|
||||
current_total = vector_db_service.count(self.config.memory_collection)
|
||||
@@ -960,7 +960,7 @@ class VectorMemoryStorage:
|
||||
_global_vector_storage = None
|
||||
|
||||
|
||||
def get_vector_memory_storage(config: Optional[VectorStorageConfig] = None) -> VectorMemoryStorage:
|
||||
def get_vector_memory_storage(config: VectorStorageConfig | None = None) -> VectorMemoryStorage:
|
||||
"""获取全局Vector记忆存储实例"""
|
||||
global _global_vector_storage
|
||||
|
||||
@@ -974,15 +974,15 @@ def get_vector_memory_storage(config: Optional[VectorStorageConfig] = None) -> V
|
||||
class VectorMemoryStorageAdapter:
|
||||
"""适配器类,提供与原UnifiedMemoryStorage兼容的接口"""
|
||||
|
||||
def __init__(self, config: Optional[VectorStorageConfig] = None):
|
||||
def __init__(self, config: VectorStorageConfig | None = None):
|
||||
self.storage = VectorMemoryStorage(config)
|
||||
|
||||
async def store_memories(self, memories: List[MemoryChunk]) -> int:
|
||||
async def store_memories(self, memories: list[MemoryChunk]) -> int:
|
||||
return await self.storage.store_memories(memories)
|
||||
|
||||
async def search_similar_memories(
|
||||
self, query_text: str, limit: int = 10, scope_id: Optional[str] = None, filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[Tuple[str, float]]:
|
||||
self, query_text: str, limit: int = 10, scope_id: str | None = None, filters: dict[str, Any] | None = None
|
||||
) -> list[tuple[str, float]]:
|
||||
results = await self.storage.search_similar_memories(query_text, limit, filters=filters)
|
||||
# 转换为原格式:(memory_id, similarity)
|
||||
return [
|
||||
@@ -990,7 +990,7 @@ class VectorMemoryStorageAdapter:
|
||||
for memory, similarity in results
|
||||
]
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
return self.storage.get_storage_stats()
|
||||
|
||||
|
||||
|
||||
@@ -3,14 +3,14 @@
|
||||
提供统一的消息管理、上下文管理和流循环调度功能
|
||||
"""
|
||||
|
||||
from .message_manager import MessageManager, message_manager
|
||||
from .context_manager import SingleStreamContextManager
|
||||
from .distribution_manager import StreamLoopManager, stream_loop_manager
|
||||
from .message_manager import MessageManager, message_manager
|
||||
|
||||
__all__ = [
|
||||
"MessageManager",
|
||||
"message_manager",
|
||||
"SingleStreamContextManager",
|
||||
"StreamLoopManager",
|
||||
"message_manager",
|
||||
"stream_loop_manager",
|
||||
]
|
||||
|
||||
@@ -6,13 +6,14 @@
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
from src.chat.energy_system import energy_manager
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.chat.energy_system import energy_manager
|
||||
|
||||
from .distribution_manager import stream_loop_manager
|
||||
|
||||
logger = get_logger("context_manager")
|
||||
@@ -21,7 +22,7 @@ logger = get_logger("context_manager")
|
||||
class SingleStreamContextManager:
|
||||
"""单流上下文管理器 - 每个实例只管理一个 stream 的上下文"""
|
||||
|
||||
def __init__(self, stream_id: str, context: StreamContext, max_context_size: Optional[int] = None):
|
||||
def __init__(self, stream_id: str, context: StreamContext, max_context_size: int | None = None):
|
||||
self.stream_id = stream_id
|
||||
self.context = context
|
||||
|
||||
@@ -66,7 +67,7 @@ class SingleStreamContextManager:
|
||||
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def update_message(self, message_id: str, updates: Dict[str, Any]) -> bool:
|
||||
async def update_message(self, message_id: str, updates: dict[str, Any]) -> bool:
|
||||
"""更新上下文中的消息
|
||||
|
||||
Args:
|
||||
@@ -84,7 +85,7 @@ class SingleStreamContextManager:
|
||||
logger.error(f"更新单流上下文消息失败 {self.stream_id}/{message_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def get_messages(self, limit: Optional[int] = None, include_unread: bool = True) -> List[DatabaseMessages]:
|
||||
def get_messages(self, limit: int | None = None, include_unread: bool = True) -> list[DatabaseMessages]:
|
||||
"""获取上下文消息
|
||||
|
||||
Args:
|
||||
@@ -117,7 +118,7 @@ class SingleStreamContextManager:
|
||||
logger.error(f"获取单流上下文消息失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def get_unread_messages(self) -> List[DatabaseMessages]:
|
||||
def get_unread_messages(self) -> list[DatabaseMessages]:
|
||||
"""获取未读消息"""
|
||||
try:
|
||||
return self.context.get_unread_messages()
|
||||
@@ -125,7 +126,7 @@ class SingleStreamContextManager:
|
||||
logger.error(f"获取单流未读消息失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def mark_messages_as_read(self, message_ids: List[str]) -> bool:
|
||||
def mark_messages_as_read(self, message_ids: list[str]) -> bool:
|
||||
"""标记消息为已读"""
|
||||
try:
|
||||
if not hasattr(self.context, "mark_message_as_read"):
|
||||
@@ -168,7 +169,7 @@ class SingleStreamContextManager:
|
||||
logger.error(f"清空单流上下文失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
def get_statistics(self) -> dict[str, Any]:
|
||||
"""获取流统计信息"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
@@ -285,7 +286,7 @@ class SingleStreamContextManager:
|
||||
logger.error(f"添加消息到单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def update_message_async(self, message_id: str, updates: Dict[str, Any]) -> bool:
|
||||
async def update_message_async(self, message_id: str, updates: dict[str, Any]) -> bool:
|
||||
"""异步实现的 update_message:更新消息并在需要时 await 能量更新。"""
|
||||
try:
|
||||
self.context.update_message_info(message_id, **updates)
|
||||
@@ -327,7 +328,7 @@ class SingleStreamContextManager:
|
||||
"""更新流能量"""
|
||||
try:
|
||||
history_messages = self.context.get_history_messages(limit=self.max_context_size)
|
||||
messages: List[DatabaseMessages] = list(history_messages)
|
||||
messages: list[DatabaseMessages] = list(history_messages)
|
||||
|
||||
if include_unread:
|
||||
messages.extend(self.get_unread_messages())
|
||||
|
||||
@@ -5,12 +5,12 @@
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
from src.chat.chatter_manager import ChatterManager
|
||||
from src.chat.energy_system import energy_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.energy_system import energy_manager
|
||||
from src.chat.chatter_manager import ChatterManager
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
|
||||
logger = get_logger("stream_loop_manager")
|
||||
@@ -19,13 +19,13 @@ logger = get_logger("stream_loop_manager")
|
||||
class StreamLoopManager:
|
||||
"""流循环管理器 - 每个流一个独立的无限循环任务"""
|
||||
|
||||
def __init__(self, max_concurrent_streams: Optional[int] = None):
|
||||
def __init__(self, max_concurrent_streams: int | None = None):
|
||||
# 流循环任务管理
|
||||
self.stream_loops: Dict[str, asyncio.Task] = {}
|
||||
self.stream_loops: dict[str, asyncio.Task] = {}
|
||||
self.loop_lock = asyncio.Lock()
|
||||
|
||||
# 统计信息
|
||||
self.stats: Dict[str, Any] = {
|
||||
self.stats: dict[str, Any] = {
|
||||
"active_streams": 0,
|
||||
"total_loops": 0,
|
||||
"total_process_cycles": 0,
|
||||
@@ -37,13 +37,13 @@ class StreamLoopManager:
|
||||
self.max_concurrent_streams = max_concurrent_streams or global_config.chat.max_concurrent_distributions
|
||||
|
||||
# 强制分发策略
|
||||
self.force_dispatch_unread_threshold: Optional[int] = getattr(
|
||||
self.force_dispatch_unread_threshold: int | None = getattr(
|
||||
global_config.chat, "force_dispatch_unread_threshold", 20
|
||||
)
|
||||
self.force_dispatch_min_interval: float = getattr(global_config.chat, "force_dispatch_min_interval", 0.1)
|
||||
|
||||
# Chatter管理器
|
||||
self.chatter_manager: Optional[ChatterManager] = None
|
||||
self.chatter_manager: ChatterManager | None = None
|
||||
|
||||
# 状态控制
|
||||
self.is_running = False
|
||||
@@ -212,7 +212,7 @@ class StreamLoopManager:
|
||||
|
||||
logger.info(f"流循环结束: {stream_id}")
|
||||
|
||||
async def _get_stream_context(self, stream_id: str) -> Optional[Any]:
|
||||
async def _get_stream_context(self, stream_id: str) -> Any | None:
|
||||
"""获取流上下文
|
||||
|
||||
Args:
|
||||
@@ -320,7 +320,7 @@ class StreamLoopManager:
|
||||
logger.debug(f"流 {stream_id} 使用默认间隔: {base_interval:.2f}s ({e})")
|
||||
return base_interval
|
||||
|
||||
def get_queue_status(self) -> Dict[str, Any]:
|
||||
def get_queue_status(self) -> dict[str, Any]:
|
||||
"""获取队列状态
|
||||
|
||||
Returns:
|
||||
@@ -374,14 +374,14 @@ class StreamLoopManager:
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def _needs_force_dispatch_for_context(self, context: Any, unread_count: Optional[int] = None) -> bool:
|
||||
def _needs_force_dispatch_for_context(self, context: Any, unread_count: int | None = None) -> bool:
|
||||
if not self.force_dispatch_unread_threshold or self.force_dispatch_unread_threshold <= 0:
|
||||
return False
|
||||
|
||||
count = unread_count if unread_count is not None else self._get_unread_count(context)
|
||||
return count > self.force_dispatch_unread_threshold
|
||||
|
||||
def get_performance_summary(self) -> Dict[str, Any]:
|
||||
def get_performance_summary(self) -> dict[str, Any]:
|
||||
"""获取性能摘要
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -6,19 +6,20 @@
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
from typing import Dict, Optional, Any, TYPE_CHECKING, List
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from src.chat.chatter_manager import ChatterManager
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.message_manager_data_model import MessageManagerStats, StreamStats
|
||||
from src.chat.chatter_manager import ChatterManager
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from .sleep_manager.sleep_manager import SleepManager
|
||||
from .sleep_manager.wakeup_manager import WakeUpManager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
|
||||
from .distribution_manager import stream_loop_manager
|
||||
from .sleep_manager.sleep_manager import SleepManager
|
||||
from .sleep_manager.wakeup_manager import WakeUpManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
@@ -32,7 +33,7 @@ class MessageManager:
|
||||
def __init__(self, check_interval: float = 5.0):
|
||||
self.check_interval = check_interval # 检查间隔(秒)
|
||||
self.is_running = False
|
||||
self.manager_task: Optional[asyncio.Task] = None
|
||||
self.manager_task: asyncio.Task | None = None
|
||||
|
||||
# 统计信息
|
||||
self.stats = MessageManagerStats()
|
||||
@@ -125,7 +126,7 @@ class MessageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"更新消息 {message_id} 时发生错误: {e}")
|
||||
|
||||
async def bulk_update_messages(self, stream_id: str, updates: List[Dict[str, Any]]) -> int:
|
||||
async def bulk_update_messages(self, stream_id: str, updates: list[dict[str, Any]]) -> int:
|
||||
"""批量更新消息信息,降低更新频率"""
|
||||
if not updates:
|
||||
return 0
|
||||
@@ -214,7 +215,7 @@ class MessageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"激活聊天流 {stream_id} 时发生错误: {e}")
|
||||
|
||||
def get_stream_stats(self, stream_id: str) -> Optional[StreamStats]:
|
||||
def get_stream_stats(self, stream_id: str) -> StreamStats | None:
|
||||
"""获取聊天流统计"""
|
||||
try:
|
||||
# 通过 ChatManager 获取 ChatStream
|
||||
@@ -243,7 +244,7 @@ class MessageManager:
|
||||
logger.error(f"获取聊天流 {stream_id} 统计时发生错误: {e}")
|
||||
return None
|
||||
|
||||
def get_manager_stats(self) -> Dict[str, Any]:
|
||||
def get_manager_stats(self) -> dict[str, Any]:
|
||||
"""获取管理器统计"""
|
||||
return {
|
||||
"total_streams": self.stats.total_streams,
|
||||
@@ -278,7 +279,7 @@ class MessageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"清理不活跃聊天流时发生错误: {e}")
|
||||
|
||||
async def _check_and_handle_interruption(self, chat_stream: Optional[ChatStream] = None):
|
||||
async def _check_and_handle_interruption(self, chat_stream: ChatStream | None = None):
|
||||
"""检查并处理消息打断"""
|
||||
if not global_config.chat.interruption_enabled:
|
||||
return
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import asyncio
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
from .notification_sender import NotificationSender
|
||||
from .sleep_state import SleepState, SleepContext
|
||||
from .sleep_state import SleepContext, SleepState
|
||||
from .time_checker import TimeChecker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -92,7 +93,7 @@ class SleepManager:
|
||||
elif current_state == SleepState.WOKEN_UP:
|
||||
self._handle_woken_up(now, is_in_theoretical_sleep, wakeup_manager)
|
||||
|
||||
def _handle_awake_to_sleep(self, now: datetime, activity: Optional[str], wakeup_manager: Optional["WakeUpManager"]):
|
||||
def _handle_awake_to_sleep(self, now: datetime, activity: str | None, wakeup_manager: Optional["WakeUpManager"]):
|
||||
"""处理从“清醒”到“准备入睡”的状态转换。"""
|
||||
if activity:
|
||||
logger.info(f"进入理论休眠时间 '{activity}',开始进行睡眠决策...")
|
||||
@@ -181,7 +182,7 @@ class SleepManager:
|
||||
self,
|
||||
now: datetime,
|
||||
is_in_theoretical_sleep: bool,
|
||||
activity: Optional[str],
|
||||
activity: str | None,
|
||||
wakeup_manager: Optional["WakeUpManager"],
|
||||
):
|
||||
"""处理“正在睡觉”状态下的逻辑。"""
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from datetime import date, datetime
|
||||
from enum import Enum, auto
|
||||
from datetime import datetime, date
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.manager.local_store_manager import local_storage
|
||||
@@ -29,10 +28,10 @@ class SleepContext:
|
||||
def __init__(self):
|
||||
"""初始化睡眠上下文,并从本地存储加载初始状态。"""
|
||||
self.current_state: SleepState = SleepState.AWAKE
|
||||
self.sleep_buffer_end_time: Optional[datetime] = None
|
||||
self.sleep_buffer_end_time: datetime | None = None
|
||||
self.total_delayed_minutes_today: float = 0.0
|
||||
self.last_sleep_check_date: Optional[date] = None
|
||||
self.re_sleep_attempt_time: Optional[datetime] = None
|
||||
self.last_sleep_check_date: date | None = None
|
||||
self.re_sleep_attempt_time: datetime | None = None
|
||||
self.load()
|
||||
|
||||
def save(self):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime, time, timedelta
|
||||
from typing import Optional, List, Dict, Any
|
||||
import random
|
||||
from datetime import datetime, time, timedelta
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
@@ -37,11 +37,11 @@ class TimeChecker:
|
||||
return self._daily_sleep_offset, self._daily_wake_offset
|
||||
|
||||
@staticmethod
|
||||
def get_today_schedule() -> Optional[List[Dict[str, Any]]]:
|
||||
def get_today_schedule() -> list[dict[str, Any]] | None:
|
||||
"""从全局 ScheduleManager 获取今天的日程安排。"""
|
||||
return schedule_manager.today_schedule
|
||||
|
||||
def is_in_theoretical_sleep_time(self, now_time: time) -> tuple[bool, Optional[str]]:
|
||||
def is_in_theoretical_sleep_time(self, now_time: time) -> tuple[bool, str | None]:
|
||||
if global_config.sleep_system.sleep_by_schedule:
|
||||
if self.get_today_schedule():
|
||||
return self._is_in_schedule_sleep_time(now_time)
|
||||
@@ -50,7 +50,7 @@ class TimeChecker:
|
||||
else:
|
||||
return self._is_in_sleep_time(now_time)
|
||||
|
||||
def _is_in_schedule_sleep_time(self, now_time: time) -> tuple[bool, Optional[str]]:
|
||||
def _is_in_schedule_sleep_time(self, now_time: time) -> tuple[bool, str | None]:
|
||||
"""检查当前时间是否落在日程表的任何一个睡眠活动中"""
|
||||
sleep_keywords = ["休眠", "睡觉", "梦乡"]
|
||||
today_schedule = self.get_today_schedule()
|
||||
@@ -79,7 +79,7 @@ class TimeChecker:
|
||||
continue
|
||||
return False, None
|
||||
|
||||
def _is_in_sleep_time(self, now_time: time) -> tuple[bool, Optional[str]]:
|
||||
def _is_in_sleep_time(self, now_time: time) -> tuple[bool, str | None]:
|
||||
"""检查当前时间是否在固定的睡眠时间内(应用偏移量)"""
|
||||
try:
|
||||
start_time_str = global_config.sleep_system.fixed_sleep_time
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from src.chat.message_manager.sleep_manager.wakeup_context import WakeUpContext
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_manager.sleep_manager.wakeup_context import WakeUpContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .sleep_manager import SleepManager
|
||||
@@ -27,9 +28,9 @@ class WakeUpManager:
|
||||
"""
|
||||
self.sleep_manager = sleep_manager
|
||||
self.context = WakeUpContext() # 使用新的上下文管理器
|
||||
self.angry_chat_id: Optional[str] = None
|
||||
self.angry_chat_id: str | None = None
|
||||
self.last_decay_time = time.time()
|
||||
self._decay_task: Optional[asyncio.Task] = None
|
||||
self._decay_task: asyncio.Task | None = None
|
||||
self.is_running = False
|
||||
self.last_log_time = 0
|
||||
self.log_interval = 30
|
||||
@@ -104,9 +105,7 @@ class WakeUpManager:
|
||||
logger.debug(f"唤醒度衰减: {old_value:.1f} -> {self.context.wakeup_value:.1f}")
|
||||
self.context.save()
|
||||
|
||||
def add_wakeup_value(
|
||||
self, is_private_chat: bool, is_mentioned: bool = False, chat_id: Optional[str] = None
|
||||
) -> bool:
|
||||
def add_wakeup_value(self, is_private_chat: bool, is_mentioned: bool = False, chat_id: str | None = None) -> bool:
|
||||
"""
|
||||
增加唤醒度值
|
||||
|
||||
|
||||
@@ -2,9 +2,8 @@ from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_emoji_manager",
|
||||
"get_chat_manager",
|
||||
"MessageStorage",
|
||||
"get_chat_manager",
|
||||
"get_emoji_manager",
|
||||
]
|
||||
|
||||
@@ -1,25 +1,24 @@
|
||||
import traceback
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from maim_message import UserInfo
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mood.mood_manager import mood_manager # 导入情绪管理器
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.message_manager import message_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
|
||||
from src.plugin_system.base import BaseCommand, EventType
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
|
||||
# 导入反注入系统
|
||||
from src.chat.antipromptinjector import initialize_anti_injector
|
||||
from src.chat.message_manager import message_manager
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
from src.mood.mood_manager import mood_manager # 导入情绪管理器
|
||||
from src.plugin_system.base import BaseCommand, EventType
|
||||
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
|
||||
|
||||
# 获取项目根目录(假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
@@ -219,7 +218,7 @@ class ChatBot:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
try:
|
||||
await plus_command_instance.send_text(f"命令执行出错: {str(e)}")
|
||||
await plus_command_instance.send_text(f"命令执行出错: {e!s}")
|
||||
except Exception as send_error:
|
||||
logger.error(f"发送错误消息失败: {send_error}")
|
||||
|
||||
@@ -286,7 +285,7 @@ class ChatBot:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
try:
|
||||
await command_instance.send_text(f"命令执行出错: {str(e)}")
|
||||
await command_instance.send_text(f"命令执行出错: {e!s}")
|
||||
except Exception as send_error:
|
||||
logger.error(f"发送错误消息失败: {send_error}")
|
||||
|
||||
@@ -338,7 +337,7 @@ class ChatBot:
|
||||
except Exception as e:
|
||||
logger.error(f"处理适配器响应时出错: {e}")
|
||||
|
||||
async def do_s4u(self, message_data: Dict[str, Any]):
|
||||
async def do_s4u(self, message_data: dict[str, Any]):
|
||||
message = MessageRecvS4U(message_data)
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
@@ -359,7 +358,7 @@ class ChatBot:
|
||||
|
||||
return
|
||||
|
||||
async def message_process(self, message_data: Dict[str, Any]) -> None:
|
||||
async def message_process(self, message_data: dict[str, Any]) -> None:
|
||||
"""处理转化后的统一格式消息"""
|
||||
try:
|
||||
# 首先处理可能的切片消息重组
|
||||
@@ -458,7 +457,7 @@ class ChatBot:
|
||||
# TODO:暂不可用
|
||||
# 确认从接口发来的message是否有自定义的prompt模板信息
|
||||
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
||||
template_group_name: Optional[str] = message.message_info.template_info.template_name # type: ignore
|
||||
template_group_name: str | None = message.message_info.template_info.template_name # type: ignore
|
||||
template_items = message.message_info.template_info.template_items
|
||||
async with global_prompt_manager.async_message_scope(template_group_name):
|
||||
if isinstance(template_items, dict):
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import hashlib
|
||||
import time
|
||||
import copy
|
||||
from typing import Dict, Optional, TYPE_CHECKING
|
||||
from rich.traceback import install
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
from rich.traceback import install
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config # 新增导入
|
||||
|
||||
# 避免循环导入,使用TYPE_CHECKING进行类型提示
|
||||
@@ -33,8 +34,8 @@ class ChatStream:
|
||||
stream_id: str,
|
||||
platform: str,
|
||||
user_info: UserInfo,
|
||||
group_info: Optional[GroupInfo] = None,
|
||||
data: Optional[dict] = None,
|
||||
group_info: GroupInfo | None = None,
|
||||
data: dict | None = None,
|
||||
):
|
||||
self.stream_id = stream_id
|
||||
self.platform = platform
|
||||
@@ -47,7 +48,7 @@ class ChatStream:
|
||||
|
||||
# 使用StreamContext替代ChatMessageContext
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.plugin_system.base.component_types import ChatType, ChatMode
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
# 创建StreamContext
|
||||
self.stream_context: StreamContext = StreamContext(
|
||||
@@ -133,11 +134,11 @@ class ChatStream:
|
||||
|
||||
# 恢复stream_context信息
|
||||
if "stream_context_chat_type" in data:
|
||||
from src.plugin_system.base.component_types import ChatType, ChatMode
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"])
|
||||
if "stream_context_chat_mode" in data:
|
||||
from src.plugin_system.base.component_types import ChatType, ChatMode
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"])
|
||||
|
||||
@@ -163,9 +164,10 @@ class ChatStream:
|
||||
def set_context(self, message: "MessageRecv"):
|
||||
"""设置聊天消息上下文"""
|
||||
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
import json
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
# 安全获取message_info中的数据
|
||||
message_info = getattr(message, "message_info", {})
|
||||
user_info = getattr(message_info, "user_info", {})
|
||||
@@ -248,7 +250,7 @@ class ChatStream:
|
||||
f"interest_value: {db_message.interest_value}"
|
||||
)
|
||||
|
||||
def _safe_get_actions(self, message: "MessageRecv") -> Optional[list]:
|
||||
def _safe_get_actions(self, message: "MessageRecv") -> list | None:
|
||||
"""安全获取消息的actions字段"""
|
||||
try:
|
||||
actions = getattr(message, "actions", None)
|
||||
@@ -278,7 +280,7 @@ class ChatStream:
|
||||
logger.warning(f"获取actions字段失败: {e}")
|
||||
return None
|
||||
|
||||
def _extract_reply_from_segment(self, segment) -> Optional[str]:
|
||||
def _extract_reply_from_segment(self, segment) -> str | None:
|
||||
"""从消息段中提取reply_to信息"""
|
||||
try:
|
||||
if hasattr(segment, "type") and segment.type == "seglist":
|
||||
@@ -391,8 +393,8 @@ class ChatManager:
|
||||
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||||
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
|
||||
self.streams: dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||||
self.last_messages: dict[str, "MessageRecv"] = {} # stream_id -> last_message
|
||||
# try:
|
||||
# async with get_db_session() as session:
|
||||
# db.connect(reuse_if_open=True)
|
||||
@@ -414,7 +416,7 @@ class ChatManager:
|
||||
await self.load_all_streams()
|
||||
logger.info(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流")
|
||||
except Exception as e:
|
||||
logger.error(f"聊天管理器启动失败: {str(e)}")
|
||||
logger.error(f"聊天管理器启动失败: {e!s}")
|
||||
|
||||
async def _auto_save_task(self):
|
||||
"""定期自动保存所有聊天流"""
|
||||
@@ -424,7 +426,7 @@ class ChatManager:
|
||||
await self._save_all_streams()
|
||||
logger.info("聊天流自动保存完成")
|
||||
except Exception as e:
|
||||
logger.error(f"聊天流自动保存失败: {str(e)}")
|
||||
logger.error(f"聊天流自动保存失败: {e!s}")
|
||||
|
||||
def register_message(self, message: "MessageRecv"):
|
||||
"""注册消息到聊天流"""
|
||||
@@ -437,9 +439,7 @@ class ChatManager:
|
||||
# logger.debug(f"注册消息到聊天流: {stream_id}")
|
||||
|
||||
@staticmethod
|
||||
def _generate_stream_id(
|
||||
platform: str, user_info: Optional[UserInfo], group_info: Optional[GroupInfo] = None
|
||||
) -> str:
|
||||
def _generate_stream_id(platform: str, user_info: UserInfo | None, group_info: GroupInfo | None = None) -> str:
|
||||
"""生成聊天流唯一ID"""
|
||||
if not user_info and not group_info:
|
||||
raise ValueError("用户信息或群组信息必须提供")
|
||||
@@ -462,7 +462,7 @@ class ChatManager:
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
async def get_or_create_stream(
|
||||
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
|
||||
self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None
|
||||
) -> ChatStream:
|
||||
"""获取或创建聊天流
|
||||
|
||||
@@ -572,7 +572,7 @@ class ChatManager:
|
||||
await self._save_stream(stream)
|
||||
return stream
|
||||
|
||||
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
|
||||
def get_stream(self, stream_id: str) -> ChatStream | None:
|
||||
"""通过stream_id获取聊天流"""
|
||||
stream = self.streams.get(stream_id)
|
||||
if not stream:
|
||||
@@ -582,13 +582,13 @@ class ChatManager:
|
||||
return stream
|
||||
|
||||
def get_stream_by_info(
|
||||
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
|
||||
) -> Optional[ChatStream]:
|
||||
self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None
|
||||
) -> ChatStream | None:
|
||||
"""通过信息获取聊天流"""
|
||||
stream_id = self._generate_stream_id(platform, user_info, group_info)
|
||||
return self.streams.get(stream_id)
|
||||
|
||||
def get_stream_name(self, stream_id: str) -> Optional[str]:
|
||||
def get_stream_name(self, stream_id: str) -> str | None:
|
||||
"""根据 stream_id 获取聊天流名称"""
|
||||
stream = self.get_stream(stream_id)
|
||||
if not stream:
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
import base64
|
||||
import time
|
||||
from abc import abstractmethod, ABCMeta
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import urllib3
|
||||
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.utils.utils_image import get_image_manager
|
||||
from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available
|
||||
from src.chat.utils.utils_voice import get_voice_text
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -41,8 +40,8 @@ class Message(MessageBase, metaclass=ABCMeta):
|
||||
message_id: str,
|
||||
chat_stream: "ChatStream",
|
||||
user_info: UserInfo,
|
||||
message_segment: Optional[Seg] = None,
|
||||
timestamp: Optional[float] = None,
|
||||
message_segment: Seg | None = None,
|
||||
timestamp: float | None = None,
|
||||
reply: Optional["MessageRecv"] = None,
|
||||
processed_plain_text: str = "",
|
||||
):
|
||||
@@ -264,7 +263,7 @@ class MessageRecv(Message):
|
||||
logger.warning("视频消息中没有base64数据")
|
||||
return "[收到视频消息,但数据异常]"
|
||||
except Exception as e:
|
||||
logger.error(f"视频处理失败: {str(e)}")
|
||||
logger.error(f"视频处理失败: {e!s}")
|
||||
import traceback
|
||||
|
||||
logger.error(f"错误详情: {traceback.format_exc()}")
|
||||
@@ -278,7 +277,7 @@ class MessageRecv(Message):
|
||||
logger.info("未启用视频识别")
|
||||
return "[视频]"
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
|
||||
|
||||
@@ -291,7 +290,7 @@ class MessageRecvS4U(MessageRecv):
|
||||
self.is_superchat = False
|
||||
self.gift_info = None
|
||||
self.gift_name = None
|
||||
self.gift_count: Optional[str] = None
|
||||
self.gift_count: str | None = None
|
||||
self.superchat_info = None
|
||||
self.superchat_price = None
|
||||
self.superchat_message_text = None
|
||||
@@ -444,7 +443,7 @@ class MessageRecvS4U(MessageRecv):
|
||||
logger.warning("视频消息中没有base64数据")
|
||||
return "[收到视频消息,但数据异常]"
|
||||
except Exception as e:
|
||||
logger.error(f"视频处理失败: {str(e)}")
|
||||
logger.error(f"视频处理失败: {e!s}")
|
||||
import traceback
|
||||
|
||||
logger.error(f"错误详情: {traceback.format_exc()}")
|
||||
@@ -458,7 +457,7 @@ class MessageRecvS4U(MessageRecv):
|
||||
logger.info("未启用视频识别")
|
||||
return "[视频]"
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
|
||||
|
||||
@@ -471,10 +470,10 @@ class MessageProcessBase(Message):
|
||||
message_id: str,
|
||||
chat_stream: "ChatStream",
|
||||
bot_user_info: UserInfo,
|
||||
message_segment: Optional[Seg] = None,
|
||||
message_segment: Seg | None = None,
|
||||
reply: Optional["MessageRecv"] = None,
|
||||
thinking_start_time: float = 0,
|
||||
timestamp: Optional[float] = None,
|
||||
timestamp: float | None = None,
|
||||
):
|
||||
# 调用父类初始化,传递时间戳
|
||||
super().__init__(
|
||||
@@ -533,9 +532,9 @@ class MessageProcessBase(Message):
|
||||
return f"[回复<{self.reply.message_info.user_info.user_nickname}> 的消息:{self.reply.processed_plain_text}]" # type: ignore
|
||||
return None
|
||||
else:
|
||||
return f"[{seg.type}:{str(seg.data)}]"
|
||||
return f"[{seg.type}:{seg.data!s}]"
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
|
||||
logger.error(f"处理消息段失败: {e!s}, 类型: {seg.type}, 数据: {seg.data}")
|
||||
return f"[处理失败的{seg.type}消息]"
|
||||
|
||||
def _generate_detailed_text(self) -> str:
|
||||
@@ -565,7 +564,7 @@ class MessageSending(MessageProcessBase):
|
||||
is_emoji: bool = False,
|
||||
thinking_start_time: float = 0,
|
||||
apply_set_reply_logic: bool = False,
|
||||
reply_to: Optional[str] = None,
|
||||
reply_to: str | None = None,
|
||||
):
|
||||
# 调用父类初始化
|
||||
super().__init__(
|
||||
@@ -635,11 +634,11 @@ class MessageSet:
|
||||
self.messages.append(message)
|
||||
self.messages.sort(key=lambda x: x.message_info.time) # type: ignore
|
||||
|
||||
def get_message_by_index(self, index: int) -> Optional[MessageSending]:
|
||||
def get_message_by_index(self, index: int) -> MessageSending | None:
|
||||
"""通过索引获取消息"""
|
||||
return self.messages[index] if 0 <= index < len(self.messages) else None
|
||||
|
||||
def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
|
||||
def get_message_by_time(self, target_time: float) -> MessageSending | None:
|
||||
"""获取最接近指定时间的消息"""
|
||||
if not self.messages:
|
||||
return None
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import re
|
||||
import traceback
|
||||
import orjson
|
||||
from typing import Union
|
||||
|
||||
from src.common.database.sqlalchemy_models import Messages, Images
|
||||
from src.common.logger import get_logger
|
||||
from .chat_stream import ChatStream
|
||||
from .message import MessageSending, MessageRecv
|
||||
import orjson
|
||||
from sqlalchemy import desc, select, update
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from sqlalchemy import select, update, desc
|
||||
from src.common.database.sqlalchemy_models import Images, Messages
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .chat_stream import ChatStream
|
||||
from .message import MessageRecv, MessageSending
|
||||
|
||||
logger = get_logger("message_storage")
|
||||
|
||||
@@ -32,7 +33,7 @@ class MessageStorage:
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
|
||||
async def store_message(message: MessageSending | MessageRecv, chat_stream: ChatStream) -> None:
|
||||
"""存储消息到数据库"""
|
||||
try:
|
||||
# 过滤敏感信息的正则模式
|
||||
@@ -299,6 +300,7 @@ class MessageStorage:
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from src.common.database.sqlalchemy_models import Messages
|
||||
|
||||
# 查找需要修复的记录:interest_value为0、null或很小的值
|
||||
|
||||
@@ -3,12 +3,11 @@ import traceback
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.message.api import get_global_api
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.message import MessageSending
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.utils.utils import truncate_message
|
||||
from src.chat.utils.utils import calculate_typing_time
|
||||
from src.chat.utils.utils import calculate_typing_time, truncate_message
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message.api import get_global_api
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -27,7 +26,7 @@ async def send_message(message: MessageSending, show_log=True) -> bool:
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {str(e)}")
|
||||
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {e!s}")
|
||||
traceback.print_exc()
|
||||
raise e # 重新抛出其他异常
|
||||
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
import time
|
||||
from typing import Dict, Optional, Type, Any, Tuple
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.base.component_types import ComponentType, ActionInfo
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.plugin_system.apis import database_api, generator_api, message_api, send_api
|
||||
from src.plugin_system.base.base_action import BaseAction
|
||||
from src.plugin_system.apis import generator_api, database_api, send_api, message_api
|
||||
|
||||
from src.plugin_system.base.component_types import ActionInfo, ComponentType
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
logger = get_logger("action_manager")
|
||||
|
||||
@@ -29,7 +27,7 @@ class ChatterActionManager:
|
||||
"""初始化动作管理器"""
|
||||
|
||||
# 当前正在使用的动作集合,默认加载默认动作
|
||||
self._using_actions: Dict[str, ActionInfo] = {}
|
||||
self._using_actions: dict[str, ActionInfo] = {}
|
||||
|
||||
# 初始化时将默认动作加载到使用中的动作
|
||||
self._using_actions = component_registry.get_default_actions()
|
||||
@@ -48,8 +46,8 @@ class ChatterActionManager:
|
||||
chat_stream: ChatStream,
|
||||
log_prefix: str,
|
||||
shutting_down: bool = False,
|
||||
action_message: Optional[dict] = None,
|
||||
) -> Optional[BaseAction]:
|
||||
action_message: dict | None = None,
|
||||
) -> BaseAction | None:
|
||||
"""
|
||||
创建动作处理器实例
|
||||
|
||||
@@ -68,7 +66,7 @@ class ChatterActionManager:
|
||||
"""
|
||||
try:
|
||||
# 获取组件类 - 明确指定查询Action类型
|
||||
component_class: Type[BaseAction] = component_registry.get_component_class(
|
||||
component_class: type[BaseAction] = component_registry.get_component_class(
|
||||
action_name, ComponentType.ACTION
|
||||
) # type: ignore
|
||||
if not component_class:
|
||||
@@ -107,7 +105,7 @@ class ChatterActionManager:
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
def get_using_actions(self) -> Dict[str, ActionInfo]:
|
||||
def get_using_actions(self) -> dict[str, ActionInfo]:
|
||||
"""获取当前正在使用的动作集合"""
|
||||
return self._using_actions.copy()
|
||||
|
||||
@@ -140,10 +138,10 @@ class ChatterActionManager:
|
||||
self,
|
||||
action_name: str,
|
||||
chat_id: str,
|
||||
target_message: Optional[dict] = None,
|
||||
target_message: dict | None = None,
|
||||
reasoning: str = "",
|
||||
action_data: Optional[dict] = None,
|
||||
thinking_id: Optional[str] = None,
|
||||
action_data: dict | None = None,
|
||||
thinking_id: str | None = None,
|
||||
log_prefix: str = "",
|
||||
clear_unread_messages: bool = True,
|
||||
) -> Any:
|
||||
@@ -437,10 +435,10 @@ class ChatterActionManager:
|
||||
response_set,
|
||||
loop_start_time,
|
||||
action_message,
|
||||
cycle_timers: Dict[str, float],
|
||||
cycle_timers: dict[str, float],
|
||||
thinking_id,
|
||||
actions,
|
||||
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
|
||||
) -> tuple[dict[str, Any], str, dict[str, float]]:
|
||||
"""
|
||||
发送并存储回复信息
|
||||
|
||||
@@ -488,7 +486,7 @@ class ChatterActionManager:
|
||||
)
|
||||
|
||||
# 构建循环信息
|
||||
loop_info: Dict[str, Any] = {
|
||||
loop_info: dict[str, Any] = {
|
||||
"loop_plan_info": {
|
||||
"action_result": actions,
|
||||
},
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
import random
|
||||
import asyncio
|
||||
import hashlib
|
||||
import random
|
||||
import time
|
||||
from typing import List, Any, Dict, TYPE_CHECKING, Tuple
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
||||
from src.plugin_system.base.component_types import ActionInfo, ActionActivationType
|
||||
from src.plugin_system.base.component_types import ActionActivationType, ActionInfo
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -59,18 +59,17 @@ class ActionModifier:
|
||||
"""
|
||||
logger.debug(f"{self.log_prefix}开始完整动作修改流程")
|
||||
|
||||
removals_s1: List[Tuple[str, str]] = []
|
||||
removals_s2: List[Tuple[str, str]] = []
|
||||
removals_s3: List[Tuple[str, str]] = []
|
||||
removals_s1: list[tuple[str, str]] = []
|
||||
removals_s2: list[tuple[str, str]] = []
|
||||
removals_s3: list[tuple[str, str]] = []
|
||||
|
||||
self.action_manager.restore_actions()
|
||||
all_actions = self.action_manager.get_using_actions()
|
||||
|
||||
# === 第0阶段:根据聊天类型过滤动作 ===
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.plugin_system.base.component_types import ChatType, ComponentType
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
# 获取聊天类型
|
||||
is_group_chat, _ = get_chat_type_and_target_info(self.chat_id)
|
||||
@@ -167,8 +166,8 @@ class ActionModifier:
|
||||
|
||||
logger.info(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}")
|
||||
|
||||
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: StreamContext):
|
||||
type_mismatched_actions: List[Tuple[str, str]] = []
|
||||
def _check_action_associated_types(self, all_actions: dict[str, ActionInfo], chat_context: StreamContext):
|
||||
type_mismatched_actions: list[tuple[str, str]] = []
|
||||
for action_name, action_info in all_actions.items():
|
||||
if action_info.associated_types and not chat_context.check_types(action_info.associated_types):
|
||||
associated_types_str = ", ".join(action_info.associated_types)
|
||||
@@ -179,9 +178,9 @@ class ActionModifier:
|
||||
|
||||
async def _get_deactivated_actions_by_type(
|
||||
self,
|
||||
actions_with_info: Dict[str, ActionInfo],
|
||||
actions_with_info: dict[str, ActionInfo],
|
||||
chat_content: str = "",
|
||||
) -> List[tuple[str, str]]:
|
||||
) -> list[tuple[str, str]]:
|
||||
"""
|
||||
根据激活类型过滤,返回需要停用的动作列表及原因
|
||||
|
||||
@@ -254,9 +253,9 @@ class ActionModifier:
|
||||
|
||||
async def _process_llm_judge_actions_parallel(
|
||||
self,
|
||||
llm_judge_actions: Dict[str, Any],
|
||||
llm_judge_actions: dict[str, Any],
|
||||
chat_content: str = "",
|
||||
) -> Dict[str, bool]:
|
||||
) -> dict[str, bool]:
|
||||
"""
|
||||
并行处理LLM判定actions,支持智能缓存
|
||||
|
||||
|
||||
@@ -3,42 +3,41 @@
|
||||
使用重构后的统一Prompt系统替换原有的复杂提示词构建逻辑
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import time
|
||||
import asyncio
|
||||
import random
|
||||
import re
|
||||
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from src.mais4u.mai_think import mai_thinking_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.individuality.individuality import get_individuality
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending
|
||||
from typing import Any
|
||||
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.utils.memory_mappings import get_memory_type_chinese_label
|
||||
from src.chat.message_receive.message import MessageRecv, MessageSending, Seg, UserInfo
|
||||
from src.chat.message_receive.uni_message_sender import HeartFCSender
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
replace_user_references_sync,
|
||||
)
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
from src.chat.utils.memory_mappings import get_memory_type_chinese_label
|
||||
|
||||
# 导入新的统一Prompt系统
|
||||
from src.chat.utils.prompt import Prompt, PromptParameters, global_prompt_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.individuality.individuality import get_individuality
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.mais4u.mai_think import mai_thinking_manager
|
||||
|
||||
# 旧记忆系统已被移除
|
||||
# 旧记忆系统已被移除
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo, EventType
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
# 导入新的统一Prompt系统
|
||||
from src.chat.utils.prompt import PromptParameters
|
||||
from src.plugin_system.base.component_types import ActionInfo, EventType
|
||||
|
||||
logger = get_logger("replyer")
|
||||
|
||||
@@ -248,12 +247,12 @@ class DefaultReplyer:
|
||||
self,
|
||||
reply_to: str = "",
|
||||
extra_info: str = "",
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
available_actions: dict[str, ActionInfo] | None = None,
|
||||
enable_tool: bool = True,
|
||||
from_plugin: bool = True,
|
||||
stream_id: Optional[str] = None,
|
||||
reply_message: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]:
|
||||
stream_id: str | None = None,
|
||||
reply_message: dict[str, Any] | None = None,
|
||||
) -> tuple[bool, dict[str, Any] | None, str | None]:
|
||||
# sourcery skip: merge-nested-ifs
|
||||
"""
|
||||
回复器 (Replier): 负责生成回复文本的核心逻辑。
|
||||
@@ -353,7 +352,7 @@ class DefaultReplyer:
|
||||
reason: str = "",
|
||||
reply_to: str = "",
|
||||
return_prompt: bool = False,
|
||||
) -> Tuple[bool, Optional[str], Optional[str]]:
|
||||
) -> tuple[bool, str | None, str | None]:
|
||||
"""
|
||||
表达器 (Expressor): 负责重写和优化回复文本。
|
||||
|
||||
@@ -722,7 +721,7 @@ class DefaultReplyer:
|
||||
logger.error(f"工具信息获取失败: {e}")
|
||||
return ""
|
||||
|
||||
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
|
||||
def _parse_reply_target(self, target_message: str) -> tuple[str, str]:
|
||||
"""解析回复目标消息 - 使用共享工具"""
|
||||
from src.chat.utils.prompt import Prompt
|
||||
|
||||
@@ -731,7 +730,7 @@ class DefaultReplyer:
|
||||
return "未知用户", "(无消息内容)"
|
||||
return Prompt.parse_reply_target(target_message)
|
||||
|
||||
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
|
||||
async def build_keywords_reaction_prompt(self, target: str | None) -> str:
|
||||
"""构建关键词反应提示
|
||||
|
||||
Args:
|
||||
@@ -766,14 +765,14 @@ class DefaultReplyer:
|
||||
keywords_reaction_prompt += f"{reaction},"
|
||||
break
|
||||
except re.error as e:
|
||||
logger.error(f"正则表达式编译错误: {pattern_str}, 错误信息: {str(e)}")
|
||||
logger.error(f"正则表达式编译错误: {pattern_str}, 错误信息: {e!s}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"关键词检测与反应时发生异常: {str(e)}", exc_info=True)
|
||||
logger.error(f"关键词检测与反应时发生异常: {e!s}", exc_info=True)
|
||||
|
||||
return keywords_reaction_prompt
|
||||
|
||||
async def _time_and_run_task(self, coroutine, name: str) -> Tuple[str, Any, float]:
|
||||
async def _time_and_run_task(self, coroutine, name: str) -> tuple[str, Any, float]:
|
||||
"""计时并运行异步任务的辅助函数
|
||||
|
||||
Args:
|
||||
@@ -790,8 +789,8 @@ class DefaultReplyer:
|
||||
return name, result, duration
|
||||
|
||||
async def build_s4u_chat_history_prompts(
|
||||
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, chat_id: str
|
||||
) -> Tuple[str, str]:
|
||||
self, message_list_before_now: list[dict[str, Any]], target_user_id: str, sender: str, chat_id: str
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
构建 s4u 风格的已读/未读历史消息 prompt
|
||||
|
||||
@@ -907,8 +906,8 @@ class DefaultReplyer:
|
||||
return await self._fallback_build_chat_history_prompts(message_list_before_now, target_user_id, sender)
|
||||
|
||||
async def _fallback_build_chat_history_prompts(
|
||||
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str
|
||||
) -> Tuple[str, str]:
|
||||
self, message_list_before_now: list[dict[str, Any]], target_user_id: str, sender: str
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
回退的已读/未读历史消息构建方法
|
||||
"""
|
||||
@@ -1000,15 +999,15 @@ class DefaultReplyer:
|
||||
|
||||
return read_history_prompt, unread_history_prompt
|
||||
|
||||
async def _get_interest_scores_for_messages(self, messages: List[dict]) -> dict[str, float]:
|
||||
async def _get_interest_scores_for_messages(self, messages: list[dict]) -> dict[str, float]:
|
||||
"""为消息获取兴趣度评分"""
|
||||
interest_scores = {}
|
||||
|
||||
try:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import (
|
||||
chatter_interest_scoring_system as interest_scoring_system,
|
||||
)
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
# 转换消息格式
|
||||
db_messages = []
|
||||
@@ -1094,9 +1093,9 @@ class DefaultReplyer:
|
||||
self,
|
||||
reply_to: str,
|
||||
extra_info: str = "",
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
available_actions: dict[str, ActionInfo] | None = None,
|
||||
enable_tool: bool = True,
|
||||
reply_message: Optional[Dict[str, Any]] = None,
|
||||
reply_message: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
构建回复器上下文
|
||||
@@ -1417,7 +1416,7 @@ class DefaultReplyer:
|
||||
raw_reply: str,
|
||||
reason: str,
|
||||
reply_to: str,
|
||||
reply_message: Optional[Dict[str, Any]] = None,
|
||||
reply_message: dict[str, Any] | None = None,
|
||||
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
@@ -1553,7 +1552,7 @@ class DefaultReplyer:
|
||||
is_emoji: bool,
|
||||
thinking_start_time: float,
|
||||
display_message: str,
|
||||
anchor_message: Optional[MessageRecv] = None,
|
||||
anchor_message: MessageRecv | None = None,
|
||||
) -> MessageSending:
|
||||
"""构建单个发送消息"""
|
||||
|
||||
@@ -1644,7 +1643,7 @@ class DefaultReplyer:
|
||||
logger.debug("从LPMM知识库获取知识失败,可能是从未导入过知识,返回空知识...")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"获取知识库内容时发生异常: {str(e)}")
|
||||
logger.error(f"获取知识库内容时发生异常: {e!s}")
|
||||
return ""
|
||||
|
||||
async def build_relation_info(self, sender: str, target: str):
|
||||
@@ -1660,10 +1659,9 @@ class DefaultReplyer:
|
||||
|
||||
# 使用AFC关系追踪器获取关系信息
|
||||
try:
|
||||
from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker
|
||||
|
||||
# 创建关系追踪器实例
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||
from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker
|
||||
|
||||
relationship_tracker = ChatterRelationshipTracker(chatter_interest_scoring_system)
|
||||
if relationship_tracker:
|
||||
@@ -1704,7 +1702,7 @@ class DefaultReplyer:
|
||||
logger.error(f"获取AFC关系信息失败: {e}")
|
||||
return f"你与{sender}是普通朋友关系。"
|
||||
|
||||
async def _store_chat_memory_async(self, reply_to: str, reply_message: Optional[Dict[str, Any]] = None):
|
||||
async def _store_chat_memory_async(self, reply_to: str, reply_message: dict[str, Any] | None = None):
|
||||
"""
|
||||
异步存储聊天记忆(从build_memory_block迁移而来)
|
||||
|
||||
|
||||
@@ -1,22 +1,20 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.replyer.default_generator import DefaultReplyer
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("ReplyerManager")
|
||||
|
||||
|
||||
class ReplyerManager:
|
||||
def __init__(self):
|
||||
self._repliers: Dict[str, DefaultReplyer] = {}
|
||||
self._repliers: dict[str, DefaultReplyer] = {}
|
||||
|
||||
def get_replyer(
|
||||
self,
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
chat_stream: ChatStream | None = None,
|
||||
chat_id: str | None = None,
|
||||
request_type: str = "replyer",
|
||||
) -> Optional[DefaultReplyer]:
|
||||
) -> DefaultReplyer | None:
|
||||
"""
|
||||
获取或创建回复器实例。
|
||||
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
import time # 导入 time 模块以获取当前时间
|
||||
import random
|
||||
import re
|
||||
import time # 导入 time 模块以获取当前时间
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from typing import List, Dict, Any, Tuple, Optional, Callable
|
||||
from rich.traceback import install
|
||||
from sqlalchemy import and_, select
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.message_repository import find_messages, count_messages
|
||||
from src.common.database.sqlalchemy_models import ActionRecords, Images
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
|
||||
from src.chat.utils.utils import assign_message_ids, translate_timestamp_to_human_readable
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from sqlalchemy import select, and_
|
||||
from src.common.database.sqlalchemy_models import ActionRecords, Images
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import count_messages, find_messages
|
||||
from src.config.config import global_config
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
|
||||
logger = get_logger("chat_message_builder")
|
||||
|
||||
@@ -22,7 +23,7 @@ install(extra_lines=3)
|
||||
def replace_user_references_sync(
|
||||
content: str,
|
||||
platform: str,
|
||||
name_resolver: Optional[Callable[[str, str], str]] = None,
|
||||
name_resolver: Callable[[str, str], str] | None = None,
|
||||
replace_bot_name: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
@@ -100,7 +101,7 @@ def replace_user_references_sync(
|
||||
async def replace_user_references_async(
|
||||
content: str,
|
||||
platform: str,
|
||||
name_resolver: Optional[Callable[[str, str], Any]] = None,
|
||||
name_resolver: Callable[[str, str], Any] | None = None,
|
||||
replace_bot_name: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
@@ -174,7 +175,7 @@ async def replace_user_references_async(
|
||||
|
||||
async def get_raw_msg_by_timestamp(
|
||||
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
@@ -194,7 +195,7 @@ async def get_raw_msg_by_timestamp_with_chat(
|
||||
limit_mode: str = "latest",
|
||||
filter_bot=False,
|
||||
filter_command=False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。
|
||||
@@ -220,7 +221,7 @@ async def get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
filter_bot=False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的消息(包含边界),按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。
|
||||
@@ -239,10 +240,10 @@ async def get_raw_msg_by_timestamp_with_chat_users(
|
||||
chat_id: str,
|
||||
timestamp_start: float,
|
||||
timestamp_end: float,
|
||||
person_ids: List[str],
|
||||
person_ids: list[str],
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取某些特定用户在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。
|
||||
@@ -263,7 +264,7 @@ async def get_actions_by_timestamp_with_chat(
|
||||
timestamp_end: float = time.time(),
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -372,7 +373,7 @@ async def get_actions_by_timestamp_with_chat(
|
||||
|
||||
async def get_actions_by_timestamp_with_chat_inclusive(
|
||||
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
|
||||
async with get_db_session() as session:
|
||||
if limit > 0:
|
||||
@@ -423,7 +424,7 @@ async def get_actions_by_timestamp_with_chat_inclusive(
|
||||
|
||||
async def get_raw_msg_by_timestamp_random(
|
||||
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
先在范围时间戳内随机选择一条消息,取得消息的chat_id,然后根据chat_id获取该聊天在指定时间戳范围内的消息
|
||||
"""
|
||||
@@ -441,7 +442,7 @@ async def get_raw_msg_by_timestamp_random(
|
||||
|
||||
async def get_raw_msg_by_timestamp_with_users(
|
||||
timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。
|
||||
@@ -452,7 +453,7 @@ async def get_raw_msg_by_timestamp_with_users(
|
||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||
|
||||
|
||||
async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
|
||||
async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> list[dict[str, Any]]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
"""
|
||||
@@ -463,7 +464,7 @@ async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List
|
||||
|
||||
async def get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id: str, timestamp: float, limit: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
"""
|
||||
@@ -474,7 +475,7 @@ async def get_raw_msg_before_timestamp_with_chat(
|
||||
|
||||
async def get_raw_msg_before_timestamp_with_users(
|
||||
timestamp: float, person_ids: list, limit: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
"""
|
||||
@@ -483,9 +484,7 @@ async def get_raw_msg_before_timestamp_with_users(
|
||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
|
||||
|
||||
async def num_new_messages_since(
|
||||
chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None
|
||||
) -> int:
|
||||
async def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: float | None = None) -> int:
|
||||
"""
|
||||
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
|
||||
如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。
|
||||
@@ -517,16 +516,16 @@ async def num_new_messages_since_with_users(
|
||||
|
||||
|
||||
async def _build_readable_messages_internal(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: list[dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
truncate: bool = False,
|
||||
pic_id_mapping: Optional[Dict[str, str]] = None,
|
||||
pic_id_mapping: dict[str, str] | None = None,
|
||||
pic_counter: int = 1,
|
||||
show_pic: bool = True,
|
||||
message_id_list: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
|
||||
message_id_list: list[dict[str, Any]] | None = None,
|
||||
) -> tuple[str, list[tuple[float, str, str]], dict[str, str], int]:
|
||||
"""
|
||||
内部辅助函数,构建可读消息字符串和原始消息详情列表。
|
||||
|
||||
@@ -545,7 +544,7 @@ async def _build_readable_messages_internal(
|
||||
if not messages:
|
||||
return "", [], pic_id_mapping or {}, pic_counter
|
||||
|
||||
message_details_raw: List[Tuple[float, str, str, bool]] = []
|
||||
message_details_raw: list[tuple[float, str, str, bool]] = []
|
||||
|
||||
# 使用传入的映射字典,如果没有则创建新的
|
||||
if pic_id_mapping is None:
|
||||
@@ -672,7 +671,7 @@ async def _build_readable_messages_internal(
|
||||
message_details_with_flags.append((timestamp, name, content, is_action))
|
||||
|
||||
# 应用截断逻辑 (如果 truncate 为 True)
|
||||
message_details: List[Tuple[float, str, str, bool]] = []
|
||||
message_details: list[tuple[float, str, str, bool]] = []
|
||||
n_messages = len(message_details_with_flags)
|
||||
if truncate and n_messages > 0:
|
||||
for i, (timestamp, name, content, is_action) in enumerate(message_details_with_flags):
|
||||
@@ -809,7 +808,7 @@ async def _build_readable_messages_internal(
|
||||
)
|
||||
|
||||
|
||||
async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
async def build_pic_mapping_info(pic_id_mapping: dict[str, str]) -> str:
|
||||
# sourcery skip: use-contextlib-suppress
|
||||
"""
|
||||
构建图片映射信息字符串,显示图片的具体描述内容
|
||||
@@ -847,7 +846,7 @@ async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
return "\n".join(mapping_lines)
|
||||
|
||||
|
||||
def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
|
||||
def build_readable_actions(actions: list[dict[str, Any]]) -> str:
|
||||
"""
|
||||
将动作列表转换为可读的文本格式。
|
||||
格式: 在()分钟前,你使用了(action_name),具体内容是:(action_prompt_display)
|
||||
@@ -922,12 +921,12 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
|
||||
|
||||
|
||||
async def build_readable_messages_with_list(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: list[dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
truncate: bool = False,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]]]:
|
||||
) -> tuple[str, list[tuple[float, str, str]]]:
|
||||
"""
|
||||
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
|
||||
允许通过参数控制格式化行为。
|
||||
@@ -943,7 +942,7 @@ async def build_readable_messages_with_list(
|
||||
|
||||
|
||||
async def build_readable_messages_with_id(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: list[dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
@@ -951,7 +950,7 @@ async def build_readable_messages_with_id(
|
||||
truncate: bool = False,
|
||||
show_actions: bool = False,
|
||||
show_pic: bool = True,
|
||||
) -> Tuple[str, List[Dict[str, Any]]]:
|
||||
) -> tuple[str, list[dict[str, Any]]]:
|
||||
"""
|
||||
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
|
||||
允许通过参数控制格式化行为。
|
||||
@@ -980,7 +979,7 @@ async def build_readable_messages_with_id(
|
||||
|
||||
|
||||
async def build_readable_messages(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: list[dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
@@ -988,7 +987,7 @@ async def build_readable_messages(
|
||||
truncate: bool = False,
|
||||
show_actions: bool = True,
|
||||
show_pic: bool = True,
|
||||
message_id_list: Optional[List[Dict[str, Any]]] = None,
|
||||
message_id_list: list[dict[str, Any]] | None = None,
|
||||
) -> str: # sourcery skip: extract-method
|
||||
"""
|
||||
将消息列表转换为可读的文本格式。
|
||||
@@ -1148,7 +1147,7 @@ async def build_readable_messages(
|
||||
return "".join(result_parts)
|
||||
|
||||
|
||||
async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
async def build_anonymous_messages(messages: list[dict[str, Any]]) -> str:
|
||||
"""
|
||||
构建匿名可读消息,将不同人的名称转为唯一占位符(A、B、C...),bot自己用SELF。
|
||||
处理 回复<aaa:bbb> 和 @<aaa:bbb> 字段,将bbb映射为匿名占位符。
|
||||
@@ -1261,7 +1260,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
return formatted_string
|
||||
|
||||
|
||||
async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
async def get_person_id_list(messages: list[dict[str, Any]]) -> list[str]:
|
||||
"""
|
||||
从消息列表中提取不重复的 person_id 列表 (忽略机器人自身)。
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆系统相关的映射表和工具函数
|
||||
提供记忆类型、置信度、重要性等的中文标签映射
|
||||
|
||||
@@ -3,19 +3,20 @@
|
||||
将原有的Prompt类和SmartPrompt功能整合为一个真正的Prompt类
|
||||
"""
|
||||
|
||||
import re
|
||||
import asyncio
|
||||
import time
|
||||
import contextvars
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Any, Optional, List, Literal, Tuple
|
||||
import re
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -50,11 +51,11 @@ class PromptParameters:
|
||||
debug_mode: bool = False
|
||||
|
||||
# 聊天历史和上下文
|
||||
chat_target_info: Optional[Dict[str, Any]] = None
|
||||
message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list)
|
||||
message_list_before_short: List[Dict[str, Any]] = field(default_factory=list)
|
||||
chat_target_info: dict[str, Any] | None = None
|
||||
message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list)
|
||||
message_list_before_short: list[dict[str, Any]] = field(default_factory=list)
|
||||
chat_talking_prompt_short: str = ""
|
||||
target_user_info: Optional[Dict[str, Any]] = None
|
||||
target_user_info: dict[str, Any] | None = None
|
||||
|
||||
# 已构建的内容块
|
||||
expression_habits_block: str = ""
|
||||
@@ -77,12 +78,12 @@ class PromptParameters:
|
||||
action_descriptions: str = ""
|
||||
|
||||
# 可用动作信息
|
||||
available_actions: Optional[Dict[str, Any]] = None
|
||||
available_actions: dict[str, Any] | None = None
|
||||
|
||||
# 动态生成的聊天场景提示
|
||||
chat_scene: str = ""
|
||||
|
||||
def validate(self) -> List[str]:
|
||||
def validate(self) -> list[str]:
|
||||
"""参数验证"""
|
||||
errors = []
|
||||
if not self.chat_id:
|
||||
@@ -98,22 +99,22 @@ class PromptContext:
|
||||
"""提示词上下文管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
|
||||
self._context_prompts: dict[str, dict[str, "Prompt"]] = {}
|
||||
self._current_context_var = contextvars.ContextVar("current_context", default=None)
|
||||
self._context_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def _current_context(self) -> Optional[str]:
|
||||
def _current_context(self) -> str | None:
|
||||
"""获取当前协程的上下文ID"""
|
||||
return self._current_context_var.get()
|
||||
|
||||
@_current_context.setter
|
||||
def _current_context(self, value: Optional[str]):
|
||||
def _current_context(self, value: str | None):
|
||||
"""设置当前协程的上下文ID"""
|
||||
self._current_context_var.set(value) # type: ignore
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_scope(self, context_id: Optional[str] = None):
|
||||
async def async_scope(self, context_id: str | None = None):
|
||||
"""创建一个异步的临时提示模板作用域"""
|
||||
if context_id is not None:
|
||||
try:
|
||||
@@ -159,7 +160,7 @@ class PromptContext:
|
||||
return self._context_prompts[current_context][name]
|
||||
return None
|
||||
|
||||
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
|
||||
async def register_async(self, prompt: "Prompt", context_id: str | None = None) -> None:
|
||||
"""异步注册提示模板到指定作用域"""
|
||||
async with self._context_lock:
|
||||
if target_context := context_id or self._current_context:
|
||||
@@ -177,7 +178,7 @@ class PromptManager:
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_message_scope(self, message_id: Optional[str] = None):
|
||||
async def async_message_scope(self, message_id: str | None = None):
|
||||
"""为消息处理创建异步临时作用域"""
|
||||
async with self._context.async_scope(message_id):
|
||||
yield self
|
||||
@@ -236,8 +237,8 @@ class Prompt:
|
||||
def __init__(
|
||||
self,
|
||||
template: str,
|
||||
name: Optional[str] = None,
|
||||
parameters: Optional[PromptParameters] = None,
|
||||
name: str | None = None,
|
||||
parameters: PromptParameters | None = None,
|
||||
should_register: bool = True,
|
||||
):
|
||||
"""
|
||||
@@ -277,7 +278,7 @@ class Prompt:
|
||||
"""将临时标记还原为实际的花括号字符"""
|
||||
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
|
||||
|
||||
def _parse_template_args(self, template: str) -> List[str]:
|
||||
def _parse_template_args(self, template: str) -> list[str]:
|
||||
"""解析模板参数"""
|
||||
template_args = []
|
||||
processed_template = self._process_escaped_braces(template)
|
||||
@@ -321,7 +322,7 @@ class Prompt:
|
||||
logger.error(f"构建Prompt失败: {e}")
|
||||
raise RuntimeError(f"构建Prompt失败: {e}") from e
|
||||
|
||||
async def _build_context_data(self) -> Dict[str, Any]:
|
||||
async def _build_context_data(self) -> dict[str, Any]:
|
||||
"""构建智能上下文数据"""
|
||||
# 并行执行所有构建任务
|
||||
start_time = time.time()
|
||||
@@ -401,7 +402,7 @@ class Prompt:
|
||||
default_result = self._get_default_result_for_task(task_name)
|
||||
results.append(default_result)
|
||||
except Exception as e:
|
||||
logger.error(f"构建任务{task_name}失败: {str(e)}")
|
||||
logger.error(f"构建任务{task_name}失败: {e!s}")
|
||||
default_result = self._get_default_result_for_task(task_name)
|
||||
results.append(default_result)
|
||||
|
||||
@@ -411,7 +412,7 @@ class Prompt:
|
||||
task_name = task_names[i] if i < len(task_names) else f"task_{i}"
|
||||
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"构建任务{task_name}失败: {str(result)}")
|
||||
logger.error(f"构建任务{task_name}失败: {result!s}")
|
||||
elif isinstance(result, dict):
|
||||
context_data.update(result)
|
||||
|
||||
@@ -453,7 +454,7 @@ class Prompt:
|
||||
|
||||
return context_data
|
||||
|
||||
async def _build_s4u_chat_context(self, context_data: Dict[str, Any]) -> None:
|
||||
async def _build_s4u_chat_context(self, context_data: dict[str, Any]) -> None:
|
||||
"""构建S4U模式的聊天上下文"""
|
||||
if not self.parameters.message_list_before_now_long:
|
||||
return
|
||||
@@ -468,7 +469,7 @@ class Prompt:
|
||||
context_data["read_history_prompt"] = read_history_prompt
|
||||
context_data["unread_history_prompt"] = unread_history_prompt
|
||||
|
||||
async def _build_normal_chat_context(self, context_data: Dict[str, Any]) -> None:
|
||||
async def _build_normal_chat_context(self, context_data: dict[str, Any]) -> None:
|
||||
"""构建normal模式的聊天上下文"""
|
||||
if not self.parameters.chat_talking_prompt_short:
|
||||
return
|
||||
@@ -477,8 +478,8 @@ class Prompt:
|
||||
{self.parameters.chat_talking_prompt_short}"""
|
||||
|
||||
async def _build_s4u_chat_history_prompts(
|
||||
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, chat_id: str
|
||||
) -> Tuple[str, str]:
|
||||
self, message_list_before_now: list[dict[str, Any]], target_user_id: str, sender: str, chat_id: str
|
||||
) -> tuple[str, str]:
|
||||
"""构建S4U风格的已读/未读历史消息prompt"""
|
||||
try:
|
||||
# 动态导入default_generator以避免循环导入
|
||||
@@ -492,7 +493,7 @@ class Prompt:
|
||||
except Exception as e:
|
||||
logger.error(f"构建S4U历史消息prompt失败: {e}")
|
||||
|
||||
async def _build_expression_habits(self) -> Dict[str, Any]:
|
||||
async def _build_expression_habits(self) -> dict[str, Any]:
|
||||
"""构建表达习惯"""
|
||||
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.parameters.chat_id)
|
||||
if not use_expression:
|
||||
@@ -533,7 +534,7 @@ class Prompt:
|
||||
logger.error(f"构建表达习惯失败: {e}")
|
||||
return {"expression_habits_block": ""}
|
||||
|
||||
async def _build_memory_block(self) -> Dict[str, Any]:
|
||||
async def _build_memory_block(self) -> dict[str, Any]:
|
||||
"""构建记忆块"""
|
||||
if not global_config.memory.enable_memory:
|
||||
return {"memory_block": ""}
|
||||
@@ -653,7 +654,7 @@ class Prompt:
|
||||
logger.error(f"构建记忆块失败: {e}")
|
||||
return {"memory_block": ""}
|
||||
|
||||
async def _build_memory_block_fast(self) -> Dict[str, Any]:
|
||||
async def _build_memory_block_fast(self) -> dict[str, Any]:
|
||||
"""快速构建记忆块(简化版本,用于未预构建时的后备方案)"""
|
||||
if not global_config.memory.enable_memory:
|
||||
return {"memory_block": ""}
|
||||
@@ -677,7 +678,7 @@ class Prompt:
|
||||
logger.warning(f"快速构建记忆块失败: {e}")
|
||||
return {"memory_block": ""}
|
||||
|
||||
async def _build_relation_info(self) -> Dict[str, Any]:
|
||||
async def _build_relation_info(self) -> dict[str, Any]:
|
||||
"""构建关系信息"""
|
||||
try:
|
||||
relation_info = await Prompt.build_relation_info(self.parameters.chat_id, self.parameters.reply_to)
|
||||
@@ -686,7 +687,7 @@ class Prompt:
|
||||
logger.error(f"构建关系信息失败: {e}")
|
||||
return {"relation_info_block": ""}
|
||||
|
||||
async def _build_tool_info(self) -> Dict[str, Any]:
|
||||
async def _build_tool_info(self) -> dict[str, Any]:
|
||||
"""构建工具信息"""
|
||||
if not global_config.tool.enable_tool:
|
||||
return {"tool_info_block": ""}
|
||||
@@ -734,7 +735,7 @@ class Prompt:
|
||||
logger.error(f"构建工具信息失败: {e}")
|
||||
return {"tool_info_block": ""}
|
||||
|
||||
async def _build_knowledge_info(self) -> Dict[str, Any]:
|
||||
async def _build_knowledge_info(self) -> dict[str, Any]:
|
||||
"""构建知识信息"""
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
return {"knowledge_prompt": ""}
|
||||
@@ -783,7 +784,7 @@ class Prompt:
|
||||
logger.error(f"构建知识信息失败: {e}")
|
||||
return {"knowledge_prompt": ""}
|
||||
|
||||
async def _build_cross_context(self) -> Dict[str, Any]:
|
||||
async def _build_cross_context(self) -> dict[str, Any]:
|
||||
"""构建跨群上下文"""
|
||||
try:
|
||||
cross_context = await Prompt.build_cross_context(
|
||||
@@ -794,7 +795,7 @@ class Prompt:
|
||||
logger.error(f"构建跨群上下文失败: {e}")
|
||||
return {"cross_context_block": ""}
|
||||
|
||||
async def _format_with_context(self, context_data: Dict[str, Any]) -> str:
|
||||
async def _format_with_context(self, context_data: dict[str, Any]) -> str:
|
||||
"""使用上下文数据格式化模板"""
|
||||
if self.parameters.prompt_mode == "s4u":
|
||||
params = self._prepare_s4u_params(context_data)
|
||||
@@ -805,7 +806,7 @@ class Prompt:
|
||||
|
||||
return await global_prompt_manager.format_prompt(self.name, **params) if self.name else self.format(**params)
|
||||
|
||||
def _prepare_s4u_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _prepare_s4u_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""准备S4U模式的参数"""
|
||||
return {
|
||||
**context_data,
|
||||
@@ -834,7 +835,7 @@ class Prompt:
|
||||
or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。",
|
||||
}
|
||||
|
||||
def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _prepare_normal_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""准备Normal模式的参数"""
|
||||
return {
|
||||
**context_data,
|
||||
@@ -862,7 +863,7 @@ class Prompt:
|
||||
or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。",
|
||||
}
|
||||
|
||||
def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _prepare_default_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""准备默认模式的参数"""
|
||||
return {
|
||||
"expression_habits_block": context_data.get("expression_habits_block", ""),
|
||||
@@ -905,7 +906,7 @@ class Prompt:
|
||||
result = self._restore_escaped_braces(processed_template)
|
||||
return result
|
||||
except (IndexError, KeyError) as e:
|
||||
raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {str(e)}") from e
|
||||
raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {e!s}") from e
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""返回格式化后的结果或原始模板"""
|
||||
@@ -922,7 +923,7 @@ class Prompt:
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def parse_reply_target(target_message: str) -> Tuple[str, str]:
|
||||
def parse_reply_target(target_message: str) -> tuple[str, str]:
|
||||
"""
|
||||
解析回复目标消息 - 统一实现
|
||||
|
||||
@@ -981,7 +982,7 @@ class Prompt:
|
||||
|
||||
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
|
||||
def _get_default_result_for_task(self, task_name: str) -> Dict[str, Any]:
|
||||
def _get_default_result_for_task(self, task_name: str) -> dict[str, Any]:
|
||||
"""
|
||||
为超时的任务提供默认结果
|
||||
|
||||
@@ -1008,7 +1009,7 @@ class Prompt:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]) -> str:
|
||||
async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: dict[str, Any] | None) -> str:
|
||||
"""
|
||||
构建跨群聊上下文 - 统一实现
|
||||
|
||||
@@ -1071,7 +1072,7 @@ class Prompt:
|
||||
|
||||
# 工厂函数
|
||||
def create_prompt(
|
||||
template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
|
||||
template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs
|
||||
) -> Prompt:
|
||||
"""快速创建Prompt实例的工厂函数"""
|
||||
if parameters is None:
|
||||
@@ -1080,7 +1081,7 @@ def create_prompt(
|
||||
|
||||
|
||||
async def create_prompt_async(
|
||||
template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
|
||||
template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs
|
||||
) -> Prompt:
|
||||
"""异步创建Prompt实例"""
|
||||
prompt = create_prompt(template, name, parameters, **kwargs)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Tuple, List
|
||||
from typing import Any
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import db_get, db_query, db_save
|
||||
from src.common.database.sqlalchemy_models import LLMUsage, Messages, OnlineTime
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import OnlineTime, LLMUsage, Messages
|
||||
from src.common.database.sqlalchemy_database_api import db_query, db_save, db_get
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
@@ -150,7 +150,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
# 延迟300秒启动,运行间隔300秒
|
||||
super().__init__(task_name="Statistics Data Output Task", wait_before_start=0, run_interval=300)
|
||||
|
||||
self.name_mapping: Dict[str, Tuple[str, float]] = {}
|
||||
self.name_mapping: dict[str, tuple[str, float]] = {}
|
||||
"""
|
||||
联系人/群聊名称映射 {聊天ID: (联系人/群聊名称, 记录时间(timestamp))}
|
||||
注:设计记录时间的目的是方便更新名称,使联系人/群聊名称保持最新
|
||||
@@ -170,7 +170,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
deploy_time = datetime(2000, 1, 1)
|
||||
local_storage["deploy_time"] = now.timestamp()
|
||||
|
||||
self.stat_period: List[Tuple[str, timedelta, str]] = [
|
||||
self.stat_period: list[tuple[str, timedelta, str]] = [
|
||||
("all_time", now - deploy_time, "自部署以来"), # 必须保留"all_time"
|
||||
("last_7_days", timedelta(days=7), "最近7天"),
|
||||
("last_24_hours", timedelta(days=1), "最近24小时"),
|
||||
@@ -181,7 +181,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
统计时间段 [(统计名称, 统计时间段, 统计描述), ...]
|
||||
"""
|
||||
|
||||
def _statistic_console_output(self, stats: Dict[str, Any], now: datetime):
|
||||
def _statistic_console_output(self, stats: dict[str, Any], now: datetime):
|
||||
"""
|
||||
输出统计数据到控制台
|
||||
:param stats: 统计数据
|
||||
@@ -239,7 +239,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
# -- 以下为统计数据收集方法 --
|
||||
|
||||
@staticmethod
|
||||
async def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
||||
async def _collect_model_request_for_period(collect_period: list[tuple[str, datetime]]) -> dict[str, Any]:
|
||||
"""
|
||||
收集指定时间段的LLM请求统计数据
|
||||
|
||||
@@ -393,8 +393,8 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
@staticmethod
|
||||
async def _collect_online_time_for_period(
|
||||
collect_period: List[Tuple[str, datetime]], now: datetime
|
||||
) -> Dict[str, Any]:
|
||||
collect_period: list[tuple[str, datetime]], now: datetime
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
收集指定时间段的在线时间统计数据
|
||||
|
||||
@@ -452,7 +452,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
break
|
||||
return stats
|
||||
|
||||
async def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
||||
async def _collect_message_count_for_period(self, collect_period: list[tuple[str, datetime]]) -> dict[str, Any]:
|
||||
"""
|
||||
收集指定时间段的消息统计数据
|
||||
|
||||
@@ -523,7 +523,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
break
|
||||
return stats
|
||||
|
||||
async def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
|
||||
async def _collect_all_statistics(self, now: datetime) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
收集各时间段的统计数据
|
||||
:param now: 基准当前时间
|
||||
@@ -533,7 +533,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
if "last_full_statistics" in local_storage:
|
||||
# 如果存在上次完整统计数据,则使用该数据进行增量统计
|
||||
last_stat: Dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore
|
||||
last_stat: dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore
|
||||
|
||||
self.name_mapping = last_stat["name_mapping"] # 上次完整统计数据的名称映射
|
||||
last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据
|
||||
@@ -620,7 +620,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
# -- 以下为统计数据格式化方法 --
|
||||
|
||||
@staticmethod
|
||||
def _format_total_stat(stats: Dict[str, Any]) -> str:
|
||||
def _format_total_stat(stats: dict[str, Any]) -> str:
|
||||
"""
|
||||
格式化总统计数据
|
||||
"""
|
||||
@@ -636,7 +636,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
return "\n".join(output)
|
||||
|
||||
@staticmethod
|
||||
def _format_model_classified_stat(stats: Dict[str, Any]) -> str:
|
||||
def _format_model_classified_stat(stats: dict[str, Any]) -> str:
|
||||
"""
|
||||
格式化按模型分类的统计数据
|
||||
"""
|
||||
@@ -662,7 +662,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
output.append("")
|
||||
return "\n".join(output)
|
||||
|
||||
def _format_chat_stat(self, stats: Dict[str, Any]) -> str:
|
||||
def _format_chat_stat(self, stats: dict[str, Any]) -> str:
|
||||
"""
|
||||
格式化聊天统计数据
|
||||
"""
|
||||
@@ -1007,7 +1007,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
async def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
|
||||
"""生成图表数据 (异步)"""
|
||||
now = datetime.now()
|
||||
chart_data: Dict[str, Any] = {}
|
||||
chart_data: dict[str, Any] = {}
|
||||
|
||||
time_ranges = [
|
||||
("6h", 6, 10),
|
||||
@@ -1023,16 +1023,16 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
async def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict:
|
||||
start_time = now - timedelta(hours=hours)
|
||||
time_points: List[datetime] = []
|
||||
time_points: list[datetime] = []
|
||||
current_time = start_time
|
||||
while current_time <= now:
|
||||
time_points.append(current_time)
|
||||
current_time += timedelta(minutes=interval_minutes)
|
||||
|
||||
total_cost_data = [0.0] * len(time_points)
|
||||
cost_by_model: Dict[str, List[float]] = {}
|
||||
cost_by_module: Dict[str, List[float]] = {}
|
||||
message_by_chat: Dict[str, List[int]] = {}
|
||||
cost_by_model: dict[str, list[float]] = {}
|
||||
cost_by_module: dict[str, list[float]] = {}
|
||||
message_by_chat: dict[str, list[int]] = {}
|
||||
time_labels = [t.strftime("%H:%M") for t in time_points]
|
||||
interval_seconds = interval_minutes * 60
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import asyncio
|
||||
|
||||
from time import perf_counter
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Optional, Dict, Callable
|
||||
from time import perf_counter
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -75,12 +75,12 @@ class Timer:
|
||||
3. 直接实例化:如果不调用 __enter__,打印对象时将显示当前 perf_counter 的值
|
||||
"""
|
||||
|
||||
__slots__ = ("name", "storage", "elapsed", "auto_unit", "start")
|
||||
__slots__ = ("auto_unit", "elapsed", "name", "start", "storage")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
storage: Optional[Dict[str, float]] = None,
|
||||
name: str | None = None,
|
||||
storage: dict[str, float] | None = None,
|
||||
auto_unit: bool = True,
|
||||
do_type_check: bool = False,
|
||||
):
|
||||
@@ -103,7 +103,7 @@ class Timer:
|
||||
if storage is not None and not isinstance(storage, dict):
|
||||
raise TimerTypeError("storage", "Optional[dict]", type(storage))
|
||||
|
||||
def __call__(self, func: Optional[Callable] = None) -> Callable:
|
||||
def __call__(self, func: Callable | None = None) -> Callable:
|
||||
"""装饰器模式"""
|
||||
if func is None:
|
||||
return lambda f: Timer(name=self.name or f.__name__, storage=self.storage, auto_unit=self.auto_unit)(f)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user