re-style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 20:26:01 +08:00
parent ecb02cae31
commit 7923eafef3
263 changed files with 3103 additions and 3123 deletions

View File

@@ -12,7 +12,7 @@ if __name__ == "__main__":
# 执行bot.py的代码
bot_file = current_dir / "bot.py"
with open(bot_file, "r", encoding="utf-8") as f:
with open(bot_file, encoding="utf-8") as f:
exec(f.read())

24
bot.py
View File

@@ -1,30 +1,30 @@
# import asyncio
import asyncio
import os
import platform
import sys
import time
import platform
import traceback
from pathlib import Path
from rich.traceback import install
from colorama import init, Fore
from colorama import Fore, init
from dotenv import load_dotenv # 处理.env文件
from rich.traceback import install
# maim_message imports for console input
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
from src.common.logger import initialize_logging, get_logger, shutdown_logging
from src.common.logger import get_logger, initialize_logging, shutdown_logging
# UI日志适配器
initialize_logging()
from src.main import MainSystem # noqa
from src import BaseMain # noqa
from src.manager.async_task_manager import async_task_manager # noqa
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge # noqa
from src.config.config import global_config # noqa
from src.common.database.database import initialize_sql_database # noqa
from src.common.database.sqlalchemy_models import initialize_database as init_db # noqa
from src import BaseMain
from src.manager.async_task_manager import async_task_manager
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge
from src.config.config import global_config
from src.common.database.database import initialize_sql_database
from src.common.database.sqlalchemy_models import initialize_database as init_db
logger = get_logger("main")
@@ -247,7 +247,7 @@ if __name__ == "__main__":
# The actual shutdown logic is now in the finally block.
except Exception as e:
logger.error(f"主程序发生异常: {str(e)} {str(traceback.format_exc())}")
logger.error(f"主程序发生异常: {e!s} {traceback.format_exc()!s}")
exit_code = 1 # 标记发生错误
finally:
# 确保 loop 在任何情况下都尝试关闭(如果存在且未关闭)

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Bilibili 插件包
提供B站视频观看体验功能像真实用户一样浏览和评价视频

View File

@@ -1,16 +1,17 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Bilibili 工具基础模块
提供 B 站视频信息获取和视频分析功能
"""
import re
import aiohttp
import asyncio
from typing import Optional, Dict, Any
from src.common.logger import get_logger
import re
from typing import Any
import aiohttp
from src.chat.utils.utils_video import get_video_analyzer
from src.common.logger import get_logger
logger = get_logger("bilibili_tool")
@@ -25,7 +26,7 @@ class BilibiliVideoAnalyzer:
"Referer": "https://www.bilibili.com/",
}
def extract_bilibili_url(self, text: str) -> Optional[str]:
def extract_bilibili_url(self, text: str) -> str | None:
"""从文本中提取哔哩哔哩视频链接"""
# 哔哩哔哩短链接模式
short_pattern = re.compile(r"https?://b23\.tv/[\w]+", re.IGNORECASE)
@@ -44,7 +45,7 @@ class BilibiliVideoAnalyzer:
return None
async def get_video_info(self, url: str) -> Optional[Dict[str, Any]]:
async def get_video_info(self, url: str) -> dict[str, Any] | None:
"""获取哔哩哔哩视频基本信息"""
try:
logger.info(f"🔍 解析视频URL: {url}")
@@ -127,7 +128,7 @@ class BilibiliVideoAnalyzer:
logger.exception("详细错误信息:")
return None
async def get_video_stream_url(self, aid: int, cid: int) -> Optional[str]:
async def get_video_stream_url(self, aid: int, cid: int) -> str | None:
"""获取视频流URL"""
try:
logger.info(f"🎥 获取视频流URL: aid={aid}, cid={cid}")
@@ -164,7 +165,7 @@ class BilibiliVideoAnalyzer:
return stream_url
# 降级到FLV格式
if "durl" in play_data and play_data["durl"]:
if play_data.get("durl"):
logger.info("📹 使用FLV格式视频流")
stream_url = play_data["durl"][0].get("url")
if stream_url:
@@ -185,7 +186,7 @@ class BilibiliVideoAnalyzer:
logger.exception("详细错误信息:")
return None
async def download_video_bytes(self, stream_url: str, max_size_mb: int = 100) -> Optional[bytes]:
async def download_video_bytes(self, stream_url: str, max_size_mb: int = 100) -> bytes | None:
"""下载视频字节数据
Args:
@@ -244,7 +245,7 @@ class BilibiliVideoAnalyzer:
logger.exception("详细错误信息:")
return None
async def analyze_bilibili_video(self, url: str, prompt: str = None) -> Dict[str, Any]:
async def analyze_bilibili_video(self, url: str, prompt: str = None) -> dict[str, Any]:
"""分析哔哩哔哩视频并返回详细信息和AI分析结果"""
try:
logger.info(f"🎬 开始分析哔哩哔哩视频: {url}")
@@ -322,10 +323,10 @@ class BilibiliVideoAnalyzer:
return result
except Exception as e:
error_msg = f"分析哔哩哔哩视频时发生异常: {str(e)}"
error_msg = f"分析哔哩哔哩视频时发生异常: {e!s}"
logger.error(f"{error_msg}")
logger.exception("详细错误信息:") # 记录完整的异常堆栈
return {"error": f"分析失败: {str(e)}"}
return {"error": f"分析失败: {e!s}"}
# 全局实例

View File

@@ -1,14 +1,15 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Bilibili 视频观看体验工具
支持哔哩哔哩视频链接解析和AI视频内容分析
"""
from typing import Dict, Any, List, Tuple, Type
from src.plugin_system import BaseTool, ToolParamType, BasePlugin, register_plugin, ComponentInfo, ConfigField
from .bilibli_base import get_bilibili_analyzer
from typing import Any
from src.common.logger import get_logger
from src.plugin_system import BasePlugin, BaseTool, ComponentInfo, ConfigField, ToolParamType, register_plugin
from .bilibli_base import get_bilibili_analyzer
logger = get_logger("bilibili_tool")
@@ -41,7 +42,7 @@ class BilibiliTool(BaseTool):
super().__init__(plugin_config)
self.analyzer = get_bilibili_analyzer()
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行哔哩哔哩视频观看体验"""
try:
url = function_args.get("url", "").strip()
@@ -83,7 +84,7 @@ class BilibiliTool(BaseTool):
return {"name": self.name, "content": content.strip()}
except Exception as e:
error_msg = f"😅 看视频的时候出了点问题: {str(e)}"
error_msg = f"😅 看视频的时候出了点问题: {e!s}"
logger.error(error_msg)
return {"name": self.name, "content": error_msg}
@@ -104,7 +105,7 @@ class BilibiliTool(BaseTool):
return base_prompt
def _format_watch_experience(self, video_info: Dict, ai_analysis: str, interest_focus: str = None) -> str:
def _format_watch_experience(self, video_info: dict, ai_analysis: str, interest_focus: str = None) -> str:
"""格式化观看体验报告"""
# 根据播放量生成热度评价
@@ -191,8 +192,8 @@ class BilibiliPlugin(BasePlugin):
# 插件基本信息
plugin_name: str = "bilibili_video_watcher"
enable_plugin: bool = True
dependencies: List[str] = []
python_dependencies: List[str] = []
dependencies: list[str] = []
python_dependencies: list[str] = []
config_file_name: str = "config.toml"
# 配置节描述
@@ -220,6 +221,6 @@ class BilibiliPlugin(BasePlugin):
},
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
"""返回插件包含的工具组件"""
return [(BilibiliTool.get_tool_info(), BilibiliTool)]

View File

@@ -4,14 +4,15 @@ Echo 示例插件
展示增强命令系统的使用方法
"""
from typing import List, Tuple, Type, Optional, Union
from typing import Union
from src.plugin_system import (
BasePlugin,
PlusCommand,
CommandArgs,
PlusCommandInfo,
ConfigField,
ChatType,
CommandArgs,
ConfigField,
PlusCommand,
PlusCommandInfo,
register_plugin,
)
from src.plugin_system.base.component_types import PythonDependency
@@ -27,7 +28,7 @@ class EchoCommand(PlusCommand):
chat_type_allow = ChatType.ALL
intercept_message = True
async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]:
async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]:
"""执行echo命令"""
if args.is_empty():
await self.send_text("❓ 请提供要回显的内容\n用法: /echo <内容>")
@@ -56,7 +57,7 @@ class HelloCommand(PlusCommand):
chat_type_allow = ChatType.ALL
intercept_message = True
async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]:
async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]:
"""执行hello命令"""
if args.is_empty():
await self.send_text("👋 Hello! 很高兴见到你!")
@@ -77,7 +78,7 @@ class InfoCommand(PlusCommand):
chat_type_allow = ChatType.ALL
intercept_message = True
async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]:
async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]:
"""执行info命令"""
info_text = (
"📋 Echo 示例插件信息\n"
@@ -105,7 +106,7 @@ class TestCommand(PlusCommand):
chat_type_allow = ChatType.ALL
intercept_message = True
async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]:
async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]:
"""执行test命令"""
if args.is_empty():
help_text = (
@@ -166,8 +167,8 @@ class EchoExamplePlugin(BasePlugin):
plugin_name: str = "echo_example_plugin"
enable_plugin: bool = True
dependencies: List[str] = []
python_dependencies: List[Union[str, "PythonDependency"]] = []
dependencies: list[str] = []
python_dependencies: list[Union[str, "PythonDependency"]] = []
config_file_name: str = "config.toml"
config_schema = {
@@ -187,7 +188,7 @@ class EchoExamplePlugin(BasePlugin):
"commands": "命令相关配置",
}
def get_plugin_components(self) -> List[Tuple[PlusCommandInfo, Type]]:
def get_plugin_components(self) -> list[tuple[PlusCommandInfo, type]]:
"""获取插件组件"""
components = []

View File

@@ -1,20 +1,20 @@
from typing import List, Tuple, Type, Dict, Any, Optional
import logging
import random
from typing import Any
from src.plugin_system import (
BasePlugin,
register_plugin,
ComponentInfo,
BaseEventHandler,
EventType,
BaseTool,
PlusCommand,
CommandArgs,
ChatType,
BaseAction,
ActionActivationType,
BaseAction,
BaseEventHandler,
BasePlugin,
BaseTool,
ChatType,
CommandArgs,
ComponentInfo,
ConfigField,
EventType,
PlusCommand,
register_plugin,
)
from src.plugin_system.base.base_event import HandlerResult
@@ -39,7 +39,7 @@ class GetSystemInfoTool(BaseTool):
available_for_llm = True
parameters = []
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
return {"name": self.name, "content": "系统版本: 1.0.1, 状态: 运行正常"}
@@ -51,7 +51,7 @@ class HelloCommand(PlusCommand):
command_aliases = ["hi", "你好"]
chat_type_allow = ChatType.ALL
async def execute(self, args: CommandArgs) -> Tuple[bool, Optional[str], bool]:
async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]:
greeting = str(self.get_config("greeting.message", "Hello, World! 我是一个由 MoFox_Bot 驱动的插件。"))
await self.send_text(greeting)
return True, "成功发送问候", True
@@ -67,7 +67,7 @@ class RandomEmojiAction(BaseAction):
action_require = ["当对话气氛轻松时", "可以用来回应简单的情感表达"]
associated_types = ["text"]
async def execute(self) -> Tuple[bool, str]:
async def execute(self) -> tuple[bool, str]:
emojis = ["😊", "😂", "👍", "🎉", "🤔", "🤖"]
await self.send_text(random.choice(emojis))
return True, "成功发送了一个随机表情"
@@ -99,9 +99,9 @@ class HelloWorldPlugin(BasePlugin):
},
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
"""根据配置文件动态注册插件的功能组件。"""
components: List[Tuple[ComponentInfo, Type]] = []
components: list[tuple[ComponentInfo, type]] = []
components.append((StartupMessageHandler.get_handler_info(), StartupMessageHandler))
components.append((GetSystemInfoTool.get_tool_info(), GetSystemInfoTool))

View File

@@ -70,6 +70,7 @@ dependencies = [
"tqdm>=4.67.1",
"urllib3>=2.5.0",
"uvicorn>=0.35.0",
"watchdog>=6.0.0",
"websockets>=15.0.1",
"aiomysql>=0.2.0",
"aiosqlite>=0.21.0",
@@ -80,29 +81,41 @@ dependencies = [
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
default = true
[tool.uv.sources]
amrita = { workspace = true }
[tool.ruff]
include = ["*.py"]
# 行长度设置
line-length = 120
target-version = "py310"
[tool.ruff.lint]
fixable = ["ALL"]
unfixable = []
select = [
"F", # Pyflakes
"W", # pycodestyle warnings
"E", # pycodestyle errors
"UP", # pyupgrade
"ASYNC", # flake8-async
"C4", # flake8-comprehensions
"T10", # flake8-debugger
"PYI", # flake8-pyi
"PT", # flake8-pytest-style
"Q", # flake8-quotes
"RUF", # Ruff-specific rules
"I", # isort
"PERF", # pylint-performance
]
ignore = [
"E402", # module-import-not-at-top-of-file
"E501", # line-too-long
"UP037", # quoted-annotation
"RUF001", # ambiguous-unicode-character-string
"RUF002", # ambiguous-unicode-character-docstring
"RUF003", # ambiguous-unicode-character-comment
]
# 如果一个变量的名称以下划线开头,即使它未被使用,也不应该被视为错误或警告。
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
# 启用的规则
select = [
"E", # pycodestyle 错误
"F", # pyflakes
"B", # flake8-bugbear
]
ignore = ["E711","E501"]
[tool.ruff.format]
docstring-code-format = true
indent-style = "space"
@@ -124,6 +137,4 @@ skip-magic-trailing-comma = false
line-ending = "auto"
[dependency-groups]
lint = [
"loguru>=0.7.3",
]
lint = ["loguru>=0.7.3"]

View File

@@ -1,10 +1,9 @@
import time
import sys
import os
from typing import Dict, List
import sys
import time
# Add project root to Python path
from src.common.database.database_model import Expression, ChatStreams
from src.common.database.database_model import ChatStreams, Expression
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
@@ -30,7 +29,7 @@ def get_chat_name(chat_id: str) -> str:
return f"查询失败 ({chat_id})"
def calculate_time_distribution(expressions) -> Dict[str, int]:
def calculate_time_distribution(expressions) -> dict[str, int]:
"""Calculate distribution of last active time in days"""
now = time.time()
distribution = {
@@ -64,7 +63,7 @@ def calculate_time_distribution(expressions) -> Dict[str, int]:
return distribution
def calculate_count_distribution(expressions) -> Dict[str, int]:
def calculate_count_distribution(expressions) -> dict[str, int]:
"""Calculate distribution of count values"""
distribution = {"0-1": 0, "1-2": 0, "2-3": 0, "3-4": 0, "4-5": 0, "5-10": 0, "10+": 0}
for expr in expressions:
@@ -86,7 +85,7 @@ def calculate_count_distribution(expressions) -> Dict[str, int]:
return distribution
def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> List[Expression]:
def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> list[Expression]:
"""Get top N most used expressions for a specific chat_id"""
return Expression.select().where(Expression.chat_id == chat_id).order_by(Expression.count.desc()).limit(top_n)

View File

@@ -1,7 +1,6 @@
import time
import sys
import os
from typing import Dict, List, Tuple, Optional
import sys
import time
from datetime import datetime
# Add project root to Python path
@@ -35,7 +34,7 @@ def format_timestamp(timestamp: float) -> str:
return "未知时间"
def calculate_interest_value_distribution(messages) -> Dict[str, int]:
def calculate_interest_value_distribution(messages) -> dict[str, int]:
"""Calculate distribution of interest_value"""
distribution = {
"0.000-0.010": 0,
@@ -76,7 +75,7 @@ def calculate_interest_value_distribution(messages) -> Dict[str, int]:
return distribution
def get_interest_value_stats(messages) -> Dict[str, float]:
def get_interest_value_stats(messages) -> dict[str, float]:
"""Calculate basic statistics for interest_value"""
values = [
float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0
@@ -97,7 +96,7 @@ def get_interest_value_stats(messages) -> Dict[str, float]:
}
def get_available_chats() -> List[Tuple[str, str, int]]:
def get_available_chats() -> list[tuple[str, str, int]]:
"""Get all available chats with message counts"""
try:
# 获取所有有消息的chat_id
@@ -130,7 +129,7 @@ def get_available_chats() -> List[Tuple[str, str, int]]:
return []
def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
def get_time_range_input() -> tuple[float | None, float | None]:
"""Get time range input from user"""
print("\n时间范围选择:")
print("1. 最近1天")
@@ -170,7 +169,7 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
def analyze_interest_values(
chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None
chat_id: str | None = None, start_time: float | None = None, end_time: float | None = None
) -> None:
"""Analyze interest values with optional filters"""

View File

@@ -1,13 +1,14 @@
import tkinter as tk
from tkinter import ttk, messagebox, filedialog, colorchooser
import orjson
from pathlib import Path
import threading
import toml
from datetime import datetime
from collections import defaultdict
import os
import threading
import time
import tkinter as tk
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from tkinter import colorchooser, filedialog, messagebox, ttk
import orjson
import toml
class LogIndex:
@@ -409,7 +410,7 @@ class AsyncLogLoader:
file_size = os.path.getsize(file_path)
processed_size = 0
with open(file_path, "r", encoding="utf-8") as f:
with open(file_path, encoding="utf-8") as f:
line_count = 0
batch_size = 1000 # 批量处理
@@ -561,7 +562,7 @@ class LogViewer:
try:
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f:
with open(config_path, encoding="utf-8") as f:
bot_config = toml.load(f)
if "log" in bot_config:
self.log_config.update(bot_config["log"])
@@ -575,7 +576,7 @@ class LogViewer:
try:
if viewer_config_path.exists():
with open(viewer_config_path, "r", encoding="utf-8") as f:
with open(viewer_config_path, encoding="utf-8") as f:
viewer_config = toml.load(f)
if "viewer" in viewer_config:
self.viewer_config.update(viewer_config["viewer"])
@@ -843,7 +844,7 @@ class LogViewer:
mapping_file = Path("config/module_mapping.json")
if mapping_file.exists():
try:
with open(mapping_file, "r", encoding="utf-8") as f:
with open(mapping_file, encoding="utf-8") as f:
custom_mapping = orjson.loads(f.read())
self.module_name_mapping.update(custom_mapping)
except Exception as e:
@@ -1172,7 +1173,7 @@ class LogViewer:
"""读取新的日志条目并返回它们"""
new_entries = []
new_modules = set() # 收集新发现的模块
with open(self.current_log_file, "r", encoding="utf-8") as f:
with open(self.current_log_file, encoding="utf-8") as f:
f.seek(from_position)
line_count = self.log_index.total_entries
for line in f:

View File

@@ -1,36 +1,37 @@
import asyncio
import datetime
import os
import shutil
import sys
import orjson
import datetime
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from threading import Lock
from typing import Optional
import orjson
from json_repair import repair_json
# 将项目根目录添加到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import get_logger
from src.chat.knowledge.utils.hash import get_sha256
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.chat.knowledge.open_ie import OpenIE
from src.chat.knowledge.embedding_store import EmbeddingManager
from src.chat.knowledge.kg_manager import KGManager
from rich.progress import (
Progress,
BarColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
)
from src.chat.knowledge.embedding_store import EmbeddingManager
from src.chat.knowledge.kg_manager import KGManager
from src.chat.knowledge.open_ie import OpenIE
from src.chat.knowledge.utils.hash import get_sha256
from src.common.logger import get_logger
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger("LPMM_LearningTool")
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data", "lpmm_raw_data")
@@ -59,7 +60,7 @@ def clear_cache():
def process_text_file(file_path):
with open(file_path, "r", encoding="utf-8") as f:
with open(file_path, encoding="utf-8") as f:
raw = f.read()
return [p.strip() for p in raw.split("\n\n") if p.strip()]
@@ -86,7 +87,7 @@ def preprocess_raw_data():
# --- 模块二:信息提取 ---
def _parse_and_repair_json(json_string: str) -> Optional[dict]:
def _parse_and_repair_json(json_string: str) -> dict | None:
"""
尝试解析JSON字符串如果失败则尝试修复并重新解析。
@@ -249,7 +250,7 @@ def extract_information(paragraphs_dict, model_set):
# --- 模块三:数据导入 ---
async def import_data(openie_obj: Optional[OpenIE] = None):
async def import_data(openie_obj: OpenIE | None = None):
"""
将OpenIE数据导入知识库Embedding Store 和 KG

View File

@@ -4,11 +4,13 @@
提供插件manifest文件的创建、验证和管理功能
"""
import argparse
import os
import sys
import argparse
import orjson
from pathlib import Path
import orjson
from src.common.logger import get_logger
from src.plugin_system.utils.manifest_utils import (
ManifestValidator,
@@ -124,7 +126,7 @@ def validate_manifest_file(plugin_dir: str) -> bool:
return False
try:
with open(manifest_path, "r", encoding="utf-8") as f:
with open(manifest_path, encoding="utf-8") as f:
manifest_data = orjson.loads(f.read())
validator = ManifestValidator()

View File

@@ -1,46 +1,48 @@
import os
import orjson
import sys # 新增系统模块导入
# import time
import pickle
import sys # 新增系统模块导入
from pathlib import Path
import orjson
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from typing import Dict, Any, List, Optional, Type
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from peewee import Field, IntegrityError, Model
from pymongo import MongoClient
from pymongo.errors import ConnectionFailure
from peewee import Model, Field, IntegrityError
# Rich 进度条和显示组件
from rich.console import Console
from rich.panel import Panel
from rich.progress import (
Progress,
TextColumn,
BarColumn,
TaskProgressColumn,
TimeRemainingColumn,
TimeElapsedColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
from rich.table import Table
from rich.panel import Panel
# from rich.text import Text
# from rich.text import Text
from src.common.database.database import db
from src.common.database.sqlalchemy_models import (
ChatStreams,
Emoji,
Messages,
Images,
ImageDescriptions,
PersonInfo,
Knowledges,
ThinkingLog,
GraphNodes,
GraphEdges,
GraphNodes,
ImageDescriptions,
Images,
Knowledges,
Messages,
PersonInfo,
ThinkingLog,
)
from src.common.logger import get_logger
@@ -54,12 +56,12 @@ class MigrationConfig:
"""迁移配置类"""
mongo_collection: str
target_model: Type[Model]
field_mapping: Dict[str, str]
target_model: type[Model]
field_mapping: dict[str, str]
batch_size: int = 500
enable_validation: bool = True
skip_duplicates: bool = True
unique_fields: List[str] = field(default_factory=list) # 用于重复检查的字段
unique_fields: list[str] = field(default_factory=list) # 用于重复检查的字段
# 数据验证相关类已移除 - 用户要求不要数据验证
@@ -73,7 +75,7 @@ class MigrationCheckpoint:
processed_count: int
last_processed_id: Any
timestamp: datetime
batch_errors: List[Dict[str, Any]] = field(default_factory=list)
batch_errors: list[dict[str, Any]] = field(default_factory=list)
@dataclass
@@ -88,11 +90,11 @@ class MigrationStats:
duplicate_count: int = 0
validation_errors: int = 0
batch_insert_count: int = 0
errors: List[Dict[str, Any]] = field(default_factory=list)
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None
errors: list[dict[str, Any]] = field(default_factory=list)
start_time: datetime | None = None
end_time: datetime | None = None
def add_error(self, doc_id: Any, error: str, doc_data: Optional[Dict] = None):
def add_error(self, doc_id: Any, error: str, doc_data: dict | None = None):
"""添加错误记录"""
self.errors.append(
{"doc_id": str(doc_id), "error": error, "timestamp": datetime.now().isoformat(), "doc_data": doc_data}
@@ -108,10 +110,10 @@ class MigrationStats:
class MongoToSQLiteMigrator:
"""MongoDB到SQLite数据迁移器 - 使用Peewee ORM"""
def __init__(self, mongo_uri: Optional[str] = None, database_name: Optional[str] = None):
def __init__(self, mongo_uri: str | None = None, database_name: str | None = None):
self.database_name = database_name or os.getenv("DATABASE_NAME", "MegBot")
self.mongo_uri = mongo_uri or self._build_mongo_uri()
self.mongo_client: Optional[MongoClient] = None
self.mongo_client: MongoClient | None = None
self.mongo_db = None
# 迁移配置
@@ -142,7 +144,7 @@ class MongoToSQLiteMigrator:
else:
return f"mongodb://{host}:{port}/{self.database_name}"
def _initialize_migration_configs(self) -> List[MigrationConfig]:
def _initialize_migration_configs(self) -> list[MigrationConfig]:
"""初始化迁移配置"""
return [ # 表情包迁移配置
MigrationConfig(
@@ -306,7 +308,7 @@ class MongoToSQLiteMigrator:
),
]
def _initialize_validation_rules(self) -> Dict[str, Any]:
def _initialize_validation_rules(self) -> dict[str, Any]:
"""数据验证已禁用 - 返回空字典"""
return {}
@@ -337,7 +339,7 @@ class MongoToSQLiteMigrator:
self.mongo_client.close()
logger.info("MongoDB连接已关闭")
def _get_nested_value(self, document: Dict[str, Any], field_path: str) -> Any:
def _get_nested_value(self, document: dict[str, Any], field_path: str) -> Any:
"""获取嵌套字段的值"""
if "." not in field_path:
return document.get(field_path)
@@ -434,7 +436,7 @@ class MongoToSQLiteMigrator:
return None
def _validate_data(self, collection_name: str, data: Dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool:
def _validate_data(self, collection_name: str, data: dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool:
"""数据验证已禁用 - 始终返回True"""
return True
@@ -454,7 +456,7 @@ class MongoToSQLiteMigrator:
except Exception as e:
logger.warning(f"保存断点失败: {e}")
def _load_checkpoint(self, collection_name: str) -> Optional[MigrationCheckpoint]:
def _load_checkpoint(self, collection_name: str) -> MigrationCheckpoint | None:
"""加载迁移断点"""
checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl"
if not checkpoint_file.exists():
@@ -467,7 +469,7 @@ class MongoToSQLiteMigrator:
logger.warning(f"加载断点失败: {e}")
return None
def _batch_insert(self, model: Type[Model], data_list: List[Dict[str, Any]]) -> int:
def _batch_insert(self, model: type[Model], data_list: list[dict[str, Any]]) -> int:
"""批量插入数据"""
if not data_list:
return 0
@@ -494,7 +496,7 @@ class MongoToSQLiteMigrator:
return success_count
def _check_duplicate_by_unique_fields(
self, model: Type[Model], data: Dict[str, Any], unique_fields: List[str]
self, model: type[Model], data: dict[str, Any], unique_fields: list[str]
) -> bool:
"""根据唯一字段检查重复"""
if not unique_fields:
@@ -512,7 +514,7 @@ class MongoToSQLiteMigrator:
logger.debug(f"重复检查失败: {e}")
return False
def _create_model_instance(self, model: Type[Model], data: Dict[str, Any]) -> Optional[Model]:
def _create_model_instance(self, model: type[Model], data: dict[str, Any]) -> Model | None:
"""使用ORM创建模型实例"""
try:
# 过滤掉不存在的字段
@@ -669,7 +671,7 @@ class MongoToSQLiteMigrator:
return stats
def migrate_all(self) -> Dict[str, MigrationStats]:
def migrate_all(self) -> dict[str, MigrationStats]:
"""执行所有迁移任务"""
logger.info("开始执行数据库迁移...")
@@ -730,7 +732,7 @@ class MongoToSQLiteMigrator:
self._print_migration_summary(all_stats)
return all_stats
def _print_migration_summary(self, all_stats: Dict[str, MigrationStats]):
def _print_migration_summary(self, all_stats: dict[str, MigrationStats]):
"""使用Rich打印美观的迁移汇总信息"""
# 计算总体统计
total_processed = sum(stats.processed_count for stats in all_stats.values())
@@ -857,7 +859,7 @@ class MongoToSQLiteMigrator:
"""添加新的迁移配置"""
self.migration_configs.append(config)
def migrate_single_collection(self, collection_name: str) -> Optional[MigrationStats]:
def migrate_single_collection(self, collection_name: str) -> MigrationStats | None:
"""迁移单个指定的集合"""
config = next((c for c in self.migration_configs if c.mongo_collection == collection_name), None)
if not config:
@@ -875,7 +877,7 @@ class MongoToSQLiteMigrator:
finally:
self.disconnect_mongodb()
def export_error_report(self, all_stats: Dict[str, MigrationStats], filepath: str):
def export_error_report(self, all_stats: dict[str, MigrationStats], filepath: str):
"""导出错误报告"""
error_report = {
"timestamp": datetime.now().isoformat(),

View File

@@ -1,17 +1,16 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
从现有ChromaDB数据重建JSON元数据索引
"""
import asyncio
import sys
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.chat.memory_system.memory_system import MemorySystem
from src.chat.memory_system.memory_metadata_index import MemoryMetadataIndexEntry
from src.chat.memory_system.memory_system import MemorySystem
from src.common.logger import get_logger
logger = get_logger(__name__)

View File

@@ -1,12 +1,11 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
轻量烟雾测试:初始化 MemorySystem 并运行一次检索,验证 MemoryMetadata.source 访问不再报错
"""
import asyncio
import sys
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

View File

@@ -1,8 +1,7 @@
import time
import sys
import os
import re
from typing import Dict, List, Tuple, Optional
import sys
import time
from datetime import datetime
# Add project root to Python path
@@ -63,7 +62,7 @@ def format_timestamp(timestamp: float) -> str:
return "未知时间"
def calculate_text_length_distribution(messages) -> Dict[str, int]:
def calculate_text_length_distribution(messages) -> dict[str, int]:
"""Calculate distribution of processed_plain_text length"""
distribution = {
"0": 0, # 空文本
@@ -126,7 +125,7 @@ def calculate_text_length_distribution(messages) -> Dict[str, int]:
return distribution
def get_text_length_stats(messages) -> Dict[str, float]:
def get_text_length_stats(messages) -> dict[str, float]:
"""Calculate basic statistics for processed_plain_text length"""
lengths = []
null_count = 0
@@ -168,7 +167,7 @@ def get_text_length_stats(messages) -> Dict[str, float]:
}
def get_available_chats() -> List[Tuple[str, str, int]]:
def get_available_chats() -> list[tuple[str, str, int]]:
"""Get all available chats with message counts"""
try:
# 获取所有有消息的chat_id排除特殊类型消息
@@ -202,7 +201,7 @@ def get_available_chats() -> List[Tuple[str, str, int]]:
return []
def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
def get_time_range_input() -> tuple[float | None, float | None]:
"""Get time range input from user"""
print("\n时间范围选择:")
print("1. 最近1天")
@@ -241,7 +240,7 @@ def get_time_range_input() -> Tuple[Optional[float], Optional[float]]:
return None, None
def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int, str, str]]:
def get_top_longest_messages(messages, top_n: int = 10) -> list[tuple[str, int, str, str]]:
"""Get top N longest messages"""
message_lengths = []
@@ -266,7 +265,7 @@ def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int,
def analyze_text_lengths(
chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None
chat_id: str | None = None, start_time: float | None = None, end_time: float | None = None
) -> None:
"""Analyze processed_plain_text lengths with optional filters"""

View File

@@ -30,7 +30,7 @@ def update_prompt_imports(file_path):
print(f"文件不存在: {file_path}")
return False
with open(file_path, "r", encoding="utf-8") as f:
with open(file_path, encoding="utf-8") as f:
content = f.read()
# 替换导入语句

View File

@@ -1,13 +1,15 @@
import random
from typing import List, Optional, Sequence
from colorama import init, Fore
from collections.abc import Sequence
from typing import List, Optional
from colorama import Fore, init
from src.common.logger import get_logger
egg = get_logger("小彩蛋")
def weighted_choice(data: Sequence[str], weights: Optional[List[float]] = None) -> str:
def weighted_choice(data: Sequence[str], weights: list[float] | None = None) -> str:
"""
从 data 中按权重随机返回一条。
若 weights 为 None则所有元素权重默认为 1。

View File

@@ -3,8 +3,8 @@ MaiBot模块系统
包含聊天、情绪、记忆、日程等功能模块
"""
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.message_receive.chat_stream import get_chat_manager
# 导出主要组件供外部使用
__all__ = [

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
MaiBot 反注入系统模块
@@ -14,25 +13,25 @@ MaiBot 反注入系统模块
"""
from .anti_injector import AntiPromptInjector, get_anti_injector, initialize_anti_injector
from .types import DetectionResult, ProcessResult
from .core import PromptInjectionDetector, MessageShield
from .processors.message_processor import MessageProcessor
from .management import AntiInjectionStatistics, UserBanManager
from .core import MessageShield, PromptInjectionDetector
from .decision import CounterAttackGenerator, ProcessingDecisionMaker
from .management import AntiInjectionStatistics, UserBanManager
from .processors.message_processor import MessageProcessor
from .types import DetectionResult, ProcessResult
__all__ = [
"AntiInjectionStatistics",
"AntiPromptInjector",
"CounterAttackGenerator",
"DetectionResult",
"MessageProcessor",
"MessageShield",
"ProcessResult",
"ProcessingDecisionMaker",
"PromptInjectionDetector",
"UserBanManager",
"get_anti_injector",
"initialize_anti_injector",
"DetectionResult",
"ProcessResult",
"PromptInjectionDetector",
"MessageShield",
"MessageProcessor",
"AntiInjectionStatistics",
"UserBanManager",
"CounterAttackGenerator",
"ProcessingDecisionMaker",
]

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
LLM反注入系统主模块
@@ -12,15 +11,16 @@ LLM反注入系统主模块
"""
import time
from typing import Optional, Tuple, Dict, Any
from typing import Any
from src.common.logger import get_logger
from src.config.config import global_config
from .types import ProcessResult
from .core import PromptInjectionDetector, MessageShield
from .processors.message_processor import MessageProcessor
from .management import AntiInjectionStatistics, UserBanManager
from .core import MessageShield, PromptInjectionDetector
from .decision import CounterAttackGenerator, ProcessingDecisionMaker
from .management import AntiInjectionStatistics, UserBanManager
from .processors.message_processor import MessageProcessor
from .types import ProcessResult
logger = get_logger("anti_injector")
@@ -43,7 +43,7 @@ class AntiPromptInjector:
async def process_message(
self, message_data: dict, chat_stream=None
) -> Tuple[ProcessResult, Optional[str], Optional[str]]:
) -> tuple[ProcessResult, str | None, str | None]:
"""处理字典格式的消息并返回结果
Args:
@@ -102,7 +102,7 @@ class AntiPromptInjector:
await self.statistics.update_stats(error_count=1)
# 异常情况下直接阻止消息
return ProcessResult.BLOCKED_INJECTION, None, f"反注入系统异常,消息已阻止: {str(e)}"
return ProcessResult.BLOCKED_INJECTION, None, f"反注入系统异常,消息已阻止: {e!s}"
finally:
# 更新处理时间统计
@@ -111,7 +111,7 @@ class AntiPromptInjector:
async def _process_message_internal(
self, text_to_detect: str, user_id: str, platform: str, processed_plain_text: str, start_time: float
) -> Tuple[ProcessResult, Optional[str], Optional[str]]:
) -> tuple[ProcessResult, str | None, str | None]:
"""内部消息处理逻辑(共用的检测核心)"""
# 如果是纯引用消息,直接允许通过
@@ -218,7 +218,7 @@ class AntiPromptInjector:
return ProcessResult.ALLOWED, None, "消息检查通过"
async def handle_message_storage(
self, result: ProcessResult, modified_content: Optional[str], reason: str, message_data: dict
self, result: ProcessResult, modified_content: str | None, reason: str, message_data: dict
) -> None:
"""处理违禁消息的数据库存储,根据处理模式决定如何处理"""
if result == ProcessResult.BLOCKED_INJECTION or result == ProcessResult.COUNTER_ATTACK:
@@ -253,9 +253,10 @@ class AntiPromptInjector:
async def _delete_message_from_storage(message_data: dict) -> None:
"""从数据库中删除违禁消息记录"""
try:
from src.common.database.sqlalchemy_models import Messages, get_db_session
from sqlalchemy import delete
from src.common.database.sqlalchemy_models import Messages, get_db_session
message_id = message_data.get("message_id")
if not message_id:
logger.warning("无法删除消息缺少message_id")
@@ -279,9 +280,10 @@ class AntiPromptInjector:
async def _update_message_in_storage(message_data: dict, new_content: str) -> None:
"""更新数据库中的消息内容为加盾版本"""
try:
from src.common.database.sqlalchemy_models import Messages, get_db_session
from sqlalchemy import update
from src.common.database.sqlalchemy_models import Messages, get_db_session
message_id = message_data.get("message_id")
if not message_id:
logger.warning("无法更新消息缺少message_id")
@@ -305,7 +307,7 @@ class AntiPromptInjector:
except Exception as e:
logger.error(f"更新消息内容失败: {e}")
async def get_stats(self) -> Dict[str, Any]:
async def get_stats(self) -> dict[str, Any]:
"""获取统计信息"""
return await self.statistics.get_stats()
@@ -315,7 +317,7 @@ class AntiPromptInjector:
# 全局反注入器实例
_global_injector: Optional[AntiPromptInjector] = None
_global_injector: AntiPromptInjector | None = None
def get_anti_injector() -> AntiPromptInjector:

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
反注入系统核心检测模块
@@ -10,4 +9,4 @@
from .detector import PromptInjectionDetector
from .shield import MessageShield
__all__ = ["PromptInjectionDetector", "MessageShield"]
__all__ = ["MessageShield", "PromptInjectionDetector"]

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
提示词注入检测器模块
@@ -8,19 +7,19 @@
3. 缓存机制优化性能
"""
import hashlib
import re
import time
import hashlib
from typing import Dict, List
from dataclasses import asdict
from src.common.logger import get_logger
from src.config.config import global_config
from ..types import DetectionResult
# 导入LLM API
from src.plugin_system.apis import llm_api
from ..types import DetectionResult
logger = get_logger("anti_injector.detector")
@@ -30,8 +29,8 @@ class PromptInjectionDetector:
def __init__(self):
"""初始化检测器"""
self.config = global_config.anti_prompt_injection
self._cache: Dict[str, DetectionResult] = {}
self._compiled_patterns: List[re.Pattern] = []
self._cache: dict[str, DetectionResult] = {}
self._compiled_patterns: list[re.Pattern] = []
self._compile_patterns()
def _compile_patterns(self):
@@ -224,7 +223,7 @@ class PromptInjectionDetector:
matched_patterns=[],
processing_time=processing_time,
detection_method="llm",
reason=f"LLM检测出错: {str(e)}",
reason=f"LLM检测出错: {e!s}",
)
@staticmethod
@@ -250,7 +249,7 @@ class PromptInjectionDetector:
请客观分析,避免误判正常对话。"""
@staticmethod
def _parse_llm_response(response: str) -> Dict:
def _parse_llm_response(response: str) -> dict:
"""解析LLM响应"""
try:
lines = response.strip().split("\n")
@@ -280,7 +279,7 @@ class PromptInjectionDetector:
except Exception as e:
logger.error(f"解析LLM响应失败: {e}")
return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {str(e)}"}
return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {e!s}"}
async def detect(self, message: str) -> DetectionResult:
"""执行检测"""
@@ -331,7 +330,7 @@ class PromptInjectionDetector:
return final_result
def _merge_results(self, results: List[DetectionResult]) -> DetectionResult:
def _merge_results(self, results: list[DetectionResult]) -> DetectionResult:
"""合并多个检测结果"""
if not results:
return DetectionResult(reason="无检测结果")
@@ -384,7 +383,7 @@ class PromptInjectionDetector:
if expired_keys:
logger.debug(f"清理了{len(expired_keys)}个过期缓存项")
def get_cache_stats(self) -> Dict:
def get_cache_stats(self) -> dict:
"""获取缓存统计信息"""
return {
"cache_size": len(self._cache),

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
消息加盾模块
@@ -6,8 +5,6 @@
主要通过注入系统提示词来指导AI安全响应。
"""
from typing import List
from src.common.logger import get_logger
from src.config.config import global_config
@@ -35,7 +32,7 @@ class MessageShield:
return SAFETY_SYSTEM_PROMPT
@staticmethod
def is_shield_needed(confidence: float, matched_patterns: List[str]) -> bool:
def is_shield_needed(confidence: float, matched_patterns: list[str]) -> bool:
"""判断是否需要加盾
Args:
@@ -60,7 +57,7 @@ class MessageShield:
return False
@staticmethod
def create_safety_summary(confidence: float, matched_patterns: List[str]) -> str:
def create_safety_summary(confidence: float, matched_patterns: list[str]) -> str:
"""创建安全处理摘要
Args:

View File

@@ -1,15 +1,13 @@
# -*- coding: utf-8 -*-
"""
反击消息生成模块
负责生成个性化的反击消息回应提示词注入攻击
"""
from typing import Optional
from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_system.apis import llm_api
from .types import DetectionResult
logger = get_logger("anti_injector.counter_attack")
@@ -55,7 +53,7 @@ class CounterAttackGenerator:
async def generate_counter_attack_message(
self, original_message: str, detection_result: DetectionResult
) -> Optional[str]:
) -> str | None:
"""生成反击消息
Args:

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
反注入系统决策模块
@@ -7,7 +6,7 @@
- counter_attack: 反击消息生成器
"""
from .decision_maker import ProcessingDecisionMaker
from .counter_attack import CounterAttackGenerator
from .decision_maker import ProcessingDecisionMaker
__all__ = ["ProcessingDecisionMaker", "CounterAttackGenerator"]
__all__ = ["CounterAttackGenerator", "ProcessingDecisionMaker"]

View File

@@ -1,15 +1,13 @@
# -*- coding: utf-8 -*-
"""
反击消息生成模块
负责生成个性化的反击消息回应提示词注入攻击
"""
from typing import Optional
from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_system.apis import llm_api
from ..types import DetectionResult
logger = get_logger("anti_injector.counter_attack")
@@ -55,7 +53,7 @@ class CounterAttackGenerator:
async def generate_counter_attack_message(
self, original_message: str, detection_result: DetectionResult
) -> Optional[str]:
) -> str | None:
"""生成反击消息
Args:

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
处理决策器模块
@@ -6,6 +5,7 @@
"""
from src.common.logger import get_logger
from ..types import DetectionResult
logger = get_logger("anti_injector.decision_maker")

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
处理决策器模块
@@ -6,6 +5,7 @@
"""
from src.common.logger import get_logger
from .types import DetectionResult
logger = get_logger("anti_injector.decision_maker")

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
提示词注入检测器模块
@@ -8,19 +7,19 @@
3. 缓存机制优化性能
"""
import hashlib
import re
import time
import hashlib
from typing import Dict, List
from dataclasses import asdict
from src.common.logger import get_logger
from src.config.config import global_config
from .types import DetectionResult
# 导入LLM API
from src.plugin_system.apis import llm_api
from .types import DetectionResult
logger = get_logger("anti_injector.detector")
@@ -30,8 +29,8 @@ class PromptInjectionDetector:
def __init__(self):
"""初始化检测器"""
self.config = global_config.anti_prompt_injection
self._cache: Dict[str, DetectionResult] = {}
self._compiled_patterns: List[re.Pattern] = []
self._cache: dict[str, DetectionResult] = {}
self._compiled_patterns: list[re.Pattern] = []
self._compile_patterns()
def _compile_patterns(self):
@@ -221,7 +220,7 @@ class PromptInjectionDetector:
matched_patterns=[],
processing_time=processing_time,
detection_method="llm",
reason=f"LLM检测出错: {str(e)}",
reason=f"LLM检测出错: {e!s}",
)
@staticmethod
@@ -247,7 +246,7 @@ class PromptInjectionDetector:
请客观分析,避免误判正常对话。"""
@staticmethod
def _parse_llm_response(response: str) -> Dict:
def _parse_llm_response(response: str) -> dict:
"""解析LLM响应"""
try:
lines = response.strip().split("\n")
@@ -277,7 +276,7 @@ class PromptInjectionDetector:
except Exception as e:
logger.error(f"解析LLM响应失败: {e}")
return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {str(e)}"}
return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {e!s}"}
async def detect(self, message: str) -> DetectionResult:
"""执行检测"""
@@ -328,7 +327,7 @@ class PromptInjectionDetector:
return final_result
def _merge_results(self, results: List[DetectionResult]) -> DetectionResult:
def _merge_results(self, results: list[DetectionResult]) -> DetectionResult:
"""合并多个检测结果"""
if not results:
return DetectionResult(reason="无检测结果")
@@ -381,7 +380,7 @@ class PromptInjectionDetector:
if expired_keys:
logger.debug(f"清理了{len(expired_keys)}个过期缓存项")
def get_cache_stats(self) -> Dict:
def get_cache_stats(self) -> dict:
"""获取缓存统计信息"""
return {
"cache_size": len(self._cache),

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
反注入系统管理模块

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
反注入系统统计模块
@@ -6,12 +5,12 @@
"""
import datetime
from typing import Dict, Any
from typing import Any
from sqlalchemy import select
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import AntiInjectionStats, get_db_session
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("anti_injector.statistics")
@@ -94,7 +93,7 @@ class AntiInjectionStatistics:
except Exception as e:
logger.error(f"更新统计数据失败: {e}")
async def get_stats(self) -> Dict[str, Any]:
async def get_stats(self) -> dict[str, Any]:
"""获取统计信息"""
try:
# 检查反注入系统是否启用

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
用户封禁管理模块
@@ -6,12 +5,12 @@
"""
import datetime
from typing import Optional, Tuple
from sqlalchemy import select
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import BanUser, get_db_session
from src.common.logger import get_logger
from ..types import DetectionResult
logger = get_logger("anti_injector.user_ban")
@@ -28,7 +27,7 @@ class UserBanManager:
"""
self.config = config
async def check_user_ban(self, user_id: str, platform: str) -> Optional[Tuple[bool, Optional[str], str]]:
async def check_user_ban(self, user_id: str, platform: str) -> tuple[bool, str | None, str] | None:
"""检查用户是否被封禁
Args:

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
反注入系统消息处理模块

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
消息内容处理模块
@@ -6,10 +5,9 @@
"""
import re
from typing import Optional
from src.common.logger import get_logger
from src.chat.message_receive.message import MessageRecv
from src.common.logger import get_logger
logger = get_logger("anti_injector.message_processor")
@@ -66,7 +64,7 @@ class MessageProcessor:
return new_content
@staticmethod
def check_whitelist(message: MessageRecv, whitelist: list) -> Optional[tuple]:
def check_whitelist(message: MessageRecv, whitelist: list) -> tuple | None:
"""检查用户白名单
Args:

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
反注入系统数据类型定义模块
@@ -10,7 +9,6 @@
"""
import time
from typing import List, Optional
from dataclasses import dataclass, field
from enum import Enum
@@ -31,8 +29,8 @@ class DetectionResult:
is_injection: bool = False
confidence: float = 0.0
matched_patterns: List[str] = field(default_factory=list)
llm_analysis: Optional[str] = None
matched_patterns: list[str] = field(default_factory=list)
llm_analysis: str | None = None
processing_time: float = 0.0
detection_method: str = "unknown"
reason: str = ""

View File

@@ -1,10 +1,11 @@
from typing import Dict, List, Optional, Any
import time
from src.plugin_system.base.base_chatter import BaseChatter
from src.common.data_models.message_manager_data_model import StreamContext
from typing import Any
from src.chat.planner_actions.action_manager import ChatterActionManager
from src.plugin_system.base.component_types import ChatType
from src.common.data_models.message_manager_data_model import StreamContext
from src.common.logger import get_logger
from src.plugin_system.base.base_chatter import BaseChatter
from src.plugin_system.base.component_types import ChatType
logger = get_logger("chatter_manager")
@@ -12,8 +13,8 @@ logger = get_logger("chatter_manager")
class ChatterManager:
def __init__(self, action_manager: ChatterActionManager):
self.action_manager = action_manager
self.chatter_classes: Dict[ChatType, List[type]] = {}
self.instances: Dict[str, BaseChatter] = {}
self.chatter_classes: dict[ChatType, list[type]] = {}
self.instances: dict[str, BaseChatter] = {}
# 管理器统计
self.stats = {
@@ -46,21 +47,21 @@ class ChatterManager:
self.stats["chatters_registered"] += 1
def get_chatter_class(self, chat_type: ChatType) -> Optional[type]:
def get_chatter_class(self, chat_type: ChatType) -> type | None:
"""获取指定聊天类型的聊天处理器类"""
if chat_type in self.chatter_classes:
return self.chatter_classes[chat_type][0]
return None
def get_supported_chat_types(self) -> List[ChatType]:
def get_supported_chat_types(self) -> list[ChatType]:
"""获取支持的聊天类型列表"""
return list(self.chatter_classes.keys())
def get_registered_chatters(self) -> Dict[ChatType, List[type]]:
def get_registered_chatters(self) -> dict[ChatType, list[type]]:
"""获取已注册的聊天处理器"""
return self.chatter_classes.copy()
def get_stream_instance(self, stream_id: str) -> Optional[BaseChatter]:
def get_stream_instance(self, stream_id: str) -> BaseChatter | None:
"""获取指定流的聊天处理器实例"""
return self.instances.get(stream_id)
@@ -139,7 +140,7 @@ class ChatterManager:
logger.error(f"处理流 {stream_id} 时发生错误: {e}")
raise
def get_stats(self) -> Dict[str, Any]:
def get_stats(self) -> dict[str, Any]:
"""获取管理器统计信息"""
stats = self.stats.copy()
stats["active_instances"] = len(self.instances)

View File

@@ -1,9 +1,7 @@
# -*- coding: utf-8 -*-
"""
表情包发送历史记录模块
"""
from typing import List, Dict
from collections import deque
from src.common.logger import get_logger
@@ -14,7 +12,7 @@ MAX_HISTORY_SIZE = 5 # 每个聊天会话最多保留最近5条表情历史
# 使用一个全局字典在内存中存储历史记录
# 键是 chat_id值是一个 deque 对象
_history_cache: Dict[str, deque] = {}
_history_cache: dict[str, deque] = {}
def add_emoji_to_history(chat_id: str, emoji_description: str):
@@ -38,7 +36,7 @@ def add_emoji_to_history(chat_id: str, emoji_description: str):
logger.debug(f"已将表情 '{emoji_description}' 添加到聊天 {chat_id} 的内存历史中")
def get_recent_emojis(chat_id: str, limit: int = 5) -> List[str]:
def get_recent_emojis(chat_id: str, limit: int = 5) -> list[str]:
"""
从内存中获取最近发送的表情包描述列表。

View File

@@ -1,23 +1,24 @@
import asyncio
import base64
import binascii
import hashlib
import io
import os
import random
import re
import time
import traceback
import io
import re
import binascii
from typing import Any, Optional
from typing import Optional, Tuple, List, Any
from PIL import Image
from rich.traceback import install
from sqlalchemy import select
from src.chat.utils.utils_image import get_image_manager, image_path_to_base64
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Emoji, Images
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
from src.llm_models.utils_model import LLMRequest
install(extra_lines=3)
@@ -47,14 +48,14 @@ class MaiEmoji:
self.embedding = []
self.hash = "" # 初始为空,在创建实例时会计算
self.description = ""
self.emotion: List[str] = []
self.emotion: list[str] = []
self.usage_count = 0
self.last_used_time = time.time()
self.register_time = time.time()
self.is_deleted = False # 标记是否已被删除
self.format = ""
async def initialize_hash_format(self) -> Optional[bool]:
async def initialize_hash_format(self) -> bool | None:
"""从文件创建表情包实例, 计算哈希值和格式"""
try:
# 使用 full_path 检查文件是否存在
@@ -105,7 +106,7 @@ class MaiEmoji:
self.is_deleted = True
return None
except Exception as e:
logger.error(f"[初始化错误] 初始化表情包时发生未预期错误 ({self.filename}): {str(e)}")
logger.error(f"[初始化错误] 初始化表情包时发生未预期错误 ({self.filename}): {e!s}")
logger.error(traceback.format_exc())
self.is_deleted = True
return None
@@ -142,7 +143,7 @@ class MaiEmoji:
self.path = EMOJI_REGISTERED_DIR
# self.filename 保持不变
except Exception as move_error:
logger.error(f"[错误] 移动文件失败: {str(move_error)}")
logger.error(f"[错误] 移动文件失败: {move_error!s}")
# 如果移动失败,尝试将实例状态恢复?暂时不处理,仅返回失败
return False
@@ -174,11 +175,11 @@ class MaiEmoji:
return True
except Exception as db_error:
logger.error(f"[错误] 保存数据库失败 ({self.filename}): {str(db_error)}")
logger.error(f"[错误] 保存数据库失败 ({self.filename}): {db_error!s}")
return False
except Exception as e:
logger.error(f"[错误] 注册表情包失败 ({self.filename}): {str(e)}")
logger.error(f"[错误] 注册表情包失败 ({self.filename}): {e!s}")
logger.error(traceback.format_exc())
return False
@@ -198,7 +199,7 @@ class MaiEmoji:
os.remove(file_to_delete)
logger.debug(f"[删除] 文件: {file_to_delete}")
except Exception as e:
logger.error(f"[错误] 删除文件失败 {file_to_delete}: {str(e)}")
logger.error(f"[错误] 删除文件失败 {file_to_delete}: {e!s}")
# 文件删除失败,但仍然尝试删除数据库记录
# 2. 删除数据库记录
@@ -214,7 +215,7 @@ class MaiEmoji:
result = 1 # Successfully deleted one record
await session.commit()
except Exception as e:
logger.error(f"[错误] 删除数据库记录时出错: {str(e)}")
logger.error(f"[错误] 删除数据库记录时出错: {e!s}")
result = 0
if result > 0:
@@ -233,11 +234,11 @@ class MaiEmoji:
return False
except Exception as e:
logger.error(f"[错误] 删除表情包失败 ({self.filename}): {str(e)}")
logger.error(f"[错误] 删除表情包失败 ({self.filename}): {e!s}")
return False
def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str]:
def _emoji_objects_to_readable_list(emoji_objects: list["MaiEmoji"]) -> list[str]:
"""将表情包对象列表转换为可读的字符串列表
参数:
@@ -256,7 +257,7 @@ def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str
return emoji_info_list
def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
def _to_emoji_objects(data: Any) -> tuple[list["MaiEmoji"], int]:
emoji_objects = []
load_errors = 0
emoji_data_list = list(data)
@@ -300,7 +301,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
logger.error(f"[加载错误] 初始化 MaiEmoji 失败 ({full_path}): {ve}")
load_errors += 1
except Exception as e:
logger.error(f"[加载错误] 处理数据库记录时出错 ({full_path}): {str(e)}")
logger.error(f"[加载错误] 处理数据库记录时出错 ({full_path}): {e!s}")
load_errors += 1
return emoji_objects, load_errors
@@ -335,7 +336,7 @@ async def clear_temp_emoji() -> None:
logger.debug(f"[清理] 删除: {filename}")
async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], removed_count: int) -> int:
async def clean_unused_emojis(emoji_dir: str, emoji_objects: list["MaiEmoji"], removed_count: int) -> int:
"""清理指定目录中未被 emoji_objects 追踪的表情包文件"""
if not os.path.exists(emoji_dir):
logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}")
@@ -361,7 +362,7 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r
logger.info(f"[清理] 删除未追踪的表情包文件: {file_full_path}")
cleaned_count += 1
except Exception as e:
logger.error(f"[错误] 删除文件时出错 ({file_full_path}): {str(e)}")
logger.error(f"[错误] 删除文件时出错 ({file_full_path}): {e!s}")
if cleaned_count > 0:
logger.info(f"[清理] 在目录 {emoji_dir} 中清理了 {cleaned_count} 个破损表情包。")
@@ -369,7 +370,7 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r
logger.info(f"[清理] 目录 {emoji_dir} 中没有需要清理的。")
except Exception as e:
logger.error(f"[错误] 清理未使用表情包文件时出错 ({emoji_dir}): {str(e)}")
logger.error(f"[错误] 清理未使用表情包文件时出错 ({emoji_dir}): {e!s}")
return removed_count + cleaned_count
@@ -437,9 +438,9 @@ class EmojiManager:
emoji_update.last_used_time = time.time() # Update last used time
await session.commit()
except Exception as e:
logger.error(f"记录表情使用失败: {str(e)}")
logger.error(f"记录表情使用失败: {e!s}")
async def get_emoji_for_text(self, text_emotion: str) -> Optional[Tuple[str, str, str]]:
async def get_emoji_for_text(self, text_emotion: str) -> tuple[str, str, str] | None:
"""
根据文本内容使用LLM选择一个合适的表情包。
@@ -531,7 +532,7 @@ class EmojiManager:
return selected_emoji.full_path, f"[表情包:{selected_emoji.description}]", text_emotion
except Exception as e:
logger.error(f"使用LLM获取表情包时发生错误: {str(e)}")
logger.error(f"使用LLM获取表情包时发生错误: {e!s}")
logger.error(traceback.format_exc())
return None
@@ -578,7 +579,7 @@ class EmojiManager:
continue
except Exception as item_error:
logger.error(f"[错误] 处理表情包记录时出错 ({emoji.filename}): {str(item_error)}")
logger.error(f"[错误] 处理表情包记录时出错 ({emoji.filename}): {item_error!s}")
# 即使出错,也尝试继续检查下一个
continue
@@ -597,7 +598,7 @@ class EmojiManager:
logger.info(f"[检查] 已检查 {total_count} 个表情包记录,全部完好")
except Exception as e:
logger.error(f"[错误] 检查表情包完整性失败: {str(e)}")
logger.error(f"[错误] 检查表情包完整性失败: {e!s}")
logger.error(traceback.format_exc())
async def start_periodic_check_register(self) -> None:
@@ -651,7 +652,7 @@ class EmojiManager:
os.remove(file_path)
logger.warning(f"[清理] 删除注册失败的表情包文件: {filename}")
except Exception as e:
logger.error(f"[错误] 扫描表情包目录失败: {str(e)}")
logger.error(f"[错误] 扫描表情包目录失败: {e!s}")
await asyncio.sleep(global_config.emoji.check_interval * 60)
@@ -674,11 +675,11 @@ class EmojiManager:
logger.warning(f"[数据库] 加载过程中出现 {load_errors} 个错误。")
except Exception as e:
logger.error(f"[错误] 从数据库加载所有表情包对象失败: {str(e)}")
logger.error(f"[错误] 从数据库加载所有表情包对象失败: {e!s}")
self.emoji_objects = [] # 加载失败则清空列表
self.emoji_num = 0
async def get_emoji_from_db(self, emoji_hash: Optional[str] = None) -> List["MaiEmoji"]:
async def get_emoji_from_db(self, emoji_hash: str | None = None) -> list["MaiEmoji"]:
"""获取指定哈希值的表情包并初始化为MaiEmoji类对象列表 (主要用于调试或特定查找)
参数:
@@ -707,7 +708,7 @@ class EmojiManager:
return emoji_objects
except Exception as e:
logger.error(f"[错误] 从数据库获取表情包对象失败: {str(e)}")
logger.error(f"[错误] 从数据库获取表情包对象失败: {e!s}")
return []
async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]:
@@ -725,7 +726,7 @@ class EmojiManager:
return emoji
return None # 如果循环结束还没找到,则返回 None
async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[str]:
async def get_emoji_tag_by_hash(self, emoji_hash: str) -> str | None:
"""根据哈希值获取已注册表情包的描述
Args:
@@ -753,10 +754,10 @@ class EmojiManager:
return None
except Exception as e:
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}")
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {e!s}")
return None
async def get_emoji_description_by_hash(self, emoji_hash: str) -> Optional[str]:
async def get_emoji_description_by_hash(self, emoji_hash: str) -> str | None:
"""根据哈希值获取已注册表情包的描述
Args:
@@ -787,7 +788,7 @@ class EmojiManager:
return None
except Exception as e:
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}")
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {e!s}")
return None
async def delete_emoji(self, emoji_hash: str) -> bool:
@@ -823,7 +824,7 @@ class EmojiManager:
return False
except Exception as e:
logger.error(f"[错误] 删除表情包失败: {str(e)}")
logger.error(f"[错误] 删除表情包失败: {e!s}")
logger.error(traceback.format_exc())
return False
@@ -909,11 +910,11 @@ class EmojiManager:
return False
except Exception as e:
logger.error(f"[错误] 替换表情包失败: {str(e)}")
logger.error(f"[错误] 替换表情包失败: {e!s}")
logger.error(traceback.format_exc())
return False
async def build_emoji_description(self, image_base64: str) -> Tuple[str, List[str]]:
async def build_emoji_description(self, image_base64: str) -> tuple[str, list[str]]:
"""
获取表情包的详细描述和情感关键词列表。
@@ -976,14 +977,14 @@ class EmojiManager:
# 4. 内容审核,确保表情包符合规定
if global_config.emoji.content_filtration:
prompt = f'''
prompt = f"""
请根据以下标准审核这个表情包:
1. 主题必须符合:"{global_config.emoji.filtration_prompt}"
2. 内容健康,不含色情、暴力、政治敏感等元素。
3. 必须是表情包,而不是普通的聊天截图或视频截图。
4. 表情包中的文字数量如果有不能超过5个。
这个表情包是否完全满足以上所有要求?请只回答“是”或“否”。
'''
"""
content, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.1, max_tokens=10
)
@@ -1023,7 +1024,7 @@ class EmojiManager:
return final_description, emotions
except Exception as e:
logger.error(f"构建表情包描述时发生严重错误: {str(e)}")
logger.error(f"构建表情包描述时发生严重错误: {e!s}")
logger.error(traceback.format_exc())
return "", []
@@ -1058,7 +1059,7 @@ class EmojiManager:
os.remove(file_full_path)
logger.info(f"[清理] 删除重复的待注册文件: {filename}")
except Exception as e:
logger.error(f"[错误] 删除重复文件失败: {str(e)}")
logger.error(f"[错误] 删除重复文件失败: {e!s}")
return False # 返回 False 表示未注册新表情
# 3. 构建描述和情感
@@ -1075,7 +1076,7 @@ class EmojiManager:
os.remove(file_full_path)
logger.info(f"[清理] 删除描述生成失败的文件: {filename}")
except Exception as e:
logger.error(f"[错误] 删除描述生成失败文件时出错: {str(e)}")
logger.error(f"[错误] 删除描述生成失败文件时出错: {e!s}")
return False
new_emoji.description = description
new_emoji.emotion = emotions
@@ -1086,7 +1087,7 @@ class EmojiManager:
os.remove(file_full_path)
logger.info(f"[清理] 删除描述生成异常的文件: {filename}")
except Exception as e:
logger.error(f"[错误] 删除描述生成异常文件时出错: {str(e)}")
logger.error(f"[错误] 删除描述生成异常文件时出错: {e!s}")
return False
# 4. 检查容量并决定是否替换或直接注册
@@ -1100,7 +1101,7 @@ class EmojiManager:
os.remove(file_full_path) # new_emoji 的 full_path 此时还是源路径
logger.info(f"[清理] 删除替换失败的新表情文件: {filename}")
except Exception as e:
logger.error(f"[错误] 删除替换失败文件时出错: {str(e)}")
logger.error(f"[错误] 删除替换失败文件时出错: {e!s}")
return False
# 替换成功时replace_a_emoji 内部已处理 new_emoji 的注册和添加到列表
return True
@@ -1122,11 +1123,11 @@ class EmojiManager:
os.remove(file_full_path)
logger.info(f"[清理] 删除注册失败的源文件: {filename}")
except Exception as e:
logger.error(f"[错误] 删除注册失败源文件时出错: {str(e)}")
logger.error(f"[错误] 删除注册失败源文件时出错: {e!s}")
return False
except Exception as e:
logger.error(f"[错误] 注册表情包时发生未预期错误 ({filename}): {str(e)}")
logger.error(f"[错误] 注册表情包时发生未预期错误 ({filename}): {e!s}")
logger.error(traceback.format_exc())
# 尝试删除源文件以避免循环处理
if os.path.exists(file_full_path):

View File

@@ -4,24 +4,24 @@
"""
from .energy_manager import (
EnergyManager,
EnergyLevel,
EnergyComponent,
EnergyCalculator,
InterestEnergyCalculator,
ActivityEnergyCalculator,
EnergyCalculator,
EnergyComponent,
EnergyLevel,
EnergyManager,
InterestEnergyCalculator,
RecencyEnergyCalculator,
RelationshipEnergyCalculator,
energy_manager,
)
__all__ = [
"EnergyManager",
"EnergyLevel",
"EnergyComponent",
"EnergyCalculator",
"InterestEnergyCalculator",
"ActivityEnergyCalculator",
"EnergyCalculator",
"EnergyComponent",
"EnergyLevel",
"EnergyManager",
"InterestEnergyCalculator",
"RecencyEnergyCalculator",
"RelationshipEnergyCalculator",
"energy_manager",

View File

@@ -4,10 +4,10 @@
"""
import time
from typing import Dict, List, Optional, Tuple, Any, Union, TypedDict
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from abc import ABC, abstractmethod
from typing import Any, TypedDict
from src.common.logger import get_logger
from src.config.config import global_config
@@ -51,8 +51,8 @@ class EnergyContext(TypedDict):
"""能量计算上下文"""
stream_id: str
messages: List[Any]
user_id: Optional[str]
messages: list[Any]
user_id: str | None
class EnergyResult(TypedDict):
@@ -61,7 +61,7 @@ class EnergyResult(TypedDict):
energy: float
level: EnergyLevel
distribution_interval: float
component_scores: Dict[str, float]
component_scores: dict[str, float]
cached: bool
@@ -69,7 +69,7 @@ class EnergyCalculator(ABC):
"""能量计算器抽象基类"""
@abstractmethod
def calculate(self, context: Dict[str, Any]) -> float:
def calculate(self, context: dict[str, Any]) -> float:
"""计算能量值"""
pass
@@ -82,7 +82,7 @@ class EnergyCalculator(ABC):
class InterestEnergyCalculator(EnergyCalculator):
"""兴趣度能量计算器"""
def calculate(self, context: Dict[str, Any]) -> float:
def calculate(self, context: dict[str, Any]) -> float:
"""基于消息兴趣度计算能量"""
messages = context.get("messages", [])
if not messages:
@@ -120,7 +120,7 @@ class ActivityEnergyCalculator(EnergyCalculator):
def __init__(self):
self.action_weights = {"reply": 0.4, "react": 0.3, "mention": 0.2, "other": 0.1}
def calculate(self, context: Dict[str, Any]) -> float:
def calculate(self, context: dict[str, Any]) -> float:
"""基于活跃度计算能量"""
messages = context.get("messages", [])
if not messages:
@@ -150,7 +150,7 @@ class ActivityEnergyCalculator(EnergyCalculator):
class RecencyEnergyCalculator(EnergyCalculator):
"""最近性能量计算器"""
def calculate(self, context: Dict[str, Any]) -> float:
def calculate(self, context: dict[str, Any]) -> float:
"""基于最近性计算能量"""
messages = context.get("messages", [])
if not messages:
@@ -197,7 +197,7 @@ class RecencyEnergyCalculator(EnergyCalculator):
class RelationshipEnergyCalculator(EnergyCalculator):
"""关系能量计算器"""
async def calculate(self, context: Dict[str, Any]) -> float:
async def calculate(self, context: dict[str, Any]) -> float:
"""基于关系计算能量"""
user_id = context.get("user_id")
if not user_id:
@@ -223,7 +223,7 @@ class EnergyManager:
"""能量管理器 - 统一管理所有能量计算"""
def __init__(self) -> None:
self.calculators: List[EnergyCalculator] = [
self.calculators: list[EnergyCalculator] = [
InterestEnergyCalculator(),
ActivityEnergyCalculator(),
RecencyEnergyCalculator(),
@@ -231,14 +231,14 @@ class EnergyManager:
]
# 能量缓存
self.energy_cache: Dict[str, Tuple[float, float]] = {} # stream_id -> (energy, timestamp)
self.energy_cache: dict[str, tuple[float, float]] = {} # stream_id -> (energy, timestamp)
self.cache_ttl: int = 60 # 1分钟缓存
# AFC阈值配置
self.thresholds: Dict[str, float] = {"high_match": 0.8, "reply": 0.4, "non_reply": 0.2}
self.thresholds: dict[str, float] = {"high_match": 0.8, "reply": 0.4, "non_reply": 0.2}
# 统计信息
self.stats: Dict[str, Union[int, float, str]] = {
self.stats: dict[str, int | float | str] = {
"total_calculations": 0,
"cache_hits": 0,
"cache_misses": 0,
@@ -272,7 +272,7 @@ class EnergyManager:
except Exception as e:
logger.warning(f"加载AFC阈值失败使用默认值: {e}")
async def calculate_focus_energy(self, stream_id: str, messages: List[Any], user_id: Optional[str] = None) -> float:
async def calculate_focus_energy(self, stream_id: str, messages: list[Any], user_id: str | None = None) -> float:
"""计算聊天流的focus_energy"""
start_time = time.time()
@@ -297,7 +297,7 @@ class EnergyManager:
}
# 计算各组件能量
component_scores: Dict[str, float] = {}
component_scores: dict[str, float] = {}
total_weight = 0.0
for calculator in self.calculators:
@@ -437,7 +437,7 @@ class EnergyManager:
if expired_keys:
logger.debug(f"清理了 {len(expired_keys)} 个过期能量缓存")
def get_statistics(self) -> Dict[str, Any]:
def get_statistics(self) -> dict[str, Any]:
"""获取统计信息"""
return {
"cache_size": len(self.energy_cache),
@@ -446,7 +446,7 @@ class EnergyManager:
"performance_stats": self.stats.copy(),
}
def update_thresholds(self, new_thresholds: Dict[str, float]) -> None:
def update_thresholds(self, new_thresholds: dict[str, float]) -> None:
"""更新阈值"""
self.thresholds.update(new_thresholds)

View File

@@ -1,21 +1,20 @@
import time
import random
import orjson
import os
import random
import time
from datetime import datetime
from typing import Any
from typing import List, Dict, Optional, Any, Tuple
from src.common.logger import get_logger
from src.common.database.sqlalchemy_database_api import get_db_session
import orjson
from sqlalchemy import select
from src.common.database.sqlalchemy_models import Expression
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import build_anonymous_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Expression
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
MAX_EXPRESSION_COUNT = 300
DECAY_DAYS = 30 # 30天衰减到0.01
@@ -193,7 +192,7 @@ class ExpressionLearner:
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
return False
async def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
async def get_expression_by_chat_id(self) -> tuple[list[dict[str, float]], list[dict[str, float]]]:
"""
获取指定chat_id的style和grammar表达方式
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
@@ -341,7 +340,7 @@ class ExpressionLearner:
return []
# 按chat_id分组
chat_dict: Dict[str, List[Dict[str, Any]]] = {}
chat_dict: dict[str, list[dict[str, Any]]] = {}
for chat_id, situation, style in learnt_expressions:
if chat_id not in chat_dict:
chat_dict[chat_id] = []
@@ -398,7 +397,7 @@ class ExpressionLearner:
return learnt_expressions
return None
async def learn_expression(self, type: str, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
async def learn_expression(self, type: str, num: int = 10) -> tuple[list[tuple[str, str, str]], str] | None:
"""从指定聊天流学习表达方式
Args:
@@ -416,7 +415,7 @@ class ExpressionLearner:
current_time = time.time()
# 获取上次学习时间
random_msg: Optional[List[Dict[str, Any]]] = await get_raw_msg_by_timestamp_with_chat_inclusive(
random_msg: list[dict[str, Any]] | None = await get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_learning_time,
timestamp_end=current_time,
@@ -447,16 +446,16 @@ class ExpressionLearner:
logger.debug(f"学习{type_str}的response: {response}")
expressions: List[Tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
expressions: list[tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
return expressions, chat_id
@staticmethod
def parse_expression_response(response: str, chat_id: str) -> List[Tuple[str, str, str]]:
def parse_expression_response(response: str, chat_id: str) -> list[tuple[str, str, str]]:
"""
解析LLM返回的表达风格总结每一行提取"""使用"之间的内容,存储为(situation, style)元组
"""
expressions: List[Tuple[str, str, str]] = []
expressions: list[tuple[str, str, str]] = []
for line in response.splitlines():
line = line.strip()
if not line:
@@ -562,7 +561,7 @@ class ExpressionLearnerManager:
if not os.path.exists(expr_file):
continue
try:
with open(expr_file, "r", encoding="utf-8") as f:
with open(expr_file, encoding="utf-8") as f:
expressions = orjson.loads(f.read())
if not isinstance(expressions, list):

View File

@@ -1,18 +1,18 @@
import orjson
import time
import random
import hashlib
import random
import time
from typing import Any
from typing import List, Dict, Tuple, Optional, Any
import orjson
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from sqlalchemy import select
from src.common.database.sqlalchemy_models import Expression
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Expression
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger("expression_selector")
@@ -45,7 +45,7 @@ def init_prompt():
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
def weighted_sample(population: List[Dict], weights: List[float], k: int) -> List[Dict]:
def weighted_sample(population: list[dict], weights: list[float], k: int) -> list[dict]:
"""按权重随机抽样"""
if not population or not weights or k <= 0:
return []
@@ -95,7 +95,7 @@ class ExpressionSelector:
return False
@staticmethod
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None:
"""解析'platform:id:type'为chat_id与get_stream_id一致"""
try:
parts = stream_config_str.split(":")
@@ -114,7 +114,7 @@ class ExpressionSelector:
except Exception:
return None
def get_related_chat_ids(self, chat_id: str) -> List[str]:
def get_related_chat_ids(self, chat_id: str) -> list[str]:
"""根据expression.rules配置获取与当前chat_id相关的所有chat_id包括自身"""
rules = global_config.expression.rules
current_group = None
@@ -139,7 +139,7 @@ class ExpressionSelector:
async def get_random_expressions(
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
# sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
@@ -195,7 +195,7 @@ class ExpressionSelector:
return selected_style, selected_grammar
@staticmethod
async def update_expressions_count_batch(expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
async def update_expressions_count_batch(expressions_to_update: list[dict[str, Any]], increment: float = 0.1):
"""对一批表达方式更新count值按chat_id+type分组后一次性写入数据库"""
if not expressions_to_update:
return
@@ -240,8 +240,8 @@ class ExpressionSelector:
chat_info: str,
max_num: int = 10,
min_num: int = 5,
target_message: Optional[str] = None,
) -> List[Dict[str, Any]]:
target_message: str | None = None,
) -> list[dict[str, Any]]:
# sourcery skip: inline-variable, list-comprehension
"""使用LLM选择适合的表达方式"""

View File

@@ -16,8 +16,7 @@ Chat Frequency Analyzer
"""
import time as time_module
from datetime import datetime, timedelta, time
from typing import List, Tuple, Optional
from datetime import datetime, time, timedelta
from .tracker import chat_frequency_tracker
@@ -42,7 +41,7 @@ class ChatFrequencyAnalyzer:
self._cache_ttl_seconds = 60 * 30 # 缓存30分钟
@staticmethod
def _find_peak_windows(timestamps: List[float]) -> List[Tuple[datetime, datetime]]:
def _find_peak_windows(timestamps: list[float]) -> list[tuple[datetime, datetime]]:
"""
使用滑动窗口算法来识别时间戳列表中的高峰时段。
@@ -59,7 +58,7 @@ class ChatFrequencyAnalyzer:
datetimes = [datetime.fromtimestamp(ts) for ts in timestamps]
datetimes.sort()
peak_windows: List[Tuple[datetime, datetime]] = []
peak_windows: list[tuple[datetime, datetime]] = []
window_start_idx = 0
for i in range(len(datetimes)):
@@ -83,7 +82,7 @@ class ChatFrequencyAnalyzer:
return peak_windows
def get_peak_chat_times(self, chat_id: str) -> List[Tuple[time, time]]:
def get_peak_chat_times(self, chat_id: str) -> list[tuple[time, time]]:
"""
获取指定用户的高峰聊天时间段。
@@ -116,7 +115,7 @@ class ChatFrequencyAnalyzer:
return peak_time_windows
def is_in_peak_time(self, chat_id: str, now: Optional[datetime] = None) -> bool:
def is_in_peak_time(self, chat_id: str, now: datetime | None = None) -> bool:
"""
检查当前时间是否处于用户的高峰聊天时段内。

View File

@@ -1,8 +1,8 @@
import orjson
import time
from typing import Dict, List, Optional
from pathlib import Path
import orjson
from src.common.logger import get_logger
# 数据存储路径
@@ -19,10 +19,10 @@ class ChatFrequencyTracker:
"""
def __init__(self):
self._timestamps: Dict[str, List[float]] = self._load_timestamps()
self._timestamps: dict[str, list[float]] = self._load_timestamps()
@staticmethod
def _load_timestamps() -> Dict[str, List[float]]:
def _load_timestamps() -> dict[str, list[float]]:
"""从本地文件加载时间戳数据。"""
if not TRACKER_FILE.exists():
return {}
@@ -61,7 +61,7 @@ class ChatFrequencyTracker:
logger.debug(f"为 chat_id '{chat_id}' 记录了新的聊天时间: {now}")
self._save_timestamps()
def get_timestamps_for_chat(self, chat_id: str) -> Optional[List[float]]:
def get_timestamps_for_chat(self, chat_id: str) -> list[float] | None:
"""
获取指定聊天的所有时间戳记录。

View File

@@ -18,11 +18,10 @@ Frequency-Based Proactive Trigger
import asyncio
import time
from datetime import datetime
from typing import Dict, Optional
from src.common.logger import get_logger
# AFC manager has been moved to chatter plugin
# AFC manager has been moved to chatter plugin
# TODO: 需要重新实现主动思考和睡眠管理功能
from .analyzer import chat_frequency_analyzer
@@ -42,10 +41,10 @@ class FrequencyBasedTrigger:
def __init__(self):
# TODO: 需要重新实现睡眠管理器
self._task: Optional[asyncio.Task] = None
self._task: asyncio.Task | None = None
# 记录上次为用户触发的时间,用于冷却控制
# 格式: { "chat_id": timestamp }
self._last_triggered: Dict[str, float] = {}
self._last_triggered: dict[str, float] = {}
async def _run_trigger_cycle(self):
"""触发器的主要循环逻辑。"""

View File

@@ -3,13 +3,14 @@
提供机器人兴趣标签和智能匹配功能
"""
from .bot_interest_manager import BotInterestManager, bot_interest_manager
from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
from .bot_interest_manager import BotInterestManager, bot_interest_manager
__all__ = [
"BotInterestManager",
"bot_interest_manager",
"BotInterestTag",
"BotPersonalityInterests",
"InterestMatchResult",
"bot_interest_manager",
]

View File

@@ -3,17 +3,18 @@
基于人设生成兴趣标签并使用embedding计算匹配度
"""
import orjson
import traceback
from typing import List, Dict, Optional, Any
from datetime import datetime
from typing import Any
import numpy as np
import orjson
from sqlalchemy import select
from src.common.config_helpers import resolve_embedding_dimension
from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
from src.common.logger import get_logger
from src.config.config import global_config
from src.common.config_helpers import resolve_embedding_dimension
from src.common.data_models.bot_interest_data_model import BotPersonalityInterests, BotInterestTag, InterestMatchResult
logger = get_logger("bot_interest_manager")
@@ -22,8 +23,8 @@ class BotInterestManager:
"""机器人兴趣标签管理器"""
def __init__(self):
self.current_interests: Optional[BotPersonalityInterests] = None
self.embedding_cache: Dict[str, List[float]] = {} # embedding缓存
self.current_interests: BotPersonalityInterests | None = None
self.embedding_cache: dict[str, list[float]] = {} # embedding缓存
self._initialized = False
# Embedding客户端配置
@@ -31,7 +32,7 @@ class BotInterestManager:
self.embedding_config = None
configured_dim = resolve_embedding_dimension()
self.embedding_dimension = int(configured_dim) if configured_dim else 0
self._detected_embedding_dimension: Optional[int] = None
self._detected_embedding_dimension: int | None = None
@property
def is_initialized(self) -> bool:
@@ -145,7 +146,7 @@ class BotInterestManager:
async def _generate_interests_from_personality(
self, personality_description: str, personality_id: str
) -> Optional[BotPersonalityInterests]:
) -> BotPersonalityInterests | None:
"""根据人设生成兴趣标签"""
try:
logger.info("🎨 开始根据人设生成兴趣标签...")
@@ -226,14 +227,14 @@ class BotInterestManager:
traceback.print_exc()
raise
async def _call_llm_for_interest_generation(self, prompt: str) -> Optional[str]:
async def _call_llm_for_interest_generation(self, prompt: str) -> str | None:
"""调用LLM生成兴趣标签"""
try:
logger.info("🔧 配置LLM客户端...")
# 使用llm_api来处理请求
from src.plugin_system.apis import llm_api
from src.config.config import model_config
from src.plugin_system.apis import llm_api
# 构建完整的提示词明确要求只返回纯JSON
full_prompt = f"""你是一个专业的机器人人设分析师,擅长根据人设描述生成合适的兴趣标签。
@@ -342,7 +343,7 @@ class BotInterestManager:
logger.info(f"🗃️ 总缓存大小: {len(self.embedding_cache)}")
logger.info("=" * 50)
async def _get_embedding(self, text: str) -> List[float]:
async def _get_embedding(self, text: str) -> list[float]:
"""获取文本的embedding向量"""
if not hasattr(self, "embedding_request"):
raise RuntimeError("❌ Embedding请求客户端未初始化")
@@ -383,7 +384,7 @@ class BotInterestManager:
else:
raise RuntimeError(f"❌ 返回的embedding为空: {embedding}")
async def _generate_message_embedding(self, message_text: str, keywords: List[str]) -> List[float]:
async def _generate_message_embedding(self, message_text: str, keywords: list[str]) -> list[float]:
"""为消息生成embedding向量"""
# 组合消息文本和关键词作为embedding输入
if keywords:
@@ -399,7 +400,7 @@ class BotInterestManager:
return embedding
async def _calculate_similarity_scores(
self, result: InterestMatchResult, message_embedding: List[float], keywords: List[str]
self, result: InterestMatchResult, message_embedding: list[float], keywords: list[str]
):
"""计算消息与兴趣标签的相似度分数"""
try:
@@ -428,7 +429,7 @@ class BotInterestManager:
except Exception as e:
logger.error(f"❌ 计算相似度分数失败: {e}")
async def calculate_interest_match(self, message_text: str, keywords: List[str] = None) -> InterestMatchResult:
async def calculate_interest_match(self, message_text: str, keywords: list[str] = None) -> InterestMatchResult:
"""计算消息与机器人兴趣的匹配度"""
if not self.current_interests or not self._initialized:
raise RuntimeError("❌ 兴趣标签系统未初始化")
@@ -528,7 +529,7 @@ class BotInterestManager:
)
return result
def _calculate_keyword_match_bonus(self, keywords: List[str], matched_tags: List[str]) -> Dict[str, float]:
def _calculate_keyword_match_bonus(self, keywords: list[str], matched_tags: list[str]) -> dict[str, float]:
"""计算关键词直接匹配奖励"""
if not keywords or not matched_tags:
return {}
@@ -610,7 +611,7 @@ class BotInterestManager:
return previous_row[-1]
def _calculate_cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
def _calculate_cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float:
"""计算余弦相似度"""
try:
vec1 = np.array(vec1)
@@ -629,16 +630,17 @@ class BotInterestManager:
logger.error(f"计算余弦相似度失败: {e}")
return 0.0
async def _load_interests_from_database(self, personality_id: str) -> Optional[BotPersonalityInterests]:
async def _load_interests_from_database(self, personality_id: str) -> BotPersonalityInterests | None:
"""从数据库加载兴趣标签"""
try:
logger.debug(f"从数据库加载兴趣标签, personality_id: {personality_id}")
# 导入SQLAlchemy相关模块
from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests
from src.common.database.sqlalchemy_database_api import get_db_session
import orjson
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests
async with get_db_session() as session:
# 查询最新的兴趣标签配置
db_interests = (
@@ -716,10 +718,11 @@ class BotInterestManager:
logger.info(f"🔄 版本: {interests.version}")
# 导入SQLAlchemy相关模块
from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests
from src.common.database.sqlalchemy_database_api import get_db_session
import orjson
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests
# 将兴趣标签转换为JSON格式
tags_data = []
for tag in interests.interest_tags:
@@ -803,11 +806,11 @@ class BotInterestManager:
logger.error("🔍 错误详情:")
traceback.print_exc()
def get_current_interests(self) -> Optional[BotPersonalityInterests]:
def get_current_interests(self) -> BotPersonalityInterests | None:
"""获取当前的兴趣标签配置"""
return self.current_interests
def get_interest_stats(self) -> Dict[str, Any]:
def get_interest_stats(self) -> dict[str, Any]:
"""获取兴趣系统统计信息"""
if not self.current_interests:
return {"initialized": False}

View File

@@ -1,33 +1,31 @@
from dataclasses import dataclass
import orjson
import os
import math
import asyncio
import math
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Tuple
import numpy as np
import pandas as pd
from dataclasses import dataclass
# import tqdm
import faiss
from .utils.hash import get_sha256
from .global_logger import logger
from rich.traceback import install
import numpy as np
import orjson
import pandas as pd
from rich.progress import (
Progress,
BarColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
)
from src.config.config import global_config
from src.common.config_helpers import resolve_embedding_dimension
from rich.traceback import install
from src.common.config_helpers import resolve_embedding_dimension
from src.config.config import global_config
from .global_logger import logger
from .utils.hash import get_sha256
install(extra_lines=3)
@@ -79,7 +77,7 @@ def cosine_similarity(a, b):
class EmbeddingStoreItem:
"""嵌入库中的项"""
def __init__(self, item_hash: str, embedding: List[float], content: str):
def __init__(self, item_hash: str, embedding: list[float], content: str):
self.hash = item_hash
self.embedding = embedding
self.str = content
@@ -127,7 +125,7 @@ class EmbeddingStore:
self.idx2hash = None
@staticmethod
def _get_embedding(s: str) -> List[float]:
def _get_embedding(s: str) -> list[float]:
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
# 创建新的事件循环并在完成后立即关闭
loop = asyncio.new_event_loop()
@@ -135,8 +133,8 @@ class EmbeddingStore:
try:
# 创建新的LLMRequest实例
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
@@ -161,8 +159,8 @@ class EmbeddingStore:
@staticmethod
def _get_embeddings_batch_threaded(
strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
) -> List[Tuple[str, List[float]]]:
strs: list[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
) -> list[tuple[str, list[float]]]:
"""使用多线程批量获取嵌入向量
Args:
@@ -192,8 +190,8 @@ class EmbeddingStore:
chunk_results = []
# 为每个线程创建独立的LLMRequest实例
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
try:
# 创建线程专用的LLM实例
@@ -303,7 +301,7 @@ class EmbeddingStore:
path = self.get_test_file_path()
if not os.path.exists(path):
return None
with open(path, "r", encoding="utf-8") as f:
with open(path, encoding="utf-8") as f:
return orjson.loads(f.read())
def check_embedding_model_consistency(self):
@@ -345,7 +343,7 @@ class EmbeddingStore:
logger.info("嵌入模型一致性校验通过。")
return True
def batch_insert_strs(self, strs: List[str], times: int) -> None:
def batch_insert_strs(self, strs: list[str], times: int) -> None:
"""向库中存入字符串(使用多线程优化)"""
if not strs:
return
@@ -481,7 +479,7 @@ class EmbeddingStore:
if os.path.exists(self.idx2hash_file_path):
logger.info(f"正在加载{self.namespace}嵌入库的idx2hash映射...")
logger.debug(f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射")
with open(self.idx2hash_file_path, "r") as f:
with open(self.idx2hash_file_path) as f:
self.idx2hash = orjson.loads(f.read())
logger.info(f"{self.namespace}嵌入库的idx2hash映射加载成功")
else:
@@ -511,7 +509,7 @@ class EmbeddingStore:
self.faiss_index = faiss.IndexFlatIP(embedding_dim)
self.faiss_index.add(embeddings)
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:
def search_top_k(self, query: list[float], k: int) -> list[tuple[str, float]]:
"""搜索最相似的k个项以余弦相似度为度量
Args:
query: 查询的embedding
@@ -575,11 +573,11 @@ class EmbeddingManager:
"""对所有嵌入库做模型一致性校验"""
return self.paragraphs_embedding_store.check_embedding_model_consistency()
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
def _store_pg_into_embedding(self, raw_paragraphs: dict[str, str]):
"""将段落编码存入Embedding库"""
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1)
def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
def _store_ent_into_embedding(self, triple_list_data: dict[str, list[list[str]]]):
"""将实体编码存入Embedding库"""
entities = set()
for triple_list in triple_list_data.values():
@@ -588,7 +586,7 @@ class EmbeddingManager:
entities.add(triple[2])
self.entities_embedding_store.batch_insert_strs(list(entities), times=2)
def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
def _store_rel_into_embedding(self, triple_list_data: dict[str, list[list[str]]]):
"""将关系编码存入Embedding库"""
graph_triples = [] # a list of unique relation triple (in tuple) from all chunks
for triples in triple_list_data.values():
@@ -606,8 +604,8 @@ class EmbeddingManager:
def store_new_data_set(
self,
raw_paragraphs: Dict[str, str],
triple_list_data: Dict[str, List[List[str]]],
raw_paragraphs: dict[str, str],
triple_list_data: dict[str, list[list[str]]],
):
if not self.check_all_embedding_model_consistency():
raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。")

View File

@@ -1,14 +1,15 @@
import asyncio
import orjson
import time
from typing import List, Union
from .global_logger import logger
from . import prompt_template
from .knowledge_lib import INVALID_ENTITY
from src.llm_models.utils_model import LLMRequest
import orjson
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from . import prompt_template
from .global_logger import logger
from .knowledge_lib import INVALID_ENTITY
def _extract_json_from_text(text: str):
# sourcery skip: assign-if-exp, extract-method
@@ -46,7 +47,7 @@ def _extract_json_from_text(text: str):
return []
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> list[str]:
# sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression
"""对段落进行实体提取返回提取出的实体列表JSON格式"""
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
@@ -92,7 +93,7 @@ def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
return entity_extract_result
def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> List[List[str]]:
def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> list[list[str]]:
"""对段落进行实体提取返回提取出的实体列表JSON格式"""
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
paragraph, entities=orjson.dumps(entities).decode("utf-8")
@@ -141,7 +142,7 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
def info_extract_from_str(
llm_client_for_ner: LLMRequest, llm_client_for_rdf: LLMRequest, paragraph: str
) -> Union[tuple[None, None], tuple[list[str], list[list[str]]]]:
) -> tuple[None, None] | tuple[list[str], list[list[str]]]:
try_count = 0
while True:
try:

View File

@@ -1,28 +1,26 @@
import orjson
import os
import time
from typing import Dict, List, Tuple
import numpy as np
import orjson
import pandas as pd
from quick_algo import di_graph, pagerank
from rich.progress import (
Progress,
BarColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
)
from quick_algo import di_graph, pagerank
from .utils.hash import get_sha256
from .embedding_store import EmbeddingManager, EmbeddingStoreItem
from src.config.config import global_config
from .embedding_store import EmbeddingManager, EmbeddingStoreItem
from .global_logger import logger
from .utils.hash import get_sha256
def _get_kg_dir():
@@ -87,7 +85,7 @@ class KGManager:
raise FileNotFoundError(f"KG图文件{self.graph_data_path}不存在")
# 加载段落hash
with open(self.pg_hash_file_path, "r", encoding="utf-8") as f:
with open(self.pg_hash_file_path, encoding="utf-8") as f:
data = orjson.loads(f.read())
self.stored_paragraph_hashes = set(data["stored_paragraph_hashes"])
@@ -100,8 +98,8 @@ class KGManager:
def _build_edges_between_ent(
self,
node_to_node: Dict[Tuple[str, str], float],
triple_list_data: Dict[str, List[List[str]]],
node_to_node: dict[tuple[str, str], float],
triple_list_data: dict[str, list[list[str]]],
):
"""构建实体节点之间的关系,同时统计实体出现次数"""
for triple_list in triple_list_data.values():
@@ -124,8 +122,8 @@ class KGManager:
@staticmethod
def _build_edges_between_ent_pg(
node_to_node: Dict[Tuple[str, str], float],
triple_list_data: Dict[str, List[List[str]]],
node_to_node: dict[tuple[str, str], float],
triple_list_data: dict[str, list[list[str]]],
):
"""构建实体节点与文段节点之间的关系"""
for idx in triple_list_data:
@@ -136,8 +134,8 @@ class KGManager:
@staticmethod
def _synonym_connect(
node_to_node: Dict[Tuple[str, str], float],
triple_list_data: Dict[str, List[List[str]]],
node_to_node: dict[tuple[str, str], float],
triple_list_data: dict[str, list[list[str]]],
embedding_manager: EmbeddingManager,
) -> int:
"""同义词连接"""
@@ -208,7 +206,7 @@ class KGManager:
def _update_graph(
self,
node_to_node: Dict[Tuple[str, str], float],
node_to_node: dict[tuple[str, str], float],
embedding_manager: EmbeddingManager,
):
"""更新KG图结构
@@ -280,7 +278,7 @@ class KGManager:
def build_kg(
self,
triple_list_data: Dict[str, List[List[str]]],
triple_list_data: dict[str, list[list[str]]],
embedding_manager: EmbeddingManager,
):
"""增量式构建KG
@@ -317,8 +315,8 @@ class KGManager:
def kg_search(
self,
relation_search_result: List[Tuple[Tuple[str, str, str], float]],
paragraph_search_result: List[Tuple[str, float]],
relation_search_result: list[tuple[tuple[str, str, str], float]],
paragraph_search_result: list[tuple[str, float]],
embed_manager: EmbeddingManager,
):
"""RAG搜索与PageRank

View File

@@ -1,10 +1,11 @@
from src.chat.knowledge.embedding_store import EmbeddingManager
from src.chat.knowledge.qa_manager import QAManager
from src.chat.knowledge.kg_manager import KGManager
from src.chat.knowledge.global_logger import logger
from src.config.config import global_config
import os
from src.chat.knowledge.embedding_store import EmbeddingManager
from src.chat.knowledge.global_logger import logger
from src.chat.knowledge.kg_manager import KGManager
from src.chat.knowledge.qa_manager import QAManager
from src.config.config import global_config
INVALID_ENTITY = [
"",
"",

View File

@@ -1,14 +1,15 @@
import orjson
import os
import glob
from typing import Any, Dict, List
import os
from typing import Any
import orjson
from .knowledge_lib import DATA_PATH, INVALID_ENTITY, ROOT_PATH
from .knowledge_lib import INVALID_ENTITY, ROOT_PATH, DATA_PATH
# from src.manager.local_store_manager import local_storage
def _filter_invalid_entities(entities: List[str]) -> List[str]:
def _filter_invalid_entities(entities: list[str]) -> list[str]:
"""过滤无效的实体"""
valid_entities = set()
for entity in entities:
@@ -20,7 +21,7 @@ def _filter_invalid_entities(entities: List[str]) -> List[str]:
return list(valid_entities)
def _filter_invalid_triples(triples: List[List[str]]) -> List[List[str]]:
def _filter_invalid_triples(triples: list[list[str]]) -> list[list[str]]:
"""过滤无效的三元组"""
unique_triples = set()
valid_triples = []
@@ -62,7 +63,7 @@ class OpenIE:
def __init__(
self,
docs: List[Dict[str, Any]],
docs: list[dict[str, Any]],
avg_ent_chars,
avg_ent_words,
):
@@ -112,7 +113,7 @@ class OpenIE:
json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json")))
data_list = []
for file in json_files:
with open(file, "r", encoding="utf-8") as f:
with open(file, encoding="utf-8") as f:
data = orjson.loads(f.read())
data_list.append(data)
if not data_list:

View File

@@ -1,15 +1,16 @@
import time
from typing import Tuple, List, Dict, Optional, Any
from typing import Any
from src.chat.utils.utils import get_embedding
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from .global_logger import logger
from .embedding_store import EmbeddingManager
from .global_logger import logger
from .kg_manager import KGManager
# from .lpmmconfig import global_config
from .utils.dyn_topk import dyn_select_top_k
from src.llm_models.utils_model import LLMRequest
from src.chat.utils.utils import get_embedding
from src.config.config import global_config, model_config
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
@@ -26,7 +27,7 @@ class QAManager:
async def process_query(
self, question: str
) -> Optional[Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]]:
) -> tuple[list[tuple[str, float, float]], dict[str, float] | None] | None:
"""处理查询"""
# 生成问题的Embedding
@@ -98,7 +99,7 @@ class QAManager:
return result, ppr_node_weights
async def get_knowledge(self, question: str) -> Optional[Dict[str, Any]]:
async def get_knowledge(self, question: str) -> dict[str, Any] | None:
"""
获取知识,返回结构化字典

View File

@@ -1,9 +1,9 @@
from typing import List, Any, Tuple
from typing import Any
def dyn_select_top_k(
score: List[Tuple[Any, float]], jmp_factor: float, var_factor: float
) -> List[Tuple[Any, float, float]]:
score: list[tuple[Any, float]], jmp_factor: float, var_factor: float
) -> list[tuple[Any, float, float]]:
"""动态TopK选择"""
# 检查输入列表是否为空
if not score:

View File

@@ -1,37 +1,35 @@
# -*- coding: utf-8 -*-
"""
简化记忆系统模块
移除即时记忆和长期记忆分类,实现统一记忆架构和智能遗忘机制
"""
# 核心数据结构
# 激活器
from .enhanced_memory_activator import MemoryActivator, enhanced_memory_activator, memory_activator
from .memory_chunk import (
ConfidenceLevel,
ContentStructure,
ImportanceLevel,
MemoryChunk,
MemoryMetadata,
ContentStructure,
MemoryType,
ImportanceLevel,
ConfidenceLevel,
create_memory_chunk,
)
# 兼容性别名
from .memory_chunk import MemoryChunk as Memory
# 遗忘引擎
from .memory_forgetting_engine import MemoryForgettingEngine, ForgettingConfig, get_memory_forgetting_engine
# Vector DB存储系统
from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage
# 记忆核心系统
from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system, initialize_memory_system
from .memory_forgetting_engine import ForgettingConfig, MemoryForgettingEngine, get_memory_forgetting_engine
# 记忆管理器
from .memory_manager import MemoryManager, MemoryResult, memory_manager
# 激活器
from .enhanced_memory_activator import MemoryActivator, memory_activator, enhanced_memory_activator
# 记忆核心系统
from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system, initialize_memory_system
# 兼容性别名
from .memory_chunk import MemoryChunk as Memory
# Vector DB存储系统
from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage
__all__ = [
# 核心数据结构

View File

@@ -1,17 +1,17 @@
# -*- coding: utf-8 -*-
"""
增强记忆系统适配器
将增强记忆系统集成到现有MoFox Bot架构中
"""
import time
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
from typing import Any
from src.common.logger import get_logger
from src.chat.memory_system.integration_layer import MemoryIntegrationLayer, IntegrationConfig, IntegrationMode
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.chat.memory_system.integration_layer import IntegrationConfig, IntegrationMode, MemoryIntegrationLayer
from src.chat.memory_system.memory_formatter import FormatterConfig, format_memories_for_llm
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
@@ -47,10 +47,10 @@ class AdapterConfig:
class EnhancedMemoryAdapter:
"""增强记忆系统适配器"""
def __init__(self, llm_model: LLMRequest, config: Optional[AdapterConfig] = None):
def __init__(self, llm_model: LLMRequest, config: AdapterConfig | None = None):
self.llm_model = llm_model
self.config = config or AdapterConfig()
self.integration_layer: Optional[MemoryIntegrationLayer] = None
self.integration_layer: MemoryIntegrationLayer | None = None
self._initialized = False
# 统计信息
@@ -96,7 +96,7 @@ class EnhancedMemoryAdapter:
# 如果初始化失败,禁用增强记忆功能
self.config.enable_enhanced_memory = False
async def process_conversation_memory(self, context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
async def process_conversation_memory(self, context: dict[str, Any] | None = None) -> dict[str, Any]:
"""处理对话记忆,以上下文为唯一输入"""
if not self._initialized or not self.config.enable_enhanced_memory:
return {"success": False, "error": "Enhanced memory not available"}
@@ -105,7 +105,7 @@ class EnhancedMemoryAdapter:
self.adapter_stats["total_processed"] += 1
try:
payload_context: Dict[str, Any] = dict(context or {})
payload_context: dict[str, Any] = dict(context or {})
conversation_text = payload_context.get("conversation_text")
if not conversation_text:
@@ -146,8 +146,8 @@ class EnhancedMemoryAdapter:
return {"success": False, "error": str(e)}
async def retrieve_relevant_memories(
self, query: str, user_id: str, context: Optional[Dict[str, Any]] = None, limit: Optional[int] = None
) -> List[MemoryChunk]:
self, query: str, user_id: str, context: dict[str, Any] | None = None, limit: int | None = None
) -> list[MemoryChunk]:
"""检索相关记忆"""
if not self._initialized or not self.config.enable_enhanced_memory:
return []
@@ -166,7 +166,7 @@ class EnhancedMemoryAdapter:
return []
async def get_memory_context_for_prompt(
self, query: str, user_id: str, context: Optional[Dict[str, Any]] = None, max_memories: int = 5
self, query: str, user_id: str, context: dict[str, Any] | None = None, max_memories: int = 5
) -> str:
"""获取用于提示词的记忆上下文"""
memories = await self.retrieve_relevant_memories(query, user_id, context, max_memories)
@@ -186,7 +186,7 @@ class EnhancedMemoryAdapter:
return format_memories_for_llm(memories=memories, query_context=query, config=formatter_config)
async def get_enhanced_memory_summary(self, user_id: str) -> Dict[str, Any]:
async def get_enhanced_memory_summary(self, user_id: str) -> dict[str, Any]:
"""获取增强记忆系统摘要"""
if not self._initialized or not self.config.enable_enhanced_memory:
return {"available": False, "reason": "Not initialized or disabled"}
@@ -222,7 +222,7 @@ class EnhancedMemoryAdapter:
new_avg = (current_avg * (total_processed - 1) + processing_time) / total_processed
self.adapter_stats["average_processing_time"] = new_avg
def get_adapter_stats(self) -> Dict[str, Any]:
def get_adapter_stats(self) -> dict[str, Any]:
"""获取适配器统计信息"""
return self.adapter_stats.copy()
@@ -253,7 +253,7 @@ class EnhancedMemoryAdapter:
# 全局适配器实例
_enhanced_memory_adapter: Optional[EnhancedMemoryAdapter] = None
_enhanced_memory_adapter: EnhancedMemoryAdapter | None = None
async def get_enhanced_memory_adapter(llm_model: LLMRequest) -> EnhancedMemoryAdapter:
@@ -292,8 +292,8 @@ async def initialize_enhanced_memory_system(llm_model: LLMRequest):
async def process_conversation_with_enhanced_memory(
context: Dict[str, Any], llm_model: Optional[LLMRequest] = None
) -> Dict[str, Any]:
context: dict[str, Any], llm_model: LLMRequest | None = None
) -> dict[str, Any]:
"""使用增强记忆系统处理对话,上下文需包含 conversation_text 等信息"""
if not llm_model:
# 获取默认的LLM模型
@@ -323,10 +323,10 @@ async def process_conversation_with_enhanced_memory(
async def retrieve_memories_with_enhanced_system(
query: str,
user_id: str,
context: Optional[Dict[str, Any]] = None,
context: dict[str, Any] | None = None,
limit: int = 10,
llm_model: Optional[LLMRequest] = None,
) -> List[MemoryChunk]:
llm_model: LLMRequest | None = None,
) -> list[MemoryChunk]:
"""使用增强记忆系统检索记忆"""
if not llm_model:
# 获取默认的LLM模型
@@ -345,9 +345,9 @@ async def retrieve_memories_with_enhanced_system(
async def get_memory_context_for_prompt(
query: str,
user_id: str,
context: Optional[Dict[str, Any]] = None,
context: dict[str, Any] | None = None,
max_memories: int = 5,
llm_model: Optional[LLMRequest] = None,
llm_model: LLMRequest | None = None,
) -> str:
"""获取用于提示词的记忆上下文"""
if not llm_model:

View File

@@ -1,15 +1,15 @@
# -*- coding: utf-8 -*-
"""
增强记忆系统钩子
用于在消息处理过程中自动构建和检索记忆
"""
from typing import Dict, List, Any, Optional
from datetime import datetime
from typing import Any
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
logger = get_logger(__name__)
@@ -27,7 +27,7 @@ class EnhancedMemoryHooks:
user_id: str,
chat_id: str,
message_id: str,
context: Optional[Dict[str, Any]] = None,
context: dict[str, Any] | None = None,
) -> bool:
"""
处理消息并构建记忆
@@ -106,8 +106,8 @@ class EnhancedMemoryHooks:
user_id: str,
chat_id: str,
limit: int = 5,
extra_context: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
extra_context: dict[str, Any] | None = None,
) -> list[dict[str, Any]]:
"""
为回复获取相关记忆

View File

@@ -1,19 +1,19 @@
# -*- coding: utf-8 -*-
"""
增强记忆系统集成脚本
用于在现有系统中无缝集成增强记忆功能
"""
from typing import Dict, Any, Optional
from typing import Any
from src.chat.memory_system.enhanced_memory_hooks import enhanced_memory_hooks
from src.common.logger import get_logger
from src.chat.memory_system.enhanced_memory_hooks import enhanced_memory_hooks
logger = get_logger(__name__)
async def process_user_message_memory(
message_content: str, user_id: str, chat_id: str, message_id: str, context: Optional[Dict[str, Any]] = None
message_content: str, user_id: str, chat_id: str, message_id: str, context: dict[str, Any] | None = None
) -> bool:
"""
处理用户消息并构建记忆
@@ -44,8 +44,8 @@ async def process_user_message_memory(
async def get_relevant_memories_for_response(
query_text: str, user_id: str, chat_id: str, limit: int = 5, extra_context: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
query_text: str, user_id: str, chat_id: str, limit: int = 5, extra_context: dict[str, Any] | None = None
) -> dict[str, Any]:
"""
为回复获取相关记忆
@@ -74,7 +74,7 @@ async def get_relevant_memories_for_response(
return {"has_memories": False, "memories": [], "memory_count": 0}
def format_memories_for_prompt(memories: Dict[str, Any]) -> str:
def format_memories_for_prompt(memories: dict[str, Any]) -> str:
"""
格式化记忆信息用于Prompt
@@ -114,7 +114,7 @@ async def cleanup_memory_system():
logger.error(f"记忆系统清理失败: {e}")
def get_memory_system_status() -> Dict[str, Any]:
def get_memory_system_status() -> dict[str, Any]:
"""
获取记忆系统状态
@@ -133,7 +133,7 @@ def get_memory_system_status() -> Dict[str, Any]:
# 便捷函数
async def remember_message(
message: str, user_id: str = "default_user", chat_id: str = "default_chat", context: Optional[Dict[str, Any]] = None
message: str, user_id: str = "default_user", chat_id: str = "default_chat", context: dict[str, Any] | None = None
) -> bool:
"""
便捷的记忆构建函数
@@ -159,8 +159,8 @@ async def recall_memories(
user_id: str = "default_user",
chat_id: str = "default_chat",
limit: int = 5,
context: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
context: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
便捷的记忆检索函数

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
增强重排序器
实现文档设计的多维度评分模型
@@ -6,12 +5,12 @@
import math
import time
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
from enum import Enum
from typing import Any
from src.common.logger import get_logger
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.common.logger import get_logger
logger = get_logger(__name__)
@@ -44,7 +43,7 @@ class ReRankingConfig:
freq_max_score: float = 5.0 # 最大频率得分
# 类型匹配权重映射
type_match_weights: Dict[str, Dict[str, float]] = None
type_match_weights: dict[str, dict[str, float]] = None
def __post_init__(self):
"""初始化类型匹配权重"""
@@ -157,7 +156,7 @@ class IntentClassifier:
],
}
def classify_intent(self, query: str, context: Dict[str, Any]) -> IntentType:
def classify_intent(self, query: str, context: dict[str, Any]) -> IntentType:
"""识别对话意图"""
if not query:
return IntentType.UNKNOWN
@@ -165,7 +164,7 @@ class IntentClassifier:
query_lower = query.lower()
# 统计各意图的匹配分数
intent_scores = {intent: 0 for intent in IntentType}
intent_scores = dict.fromkeys(IntentType, 0)
for intent, patterns in self.patterns.items():
for pattern in patterns:
@@ -187,7 +186,7 @@ class IntentClassifier:
class EnhancedReRanker:
"""增强重排序器 - 实现文档设计的多维度评分模型"""
def __init__(self, config: Optional[ReRankingConfig] = None):
def __init__(self, config: ReRankingConfig | None = None):
self.config = config or ReRankingConfig()
self.intent_classifier = IntentClassifier()
@@ -210,10 +209,10 @@ class EnhancedReRanker:
def rerank_memories(
self,
query: str,
candidate_memories: List[Tuple[str, MemoryChunk, float]], # (memory_id, memory, vector_similarity)
context: Dict[str, Any],
candidate_memories: list[tuple[str, MemoryChunk, float]], # (memory_id, memory, vector_similarity)
context: dict[str, Any],
limit: int = 10,
) -> List[Tuple[str, MemoryChunk, float]]:
) -> list[tuple[str, MemoryChunk, float]]:
"""
对候选记忆进行重排序
@@ -341,11 +340,11 @@ default_reranker = EnhancedReRanker()
def rerank_candidate_memories(
query: str,
candidate_memories: List[Tuple[str, MemoryChunk, float]],
context: Dict[str, Any],
candidate_memories: list[tuple[str, MemoryChunk, float]],
context: dict[str, Any],
limit: int = 10,
config: Optional[ReRankingConfig] = None,
) -> List[Tuple[str, MemoryChunk, float]]:
config: ReRankingConfig | None = None,
) -> list[tuple[str, MemoryChunk, float]]:
"""
便捷函数:对候选记忆进行重排序
"""

View File

@@ -1,18 +1,18 @@
# -*- coding: utf-8 -*-
"""
增强记忆系统集成层
现在只管理新的增强记忆系统,旧系统已被完全移除
"""
import time
import asyncio
from typing import Dict, List, Optional, Any
import time
from dataclasses import dataclass
from enum import Enum
from typing import Any
from src.common.logger import get_logger
from src.chat.memory_system.enhanced_memory_core import EnhancedMemorySystem
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
@@ -40,12 +40,12 @@ class IntegrationConfig:
class MemoryIntegrationLayer:
"""记忆系统集成层 - 现在只管理增强记忆系统"""
def __init__(self, llm_model: LLMRequest, config: Optional[IntegrationConfig] = None):
def __init__(self, llm_model: LLMRequest, config: IntegrationConfig | None = None):
self.llm_model = llm_model
self.config = config or IntegrationConfig()
# 只初始化增强记忆系统
self.enhanced_memory: Optional[EnhancedMemorySystem] = None
self.enhanced_memory: EnhancedMemorySystem | None = None
# 集成统计
self.integration_stats = {
@@ -113,7 +113,7 @@ class MemoryIntegrationLayer:
logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True)
raise
async def process_conversation(self, context: Dict[str, Any]) -> Dict[str, Any]:
async def process_conversation(self, context: dict[str, Any]) -> dict[str, Any]:
"""处理对话记忆,仅使用上下文信息"""
if not self._initialized or not self.enhanced_memory:
return {"success": False, "error": "Memory system not available"}
@@ -150,10 +150,10 @@ class MemoryIntegrationLayer:
async def retrieve_relevant_memories(
self,
query: str,
user_id: Optional[str] = None,
context: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
) -> List[MemoryChunk]:
user_id: str | None = None,
context: dict[str, Any] | None = None,
limit: int | None = None,
) -> list[MemoryChunk]:
"""检索相关记忆"""
if not self._initialized or not self.enhanced_memory:
return []
@@ -172,7 +172,7 @@ class MemoryIntegrationLayer:
logger.error(f"检索相关记忆失败: {e}", exc_info=True)
return []
async def get_system_status(self) -> Dict[str, Any]:
async def get_system_status(self) -> dict[str, Any]:
"""获取系统状态"""
if not self._initialized:
return {"status": "not_initialized"}
@@ -193,7 +193,7 @@ class MemoryIntegrationLayer:
logger.error(f"获取系统状态失败: {e}", exc_info=True)
return {"status": "error", "error": str(e)}
def get_integration_stats(self) -> Dict[str, Any]:
def get_integration_stats(self) -> dict[str, Any]:
"""获取集成统计信息"""
return self.integration_stats.copy()

View File

@@ -1,20 +1,20 @@
# -*- coding: utf-8 -*-
"""
记忆系统集成钩子
提供与现有MoFox Bot系统的无缝集成点
"""
import time
from typing import Dict, Optional, Any
from dataclasses import dataclass
from typing import Any
from src.common.logger import get_logger
from src.chat.memory_system.enhanced_memory_adapter import (
get_memory_context_for_prompt,
process_conversation_with_enhanced_memory,
retrieve_memories_with_enhanced_system,
get_memory_context_for_prompt,
)
from src.common.logger import get_logger
logger = get_logger(__name__)
@@ -24,7 +24,7 @@ class HookResult:
success: bool
data: Any = None
error: Optional[str] = None
error: str | None = None
processing_time: float = 0.0
@@ -125,8 +125,8 @@ class MemoryIntegrationHooks:
# 尝试注册到事件系统
try:
from src.plugin_system.core.event_manager import event_manager
from src.plugin_system.base.component_types import EventType
from src.plugin_system.core.event_manager import event_manager
# 注册消息后处理事件
event_manager.subscribe(EventType.MESSAGE_PROCESSED, self._on_message_processed_handler)
@@ -238,11 +238,11 @@ class MemoryIntegrationHooks:
# 钩子处理器方法
async def _on_message_processed_handler(self, event_data: Dict[str, Any]) -> HookResult:
async def _on_message_processed_handler(self, event_data: dict[str, Any]) -> HookResult:
"""事件系统的消息处理处理器"""
return await self._on_message_processed_hook(event_data)
async def _on_message_processed_hook(self, message_data: Dict[str, Any]) -> HookResult:
async def _on_message_processed_hook(self, message_data: dict[str, Any]) -> HookResult:
"""消息后处理钩子"""
start_time = time.time()
@@ -289,7 +289,7 @@ class MemoryIntegrationHooks:
logger.error(f"消息处理钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
async def _on_chat_stream_save_hook(self, chat_stream_data: Dict[str, Any]) -> HookResult:
async def _on_chat_stream_save_hook(self, chat_stream_data: dict[str, Any]) -> HookResult:
"""聊天流保存钩子"""
start_time = time.time()
@@ -345,7 +345,7 @@ class MemoryIntegrationHooks:
logger.error(f"聊天流保存钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
async def _on_pre_response_hook(self, response_data: Dict[str, Any]) -> HookResult:
async def _on_pre_response_hook(self, response_data: dict[str, Any]) -> HookResult:
"""回复前钩子"""
start_time = time.time()
@@ -380,7 +380,7 @@ class MemoryIntegrationHooks:
logger.error(f"回复前钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
async def _on_knowledge_query_hook(self, query_data: Dict[str, Any]) -> HookResult:
async def _on_knowledge_query_hook(self, query_data: dict[str, Any]) -> HookResult:
"""知识库查询钩子"""
start_time = time.time()
@@ -411,7 +411,7 @@ class MemoryIntegrationHooks:
logger.error(f"知识库查询钩子执行异常: {e}", exc_info=True)
return HookResult(success=False, error=str(e), processing_time=processing_time)
async def _on_prompt_building_hook(self, prompt_data: Dict[str, Any]) -> HookResult:
async def _on_prompt_building_hook(self, prompt_data: dict[str, Any]) -> HookResult:
"""提示词构建钩子"""
start_time = time.time()
@@ -459,7 +459,7 @@ class MemoryIntegrationHooks:
new_avg = (current_avg * (total_executions - 1) + processing_time) / total_executions
self.hook_stats["average_hook_time"] = new_avg
def get_hook_stats(self) -> Dict[str, Any]:
def get_hook_stats(self) -> dict[str, Any]:
"""获取钩子统计信息"""
return self.hook_stats.copy()
@@ -501,7 +501,7 @@ class MemoryMaintenanceTask:
# 全局钩子实例
_memory_hooks: Optional[MemoryIntegrationHooks] = None
_memory_hooks: MemoryIntegrationHooks | None = None
async def get_memory_integration_hooks() -> MemoryIntegrationHooks:

View File

@@ -1,20 +1,20 @@
# -*- coding: utf-8 -*-
"""
元数据索引系统
为记忆系统提供多维度的精准过滤和查询能力
"""
import threading
import time
import orjson
from typing import Dict, List, Optional, Tuple, Set, Any, Union
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum
import threading
from collections import defaultdict
from pathlib import Path
from typing import Any
import orjson
from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk, MemoryType
from src.common.logger import get_logger
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel
logger = get_logger(__name__)
@@ -40,21 +40,21 @@ class IndexType(Enum):
class IndexQuery:
"""索引查询条件"""
user_ids: Optional[List[str]] = None
memory_types: Optional[List[MemoryType]] = None
subjects: Optional[List[str]] = None
keywords: Optional[List[str]] = None
tags: Optional[List[str]] = None
categories: Optional[List[str]] = None
time_range: Optional[Tuple[float, float]] = None
confidence_levels: Optional[List[ConfidenceLevel]] = None
importance_levels: Optional[List[ImportanceLevel]] = None
min_relationship_score: Optional[float] = None
max_relationship_score: Optional[float] = None
min_access_count: Optional[int] = None
semantic_hashes: Optional[List[str]] = None
limit: Optional[int] = None
sort_by: Optional[str] = None # "created_at", "access_count", "relevance_score"
user_ids: list[str] | None = None
memory_types: list[MemoryType] | None = None
subjects: list[str] | None = None
keywords: list[str] | None = None
tags: list[str] | None = None
categories: list[str] | None = None
time_range: tuple[float, float] | None = None
confidence_levels: list[ConfidenceLevel] | None = None
importance_levels: list[ImportanceLevel] | None = None
min_relationship_score: float | None = None
max_relationship_score: float | None = None
min_access_count: int | None = None
semantic_hashes: list[str] | None = None
limit: int | None = None
sort_by: str | None = None # "created_at", "access_count", "relevance_score"
sort_order: str = "desc" # "asc", "desc"
@@ -62,10 +62,10 @@ class IndexQuery:
class IndexResult:
"""索引结果"""
memory_ids: List[str]
memory_ids: list[str]
total_count: int
query_time: float
filtered_by: List[str]
filtered_by: list[str]
class MetadataIndexManager:
@@ -94,7 +94,7 @@ class MetadataIndexManager:
self.access_frequency_index = [] # [(access_count, memory_id), ...]
# 内存缓存
self.memory_metadata_cache: Dict[str, Dict[str, Any]] = {}
self.memory_metadata_cache: dict[str, dict[str, Any]] = {}
# 统计信息
self.index_stats = {
@@ -140,7 +140,7 @@ class MetadataIndexManager:
return key
@staticmethod
def _serialize_metadata_entry(metadata: Dict[str, Any]) -> Dict[str, Any]:
def _serialize_metadata_entry(metadata: dict[str, Any]) -> dict[str, Any]:
serialized = {}
for field_name, value in metadata.items():
if isinstance(value, Enum):
@@ -149,7 +149,7 @@ class MetadataIndexManager:
serialized[field_name] = value
return serialized
async def index_memories(self, memories: List[MemoryChunk]):
async def index_memories(self, memories: list[MemoryChunk]):
"""为记忆建立索引"""
if not memories:
return
@@ -375,7 +375,7 @@ class MetadataIndexManager:
logger.error(f"❌ 元数据查询失败: {e}", exc_info=True)
return IndexResult(memory_ids=[], total_count=0, query_time=0.0, filtered_by=[])
def _get_candidate_memories(self, query: IndexQuery) -> Set[str]:
def _get_candidate_memories(self, query: IndexQuery) -> set[str]:
"""获取候选记忆ID集合"""
candidate_ids = set()
@@ -444,7 +444,7 @@ class MetadataIndexManager:
return candidate_ids
def _collect_index_matches(self, index_type: IndexType, token: Optional[Union[str, Enum]]) -> Set[str]:
def _collect_index_matches(self, index_type: IndexType, token: str | Enum | None) -> set[str]:
"""根据给定token收集索引匹配支持部分匹配"""
mapping = self.indices.get(index_type)
if mapping is None:
@@ -461,7 +461,7 @@ class MetadataIndexManager:
if not key:
return set()
matches: Set[str] = set(mapping.get(key, set()))
matches: set[str] = set(mapping.get(key, set()))
if matches:
return set(matches)
@@ -477,7 +477,7 @@ class MetadataIndexManager:
return matches
def _apply_filters(self, candidate_ids: Set[str], query: IndexQuery) -> List[str]:
def _apply_filters(self, candidate_ids: set[str], query: IndexQuery) -> list[str]:
"""应用过滤条件"""
filtered_ids = list(candidate_ids)
@@ -545,7 +545,7 @@ class MetadataIndexManager:
created_at = self.memory_metadata_cache[memory_id]["created_at"]
return start_time <= created_at <= end_time
def _sort_memories(self, memory_ids: List[str], sort_by: str, sort_order: str) -> List[str]:
def _sort_memories(self, memory_ids: list[str], sort_by: str, sort_order: str) -> list[str]:
"""对记忆进行排序"""
if sort_by == "created_at":
# 使用时间索引(已经有序)
@@ -582,7 +582,7 @@ class MetadataIndexManager:
return memory_ids
def _get_applied_filters(self, query: IndexQuery) -> List[str]:
def _get_applied_filters(self, query: IndexQuery) -> list[str]:
"""获取应用的过滤器列表"""
filters = []
if query.memory_types:
@@ -686,11 +686,11 @@ class MetadataIndexManager:
except Exception as e:
logger.error(f"❌ 移除记忆索引失败: {e}")
async def get_memory_metadata(self, memory_id: str) -> Optional[Dict[str, Any]]:
async def get_memory_metadata(self, memory_id: str) -> dict[str, Any] | None:
"""获取记忆元数据"""
return self.memory_metadata_cache.get(memory_id)
async def get_user_memory_ids(self, user_id: str, limit: Optional[int] = None) -> List[str]:
async def get_user_memory_ids(self, user_id: str, limit: int | None = None) -> list[str]:
"""获取用户的所有记忆ID"""
user_memory_ids = list(self.indices[IndexType.USER_ID].get(user_id, set()))
@@ -699,7 +699,7 @@ class MetadataIndexManager:
return user_memory_ids
async def get_memory_statistics(self, user_id: Optional[str] = None) -> Dict[str, Any]:
async def get_memory_statistics(self, user_id: str | None = None) -> dict[str, Any]:
"""获取记忆统计信息"""
stats = {
"total_memories": self.index_stats["total_memories"],
@@ -784,7 +784,7 @@ class MetadataIndexManager:
logger.info("正在保存元数据索引...")
# 保存各类索引
indices_data: Dict[str, Dict[str, List[str]]] = {}
indices_data: dict[str, dict[str, list[str]]] = {}
for index_type, index_data in self.indices.items():
serialized_index = {}
for key, values in index_data.items():
@@ -839,7 +839,7 @@ class MetadataIndexManager:
# 加载各类索引
indices_file = self.index_path / "indices.json"
if indices_file.exists():
with open(indices_file, "r", encoding="utf-8") as f:
with open(indices_file, encoding="utf-8") as f:
indices_data = orjson.loads(f.read())
for index_type_value, index_data in indices_data.items():
@@ -853,25 +853,25 @@ class MetadataIndexManager:
# 加载时间索引
time_index_file = self.index_path / "time_index.json"
if time_index_file.exists():
with open(time_index_file, "r", encoding="utf-8") as f:
with open(time_index_file, encoding="utf-8") as f:
self.time_index = orjson.loads(f.read())
# 加载关系分索引
relationship_index_file = self.index_path / "relationship_index.json"
if relationship_index_file.exists():
with open(relationship_index_file, "r", encoding="utf-8") as f:
with open(relationship_index_file, encoding="utf-8") as f:
self.relationship_index = orjson.loads(f.read())
# 加载访问频率索引
access_frequency_index_file = self.index_path / "access_frequency_index.json"
if access_frequency_index_file.exists():
with open(access_frequency_index_file, "r", encoding="utf-8") as f:
with open(access_frequency_index_file, encoding="utf-8") as f:
self.access_frequency_index = orjson.loads(f.read())
# 加载元数据缓存
metadata_cache_file = self.index_path / "metadata_cache.json"
if metadata_cache_file.exists():
with open(metadata_cache_file, "r", encoding="utf-8") as f:
with open(metadata_cache_file, encoding="utf-8") as f:
cache_data = orjson.loads(f.read())
# 转换置信度和重要性为枚举类型
@@ -914,7 +914,7 @@ class MetadataIndexManager:
# 加载统计信息
stats_file = self.index_path / "index_stats.json"
if stats_file.exists():
with open(stats_file, "r", encoding="utf-8") as f:
with open(stats_file, encoding="utf-8") as f:
self.index_stats = orjson.loads(f.read())
# 更新记忆计数
@@ -1004,7 +1004,7 @@ class MetadataIndexManager:
if len(self.indices[IndexType.CATEGORY][category]) < min_frequency:
del self.indices[IndexType.CATEGORY][category]
def get_index_stats(self) -> Dict[str, Any]:
def get_index_stats(self) -> dict[str, Any]:
"""获取索引统计信息"""
stats = self.index_stats.copy()
if stats["total_queries"] > 0:

View File

@@ -1,19 +1,19 @@
# -*- coding: utf-8 -*-
"""
多阶段召回机制
实现粗粒度到细粒度的记忆检索优化
"""
import time
from typing import Dict, List, Optional, Set, Any
from dataclasses import dataclass, field
from enum import Enum
import orjson
from typing import Any
from src.common.logger import get_logger
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
import orjson
from src.chat.memory_system.enhanced_reranker import EnhancedReRanker, ReRankingConfig
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.common.logger import get_logger
logger = get_logger(__name__)
@@ -73,11 +73,11 @@ class StageResult:
"""阶段结果"""
stage: RetrievalStage
memory_ids: List[str]
memory_ids: list[str]
processing_time: float
filtered_count: int
score_threshold: float
details: List[Dict[str, Any]] = field(default_factory=list)
details: list[dict[str, Any]] = field(default_factory=list)
@dataclass
@@ -86,17 +86,17 @@ class RetrievalResult:
query: str
user_id: str
final_memories: List[MemoryChunk]
stage_results: List[StageResult]
final_memories: list[MemoryChunk]
stage_results: list[StageResult]
total_processing_time: float
total_filtered: int
retrieval_stats: Dict[str, Any]
retrieval_stats: dict[str, Any]
class MultiStageRetrieval:
"""多阶段召回系统"""
def __init__(self, config: Optional[RetrievalConfig] = None):
def __init__(self, config: RetrievalConfig | None = None):
self.config = config or RetrievalConfig.from_global_config()
# 初始化增强重排序器
@@ -124,11 +124,11 @@ class MultiStageRetrieval:
self,
query: str,
user_id: str,
context: Dict[str, Any],
context: dict[str, Any],
metadata_index,
vector_storage,
all_memories_cache: Dict[str, MemoryChunk],
limit: Optional[int] = None,
all_memories_cache: dict[str, MemoryChunk],
limit: int | None = None,
) -> RetrievalResult:
"""多阶段记忆检索"""
start_time = time.time()
@@ -136,7 +136,7 @@ class MultiStageRetrieval:
stage_results = []
current_memory_ids = set()
memory_debug_info: Dict[str, Dict[str, Any]] = {}
memory_debug_info: dict[str, dict[str, Any]] = {}
try:
logger.debug(f"开始多阶段检索query='{query}', user_id='{user_id}'")
@@ -311,11 +311,11 @@ class MultiStageRetrieval:
self,
query: str,
user_id: str,
context: Dict[str, Any],
context: dict[str, Any],
metadata_index,
all_memories_cache: Dict[str, MemoryChunk],
all_memories_cache: dict[str, MemoryChunk],
*,
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
debug_log: dict[str, dict[str, Any]] | None = None,
) -> StageResult:
"""阶段1元数据过滤"""
start_time = time.time()
@@ -345,7 +345,7 @@ class MultiStageRetrieval:
result = await metadata_index.query_memories(index_query)
result_ids = list(result.memory_ids)
filtered_count = max(0, len(all_memories_cache) - len(result_ids))
details: List[Dict[str, Any]] = []
details: list[dict[str, Any]] = []
# 如果未命中任何索引且未指定所有者过滤,则回退到最近访问的记忆
if not result_ids:
@@ -440,12 +440,12 @@ class MultiStageRetrieval:
self,
query: str,
user_id: str,
context: Dict[str, Any],
context: dict[str, Any],
vector_storage,
candidate_ids: Set[str],
all_memories_cache: Dict[str, MemoryChunk],
candidate_ids: set[str],
all_memories_cache: dict[str, MemoryChunk],
*,
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
debug_log: dict[str, dict[str, Any]] | None = None,
) -> StageResult:
"""阶段2向量搜索"""
start_time = time.time()
@@ -479,8 +479,8 @@ class MultiStageRetrieval:
# 过滤候选记忆
filtered_memories = []
details: List[Dict[str, Any]] = []
raw_details: List[Dict[str, Any]] = []
details: list[dict[str, Any]] = []
raw_details: list[dict[str, Any]] = []
threshold = self.config.vector_similarity_threshold
for memory_id, similarity in search_result:
@@ -561,7 +561,7 @@ class MultiStageRetrieval:
)
def _create_text_search_fallback(
self, candidate_ids: Set[str], all_memories_cache: Dict[str, MemoryChunk], query_text: str, start_time: float
self, candidate_ids: set[str], all_memories_cache: dict[str, MemoryChunk], query_text: str, start_time: float
) -> StageResult:
"""当向量搜索失败时,使用文本搜索作为回退策略"""
try:
@@ -618,18 +618,18 @@ class MultiStageRetrieval:
self,
query: str,
user_id: str,
context: Dict[str, Any],
candidate_ids: Set[str],
all_memories_cache: Dict[str, MemoryChunk],
context: dict[str, Any],
candidate_ids: set[str],
all_memories_cache: dict[str, MemoryChunk],
*,
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
debug_log: dict[str, dict[str, Any]] | None = None,
) -> StageResult:
"""阶段3语义重排序"""
start_time = time.time()
try:
reranked_memories = []
details: List[Dict[str, Any]] = []
details: list[dict[str, Any]] = []
threshold = self.config.semantic_similarity_threshold
for memory_id in candidate_ids:
@@ -704,19 +704,19 @@ class MultiStageRetrieval:
self,
query: str,
user_id: str,
context: Dict[str, Any],
candidate_ids: List[str],
all_memories_cache: Dict[str, MemoryChunk],
context: dict[str, Any],
candidate_ids: list[str],
all_memories_cache: dict[str, MemoryChunk],
limit: int,
*,
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
debug_log: dict[str, dict[str, Any]] | None = None,
) -> StageResult:
"""阶段4上下文过滤"""
start_time = time.time()
try:
final_memories = []
details: List[Dict[str, Any]] = []
details: list[dict[str, Any]] = []
for memory_id in candidate_ids:
if memory_id not in all_memories_cache:
@@ -793,12 +793,12 @@ class MultiStageRetrieval:
self,
query: str,
user_id: str,
context: Dict[str, Any],
all_memories_cache: Dict[str, MemoryChunk],
context: dict[str, Any],
all_memories_cache: dict[str, MemoryChunk],
limit: int,
*,
excluded_ids: Optional[Set[str]] = None,
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
excluded_ids: set[str] | None = None,
debug_log: dict[str, dict[str, Any]] | None = None,
) -> StageResult:
"""回退检索阶段 - 当主检索失败时使用更宽松的策略"""
start_time = time.time()
@@ -881,8 +881,8 @@ class MultiStageRetrieval:
)
async def _generate_query_embedding(
self, query: str, context: Dict[str, Any], vector_storage
) -> Optional[List[float]]:
self, query: str, context: dict[str, Any], vector_storage
) -> list[float] | None:
"""生成查询向量"""
try:
query_plan = context.get("query_plan")
@@ -916,7 +916,7 @@ class MultiStageRetrieval:
logger.error(f"生成查询向量时发生异常: {e}", exc_info=True)
return None
async def _calculate_semantic_similarity(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
async def _calculate_semantic_similarity(self, query: str, memory: MemoryChunk, context: dict[str, Any]) -> float:
"""计算语义相似度 - 简化优化版本,提升召回率"""
try:
query_plan = context.get("query_plan")
@@ -947,9 +947,10 @@ class MultiStageRetrieval:
# 核心匹配策略2词汇匹配
word_score = 0.0
try:
import jieba
import re
import jieba
# 分词处理
query_words = list(jieba.cut(query_text)) + re.findall(r"[a-zA-Z]+", query_text)
memory_words = list(jieba.cut(memory_text)) + re.findall(r"[a-zA-Z]+", memory_text)
@@ -1059,7 +1060,7 @@ class MultiStageRetrieval:
logger.warning(f"计算语义相似度失败: {e}")
return 0.0
async def _calculate_context_relevance(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
async def _calculate_context_relevance(self, query: str, memory: MemoryChunk, context: dict[str, Any]) -> float:
"""计算上下文相关度"""
try:
score = 0.0
@@ -1132,7 +1133,7 @@ class MultiStageRetrieval:
return 0.0
async def _calculate_final_score(
self, query: str, memory: MemoryChunk, context: Dict[str, Any], context_score: float
self, query: str, memory: MemoryChunk, context: dict[str, Any], context_score: float
) -> float:
"""计算最终评分"""
try:
@@ -1184,7 +1185,7 @@ class MultiStageRetrieval:
logger.warning(f"计算最终评分失败: {e}")
return 0.0
def _calculate_subject_overlap(self, memory: MemoryChunk, required_subjects: Optional[List[str]]) -> float:
def _calculate_subject_overlap(self, memory: MemoryChunk, required_subjects: list[str] | None) -> float:
if not required_subjects:
return 0.0
@@ -1229,7 +1230,7 @@ class MultiStageRetrieval:
except Exception:
return 0.5
def _extract_memory_types_from_context(self, context: Dict[str, Any]) -> List[MemoryType]:
def _extract_memory_types_from_context(self, context: dict[str, Any]) -> list[MemoryType]:
"""从上下文中提取记忆类型"""
try:
query_plan = context.get("query_plan")
@@ -1256,10 +1257,10 @@ class MultiStageRetrieval:
except Exception:
return []
def _extract_keywords_from_query(self, query: str, query_plan: Optional[Any] = None) -> List[str]:
def _extract_keywords_from_query(self, query: str, query_plan: Any | None = None) -> list[str]:
"""从查询中提取关键词"""
try:
extracted: List[str] = []
extracted: list[str] = []
if query_plan and getattr(query_plan, "required_keywords", None):
extracted.extend([kw.lower() for kw in query_plan.required_keywords if isinstance(kw, str)])
@@ -1283,7 +1284,7 @@ class MultiStageRetrieval:
except Exception:
return []
def _update_retrieval_stats(self, total_time: float, stage_results: List[StageResult]):
def _update_retrieval_stats(self, total_time: float, stage_results: list[StageResult]):
"""更新检索统计"""
self.retrieval_stats["total_queries"] += 1
@@ -1306,7 +1307,7 @@ class MultiStageRetrieval:
]
stage_stat["avg_time"] = new_stage_avg
def get_retrieval_stats(self) -> Dict[str, Any]:
def get_retrieval_stats(self) -> dict[str, Any]:
"""获取检索统计信息"""
return self.retrieval_stats.copy()
@@ -1328,12 +1329,12 @@ class MultiStageRetrieval:
self,
query: str,
user_id: str,
context: Dict[str, Any],
candidate_ids: List[str],
all_memories_cache: Dict[str, MemoryChunk],
context: dict[str, Any],
candidate_ids: list[str],
all_memories_cache: dict[str, MemoryChunk],
limit: int,
*,
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
debug_log: dict[str, dict[str, Any]] | None = None,
) -> StageResult:
"""阶段5增强重排序 - 使用多维度评分模型"""
start_time = time.time()

View File

@@ -1,24 +1,23 @@
# -*- coding: utf-8 -*-
"""
向量数据库存储接口
为记忆系统提供高效的向量存储和语义搜索能力
"""
import time
import orjson
import asyncio
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
import threading
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
from pathlib import Path
import orjson
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.common.config_helpers import resolve_embedding_dimension
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.common.config_helpers import resolve_embedding_dimension
from src.common.logger import get_logger
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
@@ -48,7 +47,7 @@ class VectorStorageConfig:
class VectorStorageManager:
"""向量存储管理器"""
def __init__(self, config: Optional[VectorStorageConfig] = None):
def __init__(self, config: VectorStorageConfig | None = None):
self.config = config or VectorStorageConfig()
resolved_dimension = resolve_embedding_dimension(self.config.dimension)
@@ -68,8 +67,8 @@ class VectorStorageManager:
self.index_to_memory_id = {} # vector index -> memory_id
# 内存缓存
self.memory_cache: Dict[str, MemoryChunk] = {}
self.vector_cache: Dict[str, List[float]] = {}
self.memory_cache: dict[str, MemoryChunk] = {}
self.vector_cache: dict[str, list[float]] = {}
# 统计信息
self.storage_stats = {
@@ -125,7 +124,7 @@ class VectorStorageManager:
)
logger.info("✅ 嵌入模型初始化完成")
async def generate_query_embedding(self, query_text: str) -> Optional[List[float]]:
async def generate_query_embedding(self, query_text: str) -> list[float] | None:
"""生成查询向量,用于记忆召回"""
if not query_text:
logger.warning("查询文本为空,无法生成向量")
@@ -155,7 +154,7 @@ class VectorStorageManager:
logger.error(f"❌ 生成查询向量失败: {exc}", exc_info=True)
return None
async def store_memories(self, memories: List[MemoryChunk]):
async def store_memories(self, memories: list[MemoryChunk]):
"""存储记忆向量"""
if not memories:
return
@@ -231,7 +230,7 @@ class VectorStorageManager:
logger.debug("记忆 %s 缺少可用展示文本,使用占位符生成嵌入输入", memory.memory_id)
return memory.memory_id
async def _batch_generate_and_store_embeddings(self, memory_texts: List[Tuple[str, str]]):
async def _batch_generate_and_store_embeddings(self, memory_texts: list[tuple[str, str]]):
"""批量生成和存储嵌入向量"""
if not memory_texts:
return
@@ -253,12 +252,12 @@ class VectorStorageManager:
except Exception as e:
logger.error(f"❌ 批量生成嵌入向量失败: {e}")
async def _batch_generate_embeddings(self, memory_ids: List[str], texts: List[str]) -> Dict[str, List[float]]:
async def _batch_generate_embeddings(self, memory_ids: list[str], texts: list[str]) -> dict[str, list[float]]:
"""批量生成嵌入向量"""
if not texts:
return {}
results: Dict[str, List[float]] = {}
results: dict[str, list[float]] = {}
try:
semaphore = asyncio.Semaphore(min(4, max(1, len(texts))))
@@ -281,7 +280,9 @@ class VectorStorageManager:
logger.warning("生成记忆 %s 的嵌入向量失败: %s", memory_id, exc)
results[memory_id] = []
tasks = [asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts, strict=False)]
tasks = [
asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts, strict=False)
]
await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e:
@@ -291,7 +292,7 @@ class VectorStorageManager:
return results
async def _add_single_memory(self, memory: MemoryChunk, embedding: List[float]):
async def _add_single_memory(self, memory: MemoryChunk, embedding: list[float]):
"""添加单个记忆到向量存储"""
with self._lock:
try:
@@ -337,7 +338,7 @@ class VectorStorageManager:
except Exception as e:
logger.error(f"❌ 添加记忆到向量存储失败: {e}")
def _normalize_vector(self, vector: List[float]) -> List[float]:
def _normalize_vector(self, vector: list[float]) -> list[float]:
"""L2归一化向量"""
if not vector:
return vector
@@ -357,12 +358,12 @@ class VectorStorageManager:
async def search_similar_memories(
self,
query_vector: Optional[List[float]] = None,
query_vector: list[float] | None = None,
*,
query_text: Optional[str] = None,
query_text: str | None = None,
limit: int = 10,
scope_id: Optional[str] = None,
) -> List[Tuple[str, float]]:
scope_id: str | None = None,
) -> list[tuple[str, float]]:
"""搜索相似记忆"""
start_time = time.time()
@@ -379,7 +380,7 @@ class VectorStorageManager:
logger.warning("查询向量生成失败")
return []
scope_filter: Optional[str] = None
scope_filter: str | None = None
if isinstance(scope_id, str):
normalized_scope = scope_id.strip().lower()
if normalized_scope and normalized_scope not in {"global", "global_memory"}:
@@ -491,7 +492,7 @@ class VectorStorageManager:
logger.error(f"❌ 向量搜索失败: {e}", exc_info=True)
return []
async def get_memory_by_id(self, memory_id: str) -> Optional[MemoryChunk]:
async def get_memory_by_id(self, memory_id: str) -> MemoryChunk | None:
"""根据ID获取记忆"""
# 先检查缓存
if memory_id in self.memory_cache:
@@ -501,7 +502,7 @@ class VectorStorageManager:
self.storage_stats["total_searches"] += 1
return None
async def update_memory_embedding(self, memory_id: str, new_embedding: List[float]):
async def update_memory_embedding(self, memory_id: str, new_embedding: list[float]):
"""更新记忆的嵌入向量"""
with self._lock:
try:
@@ -636,7 +637,7 @@ class VectorStorageManager:
# 加载记忆缓存
cache_file = self.storage_path / "memory_cache.json"
if cache_file.exists():
with open(cache_file, "r", encoding="utf-8") as f:
with open(cache_file, encoding="utf-8") as f:
cache_data = orjson.loads(f.read())
self.memory_cache = {
@@ -646,13 +647,13 @@ class VectorStorageManager:
# 加载向量缓存
vector_cache_file = self.storage_path / "vector_cache.json"
if vector_cache_file.exists():
with open(vector_cache_file, "r", encoding="utf-8") as f:
with open(vector_cache_file, encoding="utf-8") as f:
self.vector_cache = orjson.loads(f.read())
# 加载映射关系
mapping_file = self.storage_path / "id_mapping.json"
if mapping_file.exists():
with open(mapping_file, "r", encoding="utf-8") as f:
with open(mapping_file, encoding="utf-8") as f:
mapping_data = orjson.loads(f.read())
raw_memory_to_index = mapping_data.get("memory_id_to_index", {})
self.memory_id_to_index = {
@@ -689,7 +690,7 @@ class VectorStorageManager:
# 加载统计信息
stats_file = self.storage_path / "storage_stats.json"
if stats_file.exists():
with open(stats_file, "r", encoding="utf-8") as f:
with open(stats_file, encoding="utf-8") as f:
self.storage_stats = orjson.loads(f.read())
# 更新向量计数
@@ -806,7 +807,7 @@ class VectorStorageManager:
if invalid_memory_ids:
logger.info(f"清理了 {len(invalid_memory_ids)} 个无效引用")
def get_storage_stats(self) -> Dict[str, Any]:
def get_storage_stats(self) -> dict[str, Any]:
"""获取存储统计信息"""
stats = self.storage_stats.copy()
if stats["total_searches"] > 0:
@@ -821,11 +822,11 @@ class SimpleVectorIndex:
def __init__(self, dimension: int):
self.dimension = dimension
self.vectors: List[List[float]] = []
self.vector_ids: List[int] = []
self.vectors: list[list[float]] = []
self.vector_ids: list[int] = []
self.next_id = 0
def add_vector(self, vector: List[float]) -> int:
def add_vector(self, vector: list[float]) -> int:
"""添加向量"""
if len(vector) != self.dimension:
raise ValueError(f"向量维度不匹配,期望 {self.dimension},实际 {len(vector)}")
@@ -837,7 +838,7 @@ class SimpleVectorIndex:
return vector_id
def search(self, query_vector: List[float], limit: int) -> List[Tuple[int, float]]:
def search(self, query_vector: list[float], limit: int) -> list[tuple[int, float]]:
"""搜索相似向量"""
if len(query_vector) != self.dimension:
raise ValueError(f"查询向量维度不匹配,期望 {self.dimension},实际 {len(query_vector)}")
@@ -853,7 +854,7 @@ class SimpleVectorIndex:
return results[:limit]
def _calculate_cosine_similarity(self, v1: List[float], v2: List[float]) -> float:
def _calculate_cosine_similarity(self, v1: list[float], v2: list[float]) -> float:
"""计算余弦相似度"""
try:
dot_product = sum(x * y for x, y in zip(v1, v2, strict=False))

View File

@@ -1,25 +1,24 @@
# -*- coding: utf-8 -*-
"""
记忆激活器
记忆系统的激活器组件
"""
import difflib
import orjson
from typing import List, Dict, Optional
from datetime import datetime
import orjson
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.memory_system.memory_manager import MemoryResult
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger("memory_activator")
def get_keywords_from_json(json_str) -> List:
def get_keywords_from_json(json_str) -> list:
"""
从JSON字符串中提取关键词列表
@@ -81,7 +80,7 @@ class MemoryActivator:
self.cached_keywords = set() # 用于缓存历史关键词
self.last_memory_query_time = 0 # 上次查询记忆的时间
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Dict]:
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> list[dict]:
"""
激活记忆
"""
@@ -155,7 +154,7 @@ class MemoryActivator:
return self.running_memory
async def _query_unified_memory(self, keywords: List[str], query_text: str) -> List[MemoryResult]:
async def _query_unified_memory(self, keywords: list[str], query_text: str) -> list[MemoryResult]:
"""查询统一记忆系统"""
try:
# 使用记忆系统
@@ -198,7 +197,7 @@ class MemoryActivator:
logger.error(f"查询统一记忆失败: {e}")
return []
async def get_instant_memory(self, target_message: str, chat_id: str) -> Optional[str]:
async def get_instant_memory(self, target_message: str, chat_id: str) -> str | None:
"""
获取即时记忆 - 兼容原有接口(使用统一存储)
"""

View File

@@ -1,25 +1,24 @@
# -*- coding: utf-8 -*-
"""
记忆激活器
记忆系统的激活器组件
"""
import difflib
import orjson
from typing import List, Dict, Optional
from datetime import datetime
import orjson
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.memory_system.memory_manager import MemoryResult
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger("memory_activator")
def get_keywords_from_json(json_str) -> List:
def get_keywords_from_json(json_str) -> list:
"""
从JSON字符串中提取关键词列表
@@ -81,7 +80,7 @@ class MemoryActivator:
self.cached_keywords = set() # 用于缓存历史关键词
self.last_memory_query_time = 0 # 上次查询记忆的时间
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Dict]:
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> list[dict]:
"""
激活记忆
"""
@@ -155,7 +154,7 @@ class MemoryActivator:
return self.running_memory
async def _query_unified_memory(self, keywords: List[str], query_text: str) -> List[MemoryResult]:
async def _query_unified_memory(self, keywords: list[str], query_text: str) -> list[MemoryResult]:
"""查询统一记忆系统"""
try:
# 使用记忆系统
@@ -198,7 +197,7 @@ class MemoryActivator:
logger.error(f"查询统一记忆失败: {e}")
return []
async def get_instant_memory(self, target_message: str, chat_id: str) -> Optional[str]:
async def get_instant_memory(self, target_message: str, chat_id: str) -> str | None:
"""
获取即时记忆 - 兼容原有接口(使用统一存储)
"""

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
记忆构建模块
从对话流中提取高质量、结构化记忆单元
@@ -33,19 +32,19 @@ import time
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Union, Type
from typing import Any
import orjson
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.chat.memory_system.memory_chunk import (
MemoryChunk,
MemoryType,
ConfidenceLevel,
ImportanceLevel,
MemoryChunk,
MemoryType,
create_memory_chunk,
)
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
@@ -62,8 +61,8 @@ class ExtractionStrategy(Enum):
class ExtractionResult:
"""提取结果"""
memories: List[MemoryChunk]
confidence_scores: List[float]
memories: list[MemoryChunk]
confidence_scores: list[float]
extraction_time: float
strategy_used: ExtractionStrategy
@@ -85,8 +84,8 @@ class MemoryBuilder:
}
async def build_memories(
self, conversation_text: str, context: Dict[str, Any], user_id: str, timestamp: float
) -> List[MemoryChunk]:
self, conversation_text: str, context: dict[str, Any], user_id: str, timestamp: float
) -> list[MemoryChunk]:
"""从对话中构建记忆"""
start_time = time.time()
@@ -116,8 +115,8 @@ class MemoryBuilder:
raise
async def _extract_with_llm(
self, text: str, context: Dict[str, Any], user_id: str, timestamp: float
) -> List[MemoryChunk]:
self, text: str, context: dict[str, Any], user_id: str, timestamp: float
) -> list[MemoryChunk]:
"""使用LLM提取记忆"""
try:
prompt = self._build_llm_extraction_prompt(text, context)
@@ -135,7 +134,7 @@ class MemoryBuilder:
logger.error(f"LLM提取失败: {e}")
raise MemoryExtractionError(str(e)) from e
def _build_llm_extraction_prompt(self, text: str, context: Dict[str, Any]) -> str:
def _build_llm_extraction_prompt(self, text: str, context: dict[str, Any]) -> str:
"""构建LLM提取提示"""
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
message_type = context.get("message_type", "normal")
@@ -315,7 +314,7 @@ class MemoryBuilder:
return prompt
def _extract_json_payload(self, response: str) -> Optional[str]:
def _extract_json_payload(self, response: str) -> str | None:
"""从模型响应中提取JSON部分兼容Markdown代码块等格式"""
if not response:
return None
@@ -338,8 +337,8 @@ class MemoryBuilder:
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
def _parse_llm_response(
self, response: str, user_id: str, timestamp: float, context: Dict[str, Any]
) -> List[MemoryChunk]:
self, response: str, user_id: str, timestamp: float, context: dict[str, Any]
) -> list[MemoryChunk]:
"""解析LLM响应"""
if not response:
raise MemoryExtractionError("LLM未返回任何响应")
@@ -385,7 +384,7 @@ class MemoryBuilder:
bot_display = self._clean_subject_text(bot_display)
memories: List[MemoryChunk] = []
memories: list[MemoryChunk] = []
for mem_data in memory_list:
try:
@@ -460,7 +459,7 @@ class MemoryBuilder:
return memories
def _parse_enum_value(self, enum_cls: Type[Enum], raw_value: Any, default: Enum, field_name: str) -> Enum:
def _parse_enum_value(self, enum_cls: type[Enum], raw_value: Any, default: Enum, field_name: str) -> Enum:
"""解析枚举值,兼容数字/字符串表示"""
if isinstance(raw_value, enum_cls):
return raw_value
@@ -514,7 +513,7 @@ class MemoryBuilder:
)
return default
def _collect_bot_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]:
def _collect_bot_identifiers(self, context: dict[str, Any] | None) -> set[str]:
identifiers: set[str] = {"bot", "机器人", "ai助手"}
if not context:
return identifiers
@@ -540,7 +539,7 @@ class MemoryBuilder:
return identifiers
def _collect_system_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]:
def _collect_system_identifiers(self, context: dict[str, Any] | None) -> set[str]:
identifiers: set[str] = set()
if not context:
return identifiers
@@ -568,8 +567,8 @@ class MemoryBuilder:
return identifiers
def _resolve_conversation_participants(self, context: Optional[Dict[str, Any]], user_id: str) -> List[str]:
participants: List[str] = []
def _resolve_conversation_participants(self, context: dict[str, Any] | None, user_id: str) -> list[str]:
participants: list[str] = []
if context:
candidate_keys = [
@@ -609,7 +608,7 @@ class MemoryBuilder:
if not participants:
participants = ["对话参与者"]
deduplicated: List[str] = []
deduplicated: list[str] = []
seen = set()
for name in participants:
key = name.lower()
@@ -620,7 +619,7 @@ class MemoryBuilder:
return deduplicated
def _resolve_user_display(self, context: Optional[Dict[str, Any]], user_id: str) -> str:
def _resolve_user_display(self, context: dict[str, Any] | None, user_id: str) -> str:
candidate_keys = [
"user_display_name",
"user_name",
@@ -683,7 +682,7 @@ class MemoryBuilder:
return False
def _split_subject_string(self, value: str) -> List[str]:
def _split_subject_string(self, value: str) -> list[str]:
if not value:
return []
@@ -699,12 +698,12 @@ class MemoryBuilder:
subject: Any,
bot_identifiers: set[str],
system_identifiers: set[str],
default_subjects: List[str],
bot_display: Optional[str] = None,
) -> List[str]:
default_subjects: list[str],
bot_display: str | None = None,
) -> list[str]:
defaults = default_subjects or ["对话参与者"]
raw_candidates: List[str] = []
raw_candidates: list[str] = []
if isinstance(subject, list):
for item in subject:
if isinstance(item, str):
@@ -716,7 +715,7 @@ class MemoryBuilder:
elif subject is not None:
raw_candidates.extend(self._split_subject_string(str(subject)))
normalized: List[str] = []
normalized: list[str] = []
bot_primary = self._clean_subject_text(bot_display or "")
for candidate in raw_candidates:
@@ -741,7 +740,7 @@ class MemoryBuilder:
if not normalized:
normalized = list(defaults)
deduplicated: List[str] = []
deduplicated: list[str] = []
seen = set()
for name in normalized:
key = name.lower()
@@ -752,7 +751,7 @@ class MemoryBuilder:
return deduplicated
def _extract_value_from_object(self, obj: Union[str, Dict[str, Any], List[Any]], keys: List[str]) -> Optional[str]:
def _extract_value_from_object(self, obj: str | dict[str, Any] | list[Any], keys: list[str]) -> str | None:
if isinstance(obj, dict):
for key in keys:
value = obj.get(key)
@@ -773,9 +772,7 @@ class MemoryBuilder:
return obj.strip() or None
return None
def _compose_display_text(
self, subjects: List[str], predicate: str, obj: Union[str, Dict[str, Any], List[Any]]
) -> str:
def _compose_display_text(self, subjects: list[str], predicate: str, obj: str | dict[str, Any] | list[Any]) -> str:
subject_phrase = "".join(subjects) if subjects else "对话参与者"
predicate = (predicate or "").strip()
@@ -841,7 +838,7 @@ class MemoryBuilder:
return f"{subject_phrase}{predicate}".strip()
return subject_phrase
def _validate_and_enhance_memories(self, memories: List[MemoryChunk], context: Dict[str, Any]) -> List[MemoryChunk]:
def _validate_and_enhance_memories(self, memories: list[MemoryChunk], context: dict[str, Any]) -> list[MemoryChunk]:
"""验证和增强记忆"""
validated_memories = []
@@ -876,7 +873,7 @@ class MemoryBuilder:
return True
def _enhance_memory(self, memory: MemoryChunk, context: Dict[str, Any]) -> MemoryChunk:
def _enhance_memory(self, memory: MemoryChunk, context: dict[str, Any]) -> MemoryChunk:
"""增强记忆块"""
# 时间规范化处理
self._normalize_time_in_memory(memory)
@@ -985,7 +982,7 @@ class MemoryBuilder:
total_confidence / self.extraction_stats["successful_extractions"]
)
def get_extraction_stats(self) -> Dict[str, Any]:
def get_extraction_stats(self) -> dict[str, Any]:
"""获取提取统计信息"""
return self.extraction_stats.copy()

View File

@@ -1,18 +1,19 @@
# -*- coding: utf-8 -*-
"""
结构化记忆单元设计
实现高质量、结构化的记忆单元,符合文档设计规范
"""
import hashlib
import time
import uuid
import orjson
from typing import Dict, List, Optional, Any, Union, Iterable
from collections.abc import Iterable
from dataclasses import dataclass, field
from enum import Enum
import hashlib
from typing import Any
import numpy as np
import orjson
from src.common.logger import get_logger
logger = get_logger(__name__)
@@ -56,17 +57,17 @@ class ImportanceLevel(Enum):
class ContentStructure:
"""主谓宾结构,包含自然语言描述"""
subject: Union[str, List[str]]
subject: str | list[str]
predicate: str
object: Union[str, Dict]
object: str | dict
display: str = ""
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""转换为字典格式"""
return {"subject": self.subject, "predicate": self.predicate, "object": self.object, "display": self.display}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ContentStructure":
def from_dict(cls, data: dict[str, Any]) -> "ContentStructure":
"""从字典创建实例"""
return cls(
subject=data.get("subject", ""),
@@ -75,7 +76,7 @@ class ContentStructure:
display=data.get("display", ""),
)
def to_subject_list(self) -> List[str]:
def to_subject_list(self) -> list[str]:
"""将主语转换为列表形式"""
if isinstance(self.subject, list):
return [s for s in self.subject if isinstance(s, str) and s.strip()]
@@ -99,7 +100,7 @@ class MemoryMetadata:
# 基础信息
memory_id: str # 唯一标识符
user_id: str # 用户ID
chat_id: Optional[str] = None # 聊天ID群聊或私聊
chat_id: str | None = None # 聊天ID群聊或私聊
# 时间信息
created_at: float = 0.0 # 创建时间戳
@@ -124,9 +125,9 @@ class MemoryMetadata:
last_forgetting_check: float = 0.0 # 上次遗忘检查时间
# 来源信息
source_context: Optional[str] = None # 来源上下文片段
source_context: str | None = None # 来源上下文片段
# 兼容旧字段: 一些代码或旧版本可能直接访问 metadata.source
source: Optional[str] = None
source: str | None = None
def __post_init__(self):
"""后初始化处理"""
@@ -209,7 +210,7 @@ class MemoryMetadata:
# 设置最小和最大阈值
return max(7.0, min(threshold, 365.0)) # 7天到1年之间
def should_forget(self, current_time: Optional[float] = None) -> bool:
def should_forget(self, current_time: float | None = None) -> bool:
"""判断是否应该遗忘"""
if current_time is None:
current_time = time.time()
@@ -222,7 +223,7 @@ class MemoryMetadata:
return days_since_activation > self.forgetting_threshold
def is_dormant(self, current_time: Optional[float] = None, inactive_days: int = 90) -> bool:
def is_dormant(self, current_time: float | None = None, inactive_days: int = 90) -> bool:
"""判断是否处于休眠状态(长期未激活)"""
if current_time is None:
current_time = time.time()
@@ -230,7 +231,7 @@ class MemoryMetadata:
days_since_last_access = (current_time - self.last_accessed) / 86400
return days_since_last_access > inactive_days
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""转换为字典格式"""
return {
"memory_id": self.memory_id,
@@ -252,7 +253,7 @@ class MemoryMetadata:
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MemoryMetadata":
def from_dict(cls, data: dict[str, Any]) -> "MemoryMetadata":
"""从字典创建实例"""
return cls(
memory_id=data.get("memory_id", ""),
@@ -286,17 +287,17 @@ class MemoryChunk:
memory_type: MemoryType # 记忆类型
# 扩展信息
keywords: List[str] = field(default_factory=list) # 关键词列表
tags: List[str] = field(default_factory=list) # 标签列表
categories: List[str] = field(default_factory=list) # 分类列表
keywords: list[str] = field(default_factory=list) # 关键词列表
tags: list[str] = field(default_factory=list) # 标签列表
categories: list[str] = field(default_factory=list) # 分类列表
# 语义信息
embedding: Optional[List[float]] = None # 语义向量
semantic_hash: Optional[str] = None # 语义哈希值
embedding: list[float] | None = None # 语义向量
semantic_hash: str | None = None # 语义哈希值
# 关联信息
related_memories: List[str] = field(default_factory=list) # 关联记忆ID列表
temporal_context: Optional[Dict[str, Any]] = None # 时间上下文
related_memories: list[str] = field(default_factory=list) # 关联记忆ID列表
temporal_context: dict[str, Any] | None = None # 时间上下文
def __post_init__(self):
"""后初始化处理"""
@@ -310,7 +311,7 @@ class MemoryChunk:
try:
# 使用向量和内容生成稳定的哈希
content_str = f"{self.content.subject}:{self.content.predicate}:{str(self.content.object)}"
content_str = f"{self.content.subject}:{self.content.predicate}:{self.content.object!s}"
embedding_str = ",".join(map(str, [round(x, 6) for x in self.embedding]))
hash_input = f"{content_str}|{embedding_str}"
@@ -342,7 +343,7 @@ class MemoryChunk:
return self.content.display or str(self.content)
@property
def subjects(self) -> List[str]:
def subjects(self) -> list[str]:
"""获取主语列表"""
return self.content.to_subject_list()
@@ -354,11 +355,11 @@ class MemoryChunk:
"""更新相关度评分"""
self.metadata.update_relevance(new_score)
def should_forget(self, current_time: Optional[float] = None) -> bool:
def should_forget(self, current_time: float | None = None) -> bool:
"""判断是否应该遗忘"""
return self.metadata.should_forget(current_time)
def is_dormant(self, current_time: Optional[float] = None, inactive_days: int = 90) -> bool:
def is_dormant(self, current_time: float | None = None, inactive_days: int = 90) -> bool:
"""判断是否处于休眠状态(长期未激活)"""
return self.metadata.is_dormant(current_time, inactive_days)
@@ -386,7 +387,7 @@ class MemoryChunk:
if memory_id and memory_id not in self.related_memories:
self.related_memories.append(memory_id)
def set_embedding(self, embedding: List[float]):
def set_embedding(self, embedding: list[float]):
"""设置语义向量"""
self.embedding = embedding
self._generate_semantic_hash()
@@ -415,7 +416,7 @@ class MemoryChunk:
logger.warning(f"计算记忆相似度失败: {e}")
return 0.0
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""转换为完整的字典格式"""
return {
"metadata": self.metadata.to_dict(),
@@ -431,7 +432,7 @@ class MemoryChunk:
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MemoryChunk":
def from_dict(cls, data: dict[str, Any]) -> "MemoryChunk":
"""从字典创建实例"""
metadata = MemoryMetadata.from_dict(data.get("metadata", {}))
content = ContentStructure.from_dict(data.get("content", {}))
@@ -541,7 +542,7 @@ class MemoryChunk:
return f"MemoryChunk(id={self.memory_id[:8]}..., type={self.memory_type.value}, user={self.user_id})"
def _build_display_text(subjects: Iterable[str], predicate: str, obj: Union[str, Dict]) -> str:
def _build_display_text(subjects: Iterable[str], predicate: str, obj: str | dict) -> str:
"""根据主谓宾生成自然语言描述"""
subjects_clean = [s.strip() for s in subjects if s and isinstance(s, str)]
subject_part = "".join(subjects_clean) if subjects_clean else "对话参与者"
@@ -569,15 +570,15 @@ def _build_display_text(subjects: Iterable[str], predicate: str, obj: Union[str,
def create_memory_chunk(
user_id: str,
subject: Union[str, List[str]],
subject: str | list[str],
predicate: str,
obj: Union[str, Dict],
obj: str | dict,
memory_type: MemoryType,
chat_id: Optional[str] = None,
source_context: Optional[str] = None,
chat_id: str | None = None,
source_context: str | None = None,
importance: ImportanceLevel = ImportanceLevel.NORMAL,
confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM,
display: Optional[str] = None,
display: str | None = None,
**kwargs,
) -> MemoryChunk:
"""便捷的内存块创建函数"""
@@ -593,10 +594,10 @@ def create_memory_chunk(
source_context=source_context,
)
subjects: List[str]
subjects: list[str]
if isinstance(subject, list):
subjects = [s for s in subject if isinstance(s, str) and s.strip()]
subject_payload: Union[str, List[str]] = subjects
subject_payload: str | list[str] = subjects
else:
cleaned = subject.strip() if isinstance(subject, str) else ""
subjects = [cleaned] if cleaned else []

View File

@@ -1,17 +1,15 @@
# -*- coding: utf-8 -*-
"""
智能记忆遗忘引擎
基于重要程度、置信度和激活频率的智能遗忘机制
"""
import time
import asyncio
from typing import List, Dict, Optional, Tuple
from datetime import datetime
import time
from dataclasses import dataclass
from datetime import datetime
from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk
from src.common.logger import get_logger
from src.chat.memory_system.memory_chunk import MemoryChunk, ImportanceLevel, ConfidenceLevel
logger = get_logger(__name__)
@@ -65,7 +63,7 @@ class ForgettingConfig:
class MemoryForgettingEngine:
"""智能记忆遗忘引擎"""
def __init__(self, config: Optional[ForgettingConfig] = None):
def __init__(self, config: ForgettingConfig | None = None):
self.config = config or ForgettingConfig()
self.stats = ForgettingStats()
self._last_forgetting_check = 0.0
@@ -116,7 +114,7 @@ class MemoryForgettingEngine:
# 确保在合理范围内
return max(self.config.min_forgetting_days, min(threshold, self.config.max_forgetting_days))
def should_forget_memory(self, memory: MemoryChunk, current_time: Optional[float] = None) -> bool:
def should_forget_memory(self, memory: MemoryChunk, current_time: float | None = None) -> bool:
"""
判断记忆是否应该被遗忘
@@ -155,7 +153,7 @@ class MemoryForgettingEngine:
return should_forget
def is_dormant_memory(self, memory: MemoryChunk, current_time: Optional[float] = None) -> bool:
def is_dormant_memory(self, memory: MemoryChunk, current_time: float | None = None) -> bool:
"""
判断记忆是否处于休眠状态
@@ -168,7 +166,7 @@ class MemoryForgettingEngine:
"""
return memory.is_dormant(current_time, self.config.dormant_threshold_days)
def should_force_forget_dormant(self, memory: MemoryChunk, current_time: Optional[float] = None) -> bool:
def should_force_forget_dormant(self, memory: MemoryChunk, current_time: float | None = None) -> bool:
"""
判断是否应该强制遗忘休眠记忆
@@ -189,7 +187,7 @@ class MemoryForgettingEngine:
days_since_last_access = (current_time - memory.metadata.last_accessed) / 86400
return days_since_last_access > self.config.force_forget_dormant_days
async def check_memories_for_forgetting(self, memories: List[MemoryChunk]) -> Tuple[List[str], List[str]]:
async def check_memories_for_forgetting(self, memories: list[MemoryChunk]) -> tuple[list[str], list[str]]:
"""
检查记忆列表,识别需要遗忘的记忆
@@ -241,7 +239,7 @@ class MemoryForgettingEngine:
return normal_forgetting_ids, force_forgetting_ids
async def perform_forgetting_check(self, memories: List[MemoryChunk]) -> Dict[str, any]:
async def perform_forgetting_check(self, memories: list[MemoryChunk]) -> dict[str, any]:
"""
执行完整的遗忘检查流程
@@ -314,7 +312,7 @@ class MemoryForgettingEngine:
except Exception as e:
logger.error(f"定期遗忘检查失败: {e}", exc_info=True)
def get_forgetting_stats(self) -> Dict[str, any]:
def get_forgetting_stats(self) -> dict[str, any]:
"""获取遗忘统计信息"""
return {
"total_checked": self.stats.total_checked,

View File

@@ -1,16 +1,14 @@
# -*- coding: utf-8 -*-
"""
记忆融合与去重机制
避免记忆碎片化,确保长期记忆库的高质量
"""
import time
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
from typing import Any
from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk
from src.common.logger import get_logger
from src.chat.memory_system.memory_chunk import MemoryChunk, ConfidenceLevel, ImportanceLevel
logger = get_logger(__name__)
@@ -22,9 +20,9 @@ class FusionResult:
original_count: int
fused_count: int
removed_duplicates: int
merged_memories: List[MemoryChunk]
merged_memories: list[MemoryChunk]
fusion_time: float
details: List[str]
details: list[str]
@dataclass
@@ -32,9 +30,9 @@ class DuplicateGroup:
"""重复记忆组"""
group_id: str
memories: List[MemoryChunk]
similarity_matrix: List[List[float]]
representative_memory: Optional[MemoryChunk] = None
memories: list[MemoryChunk]
similarity_matrix: list[list[float]]
representative_memory: MemoryChunk | None = None
class MemoryFusionEngine:
@@ -59,8 +57,8 @@ class MemoryFusionEngine:
}
async def fuse_memories(
self, new_memories: List[MemoryChunk], existing_memories: Optional[List[MemoryChunk]] = None
) -> List[MemoryChunk]:
self, new_memories: list[MemoryChunk], existing_memories: list[MemoryChunk] | None = None
) -> list[MemoryChunk]:
"""融合记忆列表"""
start_time = time.time()
@@ -106,8 +104,8 @@ class MemoryFusionEngine:
return new_memories # 失败时返回原始记忆
async def _detect_duplicate_groups(
self, new_memories: List[MemoryChunk], existing_memories: List[MemoryChunk]
) -> List[DuplicateGroup]:
self, new_memories: list[MemoryChunk], existing_memories: list[MemoryChunk]
) -> list[DuplicateGroup]:
"""检测重复记忆组"""
all_memories = new_memories + existing_memories
new_memory_ids = {memory.memory_id for memory in new_memories}
@@ -212,7 +210,7 @@ class MemoryFusionEngine:
jaccard_similarity = len(intersection) / len(union)
return jaccard_similarity
def _calculate_keyword_similarity(self, keywords1: List[str], keywords2: List[str]) -> float:
def _calculate_keyword_similarity(self, keywords1: list[str], keywords2: list[str]) -> float:
"""计算关键词相似度"""
if not keywords1 or not keywords2:
return 0.0
@@ -302,7 +300,7 @@ class MemoryFusionEngine:
return best_memory
async def _fuse_memory_group(self, group: DuplicateGroup) -> Optional[MemoryChunk]:
async def _fuse_memory_group(self, group: DuplicateGroup) -> MemoryChunk | None:
"""融合记忆组"""
if not group.memories:
return None
@@ -328,7 +326,7 @@ class MemoryFusionEngine:
# 返回置信度最高的记忆
return max(group.memories, key=lambda m: m.metadata.confidence.value)
async def _merge_memory_attributes(self, base_memory: MemoryChunk, memories: List[MemoryChunk]) -> MemoryChunk:
async def _merge_memory_attributes(self, base_memory: MemoryChunk, memories: list[MemoryChunk]) -> MemoryChunk:
"""合并记忆属性"""
# 创建基础记忆的深拷贝
fused_memory = MemoryChunk.from_dict(base_memory.to_dict())
@@ -395,7 +393,7 @@ class MemoryFusionEngine:
source_ids = [m.memory_id[:8] for m in group.memories]
fused_memory.metadata.source_context = f"Fused from {len(group.memories)} memories: {', '.join(source_ids)}"
def _merge_temporal_context(self, memories: List[MemoryChunk]) -> Dict[str, Any]:
def _merge_temporal_context(self, memories: list[MemoryChunk]) -> dict[str, Any]:
"""合并时间上下文"""
contexts = [m.temporal_context for m in memories if m.temporal_context]
@@ -426,8 +424,8 @@ class MemoryFusionEngine:
return merged_context
async def incremental_fusion(
self, new_memory: MemoryChunk, existing_memories: List[MemoryChunk]
) -> Tuple[MemoryChunk, List[MemoryChunk]]:
self, new_memory: MemoryChunk, existing_memories: list[MemoryChunk]
) -> tuple[MemoryChunk, list[MemoryChunk]]:
"""增量融合(单个新记忆与现有记忆融合)"""
# 寻找相似记忆
similar_memories = []
@@ -493,7 +491,7 @@ class MemoryFusionEngine:
except Exception as e:
logger.error(f"❌ 记忆融合引擎维护失败: {e}", exc_info=True)
def get_fusion_stats(self) -> Dict[str, Any]:
def get_fusion_stats(self) -> dict[str, Any]:
"""获取融合统计信息"""
return self.fusion_stats.copy()

View File

@@ -1,17 +1,15 @@
# -*- coding: utf-8 -*-
"""
记忆系统管理器
替代原有的 Hippocampus 和 instant_memory 系统
"""
import re
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
from typing import Any
from src.common.logger import get_logger
from src.chat.memory_system.memory_system import MemorySystem
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.chat.memory_system.memory_system import initialize_memory_system
from src.chat.memory_system.memory_system import MemorySystem, initialize_memory_system
from src.common.logger import get_logger
logger = get_logger(__name__)
@@ -27,14 +25,14 @@ class MemoryResult:
timestamp: float
source: str = "memory"
relevance_score: float = 0.0
structure: Dict[str, Any] | None = None
structure: dict[str, Any] | None = None
class MemoryManager:
"""记忆系统管理器 - 替代原有的 HippocampusManager"""
def __init__(self):
self.memory_system: Optional[MemorySystem] = None
self.memory_system: MemorySystem | None = None
self.is_initialized = False
self.user_cache = {} # 用户记忆缓存
@@ -63,8 +61,8 @@ class MemoryManager:
logger.info("正在初始化记忆系统...")
# 获取LLM模型
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
llm_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="memory")
@@ -121,7 +119,7 @@ class MemoryManager:
max_memory_length: int = 2,
time_weight: float = 1.0,
keyword_weight: float = 1.0,
) -> List[Tuple[str, str]]:
) -> list[tuple[str, str]]:
"""从文本获取相关记忆 - 兼容原有接口"""
if not self.is_initialized or not self.memory_system:
return []
@@ -152,8 +150,8 @@ class MemoryManager:
return []
async def get_memory_from_topic(
self, valid_keywords: List[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
) -> List[Tuple[str, str]]:
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
) -> list[tuple[str, str]]:
"""从关键词获取记忆 - 兼容原有接口"""
if not self.is_initialized or not self.memory_system:
return []
@@ -208,8 +206,8 @@ class MemoryManager:
return []
async def process_conversation(
self, conversation_text: str, context: Dict[str, Any], user_id: str, timestamp: Optional[float] = None
) -> List[MemoryChunk]:
self, conversation_text: str, context: dict[str, Any], user_id: str, timestamp: float | None = None
) -> list[MemoryChunk]:
"""处理对话并构建记忆 - 新增功能"""
if not self.is_initialized or not self.memory_system:
return []
@@ -235,8 +233,8 @@ class MemoryManager:
return []
async def get_enhanced_memory_context(
self, query_text: str, user_id: str, context: Optional[Dict[str, Any]] = None, limit: int = 5
) -> List[MemoryResult]:
self, query_text: str, user_id: str, context: dict[str, Any] | None = None, limit: int = 5
) -> list[MemoryResult]:
"""获取增强记忆上下文 - 新增功能"""
if not self.is_initialized or not self.memory_system:
return []
@@ -267,7 +265,7 @@ class MemoryManager:
logger.error(f"get_enhanced_memory_context 失败: {e}")
return []
def _format_memory_chunk(self, memory: MemoryChunk) -> Tuple[str, Dict[str, Any]]:
def _format_memory_chunk(self, memory: MemoryChunk) -> tuple[str, dict[str, Any]]:
"""将记忆块转换为更易读的文本描述"""
structure = memory.content.to_dict()
if memory.display:
@@ -289,7 +287,7 @@ class MemoryManager:
return formatted, structure
def _format_subject(self, subject: Optional[str], memory: MemoryChunk) -> str:
def _format_subject(self, subject: str | None, memory: MemoryChunk) -> str:
if not subject:
return "该用户"
@@ -299,7 +297,7 @@ class MemoryManager:
return "该聊天"
return self._clean_text(subject)
def _apply_predicate_format(self, subject: str, predicate: str, obj: Any) -> Optional[str]:
def _apply_predicate_format(self, subject: str, predicate: str, obj: Any) -> str | None:
predicate = (predicate or "").strip()
obj_value = obj
@@ -446,10 +444,10 @@ class MemoryManager:
text = self._truncate(str(obj).strip())
return self._clean_text(text)
def _extract_from_object(self, obj: Any, keys: List[str]) -> Optional[str]:
def _extract_from_object(self, obj: Any, keys: list[str]) -> str | None:
if isinstance(obj, dict):
for key in keys:
if key in obj and obj[key]:
if obj.get(key):
value = obj[key]
if isinstance(value, (dict, list)):
return self._clean_text(self._format_object(value))

View File

@@ -1,15 +1,15 @@
# -*- coding: utf-8 -*-
"""
记忆元数据索引管理器
使用JSON文件存储记忆元数据支持快速模糊搜索和过滤
"""
import orjson
import threading
from pathlib import Path
from typing import Dict, List, Optional, Set, Any
from dataclasses import dataclass, asdict
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from typing import Any
import orjson
from src.common.logger import get_logger
@@ -25,10 +25,10 @@ class MemoryMetadataIndexEntry:
# 分类信息
memory_type: str # MemoryType.value
subjects: List[str] # 主语列表
objects: List[str] # 宾语列表
keywords: List[str] # 关键词列表
tags: List[str] # 标签列表
subjects: list[str] # 主语列表
objects: list[str] # 宾语列表
keywords: list[str] # 关键词列表
tags: list[str] # 标签列表
# 数值字段(用于范围过滤)
importance: int # ImportanceLevel.value (1-4)
@@ -37,8 +37,8 @@ class MemoryMetadataIndexEntry:
access_count: int # 访问次数
# 可选字段
chat_id: Optional[str] = None
content_preview: Optional[str] = None # 内容预览前100字符
chat_id: str | None = None
content_preview: str | None = None # 内容预览前100字符
class MemoryMetadataIndex:
@@ -46,13 +46,13 @@ class MemoryMetadataIndex:
def __init__(self, index_file: str = "data/memory_metadata_index.json"):
self.index_file = Path(index_file)
self.index: Dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry
self.index: dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry
# 倒排索引(用于快速查找)
self.type_index: Dict[str, Set[str]] = {} # type -> {memory_ids}
self.subject_index: Dict[str, Set[str]] = {} # subject -> {memory_ids}
self.keyword_index: Dict[str, Set[str]] = {} # keyword -> {memory_ids}
self.tag_index: Dict[str, Set[str]] = {} # tag -> {memory_ids}
self.type_index: dict[str, set[str]] = {} # type -> {memory_ids}
self.subject_index: dict[str, set[str]] = {} # subject -> {memory_ids}
self.keyword_index: dict[str, set[str]] = {} # keyword -> {memory_ids}
self.tag_index: dict[str, set[str]] = {} # tag -> {memory_ids}
self.lock = threading.RLock()
@@ -178,7 +178,7 @@ class MemoryMetadataIndex:
self._remove_from_inverted_indices(memory_id)
del self.index[memory_id]
def batch_add_or_update(self, entries: List[MemoryMetadataIndexEntry]):
def batch_add_or_update(self, entries: list[MemoryMetadataIndexEntry]):
"""批量添加或更新"""
with self.lock:
for entry in entries:
@@ -191,18 +191,18 @@ class MemoryMetadataIndex:
def search(
self,
memory_types: Optional[List[str]] = None,
subjects: Optional[List[str]] = None,
keywords: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
importance_min: Optional[int] = None,
importance_max: Optional[int] = None,
created_after: Optional[float] = None,
created_before: Optional[float] = None,
user_id: Optional[str] = None,
limit: Optional[int] = None,
memory_types: list[str] | None = None,
subjects: list[str] | None = None,
keywords: list[str] | None = None,
tags: list[str] | None = None,
importance_min: int | None = None,
importance_max: int | None = None,
created_after: float | None = None,
created_before: float | None = None,
user_id: str | None = None,
limit: int | None = None,
flexible_mode: bool = True, # 新增:灵活匹配模式
) -> List[str]:
) -> list[str]:
"""
搜索符合条件的记忆ID列表支持模糊匹配
@@ -237,14 +237,14 @@ class MemoryMetadataIndex:
def _search_flexible(
self,
memory_types: Optional[List[str]] = None,
subjects: Optional[List[str]] = None,
created_after: Optional[float] = None,
created_before: Optional[float] = None,
user_id: Optional[str] = None,
limit: Optional[int] = None,
memory_types: list[str] | None = None,
subjects: list[str] | None = None,
created_after: float | None = None,
created_before: float | None = None,
user_id: str | None = None,
limit: int | None = None,
**kwargs, # 接受但不使用的参数
) -> List[str]:
) -> list[str]:
"""
灵活搜索模式2/4项匹配即可支持部分匹配
@@ -374,20 +374,20 @@ class MemoryMetadataIndex:
def _search_strict(
self,
memory_types: Optional[List[str]] = None,
subjects: Optional[List[str]] = None,
keywords: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
importance_min: Optional[int] = None,
importance_max: Optional[int] = None,
created_after: Optional[float] = None,
created_before: Optional[float] = None,
user_id: Optional[str] = None,
limit: Optional[int] = None,
) -> List[str]:
memory_types: list[str] | None = None,
subjects: list[str] | None = None,
keywords: list[str] | None = None,
tags: list[str] | None = None,
importance_min: int | None = None,
importance_max: int | None = None,
created_after: float | None = None,
created_before: float | None = None,
user_id: str | None = None,
limit: int | None = None,
) -> list[str]:
"""严格搜索模式(原有逻辑)"""
# 初始候选集(所有记忆)
candidate_ids: Optional[Set[str]] = None
candidate_ids: set[str] | None = None
# 用户过滤(必选)
if user_id:
@@ -471,11 +471,11 @@ class MemoryMetadataIndex:
return result_ids
def get_entry(self, memory_id: str) -> Optional[MemoryMetadataIndexEntry]:
def get_entry(self, memory_id: str) -> MemoryMetadataIndexEntry | None:
"""获取单个索引条目"""
return self.index.get(memory_id)
def get_stats(self) -> Dict[str, Any]:
def get_stats(self) -> dict[str, Any]:
"""获取索引统计信息"""
with self.lock:
return {

View File

@@ -1,11 +1,10 @@
# -*- coding: utf-8 -*-
"""记忆检索查询规划器"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from typing import Any
import orjson
@@ -21,16 +20,16 @@ class MemoryQueryPlan:
"""查询规划结果"""
semantic_query: str
memory_types: List[MemoryType] = field(default_factory=list)
subject_includes: List[str] = field(default_factory=list)
object_includes: List[str] = field(default_factory=list)
required_keywords: List[str] = field(default_factory=list)
optional_keywords: List[str] = field(default_factory=list)
owner_filters: List[str] = field(default_factory=list)
memory_types: list[MemoryType] = field(default_factory=list)
subject_includes: list[str] = field(default_factory=list)
object_includes: list[str] = field(default_factory=list)
required_keywords: list[str] = field(default_factory=list)
optional_keywords: list[str] = field(default_factory=list)
owner_filters: list[str] = field(default_factory=list)
recency_preference: str = "any"
limit: int = 10
emphasis: Optional[str] = None
raw_plan: Dict[str, Any] = field(default_factory=dict)
emphasis: str | None = None
raw_plan: dict[str, Any] = field(default_factory=dict)
def ensure_defaults(self, fallback_query: str, default_limit: int) -> None:
if not self.semantic_query:
@@ -46,11 +45,11 @@ class MemoryQueryPlan:
class MemoryQueryPlanner:
"""基于小模型的记忆检索查询规划器"""
def __init__(self, planner_model: Optional[LLMRequest], default_limit: int = 10):
def __init__(self, planner_model: LLMRequest | None, default_limit: int = 10):
self.model = planner_model
self.default_limit = default_limit
async def plan_query(self, query_text: str, context: Dict[str, Any]) -> MemoryQueryPlan:
async def plan_query(self, query_text: str, context: dict[str, Any]) -> MemoryQueryPlan:
if not self.model:
logger.debug("未提供查询规划模型,使用默认规划")
return self._default_plan(query_text)
@@ -82,10 +81,10 @@ class MemoryQueryPlanner:
def _default_plan(self, query_text: str) -> MemoryQueryPlan:
return MemoryQueryPlan(semantic_query=query_text, limit=self.default_limit)
def _parse_plan_dict(self, data: Dict[str, Any], fallback_query: str) -> MemoryQueryPlan:
def _parse_plan_dict(self, data: dict[str, Any], fallback_query: str) -> MemoryQueryPlan:
semantic_query = self._safe_str(data.get("semantic_query")) or fallback_query
def _collect_list(key: str) -> List[str]:
def _collect_list(key: str) -> list[str]:
value = data.get(key)
if isinstance(value, str):
return [value]
@@ -94,7 +93,7 @@ class MemoryQueryPlanner:
return []
memory_type_values = _collect_list("memory_types")
memory_types: List[MemoryType] = []
memory_types: list[MemoryType] = []
for item in memory_type_values:
if not item:
continue
@@ -123,7 +122,7 @@ class MemoryQueryPlanner:
)
return plan
def _build_prompt(self, query_text: str, context: Dict[str, Any]) -> str:
def _build_prompt(self, query_text: str, context: dict[str, Any]) -> str:
participants = context.get("participants") or context.get("speaker_names") or []
if isinstance(participants, str):
participants = [participants]
@@ -206,7 +205,7 @@ class MemoryQueryPlanner:
请直接输出符合要求的 JSON 对象,禁止添加额外文本或 Markdown 代码块。
"""
def _extract_json_payload(self, response: str) -> Optional[str]:
def _extract_json_payload(self, response: str) -> str | None:
if not response:
return None

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
精准记忆系统核心模块
1. 基于文档设计的高效记忆构建、存储与召回优化系统,覆盖构建、向量化与多阶段检索全流程。
@@ -6,26 +5,27 @@
"""
import asyncio
import time
import orjson
import re
import hashlib
from typing import Dict, List, Optional, Set, Any, TYPE_CHECKING
import re
import time
from dataclasses import asdict, dataclass
from datetime import datetime, timedelta
from dataclasses import dataclass, asdict
from enum import Enum
from typing import TYPE_CHECKING, Any
import orjson
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.chat.memory_system.memory_builder import MemoryBuilder, MemoryExtractionError
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger(__name__)
@@ -121,7 +121,7 @@ class MemorySystemConfig:
class MemorySystem:
"""精准记忆系统核心类"""
def __init__(self, llm_model: Optional[LLMRequest] = None, config: Optional[MemorySystemConfig] = None):
def __init__(self, llm_model: LLMRequest | None = None, config: MemorySystemConfig | None = None):
self.config = config or MemorySystemConfig.from_global_config()
self.llm_model = llm_model
self.status = MemorySystemStatus.INITIALIZING
@@ -131,7 +131,7 @@ class MemorySystem:
self.fusion_engine: MemoryFusionEngine = None
self.unified_storage = None # 统一存储系统
self.query_planner: MemoryQueryPlanner = None
self.forgetting_engine: Optional[MemoryForgettingEngine] = None
self.forgetting_engine: MemoryForgettingEngine | None = None
# LLM模型
self.value_assessment_model: LLMRequest = None
@@ -143,10 +143,10 @@ class MemorySystem:
self.last_retrieval_time = None
# 构建节流记录
self._last_memory_build_times: Dict[str, float] = {}
self._last_memory_build_times: dict[str, float] = {}
# 记忆指纹缓存,用于快速检测重复记忆
self._memory_fingerprints: Dict[str, str] = {}
self._memory_fingerprints: dict[str, str] = {}
logger.info("MemorySystem 初始化开始")
@@ -210,7 +210,7 @@ class MemorySystem:
raise
# 初始化遗忘引擎
from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine, ForgettingConfig
from src.chat.memory_system.memory_forgetting_engine import ForgettingConfig, MemoryForgettingEngine
# 从全局配置创建遗忘引擎配置
forgetting_config = ForgettingConfig(
@@ -241,7 +241,7 @@ class MemorySystem:
self.forgetting_engine = MemoryForgettingEngine(forgetting_config)
planner_task_config = getattr(model_config.model_task_config, "utils_small", None)
planner_model: Optional[LLMRequest] = None
planner_model: LLMRequest | None = None
try:
planner_model = LLMRequest(model_set=planner_task_config, request_type="memory.query_planner")
except Exception as planner_exc:
@@ -261,8 +261,8 @@ class MemorySystem:
raise
async def retrieve_memories_for_building(
self, query_text: str, user_id: Optional[str] = None, context: Optional[Dict[str, Any]] = None, limit: int = 5
) -> List[MemoryChunk]:
self, query_text: str, user_id: str | None = None, context: dict[str, Any] | None = None, limit: int = 5
) -> list[MemoryChunk]:
"""在构建记忆时检索相关记忆,使用统一存储系统
Args:
@@ -302,8 +302,8 @@ class MemorySystem:
return []
async def build_memory_from_conversation(
self, conversation_text: str, context: Dict[str, Any], timestamp: Optional[float] = None
) -> List[MemoryChunk]:
self, conversation_text: str, context: dict[str, Any], timestamp: float | None = None
) -> list[MemoryChunk]:
"""从对话中构建记忆
Args:
@@ -318,8 +318,8 @@ class MemorySystem:
self.status = MemorySystemStatus.BUILDING
start_time = time.time()
build_scope_key: Optional[str] = None
build_marker_time: Optional[float] = None
build_scope_key: str | None = None
build_marker_time: float | None = None
try:
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, timestamp)
@@ -408,7 +408,7 @@ class MemorySystem:
logger.error(f"❌ 记忆构建失败: {e}", exc_info=True)
raise
def _log_memory_preview(self, memories: List[MemoryChunk]) -> None:
def _log_memory_preview(self, memories: list[MemoryChunk]) -> None:
"""在控制台输出记忆预览,便于人工检查"""
if not memories:
logger.info("📝 本次未生成新的记忆")
@@ -425,12 +425,12 @@ class MemorySystem:
f"置信度={memory.metadata.confidence.name} | 内容={text}"
)
async def _collect_fusion_candidates(self, new_memories: List[MemoryChunk]) -> List[MemoryChunk]:
async def _collect_fusion_candidates(self, new_memories: list[MemoryChunk]) -> list[MemoryChunk]:
"""收集与新记忆相似的现有记忆,便于融合去重"""
if not new_memories:
return []
candidate_ids: Set[str] = set()
candidate_ids: set[str] = set()
new_memory_ids = {memory.memory_id for memory in new_memories if memory and getattr(memory, "memory_id", None)}
# 基于指纹的直接匹配
@@ -493,7 +493,7 @@ class MemorySystem:
continue
candidate_ids.add(memory_id)
existing_candidates: List[MemoryChunk] = []
existing_candidates: list[MemoryChunk] = []
cache = self.unified_storage.memory_cache if self.unified_storage else {}
for candidate_id in candidate_ids:
if candidate_id in new_memory_ids:
@@ -511,7 +511,7 @@ class MemorySystem:
return existing_candidates
async def process_conversation_memory(self, context: Dict[str, Any]) -> Dict[str, Any]:
async def process_conversation_memory(self, context: dict[str, Any]) -> dict[str, Any]:
"""对外暴露的对话记忆处理接口,仅依赖上下文信息"""
start_time = time.time()
@@ -559,12 +559,12 @@ class MemorySystem:
async def retrieve_relevant_memories(
self,
query_text: Optional[str] = None,
user_id: Optional[str] = None,
context: Optional[Dict[str, Any]] = None,
query_text: str | None = None,
user_id: str | None = None,
context: dict[str, Any] | None = None,
limit: int = 5,
**kwargs,
) -> List[MemoryChunk]:
) -> list[MemoryChunk]:
"""检索相关记忆(三阶段召回:元数据粗筛 → 向量精筛 → 综合重排)"""
raw_query = query_text or kwargs.get("query")
if not raw_query:
@@ -750,7 +750,7 @@ class MemorySystem:
raise
@staticmethod
def _extract_json_payload(response: str) -> Optional[str]:
def _extract_json_payload(response: str) -> str | None:
"""从模型响应中提取JSON部分兼容Markdown代码块等格式"""
if not response:
return None
@@ -773,10 +773,10 @@ class MemorySystem:
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
def _normalize_context(
self, raw_context: Optional[Dict[str, Any]], user_id: Optional[str], timestamp: Optional[float]
) -> Dict[str, Any]:
self, raw_context: dict[str, Any] | None, user_id: str | None, timestamp: float | None
) -> dict[str, Any]:
"""标准化上下文,确保必备字段存在且格式正确"""
context: Dict[str, Any] = {}
context: dict[str, Any] = {}
if raw_context:
try:
context = dict(raw_context)
@@ -822,7 +822,7 @@ class MemorySystem:
return context
async def _build_enhanced_query_context(self, raw_query: str, normalized_context: Dict[str, Any]) -> Dict[str, Any]:
async def _build_enhanced_query_context(self, raw_query: str, normalized_context: dict[str, Any]) -> dict[str, Any]:
"""构建包含未读消息综合上下文的增强查询上下文
Args:
@@ -861,7 +861,7 @@ class MemorySystem:
return enhanced_context
async def _collect_unread_messages_context(self, stream_id: str) -> Optional[Dict[str, Any]]:
async def _collect_unread_messages_context(self, stream_id: str) -> dict[str, Any] | None:
"""收集未读消息的综合上下文信息
Args:
@@ -953,7 +953,7 @@ class MemorySystem:
logger.warning(f"收集未读消息上下文失败: {e}", exc_info=True)
return None
def _build_unread_context_summary(self, messages_summary: List[Dict[str, Any]]) -> str:
def _build_unread_context_summary(self, messages_summary: list[dict[str, Any]]) -> str:
"""构建未读消息的文本摘要
Args:
@@ -974,7 +974,7 @@ class MemorySystem:
return " | ".join(summary_parts)
async def _resolve_conversation_context(self, fallback_text: str, context: Optional[Dict[str, Any]]) -> str:
async def _resolve_conversation_context(self, fallback_text: str, context: dict[str, Any] | None) -> str:
"""使用 stream_id 历史消息和相关记忆充实对话文本,默认回退到传入文本"""
if not context:
return fallback_text
@@ -1043,11 +1043,11 @@ class MemorySystem:
# 回退到传入文本
return fallback_text
def _get_build_scope_key(self, context: Dict[str, Any], user_id: Optional[str]) -> Optional[str]:
def _get_build_scope_key(self, context: dict[str, Any], user_id: str | None) -> str | None:
"""确定用于节流控制的记忆构建作用域"""
return "global_scope"
def _determine_history_limit(self, context: Dict[str, Any]) -> int:
def _determine_history_limit(self, context: dict[str, Any]) -> int:
"""确定历史消息获取数量限制在30-50之间"""
default_limit = 40
candidate = context.get("history_limit") or context.get("history_window") or context.get("memory_history_limit")
@@ -1065,12 +1065,12 @@ class MemorySystem:
return history_limit
def _format_history_messages(self, messages: List["DatabaseMessages"]) -> Optional[str]:
def _format_history_messages(self, messages: list["DatabaseMessages"]) -> str | None:
"""将历史消息格式化为可供LLM处理的多轮对话文本"""
if not messages:
return None
lines: List[str] = []
lines: list[str] = []
for msg in messages:
try:
content = getattr(msg, "processed_plain_text", None) or getattr(msg, "display_message", None)
@@ -1105,7 +1105,7 @@ class MemorySystem:
return "\n".join(lines) if lines else None
async def _assess_information_value(self, text: str, context: Dict[str, Any]) -> float:
async def _assess_information_value(self, text: str, context: dict[str, Any]) -> float:
"""评估信息价值
Args:
@@ -1201,7 +1201,7 @@ class MemorySystem:
logger.error(f"信息价值评估失败: {e}", exc_info=True)
return 0.5 # 默认中等价值
async def _store_memories_unified(self, memory_chunks: List[MemoryChunk]) -> int:
async def _store_memories_unified(self, memory_chunks: list[MemoryChunk]) -> int:
"""使用统一存储系统存储记忆块"""
if not memory_chunks or not self.unified_storage:
return 0
@@ -1222,7 +1222,7 @@ class MemorySystem:
return 0
# 保留原有方法以兼容旧代码
async def _store_memories(self, memory_chunks: List[MemoryChunk]) -> int:
async def _store_memories(self, memory_chunks: list[MemoryChunk]) -> int:
"""兼容性方法:重定向到统一存储"""
return await self._store_memories_unified(memory_chunks)
@@ -1271,7 +1271,7 @@ class MemorySystem:
key = self._fingerprint_key(memory.user_id, fingerprint)
self._memory_fingerprints[key] = memory.memory_id
def _register_memory_fingerprints(self, memories: List[MemoryChunk]) -> None:
def _register_memory_fingerprints(self, memories: list[MemoryChunk]) -> None:
for memory in memories:
fingerprint = self._build_memory_fingerprint(memory)
key = self._fingerprint_key(memory.user_id, fingerprint)
@@ -1302,9 +1302,9 @@ class MemorySystem:
@staticmethod
def _fingerprint_key(user_id: str, fingerprint: str) -> str:
return f"{str(user_id)}:{fingerprint}"
return f"{user_id!s}:{fingerprint}"
def get_system_stats(self) -> Dict[str, Any]:
def get_system_stats(self) -> dict[str, Any]:
"""获取系统统计信息"""
return {
"status": self.status.value,
@@ -1314,7 +1314,7 @@ class MemorySystem:
"config": asdict(self.config),
}
def _compute_memory_score(self, query_text: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
def _compute_memory_score(self, query_text: str, memory: MemoryChunk, context: dict[str, Any]) -> float:
"""根据查询和上下文为记忆计算匹配分数"""
tokens_query = self._tokenize_text(query_text)
tokens_memory = self._tokenize_text(memory.text_content)
@@ -1338,7 +1338,7 @@ class MemorySystem:
final_score = base_score * 0.7 + keyword_overlap * 0.15 + importance_boost + confidence_boost
return max(0.0, min(1.0, final_score))
def _tokenize_text(self, text: str) -> Set[str]:
def _tokenize_text(self, text: str) -> set[str]:
"""简单分词,兼容中英文"""
if not text:
return set()
@@ -1450,7 +1450,7 @@ def get_memory_system() -> MemorySystem:
return memory_system
async def initialize_memory_system(llm_model: Optional[LLMRequest] = None):
async def initialize_memory_system(llm_model: LLMRequest | None = None):
"""初始化全局记忆系统"""
global memory_system
if memory_system is None:

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
基于Vector DB的统一记忆存储系统 V2
使用ChromaDB作为底层存储替代JSON存储方式
@@ -11,20 +10,21 @@
- 自动清理过期记忆
"""
import time
import orjson
import asyncio
import threading
from typing import Dict, List, Optional, Tuple, Any
import time
from dataclasses import dataclass
from datetime import datetime
from typing import Any
from src.common.logger import get_logger
from src.common.vector_db import vector_db_service
from src.chat.utils.utils import get_embedding
from src.chat.memory_system.memory_chunk import MemoryChunk, ConfidenceLevel, ImportanceLevel
import orjson
from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk
from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine
from src.chat.memory_system.memory_metadata_index import MemoryMetadataIndex, MemoryMetadataIndexEntry
from src.chat.utils.utils import get_embedding
from src.common.logger import get_logger
from src.common.vector_db import vector_db_service
logger = get_logger(__name__)
@@ -32,7 +32,7 @@ logger = get_logger(__name__)
_ENUM_MAPPINGS_CACHE = {}
def _build_enum_mapping(enum_class: type) -> Dict[str, Any]:
def _build_enum_mapping(enum_class: type) -> dict[str, Any]:
"""构建枚举类的完整映射表
Args:
@@ -145,7 +145,7 @@ class VectorMemoryStorage:
"""基于Vector DB的记忆存储系统"""
def __init__(self, config: Optional[VectorStorageConfig] = None):
def __init__(self, config: VectorStorageConfig | None = None):
# 默认从全局配置读取如果没有传入config
if config is None:
try:
@@ -163,15 +163,15 @@ class VectorMemoryStorage:
self.vector_db_service = vector_db_service
# 内存缓存
self.memory_cache: Dict[str, MemoryChunk] = {}
self.cache_timestamps: Dict[str, float] = {}
self.memory_cache: dict[str, MemoryChunk] = {}
self.cache_timestamps: dict[str, float] = {}
self._cache = self.memory_cache # 别名,兼容旧代码
# 元数据索引管理器JSON文件索引
self.metadata_index = MemoryMetadataIndex()
# 遗忘引擎
self.forgetting_engine: Optional[MemoryForgettingEngine] = None
self.forgetting_engine: MemoryForgettingEngine | None = None
if self.config.enable_forgetting:
self.forgetting_engine = MemoryForgettingEngine()
@@ -267,7 +267,7 @@ class VectorMemoryStorage:
except Exception as e:
logger.error(f"自动清理失败: {e}")
def _memory_to_vector_format(self, memory: MemoryChunk) -> Dict[str, Any]:
def _memory_to_vector_format(self, memory: MemoryChunk) -> dict[str, Any]:
"""将MemoryChunk转换为向量存储格式"""
try:
# 获取memory_id
@@ -323,7 +323,7 @@ class VectorMemoryStorage:
logger.error(f"转换记忆 {memory_id} 到向量格式失败: {e}", exc_info=True)
raise
def _vector_result_to_memory(self, document: str, metadata: Dict[str, Any]) -> Optional[MemoryChunk]:
def _vector_result_to_memory(self, document: str, metadata: dict[str, Any]) -> MemoryChunk | None:
"""将Vector DB结果转换为MemoryChunk"""
try:
# 从元数据中恢复完整记忆
@@ -440,7 +440,7 @@ class VectorMemoryStorage:
logger.warning(f"不支持的{enum_class.__name__}值类型: {type(value)},使用默认值")
return default
def _get_from_cache(self, memory_id: str) -> Optional[MemoryChunk]:
def _get_from_cache(self, memory_id: str) -> MemoryChunk | None:
"""从缓存获取记忆"""
if not self.config.enable_caching:
return None
@@ -472,7 +472,7 @@ class VectorMemoryStorage:
self.memory_cache[memory_id] = memory
self.cache_timestamps[memory_id] = time.time()
async def store_memories(self, memories: List[MemoryChunk]) -> int:
async def store_memories(self, memories: list[MemoryChunk]) -> int:
"""批量存储记忆"""
if not memories:
return 0
@@ -603,11 +603,11 @@ class VectorMemoryStorage:
self,
query_text: str,
limit: int = 10,
similarity_threshold: Optional[float] = None,
filters: Optional[Dict[str, Any]] = None,
similarity_threshold: float | None = None,
filters: dict[str, Any] | None = None,
# 新增元数据过滤参数用于JSON索引粗筛
metadata_filters: Optional[Dict[str, Any]] = None,
) -> List[Tuple[MemoryChunk, float]]:
metadata_filters: dict[str, Any] | None = None,
) -> list[tuple[MemoryChunk, float]]:
"""
搜索相似记忆(混合索引模式)
@@ -632,7 +632,7 @@ class VectorMemoryStorage:
try:
# === 阶段一JSON元数据粗筛可选 ===
candidate_ids: Optional[List[str]] = None
candidate_ids: list[str] | None = None
if metadata_filters:
logger.debug(f"[JSON元数据粗筛] 开始,过滤条件: {metadata_filters}")
candidate_ids = self.metadata_index.search(
@@ -746,7 +746,7 @@ class VectorMemoryStorage:
logger.error(f"搜索相似记忆失败: {e}")
return []
async def get_memory_by_id(self, memory_id: str) -> Optional[MemoryChunk]:
async def get_memory_by_id(self, memory_id: str) -> MemoryChunk | None:
"""根据ID获取记忆"""
# 首先尝试从缓存获取
memory = self._get_from_cache(memory_id)
@@ -772,7 +772,7 @@ class VectorMemoryStorage:
return None
async def get_memories_by_filters(self, filters: Dict[str, Any], limit: int = 100) -> List[MemoryChunk]:
async def get_memories_by_filters(self, filters: dict[str, Any], limit: int = 100) -> list[MemoryChunk]:
"""根据过滤条件获取记忆"""
try:
results = vector_db_service.get(collection_name=self.config.memory_collection, where=filters, limit=limit)
@@ -848,7 +848,7 @@ class VectorMemoryStorage:
logger.error(f"删除记忆 {memory_id} 失败: {e}")
return False
async def delete_memories_by_filters(self, filters: Dict[str, Any]) -> int:
async def delete_memories_by_filters(self, filters: dict[str, Any]) -> int:
"""根据过滤条件批量删除记忆"""
try:
# 先获取要删除的记忆ID
@@ -880,7 +880,7 @@ class VectorMemoryStorage:
logger.error(f"批量删除记忆失败: {e}")
return 0
async def perform_forgetting_check(self) -> Dict[str, Any]:
async def perform_forgetting_check(self) -> dict[str, Any]:
"""执行遗忘检查"""
if not self.forgetting_engine:
return {"error": "遗忘引擎未启用"}
@@ -925,7 +925,7 @@ class VectorMemoryStorage:
logger.error(f"执行遗忘检查失败: {e}")
return {"error": str(e)}
def get_storage_stats(self) -> Dict[str, Any]:
def get_storage_stats(self) -> dict[str, Any]:
"""获取存储统计信息"""
try:
current_total = vector_db_service.count(self.config.memory_collection)
@@ -960,7 +960,7 @@ class VectorMemoryStorage:
_global_vector_storage = None
def get_vector_memory_storage(config: Optional[VectorStorageConfig] = None) -> VectorMemoryStorage:
def get_vector_memory_storage(config: VectorStorageConfig | None = None) -> VectorMemoryStorage:
"""获取全局Vector记忆存储实例"""
global _global_vector_storage
@@ -974,15 +974,15 @@ def get_vector_memory_storage(config: Optional[VectorStorageConfig] = None) -> V
class VectorMemoryStorageAdapter:
"""适配器类提供与原UnifiedMemoryStorage兼容的接口"""
def __init__(self, config: Optional[VectorStorageConfig] = None):
def __init__(self, config: VectorStorageConfig | None = None):
self.storage = VectorMemoryStorage(config)
async def store_memories(self, memories: List[MemoryChunk]) -> int:
async def store_memories(self, memories: list[MemoryChunk]) -> int:
return await self.storage.store_memories(memories)
async def search_similar_memories(
self, query_text: str, limit: int = 10, scope_id: Optional[str] = None, filters: Optional[Dict[str, Any]] = None
) -> List[Tuple[str, float]]:
self, query_text: str, limit: int = 10, scope_id: str | None = None, filters: dict[str, Any] | None = None
) -> list[tuple[str, float]]:
results = await self.storage.search_similar_memories(query_text, limit, filters=filters)
# 转换为原格式:(memory_id, similarity)
return [
@@ -990,7 +990,7 @@ class VectorMemoryStorageAdapter:
for memory, similarity in results
]
def get_stats(self) -> Dict[str, Any]:
def get_stats(self) -> dict[str, Any]:
return self.storage.get_storage_stats()

View File

@@ -3,14 +3,14 @@
提供统一的消息管理、上下文管理和流循环调度功能
"""
from .message_manager import MessageManager, message_manager
from .context_manager import SingleStreamContextManager
from .distribution_manager import StreamLoopManager, stream_loop_manager
from .message_manager import MessageManager, message_manager
__all__ = [
"MessageManager",
"message_manager",
"SingleStreamContextManager",
"StreamLoopManager",
"message_manager",
"stream_loop_manager",
]

View File

@@ -6,13 +6,14 @@
import asyncio
import time
from typing import Dict, List, Optional, Any
from typing import Any
from src.chat.energy_system import energy_manager
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_manager_data_model import StreamContext
from src.common.logger import get_logger
from src.config.config import global_config
from src.common.data_models.database_data_model import DatabaseMessages
from src.chat.energy_system import energy_manager
from .distribution_manager import stream_loop_manager
logger = get_logger("context_manager")
@@ -21,7 +22,7 @@ logger = get_logger("context_manager")
class SingleStreamContextManager:
"""单流上下文管理器 - 每个实例只管理一个 stream 的上下文"""
def __init__(self, stream_id: str, context: StreamContext, max_context_size: Optional[int] = None):
def __init__(self, stream_id: str, context: StreamContext, max_context_size: int | None = None):
self.stream_id = stream_id
self.context = context
@@ -66,7 +67,7 @@ class SingleStreamContextManager:
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
return False
async def update_message(self, message_id: str, updates: Dict[str, Any]) -> bool:
async def update_message(self, message_id: str, updates: dict[str, Any]) -> bool:
"""更新上下文中的消息
Args:
@@ -84,7 +85,7 @@ class SingleStreamContextManager:
logger.error(f"更新单流上下文消息失败 {self.stream_id}/{message_id}: {e}", exc_info=True)
return False
def get_messages(self, limit: Optional[int] = None, include_unread: bool = True) -> List[DatabaseMessages]:
def get_messages(self, limit: int | None = None, include_unread: bool = True) -> list[DatabaseMessages]:
"""获取上下文消息
Args:
@@ -117,7 +118,7 @@ class SingleStreamContextManager:
logger.error(f"获取单流上下文消息失败 {self.stream_id}: {e}", exc_info=True)
return []
def get_unread_messages(self) -> List[DatabaseMessages]:
def get_unread_messages(self) -> list[DatabaseMessages]:
"""获取未读消息"""
try:
return self.context.get_unread_messages()
@@ -125,7 +126,7 @@ class SingleStreamContextManager:
logger.error(f"获取单流未读消息失败 {self.stream_id}: {e}", exc_info=True)
return []
def mark_messages_as_read(self, message_ids: List[str]) -> bool:
def mark_messages_as_read(self, message_ids: list[str]) -> bool:
"""标记消息为已读"""
try:
if not hasattr(self.context, "mark_message_as_read"):
@@ -168,7 +169,7 @@ class SingleStreamContextManager:
logger.error(f"清空单流上下文失败 {self.stream_id}: {e}", exc_info=True)
return False
def get_statistics(self) -> Dict[str, Any]:
def get_statistics(self) -> dict[str, Any]:
"""获取流统计信息"""
try:
current_time = time.time()
@@ -285,7 +286,7 @@ class SingleStreamContextManager:
logger.error(f"添加消息到单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)
return False
async def update_message_async(self, message_id: str, updates: Dict[str, Any]) -> bool:
async def update_message_async(self, message_id: str, updates: dict[str, Any]) -> bool:
"""异步实现的 update_message更新消息并在需要时 await 能量更新。"""
try:
self.context.update_message_info(message_id, **updates)
@@ -327,7 +328,7 @@ class SingleStreamContextManager:
"""更新流能量"""
try:
history_messages = self.context.get_history_messages(limit=self.max_context_size)
messages: List[DatabaseMessages] = list(history_messages)
messages: list[DatabaseMessages] = list(history_messages)
if include_unread:
messages.extend(self.get_unread_messages())

View File

@@ -5,12 +5,12 @@
import asyncio
import time
from typing import Dict, Optional, Any
from typing import Any
from src.chat.chatter_manager import ChatterManager
from src.chat.energy_system import energy_manager
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.energy_system import energy_manager
from src.chat.chatter_manager import ChatterManager
from src.plugin_system.apis.chat_api import get_chat_manager
logger = get_logger("stream_loop_manager")
@@ -19,13 +19,13 @@ logger = get_logger("stream_loop_manager")
class StreamLoopManager:
"""流循环管理器 - 每个流一个独立的无限循环任务"""
def __init__(self, max_concurrent_streams: Optional[int] = None):
def __init__(self, max_concurrent_streams: int | None = None):
# 流循环任务管理
self.stream_loops: Dict[str, asyncio.Task] = {}
self.stream_loops: dict[str, asyncio.Task] = {}
self.loop_lock = asyncio.Lock()
# 统计信息
self.stats: Dict[str, Any] = {
self.stats: dict[str, Any] = {
"active_streams": 0,
"total_loops": 0,
"total_process_cycles": 0,
@@ -37,13 +37,13 @@ class StreamLoopManager:
self.max_concurrent_streams = max_concurrent_streams or global_config.chat.max_concurrent_distributions
# 强制分发策略
self.force_dispatch_unread_threshold: Optional[int] = getattr(
self.force_dispatch_unread_threshold: int | None = getattr(
global_config.chat, "force_dispatch_unread_threshold", 20
)
self.force_dispatch_min_interval: float = getattr(global_config.chat, "force_dispatch_min_interval", 0.1)
# Chatter管理器
self.chatter_manager: Optional[ChatterManager] = None
self.chatter_manager: ChatterManager | None = None
# 状态控制
self.is_running = False
@@ -212,7 +212,7 @@ class StreamLoopManager:
logger.info(f"流循环结束: {stream_id}")
async def _get_stream_context(self, stream_id: str) -> Optional[Any]:
async def _get_stream_context(self, stream_id: str) -> Any | None:
"""获取流上下文
Args:
@@ -320,7 +320,7 @@ class StreamLoopManager:
logger.debug(f"{stream_id} 使用默认间隔: {base_interval:.2f}s ({e})")
return base_interval
def get_queue_status(self) -> Dict[str, Any]:
def get_queue_status(self) -> dict[str, Any]:
"""获取队列状态
Returns:
@@ -374,14 +374,14 @@ class StreamLoopManager:
except Exception:
return 0
def _needs_force_dispatch_for_context(self, context: Any, unread_count: Optional[int] = None) -> bool:
def _needs_force_dispatch_for_context(self, context: Any, unread_count: int | None = None) -> bool:
if not self.force_dispatch_unread_threshold or self.force_dispatch_unread_threshold <= 0:
return False
count = unread_count if unread_count is not None else self._get_unread_count(context)
return count > self.force_dispatch_unread_threshold
def get_performance_summary(self) -> Dict[str, Any]:
def get_performance_summary(self) -> dict[str, Any]:
"""获取性能摘要
Returns:

View File

@@ -6,19 +6,20 @@
import asyncio
import random
import time
from typing import Dict, Optional, Any, TYPE_CHECKING, List
from typing import TYPE_CHECKING, Any
from src.chat.chatter_manager import ChatterManager
from src.chat.message_receive.chat_stream import ChatStream
from src.common.logger import get_logger
from src.chat.planner_actions.action_manager import ChatterActionManager
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_manager_data_model import MessageManagerStats, StreamStats
from src.chat.chatter_manager import ChatterManager
from src.chat.planner_actions.action_manager import ChatterActionManager
from .sleep_manager.sleep_manager import SleepManager
from .sleep_manager.wakeup_manager import WakeUpManager
from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_system.apis.chat_api import get_chat_manager
from .distribution_manager import stream_loop_manager
from .sleep_manager.sleep_manager import SleepManager
from .sleep_manager.wakeup_manager import WakeUpManager
if TYPE_CHECKING:
pass
@@ -32,7 +33,7 @@ class MessageManager:
def __init__(self, check_interval: float = 5.0):
self.check_interval = check_interval # 检查间隔(秒)
self.is_running = False
self.manager_task: Optional[asyncio.Task] = None
self.manager_task: asyncio.Task | None = None
# 统计信息
self.stats = MessageManagerStats()
@@ -125,7 +126,7 @@ class MessageManager:
except Exception as e:
logger.error(f"更新消息 {message_id} 时发生错误: {e}")
async def bulk_update_messages(self, stream_id: str, updates: List[Dict[str, Any]]) -> int:
async def bulk_update_messages(self, stream_id: str, updates: list[dict[str, Any]]) -> int:
"""批量更新消息信息,降低更新频率"""
if not updates:
return 0
@@ -214,7 +215,7 @@ class MessageManager:
except Exception as e:
logger.error(f"激活聊天流 {stream_id} 时发生错误: {e}")
def get_stream_stats(self, stream_id: str) -> Optional[StreamStats]:
def get_stream_stats(self, stream_id: str) -> StreamStats | None:
"""获取聊天流统计"""
try:
# 通过 ChatManager 获取 ChatStream
@@ -243,7 +244,7 @@ class MessageManager:
logger.error(f"获取聊天流 {stream_id} 统计时发生错误: {e}")
return None
def get_manager_stats(self) -> Dict[str, Any]:
def get_manager_stats(self) -> dict[str, Any]:
"""获取管理器统计"""
return {
"total_streams": self.stats.total_streams,
@@ -278,7 +279,7 @@ class MessageManager:
except Exception as e:
logger.error(f"清理不活跃聊天流时发生错误: {e}")
async def _check_and_handle_interruption(self, chat_stream: Optional[ChatStream] = None):
async def _check_and_handle_interruption(self, chat_stream: ChatStream | None = None):
"""检查并处理消息打断"""
if not global_config.chat.interruption_enabled:
return

View File

@@ -1,12 +1,13 @@
import asyncio
import random
from datetime import datetime, timedelta
from typing import Optional, TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
from src.common.logger import get_logger
from src.config.config import global_config
from .notification_sender import NotificationSender
from .sleep_state import SleepState, SleepContext
from .sleep_state import SleepContext, SleepState
from .time_checker import TimeChecker
if TYPE_CHECKING:
@@ -92,7 +93,7 @@ class SleepManager:
elif current_state == SleepState.WOKEN_UP:
self._handle_woken_up(now, is_in_theoretical_sleep, wakeup_manager)
def _handle_awake_to_sleep(self, now: datetime, activity: Optional[str], wakeup_manager: Optional["WakeUpManager"]):
def _handle_awake_to_sleep(self, now: datetime, activity: str | None, wakeup_manager: Optional["WakeUpManager"]):
"""处理从“清醒”到“准备入睡”的状态转换。"""
if activity:
logger.info(f"进入理论休眠时间 '{activity}',开始进行睡眠决策...")
@@ -181,7 +182,7 @@ class SleepManager:
self,
now: datetime,
is_in_theoretical_sleep: bool,
activity: Optional[str],
activity: str | None,
wakeup_manager: Optional["WakeUpManager"],
):
"""处理“正在睡觉”状态下的逻辑。"""

View File

@@ -1,6 +1,5 @@
from datetime import date, datetime
from enum import Enum, auto
from datetime import datetime, date
from typing import Optional
from src.common.logger import get_logger
from src.manager.local_store_manager import local_storage
@@ -29,10 +28,10 @@ class SleepContext:
def __init__(self):
"""初始化睡眠上下文,并从本地存储加载初始状态。"""
self.current_state: SleepState = SleepState.AWAKE
self.sleep_buffer_end_time: Optional[datetime] = None
self.sleep_buffer_end_time: datetime | None = None
self.total_delayed_minutes_today: float = 0.0
self.last_sleep_check_date: Optional[date] = None
self.re_sleep_attempt_time: Optional[datetime] = None
self.last_sleep_check_date: date | None = None
self.re_sleep_attempt_time: datetime | None = None
self.load()
def save(self):

View File

@@ -1,6 +1,6 @@
from datetime import datetime, time, timedelta
from typing import Optional, List, Dict, Any
import random
from datetime import datetime, time, timedelta
from typing import Any
from src.common.logger import get_logger
from src.config.config import global_config
@@ -37,11 +37,11 @@ class TimeChecker:
return self._daily_sleep_offset, self._daily_wake_offset
@staticmethod
def get_today_schedule() -> Optional[List[Dict[str, Any]]]:
def get_today_schedule() -> list[dict[str, Any]] | None:
"""从全局 ScheduleManager 获取今天的日程安排。"""
return schedule_manager.today_schedule
def is_in_theoretical_sleep_time(self, now_time: time) -> tuple[bool, Optional[str]]:
def is_in_theoretical_sleep_time(self, now_time: time) -> tuple[bool, str | None]:
if global_config.sleep_system.sleep_by_schedule:
if self.get_today_schedule():
return self._is_in_schedule_sleep_time(now_time)
@@ -50,7 +50,7 @@ class TimeChecker:
else:
return self._is_in_sleep_time(now_time)
def _is_in_schedule_sleep_time(self, now_time: time) -> tuple[bool, Optional[str]]:
def _is_in_schedule_sleep_time(self, now_time: time) -> tuple[bool, str | None]:
"""检查当前时间是否落在日程表的任何一个睡眠活动中"""
sleep_keywords = ["休眠", "睡觉", "梦乡"]
today_schedule = self.get_today_schedule()
@@ -79,7 +79,7 @@ class TimeChecker:
continue
return False, None
def _is_in_sleep_time(self, now_time: time) -> tuple[bool, Optional[str]]:
def _is_in_sleep_time(self, now_time: time) -> tuple[bool, str | None]:
"""检查当前时间是否在固定的睡眠时间内(应用偏移量)"""
try:
start_time_str = global_config.sleep_system.fixed_sleep_time

View File

@@ -1,9 +1,10 @@
import asyncio
import time
from typing import Optional, TYPE_CHECKING
from typing import TYPE_CHECKING
from src.chat.message_manager.sleep_manager.wakeup_context import WakeUpContext
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_manager.sleep_manager.wakeup_context import WakeUpContext
if TYPE_CHECKING:
from .sleep_manager import SleepManager
@@ -27,9 +28,9 @@ class WakeUpManager:
"""
self.sleep_manager = sleep_manager
self.context = WakeUpContext() # 使用新的上下文管理器
self.angry_chat_id: Optional[str] = None
self.angry_chat_id: str | None = None
self.last_decay_time = time.time()
self._decay_task: Optional[asyncio.Task] = None
self._decay_task: asyncio.Task | None = None
self.is_running = False
self.last_log_time = 0
self.log_interval = 30
@@ -104,9 +105,7 @@ class WakeUpManager:
logger.debug(f"唤醒度衰减: {old_value:.1f} -> {self.context.wakeup_value:.1f}")
self.context.save()
def add_wakeup_value(
self, is_private_chat: bool, is_mentioned: bool = False, chat_id: Optional[str] = None
) -> bool:
def add_wakeup_value(self, is_private_chat: bool, is_mentioned: bool = False, chat_id: str | None = None) -> bool:
"""
增加唤醒度值

View File

@@ -2,9 +2,8 @@ from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.storage import MessageStorage
__all__ = [
"get_emoji_manager",
"get_chat_manager",
"MessageStorage",
"get_chat_manager",
"get_emoji_manager",
]

View File

@@ -1,25 +1,24 @@
import traceback
import os
import re
import traceback
from typing import Any
from typing import Dict, Any, Optional
from maim_message import UserInfo
from src.common.logger import get_logger
from src.config.config import global_config
from src.mood.mood_manager import mood_manager # 导入情绪管理器
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
from src.chat.message_receive.storage import MessageStorage
from src.chat.message_manager import message_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
from src.plugin_system.base import BaseCommand, EventType
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
from src.chat.utils.utils import is_mentioned_bot_in_message
# 导入反注入系统
from src.chat.antipromptinjector import initialize_anti_injector
from src.chat.message_manager import message_manager
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.utils.utils import is_mentioned_bot_in_message
from src.common.logger import get_logger
from src.config.config import global_config
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
from src.mood.mood_manager import mood_manager # 导入情绪管理器
from src.plugin_system.base import BaseCommand, EventType
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
# 获取项目根目录假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
@@ -219,7 +218,7 @@ class ChatBot:
logger.error(traceback.format_exc())
try:
await plus_command_instance.send_text(f"命令执行出错: {str(e)}")
await plus_command_instance.send_text(f"命令执行出错: {e!s}")
except Exception as send_error:
logger.error(f"发送错误消息失败: {send_error}")
@@ -286,7 +285,7 @@ class ChatBot:
logger.error(traceback.format_exc())
try:
await command_instance.send_text(f"命令执行出错: {str(e)}")
await command_instance.send_text(f"命令执行出错: {e!s}")
except Exception as send_error:
logger.error(f"发送错误消息失败: {send_error}")
@@ -338,7 +337,7 @@ class ChatBot:
except Exception as e:
logger.error(f"处理适配器响应时出错: {e}")
async def do_s4u(self, message_data: Dict[str, Any]):
async def do_s4u(self, message_data: dict[str, Any]):
message = MessageRecvS4U(message_data)
group_info = message.message_info.group_info
user_info = message.message_info.user_info
@@ -359,7 +358,7 @@ class ChatBot:
return
async def message_process(self, message_data: Dict[str, Any]) -> None:
async def message_process(self, message_data: dict[str, Any]) -> None:
"""处理转化后的统一格式消息"""
try:
# 首先处理可能的切片消息重组
@@ -458,7 +457,7 @@ class ChatBot:
# TODO:暂不可用
# 确认从接口发来的message是否有自定义的prompt模板信息
if message.message_info.template_info and not message.message_info.template_info.template_default:
template_group_name: Optional[str] = message.message_info.template_info.template_name # type: ignore
template_group_name: str | None = message.message_info.template_info.template_name # type: ignore
template_items = message.message_info.template_info.template_items
async with global_prompt_manager.async_message_scope(template_group_name):
if isinstance(template_items, dict):

View File

@@ -1,17 +1,18 @@
import asyncio
import copy
import hashlib
import time
import copy
from typing import Dict, Optional, TYPE_CHECKING
from rich.traceback import install
from maim_message import GroupInfo, UserInfo
from typing import TYPE_CHECKING
from src.common.logger import get_logger
from maim_message import GroupInfo, UserInfo
from rich.traceback import install
from sqlalchemy import select
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy.dialects.mysql import insert as mysql_insert
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
from src.common.logger import get_logger
from src.config.config import global_config # 新增导入
# 避免循环导入使用TYPE_CHECKING进行类型提示
@@ -33,8 +34,8 @@ class ChatStream:
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: Optional[dict] = None,
group_info: GroupInfo | None = None,
data: dict | None = None,
):
self.stream_id = stream_id
self.platform = platform
@@ -47,7 +48,7 @@ class ChatStream:
# 使用StreamContext替代ChatMessageContext
from src.common.data_models.message_manager_data_model import StreamContext
from src.plugin_system.base.component_types import ChatType, ChatMode
from src.plugin_system.base.component_types import ChatMode, ChatType
# 创建StreamContext
self.stream_context: StreamContext = StreamContext(
@@ -133,11 +134,11 @@ class ChatStream:
# 恢复stream_context信息
if "stream_context_chat_type" in data:
from src.plugin_system.base.component_types import ChatType, ChatMode
from src.plugin_system.base.component_types import ChatMode, ChatType
instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"])
if "stream_context_chat_mode" in data:
from src.plugin_system.base.component_types import ChatType, ChatMode
from src.plugin_system.base.component_types import ChatMode, ChatType
instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"])
@@ -163,9 +164,10 @@ class ChatStream:
def set_context(self, message: "MessageRecv"):
"""设置聊天消息上下文"""
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
from src.common.data_models.database_data_model import DatabaseMessages
import json
from src.common.data_models.database_data_model import DatabaseMessages
# 安全获取message_info中的数据
message_info = getattr(message, "message_info", {})
user_info = getattr(message_info, "user_info", {})
@@ -248,7 +250,7 @@ class ChatStream:
f"interest_value: {db_message.interest_value}"
)
def _safe_get_actions(self, message: "MessageRecv") -> Optional[list]:
def _safe_get_actions(self, message: "MessageRecv") -> list | None:
"""安全获取消息的actions字段"""
try:
actions = getattr(message, "actions", None)
@@ -278,7 +280,7 @@ class ChatStream:
logger.warning(f"获取actions字段失败: {e}")
return None
def _extract_reply_from_segment(self, segment) -> Optional[str]:
def _extract_reply_from_segment(self, segment) -> str | None:
"""从消息段中提取reply_to信息"""
try:
if hasattr(segment, "type") and segment.type == "seglist":
@@ -391,8 +393,8 @@ class ChatManager:
def __init__(self):
if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
self.streams: dict[str, ChatStream] = {} # stream_id -> ChatStream
self.last_messages: dict[str, "MessageRecv"] = {} # stream_id -> last_message
# try:
# async with get_db_session() as session:
# db.connect(reuse_if_open=True)
@@ -414,7 +416,7 @@ class ChatManager:
await self.load_all_streams()
logger.info(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流")
except Exception as e:
logger.error(f"聊天管理器启动失败: {str(e)}")
logger.error(f"聊天管理器启动失败: {e!s}")
async def _auto_save_task(self):
"""定期自动保存所有聊天流"""
@@ -424,7 +426,7 @@ class ChatManager:
await self._save_all_streams()
logger.info("聊天流自动保存完成")
except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}")
logger.error(f"聊天流自动保存失败: {e!s}")
def register_message(self, message: "MessageRecv"):
"""注册消息到聊天流"""
@@ -437,9 +439,7 @@ class ChatManager:
# logger.debug(f"注册消息到聊天流: {stream_id}")
@staticmethod
def _generate_stream_id(
platform: str, user_info: Optional[UserInfo], group_info: Optional[GroupInfo] = None
) -> str:
def _generate_stream_id(platform: str, user_info: UserInfo | None, group_info: GroupInfo | None = None) -> str:
"""生成聊天流唯一ID"""
if not user_info and not group_info:
raise ValueError("用户信息或群组信息必须提供")
@@ -462,7 +462,7 @@ class ChatManager:
return hashlib.md5(key.encode()).hexdigest()
async def get_or_create_stream(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None
) -> ChatStream:
"""获取或创建聊天流
@@ -572,7 +572,7 @@ class ChatManager:
await self._save_stream(stream)
return stream
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
def get_stream(self, stream_id: str) -> ChatStream | None:
"""通过stream_id获取聊天流"""
stream = self.streams.get(stream_id)
if not stream:
@@ -582,13 +582,13 @@ class ChatManager:
return stream
def get_stream_by_info(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> Optional[ChatStream]:
self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None
) -> ChatStream | None:
"""通过信息获取聊天流"""
stream_id = self._generate_stream_id(platform, user_info, group_info)
return self.streams.get(stream_id)
def get_stream_name(self, stream_id: str) -> Optional[str]:
def get_stream_name(self, stream_id: str) -> str | None:
"""根据 stream_id 获取聊天流名称"""
stream = self.get_stream(stream_id)
if not stream:

View File

@@ -1,20 +1,19 @@
import base64
import time
from abc import abstractmethod, ABCMeta
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Optional, Any
from typing import Any, Optional
import urllib3
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo
from rich.traceback import install
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.utils.utils_image import get_image_manager
from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available
from src.chat.utils.utils_voice import get_voice_text
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import ChatStream
install(extra_lines=3)
@@ -41,8 +40,8 @@ class Message(MessageBase, metaclass=ABCMeta):
message_id: str,
chat_stream: "ChatStream",
user_info: UserInfo,
message_segment: Optional[Seg] = None,
timestamp: Optional[float] = None,
message_segment: Seg | None = None,
timestamp: float | None = None,
reply: Optional["MessageRecv"] = None,
processed_plain_text: str = "",
):
@@ -264,7 +263,7 @@ class MessageRecv(Message):
logger.warning("视频消息中没有base64数据")
return "[收到视频消息,但数据异常]"
except Exception as e:
logger.error(f"视频处理失败: {str(e)}")
logger.error(f"视频处理失败: {e!s}")
import traceback
logger.error(f"错误详情: {traceback.format_exc()}")
@@ -278,7 +277,7 @@ class MessageRecv(Message):
logger.info("未启用视频识别")
return "[视频]"
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
@@ -291,7 +290,7 @@ class MessageRecvS4U(MessageRecv):
self.is_superchat = False
self.gift_info = None
self.gift_name = None
self.gift_count: Optional[str] = None
self.gift_count: str | None = None
self.superchat_info = None
self.superchat_price = None
self.superchat_message_text = None
@@ -444,7 +443,7 @@ class MessageRecvS4U(MessageRecv):
logger.warning("视频消息中没有base64数据")
return "[收到视频消息,但数据异常]"
except Exception as e:
logger.error(f"视频处理失败: {str(e)}")
logger.error(f"视频处理失败: {e!s}")
import traceback
logger.error(f"错误详情: {traceback.format_exc()}")
@@ -458,7 +457,7 @@ class MessageRecvS4U(MessageRecv):
logger.info("未启用视频识别")
return "[视频]"
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
@@ -471,10 +470,10 @@ class MessageProcessBase(Message):
message_id: str,
chat_stream: "ChatStream",
bot_user_info: UserInfo,
message_segment: Optional[Seg] = None,
message_segment: Seg | None = None,
reply: Optional["MessageRecv"] = None,
thinking_start_time: float = 0,
timestamp: Optional[float] = None,
timestamp: float | None = None,
):
# 调用父类初始化,传递时间戳
super().__init__(
@@ -533,9 +532,9 @@ class MessageProcessBase(Message):
return f"[回复<{self.reply.message_info.user_info.user_nickname}> 的消息:{self.reply.processed_plain_text}]" # type: ignore
return None
else:
return f"[{seg.type}:{str(seg.data)}]"
return f"[{seg.type}:{seg.data!s}]"
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
logger.error(f"处理消息段失败: {e!s}, 类型: {seg.type}, 数据: {seg.data}")
return f"[处理失败的{seg.type}消息]"
def _generate_detailed_text(self) -> str:
@@ -565,7 +564,7 @@ class MessageSending(MessageProcessBase):
is_emoji: bool = False,
thinking_start_time: float = 0,
apply_set_reply_logic: bool = False,
reply_to: Optional[str] = None,
reply_to: str | None = None,
):
# 调用父类初始化
super().__init__(
@@ -635,11 +634,11 @@ class MessageSet:
self.messages.append(message)
self.messages.sort(key=lambda x: x.message_info.time) # type: ignore
def get_message_by_index(self, index: int) -> Optional[MessageSending]:
def get_message_by_index(self, index: int) -> MessageSending | None:
"""通过索引获取消息"""
return self.messages[index] if 0 <= index < len(self.messages) else None
def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
def get_message_by_time(self, target_time: float) -> MessageSending | None:
"""获取最接近指定时间的消息"""
if not self.messages:
return None

View File

@@ -1,14 +1,15 @@
import re
import traceback
import orjson
from typing import Union
from src.common.database.sqlalchemy_models import Messages, Images
from src.common.logger import get_logger
from .chat_stream import ChatStream
from .message import MessageSending, MessageRecv
import orjson
from sqlalchemy import desc, select, update
from src.common.database.sqlalchemy_database_api import get_db_session
from sqlalchemy import select, update, desc
from src.common.database.sqlalchemy_models import Images, Messages
from src.common.logger import get_logger
from .chat_stream import ChatStream
from .message import MessageRecv, MessageSending
logger = get_logger("message_storage")
@@ -32,7 +33,7 @@ class MessageStorage:
return []
@staticmethod
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
async def store_message(message: MessageSending | MessageRecv, chat_stream: ChatStream) -> None:
"""存储消息到数据库"""
try:
# 过滤敏感信息的正则模式
@@ -299,6 +300,7 @@ class MessageStorage:
try:
async with get_db_session() as session:
from sqlalchemy import select, update
from src.common.database.sqlalchemy_models import Messages
# 查找需要修复的记录interest_value为0、null或很小的值

View File

@@ -3,12 +3,11 @@ import traceback
from rich.traceback import install
from src.common.message.api import get_global_api
from src.common.logger import get_logger
from src.chat.message_receive.message import MessageSending
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.utils import truncate_message
from src.chat.utils.utils import calculate_typing_time
from src.chat.utils.utils import calculate_typing_time, truncate_message
from src.common.logger import get_logger
from src.common.message.api import get_global_api
install(extra_lines=3)
@@ -27,7 +26,7 @@ async def send_message(message: MessageSending, show_log=True) -> bool:
return True
except Exception as e:
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {str(e)}")
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {e!s}")
traceback.print_exc()
raise e # 重新抛出其他异常

View File

@@ -1,19 +1,17 @@
import asyncio
import traceback
import time
from typing import Dict, Optional, Type, Any, Tuple
import traceback
from typing import Any
from src.chat.utils.timer_calculator import Timer
from src.person_info.person_info import get_person_info_manager
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.utils.timer_calculator import Timer
from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.base.component_types import ComponentType, ActionInfo
from src.person_info.person_info import get_person_info_manager
from src.plugin_system.apis import database_api, generator_api, message_api, send_api
from src.plugin_system.base.base_action import BaseAction
from src.plugin_system.apis import generator_api, database_api, send_api, message_api
from src.plugin_system.base.component_types import ActionInfo, ComponentType
from src.plugin_system.core.component_registry import component_registry
logger = get_logger("action_manager")
@@ -29,7 +27,7 @@ class ChatterActionManager:
"""初始化动作管理器"""
# 当前正在使用的动作集合,默认加载默认动作
self._using_actions: Dict[str, ActionInfo] = {}
self._using_actions: dict[str, ActionInfo] = {}
# 初始化时将默认动作加载到使用中的动作
self._using_actions = component_registry.get_default_actions()
@@ -48,8 +46,8 @@ class ChatterActionManager:
chat_stream: ChatStream,
log_prefix: str,
shutting_down: bool = False,
action_message: Optional[dict] = None,
) -> Optional[BaseAction]:
action_message: dict | None = None,
) -> BaseAction | None:
"""
创建动作处理器实例
@@ -68,7 +66,7 @@ class ChatterActionManager:
"""
try:
# 获取组件类 - 明确指定查询Action类型
component_class: Type[BaseAction] = component_registry.get_component_class(
component_class: type[BaseAction] = component_registry.get_component_class(
action_name, ComponentType.ACTION
) # type: ignore
if not component_class:
@@ -107,7 +105,7 @@ class ChatterActionManager:
logger.error(traceback.format_exc())
return None
def get_using_actions(self) -> Dict[str, ActionInfo]:
def get_using_actions(self) -> dict[str, ActionInfo]:
"""获取当前正在使用的动作集合"""
return self._using_actions.copy()
@@ -140,10 +138,10 @@ class ChatterActionManager:
self,
action_name: str,
chat_id: str,
target_message: Optional[dict] = None,
target_message: dict | None = None,
reasoning: str = "",
action_data: Optional[dict] = None,
thinking_id: Optional[str] = None,
action_data: dict | None = None,
thinking_id: str | None = None,
log_prefix: str = "",
clear_unread_messages: bool = True,
) -> Any:
@@ -437,10 +435,10 @@ class ChatterActionManager:
response_set,
loop_start_time,
action_message,
cycle_timers: Dict[str, float],
cycle_timers: dict[str, float],
thinking_id,
actions,
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
) -> tuple[dict[str, Any], str, dict[str, float]]:
"""
发送并存储回复信息
@@ -488,7 +486,7 @@ class ChatterActionManager:
)
# 构建循环信息
loop_info: Dict[str, Any] = {
loop_info: dict[str, Any] = {
"loop_plan_info": {
"action_result": actions,
},

View File

@@ -1,17 +1,17 @@
import random
import asyncio
import hashlib
import random
import time
from typing import List, Any, Dict, TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Any
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.planner_actions.action_manager import ChatterActionManager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
from src.common.data_models.message_manager_data_model import StreamContext
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.data_models.message_manager_data_model import StreamContext
from src.chat.planner_actions.action_manager import ChatterActionManager
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
from src.plugin_system.base.component_types import ActionInfo, ActionActivationType
from src.plugin_system.base.component_types import ActionActivationType, ActionInfo
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
if TYPE_CHECKING:
@@ -59,18 +59,17 @@ class ActionModifier:
"""
logger.debug(f"{self.log_prefix}开始完整动作修改流程")
removals_s1: List[Tuple[str, str]] = []
removals_s2: List[Tuple[str, str]] = []
removals_s3: List[Tuple[str, str]] = []
removals_s1: list[tuple[str, str]] = []
removals_s2: list[tuple[str, str]] = []
removals_s3: list[tuple[str, str]] = []
self.action_manager.restore_actions()
all_actions = self.action_manager.get_using_actions()
# === 第0阶段根据聊天类型过滤动作 ===
from src.plugin_system.base.component_types import ChatType
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.base.component_types import ComponentType
from src.chat.utils.utils import get_chat_type_and_target_info
from src.plugin_system.base.component_types import ChatType, ComponentType
from src.plugin_system.core.component_registry import component_registry
# 获取聊天类型
is_group_chat, _ = get_chat_type_and_target_info(self.chat_id)
@@ -167,8 +166,8 @@ class ActionModifier:
logger.info(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}")
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: StreamContext):
type_mismatched_actions: List[Tuple[str, str]] = []
def _check_action_associated_types(self, all_actions: dict[str, ActionInfo], chat_context: StreamContext):
type_mismatched_actions: list[tuple[str, str]] = []
for action_name, action_info in all_actions.items():
if action_info.associated_types and not chat_context.check_types(action_info.associated_types):
associated_types_str = ", ".join(action_info.associated_types)
@@ -179,9 +178,9 @@ class ActionModifier:
async def _get_deactivated_actions_by_type(
self,
actions_with_info: Dict[str, ActionInfo],
actions_with_info: dict[str, ActionInfo],
chat_content: str = "",
) -> List[tuple[str, str]]:
) -> list[tuple[str, str]]:
"""
根据激活类型过滤,返回需要停用的动作列表及原因
@@ -254,9 +253,9 @@ class ActionModifier:
async def _process_llm_judge_actions_parallel(
self,
llm_judge_actions: Dict[str, Any],
llm_judge_actions: dict[str, Any],
chat_content: str = "",
) -> Dict[str, bool]:
) -> dict[str, bool]:
"""
并行处理LLM判定actions支持智能缓存

View File

@@ -3,42 +3,41 @@
使用重构后的统一Prompt系统替换原有的复杂提示词构建逻辑
"""
import traceback
import time
import asyncio
import random
import re
from typing import List, Optional, Dict, Any, Tuple
import time
import traceback
from datetime import datetime
from src.mais4u.mai_think import mai_thinking_manager
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.individuality.individuality import get_individuality
from src.llm_models.utils_model import LLMRequest
from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending
from typing import Any
from src.chat.express.expression_selector import expression_selector
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.utils.memory_mappings import get_memory_type_chinese_label
from src.chat.message_receive.message import MessageRecv, MessageSending, Seg, UserInfo
from src.chat.message_receive.uni_message_sender import HeartFCSender
from src.chat.utils.timer_calculator import Timer
from src.chat.utils.utils import get_chat_type_and_target_info
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import (
build_readable_messages,
get_raw_msg_before_timestamp_with_chat,
replace_user_references_sync,
)
from src.chat.express.expression_selector import expression_selector
from src.chat.utils.memory_mappings import get_memory_type_chinese_label
# 导入新的统一Prompt系统
from src.chat.utils.prompt import Prompt, PromptParameters, global_prompt_manager
from src.chat.utils.timer_calculator import Timer
from src.chat.utils.utils import get_chat_type_and_target_info
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.individuality.individuality import get_individuality
from src.llm_models.utils_model import LLMRequest
from src.mais4u.mai_think import mai_thinking_manager
# 旧记忆系统已被移除
# 旧记忆系统已被移除
from src.mood.mood_manager import mood_manager
from src.person_info.person_info import get_person_info_manager
from src.plugin_system.base.component_types import ActionInfo, EventType
from src.plugin_system.apis import llm_api
# 导入新的统一Prompt系统
from src.chat.utils.prompt import PromptParameters
from src.plugin_system.base.component_types import ActionInfo, EventType
logger = get_logger("replyer")
@@ -248,12 +247,12 @@ class DefaultReplyer:
self,
reply_to: str = "",
extra_info: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
available_actions: dict[str, ActionInfo] | None = None,
enable_tool: bool = True,
from_plugin: bool = True,
stream_id: Optional[str] = None,
reply_message: Optional[Dict[str, Any]] = None,
) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]:
stream_id: str | None = None,
reply_message: dict[str, Any] | None = None,
) -> tuple[bool, dict[str, Any] | None, str | None]:
# sourcery skip: merge-nested-ifs
"""
回复器 (Replier): 负责生成回复文本的核心逻辑。
@@ -353,7 +352,7 @@ class DefaultReplyer:
reason: str = "",
reply_to: str = "",
return_prompt: bool = False,
) -> Tuple[bool, Optional[str], Optional[str]]:
) -> tuple[bool, str | None, str | None]:
"""
表达器 (Expressor): 负责重写和优化回复文本。
@@ -722,7 +721,7 @@ class DefaultReplyer:
logger.error(f"工具信息获取失败: {e}")
return ""
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
def _parse_reply_target(self, target_message: str) -> tuple[str, str]:
"""解析回复目标消息 - 使用共享工具"""
from src.chat.utils.prompt import Prompt
@@ -731,7 +730,7 @@ class DefaultReplyer:
return "未知用户", "(无消息内容)"
return Prompt.parse_reply_target(target_message)
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
async def build_keywords_reaction_prompt(self, target: str | None) -> str:
"""构建关键词反应提示
Args:
@@ -766,14 +765,14 @@ class DefaultReplyer:
keywords_reaction_prompt += f"{reaction}"
break
except re.error as e:
logger.error(f"正则表达式编译错误: {pattern_str}, 错误信息: {str(e)}")
logger.error(f"正则表达式编译错误: {pattern_str}, 错误信息: {e!s}")
continue
except Exception as e:
logger.error(f"关键词检测与反应时发生异常: {str(e)}", exc_info=True)
logger.error(f"关键词检测与反应时发生异常: {e!s}", exc_info=True)
return keywords_reaction_prompt
async def _time_and_run_task(self, coroutine, name: str) -> Tuple[str, Any, float]:
async def _time_and_run_task(self, coroutine, name: str) -> tuple[str, Any, float]:
"""计时并运行异步任务的辅助函数
Args:
@@ -790,8 +789,8 @@ class DefaultReplyer:
return name, result, duration
async def build_s4u_chat_history_prompts(
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, chat_id: str
) -> Tuple[str, str]:
self, message_list_before_now: list[dict[str, Any]], target_user_id: str, sender: str, chat_id: str
) -> tuple[str, str]:
"""
构建 s4u 风格的已读/未读历史消息 prompt
@@ -907,8 +906,8 @@ class DefaultReplyer:
return await self._fallback_build_chat_history_prompts(message_list_before_now, target_user_id, sender)
async def _fallback_build_chat_history_prompts(
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str
) -> Tuple[str, str]:
self, message_list_before_now: list[dict[str, Any]], target_user_id: str, sender: str
) -> tuple[str, str]:
"""
回退的已读/未读历史消息构建方法
"""
@@ -1000,15 +999,15 @@ class DefaultReplyer:
return read_history_prompt, unread_history_prompt
async def _get_interest_scores_for_messages(self, messages: List[dict]) -> dict[str, float]:
async def _get_interest_scores_for_messages(self, messages: list[dict]) -> dict[str, float]:
"""为消息获取兴趣度评分"""
interest_scores = {}
try:
from src.common.data_models.database_data_model import DatabaseMessages
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import (
chatter_interest_scoring_system as interest_scoring_system,
)
from src.common.data_models.database_data_model import DatabaseMessages
# 转换消息格式
db_messages = []
@@ -1094,9 +1093,9 @@ class DefaultReplyer:
self,
reply_to: str,
extra_info: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
available_actions: dict[str, ActionInfo] | None = None,
enable_tool: bool = True,
reply_message: Optional[Dict[str, Any]] = None,
reply_message: dict[str, Any] | None = None,
) -> str:
"""
构建回复器上下文
@@ -1417,7 +1416,7 @@ class DefaultReplyer:
raw_reply: str,
reason: str,
reply_to: str,
reply_message: Optional[Dict[str, Any]] = None,
reply_message: dict[str, Any] | None = None,
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
chat_stream = self.chat_stream
chat_id = chat_stream.stream_id
@@ -1553,7 +1552,7 @@ class DefaultReplyer:
is_emoji: bool,
thinking_start_time: float,
display_message: str,
anchor_message: Optional[MessageRecv] = None,
anchor_message: MessageRecv | None = None,
) -> MessageSending:
"""构建单个发送消息"""
@@ -1644,7 +1643,7 @@ class DefaultReplyer:
logger.debug("从LPMM知识库获取知识失败可能是从未导入过知识返回空知识...")
return ""
except Exception as e:
logger.error(f"获取知识库内容时发生异常: {str(e)}")
logger.error(f"获取知识库内容时发生异常: {e!s}")
return ""
async def build_relation_info(self, sender: str, target: str):
@@ -1660,10 +1659,9 @@ class DefaultReplyer:
# 使用AFC关系追踪器获取关系信息
try:
from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker
# 创建关系追踪器实例
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker
relationship_tracker = ChatterRelationshipTracker(chatter_interest_scoring_system)
if relationship_tracker:
@@ -1704,7 +1702,7 @@ class DefaultReplyer:
logger.error(f"获取AFC关系信息失败: {e}")
return f"你与{sender}是普通朋友关系。"
async def _store_chat_memory_async(self, reply_to: str, reply_message: Optional[Dict[str, Any]] = None):
async def _store_chat_memory_async(self, reply_to: str, reply_message: dict[str, Any] | None = None):
"""
异步存储聊天记忆从build_memory_block迁移而来

View File

@@ -1,22 +1,20 @@
from typing import Dict, Optional
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.replyer.default_generator import DefaultReplyer
from src.common.logger import get_logger
logger = get_logger("ReplyerManager")
class ReplyerManager:
def __init__(self):
self._repliers: Dict[str, DefaultReplyer] = {}
self._repliers: dict[str, DefaultReplyer] = {}
def get_replyer(
self,
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
chat_stream: ChatStream | None = None,
chat_id: str | None = None,
request_type: str = "replyer",
) -> Optional[DefaultReplyer]:
) -> DefaultReplyer | None:
"""
获取或创建回复器实例。

View File

@@ -1,18 +1,19 @@
import time # 导入 time 模块以获取当前时间
import random
import re
import time # 导入 time 模块以获取当前时间
from collections.abc import Callable
from typing import Any
from typing import List, Dict, Any, Tuple, Optional, Callable
from rich.traceback import install
from sqlalchemy import and_, select
from src.config.config import global_config
from src.common.message_repository import find_messages, count_messages
from src.common.database.sqlalchemy_models import ActionRecords, Images
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
from src.chat.utils.utils import assign_message_ids, translate_timestamp_to_human_readable
from src.common.database.sqlalchemy_database_api import get_db_session
from sqlalchemy import select, and_
from src.common.database.sqlalchemy_models import ActionRecords, Images
from src.common.logger import get_logger
from src.common.message_repository import count_messages, find_messages
from src.config.config import global_config
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
logger = get_logger("chat_message_builder")
@@ -22,7 +23,7 @@ install(extra_lines=3)
def replace_user_references_sync(
content: str,
platform: str,
name_resolver: Optional[Callable[[str, str], str]] = None,
name_resolver: Callable[[str, str], str] | None = None,
replace_bot_name: bool = True,
) -> str:
"""
@@ -100,7 +101,7 @@ def replace_user_references_sync(
async def replace_user_references_async(
content: str,
platform: str,
name_resolver: Optional[Callable[[str, str], Any]] = None,
name_resolver: Callable[[str, str], Any] | None = None,
replace_bot_name: bool = True,
) -> str:
"""
@@ -174,7 +175,7 @@ async def replace_user_references_async(
async def get_raw_msg_by_timestamp(
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
获取从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
@@ -194,7 +195,7 @@ async def get_raw_msg_by_timestamp_with_chat(
limit_mode: str = "latest",
filter_bot=False,
filter_command=False,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""获取在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'
@@ -220,7 +221,7 @@ async def get_raw_msg_by_timestamp_with_chat_inclusive(
limit: int = 0,
limit_mode: str = "latest",
filter_bot=False,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""获取在特定聊天从指定时间戳到指定时间戳的消息(包含边界),按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'
@@ -239,10 +240,10 @@ async def get_raw_msg_by_timestamp_with_chat_users(
chat_id: str,
timestamp_start: float,
timestamp_end: float,
person_ids: List[str],
person_ids: list[str],
limit: int = 0,
limit_mode: str = "latest",
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""获取某些特定用户在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'
@@ -263,7 +264,7 @@ async def get_actions_by_timestamp_with_chat(
timestamp_end: float = time.time(),
limit: int = 0,
limit_mode: str = "latest",
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
from src.common.logger import get_logger
@@ -372,7 +373,7 @@ async def get_actions_by_timestamp_with_chat(
async def get_actions_by_timestamp_with_chat_inclusive(
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
async with get_db_session() as session:
if limit > 0:
@@ -423,7 +424,7 @@ async def get_actions_by_timestamp_with_chat_inclusive(
async def get_raw_msg_by_timestamp_random(
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
先在范围时间戳内随机选择一条消息取得消息的chat_id然后根据chat_id获取该聊天在指定时间戳范围内的消息
"""
@@ -441,7 +442,7 @@ async def get_raw_msg_by_timestamp_random(
async def get_raw_msg_by_timestamp_with_users(
timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'
@@ -452,7 +453,7 @@ async def get_raw_msg_by_timestamp_with_users(
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> list[dict[str, Any]]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
@@ -463,7 +464,7 @@ async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List
async def get_raw_msg_before_timestamp_with_chat(
chat_id: str, timestamp: float, limit: int = 0
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
@@ -474,7 +475,7 @@ async def get_raw_msg_before_timestamp_with_chat(
async def get_raw_msg_before_timestamp_with_users(
timestamp: float, person_ids: list, limit: int = 0
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
@@ -483,9 +484,7 @@ async def get_raw_msg_before_timestamp_with_users(
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
async def num_new_messages_since(
chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None
) -> int:
async def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: float | None = None) -> int:
"""
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
如果 timestamp_end 为 None则检查从 timestamp_start (不含) 到当前时间的消息。
@@ -517,16 +516,16 @@ async def num_new_messages_since_with_users(
async def _build_readable_messages_internal(
messages: List[Dict[str, Any]],
messages: list[dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
truncate: bool = False,
pic_id_mapping: Optional[Dict[str, str]] = None,
pic_id_mapping: dict[str, str] | None = None,
pic_counter: int = 1,
show_pic: bool = True,
message_id_list: Optional[List[Dict[str, Any]]] = None,
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
message_id_list: list[dict[str, Any]] | None = None,
) -> tuple[str, list[tuple[float, str, str]], dict[str, str], int]:
"""
内部辅助函数,构建可读消息字符串和原始消息详情列表。
@@ -545,7 +544,7 @@ async def _build_readable_messages_internal(
if not messages:
return "", [], pic_id_mapping or {}, pic_counter
message_details_raw: List[Tuple[float, str, str, bool]] = []
message_details_raw: list[tuple[float, str, str, bool]] = []
# 使用传入的映射字典,如果没有则创建新的
if pic_id_mapping is None:
@@ -672,7 +671,7 @@ async def _build_readable_messages_internal(
message_details_with_flags.append((timestamp, name, content, is_action))
# 应用截断逻辑 (如果 truncate 为 True)
message_details: List[Tuple[float, str, str, bool]] = []
message_details: list[tuple[float, str, str, bool]] = []
n_messages = len(message_details_with_flags)
if truncate and n_messages > 0:
for i, (timestamp, name, content, is_action) in enumerate(message_details_with_flags):
@@ -809,7 +808,7 @@ async def _build_readable_messages_internal(
)
async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
async def build_pic_mapping_info(pic_id_mapping: dict[str, str]) -> str:
# sourcery skip: use-contextlib-suppress
"""
构建图片映射信息字符串,显示图片的具体描述内容
@@ -847,7 +846,7 @@ async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
return "\n".join(mapping_lines)
def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
def build_readable_actions(actions: list[dict[str, Any]]) -> str:
"""
将动作列表转换为可读的文本格式。
格式: 在()分钟前,你使用了(action_name)具体内容是action_prompt_display
@@ -922,12 +921,12 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
async def build_readable_messages_with_list(
messages: List[Dict[str, Any]],
messages: list[dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
truncate: bool = False,
) -> Tuple[str, List[Tuple[float, str, str]]]:
) -> tuple[str, list[tuple[float, str, str]]]:
"""
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
允许通过参数控制格式化行为。
@@ -943,7 +942,7 @@ async def build_readable_messages_with_list(
async def build_readable_messages_with_id(
messages: List[Dict[str, Any]],
messages: list[dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
@@ -951,7 +950,7 @@ async def build_readable_messages_with_id(
truncate: bool = False,
show_actions: bool = False,
show_pic: bool = True,
) -> Tuple[str, List[Dict[str, Any]]]:
) -> tuple[str, list[dict[str, Any]]]:
"""
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
允许通过参数控制格式化行为。
@@ -980,7 +979,7 @@ async def build_readable_messages_with_id(
async def build_readable_messages(
messages: List[Dict[str, Any]],
messages: list[dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
@@ -988,7 +987,7 @@ async def build_readable_messages(
truncate: bool = False,
show_actions: bool = True,
show_pic: bool = True,
message_id_list: Optional[List[Dict[str, Any]]] = None,
message_id_list: list[dict[str, Any]] | None = None,
) -> str: # sourcery skip: extract-method
"""
将消息列表转换为可读的文本格式。
@@ -1148,7 +1147,7 @@ async def build_readable_messages(
return "".join(result_parts)
async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
async def build_anonymous_messages(messages: list[dict[str, Any]]) -> str:
"""
构建匿名可读消息将不同人的名称转为唯一占位符A、B、C...bot自己用SELF。
处理 回复<aaa:bbb> 和 @<aaa:bbb> 字段将bbb映射为匿名占位符。
@@ -1261,7 +1260,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
return formatted_string
async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
async def get_person_id_list(messages: list[dict[str, Any]]) -> list[str]:
"""
从消息列表中提取不重复的 person_id 列表 (忽略机器人自身)。

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
记忆系统相关的映射表和工具函数
提供记忆类型、置信度、重要性等的中文标签映射

View File

@@ -3,19 +3,20 @@
将原有的Prompt类和SmartPrompt功能整合为一个真正的Prompt类
"""
import re
import asyncio
import time
import contextvars
from dataclasses import dataclass, field
from typing import Dict, Any, Optional, List, Literal, Tuple
import re
import time
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import Any, Literal, Optional
from rich.traceback import install
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import build_readable_messages
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.utils.chat_message_builder import build_readable_messages
from src.chat.message_receive.chat_stream import get_chat_manager
from src.person_info.person_info import get_person_info_manager
install(extra_lines=3)
@@ -50,11 +51,11 @@ class PromptParameters:
debug_mode: bool = False
# 聊天历史和上下文
chat_target_info: Optional[Dict[str, Any]] = None
message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list)
message_list_before_short: List[Dict[str, Any]] = field(default_factory=list)
chat_target_info: dict[str, Any] | None = None
message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list)
message_list_before_short: list[dict[str, Any]] = field(default_factory=list)
chat_talking_prompt_short: str = ""
target_user_info: Optional[Dict[str, Any]] = None
target_user_info: dict[str, Any] | None = None
# 已构建的内容块
expression_habits_block: str = ""
@@ -77,12 +78,12 @@ class PromptParameters:
action_descriptions: str = ""
# 可用动作信息
available_actions: Optional[Dict[str, Any]] = None
available_actions: dict[str, Any] | None = None
# 动态生成的聊天场景提示
chat_scene: str = ""
def validate(self) -> List[str]:
def validate(self) -> list[str]:
"""参数验证"""
errors = []
if not self.chat_id:
@@ -98,22 +99,22 @@ class PromptContext:
"""提示词上下文管理器"""
def __init__(self):
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
self._context_prompts: dict[str, dict[str, "Prompt"]] = {}
self._current_context_var = contextvars.ContextVar("current_context", default=None)
self._context_lock = asyncio.Lock()
@property
def _current_context(self) -> Optional[str]:
def _current_context(self) -> str | None:
"""获取当前协程的上下文ID"""
return self._current_context_var.get()
@_current_context.setter
def _current_context(self, value: Optional[str]):
def _current_context(self, value: str | None):
"""设置当前协程的上下文ID"""
self._current_context_var.set(value) # type: ignore
@asynccontextmanager
async def async_scope(self, context_id: Optional[str] = None):
async def async_scope(self, context_id: str | None = None):
"""创建一个异步的临时提示模板作用域"""
if context_id is not None:
try:
@@ -159,7 +160,7 @@ class PromptContext:
return self._context_prompts[current_context][name]
return None
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
async def register_async(self, prompt: "Prompt", context_id: str | None = None) -> None:
"""异步注册提示模板到指定作用域"""
async with self._context_lock:
if target_context := context_id or self._current_context:
@@ -177,7 +178,7 @@ class PromptManager:
self._lock = asyncio.Lock()
@asynccontextmanager
async def async_message_scope(self, message_id: Optional[str] = None):
async def async_message_scope(self, message_id: str | None = None):
"""为消息处理创建异步临时作用域"""
async with self._context.async_scope(message_id):
yield self
@@ -236,8 +237,8 @@ class Prompt:
def __init__(
self,
template: str,
name: Optional[str] = None,
parameters: Optional[PromptParameters] = None,
name: str | None = None,
parameters: PromptParameters | None = None,
should_register: bool = True,
):
"""
@@ -277,7 +278,7 @@ class Prompt:
"""将临时标记还原为实际的花括号字符"""
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
def _parse_template_args(self, template: str) -> List[str]:
def _parse_template_args(self, template: str) -> list[str]:
"""解析模板参数"""
template_args = []
processed_template = self._process_escaped_braces(template)
@@ -321,7 +322,7 @@ class Prompt:
logger.error(f"构建Prompt失败: {e}")
raise RuntimeError(f"构建Prompt失败: {e}") from e
async def _build_context_data(self) -> Dict[str, Any]:
async def _build_context_data(self) -> dict[str, Any]:
"""构建智能上下文数据"""
# 并行执行所有构建任务
start_time = time.time()
@@ -401,7 +402,7 @@ class Prompt:
default_result = self._get_default_result_for_task(task_name)
results.append(default_result)
except Exception as e:
logger.error(f"构建任务{task_name}失败: {str(e)}")
logger.error(f"构建任务{task_name}失败: {e!s}")
default_result = self._get_default_result_for_task(task_name)
results.append(default_result)
@@ -411,7 +412,7 @@ class Prompt:
task_name = task_names[i] if i < len(task_names) else f"task_{i}"
if isinstance(result, Exception):
logger.error(f"构建任务{task_name}失败: {str(result)}")
logger.error(f"构建任务{task_name}失败: {result!s}")
elif isinstance(result, dict):
context_data.update(result)
@@ -453,7 +454,7 @@ class Prompt:
return context_data
async def _build_s4u_chat_context(self, context_data: Dict[str, Any]) -> None:
async def _build_s4u_chat_context(self, context_data: dict[str, Any]) -> None:
"""构建S4U模式的聊天上下文"""
if not self.parameters.message_list_before_now_long:
return
@@ -468,7 +469,7 @@ class Prompt:
context_data["read_history_prompt"] = read_history_prompt
context_data["unread_history_prompt"] = unread_history_prompt
async def _build_normal_chat_context(self, context_data: Dict[str, Any]) -> None:
async def _build_normal_chat_context(self, context_data: dict[str, Any]) -> None:
"""构建normal模式的聊天上下文"""
if not self.parameters.chat_talking_prompt_short:
return
@@ -477,8 +478,8 @@ class Prompt:
{self.parameters.chat_talking_prompt_short}"""
async def _build_s4u_chat_history_prompts(
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, chat_id: str
) -> Tuple[str, str]:
self, message_list_before_now: list[dict[str, Any]], target_user_id: str, sender: str, chat_id: str
) -> tuple[str, str]:
"""构建S4U风格的已读/未读历史消息prompt"""
try:
# 动态导入default_generator以避免循环导入
@@ -492,7 +493,7 @@ class Prompt:
except Exception as e:
logger.error(f"构建S4U历史消息prompt失败: {e}")
async def _build_expression_habits(self) -> Dict[str, Any]:
async def _build_expression_habits(self) -> dict[str, Any]:
"""构建表达习惯"""
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.parameters.chat_id)
if not use_expression:
@@ -533,7 +534,7 @@ class Prompt:
logger.error(f"构建表达习惯失败: {e}")
return {"expression_habits_block": ""}
async def _build_memory_block(self) -> Dict[str, Any]:
async def _build_memory_block(self) -> dict[str, Any]:
"""构建记忆块"""
if not global_config.memory.enable_memory:
return {"memory_block": ""}
@@ -653,7 +654,7 @@ class Prompt:
logger.error(f"构建记忆块失败: {e}")
return {"memory_block": ""}
async def _build_memory_block_fast(self) -> Dict[str, Any]:
async def _build_memory_block_fast(self) -> dict[str, Any]:
"""快速构建记忆块(简化版本,用于未预构建时的后备方案)"""
if not global_config.memory.enable_memory:
return {"memory_block": ""}
@@ -677,7 +678,7 @@ class Prompt:
logger.warning(f"快速构建记忆块失败: {e}")
return {"memory_block": ""}
async def _build_relation_info(self) -> Dict[str, Any]:
async def _build_relation_info(self) -> dict[str, Any]:
"""构建关系信息"""
try:
relation_info = await Prompt.build_relation_info(self.parameters.chat_id, self.parameters.reply_to)
@@ -686,7 +687,7 @@ class Prompt:
logger.error(f"构建关系信息失败: {e}")
return {"relation_info_block": ""}
async def _build_tool_info(self) -> Dict[str, Any]:
async def _build_tool_info(self) -> dict[str, Any]:
"""构建工具信息"""
if not global_config.tool.enable_tool:
return {"tool_info_block": ""}
@@ -734,7 +735,7 @@ class Prompt:
logger.error(f"构建工具信息失败: {e}")
return {"tool_info_block": ""}
async def _build_knowledge_info(self) -> Dict[str, Any]:
async def _build_knowledge_info(self) -> dict[str, Any]:
"""构建知识信息"""
if not global_config.lpmm_knowledge.enable:
return {"knowledge_prompt": ""}
@@ -783,7 +784,7 @@ class Prompt:
logger.error(f"构建知识信息失败: {e}")
return {"knowledge_prompt": ""}
async def _build_cross_context(self) -> Dict[str, Any]:
async def _build_cross_context(self) -> dict[str, Any]:
"""构建跨群上下文"""
try:
cross_context = await Prompt.build_cross_context(
@@ -794,7 +795,7 @@ class Prompt:
logger.error(f"构建跨群上下文失败: {e}")
return {"cross_context_block": ""}
async def _format_with_context(self, context_data: Dict[str, Any]) -> str:
async def _format_with_context(self, context_data: dict[str, Any]) -> str:
"""使用上下文数据格式化模板"""
if self.parameters.prompt_mode == "s4u":
params = self._prepare_s4u_params(context_data)
@@ -805,7 +806,7 @@ class Prompt:
return await global_prompt_manager.format_prompt(self.name, **params) if self.name else self.format(**params)
def _prepare_s4u_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
def _prepare_s4u_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
"""准备S4U模式的参数"""
return {
**context_data,
@@ -834,7 +835,7 @@ class Prompt:
or "你正在一个QQ群里聊天你需要理解整个群的聊天动态和话题走向并做出自然的回应。",
}
def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
def _prepare_normal_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
"""准备Normal模式的参数"""
return {
**context_data,
@@ -862,7 +863,7 @@ class Prompt:
or "你正在一个QQ群里聊天你需要理解整个群的聊天动态和话题走向并做出自然的回应。",
}
def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
def _prepare_default_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
"""准备默认模式的参数"""
return {
"expression_habits_block": context_data.get("expression_habits_block", ""),
@@ -905,7 +906,7 @@ class Prompt:
result = self._restore_escaped_braces(processed_template)
return result
except (IndexError, KeyError) as e:
raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {str(e)}") from e
raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {e!s}") from e
def __str__(self) -> str:
"""返回格式化后的结果或原始模板"""
@@ -922,7 +923,7 @@ class Prompt:
# =============================================================================
@staticmethod
def parse_reply_target(target_message: str) -> Tuple[str, str]:
def parse_reply_target(target_message: str) -> tuple[str, str]:
"""
解析回复目标消息 - 统一实现
@@ -981,7 +982,7 @@ class Prompt:
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
def _get_default_result_for_task(self, task_name: str) -> Dict[str, Any]:
def _get_default_result_for_task(self, task_name: str) -> dict[str, Any]:
"""
为超时的任务提供默认结果
@@ -1008,7 +1009,7 @@ class Prompt:
return {}
@staticmethod
async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]) -> str:
async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: dict[str, Any] | None) -> str:
"""
构建跨群聊上下文 - 统一实现
@@ -1071,7 +1072,7 @@ class Prompt:
# 工厂函数
def create_prompt(
template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs
) -> Prompt:
"""快速创建Prompt实例的工厂函数"""
if parameters is None:
@@ -1080,7 +1081,7 @@ def create_prompt(
async def create_prompt_async(
template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs
) -> Prompt:
"""异步创建Prompt实例"""
prompt = create_prompt(template, name, parameters, **kwargs)

View File

@@ -1,11 +1,11 @@
import asyncio
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Dict, Tuple, List
from typing import Any
from src.common.database.sqlalchemy_database_api import db_get, db_query, db_save
from src.common.database.sqlalchemy_models import LLMUsage, Messages, OnlineTime
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import OnlineTime, LLMUsage, Messages
from src.common.database.sqlalchemy_database_api import db_query, db_save, db_get
from src.manager.async_task_manager import AsyncTask
from src.manager.local_store_manager import local_storage
@@ -150,7 +150,7 @@ class StatisticOutputTask(AsyncTask):
# 延迟300秒启动运行间隔300秒
super().__init__(task_name="Statistics Data Output Task", wait_before_start=0, run_interval=300)
self.name_mapping: Dict[str, Tuple[str, float]] = {}
self.name_mapping: dict[str, tuple[str, float]] = {}
"""
联系人/群聊名称映射 {聊天ID: (联系人/群聊名称, 记录时间timestamp)}
注:设计记录时间的目的是方便更新名称,使联系人/群聊名称保持最新
@@ -170,7 +170,7 @@ class StatisticOutputTask(AsyncTask):
deploy_time = datetime(2000, 1, 1)
local_storage["deploy_time"] = now.timestamp()
self.stat_period: List[Tuple[str, timedelta, str]] = [
self.stat_period: list[tuple[str, timedelta, str]] = [
("all_time", now - deploy_time, "自部署以来"), # 必须保留"all_time"
("last_7_days", timedelta(days=7), "最近7天"),
("last_24_hours", timedelta(days=1), "最近24小时"),
@@ -181,7 +181,7 @@ class StatisticOutputTask(AsyncTask):
统计时间段 [(统计名称, 统计时间段, 统计描述), ...]
"""
def _statistic_console_output(self, stats: Dict[str, Any], now: datetime):
def _statistic_console_output(self, stats: dict[str, Any], now: datetime):
"""
输出统计数据到控制台
:param stats: 统计数据
@@ -239,7 +239,7 @@ class StatisticOutputTask(AsyncTask):
# -- 以下为统计数据收集方法 --
@staticmethod
async def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
async def _collect_model_request_for_period(collect_period: list[tuple[str, datetime]]) -> dict[str, Any]:
"""
收集指定时间段的LLM请求统计数据
@@ -393,8 +393,8 @@ class StatisticOutputTask(AsyncTask):
@staticmethod
async def _collect_online_time_for_period(
collect_period: List[Tuple[str, datetime]], now: datetime
) -> Dict[str, Any]:
collect_period: list[tuple[str, datetime]], now: datetime
) -> dict[str, Any]:
"""
收集指定时间段的在线时间统计数据
@@ -452,7 +452,7 @@ class StatisticOutputTask(AsyncTask):
break
return stats
async def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
async def _collect_message_count_for_period(self, collect_period: list[tuple[str, datetime]]) -> dict[str, Any]:
"""
收集指定时间段的消息统计数据
@@ -523,7 +523,7 @@ class StatisticOutputTask(AsyncTask):
break
return stats
async def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
async def _collect_all_statistics(self, now: datetime) -> dict[str, dict[str, Any]]:
"""
收集各时间段的统计数据
:param now: 基准当前时间
@@ -533,7 +533,7 @@ class StatisticOutputTask(AsyncTask):
if "last_full_statistics" in local_storage:
# 如果存在上次完整统计数据,则使用该数据进行增量统计
last_stat: Dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore
last_stat: dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore
self.name_mapping = last_stat["name_mapping"] # 上次完整统计数据的名称映射
last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据
@@ -620,7 +620,7 @@ class StatisticOutputTask(AsyncTask):
# -- 以下为统计数据格式化方法 --
@staticmethod
def _format_total_stat(stats: Dict[str, Any]) -> str:
def _format_total_stat(stats: dict[str, Any]) -> str:
"""
格式化总统计数据
"""
@@ -636,7 +636,7 @@ class StatisticOutputTask(AsyncTask):
return "\n".join(output)
@staticmethod
def _format_model_classified_stat(stats: Dict[str, Any]) -> str:
def _format_model_classified_stat(stats: dict[str, Any]) -> str:
"""
格式化按模型分类的统计数据
"""
@@ -662,7 +662,7 @@ class StatisticOutputTask(AsyncTask):
output.append("")
return "\n".join(output)
def _format_chat_stat(self, stats: Dict[str, Any]) -> str:
def _format_chat_stat(self, stats: dict[str, Any]) -> str:
"""
格式化聊天统计数据
"""
@@ -1007,7 +1007,7 @@ class StatisticOutputTask(AsyncTask):
async def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
"""生成图表数据 (异步)"""
now = datetime.now()
chart_data: Dict[str, Any] = {}
chart_data: dict[str, Any] = {}
time_ranges = [
("6h", 6, 10),
@@ -1023,16 +1023,16 @@ class StatisticOutputTask(AsyncTask):
async def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict:
start_time = now - timedelta(hours=hours)
time_points: List[datetime] = []
time_points: list[datetime] = []
current_time = start_time
while current_time <= now:
time_points.append(current_time)
current_time += timedelta(minutes=interval_minutes)
total_cost_data = [0.0] * len(time_points)
cost_by_model: Dict[str, List[float]] = {}
cost_by_module: Dict[str, List[float]] = {}
message_by_chat: Dict[str, List[int]] = {}
cost_by_model: dict[str, list[float]] = {}
cost_by_module: dict[str, list[float]] = {}
message_by_chat: dict[str, list[int]] = {}
time_labels = [t.strftime("%H:%M") for t in time_points]
interval_seconds = interval_minutes * 60

View File

@@ -1,8 +1,8 @@
import asyncio
from time import perf_counter
from collections.abc import Callable
from functools import wraps
from typing import Optional, Dict, Callable
from time import perf_counter
from rich.traceback import install
install(extra_lines=3)
@@ -75,12 +75,12 @@ class Timer:
3. 直接实例化:如果不调用 __enter__打印对象时将显示当前 perf_counter 的值
"""
__slots__ = ("name", "storage", "elapsed", "auto_unit", "start")
__slots__ = ("auto_unit", "elapsed", "name", "start", "storage")
def __init__(
self,
name: Optional[str] = None,
storage: Optional[Dict[str, float]] = None,
name: str | None = None,
storage: dict[str, float] | None = None,
auto_unit: bool = True,
do_type_check: bool = False,
):
@@ -103,7 +103,7 @@ class Timer:
if storage is not None and not isinstance(storage, dict):
raise TimerTypeError("storage", "Optional[dict]", type(storage))
def __call__(self, func: Optional[Callable] = None) -> Callable:
def __call__(self, func: Callable | None = None) -> Callable:
"""装饰器模式"""
if func is None:
return lambda f: Timer(name=self.name or f.__name__, storage=self.storage, auto_unit=self.auto_unit)(f)

Some files were not shown because too many files have changed in this diff Show More