re-style: 格式化代码
This commit is contained in:
committed by
Windpicker-owo
parent
00ba07e0e1
commit
a79253c714
@@ -3,8 +3,8 @@ MaiBot模块系统
|
||||
包含聊天、情绪、记忆、日程等功能模块
|
||||
"""
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
# 导出主要组件供外部使用
|
||||
__all__ = [
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
MaiBot 反注入系统模块
|
||||
|
||||
@@ -14,25 +13,25 @@ MaiBot 反注入系统模块
|
||||
"""
|
||||
|
||||
from .anti_injector import AntiPromptInjector, get_anti_injector, initialize_anti_injector
|
||||
from .types import DetectionResult, ProcessResult
|
||||
from .core import PromptInjectionDetector, MessageShield
|
||||
from .processors.message_processor import MessageProcessor
|
||||
from .management import AntiInjectionStatistics, UserBanManager
|
||||
from .core import MessageShield, PromptInjectionDetector
|
||||
from .decision import CounterAttackGenerator, ProcessingDecisionMaker
|
||||
from .management import AntiInjectionStatistics, UserBanManager
|
||||
from .processors.message_processor import MessageProcessor
|
||||
from .types import DetectionResult, ProcessResult
|
||||
|
||||
__all__ = [
|
||||
"AntiInjectionStatistics",
|
||||
"AntiPromptInjector",
|
||||
"CounterAttackGenerator",
|
||||
"DetectionResult",
|
||||
"MessageProcessor",
|
||||
"MessageShield",
|
||||
"ProcessResult",
|
||||
"ProcessingDecisionMaker",
|
||||
"PromptInjectionDetector",
|
||||
"UserBanManager",
|
||||
"get_anti_injector",
|
||||
"initialize_anti_injector",
|
||||
"DetectionResult",
|
||||
"ProcessResult",
|
||||
"PromptInjectionDetector",
|
||||
"MessageShield",
|
||||
"MessageProcessor",
|
||||
"AntiInjectionStatistics",
|
||||
"UserBanManager",
|
||||
"CounterAttackGenerator",
|
||||
"ProcessingDecisionMaker",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
LLM反注入系统主模块
|
||||
|
||||
@@ -12,15 +11,16 @@ LLM反注入系统主模块
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Optional, Tuple, Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from .types import ProcessResult
|
||||
from .core import PromptInjectionDetector, MessageShield
|
||||
from .processors.message_processor import MessageProcessor
|
||||
from .management import AntiInjectionStatistics, UserBanManager
|
||||
|
||||
from .core import MessageShield, PromptInjectionDetector
|
||||
from .decision import CounterAttackGenerator, ProcessingDecisionMaker
|
||||
from .management import AntiInjectionStatistics, UserBanManager
|
||||
from .processors.message_processor import MessageProcessor
|
||||
from .types import ProcessResult
|
||||
|
||||
logger = get_logger("anti_injector")
|
||||
|
||||
@@ -43,7 +43,7 @@ class AntiPromptInjector:
|
||||
|
||||
async def process_message(
|
||||
self, message_data: dict, chat_stream=None
|
||||
) -> Tuple[ProcessResult, Optional[str], Optional[str]]:
|
||||
) -> tuple[ProcessResult, str | None, str | None]:
|
||||
"""处理字典格式的消息并返回结果
|
||||
|
||||
Args:
|
||||
@@ -102,7 +102,7 @@ class AntiPromptInjector:
|
||||
await self.statistics.update_stats(error_count=1)
|
||||
|
||||
# 异常情况下直接阻止消息
|
||||
return ProcessResult.BLOCKED_INJECTION, None, f"反注入系统异常,消息已阻止: {str(e)}"
|
||||
return ProcessResult.BLOCKED_INJECTION, None, f"反注入系统异常,消息已阻止: {e!s}"
|
||||
|
||||
finally:
|
||||
# 更新处理时间统计
|
||||
@@ -111,7 +111,7 @@ class AntiPromptInjector:
|
||||
|
||||
async def _process_message_internal(
|
||||
self, text_to_detect: str, user_id: str, platform: str, processed_plain_text: str, start_time: float
|
||||
) -> Tuple[ProcessResult, Optional[str], Optional[str]]:
|
||||
) -> tuple[ProcessResult, str | None, str | None]:
|
||||
"""内部消息处理逻辑(共用的检测核心)"""
|
||||
|
||||
# 如果是纯引用消息,直接允许通过
|
||||
@@ -218,7 +218,7 @@ class AntiPromptInjector:
|
||||
return ProcessResult.ALLOWED, None, "消息检查通过"
|
||||
|
||||
async def handle_message_storage(
|
||||
self, result: ProcessResult, modified_content: Optional[str], reason: str, message_data: dict
|
||||
self, result: ProcessResult, modified_content: str | None, reason: str, message_data: dict
|
||||
) -> None:
|
||||
"""处理违禁消息的数据库存储,根据处理模式决定如何处理"""
|
||||
if result == ProcessResult.BLOCKED_INJECTION or result == ProcessResult.COUNTER_ATTACK:
|
||||
@@ -253,9 +253,10 @@ class AntiPromptInjector:
|
||||
async def _delete_message_from_storage(message_data: dict) -> None:
|
||||
"""从数据库中删除违禁消息记录"""
|
||||
try:
|
||||
from src.common.database.sqlalchemy_models import Messages, get_db_session
|
||||
from sqlalchemy import delete
|
||||
|
||||
from src.common.database.sqlalchemy_models import Messages, get_db_session
|
||||
|
||||
message_id = message_data.get("message_id")
|
||||
if not message_id:
|
||||
logger.warning("无法删除消息:缺少message_id")
|
||||
@@ -279,9 +280,10 @@ class AntiPromptInjector:
|
||||
async def _update_message_in_storage(message_data: dict, new_content: str) -> None:
|
||||
"""更新数据库中的消息内容为加盾版本"""
|
||||
try:
|
||||
from src.common.database.sqlalchemy_models import Messages, get_db_session
|
||||
from sqlalchemy import update
|
||||
|
||||
from src.common.database.sqlalchemy_models import Messages, get_db_session
|
||||
|
||||
message_id = message_data.get("message_id")
|
||||
if not message_id:
|
||||
logger.warning("无法更新消息:缺少message_id")
|
||||
@@ -305,7 +307,7 @@ class AntiPromptInjector:
|
||||
except Exception as e:
|
||||
logger.error(f"更新消息内容失败: {e}")
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
return await self.statistics.get_stats()
|
||||
|
||||
@@ -315,7 +317,7 @@ class AntiPromptInjector:
|
||||
|
||||
|
||||
# 全局反注入器实例
|
||||
_global_injector: Optional[AntiPromptInjector] = None
|
||||
_global_injector: AntiPromptInjector | None = None
|
||||
|
||||
|
||||
def get_anti_injector() -> AntiPromptInjector:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
反注入系统核心检测模块
|
||||
|
||||
@@ -10,4 +9,4 @@
|
||||
from .detector import PromptInjectionDetector
|
||||
from .shield import MessageShield
|
||||
|
||||
__all__ = ["PromptInjectionDetector", "MessageShield"]
|
||||
__all__ = ["MessageShield", "PromptInjectionDetector"]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
提示词注入检测器模块
|
||||
|
||||
@@ -8,19 +7,19 @@
|
||||
3. 缓存机制优化性能
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
import time
|
||||
import hashlib
|
||||
from typing import Dict, List
|
||||
from dataclasses import asdict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from ..types import DetectionResult
|
||||
|
||||
# 导入LLM API
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
from ..types import DetectionResult
|
||||
|
||||
logger = get_logger("anti_injector.detector")
|
||||
|
||||
|
||||
@@ -30,8 +29,8 @@ class PromptInjectionDetector:
|
||||
def __init__(self):
|
||||
"""初始化检测器"""
|
||||
self.config = global_config.anti_prompt_injection
|
||||
self._cache: Dict[str, DetectionResult] = {}
|
||||
self._compiled_patterns: List[re.Pattern] = []
|
||||
self._cache: dict[str, DetectionResult] = {}
|
||||
self._compiled_patterns: list[re.Pattern] = []
|
||||
self._compile_patterns()
|
||||
|
||||
def _compile_patterns(self):
|
||||
@@ -224,7 +223,7 @@ class PromptInjectionDetector:
|
||||
matched_patterns=[],
|
||||
processing_time=processing_time,
|
||||
detection_method="llm",
|
||||
reason=f"LLM检测出错: {str(e)}",
|
||||
reason=f"LLM检测出错: {e!s}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -250,7 +249,7 @@ class PromptInjectionDetector:
|
||||
请客观分析,避免误判正常对话。"""
|
||||
|
||||
@staticmethod
|
||||
def _parse_llm_response(response: str) -> Dict:
|
||||
def _parse_llm_response(response: str) -> dict:
|
||||
"""解析LLM响应"""
|
||||
try:
|
||||
lines = response.strip().split("\n")
|
||||
@@ -280,7 +279,7 @@ class PromptInjectionDetector:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析LLM响应失败: {e}")
|
||||
return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {str(e)}"}
|
||||
return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {e!s}"}
|
||||
|
||||
async def detect(self, message: str) -> DetectionResult:
|
||||
"""执行检测"""
|
||||
@@ -331,7 +330,7 @@ class PromptInjectionDetector:
|
||||
|
||||
return final_result
|
||||
|
||||
def _merge_results(self, results: List[DetectionResult]) -> DetectionResult:
|
||||
def _merge_results(self, results: list[DetectionResult]) -> DetectionResult:
|
||||
"""合并多个检测结果"""
|
||||
if not results:
|
||||
return DetectionResult(reason="无检测结果")
|
||||
@@ -384,7 +383,7 @@ class PromptInjectionDetector:
|
||||
if expired_keys:
|
||||
logger.debug(f"清理了{len(expired_keys)}个过期缓存项")
|
||||
|
||||
def get_cache_stats(self) -> Dict:
|
||||
def get_cache_stats(self) -> dict:
|
||||
"""获取缓存统计信息"""
|
||||
return {
|
||||
"cache_size": len(self._cache),
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
消息加盾模块
|
||||
|
||||
@@ -6,8 +5,6 @@
|
||||
主要通过注入系统提示词来指导AI安全响应。
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
@@ -35,7 +32,7 @@ class MessageShield:
|
||||
return SAFETY_SYSTEM_PROMPT
|
||||
|
||||
@staticmethod
|
||||
def is_shield_needed(confidence: float, matched_patterns: List[str]) -> bool:
|
||||
def is_shield_needed(confidence: float, matched_patterns: list[str]) -> bool:
|
||||
"""判断是否需要加盾
|
||||
|
||||
Args:
|
||||
@@ -60,7 +57,7 @@ class MessageShield:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def create_safety_summary(confidence: float, matched_patterns: List[str]) -> str:
|
||||
def create_safety_summary(confidence: float, matched_patterns: list[str]) -> str:
|
||||
"""创建安全处理摘要
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
反击消息生成模块
|
||||
|
||||
负责生成个性化的反击消息回应提示词注入攻击
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
from .types import DetectionResult
|
||||
|
||||
logger = get_logger("anti_injector.counter_attack")
|
||||
@@ -55,7 +53,7 @@ class CounterAttackGenerator:
|
||||
|
||||
async def generate_counter_attack_message(
|
||||
self, original_message: str, detection_result: DetectionResult
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
"""生成反击消息
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
反注入系统决策模块
|
||||
|
||||
@@ -7,7 +6,7 @@
|
||||
- counter_attack: 反击消息生成器
|
||||
"""
|
||||
|
||||
from .decision_maker import ProcessingDecisionMaker
|
||||
from .counter_attack import CounterAttackGenerator
|
||||
from .decision_maker import ProcessingDecisionMaker
|
||||
|
||||
__all__ = ["ProcessingDecisionMaker", "CounterAttackGenerator"]
|
||||
__all__ = ["CounterAttackGenerator", "ProcessingDecisionMaker"]
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
反击消息生成模块
|
||||
|
||||
负责生成个性化的反击消息回应提示词注入攻击
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
from ..types import DetectionResult
|
||||
|
||||
logger = get_logger("anti_injector.counter_attack")
|
||||
@@ -55,7 +53,7 @@ class CounterAttackGenerator:
|
||||
|
||||
async def generate_counter_attack_message(
|
||||
self, original_message: str, detection_result: DetectionResult
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
"""生成反击消息
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
处理决策器模块
|
||||
|
||||
@@ -6,6 +5,7 @@
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from ..types import DetectionResult
|
||||
|
||||
logger = get_logger("anti_injector.decision_maker")
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
处理决策器模块
|
||||
|
||||
@@ -6,6 +5,7 @@
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .types import DetectionResult
|
||||
|
||||
logger = get_logger("anti_injector.decision_maker")
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
提示词注入检测器模块
|
||||
|
||||
@@ -8,19 +7,19 @@
|
||||
3. 缓存机制优化性能
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
import time
|
||||
import hashlib
|
||||
from typing import Dict, List
|
||||
from dataclasses import asdict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from .types import DetectionResult
|
||||
|
||||
# 导入LLM API
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
from .types import DetectionResult
|
||||
|
||||
logger = get_logger("anti_injector.detector")
|
||||
|
||||
|
||||
@@ -30,8 +29,8 @@ class PromptInjectionDetector:
|
||||
def __init__(self):
|
||||
"""初始化检测器"""
|
||||
self.config = global_config.anti_prompt_injection
|
||||
self._cache: Dict[str, DetectionResult] = {}
|
||||
self._compiled_patterns: List[re.Pattern] = []
|
||||
self._cache: dict[str, DetectionResult] = {}
|
||||
self._compiled_patterns: list[re.Pattern] = []
|
||||
self._compile_patterns()
|
||||
|
||||
def _compile_patterns(self):
|
||||
@@ -221,7 +220,7 @@ class PromptInjectionDetector:
|
||||
matched_patterns=[],
|
||||
processing_time=processing_time,
|
||||
detection_method="llm",
|
||||
reason=f"LLM检测出错: {str(e)}",
|
||||
reason=f"LLM检测出错: {e!s}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -247,7 +246,7 @@ class PromptInjectionDetector:
|
||||
请客观分析,避免误判正常对话。"""
|
||||
|
||||
@staticmethod
|
||||
def _parse_llm_response(response: str) -> Dict:
|
||||
def _parse_llm_response(response: str) -> dict:
|
||||
"""解析LLM响应"""
|
||||
try:
|
||||
lines = response.strip().split("\n")
|
||||
@@ -277,7 +276,7 @@ class PromptInjectionDetector:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析LLM响应失败: {e}")
|
||||
return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {str(e)}"}
|
||||
return {"is_injection": False, "confidence": 0.0, "reasoning": f"解析失败: {e!s}"}
|
||||
|
||||
async def detect(self, message: str) -> DetectionResult:
|
||||
"""执行检测"""
|
||||
@@ -328,7 +327,7 @@ class PromptInjectionDetector:
|
||||
|
||||
return final_result
|
||||
|
||||
def _merge_results(self, results: List[DetectionResult]) -> DetectionResult:
|
||||
def _merge_results(self, results: list[DetectionResult]) -> DetectionResult:
|
||||
"""合并多个检测结果"""
|
||||
if not results:
|
||||
return DetectionResult(reason="无检测结果")
|
||||
@@ -381,7 +380,7 @@ class PromptInjectionDetector:
|
||||
if expired_keys:
|
||||
logger.debug(f"清理了{len(expired_keys)}个过期缓存项")
|
||||
|
||||
def get_cache_stats(self) -> Dict:
|
||||
def get_cache_stats(self) -> dict:
|
||||
"""获取缓存统计信息"""
|
||||
return {
|
||||
"cache_size": len(self._cache),
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
反注入系统管理模块
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
反注入系统统计模块
|
||||
|
||||
@@ -6,12 +5,12 @@
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from typing import Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import AntiInjectionStats, get_db_session
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("anti_injector.statistics")
|
||||
@@ -94,7 +93,7 @@ class AntiInjectionStatistics:
|
||||
except Exception as e:
|
||||
logger.error(f"更新统计数据失败: {e}")
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
try:
|
||||
# 检查反注入系统是否启用
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
用户封禁管理模块
|
||||
|
||||
@@ -6,12 +5,12 @@
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import BanUser, get_db_session
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from ..types import DetectionResult
|
||||
|
||||
logger = get_logger("anti_injector.user_ban")
|
||||
@@ -28,7 +27,7 @@ class UserBanManager:
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
async def check_user_ban(self, user_id: str, platform: str) -> Optional[Tuple[bool, Optional[str], str]]:
|
||||
async def check_user_ban(self, user_id: str, platform: str) -> tuple[bool, str | None, str] | None:
|
||||
"""检查用户是否被封禁
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
反注入系统消息处理模块
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
消息内容处理模块
|
||||
|
||||
@@ -6,10 +5,9 @@
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("anti_injector.message_processor")
|
||||
|
||||
@@ -66,7 +64,7 @@ class MessageProcessor:
|
||||
return new_content
|
||||
|
||||
@staticmethod
|
||||
def check_whitelist(message: MessageRecv, whitelist: list) -> Optional[tuple]:
|
||||
def check_whitelist(message: MessageRecv, whitelist: list) -> tuple | None:
|
||||
"""检查用户白名单
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
反注入系统数据类型定义模块
|
||||
|
||||
@@ -10,7 +9,6 @@
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
@@ -31,8 +29,8 @@ class DetectionResult:
|
||||
|
||||
is_injection: bool = False
|
||||
confidence: float = 0.0
|
||||
matched_patterns: List[str] = field(default_factory=list)
|
||||
llm_analysis: Optional[str] = None
|
||||
matched_patterns: list[str] = field(default_factory=list)
|
||||
llm_analysis: str | None = None
|
||||
processing_time: float = 0.0
|
||||
detection_method: str = "unknown"
|
||||
reason: str = ""
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from typing import Dict, List, Optional, Any
|
||||
import time
|
||||
from src.plugin_system.base.base_chatter import BaseChatter
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from typing import Any
|
||||
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_chatter import BaseChatter
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
|
||||
logger = get_logger("chatter_manager")
|
||||
|
||||
@@ -12,8 +13,8 @@ logger = get_logger("chatter_manager")
|
||||
class ChatterManager:
|
||||
def __init__(self, action_manager: ChatterActionManager):
|
||||
self.action_manager = action_manager
|
||||
self.chatter_classes: Dict[ChatType, List[type]] = {}
|
||||
self.instances: Dict[str, BaseChatter] = {}
|
||||
self.chatter_classes: dict[ChatType, list[type]] = {}
|
||||
self.instances: dict[str, BaseChatter] = {}
|
||||
|
||||
# 管理器统计
|
||||
self.stats = {
|
||||
@@ -46,21 +47,21 @@ class ChatterManager:
|
||||
|
||||
self.stats["chatters_registered"] += 1
|
||||
|
||||
def get_chatter_class(self, chat_type: ChatType) -> Optional[type]:
|
||||
def get_chatter_class(self, chat_type: ChatType) -> type | None:
|
||||
"""获取指定聊天类型的聊天处理器类"""
|
||||
if chat_type in self.chatter_classes:
|
||||
return self.chatter_classes[chat_type][0]
|
||||
return None
|
||||
|
||||
def get_supported_chat_types(self) -> List[ChatType]:
|
||||
def get_supported_chat_types(self) -> list[ChatType]:
|
||||
"""获取支持的聊天类型列表"""
|
||||
return list(self.chatter_classes.keys())
|
||||
|
||||
def get_registered_chatters(self) -> Dict[ChatType, List[type]]:
|
||||
def get_registered_chatters(self) -> dict[ChatType, list[type]]:
|
||||
"""获取已注册的聊天处理器"""
|
||||
return self.chatter_classes.copy()
|
||||
|
||||
def get_stream_instance(self, stream_id: str) -> Optional[BaseChatter]:
|
||||
def get_stream_instance(self, stream_id: str) -> BaseChatter | None:
|
||||
"""获取指定流的聊天处理器实例"""
|
||||
return self.instances.get(stream_id)
|
||||
|
||||
@@ -139,7 +140,7 @@ class ChatterManager:
|
||||
logger.error(f"处理流 {stream_id} 时发生错误: {e}")
|
||||
raise
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""获取管理器统计信息"""
|
||||
stats = self.stats.copy()
|
||||
stats["active_instances"] = len(self.instances)
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
表情包发送历史记录模块
|
||||
"""
|
||||
|
||||
from typing import List, Dict
|
||||
from collections import deque
|
||||
from typing import List, Dict
|
||||
|
||||
@@ -15,7 +13,7 @@ MAX_HISTORY_SIZE = 5 # 每个聊天会话最多保留最近5条表情历史
|
||||
|
||||
# 使用一个全局字典在内存中存储历史记录
|
||||
# 键是 chat_id,值是一个 deque 对象
|
||||
_history_cache: Dict[str, deque] = {}
|
||||
_history_cache: dict[str, deque] = {}
|
||||
|
||||
|
||||
def add_emoji_to_history(chat_id: str, emoji_description: str):
|
||||
@@ -39,7 +37,7 @@ def add_emoji_to_history(chat_id: str, emoji_description: str):
|
||||
logger.debug(f"已将表情 '{emoji_description}' 添加到聊天 {chat_id} 的内存历史中")
|
||||
|
||||
|
||||
def get_recent_emojis(chat_id: str, limit: int = 5) -> List[str]:
|
||||
def get_recent_emojis(chat_id: str, limit: int = 5) -> list[str]:
|
||||
"""
|
||||
从内存中获取最近发送的表情包描述列表。
|
||||
|
||||
|
||||
@@ -1,23 +1,24 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import binascii
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
import io
|
||||
import re
|
||||
import binascii
|
||||
from typing import Any, Optional
|
||||
|
||||
from typing import Optional, Tuple, List, Any
|
||||
from PIL import Image
|
||||
from rich.traceback import install
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.chat.utils.utils_image import get_image_manager, image_path_to_base64
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import Emoji, Images
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -47,14 +48,14 @@ class MaiEmoji:
|
||||
self.embedding = []
|
||||
self.hash = "" # 初始为空,在创建实例时会计算
|
||||
self.description = ""
|
||||
self.emotion: List[str] = []
|
||||
self.emotion: list[str] = []
|
||||
self.usage_count = 0
|
||||
self.last_used_time = time.time()
|
||||
self.register_time = time.time()
|
||||
self.is_deleted = False # 标记是否已被删除
|
||||
self.format = ""
|
||||
|
||||
async def initialize_hash_format(self) -> Optional[bool]:
|
||||
async def initialize_hash_format(self) -> bool | None:
|
||||
"""从文件创建表情包实例, 计算哈希值和格式"""
|
||||
try:
|
||||
# 使用 full_path 检查文件是否存在
|
||||
@@ -105,7 +106,7 @@ class MaiEmoji:
|
||||
self.is_deleted = True
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[初始化错误] 初始化表情包时发生未预期错误 ({self.filename}): {str(e)}")
|
||||
logger.error(f"[初始化错误] 初始化表情包时发生未预期错误 ({self.filename}): {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
self.is_deleted = True
|
||||
return None
|
||||
@@ -142,7 +143,7 @@ class MaiEmoji:
|
||||
self.path = EMOJI_REGISTERED_DIR
|
||||
# self.filename 保持不变
|
||||
except Exception as move_error:
|
||||
logger.error(f"[错误] 移动文件失败: {str(move_error)}")
|
||||
logger.error(f"[错误] 移动文件失败: {move_error!s}")
|
||||
# 如果移动失败,尝试将实例状态恢复?暂时不处理,仅返回失败
|
||||
return False
|
||||
|
||||
@@ -174,11 +175,11 @@ class MaiEmoji:
|
||||
return True
|
||||
|
||||
except Exception as db_error:
|
||||
logger.error(f"[错误] 保存数据库失败 ({self.filename}): {str(db_error)}")
|
||||
logger.error(f"[错误] 保存数据库失败 ({self.filename}): {db_error!s}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 注册表情包失败 ({self.filename}): {str(e)}")
|
||||
logger.error(f"[错误] 注册表情包失败 ({self.filename}): {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
@@ -198,7 +199,7 @@ class MaiEmoji:
|
||||
os.remove(file_to_delete)
|
||||
logger.debug(f"[删除] 文件: {file_to_delete}")
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除文件失败 {file_to_delete}: {str(e)}")
|
||||
logger.error(f"[错误] 删除文件失败 {file_to_delete}: {e!s}")
|
||||
# 文件删除失败,但仍然尝试删除数据库记录
|
||||
|
||||
# 2. 删除数据库记录
|
||||
@@ -214,7 +215,7 @@ class MaiEmoji:
|
||||
result = 1 # Successfully deleted one record
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除数据库记录时出错: {str(e)}")
|
||||
logger.error(f"[错误] 删除数据库记录时出错: {e!s}")
|
||||
result = 0
|
||||
|
||||
if result > 0:
|
||||
@@ -233,11 +234,11 @@ class MaiEmoji:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除表情包失败 ({self.filename}): {str(e)}")
|
||||
logger.error(f"[错误] 删除表情包失败 ({self.filename}): {e!s}")
|
||||
return False
|
||||
|
||||
|
||||
def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str]:
|
||||
def _emoji_objects_to_readable_list(emoji_objects: list["MaiEmoji"]) -> list[str]:
|
||||
"""将表情包对象列表转换为可读的字符串列表
|
||||
|
||||
参数:
|
||||
@@ -256,7 +257,7 @@ def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str
|
||||
return emoji_info_list
|
||||
|
||||
|
||||
def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
|
||||
def _to_emoji_objects(data: Any) -> tuple[list["MaiEmoji"], int]:
|
||||
emoji_objects = []
|
||||
load_errors = 0
|
||||
emoji_data_list = list(data)
|
||||
@@ -300,7 +301,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
|
||||
logger.error(f"[加载错误] 初始化 MaiEmoji 失败 ({full_path}): {ve}")
|
||||
load_errors += 1
|
||||
except Exception as e:
|
||||
logger.error(f"[加载错误] 处理数据库记录时出错 ({full_path}): {str(e)}")
|
||||
logger.error(f"[加载错误] 处理数据库记录时出错 ({full_path}): {e!s}")
|
||||
load_errors += 1
|
||||
return emoji_objects, load_errors
|
||||
|
||||
@@ -335,7 +336,7 @@ async def clear_temp_emoji() -> None:
|
||||
logger.debug(f"[清理] 删除: {filename}")
|
||||
|
||||
|
||||
async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], removed_count: int) -> int:
|
||||
async def clean_unused_emojis(emoji_dir: str, emoji_objects: list["MaiEmoji"], removed_count: int) -> int:
|
||||
"""清理指定目录中未被 emoji_objects 追踪的表情包文件"""
|
||||
if not os.path.exists(emoji_dir):
|
||||
logger.warning(f"[清理] 目标目录不存在,跳过清理: {emoji_dir}")
|
||||
@@ -361,7 +362,7 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r
|
||||
logger.info(f"[清理] 删除未追踪的表情包文件: {file_full_path}")
|
||||
cleaned_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除文件时出错 ({file_full_path}): {str(e)}")
|
||||
logger.error(f"[错误] 删除文件时出错 ({file_full_path}): {e!s}")
|
||||
|
||||
if cleaned_count > 0:
|
||||
logger.info(f"[清理] 在目录 {emoji_dir} 中清理了 {cleaned_count} 个破损表情包。")
|
||||
@@ -369,7 +370,7 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: List["MaiEmoji"], r
|
||||
logger.info(f"[清理] 目录 {emoji_dir} 中没有需要清理的。")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 清理未使用表情包文件时出错 ({emoji_dir}): {str(e)}")
|
||||
logger.error(f"[错误] 清理未使用表情包文件时出错 ({emoji_dir}): {e!s}")
|
||||
|
||||
return removed_count + cleaned_count
|
||||
|
||||
@@ -437,9 +438,9 @@ class EmojiManager:
|
||||
emoji_update.last_used_time = time.time() # Update last used time
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"记录表情使用失败: {str(e)}")
|
||||
logger.error(f"记录表情使用失败: {e!s}")
|
||||
|
||||
async def get_emoji_for_text(self, text_emotion: str) -> Optional[Tuple[str, str, str]]:
|
||||
async def get_emoji_for_text(self, text_emotion: str) -> tuple[str, str, str] | None:
|
||||
"""
|
||||
根据文本内容,使用LLM选择一个合适的表情包。
|
||||
|
||||
@@ -531,7 +532,7 @@ class EmojiManager:
|
||||
return selected_emoji.full_path, f"[表情包:{selected_emoji.description}]", text_emotion
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"使用LLM获取表情包时发生错误: {str(e)}")
|
||||
logger.error(f"使用LLM获取表情包时发生错误: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
@@ -578,7 +579,7 @@ class EmojiManager:
|
||||
continue
|
||||
|
||||
except Exception as item_error:
|
||||
logger.error(f"[错误] 处理表情包记录时出错 ({emoji.filename}): {str(item_error)}")
|
||||
logger.error(f"[错误] 处理表情包记录时出错 ({emoji.filename}): {item_error!s}")
|
||||
# 即使出错,也尝试继续检查下一个
|
||||
continue
|
||||
|
||||
@@ -597,7 +598,7 @@ class EmojiManager:
|
||||
logger.info(f"[检查] 已检查 {total_count} 个表情包记录,全部完好")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 检查表情包完整性失败: {str(e)}")
|
||||
logger.error(f"[错误] 检查表情包完整性失败: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def start_periodic_check_register(self) -> None:
|
||||
@@ -651,7 +652,7 @@ class EmojiManager:
|
||||
os.remove(file_path)
|
||||
logger.warning(f"[清理] 删除注册失败的表情包文件: {filename}")
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 扫描表情包目录失败: {str(e)}")
|
||||
logger.error(f"[错误] 扫描表情包目录失败: {e!s}")
|
||||
|
||||
await asyncio.sleep(global_config.emoji.check_interval * 60)
|
||||
|
||||
@@ -674,12 +675,11 @@ class EmojiManager:
|
||||
logger.warning(f"[数据库] 加载过程中出现 {load_errors} 个错误。")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 从数据库加载所有表情包对象失败: {str(e)}")
|
||||
logger.error(f"[错误] 从数据库加载所有表情包对象失败: {e!s}")
|
||||
self.emoji_objects = [] # 加载失败则清空列表
|
||||
self.emoji_num = 0
|
||||
|
||||
@staticmethod
|
||||
async def get_emoji_from_db(emoji_hash: Optional[str] = None) -> List["MaiEmoji"]:
|
||||
async def get_emoji_from_db(self, emoji_hash: str | None = None) -> list["MaiEmoji"]:
|
||||
"""获取指定哈希值的表情包并初始化为MaiEmoji类对象列表 (主要用于调试或特定查找)
|
||||
|
||||
参数:
|
||||
@@ -708,7 +708,7 @@ class EmojiManager:
|
||||
return emoji_objects
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 从数据库获取表情包对象失败: {str(e)}")
|
||||
logger.error(f"[错误] 从数据库获取表情包对象失败: {e!s}")
|
||||
return []
|
||||
|
||||
async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]:
|
||||
@@ -726,7 +726,7 @@ class EmojiManager:
|
||||
return emoji
|
||||
return None # 如果循环结束还没找到,则返回 None
|
||||
|
||||
async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[str]:
|
||||
async def get_emoji_tag_by_hash(self, emoji_hash: str) -> str | None:
|
||||
"""根据哈希值获取已注册表情包的描述
|
||||
|
||||
Args:
|
||||
@@ -754,10 +754,10 @@ class EmojiManager:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}")
|
||||
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {e!s}")
|
||||
return None
|
||||
|
||||
async def get_emoji_description_by_hash(self, emoji_hash: str) -> Optional[str]:
|
||||
async def get_emoji_description_by_hash(self, emoji_hash: str) -> str | None:
|
||||
"""根据哈希值获取已注册表情包的描述
|
||||
|
||||
Args:
|
||||
@@ -788,7 +788,7 @@ class EmojiManager:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}")
|
||||
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {e!s}")
|
||||
return None
|
||||
|
||||
async def delete_emoji(self, emoji_hash: str) -> bool:
|
||||
@@ -824,7 +824,7 @@ class EmojiManager:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除表情包失败: {str(e)}")
|
||||
logger.error(f"[错误] 删除表情包失败: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
@@ -910,11 +910,11 @@ class EmojiManager:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 替换表情包失败: {str(e)}")
|
||||
logger.error(f"[错误] 替换表情包失败: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
async def build_emoji_description(self, image_base64: str) -> Tuple[str, List[str]]:
|
||||
async def build_emoji_description(self, image_base64: str) -> tuple[str, list[str]]:
|
||||
"""
|
||||
获取表情包的详细描述和情感关键词列表。
|
||||
|
||||
@@ -977,14 +977,14 @@ class EmojiManager:
|
||||
|
||||
# 4. 内容审核,确保表情包符合规定
|
||||
if global_config.emoji.content_filtration:
|
||||
prompt = f'''
|
||||
prompt = f"""
|
||||
请根据以下标准审核这个表情包:
|
||||
1. 主题必须符合:"{global_config.emoji.filtration_prompt}"。
|
||||
2. 内容健康,不含色情、暴力、政治敏感等元素。
|
||||
3. 必须是表情包,而不是普通的聊天截图或视频截图。
|
||||
4. 表情包中的文字数量(如果有)不能超过5个。
|
||||
这个表情包是否完全满足以上所有要求?请只回答“是”或“否”。
|
||||
'''
|
||||
"""
|
||||
content, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.1, max_tokens=10
|
||||
)
|
||||
@@ -1024,7 +1024,7 @@ class EmojiManager:
|
||||
return final_description, emotions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"构建表情包描述时发生严重错误: {str(e)}")
|
||||
logger.error(f"构建表情包描述时发生严重错误: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
return "", []
|
||||
|
||||
@@ -1059,7 +1059,7 @@ class EmojiManager:
|
||||
os.remove(file_full_path)
|
||||
logger.info(f"[清理] 删除重复的待注册文件: {filename}")
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除重复文件失败: {str(e)}")
|
||||
logger.error(f"[错误] 删除重复文件失败: {e!s}")
|
||||
return False # 返回 False 表示未注册新表情
|
||||
|
||||
# 3. 构建描述和情感
|
||||
@@ -1076,7 +1076,7 @@ class EmojiManager:
|
||||
os.remove(file_full_path)
|
||||
logger.info(f"[清理] 删除描述生成失败的文件: {filename}")
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除描述生成失败文件时出错: {str(e)}")
|
||||
logger.error(f"[错误] 删除描述生成失败文件时出错: {e!s}")
|
||||
return False
|
||||
new_emoji.description = description
|
||||
new_emoji.emotion = emotions
|
||||
@@ -1087,7 +1087,7 @@ class EmojiManager:
|
||||
os.remove(file_full_path)
|
||||
logger.info(f"[清理] 删除描述生成异常的文件: {filename}")
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除描述生成异常文件时出错: {str(e)}")
|
||||
logger.error(f"[错误] 删除描述生成异常文件时出错: {e!s}")
|
||||
return False
|
||||
|
||||
# 4. 检查容量并决定是否替换或直接注册
|
||||
@@ -1101,7 +1101,7 @@ class EmojiManager:
|
||||
os.remove(file_full_path) # new_emoji 的 full_path 此时还是源路径
|
||||
logger.info(f"[清理] 删除替换失败的新表情文件: {filename}")
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除替换失败文件时出错: {str(e)}")
|
||||
logger.error(f"[错误] 删除替换失败文件时出错: {e!s}")
|
||||
return False
|
||||
# 替换成功时,replace_a_emoji 内部已处理 new_emoji 的注册和添加到列表
|
||||
return True
|
||||
@@ -1123,11 +1123,11 @@ class EmojiManager:
|
||||
os.remove(file_full_path)
|
||||
logger.info(f"[清理] 删除注册失败的源文件: {filename}")
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 删除注册失败源文件时出错: {str(e)}")
|
||||
logger.error(f"[错误] 删除注册失败源文件时出错: {e!s}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[错误] 注册表情包时发生未预期错误 ({filename}): {str(e)}")
|
||||
logger.error(f"[错误] 注册表情包时发生未预期错误 ({filename}): {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
# 尝试删除源文件以避免循环处理
|
||||
if os.path.exists(file_full_path):
|
||||
|
||||
@@ -4,24 +4,24 @@
|
||||
"""
|
||||
|
||||
from .energy_manager import (
|
||||
EnergyManager,
|
||||
EnergyLevel,
|
||||
EnergyComponent,
|
||||
EnergyCalculator,
|
||||
InterestEnergyCalculator,
|
||||
ActivityEnergyCalculator,
|
||||
EnergyCalculator,
|
||||
EnergyComponent,
|
||||
EnergyLevel,
|
||||
EnergyManager,
|
||||
InterestEnergyCalculator,
|
||||
RecencyEnergyCalculator,
|
||||
RelationshipEnergyCalculator,
|
||||
energy_manager,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EnergyManager",
|
||||
"EnergyLevel",
|
||||
"EnergyComponent",
|
||||
"EnergyCalculator",
|
||||
"InterestEnergyCalculator",
|
||||
"ActivityEnergyCalculator",
|
||||
"EnergyCalculator",
|
||||
"EnergyComponent",
|
||||
"EnergyLevel",
|
||||
"EnergyManager",
|
||||
"InterestEnergyCalculator",
|
||||
"RecencyEnergyCalculator",
|
||||
"RelationshipEnergyCalculator",
|
||||
"energy_manager",
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union, TypedDict
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
@@ -51,8 +51,8 @@ class EnergyContext(TypedDict):
|
||||
"""能量计算上下文"""
|
||||
|
||||
stream_id: str
|
||||
messages: List[Any]
|
||||
user_id: Optional[str]
|
||||
messages: list[Any]
|
||||
user_id: str | None
|
||||
|
||||
|
||||
class EnergyResult(TypedDict):
|
||||
@@ -61,7 +61,7 @@ class EnergyResult(TypedDict):
|
||||
energy: float
|
||||
level: EnergyLevel
|
||||
distribution_interval: float
|
||||
component_scores: Dict[str, float]
|
||||
component_scores: dict[str, float]
|
||||
cached: bool
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ class EnergyCalculator(ABC):
|
||||
"""能量计算器抽象基类"""
|
||||
|
||||
@abstractmethod
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
def calculate(self, context: dict[str, Any]) -> float:
|
||||
"""计算能量值"""
|
||||
pass
|
||||
|
||||
@@ -82,7 +82,7 @@ class EnergyCalculator(ABC):
|
||||
class InterestEnergyCalculator(EnergyCalculator):
|
||||
"""兴趣度能量计算器"""
|
||||
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
def calculate(self, context: dict[str, Any]) -> float:
|
||||
"""基于消息兴趣度计算能量"""
|
||||
messages = context.get("messages", [])
|
||||
if not messages:
|
||||
@@ -120,7 +120,7 @@ class ActivityEnergyCalculator(EnergyCalculator):
|
||||
def __init__(self):
|
||||
self.action_weights = {"reply": 0.4, "react": 0.3, "mention": 0.2, "other": 0.1}
|
||||
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
def calculate(self, context: dict[str, Any]) -> float:
|
||||
"""基于活跃度计算能量"""
|
||||
messages = context.get("messages", [])
|
||||
if not messages:
|
||||
@@ -150,7 +150,7 @@ class ActivityEnergyCalculator(EnergyCalculator):
|
||||
class RecencyEnergyCalculator(EnergyCalculator):
|
||||
"""最近性能量计算器"""
|
||||
|
||||
def calculate(self, context: Dict[str, Any]) -> float:
|
||||
def calculate(self, context: dict[str, Any]) -> float:
|
||||
"""基于最近性计算能量"""
|
||||
messages = context.get("messages", [])
|
||||
if not messages:
|
||||
@@ -197,7 +197,7 @@ class RecencyEnergyCalculator(EnergyCalculator):
|
||||
class RelationshipEnergyCalculator(EnergyCalculator):
|
||||
"""关系能量计算器"""
|
||||
|
||||
async def calculate(self, context: Dict[str, Any]) -> float:
|
||||
async def calculate(self, context: dict[str, Any]) -> float:
|
||||
"""基于关系计算能量"""
|
||||
user_id = context.get("user_id")
|
||||
if not user_id:
|
||||
@@ -223,7 +223,7 @@ class EnergyManager:
|
||||
"""能量管理器 - 统一管理所有能量计算"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.calculators: List[EnergyCalculator] = [
|
||||
self.calculators: list[EnergyCalculator] = [
|
||||
InterestEnergyCalculator(),
|
||||
ActivityEnergyCalculator(),
|
||||
RecencyEnergyCalculator(),
|
||||
@@ -231,14 +231,14 @@ class EnergyManager:
|
||||
]
|
||||
|
||||
# 能量缓存
|
||||
self.energy_cache: Dict[str, Tuple[float, float]] = {} # stream_id -> (energy, timestamp)
|
||||
self.energy_cache: dict[str, tuple[float, float]] = {} # stream_id -> (energy, timestamp)
|
||||
self.cache_ttl: int = 60 # 1分钟缓存
|
||||
|
||||
# AFC阈值配置
|
||||
self.thresholds: Dict[str, float] = {"high_match": 0.8, "reply": 0.4, "non_reply": 0.2}
|
||||
self.thresholds: dict[str, float] = {"high_match": 0.8, "reply": 0.4, "non_reply": 0.2}
|
||||
|
||||
# 统计信息
|
||||
self.stats: Dict[str, Union[int, float, str]] = {
|
||||
self.stats: dict[str, int | float | str] = {
|
||||
"total_calculations": 0,
|
||||
"cache_hits": 0,
|
||||
"cache_misses": 0,
|
||||
@@ -272,7 +272,7 @@ class EnergyManager:
|
||||
except Exception as e:
|
||||
logger.warning(f"加载AFC阈值失败,使用默认值: {e}")
|
||||
|
||||
async def calculate_focus_energy(self, stream_id: str, messages: List[Any], user_id: Optional[str] = None) -> float:
|
||||
async def calculate_focus_energy(self, stream_id: str, messages: list[Any], user_id: str | None = None) -> float:
|
||||
"""计算聊天流的focus_energy"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -297,7 +297,7 @@ class EnergyManager:
|
||||
}
|
||||
|
||||
# 计算各组件能量
|
||||
component_scores: Dict[str, float] = {}
|
||||
component_scores: dict[str, float] = {}
|
||||
total_weight = 0.0
|
||||
|
||||
for calculator in self.calculators:
|
||||
@@ -437,7 +437,7 @@ class EnergyManager:
|
||||
if expired_keys:
|
||||
logger.debug(f"清理了 {len(expired_keys)} 个过期能量缓存")
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
def get_statistics(self) -> dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
return {
|
||||
"cache_size": len(self.energy_cache),
|
||||
@@ -446,7 +446,7 @@ class EnergyManager:
|
||||
"performance_stats": self.stats.copy(),
|
||||
}
|
||||
|
||||
def update_thresholds(self, new_thresholds: Dict[str, float]) -> None:
|
||||
def update_thresholds(self, new_thresholds: dict[str, float]) -> None:
|
||||
"""更新阈值"""
|
||||
self.thresholds.update(new_thresholds)
|
||||
|
||||
|
||||
@@ -1,21 +1,20 @@
|
||||
import time
|
||||
import random
|
||||
import orjson
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
import orjson
|
||||
from sqlalchemy import select
|
||||
from src.common.database.sqlalchemy_models import Expression
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import build_anonymous_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import Expression
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
MAX_EXPRESSION_COUNT = 300
|
||||
DECAY_DAYS = 30 # 30天衰减到0.01
|
||||
@@ -171,7 +170,7 @@ class ExpressionLearner:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
|
||||
return False
|
||||
|
||||
async def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
|
||||
async def get_expression_by_chat_id(self) -> tuple[list[dict[str, float]], list[dict[str, float]]]:
|
||||
"""
|
||||
获取指定chat_id的style和grammar表达方式
|
||||
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
|
||||
@@ -311,7 +310,7 @@ class ExpressionLearner:
|
||||
return []
|
||||
|
||||
# 按chat_id分组
|
||||
chat_dict: Dict[str, List[Dict[str, Any]]] = {}
|
||||
chat_dict: dict[str, list[dict[str, Any]]] = {}
|
||||
for chat_id, situation, style in learnt_expressions:
|
||||
if chat_id not in chat_dict:
|
||||
chat_dict[chat_id] = []
|
||||
@@ -368,7 +367,7 @@ class ExpressionLearner:
|
||||
return learnt_expressions
|
||||
return None
|
||||
|
||||
async def learn_expression(self, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
|
||||
async def learn_expression(self, type: str, num: int = 10) -> tuple[list[tuple[str, str, str]], str] | None:
|
||||
"""从指定聊天流学习表达方式
|
||||
|
||||
Args:
|
||||
@@ -380,7 +379,7 @@ class ExpressionLearner:
|
||||
current_time = time.time()
|
||||
|
||||
# 获取上次学习时间
|
||||
random_msg: Optional[List[Dict[str, Any]]] = await get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
random_msg: list[dict[str, Any]] | None = await get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_learning_time,
|
||||
timestamp_end=current_time,
|
||||
@@ -411,16 +410,16 @@ class ExpressionLearner:
|
||||
|
||||
logger.debug(f"学习{type_str}的response: {response}")
|
||||
|
||||
expressions: List[Tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
|
||||
expressions: list[tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
|
||||
|
||||
return expressions, chat_id
|
||||
|
||||
@staticmethod
|
||||
def parse_expression_response(response: str, chat_id: str) -> List[Tuple[str, str, str]]:
|
||||
def parse_expression_response(response: str, chat_id: str) -> list[tuple[str, str, str]]:
|
||||
"""
|
||||
解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组
|
||||
"""
|
||||
expressions: List[Tuple[str, str, str]] = []
|
||||
expressions: list[tuple[str, str, str]] = []
|
||||
for line in response.splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
@@ -516,7 +515,7 @@ class ExpressionLearnerManager:
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(expr_file, "r", encoding="utf-8") as f:
|
||||
with open(expr_file, encoding="utf-8") as f:
|
||||
expressions = orjson.loads(f.read())
|
||||
|
||||
for chat_id in chat_ids:
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
import orjson
|
||||
import time
|
||||
import random
|
||||
import hashlib
|
||||
import random
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from sqlalchemy import select
|
||||
from src.common.database.sqlalchemy_models import Expression
|
||||
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import Expression
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
@@ -45,7 +45,7 @@ def init_prompt():
|
||||
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
|
||||
|
||||
|
||||
def weighted_sample(population: List[Dict], weights: List[float], k: int) -> List[Dict]:
|
||||
def weighted_sample(population: list[dict], weights: list[float], k: int) -> list[dict]:
|
||||
"""按权重随机抽样"""
|
||||
if not population or not weights or k <= 0:
|
||||
return []
|
||||
@@ -95,7 +95,7 @@ class ExpressionSelector:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
|
||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None:
|
||||
"""解析'platform:id:type'为chat_id(与get_stream_id一致)"""
|
||||
try:
|
||||
parts = stream_config_str.split(":")
|
||||
@@ -114,7 +114,7 @@ class ExpressionSelector:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_related_chat_ids(self, chat_id: str) -> List[str]:
|
||||
def get_related_chat_ids(self, chat_id: str) -> list[str]:
|
||||
"""根据expression.rules配置,获取与当前chat_id相关的所有chat_id(包括自身)"""
|
||||
rules = global_config.expression.rules
|
||||
current_group = None
|
||||
@@ -139,7 +139,7 @@ class ExpressionSelector:
|
||||
|
||||
async def get_random_expressions(
|
||||
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
|
||||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
# sourcery skip: extract-duplicate-method, move-assign
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
@@ -195,7 +195,7 @@ class ExpressionSelector:
|
||||
return selected_style, selected_grammar
|
||||
|
||||
@staticmethod
|
||||
async def update_expressions_count_batch(expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
|
||||
async def update_expressions_count_batch(expressions_to_update: list[dict[str, Any]], increment: float = 0.1):
|
||||
"""对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库"""
|
||||
if not expressions_to_update:
|
||||
return
|
||||
@@ -239,8 +239,9 @@ class ExpressionSelector:
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
max_num: int = 10,
|
||||
target_message: Optional[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
min_num: int = 5,
|
||||
target_message: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
# sourcery skip: inline-variable, list-comprehension
|
||||
"""使用LLM选择适合的表达方式"""
|
||||
|
||||
|
||||
@@ -16,8 +16,7 @@ Chat Frequency Analyzer
|
||||
"""
|
||||
|
||||
import time as time_module
|
||||
from datetime import datetime, timedelta, time
|
||||
from typing import List, Tuple, Optional
|
||||
from datetime import datetime, time, timedelta
|
||||
|
||||
from .tracker import chat_frequency_tracker
|
||||
|
||||
@@ -42,7 +41,7 @@ class ChatFrequencyAnalyzer:
|
||||
self._cache_ttl_seconds = 60 * 30 # 缓存30分钟
|
||||
|
||||
@staticmethod
|
||||
def _find_peak_windows(timestamps: List[float]) -> List[Tuple[datetime, datetime]]:
|
||||
def _find_peak_windows(timestamps: list[float]) -> list[tuple[datetime, datetime]]:
|
||||
"""
|
||||
使用滑动窗口算法来识别时间戳列表中的高峰时段。
|
||||
|
||||
@@ -59,7 +58,7 @@ class ChatFrequencyAnalyzer:
|
||||
datetimes = [datetime.fromtimestamp(ts) for ts in timestamps]
|
||||
datetimes.sort()
|
||||
|
||||
peak_windows: List[Tuple[datetime, datetime]] = []
|
||||
peak_windows: list[tuple[datetime, datetime]] = []
|
||||
window_start_idx = 0
|
||||
|
||||
for i in range(len(datetimes)):
|
||||
@@ -83,7 +82,7 @@ class ChatFrequencyAnalyzer:
|
||||
|
||||
return peak_windows
|
||||
|
||||
def get_peak_chat_times(self, chat_id: str) -> List[Tuple[time, time]]:
|
||||
def get_peak_chat_times(self, chat_id: str) -> list[tuple[time, time]]:
|
||||
"""
|
||||
获取指定用户的高峰聊天时间段。
|
||||
|
||||
@@ -116,7 +115,7 @@ class ChatFrequencyAnalyzer:
|
||||
|
||||
return peak_time_windows
|
||||
|
||||
def is_in_peak_time(self, chat_id: str, now: Optional[datetime] = None) -> bool:
|
||||
def is_in_peak_time(self, chat_id: str, now: datetime | None = None) -> bool:
|
||||
"""
|
||||
检查当前时间是否处于用户的高峰聊天时段内。
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import orjson
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# 数据存储路径
|
||||
@@ -19,10 +19,10 @@ class ChatFrequencyTracker:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._timestamps: Dict[str, List[float]] = self._load_timestamps()
|
||||
self._timestamps: dict[str, list[float]] = self._load_timestamps()
|
||||
|
||||
@staticmethod
|
||||
def _load_timestamps() -> Dict[str, List[float]]:
|
||||
def _load_timestamps() -> dict[str, list[float]]:
|
||||
"""从本地文件加载时间戳数据。"""
|
||||
if not TRACKER_FILE.exists():
|
||||
return {}
|
||||
@@ -61,7 +61,7 @@ class ChatFrequencyTracker:
|
||||
logger.debug(f"为 chat_id '{chat_id}' 记录了新的聊天时间: {now}")
|
||||
self._save_timestamps()
|
||||
|
||||
def get_timestamps_for_chat(self, chat_id: str) -> Optional[List[float]]:
|
||||
def get_timestamps_for_chat(self, chat_id: str) -> list[float] | None:
|
||||
"""
|
||||
获取指定聊天的所有时间戳记录。
|
||||
|
||||
|
||||
@@ -18,11 +18,10 @@ Frequency-Based Proactive Trigger
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
# AFC manager has been moved to chatter plugin
|
||||
|
||||
# AFC manager has been moved to chatter plugin
|
||||
# TODO: 需要重新实现主动思考和睡眠管理功能
|
||||
from .analyzer import chat_frequency_analyzer
|
||||
|
||||
@@ -42,10 +41,10 @@ class FrequencyBasedTrigger:
|
||||
|
||||
def __init__(self):
|
||||
# TODO: 需要重新实现睡眠管理器
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._task: asyncio.Task | None = None
|
||||
# 记录上次为用户触发的时间,用于冷却控制
|
||||
# 格式: { "chat_id": timestamp }
|
||||
self._last_triggered: Dict[str, float] = {}
|
||||
self._last_triggered: dict[str, float] = {}
|
||||
|
||||
async def _run_trigger_cycle(self):
|
||||
"""触发器的主要循环逻辑。"""
|
||||
|
||||
@@ -3,13 +3,14 @@
|
||||
提供机器人兴趣标签和智能匹配功能
|
||||
"""
|
||||
|
||||
from .bot_interest_manager import BotInterestManager, bot_interest_manager
|
||||
from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
|
||||
|
||||
from .bot_interest_manager import BotInterestManager, bot_interest_manager
|
||||
|
||||
__all__ = [
|
||||
"BotInterestManager",
|
||||
"bot_interest_manager",
|
||||
"BotInterestTag",
|
||||
"BotPersonalityInterests",
|
||||
"InterestMatchResult",
|
||||
"bot_interest_manager",
|
||||
]
|
||||
|
||||
@@ -3,17 +3,18 @@
|
||||
基于人设生成兴趣标签,并使用embedding计算匹配度
|
||||
"""
|
||||
|
||||
import orjson
|
||||
import traceback
|
||||
from typing import List, Dict, Optional, Any
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import orjson
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from src.common.data_models.bot_interest_data_model import BotPersonalityInterests, BotInterestTag, InterestMatchResult
|
||||
|
||||
logger = get_logger("bot_interest_manager")
|
||||
|
||||
@@ -22,8 +23,8 @@ class BotInterestManager:
|
||||
"""机器人兴趣标签管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.current_interests: Optional[BotPersonalityInterests] = None
|
||||
self.embedding_cache: Dict[str, List[float]] = {} # embedding缓存
|
||||
self.current_interests: BotPersonalityInterests | None = None
|
||||
self.embedding_cache: dict[str, list[float]] = {} # embedding缓存
|
||||
self._initialized = False
|
||||
|
||||
# Embedding客户端配置
|
||||
@@ -31,7 +32,7 @@ class BotInterestManager:
|
||||
self.embedding_config = None
|
||||
configured_dim = resolve_embedding_dimension()
|
||||
self.embedding_dimension = int(configured_dim) if configured_dim else 0
|
||||
self._detected_embedding_dimension: Optional[int] = None
|
||||
self._detected_embedding_dimension: int | None = None
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
@@ -145,7 +146,7 @@ class BotInterestManager:
|
||||
|
||||
async def _generate_interests_from_personality(
|
||||
self, personality_description: str, personality_id: str
|
||||
) -> Optional[BotPersonalityInterests]:
|
||||
) -> BotPersonalityInterests | None:
|
||||
"""根据人设生成兴趣标签"""
|
||||
try:
|
||||
logger.info("🎨 开始根据人设生成兴趣标签...")
|
||||
@@ -226,14 +227,14 @@ class BotInterestManager:
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
async def _call_llm_for_interest_generation(self, prompt: str) -> Optional[str]:
|
||||
async def _call_llm_for_interest_generation(self, prompt: str) -> str | None:
|
||||
"""调用LLM生成兴趣标签"""
|
||||
try:
|
||||
logger.info("🔧 配置LLM客户端...")
|
||||
|
||||
# 使用llm_api来处理请求
|
||||
from src.plugin_system.apis import llm_api
|
||||
from src.config.config import model_config
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
# 构建完整的提示词,明确要求只返回纯JSON
|
||||
full_prompt = f"""你是一个专业的机器人人设分析师,擅长根据人设描述生成合适的兴趣标签。
|
||||
@@ -342,7 +343,7 @@ class BotInterestManager:
|
||||
logger.info(f"🗃️ 总缓存大小: {len(self.embedding_cache)}")
|
||||
logger.info("=" * 50)
|
||||
|
||||
async def _get_embedding(self, text: str) -> List[float]:
|
||||
async def _get_embedding(self, text: str) -> list[float]:
|
||||
"""获取文本的embedding向量"""
|
||||
if not hasattr(self, "embedding_request"):
|
||||
raise RuntimeError("❌ Embedding请求客户端未初始化")
|
||||
@@ -383,7 +384,7 @@ class BotInterestManager:
|
||||
else:
|
||||
raise RuntimeError(f"❌ 返回的embedding为空: {embedding}")
|
||||
|
||||
async def _generate_message_embedding(self, message_text: str, keywords: List[str]) -> List[float]:
|
||||
async def _generate_message_embedding(self, message_text: str, keywords: list[str]) -> list[float]:
|
||||
"""为消息生成embedding向量"""
|
||||
# 组合消息文本和关键词作为embedding输入
|
||||
if keywords:
|
||||
@@ -399,7 +400,7 @@ class BotInterestManager:
|
||||
return embedding
|
||||
|
||||
async def _calculate_similarity_scores(
|
||||
self, result: InterestMatchResult, message_embedding: List[float], keywords: List[str]
|
||||
self, result: InterestMatchResult, message_embedding: list[float], keywords: list[str]
|
||||
):
|
||||
"""计算消息与兴趣标签的相似度分数"""
|
||||
try:
|
||||
@@ -428,7 +429,7 @@ class BotInterestManager:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 计算相似度分数失败: {e}")
|
||||
|
||||
async def calculate_interest_match(self, message_text: str, keywords: List[str] = None) -> InterestMatchResult:
|
||||
async def calculate_interest_match(self, message_text: str, keywords: list[str] = None) -> InterestMatchResult:
|
||||
"""计算消息与机器人兴趣的匹配度"""
|
||||
if not self.current_interests or not self._initialized:
|
||||
raise RuntimeError("❌ 兴趣标签系统未初始化")
|
||||
@@ -528,7 +529,7 @@ class BotInterestManager:
|
||||
)
|
||||
return result
|
||||
|
||||
def _calculate_keyword_match_bonus(self, keywords: List[str], matched_tags: List[str]) -> Dict[str, float]:
|
||||
def _calculate_keyword_match_bonus(self, keywords: list[str], matched_tags: list[str]) -> dict[str, float]:
|
||||
"""计算关键词直接匹配奖励"""
|
||||
if not keywords or not matched_tags:
|
||||
return {}
|
||||
@@ -610,7 +611,7 @@ class BotInterestManager:
|
||||
|
||||
return previous_row[-1]
|
||||
|
||||
def _calculate_cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
|
||||
def _calculate_cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float:
|
||||
"""计算余弦相似度"""
|
||||
try:
|
||||
vec1 = np.array(vec1)
|
||||
@@ -629,16 +630,17 @@ class BotInterestManager:
|
||||
logger.error(f"计算余弦相似度失败: {e}")
|
||||
return 0.0
|
||||
|
||||
async def _load_interests_from_database(self, personality_id: str) -> Optional[BotPersonalityInterests]:
|
||||
async def _load_interests_from_database(self, personality_id: str) -> BotPersonalityInterests | None:
|
||||
"""从数据库加载兴趣标签"""
|
||||
try:
|
||||
logger.debug(f"从数据库加载兴趣标签, personality_id: {personality_id}")
|
||||
|
||||
# 导入SQLAlchemy相关模块
|
||||
from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
import orjson
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests
|
||||
|
||||
async with get_db_session() as session:
|
||||
# 查询最新的兴趣标签配置
|
||||
db_interests = (
|
||||
@@ -716,10 +718,11 @@ class BotInterestManager:
|
||||
logger.info(f"🔄 版本: {interests.version}")
|
||||
|
||||
# 导入SQLAlchemy相关模块
|
||||
from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
import orjson
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests
|
||||
|
||||
# 将兴趣标签转换为JSON格式
|
||||
tags_data = []
|
||||
for tag in interests.interest_tags:
|
||||
@@ -803,11 +806,11 @@ class BotInterestManager:
|
||||
logger.error("🔍 错误详情:")
|
||||
traceback.print_exc()
|
||||
|
||||
def get_current_interests(self) -> Optional[BotPersonalityInterests]:
|
||||
def get_current_interests(self) -> BotPersonalityInterests | None:
|
||||
"""获取当前的兴趣标签配置"""
|
||||
return self.current_interests
|
||||
|
||||
def get_interest_stats(self) -> Dict[str, Any]:
|
||||
def get_interest_stats(self) -> dict[str, Any]:
|
||||
"""获取兴趣系统统计信息"""
|
||||
if not self.current_interests:
|
||||
return {"initialized": False}
|
||||
|
||||
@@ -1,33 +1,31 @@
|
||||
from dataclasses import dataclass
|
||||
import orjson
|
||||
import os
|
||||
import math
|
||||
import asyncio
|
||||
import math
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from dataclasses import dataclass
|
||||
|
||||
# import tqdm
|
||||
import faiss
|
||||
|
||||
from .utils.hash import get_sha256
|
||||
from .global_logger import logger
|
||||
from rich.traceback import install
|
||||
import numpy as np
|
||||
import orjson
|
||||
import pandas as pd
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
BarColumn,
|
||||
MofNCompleteColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TaskProgressColumn,
|
||||
TextColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
TaskProgressColumn,
|
||||
MofNCompleteColumn,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
from src.config.config import global_config
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from src.config.config import global_config
|
||||
|
||||
from .global_logger import logger
|
||||
from .utils.hash import get_sha256
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -79,7 +77,7 @@ def cosine_similarity(a, b):
|
||||
class EmbeddingStoreItem:
|
||||
"""嵌入库中的项"""
|
||||
|
||||
def __init__(self, item_hash: str, embedding: List[float], content: str):
|
||||
def __init__(self, item_hash: str, embedding: list[float], content: str):
|
||||
self.hash = item_hash
|
||||
self.embedding = embedding
|
||||
self.str = content
|
||||
@@ -127,7 +125,7 @@ class EmbeddingStore:
|
||||
self.idx2hash = None
|
||||
|
||||
@staticmethod
|
||||
def _get_embedding(s: str) -> List[float]:
|
||||
def _get_embedding(s: str) -> list[float]:
|
||||
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
|
||||
# 创建新的事件循环并在完成后立即关闭
|
||||
loop = asyncio.new_event_loop()
|
||||
@@ -135,8 +133,8 @@ class EmbeddingStore:
|
||||
|
||||
try:
|
||||
# 创建新的LLMRequest实例
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
|
||||
@@ -161,8 +159,8 @@ class EmbeddingStore:
|
||||
|
||||
@staticmethod
|
||||
def _get_embeddings_batch_threaded(
|
||||
strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
|
||||
) -> List[Tuple[str, List[float]]]:
|
||||
strs: list[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
|
||||
) -> list[tuple[str, list[float]]]:
|
||||
"""使用多线程批量获取嵌入向量
|
||||
|
||||
Args:
|
||||
@@ -192,8 +190,8 @@ class EmbeddingStore:
|
||||
chunk_results = []
|
||||
|
||||
# 为每个线程创建独立的LLMRequest实例
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
try:
|
||||
# 创建线程专用的LLM实例
|
||||
@@ -303,7 +301,7 @@ class EmbeddingStore:
|
||||
path = self.get_test_file_path()
|
||||
if not os.path.exists(path):
|
||||
return None
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
return orjson.loads(f.read())
|
||||
|
||||
def check_embedding_model_consistency(self):
|
||||
@@ -345,7 +343,7 @@ class EmbeddingStore:
|
||||
logger.info("嵌入模型一致性校验通过。")
|
||||
return True
|
||||
|
||||
def batch_insert_strs(self, strs: List[str], times: int) -> None:
|
||||
def batch_insert_strs(self, strs: list[str], times: int) -> None:
|
||||
"""向库中存入字符串(使用多线程优化)"""
|
||||
if not strs:
|
||||
return
|
||||
@@ -481,7 +479,7 @@ class EmbeddingStore:
|
||||
if os.path.exists(self.idx2hash_file_path):
|
||||
logger.info(f"正在加载{self.namespace}嵌入库的idx2hash映射...")
|
||||
logger.debug(f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射")
|
||||
with open(self.idx2hash_file_path, "r") as f:
|
||||
with open(self.idx2hash_file_path) as f:
|
||||
self.idx2hash = orjson.loads(f.read())
|
||||
logger.info(f"{self.namespace}嵌入库的idx2hash映射加载成功")
|
||||
else:
|
||||
@@ -511,7 +509,7 @@ class EmbeddingStore:
|
||||
self.faiss_index = faiss.IndexFlatIP(embedding_dim)
|
||||
self.faiss_index.add(embeddings)
|
||||
|
||||
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:
|
||||
def search_top_k(self, query: list[float], k: int) -> list[tuple[str, float]]:
|
||||
"""搜索最相似的k个项,以余弦相似度为度量
|
||||
Args:
|
||||
query: 查询的embedding
|
||||
@@ -575,11 +573,11 @@ class EmbeddingManager:
|
||||
"""对所有嵌入库做模型一致性校验"""
|
||||
return self.paragraphs_embedding_store.check_embedding_model_consistency()
|
||||
|
||||
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
|
||||
def _store_pg_into_embedding(self, raw_paragraphs: dict[str, str]):
|
||||
"""将段落编码存入Embedding库"""
|
||||
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1)
|
||||
|
||||
def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
|
||||
def _store_ent_into_embedding(self, triple_list_data: dict[str, list[list[str]]]):
|
||||
"""将实体编码存入Embedding库"""
|
||||
entities = set()
|
||||
for triple_list in triple_list_data.values():
|
||||
@@ -588,7 +586,7 @@ class EmbeddingManager:
|
||||
entities.add(triple[2])
|
||||
self.entities_embedding_store.batch_insert_strs(list(entities), times=2)
|
||||
|
||||
def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
|
||||
def _store_rel_into_embedding(self, triple_list_data: dict[str, list[list[str]]]):
|
||||
"""将关系编码存入Embedding库"""
|
||||
graph_triples = [] # a list of unique relation triple (in tuple) from all chunks
|
||||
for triples in triple_list_data.values():
|
||||
@@ -606,8 +604,8 @@ class EmbeddingManager:
|
||||
|
||||
def store_new_data_set(
|
||||
self,
|
||||
raw_paragraphs: Dict[str, str],
|
||||
triple_list_data: Dict[str, List[List[str]]],
|
||||
raw_paragraphs: dict[str, str],
|
||||
triple_list_data: dict[str, list[list[str]]],
|
||||
):
|
||||
if not self.check_all_embedding_model_consistency():
|
||||
raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。")
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import asyncio
|
||||
import orjson
|
||||
import time
|
||||
from typing import List, Union
|
||||
|
||||
from .global_logger import logger
|
||||
from . import prompt_template
|
||||
from .knowledge_lib import INVALID_ENTITY
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
from . import prompt_template
|
||||
from .global_logger import logger
|
||||
from .knowledge_lib import INVALID_ENTITY
|
||||
|
||||
|
||||
def _extract_json_from_text(text: str):
|
||||
# sourcery skip: assign-if-exp, extract-method
|
||||
@@ -46,7 +47,7 @@ def _extract_json_from_text(text: str):
|
||||
return []
|
||||
|
||||
|
||||
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
||||
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> list[str]:
|
||||
# sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression
|
||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
|
||||
@@ -92,7 +93,7 @@ def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
||||
return entity_extract_result
|
||||
|
||||
|
||||
def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> List[List[str]]:
|
||||
def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> list[list[str]]:
|
||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
|
||||
paragraph, entities=orjson.dumps(entities).decode("utf-8")
|
||||
@@ -141,7 +142,7 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
|
||||
|
||||
def info_extract_from_str(
|
||||
llm_client_for_ner: LLMRequest, llm_client_for_rdf: LLMRequest, paragraph: str
|
||||
) -> Union[tuple[None, None], tuple[list[str], list[list[str]]]]:
|
||||
) -> tuple[None, None] | tuple[list[str], list[list[str]]]:
|
||||
try_count = 0
|
||||
while True:
|
||||
try:
|
||||
|
||||
@@ -1,28 +1,26 @@
|
||||
import orjson
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import orjson
|
||||
import pandas as pd
|
||||
from quick_algo import di_graph, pagerank
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
BarColumn,
|
||||
MofNCompleteColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TaskProgressColumn,
|
||||
TextColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
TaskProgressColumn,
|
||||
MofNCompleteColumn,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
from quick_algo import di_graph, pagerank
|
||||
|
||||
|
||||
from .utils.hash import get_sha256
|
||||
from .embedding_store import EmbeddingManager, EmbeddingStoreItem
|
||||
from src.config.config import global_config
|
||||
|
||||
from .embedding_store import EmbeddingManager, EmbeddingStoreItem
|
||||
from .global_logger import logger
|
||||
from .utils.hash import get_sha256
|
||||
|
||||
|
||||
def _get_kg_dir():
|
||||
@@ -87,7 +85,7 @@ class KGManager:
|
||||
raise FileNotFoundError(f"KG图文件{self.graph_data_path}不存在")
|
||||
|
||||
# 加载段落hash
|
||||
with open(self.pg_hash_file_path, "r", encoding="utf-8") as f:
|
||||
with open(self.pg_hash_file_path, encoding="utf-8") as f:
|
||||
data = orjson.loads(f.read())
|
||||
self.stored_paragraph_hashes = set(data["stored_paragraph_hashes"])
|
||||
|
||||
@@ -100,8 +98,8 @@ class KGManager:
|
||||
|
||||
def _build_edges_between_ent(
|
||||
self,
|
||||
node_to_node: Dict[Tuple[str, str], float],
|
||||
triple_list_data: Dict[str, List[List[str]]],
|
||||
node_to_node: dict[tuple[str, str], float],
|
||||
triple_list_data: dict[str, list[list[str]]],
|
||||
):
|
||||
"""构建实体节点之间的关系,同时统计实体出现次数"""
|
||||
for triple_list in triple_list_data.values():
|
||||
@@ -124,8 +122,8 @@ class KGManager:
|
||||
|
||||
@staticmethod
|
||||
def _build_edges_between_ent_pg(
|
||||
node_to_node: Dict[Tuple[str, str], float],
|
||||
triple_list_data: Dict[str, List[List[str]]],
|
||||
node_to_node: dict[tuple[str, str], float],
|
||||
triple_list_data: dict[str, list[list[str]]],
|
||||
):
|
||||
"""构建实体节点与文段节点之间的关系"""
|
||||
for idx in triple_list_data:
|
||||
@@ -136,8 +134,8 @@ class KGManager:
|
||||
|
||||
@staticmethod
|
||||
def _synonym_connect(
|
||||
node_to_node: Dict[Tuple[str, str], float],
|
||||
triple_list_data: Dict[str, List[List[str]]],
|
||||
node_to_node: dict[tuple[str, str], float],
|
||||
triple_list_data: dict[str, list[list[str]]],
|
||||
embedding_manager: EmbeddingManager,
|
||||
) -> int:
|
||||
"""同义词连接"""
|
||||
@@ -208,7 +206,7 @@ class KGManager:
|
||||
|
||||
def _update_graph(
|
||||
self,
|
||||
node_to_node: Dict[Tuple[str, str], float],
|
||||
node_to_node: dict[tuple[str, str], float],
|
||||
embedding_manager: EmbeddingManager,
|
||||
):
|
||||
"""更新KG图结构
|
||||
@@ -280,7 +278,7 @@ class KGManager:
|
||||
|
||||
def build_kg(
|
||||
self,
|
||||
triple_list_data: Dict[str, List[List[str]]],
|
||||
triple_list_data: dict[str, list[list[str]]],
|
||||
embedding_manager: EmbeddingManager,
|
||||
):
|
||||
"""增量式构建KG
|
||||
@@ -317,8 +315,8 @@ class KGManager:
|
||||
|
||||
def kg_search(
|
||||
self,
|
||||
relation_search_result: List[Tuple[Tuple[str, str, str], float]],
|
||||
paragraph_search_result: List[Tuple[str, float]],
|
||||
relation_search_result: list[tuple[tuple[str, str, str], float]],
|
||||
paragraph_search_result: list[tuple[str, float]],
|
||||
embed_manager: EmbeddingManager,
|
||||
):
|
||||
"""RAG搜索与PageRank
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.qa_manager import QAManager
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.chat.knowledge.global_logger import logger
|
||||
from src.config.config import global_config
|
||||
import os
|
||||
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.global_logger import logger
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.chat.knowledge.qa_manager import QAManager
|
||||
from src.config.config import global_config
|
||||
|
||||
INVALID_ENTITY = [
|
||||
"",
|
||||
"你",
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import orjson
|
||||
import os
|
||||
import glob
|
||||
from typing import Any, Dict, List
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from .knowledge_lib import DATA_PATH, INVALID_ENTITY, ROOT_PATH
|
||||
|
||||
from .knowledge_lib import INVALID_ENTITY, ROOT_PATH, DATA_PATH
|
||||
# from src.manager.local_store_manager import local_storage
|
||||
|
||||
|
||||
def _filter_invalid_entities(entities: List[str]) -> List[str]:
|
||||
def _filter_invalid_entities(entities: list[str]) -> list[str]:
|
||||
"""过滤无效的实体"""
|
||||
valid_entities = set()
|
||||
for entity in entities:
|
||||
@@ -20,7 +21,7 @@ def _filter_invalid_entities(entities: List[str]) -> List[str]:
|
||||
return list(valid_entities)
|
||||
|
||||
|
||||
def _filter_invalid_triples(triples: List[List[str]]) -> List[List[str]]:
|
||||
def _filter_invalid_triples(triples: list[list[str]]) -> list[list[str]]:
|
||||
"""过滤无效的三元组"""
|
||||
unique_triples = set()
|
||||
valid_triples = []
|
||||
@@ -62,7 +63,7 @@ class OpenIE:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
docs: List[Dict[str, Any]],
|
||||
docs: list[dict[str, Any]],
|
||||
avg_ent_chars,
|
||||
avg_ent_words,
|
||||
):
|
||||
@@ -112,7 +113,7 @@ class OpenIE:
|
||||
json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json")))
|
||||
data_list = []
|
||||
for file in json_files:
|
||||
with open(file, "r", encoding="utf-8") as f:
|
||||
with open(file, encoding="utf-8") as f:
|
||||
data = orjson.loads(f.read())
|
||||
data_list.append(data)
|
||||
if not data_list:
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import time
|
||||
from typing import Tuple, List, Dict, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
from .global_logger import logger
|
||||
from .embedding_store import EmbeddingManager
|
||||
from .global_logger import logger
|
||||
from .kg_manager import KGManager
|
||||
|
||||
# from .lpmmconfig import global_config
|
||||
from .utils.dyn_topk import dyn_select_top_k
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.config.config import global_config, model_config
|
||||
|
||||
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
|
||||
|
||||
@@ -26,7 +27,7 @@ class QAManager:
|
||||
|
||||
async def process_query(
|
||||
self, question: str
|
||||
) -> Optional[Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]]:
|
||||
) -> tuple[list[tuple[str, float, float]], dict[str, float] | None] | None:
|
||||
"""处理查询"""
|
||||
|
||||
# 生成问题的Embedding
|
||||
@@ -98,7 +99,7 @@ class QAManager:
|
||||
|
||||
return result, ppr_node_weights
|
||||
|
||||
async def get_knowledge(self, question: str) -> Optional[Dict[str, Any]]:
|
||||
async def get_knowledge(self, question: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
获取知识,返回结构化字典
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import List, Any, Tuple
|
||||
from typing import Any
|
||||
|
||||
|
||||
def dyn_select_top_k(
|
||||
score: List[Tuple[Any, float]], jmp_factor: float, var_factor: float
|
||||
) -> List[Tuple[Any, float, float]]:
|
||||
score: list[tuple[Any, float]], jmp_factor: float, var_factor: float
|
||||
) -> list[tuple[Any, float, float]]:
|
||||
"""动态TopK选择"""
|
||||
# 检查输入列表是否为空
|
||||
if not score:
|
||||
|
||||
@@ -1,37 +1,35 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
简化记忆系统模块
|
||||
移除即时记忆和长期记忆分类,实现统一记忆架构和智能遗忘机制
|
||||
"""
|
||||
|
||||
# 核心数据结构
|
||||
# 激活器
|
||||
from .enhanced_memory_activator import MemoryActivator, enhanced_memory_activator, memory_activator
|
||||
from .memory_chunk import (
|
||||
ConfidenceLevel,
|
||||
ContentStructure,
|
||||
ImportanceLevel,
|
||||
MemoryChunk,
|
||||
MemoryMetadata,
|
||||
ContentStructure,
|
||||
MemoryType,
|
||||
ImportanceLevel,
|
||||
ConfidenceLevel,
|
||||
create_memory_chunk,
|
||||
)
|
||||
|
||||
# 兼容性别名
|
||||
from .memory_chunk import MemoryChunk as Memory
|
||||
|
||||
# 遗忘引擎
|
||||
from .memory_forgetting_engine import MemoryForgettingEngine, ForgettingConfig, get_memory_forgetting_engine
|
||||
|
||||
# Vector DB存储系统
|
||||
from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage
|
||||
|
||||
# 记忆核心系统
|
||||
from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system, initialize_memory_system
|
||||
from .memory_forgetting_engine import ForgettingConfig, MemoryForgettingEngine, get_memory_forgetting_engine
|
||||
|
||||
# 记忆管理器
|
||||
from .memory_manager import MemoryManager, MemoryResult, memory_manager
|
||||
|
||||
# 激活器
|
||||
from .enhanced_memory_activator import MemoryActivator, memory_activator, enhanced_memory_activator
|
||||
# 记忆核心系统
|
||||
from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system, initialize_memory_system
|
||||
|
||||
# 兼容性别名
|
||||
from .memory_chunk import MemoryChunk as Memory
|
||||
# Vector DB存储系统
|
||||
from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage
|
||||
|
||||
__all__ = [
|
||||
# 核心数据结构
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强记忆系统适配器
|
||||
将增强记忆系统集成到现有MoFox Bot架构中
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.integration_layer import MemoryIntegrationLayer, IntegrationConfig, IntegrationMode
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.chat.memory_system.integration_layer import IntegrationConfig, IntegrationMode, MemoryIntegrationLayer
|
||||
from src.chat.memory_system.memory_formatter import FormatterConfig, format_memories_for_llm
|
||||
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -47,10 +47,10 @@ class AdapterConfig:
|
||||
class EnhancedMemoryAdapter:
|
||||
"""增强记忆系统适配器"""
|
||||
|
||||
def __init__(self, llm_model: LLMRequest, config: Optional[AdapterConfig] = None):
|
||||
def __init__(self, llm_model: LLMRequest, config: AdapterConfig | None = None):
|
||||
self.llm_model = llm_model
|
||||
self.config = config or AdapterConfig()
|
||||
self.integration_layer: Optional[MemoryIntegrationLayer] = None
|
||||
self.integration_layer: MemoryIntegrationLayer | None = None
|
||||
self._initialized = False
|
||||
|
||||
# 统计信息
|
||||
@@ -96,7 +96,7 @@ class EnhancedMemoryAdapter:
|
||||
# 如果初始化失败,禁用增强记忆功能
|
||||
self.config.enable_enhanced_memory = False
|
||||
|
||||
async def process_conversation_memory(self, context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
async def process_conversation_memory(self, context: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
"""处理对话记忆,以上下文为唯一输入"""
|
||||
if not self._initialized or not self.config.enable_enhanced_memory:
|
||||
return {"success": False, "error": "Enhanced memory not available"}
|
||||
@@ -105,7 +105,7 @@ class EnhancedMemoryAdapter:
|
||||
self.adapter_stats["total_processed"] += 1
|
||||
|
||||
try:
|
||||
payload_context: Dict[str, Any] = dict(context or {})
|
||||
payload_context: dict[str, Any] = dict(context or {})
|
||||
|
||||
conversation_text = payload_context.get("conversation_text")
|
||||
if not conversation_text:
|
||||
@@ -146,8 +146,8 @@ class EnhancedMemoryAdapter:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def retrieve_relevant_memories(
|
||||
self, query: str, user_id: str, context: Optional[Dict[str, Any]] = None, limit: Optional[int] = None
|
||||
) -> List[MemoryChunk]:
|
||||
self, query: str, user_id: str, context: dict[str, Any] | None = None, limit: int | None = None
|
||||
) -> list[MemoryChunk]:
|
||||
"""检索相关记忆"""
|
||||
if not self._initialized or not self.config.enable_enhanced_memory:
|
||||
return []
|
||||
@@ -166,7 +166,7 @@ class EnhancedMemoryAdapter:
|
||||
return []
|
||||
|
||||
async def get_memory_context_for_prompt(
|
||||
self, query: str, user_id: str, context: Optional[Dict[str, Any]] = None, max_memories: int = 5
|
||||
self, query: str, user_id: str, context: dict[str, Any] | None = None, max_memories: int = 5
|
||||
) -> str:
|
||||
"""获取用于提示词的记忆上下文"""
|
||||
memories = await self.retrieve_relevant_memories(query, user_id, context, max_memories)
|
||||
@@ -186,7 +186,7 @@ class EnhancedMemoryAdapter:
|
||||
|
||||
return format_memories_for_llm(memories=memories, query_context=query, config=formatter_config)
|
||||
|
||||
async def get_enhanced_memory_summary(self, user_id: str) -> Dict[str, Any]:
|
||||
async def get_enhanced_memory_summary(self, user_id: str) -> dict[str, Any]:
|
||||
"""获取增强记忆系统摘要"""
|
||||
if not self._initialized or not self.config.enable_enhanced_memory:
|
||||
return {"available": False, "reason": "Not initialized or disabled"}
|
||||
@@ -222,7 +222,7 @@ class EnhancedMemoryAdapter:
|
||||
new_avg = (current_avg * (total_processed - 1) + processing_time) / total_processed
|
||||
self.adapter_stats["average_processing_time"] = new_avg
|
||||
|
||||
def get_adapter_stats(self) -> Dict[str, Any]:
|
||||
def get_adapter_stats(self) -> dict[str, Any]:
|
||||
"""获取适配器统计信息"""
|
||||
return self.adapter_stats.copy()
|
||||
|
||||
@@ -253,7 +253,7 @@ class EnhancedMemoryAdapter:
|
||||
|
||||
|
||||
# 全局适配器实例
|
||||
_enhanced_memory_adapter: Optional[EnhancedMemoryAdapter] = None
|
||||
_enhanced_memory_adapter: EnhancedMemoryAdapter | None = None
|
||||
|
||||
|
||||
async def get_enhanced_memory_adapter(llm_model: LLMRequest) -> EnhancedMemoryAdapter:
|
||||
@@ -292,8 +292,8 @@ async def initialize_enhanced_memory_system(llm_model: LLMRequest):
|
||||
|
||||
|
||||
async def process_conversation_with_enhanced_memory(
|
||||
context: Dict[str, Any], llm_model: Optional[LLMRequest] = None
|
||||
) -> Dict[str, Any]:
|
||||
context: dict[str, Any], llm_model: LLMRequest | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""使用增强记忆系统处理对话,上下文需包含 conversation_text 等信息"""
|
||||
if not llm_model:
|
||||
# 获取默认的LLM模型
|
||||
@@ -323,10 +323,10 @@ async def process_conversation_with_enhanced_memory(
|
||||
async def retrieve_memories_with_enhanced_system(
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
limit: int = 10,
|
||||
llm_model: Optional[LLMRequest] = None,
|
||||
) -> List[MemoryChunk]:
|
||||
llm_model: LLMRequest | None = None,
|
||||
) -> list[MemoryChunk]:
|
||||
"""使用增强记忆系统检索记忆"""
|
||||
if not llm_model:
|
||||
# 获取默认的LLM模型
|
||||
@@ -345,9 +345,9 @@ async def retrieve_memories_with_enhanced_system(
|
||||
async def get_memory_context_for_prompt(
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
max_memories: int = 5,
|
||||
llm_model: Optional[LLMRequest] = None,
|
||||
llm_model: LLMRequest | None = None,
|
||||
) -> str:
|
||||
"""获取用于提示词的记忆上下文"""
|
||||
if not llm_model:
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强记忆系统钩子
|
||||
用于在消息处理过程中自动构建和检索记忆
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -27,7 +27,7 @@ class EnhancedMemoryHooks:
|
||||
user_id: str,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
处理消息并构建记忆
|
||||
@@ -106,8 +106,8 @@ class EnhancedMemoryHooks:
|
||||
user_id: str,
|
||||
chat_id: str,
|
||||
limit: int = 5,
|
||||
extra_context: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
extra_context: dict[str, Any] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
为回复获取相关记忆
|
||||
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强记忆系统集成脚本
|
||||
用于在现有系统中无缝集成增强记忆功能
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from src.chat.memory_system.enhanced_memory_hooks import enhanced_memory_hooks
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.enhanced_memory_hooks import enhanced_memory_hooks
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def process_user_message_memory(
|
||||
message_content: str, user_id: str, chat_id: str, message_id: str, context: Optional[Dict[str, Any]] = None
|
||||
message_content: str, user_id: str, chat_id: str, message_id: str, context: dict[str, Any] | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
处理用户消息并构建记忆
|
||||
@@ -44,8 +44,8 @@ async def process_user_message_memory(
|
||||
|
||||
|
||||
async def get_relevant_memories_for_response(
|
||||
query_text: str, user_id: str, chat_id: str, limit: int = 5, extra_context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
query_text: str, user_id: str, chat_id: str, limit: int = 5, extra_context: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
为回复获取相关记忆
|
||||
|
||||
@@ -74,7 +74,7 @@ async def get_relevant_memories_for_response(
|
||||
return {"has_memories": False, "memories": [], "memory_count": 0}
|
||||
|
||||
|
||||
def format_memories_for_prompt(memories: Dict[str, Any]) -> str:
|
||||
def format_memories_for_prompt(memories: dict[str, Any]) -> str:
|
||||
"""
|
||||
格式化记忆信息用于Prompt
|
||||
|
||||
@@ -114,7 +114,7 @@ async def cleanup_memory_system():
|
||||
logger.error(f"记忆系统清理失败: {e}")
|
||||
|
||||
|
||||
def get_memory_system_status() -> Dict[str, Any]:
|
||||
def get_memory_system_status() -> dict[str, Any]:
|
||||
"""
|
||||
获取记忆系统状态
|
||||
|
||||
@@ -133,7 +133,7 @@ def get_memory_system_status() -> Dict[str, Any]:
|
||||
|
||||
# 便捷函数
|
||||
async def remember_message(
|
||||
message: str, user_id: str = "default_user", chat_id: str = "default_chat", context: Optional[Dict[str, Any]] = None
|
||||
message: str, user_id: str = "default_user", chat_id: str = "default_chat", context: dict[str, Any] | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
便捷的记忆构建函数
|
||||
@@ -159,8 +159,8 @@ async def recall_memories(
|
||||
user_id: str = "default_user",
|
||||
chat_id: str = "default_chat",
|
||||
limit: int = 5,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
便捷的记忆检索函数
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强重排序器
|
||||
实现文档设计的多维度评分模型
|
||||
@@ -6,12 +5,12 @@
|
||||
|
||||
import math
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -44,7 +43,7 @@ class ReRankingConfig:
|
||||
freq_max_score: float = 5.0 # 最大频率得分
|
||||
|
||||
# 类型匹配权重映射
|
||||
type_match_weights: Dict[str, Dict[str, float]] = None
|
||||
type_match_weights: dict[str, dict[str, float]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化类型匹配权重"""
|
||||
@@ -157,7 +156,7 @@ class IntentClassifier:
|
||||
],
|
||||
}
|
||||
|
||||
def classify_intent(self, query: str, context: Dict[str, Any]) -> IntentType:
|
||||
def classify_intent(self, query: str, context: dict[str, Any]) -> IntentType:
|
||||
"""识别对话意图"""
|
||||
if not query:
|
||||
return IntentType.UNKNOWN
|
||||
@@ -165,7 +164,7 @@ class IntentClassifier:
|
||||
query_lower = query.lower()
|
||||
|
||||
# 统计各意图的匹配分数
|
||||
intent_scores = {intent: 0 for intent in IntentType}
|
||||
intent_scores = dict.fromkeys(IntentType, 0)
|
||||
|
||||
for intent, patterns in self.patterns.items():
|
||||
for pattern in patterns:
|
||||
@@ -187,7 +186,7 @@ class IntentClassifier:
|
||||
class EnhancedReRanker:
|
||||
"""增强重排序器 - 实现文档设计的多维度评分模型"""
|
||||
|
||||
def __init__(self, config: Optional[ReRankingConfig] = None):
|
||||
def __init__(self, config: ReRankingConfig | None = None):
|
||||
self.config = config or ReRankingConfig()
|
||||
self.intent_classifier = IntentClassifier()
|
||||
|
||||
@@ -210,10 +209,10 @@ class EnhancedReRanker:
|
||||
def rerank_memories(
|
||||
self,
|
||||
query: str,
|
||||
candidate_memories: List[Tuple[str, MemoryChunk, float]], # (memory_id, memory, vector_similarity)
|
||||
context: Dict[str, Any],
|
||||
candidate_memories: list[tuple[str, MemoryChunk, float]], # (memory_id, memory, vector_similarity)
|
||||
context: dict[str, Any],
|
||||
limit: int = 10,
|
||||
) -> List[Tuple[str, MemoryChunk, float]]:
|
||||
) -> list[tuple[str, MemoryChunk, float]]:
|
||||
"""
|
||||
对候选记忆进行重排序
|
||||
|
||||
@@ -341,11 +340,11 @@ default_reranker = EnhancedReRanker()
|
||||
|
||||
def rerank_candidate_memories(
|
||||
query: str,
|
||||
candidate_memories: List[Tuple[str, MemoryChunk, float]],
|
||||
context: Dict[str, Any],
|
||||
candidate_memories: list[tuple[str, MemoryChunk, float]],
|
||||
context: dict[str, Any],
|
||||
limit: int = 10,
|
||||
config: Optional[ReRankingConfig] = None,
|
||||
) -> List[Tuple[str, MemoryChunk, float]]:
|
||||
config: ReRankingConfig | None = None,
|
||||
) -> list[tuple[str, MemoryChunk, float]]:
|
||||
"""
|
||||
便捷函数:对候选记忆进行重排序
|
||||
"""
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
增强记忆系统集成层
|
||||
现在只管理新的增强记忆系统,旧系统已被完全移除
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Any
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.enhanced_memory_core import EnhancedMemorySystem
|
||||
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -40,12 +40,12 @@ class IntegrationConfig:
|
||||
class MemoryIntegrationLayer:
|
||||
"""记忆系统集成层 - 现在只管理增强记忆系统"""
|
||||
|
||||
def __init__(self, llm_model: LLMRequest, config: Optional[IntegrationConfig] = None):
|
||||
def __init__(self, llm_model: LLMRequest, config: IntegrationConfig | None = None):
|
||||
self.llm_model = llm_model
|
||||
self.config = config or IntegrationConfig()
|
||||
|
||||
# 只初始化增强记忆系统
|
||||
self.enhanced_memory: Optional[EnhancedMemorySystem] = None
|
||||
self.enhanced_memory: EnhancedMemorySystem | None = None
|
||||
|
||||
# 集成统计
|
||||
self.integration_stats = {
|
||||
@@ -113,7 +113,7 @@ class MemoryIntegrationLayer:
|
||||
logger.error(f"❌ 增强记忆系统初始化失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def process_conversation(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def process_conversation(self, context: dict[str, Any]) -> dict[str, Any]:
|
||||
"""处理对话记忆,仅使用上下文信息"""
|
||||
if not self._initialized or not self.enhanced_memory:
|
||||
return {"success": False, "error": "Memory system not available"}
|
||||
@@ -150,10 +150,10 @@ class MemoryIntegrationLayer:
|
||||
async def retrieve_relevant_memories(
|
||||
self,
|
||||
query: str,
|
||||
user_id: Optional[str] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[MemoryChunk]:
|
||||
user_id: str | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[MemoryChunk]:
|
||||
"""检索相关记忆"""
|
||||
if not self._initialized or not self.enhanced_memory:
|
||||
return []
|
||||
@@ -172,7 +172,7 @@ class MemoryIntegrationLayer:
|
||||
logger.error(f"检索相关记忆失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def get_system_status(self) -> Dict[str, Any]:
|
||||
async def get_system_status(self) -> dict[str, Any]:
|
||||
"""获取系统状态"""
|
||||
if not self._initialized:
|
||||
return {"status": "not_initialized"}
|
||||
@@ -193,7 +193,7 @@ class MemoryIntegrationLayer:
|
||||
logger.error(f"获取系统状态失败: {e}", exc_info=True)
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def get_integration_stats(self) -> Dict[str, Any]:
|
||||
def get_integration_stats(self) -> dict[str, Any]:
|
||||
"""获取集成统计信息"""
|
||||
return self.integration_stats.copy()
|
||||
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆系统集成钩子
|
||||
提供与现有MoFox Bot系统的无缝集成点
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.enhanced_memory_adapter import (
|
||||
get_memory_context_for_prompt,
|
||||
process_conversation_with_enhanced_memory,
|
||||
retrieve_memories_with_enhanced_system,
|
||||
get_memory_context_for_prompt,
|
||||
)
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ class HookResult:
|
||||
|
||||
success: bool
|
||||
data: Any = None
|
||||
error: Optional[str] = None
|
||||
error: str | None = None
|
||||
processing_time: float = 0.0
|
||||
|
||||
|
||||
@@ -125,8 +125,8 @@ class MemoryIntegrationHooks:
|
||||
|
||||
# 尝试注册到事件系统
|
||||
try:
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from src.plugin_system.base.component_types import EventType
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
|
||||
# 注册消息后处理事件
|
||||
event_manager.subscribe(EventType.MESSAGE_PROCESSED, self._on_message_processed_handler)
|
||||
@@ -238,11 +238,11 @@ class MemoryIntegrationHooks:
|
||||
|
||||
# 钩子处理器方法
|
||||
|
||||
async def _on_message_processed_handler(self, event_data: Dict[str, Any]) -> HookResult:
|
||||
async def _on_message_processed_handler(self, event_data: dict[str, Any]) -> HookResult:
|
||||
"""事件系统的消息处理处理器"""
|
||||
return await self._on_message_processed_hook(event_data)
|
||||
|
||||
async def _on_message_processed_hook(self, message_data: Dict[str, Any]) -> HookResult:
|
||||
async def _on_message_processed_hook(self, message_data: dict[str, Any]) -> HookResult:
|
||||
"""消息后处理钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -289,7 +289,7 @@ class MemoryIntegrationHooks:
|
||||
logger.error(f"消息处理钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
async def _on_chat_stream_save_hook(self, chat_stream_data: Dict[str, Any]) -> HookResult:
|
||||
async def _on_chat_stream_save_hook(self, chat_stream_data: dict[str, Any]) -> HookResult:
|
||||
"""聊天流保存钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -345,7 +345,7 @@ class MemoryIntegrationHooks:
|
||||
logger.error(f"聊天流保存钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
async def _on_pre_response_hook(self, response_data: Dict[str, Any]) -> HookResult:
|
||||
async def _on_pre_response_hook(self, response_data: dict[str, Any]) -> HookResult:
|
||||
"""回复前钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -380,7 +380,7 @@ class MemoryIntegrationHooks:
|
||||
logger.error(f"回复前钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
async def _on_knowledge_query_hook(self, query_data: Dict[str, Any]) -> HookResult:
|
||||
async def _on_knowledge_query_hook(self, query_data: dict[str, Any]) -> HookResult:
|
||||
"""知识库查询钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -411,7 +411,7 @@ class MemoryIntegrationHooks:
|
||||
logger.error(f"知识库查询钩子执行异常: {e}", exc_info=True)
|
||||
return HookResult(success=False, error=str(e), processing_time=processing_time)
|
||||
|
||||
async def _on_prompt_building_hook(self, prompt_data: Dict[str, Any]) -> HookResult:
|
||||
async def _on_prompt_building_hook(self, prompt_data: dict[str, Any]) -> HookResult:
|
||||
"""提示词构建钩子"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -459,7 +459,7 @@ class MemoryIntegrationHooks:
|
||||
new_avg = (current_avg * (total_executions - 1) + processing_time) / total_executions
|
||||
self.hook_stats["average_hook_time"] = new_avg
|
||||
|
||||
def get_hook_stats(self) -> Dict[str, Any]:
|
||||
def get_hook_stats(self) -> dict[str, Any]:
|
||||
"""获取钩子统计信息"""
|
||||
return self.hook_stats.copy()
|
||||
|
||||
@@ -501,7 +501,7 @@ class MemoryMaintenanceTask:
|
||||
|
||||
|
||||
# 全局钩子实例
|
||||
_memory_hooks: Optional[MemoryIntegrationHooks] = None
|
||||
_memory_hooks: MemoryIntegrationHooks | None = None
|
||||
|
||||
|
||||
async def get_memory_integration_hooks() -> MemoryIntegrationHooks:
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
元数据索引系统
|
||||
为记忆系统提供多维度的精准过滤和查询能力
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
import orjson
|
||||
from typing import Dict, List, Optional, Tuple, Set, Any, Union
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk, MemoryType
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -40,21 +40,21 @@ class IndexType(Enum):
|
||||
class IndexQuery:
|
||||
"""索引查询条件"""
|
||||
|
||||
user_ids: Optional[List[str]] = None
|
||||
memory_types: Optional[List[MemoryType]] = None
|
||||
subjects: Optional[List[str]] = None
|
||||
keywords: Optional[List[str]] = None
|
||||
tags: Optional[List[str]] = None
|
||||
categories: Optional[List[str]] = None
|
||||
time_range: Optional[Tuple[float, float]] = None
|
||||
confidence_levels: Optional[List[ConfidenceLevel]] = None
|
||||
importance_levels: Optional[List[ImportanceLevel]] = None
|
||||
min_relationship_score: Optional[float] = None
|
||||
max_relationship_score: Optional[float] = None
|
||||
min_access_count: Optional[int] = None
|
||||
semantic_hashes: Optional[List[str]] = None
|
||||
limit: Optional[int] = None
|
||||
sort_by: Optional[str] = None # "created_at", "access_count", "relevance_score"
|
||||
user_ids: list[str] | None = None
|
||||
memory_types: list[MemoryType] | None = None
|
||||
subjects: list[str] | None = None
|
||||
keywords: list[str] | None = None
|
||||
tags: list[str] | None = None
|
||||
categories: list[str] | None = None
|
||||
time_range: tuple[float, float] | None = None
|
||||
confidence_levels: list[ConfidenceLevel] | None = None
|
||||
importance_levels: list[ImportanceLevel] | None = None
|
||||
min_relationship_score: float | None = None
|
||||
max_relationship_score: float | None = None
|
||||
min_access_count: int | None = None
|
||||
semantic_hashes: list[str] | None = None
|
||||
limit: int | None = None
|
||||
sort_by: str | None = None # "created_at", "access_count", "relevance_score"
|
||||
sort_order: str = "desc" # "asc", "desc"
|
||||
|
||||
|
||||
@@ -62,10 +62,10 @@ class IndexQuery:
|
||||
class IndexResult:
|
||||
"""索引结果"""
|
||||
|
||||
memory_ids: List[str]
|
||||
memory_ids: list[str]
|
||||
total_count: int
|
||||
query_time: float
|
||||
filtered_by: List[str]
|
||||
filtered_by: list[str]
|
||||
|
||||
|
||||
class MetadataIndexManager:
|
||||
@@ -94,7 +94,7 @@ class MetadataIndexManager:
|
||||
self.access_frequency_index = [] # [(access_count, memory_id), ...]
|
||||
|
||||
# 内存缓存
|
||||
self.memory_metadata_cache: Dict[str, Dict[str, Any]] = {}
|
||||
self.memory_metadata_cache: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# 统计信息
|
||||
self.index_stats = {
|
||||
@@ -140,7 +140,7 @@ class MetadataIndexManager:
|
||||
return key
|
||||
|
||||
@staticmethod
|
||||
def _serialize_metadata_entry(metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _serialize_metadata_entry(metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
serialized = {}
|
||||
for field_name, value in metadata.items():
|
||||
if isinstance(value, Enum):
|
||||
@@ -149,7 +149,7 @@ class MetadataIndexManager:
|
||||
serialized[field_name] = value
|
||||
return serialized
|
||||
|
||||
async def index_memories(self, memories: List[MemoryChunk]):
|
||||
async def index_memories(self, memories: list[MemoryChunk]):
|
||||
"""为记忆建立索引"""
|
||||
if not memories:
|
||||
return
|
||||
@@ -375,7 +375,7 @@ class MetadataIndexManager:
|
||||
logger.error(f"❌ 元数据查询失败: {e}", exc_info=True)
|
||||
return IndexResult(memory_ids=[], total_count=0, query_time=0.0, filtered_by=[])
|
||||
|
||||
def _get_candidate_memories(self, query: IndexQuery) -> Set[str]:
|
||||
def _get_candidate_memories(self, query: IndexQuery) -> set[str]:
|
||||
"""获取候选记忆ID集合"""
|
||||
candidate_ids = set()
|
||||
|
||||
@@ -444,7 +444,7 @@ class MetadataIndexManager:
|
||||
|
||||
return candidate_ids
|
||||
|
||||
def _collect_index_matches(self, index_type: IndexType, token: Optional[Union[str, Enum]]) -> Set[str]:
|
||||
def _collect_index_matches(self, index_type: IndexType, token: str | Enum | None) -> set[str]:
|
||||
"""根据给定token收集索引匹配,支持部分匹配"""
|
||||
mapping = self.indices.get(index_type)
|
||||
if mapping is None:
|
||||
@@ -461,7 +461,7 @@ class MetadataIndexManager:
|
||||
if not key:
|
||||
return set()
|
||||
|
||||
matches: Set[str] = set(mapping.get(key, set()))
|
||||
matches: set[str] = set(mapping.get(key, set()))
|
||||
|
||||
if matches:
|
||||
return set(matches)
|
||||
@@ -477,7 +477,7 @@ class MetadataIndexManager:
|
||||
|
||||
return matches
|
||||
|
||||
def _apply_filters(self, candidate_ids: Set[str], query: IndexQuery) -> List[str]:
|
||||
def _apply_filters(self, candidate_ids: set[str], query: IndexQuery) -> list[str]:
|
||||
"""应用过滤条件"""
|
||||
filtered_ids = list(candidate_ids)
|
||||
|
||||
@@ -545,7 +545,7 @@ class MetadataIndexManager:
|
||||
created_at = self.memory_metadata_cache[memory_id]["created_at"]
|
||||
return start_time <= created_at <= end_time
|
||||
|
||||
def _sort_memories(self, memory_ids: List[str], sort_by: str, sort_order: str) -> List[str]:
|
||||
def _sort_memories(self, memory_ids: list[str], sort_by: str, sort_order: str) -> list[str]:
|
||||
"""对记忆进行排序"""
|
||||
if sort_by == "created_at":
|
||||
# 使用时间索引(已经有序)
|
||||
@@ -582,7 +582,7 @@ class MetadataIndexManager:
|
||||
|
||||
return memory_ids
|
||||
|
||||
def _get_applied_filters(self, query: IndexQuery) -> List[str]:
|
||||
def _get_applied_filters(self, query: IndexQuery) -> list[str]:
|
||||
"""获取应用的过滤器列表"""
|
||||
filters = []
|
||||
if query.memory_types:
|
||||
@@ -686,11 +686,11 @@ class MetadataIndexManager:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 移除记忆索引失败: {e}")
|
||||
|
||||
async def get_memory_metadata(self, memory_id: str) -> Optional[Dict[str, Any]]:
|
||||
async def get_memory_metadata(self, memory_id: str) -> dict[str, Any] | None:
|
||||
"""获取记忆元数据"""
|
||||
return self.memory_metadata_cache.get(memory_id)
|
||||
|
||||
async def get_user_memory_ids(self, user_id: str, limit: Optional[int] = None) -> List[str]:
|
||||
async def get_user_memory_ids(self, user_id: str, limit: int | None = None) -> list[str]:
|
||||
"""获取用户的所有记忆ID"""
|
||||
user_memory_ids = list(self.indices[IndexType.USER_ID].get(user_id, set()))
|
||||
|
||||
@@ -699,7 +699,7 @@ class MetadataIndexManager:
|
||||
|
||||
return user_memory_ids
|
||||
|
||||
async def get_memory_statistics(self, user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
async def get_memory_statistics(self, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""获取记忆统计信息"""
|
||||
stats = {
|
||||
"total_memories": self.index_stats["total_memories"],
|
||||
@@ -784,7 +784,7 @@ class MetadataIndexManager:
|
||||
logger.info("正在保存元数据索引...")
|
||||
|
||||
# 保存各类索引
|
||||
indices_data: Dict[str, Dict[str, List[str]]] = {}
|
||||
indices_data: dict[str, dict[str, list[str]]] = {}
|
||||
for index_type, index_data in self.indices.items():
|
||||
serialized_index = {}
|
||||
for key, values in index_data.items():
|
||||
@@ -839,7 +839,7 @@ class MetadataIndexManager:
|
||||
# 加载各类索引
|
||||
indices_file = self.index_path / "indices.json"
|
||||
if indices_file.exists():
|
||||
with open(indices_file, "r", encoding="utf-8") as f:
|
||||
with open(indices_file, encoding="utf-8") as f:
|
||||
indices_data = orjson.loads(f.read())
|
||||
|
||||
for index_type_value, index_data in indices_data.items():
|
||||
@@ -853,25 +853,25 @@ class MetadataIndexManager:
|
||||
# 加载时间索引
|
||||
time_index_file = self.index_path / "time_index.json"
|
||||
if time_index_file.exists():
|
||||
with open(time_index_file, "r", encoding="utf-8") as f:
|
||||
with open(time_index_file, encoding="utf-8") as f:
|
||||
self.time_index = orjson.loads(f.read())
|
||||
|
||||
# 加载关系分索引
|
||||
relationship_index_file = self.index_path / "relationship_index.json"
|
||||
if relationship_index_file.exists():
|
||||
with open(relationship_index_file, "r", encoding="utf-8") as f:
|
||||
with open(relationship_index_file, encoding="utf-8") as f:
|
||||
self.relationship_index = orjson.loads(f.read())
|
||||
|
||||
# 加载访问频率索引
|
||||
access_frequency_index_file = self.index_path / "access_frequency_index.json"
|
||||
if access_frequency_index_file.exists():
|
||||
with open(access_frequency_index_file, "r", encoding="utf-8") as f:
|
||||
with open(access_frequency_index_file, encoding="utf-8") as f:
|
||||
self.access_frequency_index = orjson.loads(f.read())
|
||||
|
||||
# 加载元数据缓存
|
||||
metadata_cache_file = self.index_path / "metadata_cache.json"
|
||||
if metadata_cache_file.exists():
|
||||
with open(metadata_cache_file, "r", encoding="utf-8") as f:
|
||||
with open(metadata_cache_file, encoding="utf-8") as f:
|
||||
cache_data = orjson.loads(f.read())
|
||||
|
||||
# 转换置信度和重要性为枚举类型
|
||||
@@ -914,7 +914,7 @@ class MetadataIndexManager:
|
||||
# 加载统计信息
|
||||
stats_file = self.index_path / "index_stats.json"
|
||||
if stats_file.exists():
|
||||
with open(stats_file, "r", encoding="utf-8") as f:
|
||||
with open(stats_file, encoding="utf-8") as f:
|
||||
self.index_stats = orjson.loads(f.read())
|
||||
|
||||
# 更新记忆计数
|
||||
@@ -1004,7 +1004,7 @@ class MetadataIndexManager:
|
||||
if len(self.indices[IndexType.CATEGORY][category]) < min_frequency:
|
||||
del self.indices[IndexType.CATEGORY][category]
|
||||
|
||||
def get_index_stats(self) -> Dict[str, Any]:
|
||||
def get_index_stats(self) -> dict[str, Any]:
|
||||
"""获取索引统计信息"""
|
||||
stats = self.index_stats.copy()
|
||||
if stats["total_queries"] > 0:
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
多阶段召回机制
|
||||
实现粗粒度到细粒度的记忆检索优化
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional, Set, Any
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import orjson
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
import orjson
|
||||
from src.chat.memory_system.enhanced_reranker import EnhancedReRanker, ReRankingConfig
|
||||
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -73,11 +73,11 @@ class StageResult:
|
||||
"""阶段结果"""
|
||||
|
||||
stage: RetrievalStage
|
||||
memory_ids: List[str]
|
||||
memory_ids: list[str]
|
||||
processing_time: float
|
||||
filtered_count: int
|
||||
score_threshold: float
|
||||
details: List[Dict[str, Any]] = field(default_factory=list)
|
||||
details: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -86,17 +86,17 @@ class RetrievalResult:
|
||||
|
||||
query: str
|
||||
user_id: str
|
||||
final_memories: List[MemoryChunk]
|
||||
stage_results: List[StageResult]
|
||||
final_memories: list[MemoryChunk]
|
||||
stage_results: list[StageResult]
|
||||
total_processing_time: float
|
||||
total_filtered: int
|
||||
retrieval_stats: Dict[str, Any]
|
||||
retrieval_stats: dict[str, Any]
|
||||
|
||||
|
||||
class MultiStageRetrieval:
|
||||
"""多阶段召回系统"""
|
||||
|
||||
def __init__(self, config: Optional[RetrievalConfig] = None):
|
||||
def __init__(self, config: RetrievalConfig | None = None):
|
||||
self.config = config or RetrievalConfig.from_global_config()
|
||||
|
||||
# 初始化增强重排序器
|
||||
@@ -124,11 +124,11 @@ class MultiStageRetrieval:
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
metadata_index,
|
||||
vector_storage,
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
limit: Optional[int] = None,
|
||||
all_memories_cache: dict[str, MemoryChunk],
|
||||
limit: int | None = None,
|
||||
) -> RetrievalResult:
|
||||
"""多阶段记忆检索"""
|
||||
start_time = time.time()
|
||||
@@ -136,7 +136,7 @@ class MultiStageRetrieval:
|
||||
|
||||
stage_results = []
|
||||
current_memory_ids = set()
|
||||
memory_debug_info: Dict[str, Dict[str, Any]] = {}
|
||||
memory_debug_info: dict[str, dict[str, Any]] = {}
|
||||
|
||||
try:
|
||||
logger.debug(f"开始多阶段检索:query='{query}', user_id='{user_id}'")
|
||||
@@ -311,11 +311,11 @@ class MultiStageRetrieval:
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
metadata_index,
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
all_memories_cache: dict[str, MemoryChunk],
|
||||
*,
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
debug_log: dict[str, dict[str, Any]] | None = None,
|
||||
) -> StageResult:
|
||||
"""阶段1:元数据过滤"""
|
||||
start_time = time.time()
|
||||
@@ -345,7 +345,7 @@ class MultiStageRetrieval:
|
||||
result = await metadata_index.query_memories(index_query)
|
||||
result_ids = list(result.memory_ids)
|
||||
filtered_count = max(0, len(all_memories_cache) - len(result_ids))
|
||||
details: List[Dict[str, Any]] = []
|
||||
details: list[dict[str, Any]] = []
|
||||
|
||||
# 如果未命中任何索引且未指定所有者过滤,则回退到最近访问的记忆
|
||||
if not result_ids:
|
||||
@@ -440,12 +440,12 @@ class MultiStageRetrieval:
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
vector_storage,
|
||||
candidate_ids: Set[str],
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
candidate_ids: set[str],
|
||||
all_memories_cache: dict[str, MemoryChunk],
|
||||
*,
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
debug_log: dict[str, dict[str, Any]] | None = None,
|
||||
) -> StageResult:
|
||||
"""阶段2:向量搜索"""
|
||||
start_time = time.time()
|
||||
@@ -479,8 +479,8 @@ class MultiStageRetrieval:
|
||||
|
||||
# 过滤候选记忆
|
||||
filtered_memories = []
|
||||
details: List[Dict[str, Any]] = []
|
||||
raw_details: List[Dict[str, Any]] = []
|
||||
details: list[dict[str, Any]] = []
|
||||
raw_details: list[dict[str, Any]] = []
|
||||
threshold = self.config.vector_similarity_threshold
|
||||
|
||||
for memory_id, similarity in search_result:
|
||||
@@ -561,7 +561,7 @@ class MultiStageRetrieval:
|
||||
)
|
||||
|
||||
def _create_text_search_fallback(
|
||||
self, candidate_ids: Set[str], all_memories_cache: Dict[str, MemoryChunk], query_text: str, start_time: float
|
||||
self, candidate_ids: set[str], all_memories_cache: dict[str, MemoryChunk], query_text: str, start_time: float
|
||||
) -> StageResult:
|
||||
"""当向量搜索失败时,使用文本搜索作为回退策略"""
|
||||
try:
|
||||
@@ -618,18 +618,18 @@ class MultiStageRetrieval:
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
candidate_ids: Set[str],
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
context: dict[str, Any],
|
||||
candidate_ids: set[str],
|
||||
all_memories_cache: dict[str, MemoryChunk],
|
||||
*,
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
debug_log: dict[str, dict[str, Any]] | None = None,
|
||||
) -> StageResult:
|
||||
"""阶段3:语义重排序"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
reranked_memories = []
|
||||
details: List[Dict[str, Any]] = []
|
||||
details: list[dict[str, Any]] = []
|
||||
threshold = self.config.semantic_similarity_threshold
|
||||
|
||||
for memory_id in candidate_ids:
|
||||
@@ -704,19 +704,19 @@ class MultiStageRetrieval:
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
candidate_ids: List[str],
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
context: dict[str, Any],
|
||||
candidate_ids: list[str],
|
||||
all_memories_cache: dict[str, MemoryChunk],
|
||||
limit: int,
|
||||
*,
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
debug_log: dict[str, dict[str, Any]] | None = None,
|
||||
) -> StageResult:
|
||||
"""阶段4:上下文过滤"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
final_memories = []
|
||||
details: List[Dict[str, Any]] = []
|
||||
details: list[dict[str, Any]] = []
|
||||
|
||||
for memory_id in candidate_ids:
|
||||
if memory_id not in all_memories_cache:
|
||||
@@ -793,12 +793,12 @@ class MultiStageRetrieval:
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
context: dict[str, Any],
|
||||
all_memories_cache: dict[str, MemoryChunk],
|
||||
limit: int,
|
||||
*,
|
||||
excluded_ids: Optional[Set[str]] = None,
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
excluded_ids: set[str] | None = None,
|
||||
debug_log: dict[str, dict[str, Any]] | None = None,
|
||||
) -> StageResult:
|
||||
"""回退检索阶段 - 当主检索失败时使用更宽松的策略"""
|
||||
start_time = time.time()
|
||||
@@ -881,8 +881,8 @@ class MultiStageRetrieval:
|
||||
)
|
||||
|
||||
async def _generate_query_embedding(
|
||||
self, query: str, context: Dict[str, Any], vector_storage
|
||||
) -> Optional[List[float]]:
|
||||
self, query: str, context: dict[str, Any], vector_storage
|
||||
) -> list[float] | None:
|
||||
"""生成查询向量"""
|
||||
try:
|
||||
query_plan = context.get("query_plan")
|
||||
@@ -916,7 +916,7 @@ class MultiStageRetrieval:
|
||||
logger.error(f"生成查询向量时发生异常: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def _calculate_semantic_similarity(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
|
||||
async def _calculate_semantic_similarity(self, query: str, memory: MemoryChunk, context: dict[str, Any]) -> float:
|
||||
"""计算语义相似度 - 简化优化版本,提升召回率"""
|
||||
try:
|
||||
query_plan = context.get("query_plan")
|
||||
@@ -947,9 +947,10 @@ class MultiStageRetrieval:
|
||||
# 核心匹配策略2:词汇匹配
|
||||
word_score = 0.0
|
||||
try:
|
||||
import jieba
|
||||
import re
|
||||
|
||||
import jieba
|
||||
|
||||
# 分词处理
|
||||
query_words = list(jieba.cut(query_text)) + re.findall(r"[a-zA-Z]+", query_text)
|
||||
memory_words = list(jieba.cut(memory_text)) + re.findall(r"[a-zA-Z]+", memory_text)
|
||||
@@ -1059,7 +1060,7 @@ class MultiStageRetrieval:
|
||||
logger.warning(f"计算语义相似度失败: {e}")
|
||||
return 0.0
|
||||
|
||||
async def _calculate_context_relevance(self, query: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
|
||||
async def _calculate_context_relevance(self, query: str, memory: MemoryChunk, context: dict[str, Any]) -> float:
|
||||
"""计算上下文相关度"""
|
||||
try:
|
||||
score = 0.0
|
||||
@@ -1132,7 +1133,7 @@ class MultiStageRetrieval:
|
||||
return 0.0
|
||||
|
||||
async def _calculate_final_score(
|
||||
self, query: str, memory: MemoryChunk, context: Dict[str, Any], context_score: float
|
||||
self, query: str, memory: MemoryChunk, context: dict[str, Any], context_score: float
|
||||
) -> float:
|
||||
"""计算最终评分"""
|
||||
try:
|
||||
@@ -1184,7 +1185,7 @@ class MultiStageRetrieval:
|
||||
logger.warning(f"计算最终评分失败: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_subject_overlap(self, memory: MemoryChunk, required_subjects: Optional[List[str]]) -> float:
|
||||
def _calculate_subject_overlap(self, memory: MemoryChunk, required_subjects: list[str] | None) -> float:
|
||||
if not required_subjects:
|
||||
return 0.0
|
||||
|
||||
@@ -1229,7 +1230,7 @@ class MultiStageRetrieval:
|
||||
except Exception:
|
||||
return 0.5
|
||||
|
||||
def _extract_memory_types_from_context(self, context: Dict[str, Any]) -> List[MemoryType]:
|
||||
def _extract_memory_types_from_context(self, context: dict[str, Any]) -> list[MemoryType]:
|
||||
"""从上下文中提取记忆类型"""
|
||||
try:
|
||||
query_plan = context.get("query_plan")
|
||||
@@ -1256,10 +1257,10 @@ class MultiStageRetrieval:
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _extract_keywords_from_query(self, query: str, query_plan: Optional[Any] = None) -> List[str]:
|
||||
def _extract_keywords_from_query(self, query: str, query_plan: Any | None = None) -> list[str]:
|
||||
"""从查询中提取关键词"""
|
||||
try:
|
||||
extracted: List[str] = []
|
||||
extracted: list[str] = []
|
||||
|
||||
if query_plan and getattr(query_plan, "required_keywords", None):
|
||||
extracted.extend([kw.lower() for kw in query_plan.required_keywords if isinstance(kw, str)])
|
||||
@@ -1283,7 +1284,7 @@ class MultiStageRetrieval:
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _update_retrieval_stats(self, total_time: float, stage_results: List[StageResult]):
|
||||
def _update_retrieval_stats(self, total_time: float, stage_results: list[StageResult]):
|
||||
"""更新检索统计"""
|
||||
self.retrieval_stats["total_queries"] += 1
|
||||
|
||||
@@ -1306,7 +1307,7 @@ class MultiStageRetrieval:
|
||||
]
|
||||
stage_stat["avg_time"] = new_stage_avg
|
||||
|
||||
def get_retrieval_stats(self) -> Dict[str, Any]:
|
||||
def get_retrieval_stats(self) -> dict[str, Any]:
|
||||
"""获取检索统计信息"""
|
||||
return self.retrieval_stats.copy()
|
||||
|
||||
@@ -1328,12 +1329,12 @@ class MultiStageRetrieval:
|
||||
self,
|
||||
query: str,
|
||||
user_id: str,
|
||||
context: Dict[str, Any],
|
||||
candidate_ids: List[str],
|
||||
all_memories_cache: Dict[str, MemoryChunk],
|
||||
context: dict[str, Any],
|
||||
candidate_ids: list[str],
|
||||
all_memories_cache: dict[str, MemoryChunk],
|
||||
limit: int,
|
||||
*,
|
||||
debug_log: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
debug_log: dict[str, dict[str, Any]] | None = None,
|
||||
) -> StageResult:
|
||||
"""阶段5:增强重排序 - 使用多维度评分模型"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -1,24 +1,23 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
向量数据库存储接口
|
||||
为记忆系统提供高效的向量存储和语义搜索能力
|
||||
"""
|
||||
|
||||
import time
|
||||
import orjson
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -48,7 +47,7 @@ class VectorStorageConfig:
|
||||
class VectorStorageManager:
|
||||
"""向量存储管理器"""
|
||||
|
||||
def __init__(self, config: Optional[VectorStorageConfig] = None):
|
||||
def __init__(self, config: VectorStorageConfig | None = None):
|
||||
self.config = config or VectorStorageConfig()
|
||||
|
||||
resolved_dimension = resolve_embedding_dimension(self.config.dimension)
|
||||
@@ -68,8 +67,8 @@ class VectorStorageManager:
|
||||
self.index_to_memory_id = {} # vector index -> memory_id
|
||||
|
||||
# 内存缓存
|
||||
self.memory_cache: Dict[str, MemoryChunk] = {}
|
||||
self.vector_cache: Dict[str, List[float]] = {}
|
||||
self.memory_cache: dict[str, MemoryChunk] = {}
|
||||
self.vector_cache: dict[str, list[float]] = {}
|
||||
|
||||
# 统计信息
|
||||
self.storage_stats = {
|
||||
@@ -125,7 +124,7 @@ class VectorStorageManager:
|
||||
)
|
||||
logger.info("✅ 嵌入模型初始化完成")
|
||||
|
||||
async def generate_query_embedding(self, query_text: str) -> Optional[List[float]]:
|
||||
async def generate_query_embedding(self, query_text: str) -> list[float] | None:
|
||||
"""生成查询向量,用于记忆召回"""
|
||||
if not query_text:
|
||||
logger.warning("查询文本为空,无法生成向量")
|
||||
@@ -155,7 +154,7 @@ class VectorStorageManager:
|
||||
logger.error(f"❌ 生成查询向量失败: {exc}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def store_memories(self, memories: List[MemoryChunk]):
|
||||
async def store_memories(self, memories: list[MemoryChunk]):
|
||||
"""存储记忆向量"""
|
||||
if not memories:
|
||||
return
|
||||
@@ -231,7 +230,7 @@ class VectorStorageManager:
|
||||
logger.debug("记忆 %s 缺少可用展示文本,使用占位符生成嵌入输入", memory.memory_id)
|
||||
return memory.memory_id
|
||||
|
||||
async def _batch_generate_and_store_embeddings(self, memory_texts: List[Tuple[str, str]]):
|
||||
async def _batch_generate_and_store_embeddings(self, memory_texts: list[tuple[str, str]]):
|
||||
"""批量生成和存储嵌入向量"""
|
||||
if not memory_texts:
|
||||
return
|
||||
@@ -253,12 +252,12 @@ class VectorStorageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 批量生成嵌入向量失败: {e}")
|
||||
|
||||
async def _batch_generate_embeddings(self, memory_ids: List[str], texts: List[str]) -> Dict[str, List[float]]:
|
||||
async def _batch_generate_embeddings(self, memory_ids: list[str], texts: list[str]) -> dict[str, list[float]]:
|
||||
"""批量生成嵌入向量"""
|
||||
if not texts:
|
||||
return {}
|
||||
|
||||
results: Dict[str, List[float]] = {}
|
||||
results: dict[str, list[float]] = {}
|
||||
|
||||
try:
|
||||
semaphore = asyncio.Semaphore(min(4, max(1, len(texts))))
|
||||
@@ -281,7 +280,9 @@ class VectorStorageManager:
|
||||
logger.warning("生成记忆 %s 的嵌入向量失败: %s", memory_id, exc)
|
||||
results[memory_id] = []
|
||||
|
||||
tasks = [asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts, strict=False)]
|
||||
tasks = [
|
||||
asyncio.create_task(generate_embedding(mid, text)) for mid, text in zip(memory_ids, texts, strict=False)
|
||||
]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
except Exception as e:
|
||||
@@ -291,7 +292,7 @@ class VectorStorageManager:
|
||||
|
||||
return results
|
||||
|
||||
async def _add_single_memory(self, memory: MemoryChunk, embedding: List[float]):
|
||||
async def _add_single_memory(self, memory: MemoryChunk, embedding: list[float]):
|
||||
"""添加单个记忆到向量存储"""
|
||||
with self._lock:
|
||||
try:
|
||||
@@ -337,7 +338,7 @@ class VectorStorageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 添加记忆到向量存储失败: {e}")
|
||||
|
||||
def _normalize_vector(self, vector: List[float]) -> List[float]:
|
||||
def _normalize_vector(self, vector: list[float]) -> list[float]:
|
||||
"""L2归一化向量"""
|
||||
if not vector:
|
||||
return vector
|
||||
@@ -357,12 +358,12 @@ class VectorStorageManager:
|
||||
|
||||
async def search_similar_memories(
|
||||
self,
|
||||
query_vector: Optional[List[float]] = None,
|
||||
query_vector: list[float] | None = None,
|
||||
*,
|
||||
query_text: Optional[str] = None,
|
||||
query_text: str | None = None,
|
||||
limit: int = 10,
|
||||
scope_id: Optional[str] = None,
|
||||
) -> List[Tuple[str, float]]:
|
||||
scope_id: str | None = None,
|
||||
) -> list[tuple[str, float]]:
|
||||
"""搜索相似记忆"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -379,7 +380,7 @@ class VectorStorageManager:
|
||||
logger.warning("查询向量生成失败")
|
||||
return []
|
||||
|
||||
scope_filter: Optional[str] = None
|
||||
scope_filter: str | None = None
|
||||
if isinstance(scope_id, str):
|
||||
normalized_scope = scope_id.strip().lower()
|
||||
if normalized_scope and normalized_scope not in {"global", "global_memory"}:
|
||||
@@ -491,7 +492,7 @@ class VectorStorageManager:
|
||||
logger.error(f"❌ 向量搜索失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def get_memory_by_id(self, memory_id: str) -> Optional[MemoryChunk]:
|
||||
async def get_memory_by_id(self, memory_id: str) -> MemoryChunk | None:
|
||||
"""根据ID获取记忆"""
|
||||
# 先检查缓存
|
||||
if memory_id in self.memory_cache:
|
||||
@@ -501,7 +502,7 @@ class VectorStorageManager:
|
||||
self.storage_stats["total_searches"] += 1
|
||||
return None
|
||||
|
||||
async def update_memory_embedding(self, memory_id: str, new_embedding: List[float]):
|
||||
async def update_memory_embedding(self, memory_id: str, new_embedding: list[float]):
|
||||
"""更新记忆的嵌入向量"""
|
||||
with self._lock:
|
||||
try:
|
||||
@@ -636,7 +637,7 @@ class VectorStorageManager:
|
||||
# 加载记忆缓存
|
||||
cache_file = self.storage_path / "memory_cache.json"
|
||||
if cache_file.exists():
|
||||
with open(cache_file, "r", encoding="utf-8") as f:
|
||||
with open(cache_file, encoding="utf-8") as f:
|
||||
cache_data = orjson.loads(f.read())
|
||||
|
||||
self.memory_cache = {
|
||||
@@ -646,13 +647,13 @@ class VectorStorageManager:
|
||||
# 加载向量缓存
|
||||
vector_cache_file = self.storage_path / "vector_cache.json"
|
||||
if vector_cache_file.exists():
|
||||
with open(vector_cache_file, "r", encoding="utf-8") as f:
|
||||
with open(vector_cache_file, encoding="utf-8") as f:
|
||||
self.vector_cache = orjson.loads(f.read())
|
||||
|
||||
# 加载映射关系
|
||||
mapping_file = self.storage_path / "id_mapping.json"
|
||||
if mapping_file.exists():
|
||||
with open(mapping_file, "r", encoding="utf-8") as f:
|
||||
with open(mapping_file, encoding="utf-8") as f:
|
||||
mapping_data = orjson.loads(f.read())
|
||||
raw_memory_to_index = mapping_data.get("memory_id_to_index", {})
|
||||
self.memory_id_to_index = {
|
||||
@@ -689,7 +690,7 @@ class VectorStorageManager:
|
||||
# 加载统计信息
|
||||
stats_file = self.storage_path / "storage_stats.json"
|
||||
if stats_file.exists():
|
||||
with open(stats_file, "r", encoding="utf-8") as f:
|
||||
with open(stats_file, encoding="utf-8") as f:
|
||||
self.storage_stats = orjson.loads(f.read())
|
||||
|
||||
# 更新向量计数
|
||||
@@ -806,7 +807,7 @@ class VectorStorageManager:
|
||||
if invalid_memory_ids:
|
||||
logger.info(f"清理了 {len(invalid_memory_ids)} 个无效引用")
|
||||
|
||||
def get_storage_stats(self) -> Dict[str, Any]:
|
||||
def get_storage_stats(self) -> dict[str, Any]:
|
||||
"""获取存储统计信息"""
|
||||
stats = self.storage_stats.copy()
|
||||
if stats["total_searches"] > 0:
|
||||
@@ -821,11 +822,11 @@ class SimpleVectorIndex:
|
||||
|
||||
def __init__(self, dimension: int):
|
||||
self.dimension = dimension
|
||||
self.vectors: List[List[float]] = []
|
||||
self.vector_ids: List[int] = []
|
||||
self.vectors: list[list[float]] = []
|
||||
self.vector_ids: list[int] = []
|
||||
self.next_id = 0
|
||||
|
||||
def add_vector(self, vector: List[float]) -> int:
|
||||
def add_vector(self, vector: list[float]) -> int:
|
||||
"""添加向量"""
|
||||
if len(vector) != self.dimension:
|
||||
raise ValueError(f"向量维度不匹配,期望 {self.dimension},实际 {len(vector)}")
|
||||
@@ -837,7 +838,7 @@ class SimpleVectorIndex:
|
||||
|
||||
return vector_id
|
||||
|
||||
def search(self, query_vector: List[float], limit: int) -> List[Tuple[int, float]]:
|
||||
def search(self, query_vector: list[float], limit: int) -> list[tuple[int, float]]:
|
||||
"""搜索相似向量"""
|
||||
if len(query_vector) != self.dimension:
|
||||
raise ValueError(f"查询向量维度不匹配,期望 {self.dimension},实际 {len(query_vector)}")
|
||||
@@ -853,7 +854,7 @@ class SimpleVectorIndex:
|
||||
|
||||
return results[:limit]
|
||||
|
||||
def _calculate_cosine_similarity(self, v1: List[float], v2: List[float]) -> float:
|
||||
def _calculate_cosine_similarity(self, v1: list[float], v2: list[float]) -> float:
|
||||
"""计算余弦相似度"""
|
||||
try:
|
||||
dot_product = sum(x * y for x, y in zip(v1, v2, strict=False))
|
||||
|
||||
@@ -1,25 +1,24 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆激活器
|
||||
记忆系统的激活器组件
|
||||
"""
|
||||
|
||||
import difflib
|
||||
import orjson
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
|
||||
from src.chat.memory_system.memory_manager import MemoryResult
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger("memory_activator")
|
||||
|
||||
|
||||
def get_keywords_from_json(json_str) -> List:
|
||||
def get_keywords_from_json(json_str) -> list:
|
||||
"""
|
||||
从JSON字符串中提取关键词列表
|
||||
|
||||
@@ -81,7 +80,7 @@ class MemoryActivator:
|
||||
self.cached_keywords = set() # 用于缓存历史关键词
|
||||
self.last_memory_query_time = 0 # 上次查询记忆的时间
|
||||
|
||||
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Dict]:
|
||||
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> list[dict]:
|
||||
"""
|
||||
激活记忆
|
||||
"""
|
||||
@@ -155,7 +154,7 @@ class MemoryActivator:
|
||||
|
||||
return self.running_memory
|
||||
|
||||
async def _query_unified_memory(self, keywords: List[str], query_text: str) -> List[MemoryResult]:
|
||||
async def _query_unified_memory(self, keywords: list[str], query_text: str) -> list[MemoryResult]:
|
||||
"""查询统一记忆系统"""
|
||||
try:
|
||||
# 使用记忆系统
|
||||
@@ -198,7 +197,7 @@ class MemoryActivator:
|
||||
logger.error(f"查询统一记忆失败: {e}")
|
||||
return []
|
||||
|
||||
async def get_instant_memory(self, target_message: str, chat_id: str) -> Optional[str]:
|
||||
async def get_instant_memory(self, target_message: str, chat_id: str) -> str | None:
|
||||
"""
|
||||
获取即时记忆 - 兼容原有接口(使用统一存储)
|
||||
"""
|
||||
|
||||
@@ -1,25 +1,24 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆激活器
|
||||
记忆系统的激活器组件
|
||||
"""
|
||||
|
||||
import difflib
|
||||
import orjson
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
|
||||
from src.chat.memory_system.memory_manager import MemoryResult
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger("memory_activator")
|
||||
|
||||
|
||||
def get_keywords_from_json(json_str) -> List:
|
||||
def get_keywords_from_json(json_str) -> list:
|
||||
"""
|
||||
从JSON字符串中提取关键词列表
|
||||
|
||||
@@ -81,7 +80,7 @@ class MemoryActivator:
|
||||
self.cached_keywords = set() # 用于缓存历史关键词
|
||||
self.last_memory_query_time = 0 # 上次查询记忆的时间
|
||||
|
||||
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Dict]:
|
||||
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> list[dict]:
|
||||
"""
|
||||
激活记忆
|
||||
"""
|
||||
@@ -155,7 +154,7 @@ class MemoryActivator:
|
||||
|
||||
return self.running_memory
|
||||
|
||||
async def _query_unified_memory(self, keywords: List[str], query_text: str) -> List[MemoryResult]:
|
||||
async def _query_unified_memory(self, keywords: list[str], query_text: str) -> list[MemoryResult]:
|
||||
"""查询统一记忆系统"""
|
||||
try:
|
||||
# 使用记忆系统
|
||||
@@ -198,7 +197,7 @@ class MemoryActivator:
|
||||
logger.error(f"查询统一记忆失败: {e}")
|
||||
return []
|
||||
|
||||
async def get_instant_memory(self, target_message: str, chat_id: str) -> Optional[str]:
|
||||
async def get_instant_memory(self, target_message: str, chat_id: str) -> str | None:
|
||||
"""
|
||||
获取即时记忆 - 兼容原有接口(使用统一存储)
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆构建模块
|
||||
从对话流中提取高质量、结构化记忆单元
|
||||
@@ -33,19 +32,19 @@ import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union, Type
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.memory_system.memory_chunk import (
|
||||
MemoryChunk,
|
||||
MemoryType,
|
||||
ConfidenceLevel,
|
||||
ImportanceLevel,
|
||||
MemoryChunk,
|
||||
MemoryType,
|
||||
create_memory_chunk,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -62,8 +61,8 @@ class ExtractionStrategy(Enum):
|
||||
class ExtractionResult:
|
||||
"""提取结果"""
|
||||
|
||||
memories: List[MemoryChunk]
|
||||
confidence_scores: List[float]
|
||||
memories: list[MemoryChunk]
|
||||
confidence_scores: list[float]
|
||||
extraction_time: float
|
||||
strategy_used: ExtractionStrategy
|
||||
|
||||
@@ -85,8 +84,8 @@ class MemoryBuilder:
|
||||
}
|
||||
|
||||
async def build_memories(
|
||||
self, conversation_text: str, context: Dict[str, Any], user_id: str, timestamp: float
|
||||
) -> List[MemoryChunk]:
|
||||
self, conversation_text: str, context: dict[str, Any], user_id: str, timestamp: float
|
||||
) -> list[MemoryChunk]:
|
||||
"""从对话中构建记忆"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -116,8 +115,8 @@ class MemoryBuilder:
|
||||
raise
|
||||
|
||||
async def _extract_with_llm(
|
||||
self, text: str, context: Dict[str, Any], user_id: str, timestamp: float
|
||||
) -> List[MemoryChunk]:
|
||||
self, text: str, context: dict[str, Any], user_id: str, timestamp: float
|
||||
) -> list[MemoryChunk]:
|
||||
"""使用LLM提取记忆"""
|
||||
try:
|
||||
prompt = self._build_llm_extraction_prompt(text, context)
|
||||
@@ -135,7 +134,7 @@ class MemoryBuilder:
|
||||
logger.error(f"LLM提取失败: {e}")
|
||||
raise MemoryExtractionError(str(e)) from e
|
||||
|
||||
def _build_llm_extraction_prompt(self, text: str, context: Dict[str, Any]) -> str:
|
||||
def _build_llm_extraction_prompt(self, text: str, context: dict[str, Any]) -> str:
|
||||
"""构建LLM提取提示"""
|
||||
current_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
message_type = context.get("message_type", "normal")
|
||||
@@ -315,7 +314,7 @@ class MemoryBuilder:
|
||||
|
||||
return prompt
|
||||
|
||||
def _extract_json_payload(self, response: str) -> Optional[str]:
|
||||
def _extract_json_payload(self, response: str) -> str | None:
|
||||
"""从模型响应中提取JSON部分,兼容Markdown代码块等格式"""
|
||||
if not response:
|
||||
return None
|
||||
@@ -338,8 +337,8 @@ class MemoryBuilder:
|
||||
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
|
||||
|
||||
def _parse_llm_response(
|
||||
self, response: str, user_id: str, timestamp: float, context: Dict[str, Any]
|
||||
) -> List[MemoryChunk]:
|
||||
self, response: str, user_id: str, timestamp: float, context: dict[str, Any]
|
||||
) -> list[MemoryChunk]:
|
||||
"""解析LLM响应"""
|
||||
if not response:
|
||||
raise MemoryExtractionError("LLM未返回任何响应")
|
||||
@@ -385,7 +384,7 @@ class MemoryBuilder:
|
||||
|
||||
bot_display = self._clean_subject_text(bot_display)
|
||||
|
||||
memories: List[MemoryChunk] = []
|
||||
memories: list[MemoryChunk] = []
|
||||
|
||||
for mem_data in memory_list:
|
||||
try:
|
||||
@@ -460,7 +459,7 @@ class MemoryBuilder:
|
||||
|
||||
return memories
|
||||
|
||||
def _parse_enum_value(self, enum_cls: Type[Enum], raw_value: Any, default: Enum, field_name: str) -> Enum:
|
||||
def _parse_enum_value(self, enum_cls: type[Enum], raw_value: Any, default: Enum, field_name: str) -> Enum:
|
||||
"""解析枚举值,兼容数字/字符串表示"""
|
||||
if isinstance(raw_value, enum_cls):
|
||||
return raw_value
|
||||
@@ -514,7 +513,7 @@ class MemoryBuilder:
|
||||
)
|
||||
return default
|
||||
|
||||
def _collect_bot_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]:
|
||||
def _collect_bot_identifiers(self, context: dict[str, Any] | None) -> set[str]:
|
||||
identifiers: set[str] = {"bot", "机器人", "ai助手"}
|
||||
if not context:
|
||||
return identifiers
|
||||
@@ -540,7 +539,7 @@ class MemoryBuilder:
|
||||
|
||||
return identifiers
|
||||
|
||||
def _collect_system_identifiers(self, context: Optional[Dict[str, Any]]) -> set[str]:
|
||||
def _collect_system_identifiers(self, context: dict[str, Any] | None) -> set[str]:
|
||||
identifiers: set[str] = set()
|
||||
if not context:
|
||||
return identifiers
|
||||
@@ -568,8 +567,8 @@ class MemoryBuilder:
|
||||
|
||||
return identifiers
|
||||
|
||||
def _resolve_conversation_participants(self, context: Optional[Dict[str, Any]], user_id: str) -> List[str]:
|
||||
participants: List[str] = []
|
||||
def _resolve_conversation_participants(self, context: dict[str, Any] | None, user_id: str) -> list[str]:
|
||||
participants: list[str] = []
|
||||
|
||||
if context:
|
||||
candidate_keys = [
|
||||
@@ -609,7 +608,7 @@ class MemoryBuilder:
|
||||
if not participants:
|
||||
participants = ["对话参与者"]
|
||||
|
||||
deduplicated: List[str] = []
|
||||
deduplicated: list[str] = []
|
||||
seen = set()
|
||||
for name in participants:
|
||||
key = name.lower()
|
||||
@@ -620,7 +619,7 @@ class MemoryBuilder:
|
||||
|
||||
return deduplicated
|
||||
|
||||
def _resolve_user_display(self, context: Optional[Dict[str, Any]], user_id: str) -> str:
|
||||
def _resolve_user_display(self, context: dict[str, Any] | None, user_id: str) -> str:
|
||||
candidate_keys = [
|
||||
"user_display_name",
|
||||
"user_name",
|
||||
@@ -683,7 +682,7 @@ class MemoryBuilder:
|
||||
|
||||
return False
|
||||
|
||||
def _split_subject_string(self, value: str) -> List[str]:
|
||||
def _split_subject_string(self, value: str) -> list[str]:
|
||||
if not value:
|
||||
return []
|
||||
|
||||
@@ -699,12 +698,12 @@ class MemoryBuilder:
|
||||
subject: Any,
|
||||
bot_identifiers: set[str],
|
||||
system_identifiers: set[str],
|
||||
default_subjects: List[str],
|
||||
bot_display: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
default_subjects: list[str],
|
||||
bot_display: str | None = None,
|
||||
) -> list[str]:
|
||||
defaults = default_subjects or ["对话参与者"]
|
||||
|
||||
raw_candidates: List[str] = []
|
||||
raw_candidates: list[str] = []
|
||||
if isinstance(subject, list):
|
||||
for item in subject:
|
||||
if isinstance(item, str):
|
||||
@@ -716,7 +715,7 @@ class MemoryBuilder:
|
||||
elif subject is not None:
|
||||
raw_candidates.extend(self._split_subject_string(str(subject)))
|
||||
|
||||
normalized: List[str] = []
|
||||
normalized: list[str] = []
|
||||
bot_primary = self._clean_subject_text(bot_display or "")
|
||||
|
||||
for candidate in raw_candidates:
|
||||
@@ -741,7 +740,7 @@ class MemoryBuilder:
|
||||
if not normalized:
|
||||
normalized = list(defaults)
|
||||
|
||||
deduplicated: List[str] = []
|
||||
deduplicated: list[str] = []
|
||||
seen = set()
|
||||
for name in normalized:
|
||||
key = name.lower()
|
||||
@@ -752,7 +751,7 @@ class MemoryBuilder:
|
||||
|
||||
return deduplicated
|
||||
|
||||
def _extract_value_from_object(self, obj: Union[str, Dict[str, Any], List[Any]], keys: List[str]) -> Optional[str]:
|
||||
def _extract_value_from_object(self, obj: str | dict[str, Any] | list[Any], keys: list[str]) -> str | None:
|
||||
if isinstance(obj, dict):
|
||||
for key in keys:
|
||||
value = obj.get(key)
|
||||
@@ -773,9 +772,7 @@ class MemoryBuilder:
|
||||
return obj.strip() or None
|
||||
return None
|
||||
|
||||
def _compose_display_text(
|
||||
self, subjects: List[str], predicate: str, obj: Union[str, Dict[str, Any], List[Any]]
|
||||
) -> str:
|
||||
def _compose_display_text(self, subjects: list[str], predicate: str, obj: str | dict[str, Any] | list[Any]) -> str:
|
||||
subject_phrase = "、".join(subjects) if subjects else "对话参与者"
|
||||
predicate = (predicate or "").strip()
|
||||
|
||||
@@ -841,7 +838,7 @@ class MemoryBuilder:
|
||||
return f"{subject_phrase}{predicate}".strip()
|
||||
return subject_phrase
|
||||
|
||||
def _validate_and_enhance_memories(self, memories: List[MemoryChunk], context: Dict[str, Any]) -> List[MemoryChunk]:
|
||||
def _validate_and_enhance_memories(self, memories: list[MemoryChunk], context: dict[str, Any]) -> list[MemoryChunk]:
|
||||
"""验证和增强记忆"""
|
||||
validated_memories = []
|
||||
|
||||
@@ -876,7 +873,7 @@ class MemoryBuilder:
|
||||
|
||||
return True
|
||||
|
||||
def _enhance_memory(self, memory: MemoryChunk, context: Dict[str, Any]) -> MemoryChunk:
|
||||
def _enhance_memory(self, memory: MemoryChunk, context: dict[str, Any]) -> MemoryChunk:
|
||||
"""增强记忆块"""
|
||||
# 时间规范化处理
|
||||
self._normalize_time_in_memory(memory)
|
||||
@@ -985,7 +982,7 @@ class MemoryBuilder:
|
||||
total_confidence / self.extraction_stats["successful_extractions"]
|
||||
)
|
||||
|
||||
def get_extraction_stats(self) -> Dict[str, Any]:
|
||||
def get_extraction_stats(self) -> dict[str, Any]:
|
||||
"""获取提取统计信息"""
|
||||
return self.extraction_stats.copy()
|
||||
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
结构化记忆单元设计
|
||||
实现高质量、结构化的记忆单元,符合文档设计规范
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
import uuid
|
||||
import orjson
|
||||
from typing import Dict, List, Optional, Any, Union, Iterable
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import hashlib
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -56,17 +57,17 @@ class ImportanceLevel(Enum):
|
||||
class ContentStructure:
|
||||
"""主谓宾结构,包含自然语言描述"""
|
||||
|
||||
subject: Union[str, List[str]]
|
||||
subject: str | list[str]
|
||||
predicate: str
|
||||
object: Union[str, Dict]
|
||||
object: str | dict
|
||||
display: str = ""
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {"subject": self.subject, "predicate": self.predicate, "object": self.object, "display": self.display}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ContentStructure":
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ContentStructure":
|
||||
"""从字典创建实例"""
|
||||
return cls(
|
||||
subject=data.get("subject", ""),
|
||||
@@ -75,7 +76,7 @@ class ContentStructure:
|
||||
display=data.get("display", ""),
|
||||
)
|
||||
|
||||
def to_subject_list(self) -> List[str]:
|
||||
def to_subject_list(self) -> list[str]:
|
||||
"""将主语转换为列表形式"""
|
||||
if isinstance(self.subject, list):
|
||||
return [s for s in self.subject if isinstance(s, str) and s.strip()]
|
||||
@@ -99,7 +100,7 @@ class MemoryMetadata:
|
||||
# 基础信息
|
||||
memory_id: str # 唯一标识符
|
||||
user_id: str # 用户ID
|
||||
chat_id: Optional[str] = None # 聊天ID(群聊或私聊)
|
||||
chat_id: str | None = None # 聊天ID(群聊或私聊)
|
||||
|
||||
# 时间信息
|
||||
created_at: float = 0.0 # 创建时间戳
|
||||
@@ -124,9 +125,9 @@ class MemoryMetadata:
|
||||
last_forgetting_check: float = 0.0 # 上次遗忘检查时间
|
||||
|
||||
# 来源信息
|
||||
source_context: Optional[str] = None # 来源上下文片段
|
||||
source_context: str | None = None # 来源上下文片段
|
||||
# 兼容旧字段: 一些代码或旧版本可能直接访问 metadata.source
|
||||
source: Optional[str] = None
|
||||
source: str | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""后初始化处理"""
|
||||
@@ -209,7 +210,7 @@ class MemoryMetadata:
|
||||
# 设置最小和最大阈值
|
||||
return max(7.0, min(threshold, 365.0)) # 7天到1年之间
|
||||
|
||||
def should_forget(self, current_time: Optional[float] = None) -> bool:
|
||||
def should_forget(self, current_time: float | None = None) -> bool:
|
||||
"""判断是否应该遗忘"""
|
||||
if current_time is None:
|
||||
current_time = time.time()
|
||||
@@ -222,7 +223,7 @@ class MemoryMetadata:
|
||||
|
||||
return days_since_activation > self.forgetting_threshold
|
||||
|
||||
def is_dormant(self, current_time: Optional[float] = None, inactive_days: int = 90) -> bool:
|
||||
def is_dormant(self, current_time: float | None = None, inactive_days: int = 90) -> bool:
|
||||
"""判断是否处于休眠状态(长期未激活)"""
|
||||
if current_time is None:
|
||||
current_time = time.time()
|
||||
@@ -230,7 +231,7 @@ class MemoryMetadata:
|
||||
days_since_last_access = (current_time - self.last_accessed) / 86400
|
||||
return days_since_last_access > inactive_days
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"memory_id": self.memory_id,
|
||||
@@ -252,7 +253,7 @@ class MemoryMetadata:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "MemoryMetadata":
|
||||
def from_dict(cls, data: dict[str, Any]) -> "MemoryMetadata":
|
||||
"""从字典创建实例"""
|
||||
return cls(
|
||||
memory_id=data.get("memory_id", ""),
|
||||
@@ -286,17 +287,17 @@ class MemoryChunk:
|
||||
memory_type: MemoryType # 记忆类型
|
||||
|
||||
# 扩展信息
|
||||
keywords: List[str] = field(default_factory=list) # 关键词列表
|
||||
tags: List[str] = field(default_factory=list) # 标签列表
|
||||
categories: List[str] = field(default_factory=list) # 分类列表
|
||||
keywords: list[str] = field(default_factory=list) # 关键词列表
|
||||
tags: list[str] = field(default_factory=list) # 标签列表
|
||||
categories: list[str] = field(default_factory=list) # 分类列表
|
||||
|
||||
# 语义信息
|
||||
embedding: Optional[List[float]] = None # 语义向量
|
||||
semantic_hash: Optional[str] = None # 语义哈希值
|
||||
embedding: list[float] | None = None # 语义向量
|
||||
semantic_hash: str | None = None # 语义哈希值
|
||||
|
||||
# 关联信息
|
||||
related_memories: List[str] = field(default_factory=list) # 关联记忆ID列表
|
||||
temporal_context: Optional[Dict[str, Any]] = None # 时间上下文
|
||||
related_memories: list[str] = field(default_factory=list) # 关联记忆ID列表
|
||||
temporal_context: dict[str, Any] | None = None # 时间上下文
|
||||
|
||||
def __post_init__(self):
|
||||
"""后初始化处理"""
|
||||
@@ -310,7 +311,7 @@ class MemoryChunk:
|
||||
|
||||
try:
|
||||
# 使用向量和内容生成稳定的哈希
|
||||
content_str = f"{self.content.subject}:{self.content.predicate}:{str(self.content.object)}"
|
||||
content_str = f"{self.content.subject}:{self.content.predicate}:{self.content.object!s}"
|
||||
embedding_str = ",".join(map(str, [round(x, 6) for x in self.embedding]))
|
||||
|
||||
hash_input = f"{content_str}|{embedding_str}"
|
||||
@@ -342,7 +343,7 @@ class MemoryChunk:
|
||||
return self.content.display or str(self.content)
|
||||
|
||||
@property
|
||||
def subjects(self) -> List[str]:
|
||||
def subjects(self) -> list[str]:
|
||||
"""获取主语列表"""
|
||||
return self.content.to_subject_list()
|
||||
|
||||
@@ -354,11 +355,11 @@ class MemoryChunk:
|
||||
"""更新相关度评分"""
|
||||
self.metadata.update_relevance(new_score)
|
||||
|
||||
def should_forget(self, current_time: Optional[float] = None) -> bool:
|
||||
def should_forget(self, current_time: float | None = None) -> bool:
|
||||
"""判断是否应该遗忘"""
|
||||
return self.metadata.should_forget(current_time)
|
||||
|
||||
def is_dormant(self, current_time: Optional[float] = None, inactive_days: int = 90) -> bool:
|
||||
def is_dormant(self, current_time: float | None = None, inactive_days: int = 90) -> bool:
|
||||
"""判断是否处于休眠状态(长期未激活)"""
|
||||
return self.metadata.is_dormant(current_time, inactive_days)
|
||||
|
||||
@@ -386,7 +387,7 @@ class MemoryChunk:
|
||||
if memory_id and memory_id not in self.related_memories:
|
||||
self.related_memories.append(memory_id)
|
||||
|
||||
def set_embedding(self, embedding: List[float]):
|
||||
def set_embedding(self, embedding: list[float]):
|
||||
"""设置语义向量"""
|
||||
self.embedding = embedding
|
||||
self._generate_semantic_hash()
|
||||
@@ -415,7 +416,7 @@ class MemoryChunk:
|
||||
logger.warning(f"计算记忆相似度失败: {e}")
|
||||
return 0.0
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为完整的字典格式"""
|
||||
return {
|
||||
"metadata": self.metadata.to_dict(),
|
||||
@@ -431,7 +432,7 @@ class MemoryChunk:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "MemoryChunk":
|
||||
def from_dict(cls, data: dict[str, Any]) -> "MemoryChunk":
|
||||
"""从字典创建实例"""
|
||||
metadata = MemoryMetadata.from_dict(data.get("metadata", {}))
|
||||
content = ContentStructure.from_dict(data.get("content", {}))
|
||||
@@ -541,7 +542,7 @@ class MemoryChunk:
|
||||
return f"MemoryChunk(id={self.memory_id[:8]}..., type={self.memory_type.value}, user={self.user_id})"
|
||||
|
||||
|
||||
def _build_display_text(subjects: Iterable[str], predicate: str, obj: Union[str, Dict]) -> str:
|
||||
def _build_display_text(subjects: Iterable[str], predicate: str, obj: str | dict) -> str:
|
||||
"""根据主谓宾生成自然语言描述"""
|
||||
subjects_clean = [s.strip() for s in subjects if s and isinstance(s, str)]
|
||||
subject_part = "、".join(subjects_clean) if subjects_clean else "对话参与者"
|
||||
@@ -569,15 +570,15 @@ def _build_display_text(subjects: Iterable[str], predicate: str, obj: Union[str,
|
||||
|
||||
def create_memory_chunk(
|
||||
user_id: str,
|
||||
subject: Union[str, List[str]],
|
||||
subject: str | list[str],
|
||||
predicate: str,
|
||||
obj: Union[str, Dict],
|
||||
obj: str | dict,
|
||||
memory_type: MemoryType,
|
||||
chat_id: Optional[str] = None,
|
||||
source_context: Optional[str] = None,
|
||||
chat_id: str | None = None,
|
||||
source_context: str | None = None,
|
||||
importance: ImportanceLevel = ImportanceLevel.NORMAL,
|
||||
confidence: ConfidenceLevel = ConfidenceLevel.MEDIUM,
|
||||
display: Optional[str] = None,
|
||||
display: str | None = None,
|
||||
**kwargs,
|
||||
) -> MemoryChunk:
|
||||
"""便捷的内存块创建函数"""
|
||||
@@ -593,10 +594,10 @@ def create_memory_chunk(
|
||||
source_context=source_context,
|
||||
)
|
||||
|
||||
subjects: List[str]
|
||||
subjects: list[str]
|
||||
if isinstance(subject, list):
|
||||
subjects = [s for s in subject if isinstance(s, str) and s.strip()]
|
||||
subject_payload: Union[str, List[str]] = subjects
|
||||
subject_payload: str | list[str] = subjects
|
||||
else:
|
||||
cleaned = subject.strip() if isinstance(subject, str) else ""
|
||||
subjects = [cleaned] if cleaned else []
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
智能记忆遗忘引擎
|
||||
基于重要程度、置信度和激活频率的智能遗忘机制
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from datetime import datetime
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, ImportanceLevel, ConfidenceLevel
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -65,7 +63,7 @@ class ForgettingConfig:
|
||||
class MemoryForgettingEngine:
|
||||
"""智能记忆遗忘引擎"""
|
||||
|
||||
def __init__(self, config: Optional[ForgettingConfig] = None):
|
||||
def __init__(self, config: ForgettingConfig | None = None):
|
||||
self.config = config or ForgettingConfig()
|
||||
self.stats = ForgettingStats()
|
||||
self._last_forgetting_check = 0.0
|
||||
@@ -116,7 +114,7 @@ class MemoryForgettingEngine:
|
||||
# 确保在合理范围内
|
||||
return max(self.config.min_forgetting_days, min(threshold, self.config.max_forgetting_days))
|
||||
|
||||
def should_forget_memory(self, memory: MemoryChunk, current_time: Optional[float] = None) -> bool:
|
||||
def should_forget_memory(self, memory: MemoryChunk, current_time: float | None = None) -> bool:
|
||||
"""
|
||||
判断记忆是否应该被遗忘
|
||||
|
||||
@@ -155,7 +153,7 @@ class MemoryForgettingEngine:
|
||||
|
||||
return should_forget
|
||||
|
||||
def is_dormant_memory(self, memory: MemoryChunk, current_time: Optional[float] = None) -> bool:
|
||||
def is_dormant_memory(self, memory: MemoryChunk, current_time: float | None = None) -> bool:
|
||||
"""
|
||||
判断记忆是否处于休眠状态
|
||||
|
||||
@@ -168,7 +166,7 @@ class MemoryForgettingEngine:
|
||||
"""
|
||||
return memory.is_dormant(current_time, self.config.dormant_threshold_days)
|
||||
|
||||
def should_force_forget_dormant(self, memory: MemoryChunk, current_time: Optional[float] = None) -> bool:
|
||||
def should_force_forget_dormant(self, memory: MemoryChunk, current_time: float | None = None) -> bool:
|
||||
"""
|
||||
判断是否应该强制遗忘休眠记忆
|
||||
|
||||
@@ -189,7 +187,7 @@ class MemoryForgettingEngine:
|
||||
days_since_last_access = (current_time - memory.metadata.last_accessed) / 86400
|
||||
return days_since_last_access > self.config.force_forget_dormant_days
|
||||
|
||||
async def check_memories_for_forgetting(self, memories: List[MemoryChunk]) -> Tuple[List[str], List[str]]:
|
||||
async def check_memories_for_forgetting(self, memories: list[MemoryChunk]) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
检查记忆列表,识别需要遗忘的记忆
|
||||
|
||||
@@ -241,7 +239,7 @@ class MemoryForgettingEngine:
|
||||
|
||||
return normal_forgetting_ids, force_forgetting_ids
|
||||
|
||||
async def perform_forgetting_check(self, memories: List[MemoryChunk]) -> Dict[str, any]:
|
||||
async def perform_forgetting_check(self, memories: list[MemoryChunk]) -> dict[str, any]:
|
||||
"""
|
||||
执行完整的遗忘检查流程
|
||||
|
||||
@@ -314,7 +312,7 @@ class MemoryForgettingEngine:
|
||||
except Exception as e:
|
||||
logger.error(f"定期遗忘检查失败: {e}", exc_info=True)
|
||||
|
||||
def get_forgetting_stats(self) -> Dict[str, any]:
|
||||
def get_forgetting_stats(self) -> dict[str, any]:
|
||||
"""获取遗忘统计信息"""
|
||||
return {
|
||||
"total_checked": self.stats.total_checked,
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆融合与去重机制
|
||||
避免记忆碎片化,确保长期记忆库的高质量
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, ConfidenceLevel, ImportanceLevel
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -22,9 +20,9 @@ class FusionResult:
|
||||
original_count: int
|
||||
fused_count: int
|
||||
removed_duplicates: int
|
||||
merged_memories: List[MemoryChunk]
|
||||
merged_memories: list[MemoryChunk]
|
||||
fusion_time: float
|
||||
details: List[str]
|
||||
details: list[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -32,9 +30,9 @@ class DuplicateGroup:
|
||||
"""重复记忆组"""
|
||||
|
||||
group_id: str
|
||||
memories: List[MemoryChunk]
|
||||
similarity_matrix: List[List[float]]
|
||||
representative_memory: Optional[MemoryChunk] = None
|
||||
memories: list[MemoryChunk]
|
||||
similarity_matrix: list[list[float]]
|
||||
representative_memory: MemoryChunk | None = None
|
||||
|
||||
|
||||
class MemoryFusionEngine:
|
||||
@@ -59,8 +57,8 @@ class MemoryFusionEngine:
|
||||
}
|
||||
|
||||
async def fuse_memories(
|
||||
self, new_memories: List[MemoryChunk], existing_memories: Optional[List[MemoryChunk]] = None
|
||||
) -> List[MemoryChunk]:
|
||||
self, new_memories: list[MemoryChunk], existing_memories: list[MemoryChunk] | None = None
|
||||
) -> list[MemoryChunk]:
|
||||
"""融合记忆列表"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -106,8 +104,8 @@ class MemoryFusionEngine:
|
||||
return new_memories # 失败时返回原始记忆
|
||||
|
||||
async def _detect_duplicate_groups(
|
||||
self, new_memories: List[MemoryChunk], existing_memories: List[MemoryChunk]
|
||||
) -> List[DuplicateGroup]:
|
||||
self, new_memories: list[MemoryChunk], existing_memories: list[MemoryChunk]
|
||||
) -> list[DuplicateGroup]:
|
||||
"""检测重复记忆组"""
|
||||
all_memories = new_memories + existing_memories
|
||||
new_memory_ids = {memory.memory_id for memory in new_memories}
|
||||
@@ -212,7 +210,7 @@ class MemoryFusionEngine:
|
||||
jaccard_similarity = len(intersection) / len(union)
|
||||
return jaccard_similarity
|
||||
|
||||
def _calculate_keyword_similarity(self, keywords1: List[str], keywords2: List[str]) -> float:
|
||||
def _calculate_keyword_similarity(self, keywords1: list[str], keywords2: list[str]) -> float:
|
||||
"""计算关键词相似度"""
|
||||
if not keywords1 or not keywords2:
|
||||
return 0.0
|
||||
@@ -302,7 +300,7 @@ class MemoryFusionEngine:
|
||||
|
||||
return best_memory
|
||||
|
||||
async def _fuse_memory_group(self, group: DuplicateGroup) -> Optional[MemoryChunk]:
|
||||
async def _fuse_memory_group(self, group: DuplicateGroup) -> MemoryChunk | None:
|
||||
"""融合记忆组"""
|
||||
if not group.memories:
|
||||
return None
|
||||
@@ -328,7 +326,7 @@ class MemoryFusionEngine:
|
||||
# 返回置信度最高的记忆
|
||||
return max(group.memories, key=lambda m: m.metadata.confidence.value)
|
||||
|
||||
async def _merge_memory_attributes(self, base_memory: MemoryChunk, memories: List[MemoryChunk]) -> MemoryChunk:
|
||||
async def _merge_memory_attributes(self, base_memory: MemoryChunk, memories: list[MemoryChunk]) -> MemoryChunk:
|
||||
"""合并记忆属性"""
|
||||
# 创建基础记忆的深拷贝
|
||||
fused_memory = MemoryChunk.from_dict(base_memory.to_dict())
|
||||
@@ -395,7 +393,7 @@ class MemoryFusionEngine:
|
||||
source_ids = [m.memory_id[:8] for m in group.memories]
|
||||
fused_memory.metadata.source_context = f"Fused from {len(group.memories)} memories: {', '.join(source_ids)}"
|
||||
|
||||
def _merge_temporal_context(self, memories: List[MemoryChunk]) -> Dict[str, Any]:
|
||||
def _merge_temporal_context(self, memories: list[MemoryChunk]) -> dict[str, Any]:
|
||||
"""合并时间上下文"""
|
||||
contexts = [m.temporal_context for m in memories if m.temporal_context]
|
||||
|
||||
@@ -426,8 +424,8 @@ class MemoryFusionEngine:
|
||||
return merged_context
|
||||
|
||||
async def incremental_fusion(
|
||||
self, new_memory: MemoryChunk, existing_memories: List[MemoryChunk]
|
||||
) -> Tuple[MemoryChunk, List[MemoryChunk]]:
|
||||
self, new_memory: MemoryChunk, existing_memories: list[MemoryChunk]
|
||||
) -> tuple[MemoryChunk, list[MemoryChunk]]:
|
||||
"""增量融合(单个新记忆与现有记忆融合)"""
|
||||
# 寻找相似记忆
|
||||
similar_memories = []
|
||||
@@ -493,7 +491,7 @@ class MemoryFusionEngine:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 记忆融合引擎维护失败: {e}", exc_info=True)
|
||||
|
||||
def get_fusion_stats(self) -> Dict[str, Any]:
|
||||
def get_fusion_stats(self) -> dict[str, Any]:
|
||||
"""获取融合统计信息"""
|
||||
return self.fusion_stats.copy()
|
||||
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆系统管理器
|
||||
替代原有的 Hippocampus 和 instant_memory 系统
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.memory_system.memory_system import MemorySystem
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.chat.memory_system.memory_system import initialize_memory_system
|
||||
from src.chat.memory_system.memory_system import MemorySystem, initialize_memory_system
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -27,14 +25,14 @@ class MemoryResult:
|
||||
timestamp: float
|
||||
source: str = "memory"
|
||||
relevance_score: float = 0.0
|
||||
structure: Dict[str, Any] | None = None
|
||||
structure: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
"""记忆系统管理器 - 替代原有的 HippocampusManager"""
|
||||
|
||||
def __init__(self):
|
||||
self.memory_system: Optional[MemorySystem] = None
|
||||
self.memory_system: MemorySystem | None = None
|
||||
self.is_initialized = False
|
||||
self.user_cache = {} # 用户记忆缓存
|
||||
|
||||
@@ -63,8 +61,8 @@ class MemoryManager:
|
||||
logger.info("正在初始化记忆系统...")
|
||||
|
||||
# 获取LLM模型
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
llm_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="memory")
|
||||
|
||||
@@ -121,7 +119,7 @@ class MemoryManager:
|
||||
max_memory_length: int = 2,
|
||||
time_weight: float = 1.0,
|
||||
keyword_weight: float = 1.0,
|
||||
) -> List[Tuple[str, str]]:
|
||||
) -> list[tuple[str, str]]:
|
||||
"""从文本获取相关记忆 - 兼容原有接口"""
|
||||
if not self.is_initialized or not self.memory_system:
|
||||
return []
|
||||
@@ -152,8 +150,8 @@ class MemoryManager:
|
||||
return []
|
||||
|
||||
async def get_memory_from_topic(
|
||||
self, valid_keywords: List[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
|
||||
) -> List[Tuple[str, str]]:
|
||||
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
|
||||
) -> list[tuple[str, str]]:
|
||||
"""从关键词获取记忆 - 兼容原有接口"""
|
||||
if not self.is_initialized or not self.memory_system:
|
||||
return []
|
||||
@@ -208,8 +206,8 @@ class MemoryManager:
|
||||
return []
|
||||
|
||||
async def process_conversation(
|
||||
self, conversation_text: str, context: Dict[str, Any], user_id: str, timestamp: Optional[float] = None
|
||||
) -> List[MemoryChunk]:
|
||||
self, conversation_text: str, context: dict[str, Any], user_id: str, timestamp: float | None = None
|
||||
) -> list[MemoryChunk]:
|
||||
"""处理对话并构建记忆 - 新增功能"""
|
||||
if not self.is_initialized or not self.memory_system:
|
||||
return []
|
||||
@@ -235,8 +233,8 @@ class MemoryManager:
|
||||
return []
|
||||
|
||||
async def get_enhanced_memory_context(
|
||||
self, query_text: str, user_id: str, context: Optional[Dict[str, Any]] = None, limit: int = 5
|
||||
) -> List[MemoryResult]:
|
||||
self, query_text: str, user_id: str, context: dict[str, Any] | None = None, limit: int = 5
|
||||
) -> list[MemoryResult]:
|
||||
"""获取增强记忆上下文 - 新增功能"""
|
||||
if not self.is_initialized or not self.memory_system:
|
||||
return []
|
||||
@@ -267,7 +265,7 @@ class MemoryManager:
|
||||
logger.error(f"get_enhanced_memory_context 失败: {e}")
|
||||
return []
|
||||
|
||||
def _format_memory_chunk(self, memory: MemoryChunk) -> Tuple[str, Dict[str, Any]]:
|
||||
def _format_memory_chunk(self, memory: MemoryChunk) -> tuple[str, dict[str, Any]]:
|
||||
"""将记忆块转换为更易读的文本描述"""
|
||||
structure = memory.content.to_dict()
|
||||
if memory.display:
|
||||
@@ -289,7 +287,7 @@ class MemoryManager:
|
||||
|
||||
return formatted, structure
|
||||
|
||||
def _format_subject(self, subject: Optional[str], memory: MemoryChunk) -> str:
|
||||
def _format_subject(self, subject: str | None, memory: MemoryChunk) -> str:
|
||||
if not subject:
|
||||
return "该用户"
|
||||
|
||||
@@ -299,7 +297,7 @@ class MemoryManager:
|
||||
return "该聊天"
|
||||
return self._clean_text(subject)
|
||||
|
||||
def _apply_predicate_format(self, subject: str, predicate: str, obj: Any) -> Optional[str]:
|
||||
def _apply_predicate_format(self, subject: str, predicate: str, obj: Any) -> str | None:
|
||||
predicate = (predicate or "").strip()
|
||||
obj_value = obj
|
||||
|
||||
@@ -446,10 +444,10 @@ class MemoryManager:
|
||||
text = self._truncate(str(obj).strip())
|
||||
return self._clean_text(text)
|
||||
|
||||
def _extract_from_object(self, obj: Any, keys: List[str]) -> Optional[str]:
|
||||
def _extract_from_object(self, obj: Any, keys: list[str]) -> str | None:
|
||||
if isinstance(obj, dict):
|
||||
for key in keys:
|
||||
if key in obj and obj[key]:
|
||||
if obj.get(key):
|
||||
value = obj[key]
|
||||
if isinstance(value, (dict, list)):
|
||||
return self._clean_text(self._format_object(value))
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆元数据索引管理器
|
||||
使用JSON文件存储记忆元数据,支持快速模糊搜索和过滤
|
||||
"""
|
||||
|
||||
import orjson
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -25,10 +25,10 @@ class MemoryMetadataIndexEntry:
|
||||
|
||||
# 分类信息
|
||||
memory_type: str # MemoryType.value
|
||||
subjects: List[str] # 主语列表
|
||||
objects: List[str] # 宾语列表
|
||||
keywords: List[str] # 关键词列表
|
||||
tags: List[str] # 标签列表
|
||||
subjects: list[str] # 主语列表
|
||||
objects: list[str] # 宾语列表
|
||||
keywords: list[str] # 关键词列表
|
||||
tags: list[str] # 标签列表
|
||||
|
||||
# 数值字段(用于范围过滤)
|
||||
importance: int # ImportanceLevel.value (1-4)
|
||||
@@ -37,8 +37,8 @@ class MemoryMetadataIndexEntry:
|
||||
access_count: int # 访问次数
|
||||
|
||||
# 可选字段
|
||||
chat_id: Optional[str] = None
|
||||
content_preview: Optional[str] = None # 内容预览(前100字符)
|
||||
chat_id: str | None = None
|
||||
content_preview: str | None = None # 内容预览(前100字符)
|
||||
|
||||
|
||||
class MemoryMetadataIndex:
|
||||
@@ -46,13 +46,13 @@ class MemoryMetadataIndex:
|
||||
|
||||
def __init__(self, index_file: str = "data/memory_metadata_index.json"):
|
||||
self.index_file = Path(index_file)
|
||||
self.index: Dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry
|
||||
self.index: dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry
|
||||
|
||||
# 倒排索引(用于快速查找)
|
||||
self.type_index: Dict[str, Set[str]] = {} # type -> {memory_ids}
|
||||
self.subject_index: Dict[str, Set[str]] = {} # subject -> {memory_ids}
|
||||
self.keyword_index: Dict[str, Set[str]] = {} # keyword -> {memory_ids}
|
||||
self.tag_index: Dict[str, Set[str]] = {} # tag -> {memory_ids}
|
||||
self.type_index: dict[str, set[str]] = {} # type -> {memory_ids}
|
||||
self.subject_index: dict[str, set[str]] = {} # subject -> {memory_ids}
|
||||
self.keyword_index: dict[str, set[str]] = {} # keyword -> {memory_ids}
|
||||
self.tag_index: dict[str, set[str]] = {} # tag -> {memory_ids}
|
||||
|
||||
self.lock = threading.RLock()
|
||||
|
||||
@@ -178,7 +178,7 @@ class MemoryMetadataIndex:
|
||||
self._remove_from_inverted_indices(memory_id)
|
||||
del self.index[memory_id]
|
||||
|
||||
def batch_add_or_update(self, entries: List[MemoryMetadataIndexEntry]):
|
||||
def batch_add_or_update(self, entries: list[MemoryMetadataIndexEntry]):
|
||||
"""批量添加或更新"""
|
||||
with self.lock:
|
||||
for entry in entries:
|
||||
@@ -191,18 +191,18 @@ class MemoryMetadataIndex:
|
||||
|
||||
def search(
|
||||
self,
|
||||
memory_types: Optional[List[str]] = None,
|
||||
subjects: Optional[List[str]] = None,
|
||||
keywords: Optional[List[str]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
importance_min: Optional[int] = None,
|
||||
importance_max: Optional[int] = None,
|
||||
created_after: Optional[float] = None,
|
||||
created_before: Optional[float] = None,
|
||||
user_id: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
memory_types: list[str] | None = None,
|
||||
subjects: list[str] | None = None,
|
||||
keywords: list[str] | None = None,
|
||||
tags: list[str] | None = None,
|
||||
importance_min: int | None = None,
|
||||
importance_max: int | None = None,
|
||||
created_after: float | None = None,
|
||||
created_before: float | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
flexible_mode: bool = True, # 新增:灵活匹配模式
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""
|
||||
搜索符合条件的记忆ID列表(支持模糊匹配)
|
||||
|
||||
@@ -237,14 +237,14 @@ class MemoryMetadataIndex:
|
||||
|
||||
def _search_flexible(
|
||||
self,
|
||||
memory_types: Optional[List[str]] = None,
|
||||
subjects: Optional[List[str]] = None,
|
||||
created_after: Optional[float] = None,
|
||||
created_before: Optional[float] = None,
|
||||
user_id: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
memory_types: list[str] | None = None,
|
||||
subjects: list[str] | None = None,
|
||||
created_after: float | None = None,
|
||||
created_before: float | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
**kwargs, # 接受但不使用的参数
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""
|
||||
灵活搜索模式:2/4项匹配即可,支持部分匹配
|
||||
|
||||
@@ -374,20 +374,20 @@ class MemoryMetadataIndex:
|
||||
|
||||
def _search_strict(
|
||||
self,
|
||||
memory_types: Optional[List[str]] = None,
|
||||
subjects: Optional[List[str]] = None,
|
||||
keywords: Optional[List[str]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
importance_min: Optional[int] = None,
|
||||
importance_max: Optional[int] = None,
|
||||
created_after: Optional[float] = None,
|
||||
created_before: Optional[float] = None,
|
||||
user_id: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
memory_types: list[str] | None = None,
|
||||
subjects: list[str] | None = None,
|
||||
keywords: list[str] | None = None,
|
||||
tags: list[str] | None = None,
|
||||
importance_min: int | None = None,
|
||||
importance_max: int | None = None,
|
||||
created_after: float | None = None,
|
||||
created_before: float | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[str]:
|
||||
"""严格搜索模式(原有逻辑)"""
|
||||
# 初始候选集(所有记忆)
|
||||
candidate_ids: Optional[Set[str]] = None
|
||||
candidate_ids: set[str] | None = None
|
||||
|
||||
# 用户过滤(必选)
|
||||
if user_id:
|
||||
@@ -471,11 +471,11 @@ class MemoryMetadataIndex:
|
||||
|
||||
return result_ids
|
||||
|
||||
def get_entry(self, memory_id: str) -> Optional[MemoryMetadataIndexEntry]:
|
||||
def get_entry(self, memory_id: str) -> MemoryMetadataIndexEntry | None:
|
||||
"""获取单个索引条目"""
|
||||
return self.index.get(memory_id)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""获取索引统计信息"""
|
||||
with self.lock:
|
||||
return {
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""记忆检索查询规划器"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
@@ -21,16 +20,16 @@ class MemoryQueryPlan:
|
||||
"""查询规划结果"""
|
||||
|
||||
semantic_query: str
|
||||
memory_types: List[MemoryType] = field(default_factory=list)
|
||||
subject_includes: List[str] = field(default_factory=list)
|
||||
object_includes: List[str] = field(default_factory=list)
|
||||
required_keywords: List[str] = field(default_factory=list)
|
||||
optional_keywords: List[str] = field(default_factory=list)
|
||||
owner_filters: List[str] = field(default_factory=list)
|
||||
memory_types: list[MemoryType] = field(default_factory=list)
|
||||
subject_includes: list[str] = field(default_factory=list)
|
||||
object_includes: list[str] = field(default_factory=list)
|
||||
required_keywords: list[str] = field(default_factory=list)
|
||||
optional_keywords: list[str] = field(default_factory=list)
|
||||
owner_filters: list[str] = field(default_factory=list)
|
||||
recency_preference: str = "any"
|
||||
limit: int = 10
|
||||
emphasis: Optional[str] = None
|
||||
raw_plan: Dict[str, Any] = field(default_factory=dict)
|
||||
emphasis: str | None = None
|
||||
raw_plan: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def ensure_defaults(self, fallback_query: str, default_limit: int) -> None:
|
||||
if not self.semantic_query:
|
||||
@@ -46,11 +45,11 @@ class MemoryQueryPlan:
|
||||
class MemoryQueryPlanner:
|
||||
"""基于小模型的记忆检索查询规划器"""
|
||||
|
||||
def __init__(self, planner_model: Optional[LLMRequest], default_limit: int = 10):
|
||||
def __init__(self, planner_model: LLMRequest | None, default_limit: int = 10):
|
||||
self.model = planner_model
|
||||
self.default_limit = default_limit
|
||||
|
||||
async def plan_query(self, query_text: str, context: Dict[str, Any]) -> MemoryQueryPlan:
|
||||
async def plan_query(self, query_text: str, context: dict[str, Any]) -> MemoryQueryPlan:
|
||||
if not self.model:
|
||||
logger.debug("未提供查询规划模型,使用默认规划")
|
||||
return self._default_plan(query_text)
|
||||
@@ -82,10 +81,10 @@ class MemoryQueryPlanner:
|
||||
def _default_plan(self, query_text: str) -> MemoryQueryPlan:
|
||||
return MemoryQueryPlan(semantic_query=query_text, limit=self.default_limit)
|
||||
|
||||
def _parse_plan_dict(self, data: Dict[str, Any], fallback_query: str) -> MemoryQueryPlan:
|
||||
def _parse_plan_dict(self, data: dict[str, Any], fallback_query: str) -> MemoryQueryPlan:
|
||||
semantic_query = self._safe_str(data.get("semantic_query")) or fallback_query
|
||||
|
||||
def _collect_list(key: str) -> List[str]:
|
||||
def _collect_list(key: str) -> list[str]:
|
||||
value = data.get(key)
|
||||
if isinstance(value, str):
|
||||
return [value]
|
||||
@@ -94,7 +93,7 @@ class MemoryQueryPlanner:
|
||||
return []
|
||||
|
||||
memory_type_values = _collect_list("memory_types")
|
||||
memory_types: List[MemoryType] = []
|
||||
memory_types: list[MemoryType] = []
|
||||
for item in memory_type_values:
|
||||
if not item:
|
||||
continue
|
||||
@@ -123,7 +122,7 @@ class MemoryQueryPlanner:
|
||||
)
|
||||
return plan
|
||||
|
||||
def _build_prompt(self, query_text: str, context: Dict[str, Any]) -> str:
|
||||
def _build_prompt(self, query_text: str, context: dict[str, Any]) -> str:
|
||||
participants = context.get("participants") or context.get("speaker_names") or []
|
||||
if isinstance(participants, str):
|
||||
participants = [participants]
|
||||
@@ -206,7 +205,7 @@ class MemoryQueryPlanner:
|
||||
请直接输出符合要求的 JSON 对象,禁止添加额外文本或 Markdown 代码块。
|
||||
"""
|
||||
|
||||
def _extract_json_payload(self, response: str) -> Optional[str]:
|
||||
def _extract_json_payload(self, response: str) -> str | None:
|
||||
if not response:
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
精准记忆系统核心模块
|
||||
1. 基于文档设计的高效记忆构建、存储与召回优化系统,覆盖构建、向量化与多阶段检索全流程。
|
||||
@@ -6,26 +5,27 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import orjson
|
||||
import re
|
||||
import hashlib
|
||||
from typing import Dict, List, Optional, Set, Any, TYPE_CHECKING
|
||||
import re
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.chat.memory_system.memory_builder import MemoryBuilder, MemoryExtractionError
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
|
||||
from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -121,7 +121,7 @@ class MemorySystemConfig:
|
||||
class MemorySystem:
|
||||
"""精准记忆系统核心类"""
|
||||
|
||||
def __init__(self, llm_model: Optional[LLMRequest] = None, config: Optional[MemorySystemConfig] = None):
|
||||
def __init__(self, llm_model: LLMRequest | None = None, config: MemorySystemConfig | None = None):
|
||||
self.config = config or MemorySystemConfig.from_global_config()
|
||||
self.llm_model = llm_model
|
||||
self.status = MemorySystemStatus.INITIALIZING
|
||||
@@ -131,7 +131,7 @@ class MemorySystem:
|
||||
self.fusion_engine: MemoryFusionEngine = None
|
||||
self.unified_storage = None # 统一存储系统
|
||||
self.query_planner: MemoryQueryPlanner = None
|
||||
self.forgetting_engine: Optional[MemoryForgettingEngine] = None
|
||||
self.forgetting_engine: MemoryForgettingEngine | None = None
|
||||
|
||||
# LLM模型
|
||||
self.value_assessment_model: LLMRequest = None
|
||||
@@ -143,10 +143,10 @@ class MemorySystem:
|
||||
self.last_retrieval_time = None
|
||||
|
||||
# 构建节流记录
|
||||
self._last_memory_build_times: Dict[str, float] = {}
|
||||
self._last_memory_build_times: dict[str, float] = {}
|
||||
|
||||
# 记忆指纹缓存,用于快速检测重复记忆
|
||||
self._memory_fingerprints: Dict[str, str] = {}
|
||||
self._memory_fingerprints: dict[str, str] = {}
|
||||
|
||||
logger.info("MemorySystem 初始化开始")
|
||||
|
||||
@@ -210,7 +210,7 @@ class MemorySystem:
|
||||
raise
|
||||
|
||||
# 初始化遗忘引擎
|
||||
from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine, ForgettingConfig
|
||||
from src.chat.memory_system.memory_forgetting_engine import ForgettingConfig, MemoryForgettingEngine
|
||||
|
||||
# 从全局配置创建遗忘引擎配置
|
||||
forgetting_config = ForgettingConfig(
|
||||
@@ -241,7 +241,7 @@ class MemorySystem:
|
||||
self.forgetting_engine = MemoryForgettingEngine(forgetting_config)
|
||||
|
||||
planner_task_config = getattr(model_config.model_task_config, "utils_small", None)
|
||||
planner_model: Optional[LLMRequest] = None
|
||||
planner_model: LLMRequest | None = None
|
||||
try:
|
||||
planner_model = LLMRequest(model_set=planner_task_config, request_type="memory.query_planner")
|
||||
except Exception as planner_exc:
|
||||
@@ -261,8 +261,8 @@ class MemorySystem:
|
||||
raise
|
||||
|
||||
async def retrieve_memories_for_building(
|
||||
self, query_text: str, user_id: Optional[str] = None, context: Optional[Dict[str, Any]] = None, limit: int = 5
|
||||
) -> List[MemoryChunk]:
|
||||
self, query_text: str, user_id: str | None = None, context: dict[str, Any] | None = None, limit: int = 5
|
||||
) -> list[MemoryChunk]:
|
||||
"""在构建记忆时检索相关记忆,使用统一存储系统
|
||||
|
||||
Args:
|
||||
@@ -302,8 +302,8 @@ class MemorySystem:
|
||||
return []
|
||||
|
||||
async def build_memory_from_conversation(
|
||||
self, conversation_text: str, context: Dict[str, Any], timestamp: Optional[float] = None
|
||||
) -> List[MemoryChunk]:
|
||||
self, conversation_text: str, context: dict[str, Any], timestamp: float | None = None
|
||||
) -> list[MemoryChunk]:
|
||||
"""从对话中构建记忆
|
||||
|
||||
Args:
|
||||
@@ -318,8 +318,8 @@ class MemorySystem:
|
||||
self.status = MemorySystemStatus.BUILDING
|
||||
start_time = time.time()
|
||||
|
||||
build_scope_key: Optional[str] = None
|
||||
build_marker_time: Optional[float] = None
|
||||
build_scope_key: str | None = None
|
||||
build_marker_time: float | None = None
|
||||
|
||||
try:
|
||||
normalized_context = self._normalize_context(context, GLOBAL_MEMORY_SCOPE, timestamp)
|
||||
@@ -408,7 +408,7 @@ class MemorySystem:
|
||||
logger.error(f"❌ 记忆构建失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _log_memory_preview(self, memories: List[MemoryChunk]) -> None:
|
||||
def _log_memory_preview(self, memories: list[MemoryChunk]) -> None:
|
||||
"""在控制台输出记忆预览,便于人工检查"""
|
||||
if not memories:
|
||||
logger.info("📝 本次未生成新的记忆")
|
||||
@@ -425,12 +425,12 @@ class MemorySystem:
|
||||
f"置信度={memory.metadata.confidence.name} | 内容={text}"
|
||||
)
|
||||
|
||||
async def _collect_fusion_candidates(self, new_memories: List[MemoryChunk]) -> List[MemoryChunk]:
|
||||
async def _collect_fusion_candidates(self, new_memories: list[MemoryChunk]) -> list[MemoryChunk]:
|
||||
"""收集与新记忆相似的现有记忆,便于融合去重"""
|
||||
if not new_memories:
|
||||
return []
|
||||
|
||||
candidate_ids: Set[str] = set()
|
||||
candidate_ids: set[str] = set()
|
||||
new_memory_ids = {memory.memory_id for memory in new_memories if memory and getattr(memory, "memory_id", None)}
|
||||
|
||||
# 基于指纹的直接匹配
|
||||
@@ -493,7 +493,7 @@ class MemorySystem:
|
||||
continue
|
||||
candidate_ids.add(memory_id)
|
||||
|
||||
existing_candidates: List[MemoryChunk] = []
|
||||
existing_candidates: list[MemoryChunk] = []
|
||||
cache = self.unified_storage.memory_cache if self.unified_storage else {}
|
||||
for candidate_id in candidate_ids:
|
||||
if candidate_id in new_memory_ids:
|
||||
@@ -511,7 +511,7 @@ class MemorySystem:
|
||||
|
||||
return existing_candidates
|
||||
|
||||
async def process_conversation_memory(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def process_conversation_memory(self, context: dict[str, Any]) -> dict[str, Any]:
|
||||
"""对外暴露的对话记忆处理接口,仅依赖上下文信息"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -559,12 +559,12 @@ class MemorySystem:
|
||||
|
||||
async def retrieve_relevant_memories(
|
||||
self,
|
||||
query_text: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
query_text: str | None = None,
|
||||
user_id: str | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
limit: int = 5,
|
||||
**kwargs,
|
||||
) -> List[MemoryChunk]:
|
||||
) -> list[MemoryChunk]:
|
||||
"""检索相关记忆(三阶段召回:元数据粗筛 → 向量精筛 → 综合重排)"""
|
||||
raw_query = query_text or kwargs.get("query")
|
||||
if not raw_query:
|
||||
@@ -750,7 +750,7 @@ class MemorySystem:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _extract_json_payload(response: str) -> Optional[str]:
|
||||
def _extract_json_payload(response: str) -> str | None:
|
||||
"""从模型响应中提取JSON部分,兼容Markdown代码块等格式"""
|
||||
if not response:
|
||||
return None
|
||||
@@ -773,10 +773,10 @@ class MemorySystem:
|
||||
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
|
||||
|
||||
def _normalize_context(
|
||||
self, raw_context: Optional[Dict[str, Any]], user_id: Optional[str], timestamp: Optional[float]
|
||||
) -> Dict[str, Any]:
|
||||
self, raw_context: dict[str, Any] | None, user_id: str | None, timestamp: float | None
|
||||
) -> dict[str, Any]:
|
||||
"""标准化上下文,确保必备字段存在且格式正确"""
|
||||
context: Dict[str, Any] = {}
|
||||
context: dict[str, Any] = {}
|
||||
if raw_context:
|
||||
try:
|
||||
context = dict(raw_context)
|
||||
@@ -822,7 +822,7 @@ class MemorySystem:
|
||||
|
||||
return context
|
||||
|
||||
async def _build_enhanced_query_context(self, raw_query: str, normalized_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def _build_enhanced_query_context(self, raw_query: str, normalized_context: dict[str, Any]) -> dict[str, Any]:
|
||||
"""构建包含未读消息综合上下文的增强查询上下文
|
||||
|
||||
Args:
|
||||
@@ -861,7 +861,7 @@ class MemorySystem:
|
||||
|
||||
return enhanced_context
|
||||
|
||||
async def _collect_unread_messages_context(self, stream_id: str) -> Optional[Dict[str, Any]]:
|
||||
async def _collect_unread_messages_context(self, stream_id: str) -> dict[str, Any] | None:
|
||||
"""收集未读消息的综合上下文信息
|
||||
|
||||
Args:
|
||||
@@ -953,7 +953,7 @@ class MemorySystem:
|
||||
logger.warning(f"收集未读消息上下文失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _build_unread_context_summary(self, messages_summary: List[Dict[str, Any]]) -> str:
|
||||
def _build_unread_context_summary(self, messages_summary: list[dict[str, Any]]) -> str:
|
||||
"""构建未读消息的文本摘要
|
||||
|
||||
Args:
|
||||
@@ -974,7 +974,7 @@ class MemorySystem:
|
||||
|
||||
return " | ".join(summary_parts)
|
||||
|
||||
async def _resolve_conversation_context(self, fallback_text: str, context: Optional[Dict[str, Any]]) -> str:
|
||||
async def _resolve_conversation_context(self, fallback_text: str, context: dict[str, Any] | None) -> str:
|
||||
"""使用 stream_id 历史消息和相关记忆充实对话文本,默认回退到传入文本"""
|
||||
if not context:
|
||||
return fallback_text
|
||||
@@ -1043,11 +1043,11 @@ class MemorySystem:
|
||||
# 回退到传入文本
|
||||
return fallback_text
|
||||
|
||||
def _get_build_scope_key(self, context: Dict[str, Any], user_id: Optional[str]) -> Optional[str]:
|
||||
def _get_build_scope_key(self, context: dict[str, Any], user_id: str | None) -> str | None:
|
||||
"""确定用于节流控制的记忆构建作用域"""
|
||||
return "global_scope"
|
||||
|
||||
def _determine_history_limit(self, context: Dict[str, Any]) -> int:
|
||||
def _determine_history_limit(self, context: dict[str, Any]) -> int:
|
||||
"""确定历史消息获取数量,限制在30-50之间"""
|
||||
default_limit = 40
|
||||
candidate = context.get("history_limit") or context.get("history_window") or context.get("memory_history_limit")
|
||||
@@ -1065,12 +1065,12 @@ class MemorySystem:
|
||||
|
||||
return history_limit
|
||||
|
||||
def _format_history_messages(self, messages: List["DatabaseMessages"]) -> Optional[str]:
|
||||
def _format_history_messages(self, messages: list["DatabaseMessages"]) -> str | None:
|
||||
"""将历史消息格式化为可供LLM处理的多轮对话文本"""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
lines: List[str] = []
|
||||
lines: list[str] = []
|
||||
for msg in messages:
|
||||
try:
|
||||
content = getattr(msg, "processed_plain_text", None) or getattr(msg, "display_message", None)
|
||||
@@ -1105,7 +1105,7 @@ class MemorySystem:
|
||||
|
||||
return "\n".join(lines) if lines else None
|
||||
|
||||
async def _assess_information_value(self, text: str, context: Dict[str, Any]) -> float:
|
||||
async def _assess_information_value(self, text: str, context: dict[str, Any]) -> float:
|
||||
"""评估信息价值
|
||||
|
||||
Args:
|
||||
@@ -1201,7 +1201,7 @@ class MemorySystem:
|
||||
logger.error(f"信息价值评估失败: {e}", exc_info=True)
|
||||
return 0.5 # 默认中等价值
|
||||
|
||||
async def _store_memories_unified(self, memory_chunks: List[MemoryChunk]) -> int:
|
||||
async def _store_memories_unified(self, memory_chunks: list[MemoryChunk]) -> int:
|
||||
"""使用统一存储系统存储记忆块"""
|
||||
if not memory_chunks or not self.unified_storage:
|
||||
return 0
|
||||
@@ -1222,7 +1222,7 @@ class MemorySystem:
|
||||
return 0
|
||||
|
||||
# 保留原有方法以兼容旧代码
|
||||
async def _store_memories(self, memory_chunks: List[MemoryChunk]) -> int:
|
||||
async def _store_memories(self, memory_chunks: list[MemoryChunk]) -> int:
|
||||
"""兼容性方法:重定向到统一存储"""
|
||||
return await self._store_memories_unified(memory_chunks)
|
||||
|
||||
@@ -1271,7 +1271,7 @@ class MemorySystem:
|
||||
key = self._fingerprint_key(memory.user_id, fingerprint)
|
||||
self._memory_fingerprints[key] = memory.memory_id
|
||||
|
||||
def _register_memory_fingerprints(self, memories: List[MemoryChunk]) -> None:
|
||||
def _register_memory_fingerprints(self, memories: list[MemoryChunk]) -> None:
|
||||
for memory in memories:
|
||||
fingerprint = self._build_memory_fingerprint(memory)
|
||||
key = self._fingerprint_key(memory.user_id, fingerprint)
|
||||
@@ -1302,9 +1302,9 @@ class MemorySystem:
|
||||
|
||||
@staticmethod
|
||||
def _fingerprint_key(user_id: str, fingerprint: str) -> str:
|
||||
return f"{str(user_id)}:{fingerprint}"
|
||||
return f"{user_id!s}:{fingerprint}"
|
||||
|
||||
def get_system_stats(self) -> Dict[str, Any]:
|
||||
def get_system_stats(self) -> dict[str, Any]:
|
||||
"""获取系统统计信息"""
|
||||
return {
|
||||
"status": self.status.value,
|
||||
@@ -1314,7 +1314,7 @@ class MemorySystem:
|
||||
"config": asdict(self.config),
|
||||
}
|
||||
|
||||
def _compute_memory_score(self, query_text: str, memory: MemoryChunk, context: Dict[str, Any]) -> float:
|
||||
def _compute_memory_score(self, query_text: str, memory: MemoryChunk, context: dict[str, Any]) -> float:
|
||||
"""根据查询和上下文为记忆计算匹配分数"""
|
||||
tokens_query = self._tokenize_text(query_text)
|
||||
tokens_memory = self._tokenize_text(memory.text_content)
|
||||
@@ -1338,7 +1338,7 @@ class MemorySystem:
|
||||
final_score = base_score * 0.7 + keyword_overlap * 0.15 + importance_boost + confidence_boost
|
||||
return max(0.0, min(1.0, final_score))
|
||||
|
||||
def _tokenize_text(self, text: str) -> Set[str]:
|
||||
def _tokenize_text(self, text: str) -> set[str]:
|
||||
"""简单分词,兼容中英文"""
|
||||
if not text:
|
||||
return set()
|
||||
@@ -1450,7 +1450,7 @@ def get_memory_system() -> MemorySystem:
|
||||
return memory_system
|
||||
|
||||
|
||||
async def initialize_memory_system(llm_model: Optional[LLMRequest] = None):
|
||||
async def initialize_memory_system(llm_model: LLMRequest | None = None):
|
||||
"""初始化全局记忆系统"""
|
||||
global memory_system
|
||||
if memory_system is None:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
基于Vector DB的统一记忆存储系统 V2
|
||||
使用ChromaDB作为底层存储,替代JSON存储方式
|
||||
@@ -11,20 +10,21 @@
|
||||
- 自动清理过期记忆
|
||||
"""
|
||||
|
||||
import time
|
||||
import orjson
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.vector_db import vector_db_service
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, ConfidenceLevel, ImportanceLevel
|
||||
import orjson
|
||||
|
||||
from src.chat.memory_system.memory_chunk import ConfidenceLevel, ImportanceLevel, MemoryChunk
|
||||
from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine
|
||||
from src.chat.memory_system.memory_metadata_index import MemoryMetadataIndex, MemoryMetadataIndexEntry
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.common.logger import get_logger
|
||||
from src.common.vector_db import vector_db_service
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -32,7 +32,7 @@ logger = get_logger(__name__)
|
||||
_ENUM_MAPPINGS_CACHE = {}
|
||||
|
||||
|
||||
def _build_enum_mapping(enum_class: type) -> Dict[str, Any]:
|
||||
def _build_enum_mapping(enum_class: type) -> dict[str, Any]:
|
||||
"""构建枚举类的完整映射表
|
||||
|
||||
Args:
|
||||
@@ -145,7 +145,7 @@ class VectorMemoryStorage:
|
||||
|
||||
"""基于Vector DB的记忆存储系统"""
|
||||
|
||||
def __init__(self, config: Optional[VectorStorageConfig] = None):
|
||||
def __init__(self, config: VectorStorageConfig | None = None):
|
||||
# 默认从全局配置读取,如果没有传入config
|
||||
if config is None:
|
||||
try:
|
||||
@@ -163,15 +163,15 @@ class VectorMemoryStorage:
|
||||
self.vector_db_service = vector_db_service
|
||||
|
||||
# 内存缓存
|
||||
self.memory_cache: Dict[str, MemoryChunk] = {}
|
||||
self.cache_timestamps: Dict[str, float] = {}
|
||||
self.memory_cache: dict[str, MemoryChunk] = {}
|
||||
self.cache_timestamps: dict[str, float] = {}
|
||||
self._cache = self.memory_cache # 别名,兼容旧代码
|
||||
|
||||
# 元数据索引管理器(JSON文件索引)
|
||||
self.metadata_index = MemoryMetadataIndex()
|
||||
|
||||
# 遗忘引擎
|
||||
self.forgetting_engine: Optional[MemoryForgettingEngine] = None
|
||||
self.forgetting_engine: MemoryForgettingEngine | None = None
|
||||
if self.config.enable_forgetting:
|
||||
self.forgetting_engine = MemoryForgettingEngine()
|
||||
|
||||
@@ -267,7 +267,7 @@ class VectorMemoryStorage:
|
||||
except Exception as e:
|
||||
logger.error(f"自动清理失败: {e}")
|
||||
|
||||
def _memory_to_vector_format(self, memory: MemoryChunk) -> Dict[str, Any]:
|
||||
def _memory_to_vector_format(self, memory: MemoryChunk) -> dict[str, Any]:
|
||||
"""将MemoryChunk转换为向量存储格式"""
|
||||
try:
|
||||
# 获取memory_id
|
||||
@@ -323,7 +323,7 @@ class VectorMemoryStorage:
|
||||
logger.error(f"转换记忆 {memory_id} 到向量格式失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _vector_result_to_memory(self, document: str, metadata: Dict[str, Any]) -> Optional[MemoryChunk]:
|
||||
def _vector_result_to_memory(self, document: str, metadata: dict[str, Any]) -> MemoryChunk | None:
|
||||
"""将Vector DB结果转换为MemoryChunk"""
|
||||
try:
|
||||
# 从元数据中恢复完整记忆
|
||||
@@ -440,7 +440,7 @@ class VectorMemoryStorage:
|
||||
logger.warning(f"不支持的{enum_class.__name__}值类型: {type(value)},使用默认值")
|
||||
return default
|
||||
|
||||
def _get_from_cache(self, memory_id: str) -> Optional[MemoryChunk]:
|
||||
def _get_from_cache(self, memory_id: str) -> MemoryChunk | None:
|
||||
"""从缓存获取记忆"""
|
||||
if not self.config.enable_caching:
|
||||
return None
|
||||
@@ -472,7 +472,7 @@ class VectorMemoryStorage:
|
||||
self.memory_cache[memory_id] = memory
|
||||
self.cache_timestamps[memory_id] = time.time()
|
||||
|
||||
async def store_memories(self, memories: List[MemoryChunk]) -> int:
|
||||
async def store_memories(self, memories: list[MemoryChunk]) -> int:
|
||||
"""批量存储记忆"""
|
||||
if not memories:
|
||||
return 0
|
||||
@@ -603,11 +603,11 @@ class VectorMemoryStorage:
|
||||
self,
|
||||
query_text: str,
|
||||
limit: int = 10,
|
||||
similarity_threshold: Optional[float] = None,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
similarity_threshold: float | None = None,
|
||||
filters: dict[str, Any] | None = None,
|
||||
# 新增:元数据过滤参数(用于JSON索引粗筛)
|
||||
metadata_filters: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Tuple[MemoryChunk, float]]:
|
||||
metadata_filters: dict[str, Any] | None = None,
|
||||
) -> list[tuple[MemoryChunk, float]]:
|
||||
"""
|
||||
搜索相似记忆(混合索引模式)
|
||||
|
||||
@@ -632,7 +632,7 @@ class VectorMemoryStorage:
|
||||
|
||||
try:
|
||||
# === 阶段一:JSON元数据粗筛(可选) ===
|
||||
candidate_ids: Optional[List[str]] = None
|
||||
candidate_ids: list[str] | None = None
|
||||
if metadata_filters:
|
||||
logger.debug(f"[JSON元数据粗筛] 开始,过滤条件: {metadata_filters}")
|
||||
candidate_ids = self.metadata_index.search(
|
||||
@@ -746,7 +746,7 @@ class VectorMemoryStorage:
|
||||
logger.error(f"搜索相似记忆失败: {e}")
|
||||
return []
|
||||
|
||||
async def get_memory_by_id(self, memory_id: str) -> Optional[MemoryChunk]:
|
||||
async def get_memory_by_id(self, memory_id: str) -> MemoryChunk | None:
|
||||
"""根据ID获取记忆"""
|
||||
# 首先尝试从缓存获取
|
||||
memory = self._get_from_cache(memory_id)
|
||||
@@ -772,7 +772,7 @@ class VectorMemoryStorage:
|
||||
|
||||
return None
|
||||
|
||||
async def get_memories_by_filters(self, filters: Dict[str, Any], limit: int = 100) -> List[MemoryChunk]:
|
||||
async def get_memories_by_filters(self, filters: dict[str, Any], limit: int = 100) -> list[MemoryChunk]:
|
||||
"""根据过滤条件获取记忆"""
|
||||
try:
|
||||
results = vector_db_service.get(collection_name=self.config.memory_collection, where=filters, limit=limit)
|
||||
@@ -848,7 +848,7 @@ class VectorMemoryStorage:
|
||||
logger.error(f"删除记忆 {memory_id} 失败: {e}")
|
||||
return False
|
||||
|
||||
async def delete_memories_by_filters(self, filters: Dict[str, Any]) -> int:
|
||||
async def delete_memories_by_filters(self, filters: dict[str, Any]) -> int:
|
||||
"""根据过滤条件批量删除记忆"""
|
||||
try:
|
||||
# 先获取要删除的记忆ID
|
||||
@@ -880,7 +880,7 @@ class VectorMemoryStorage:
|
||||
logger.error(f"批量删除记忆失败: {e}")
|
||||
return 0
|
||||
|
||||
async def perform_forgetting_check(self) -> Dict[str, Any]:
|
||||
async def perform_forgetting_check(self) -> dict[str, Any]:
|
||||
"""执行遗忘检查"""
|
||||
if not self.forgetting_engine:
|
||||
return {"error": "遗忘引擎未启用"}
|
||||
@@ -925,7 +925,7 @@ class VectorMemoryStorage:
|
||||
logger.error(f"执行遗忘检查失败: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def get_storage_stats(self) -> Dict[str, Any]:
|
||||
def get_storage_stats(self) -> dict[str, Any]:
|
||||
"""获取存储统计信息"""
|
||||
try:
|
||||
current_total = vector_db_service.count(self.config.memory_collection)
|
||||
@@ -960,7 +960,7 @@ class VectorMemoryStorage:
|
||||
_global_vector_storage = None
|
||||
|
||||
|
||||
def get_vector_memory_storage(config: Optional[VectorStorageConfig] = None) -> VectorMemoryStorage:
|
||||
def get_vector_memory_storage(config: VectorStorageConfig | None = None) -> VectorMemoryStorage:
|
||||
"""获取全局Vector记忆存储实例"""
|
||||
global _global_vector_storage
|
||||
|
||||
@@ -974,15 +974,15 @@ def get_vector_memory_storage(config: Optional[VectorStorageConfig] = None) -> V
|
||||
class VectorMemoryStorageAdapter:
|
||||
"""适配器类,提供与原UnifiedMemoryStorage兼容的接口"""
|
||||
|
||||
def __init__(self, config: Optional[VectorStorageConfig] = None):
|
||||
def __init__(self, config: VectorStorageConfig | None = None):
|
||||
self.storage = VectorMemoryStorage(config)
|
||||
|
||||
async def store_memories(self, memories: List[MemoryChunk]) -> int:
|
||||
async def store_memories(self, memories: list[MemoryChunk]) -> int:
|
||||
return await self.storage.store_memories(memories)
|
||||
|
||||
async def search_similar_memories(
|
||||
self, query_text: str, limit: int = 10, scope_id: Optional[str] = None, filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[Tuple[str, float]]:
|
||||
self, query_text: str, limit: int = 10, scope_id: str | None = None, filters: dict[str, Any] | None = None
|
||||
) -> list[tuple[str, float]]:
|
||||
results = await self.storage.search_similar_memories(query_text, limit, filters=filters)
|
||||
# 转换为原格式:(memory_id, similarity)
|
||||
return [
|
||||
@@ -990,7 +990,7 @@ class VectorMemoryStorageAdapter:
|
||||
for memory, similarity in results
|
||||
]
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
return self.storage.get_storage_stats()
|
||||
|
||||
|
||||
|
||||
@@ -3,14 +3,14 @@
|
||||
提供统一的消息管理、上下文管理和流循环调度功能
|
||||
"""
|
||||
|
||||
from .message_manager import MessageManager, message_manager
|
||||
from .context_manager import SingleStreamContextManager
|
||||
from .distribution_manager import StreamLoopManager, stream_loop_manager
|
||||
from .message_manager import MessageManager, message_manager
|
||||
|
||||
__all__ = [
|
||||
"MessageManager",
|
||||
"message_manager",
|
||||
"SingleStreamContextManager",
|
||||
"StreamLoopManager",
|
||||
"message_manager",
|
||||
"stream_loop_manager",
|
||||
]
|
||||
|
||||
@@ -6,13 +6,14 @@
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
from src.chat.energy_system import energy_manager
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.chat.energy_system import energy_manager
|
||||
|
||||
from .distribution_manager import stream_loop_manager
|
||||
|
||||
logger = get_logger("context_manager")
|
||||
@@ -21,7 +22,7 @@ logger = get_logger("context_manager")
|
||||
class SingleStreamContextManager:
|
||||
"""单流上下文管理器 - 每个实例只管理一个 stream 的上下文"""
|
||||
|
||||
def __init__(self, stream_id: str, context: StreamContext, max_context_size: Optional[int] = None):
|
||||
def __init__(self, stream_id: str, context: StreamContext, max_context_size: int | None = None):
|
||||
self.stream_id = stream_id
|
||||
self.context = context
|
||||
|
||||
@@ -66,7 +67,7 @@ class SingleStreamContextManager:
|
||||
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def update_message(self, message_id: str, updates: Dict[str, Any]) -> bool:
|
||||
async def update_message(self, message_id: str, updates: dict[str, Any]) -> bool:
|
||||
"""更新上下文中的消息
|
||||
|
||||
Args:
|
||||
@@ -84,7 +85,7 @@ class SingleStreamContextManager:
|
||||
logger.error(f"更新单流上下文消息失败 {self.stream_id}/{message_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def get_messages(self, limit: Optional[int] = None, include_unread: bool = True) -> List[DatabaseMessages]:
|
||||
def get_messages(self, limit: int | None = None, include_unread: bool = True) -> list[DatabaseMessages]:
|
||||
"""获取上下文消息
|
||||
|
||||
Args:
|
||||
@@ -117,7 +118,7 @@ class SingleStreamContextManager:
|
||||
logger.error(f"获取单流上下文消息失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def get_unread_messages(self) -> List[DatabaseMessages]:
|
||||
def get_unread_messages(self) -> list[DatabaseMessages]:
|
||||
"""获取未读消息"""
|
||||
try:
|
||||
return self.context.get_unread_messages()
|
||||
@@ -125,7 +126,7 @@ class SingleStreamContextManager:
|
||||
logger.error(f"获取单流未读消息失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def mark_messages_as_read(self, message_ids: List[str]) -> bool:
|
||||
def mark_messages_as_read(self, message_ids: list[str]) -> bool:
|
||||
"""标记消息为已读"""
|
||||
try:
|
||||
if not hasattr(self.context, "mark_message_as_read"):
|
||||
@@ -168,7 +169,7 @@ class SingleStreamContextManager:
|
||||
logger.error(f"清空单流上下文失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
def get_statistics(self) -> dict[str, Any]:
|
||||
"""获取流统计信息"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
@@ -285,7 +286,7 @@ class SingleStreamContextManager:
|
||||
logger.error(f"添加消息到单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def update_message_async(self, message_id: str, updates: Dict[str, Any]) -> bool:
|
||||
async def update_message_async(self, message_id: str, updates: dict[str, Any]) -> bool:
|
||||
"""异步实现的 update_message:更新消息并在需要时 await 能量更新。"""
|
||||
try:
|
||||
self.context.update_message_info(message_id, **updates)
|
||||
@@ -327,7 +328,7 @@ class SingleStreamContextManager:
|
||||
"""更新流能量"""
|
||||
try:
|
||||
history_messages = self.context.get_history_messages(limit=self.max_context_size)
|
||||
messages: List[DatabaseMessages] = list(history_messages)
|
||||
messages: list[DatabaseMessages] = list(history_messages)
|
||||
|
||||
if include_unread:
|
||||
messages.extend(self.get_unread_messages())
|
||||
|
||||
@@ -5,12 +5,12 @@
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
from src.chat.chatter_manager import ChatterManager
|
||||
from src.chat.energy_system import energy_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.energy_system import energy_manager
|
||||
from src.chat.chatter_manager import ChatterManager
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
|
||||
logger = get_logger("stream_loop_manager")
|
||||
@@ -19,13 +19,13 @@ logger = get_logger("stream_loop_manager")
|
||||
class StreamLoopManager:
|
||||
"""流循环管理器 - 每个流一个独立的无限循环任务"""
|
||||
|
||||
def __init__(self, max_concurrent_streams: Optional[int] = None):
|
||||
def __init__(self, max_concurrent_streams: int | None = None):
|
||||
# 流循环任务管理
|
||||
self.stream_loops: Dict[str, asyncio.Task] = {}
|
||||
self.stream_loops: dict[str, asyncio.Task] = {}
|
||||
self.loop_lock = asyncio.Lock()
|
||||
|
||||
# 统计信息
|
||||
self.stats: Dict[str, Any] = {
|
||||
self.stats: dict[str, Any] = {
|
||||
"active_streams": 0,
|
||||
"total_loops": 0,
|
||||
"total_process_cycles": 0,
|
||||
@@ -37,13 +37,13 @@ class StreamLoopManager:
|
||||
self.max_concurrent_streams = max_concurrent_streams or global_config.chat.max_concurrent_distributions
|
||||
|
||||
# 强制分发策略
|
||||
self.force_dispatch_unread_threshold: Optional[int] = getattr(
|
||||
self.force_dispatch_unread_threshold: int | None = getattr(
|
||||
global_config.chat, "force_dispatch_unread_threshold", 20
|
||||
)
|
||||
self.force_dispatch_min_interval: float = getattr(global_config.chat, "force_dispatch_min_interval", 0.1)
|
||||
|
||||
# Chatter管理器
|
||||
self.chatter_manager: Optional[ChatterManager] = None
|
||||
self.chatter_manager: ChatterManager | None = None
|
||||
|
||||
# 状态控制
|
||||
self.is_running = False
|
||||
@@ -212,7 +212,7 @@ class StreamLoopManager:
|
||||
|
||||
logger.info(f"流循环结束: {stream_id}")
|
||||
|
||||
async def _get_stream_context(self, stream_id: str) -> Optional[Any]:
|
||||
async def _get_stream_context(self, stream_id: str) -> Any | None:
|
||||
"""获取流上下文
|
||||
|
||||
Args:
|
||||
@@ -320,7 +320,7 @@ class StreamLoopManager:
|
||||
logger.debug(f"流 {stream_id} 使用默认间隔: {base_interval:.2f}s ({e})")
|
||||
return base_interval
|
||||
|
||||
def get_queue_status(self) -> Dict[str, Any]:
|
||||
def get_queue_status(self) -> dict[str, Any]:
|
||||
"""获取队列状态
|
||||
|
||||
Returns:
|
||||
@@ -374,14 +374,14 @@ class StreamLoopManager:
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def _needs_force_dispatch_for_context(self, context: Any, unread_count: Optional[int] = None) -> bool:
|
||||
def _needs_force_dispatch_for_context(self, context: Any, unread_count: int | None = None) -> bool:
|
||||
if not self.force_dispatch_unread_threshold or self.force_dispatch_unread_threshold <= 0:
|
||||
return False
|
||||
|
||||
count = unread_count if unread_count is not None else self._get_unread_count(context)
|
||||
return count > self.force_dispatch_unread_threshold
|
||||
|
||||
def get_performance_summary(self) -> Dict[str, Any]:
|
||||
def get_performance_summary(self) -> dict[str, Any]:
|
||||
"""获取性能摘要
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -6,19 +6,20 @@
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
from typing import Dict, Optional, Any, TYPE_CHECKING, List
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from src.chat.chatter_manager import ChatterManager
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.message_manager_data_model import MessageManagerStats, StreamStats
|
||||
from src.chat.chatter_manager import ChatterManager
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from .sleep_manager.sleep_manager import SleepManager
|
||||
from .sleep_manager.wakeup_manager import WakeUpManager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
|
||||
from .distribution_manager import stream_loop_manager
|
||||
from .sleep_manager.sleep_manager import SleepManager
|
||||
from .sleep_manager.wakeup_manager import WakeUpManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
@@ -32,7 +33,7 @@ class MessageManager:
|
||||
def __init__(self, check_interval: float = 5.0):
|
||||
self.check_interval = check_interval # 检查间隔(秒)
|
||||
self.is_running = False
|
||||
self.manager_task: Optional[asyncio.Task] = None
|
||||
self.manager_task: asyncio.Task | None = None
|
||||
|
||||
# 统计信息
|
||||
self.stats = MessageManagerStats()
|
||||
@@ -125,7 +126,7 @@ class MessageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"更新消息 {message_id} 时发生错误: {e}")
|
||||
|
||||
async def bulk_update_messages(self, stream_id: str, updates: List[Dict[str, Any]]) -> int:
|
||||
async def bulk_update_messages(self, stream_id: str, updates: list[dict[str, Any]]) -> int:
|
||||
"""批量更新消息信息,降低更新频率"""
|
||||
if not updates:
|
||||
return 0
|
||||
@@ -214,7 +215,7 @@ class MessageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"激活聊天流 {stream_id} 时发生错误: {e}")
|
||||
|
||||
def get_stream_stats(self, stream_id: str) -> Optional[StreamStats]:
|
||||
def get_stream_stats(self, stream_id: str) -> StreamStats | None:
|
||||
"""获取聊天流统计"""
|
||||
try:
|
||||
# 通过 ChatManager 获取 ChatStream
|
||||
@@ -243,7 +244,7 @@ class MessageManager:
|
||||
logger.error(f"获取聊天流 {stream_id} 统计时发生错误: {e}")
|
||||
return None
|
||||
|
||||
def get_manager_stats(self) -> Dict[str, Any]:
|
||||
def get_manager_stats(self) -> dict[str, Any]:
|
||||
"""获取管理器统计"""
|
||||
return {
|
||||
"total_streams": self.stats.total_streams,
|
||||
@@ -278,7 +279,7 @@ class MessageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"清理不活跃聊天流时发生错误: {e}")
|
||||
|
||||
async def _check_and_handle_interruption(self, chat_stream: Optional[ChatStream] = None):
|
||||
async def _check_and_handle_interruption(self, chat_stream: ChatStream | None = None):
|
||||
"""检查并处理消息打断"""
|
||||
if not global_config.chat.interruption_enabled:
|
||||
return
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import asyncio
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
from .notification_sender import NotificationSender
|
||||
from .sleep_state import SleepState, SleepContext
|
||||
from .sleep_state import SleepContext, SleepState
|
||||
from .time_checker import TimeChecker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -92,7 +93,7 @@ class SleepManager:
|
||||
elif current_state == SleepState.WOKEN_UP:
|
||||
self._handle_woken_up(now, is_in_theoretical_sleep, wakeup_manager)
|
||||
|
||||
def _handle_awake_to_sleep(self, now: datetime, activity: Optional[str], wakeup_manager: Optional["WakeUpManager"]):
|
||||
def _handle_awake_to_sleep(self, now: datetime, activity: str | None, wakeup_manager: Optional["WakeUpManager"]):
|
||||
"""处理从“清醒”到“准备入睡”的状态转换。"""
|
||||
if activity:
|
||||
logger.info(f"进入理论休眠时间 '{activity}',开始进行睡眠决策...")
|
||||
@@ -181,7 +182,7 @@ class SleepManager:
|
||||
self,
|
||||
now: datetime,
|
||||
is_in_theoretical_sleep: bool,
|
||||
activity: Optional[str],
|
||||
activity: str | None,
|
||||
wakeup_manager: Optional["WakeUpManager"],
|
||||
):
|
||||
"""处理“正在睡觉”状态下的逻辑。"""
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from datetime import date, datetime
|
||||
from enum import Enum, auto
|
||||
from datetime import datetime, date
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.manager.local_store_manager import local_storage
|
||||
@@ -29,10 +28,10 @@ class SleepContext:
|
||||
def __init__(self):
|
||||
"""初始化睡眠上下文,并从本地存储加载初始状态。"""
|
||||
self.current_state: SleepState = SleepState.AWAKE
|
||||
self.sleep_buffer_end_time: Optional[datetime] = None
|
||||
self.sleep_buffer_end_time: datetime | None = None
|
||||
self.total_delayed_minutes_today: float = 0.0
|
||||
self.last_sleep_check_date: Optional[date] = None
|
||||
self.re_sleep_attempt_time: Optional[datetime] = None
|
||||
self.last_sleep_check_date: date | None = None
|
||||
self.re_sleep_attempt_time: datetime | None = None
|
||||
self.load()
|
||||
|
||||
def save(self):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime, time, timedelta
|
||||
from typing import Optional, List, Dict, Any
|
||||
import random
|
||||
from datetime import datetime, time, timedelta
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
@@ -37,11 +37,11 @@ class TimeChecker:
|
||||
return self._daily_sleep_offset, self._daily_wake_offset
|
||||
|
||||
@staticmethod
|
||||
def get_today_schedule() -> Optional[List[Dict[str, Any]]]:
|
||||
def get_today_schedule() -> list[dict[str, Any]] | None:
|
||||
"""从全局 ScheduleManager 获取今天的日程安排。"""
|
||||
return schedule_manager.today_schedule
|
||||
|
||||
def is_in_theoretical_sleep_time(self, now_time: time) -> tuple[bool, Optional[str]]:
|
||||
def is_in_theoretical_sleep_time(self, now_time: time) -> tuple[bool, str | None]:
|
||||
if global_config.sleep_system.sleep_by_schedule:
|
||||
if self.get_today_schedule():
|
||||
return self._is_in_schedule_sleep_time(now_time)
|
||||
@@ -50,7 +50,7 @@ class TimeChecker:
|
||||
else:
|
||||
return self._is_in_sleep_time(now_time)
|
||||
|
||||
def _is_in_schedule_sleep_time(self, now_time: time) -> tuple[bool, Optional[str]]:
|
||||
def _is_in_schedule_sleep_time(self, now_time: time) -> tuple[bool, str | None]:
|
||||
"""检查当前时间是否落在日程表的任何一个睡眠活动中"""
|
||||
sleep_keywords = ["休眠", "睡觉", "梦乡"]
|
||||
today_schedule = self.get_today_schedule()
|
||||
@@ -79,7 +79,7 @@ class TimeChecker:
|
||||
continue
|
||||
return False, None
|
||||
|
||||
def _is_in_sleep_time(self, now_time: time) -> tuple[bool, Optional[str]]:
|
||||
def _is_in_sleep_time(self, now_time: time) -> tuple[bool, str | None]:
|
||||
"""检查当前时间是否在固定的睡眠时间内(应用偏移量)"""
|
||||
try:
|
||||
start_time_str = global_config.sleep_system.fixed_sleep_time
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from src.chat.message_manager.sleep_manager.wakeup_context import WakeUpContext
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_manager.sleep_manager.wakeup_context import WakeUpContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .sleep_manager import SleepManager
|
||||
@@ -27,9 +28,9 @@ class WakeUpManager:
|
||||
"""
|
||||
self.sleep_manager = sleep_manager
|
||||
self.context = WakeUpContext() # 使用新的上下文管理器
|
||||
self.angry_chat_id: Optional[str] = None
|
||||
self.angry_chat_id: str | None = None
|
||||
self.last_decay_time = time.time()
|
||||
self._decay_task: Optional[asyncio.Task] = None
|
||||
self._decay_task: asyncio.Task | None = None
|
||||
self.is_running = False
|
||||
self.last_log_time = 0
|
||||
self.log_interval = 30
|
||||
@@ -104,9 +105,7 @@ class WakeUpManager:
|
||||
logger.debug(f"唤醒度衰减: {old_value:.1f} -> {self.context.wakeup_value:.1f}")
|
||||
self.context.save()
|
||||
|
||||
def add_wakeup_value(
|
||||
self, is_private_chat: bool, is_mentioned: bool = False, chat_id: Optional[str] = None
|
||||
) -> bool:
|
||||
def add_wakeup_value(self, is_private_chat: bool, is_mentioned: bool = False, chat_id: str | None = None) -> bool:
|
||||
"""
|
||||
增加唤醒度值
|
||||
|
||||
|
||||
@@ -2,9 +2,8 @@ from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_emoji_manager",
|
||||
"get_chat_manager",
|
||||
"MessageStorage",
|
||||
"get_chat_manager",
|
||||
"get_emoji_manager",
|
||||
]
|
||||
|
||||
@@ -1,25 +1,24 @@
|
||||
import traceback
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from maim_message import UserInfo
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mood.mood_manager import mood_manager # 导入情绪管理器
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.message_manager import message_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
|
||||
from src.plugin_system.base import BaseCommand, EventType
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
|
||||
# 导入反注入系统
|
||||
from src.chat.antipromptinjector import initialize_anti_injector
|
||||
from src.chat.message_manager import message_manager
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
from src.mood.mood_manager import mood_manager # 导入情绪管理器
|
||||
from src.plugin_system.base import BaseCommand, EventType
|
||||
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
|
||||
|
||||
# 获取项目根目录(假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
@@ -220,7 +219,7 @@ class ChatBot:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
try:
|
||||
await plus_command_instance.send_text(f"命令执行出错: {str(e)}")
|
||||
await plus_command_instance.send_text(f"命令执行出错: {e!s}")
|
||||
except Exception as send_error:
|
||||
logger.error(f"发送错误消息失败: {send_error}")
|
||||
|
||||
@@ -288,7 +287,7 @@ class ChatBot:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
try:
|
||||
await command_instance.send_text(f"命令执行出错: {str(e)}")
|
||||
await command_instance.send_text(f"命令执行出错: {e!s}")
|
||||
except Exception as send_error:
|
||||
logger.error(f"发送错误消息失败: {send_error}")
|
||||
|
||||
@@ -341,7 +340,7 @@ class ChatBot:
|
||||
except Exception as e:
|
||||
logger.error(f"处理适配器响应时出错: {e}")
|
||||
|
||||
async def do_s4u(self, message_data: Dict[str, Any]):
|
||||
async def do_s4u(self, message_data: dict[str, Any]):
|
||||
message = MessageRecvS4U(message_data)
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
@@ -364,7 +363,7 @@ class ChatBot:
|
||||
|
||||
return
|
||||
|
||||
async def message_process(self, message_data: Dict[str, Any]) -> None:
|
||||
async def message_process(self, message_data: dict[str, Any]) -> None:
|
||||
"""处理转化后的统一格式消息"""
|
||||
try:
|
||||
# 首先处理可能的切片消息重组
|
||||
@@ -462,7 +461,7 @@ class ChatBot:
|
||||
# TODO:暂不可用
|
||||
# 确认从接口发来的message是否有自定义的prompt模板信息
|
||||
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
||||
template_group_name: Optional[str] = message.message_info.template_info.template_name # type: ignore
|
||||
template_group_name: str | None = message.message_info.template_info.template_name # type: ignore
|
||||
template_items = message.message_info.template_info.template_items
|
||||
async with global_prompt_manager.async_message_scope(template_group_name):
|
||||
if isinstance(template_items, dict):
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import hashlib
|
||||
import time
|
||||
import copy
|
||||
from typing import Dict, Optional, TYPE_CHECKING
|
||||
from rich.traceback import install
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
from rich.traceback import install
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config # 新增导入
|
||||
|
||||
# 避免循环导入,使用TYPE_CHECKING进行类型提示
|
||||
@@ -33,8 +34,8 @@ class ChatStream:
|
||||
stream_id: str,
|
||||
platform: str,
|
||||
user_info: UserInfo,
|
||||
group_info: Optional[GroupInfo] = None,
|
||||
data: Optional[dict] = None,
|
||||
group_info: GroupInfo | None = None,
|
||||
data: dict | None = None,
|
||||
):
|
||||
self.stream_id = stream_id
|
||||
self.platform = platform
|
||||
@@ -47,7 +48,7 @@ class ChatStream:
|
||||
|
||||
# 使用StreamContext替代ChatMessageContext
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.plugin_system.base.component_types import ChatType, ChatMode
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
# 创建StreamContext
|
||||
self.stream_context: StreamContext = StreamContext(
|
||||
@@ -133,11 +134,11 @@ class ChatStream:
|
||||
|
||||
# 恢复stream_context信息
|
||||
if "stream_context_chat_type" in data:
|
||||
from src.plugin_system.base.component_types import ChatType, ChatMode
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"])
|
||||
if "stream_context_chat_mode" in data:
|
||||
from src.plugin_system.base.component_types import ChatType, ChatMode
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"])
|
||||
|
||||
@@ -163,9 +164,10 @@ class ChatStream:
|
||||
def set_context(self, message: "MessageRecv"):
|
||||
"""设置聊天消息上下文"""
|
||||
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
import json
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
# 安全获取message_info中的数据
|
||||
message_info = getattr(message, "message_info", {})
|
||||
user_info = getattr(message_info, "user_info", {})
|
||||
@@ -248,7 +250,7 @@ class ChatStream:
|
||||
f"interest_value: {db_message.interest_value}"
|
||||
)
|
||||
|
||||
def _safe_get_actions(self, message: "MessageRecv") -> Optional[list]:
|
||||
def _safe_get_actions(self, message: "MessageRecv") -> list | None:
|
||||
"""安全获取消息的actions字段"""
|
||||
try:
|
||||
actions = getattr(message, "actions", None)
|
||||
@@ -278,7 +280,7 @@ class ChatStream:
|
||||
logger.warning(f"获取actions字段失败: {e}")
|
||||
return None
|
||||
|
||||
def _extract_reply_from_segment(self, segment) -> Optional[str]:
|
||||
def _extract_reply_from_segment(self, segment) -> str | None:
|
||||
"""从消息段中提取reply_to信息"""
|
||||
try:
|
||||
if hasattr(segment, "type") and segment.type == "seglist":
|
||||
@@ -391,8 +393,8 @@ class ChatManager:
|
||||
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||||
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
|
||||
self.streams: dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||||
self.last_messages: dict[str, "MessageRecv"] = {} # stream_id -> last_message
|
||||
# try:
|
||||
# async with get_db_session() as session:
|
||||
# db.connect(reuse_if_open=True)
|
||||
@@ -414,7 +416,7 @@ class ChatManager:
|
||||
await self.load_all_streams()
|
||||
logger.info(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流")
|
||||
except Exception as e:
|
||||
logger.error(f"聊天管理器启动失败: {str(e)}")
|
||||
logger.error(f"聊天管理器启动失败: {e!s}")
|
||||
|
||||
async def _auto_save_task(self):
|
||||
"""定期自动保存所有聊天流"""
|
||||
@@ -424,7 +426,7 @@ class ChatManager:
|
||||
await self._save_all_streams()
|
||||
logger.info("聊天流自动保存完成")
|
||||
except Exception as e:
|
||||
logger.error(f"聊天流自动保存失败: {str(e)}")
|
||||
logger.error(f"聊天流自动保存失败: {e!s}")
|
||||
|
||||
def register_message(self, message: "MessageRecv"):
|
||||
"""注册消息到聊天流"""
|
||||
@@ -437,9 +439,7 @@ class ChatManager:
|
||||
# logger.debug(f"注册消息到聊天流: {stream_id}")
|
||||
|
||||
@staticmethod
|
||||
def _generate_stream_id(
|
||||
platform: str, user_info: Optional[UserInfo], group_info: Optional[GroupInfo] = None
|
||||
) -> str:
|
||||
def _generate_stream_id(platform: str, user_info: UserInfo | None, group_info: GroupInfo | None = None) -> str:
|
||||
"""生成聊天流唯一ID"""
|
||||
if not user_info and not group_info:
|
||||
raise ValueError("用户信息或群组信息必须提供")
|
||||
@@ -462,7 +462,7 @@ class ChatManager:
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
async def get_or_create_stream(
|
||||
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
|
||||
self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None
|
||||
) -> ChatStream:
|
||||
"""获取或创建聊天流
|
||||
|
||||
@@ -572,7 +572,7 @@ class ChatManager:
|
||||
await self._save_stream(stream)
|
||||
return stream
|
||||
|
||||
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
|
||||
def get_stream(self, stream_id: str) -> ChatStream | None:
|
||||
"""通过stream_id获取聊天流"""
|
||||
stream = self.streams.get(stream_id)
|
||||
if not stream:
|
||||
@@ -582,13 +582,13 @@ class ChatManager:
|
||||
return stream
|
||||
|
||||
def get_stream_by_info(
|
||||
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
|
||||
) -> Optional[ChatStream]:
|
||||
self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None
|
||||
) -> ChatStream | None:
|
||||
"""通过信息获取聊天流"""
|
||||
stream_id = self._generate_stream_id(platform, user_info, group_info)
|
||||
return self.streams.get(stream_id)
|
||||
|
||||
def get_stream_name(self, stream_id: str) -> Optional[str]:
|
||||
def get_stream_name(self, stream_id: str) -> str | None:
|
||||
"""根据 stream_id 获取聊天流名称"""
|
||||
stream = self.get_stream(stream_id)
|
||||
if not stream:
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
import base64
|
||||
import time
|
||||
from abc import abstractmethod, ABCMeta
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import urllib3
|
||||
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.utils.utils_image import get_image_manager
|
||||
from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available
|
||||
from src.chat.utils.utils_voice import get_voice_text
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -41,8 +40,8 @@ class Message(MessageBase, metaclass=ABCMeta):
|
||||
message_id: str,
|
||||
chat_stream: "ChatStream",
|
||||
user_info: UserInfo,
|
||||
message_segment: Optional[Seg] = None,
|
||||
timestamp: Optional[float] = None,
|
||||
message_segment: Seg | None = None,
|
||||
timestamp: float | None = None,
|
||||
reply: Optional["MessageRecv"] = None,
|
||||
processed_plain_text: str = "",
|
||||
):
|
||||
@@ -264,7 +263,7 @@ class MessageRecv(Message):
|
||||
logger.warning("视频消息中没有base64数据")
|
||||
return "[收到视频消息,但数据异常]"
|
||||
except Exception as e:
|
||||
logger.error(f"视频处理失败: {str(e)}")
|
||||
logger.error(f"视频处理失败: {e!s}")
|
||||
import traceback
|
||||
|
||||
logger.error(f"错误详情: {traceback.format_exc()}")
|
||||
@@ -278,7 +277,7 @@ class MessageRecv(Message):
|
||||
logger.info("未启用视频识别")
|
||||
return "[视频]"
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
|
||||
|
||||
@@ -291,7 +290,7 @@ class MessageRecvS4U(MessageRecv):
|
||||
self.is_superchat = False
|
||||
self.gift_info = None
|
||||
self.gift_name = None
|
||||
self.gift_count: Optional[str] = None
|
||||
self.gift_count: str | None = None
|
||||
self.superchat_info = None
|
||||
self.superchat_price = None
|
||||
self.superchat_message_text = None
|
||||
@@ -444,7 +443,7 @@ class MessageRecvS4U(MessageRecv):
|
||||
logger.warning("视频消息中没有base64数据")
|
||||
return "[收到视频消息,但数据异常]"
|
||||
except Exception as e:
|
||||
logger.error(f"视频处理失败: {str(e)}")
|
||||
logger.error(f"视频处理失败: {e!s}")
|
||||
import traceback
|
||||
|
||||
logger.error(f"错误详情: {traceback.format_exc()}")
|
||||
@@ -458,7 +457,7 @@ class MessageRecvS4U(MessageRecv):
|
||||
logger.info("未启用视频识别")
|
||||
return "[视频]"
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
|
||||
|
||||
@@ -471,10 +470,10 @@ class MessageProcessBase(Message):
|
||||
message_id: str,
|
||||
chat_stream: "ChatStream",
|
||||
bot_user_info: UserInfo,
|
||||
message_segment: Optional[Seg] = None,
|
||||
message_segment: Seg | None = None,
|
||||
reply: Optional["MessageRecv"] = None,
|
||||
thinking_start_time: float = 0,
|
||||
timestamp: Optional[float] = None,
|
||||
timestamp: float | None = None,
|
||||
):
|
||||
# 调用父类初始化,传递时间戳
|
||||
super().__init__(
|
||||
@@ -533,9 +532,9 @@ class MessageProcessBase(Message):
|
||||
return f"[回复<{self.reply.message_info.user_info.user_nickname}> 的消息:{self.reply.processed_plain_text}]" # type: ignore
|
||||
return None
|
||||
else:
|
||||
return f"[{seg.type}:{str(seg.data)}]"
|
||||
return f"[{seg.type}:{seg.data!s}]"
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
|
||||
logger.error(f"处理消息段失败: {e!s}, 类型: {seg.type}, 数据: {seg.data}")
|
||||
return f"[处理失败的{seg.type}消息]"
|
||||
|
||||
def _generate_detailed_text(self) -> str:
|
||||
@@ -565,8 +564,7 @@ class MessageSending(MessageProcessBase):
|
||||
is_emoji: bool = False,
|
||||
thinking_start_time: float = 0,
|
||||
apply_set_reply_logic: bool = False,
|
||||
reply_to: Optional[str] = None,
|
||||
selected_expressions:List[int] = None,
|
||||
reply_to: str | None = None,
|
||||
):
|
||||
# 调用父类初始化
|
||||
super().__init__(
|
||||
@@ -638,11 +636,11 @@ class MessageSet:
|
||||
self.messages.append(message)
|
||||
self.messages.sort(key=lambda x: x.message_info.time) # type: ignore
|
||||
|
||||
def get_message_by_index(self, index: int) -> Optional[MessageSending]:
|
||||
def get_message_by_index(self, index: int) -> MessageSending | None:
|
||||
"""通过索引获取消息"""
|
||||
return self.messages[index] if 0 <= index < len(self.messages) else None
|
||||
|
||||
def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
|
||||
def get_message_by_time(self, target_time: float) -> MessageSending | None:
|
||||
"""获取最接近指定时间的消息"""
|
||||
if not self.messages:
|
||||
return None
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import re
|
||||
import json
|
||||
import traceback
|
||||
from typing import Union
|
||||
|
||||
import orjson
|
||||
from sqlalchemy import select, desc, update
|
||||
from sqlalchemy import desc, select, update
|
||||
|
||||
from src.common.database.sqlalchemy_models import Messages, Images, get_db_session
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import Images, Messages
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .chat_stream import ChatStream
|
||||
from .message import MessageSending, MessageRecv
|
||||
from .message import MessageRecv, MessageSending
|
||||
|
||||
logger = get_logger("message_storage")
|
||||
|
||||
@@ -33,7 +34,7 @@ class MessageStorage:
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
|
||||
async def store_message(message: MessageSending | MessageRecv, chat_stream: ChatStream) -> None:
|
||||
"""存储消息到数据库"""
|
||||
try:
|
||||
# 过滤敏感信息的正则模式
|
||||
@@ -292,6 +293,7 @@ class MessageStorage:
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from src.common.database.sqlalchemy_models import Messages
|
||||
|
||||
# 查找需要修复的记录:interest_value为0、null或很小的值
|
||||
|
||||
@@ -3,12 +3,11 @@ import traceback
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.message.api import get_global_api
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.message import MessageSending
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.utils.utils import truncate_message
|
||||
from src.chat.utils.utils import calculate_typing_time
|
||||
from src.chat.utils.utils import calculate_typing_time, truncate_message
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message.api import get_global_api
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -27,7 +26,7 @@ async def send_message(message: MessageSending, show_log=True) -> bool:
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {str(e)}")
|
||||
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {e!s}")
|
||||
traceback.print_exc()
|
||||
raise e # 重新抛出其他异常
|
||||
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
import time
|
||||
from typing import Dict, Optional, Type, Any, Tuple
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.base.component_types import ComponentType, ActionInfo
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.plugin_system.apis import database_api, generator_api, message_api, send_api
|
||||
from src.plugin_system.base.base_action import BaseAction
|
||||
from src.plugin_system.apis import generator_api, database_api, send_api, message_api
|
||||
|
||||
from src.plugin_system.base.component_types import ActionInfo, ComponentType
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
logger = get_logger("action_manager")
|
||||
|
||||
@@ -29,7 +27,7 @@ class ChatterActionManager:
|
||||
"""初始化动作管理器"""
|
||||
|
||||
# 当前正在使用的动作集合,默认加载默认动作
|
||||
self._using_actions: Dict[str, ActionInfo] = {}
|
||||
self._using_actions: dict[str, ActionInfo] = {}
|
||||
|
||||
# 初始化时将默认动作加载到使用中的动作
|
||||
self._using_actions = component_registry.get_default_actions()
|
||||
@@ -48,8 +46,8 @@ class ChatterActionManager:
|
||||
chat_stream: ChatStream,
|
||||
log_prefix: str,
|
||||
shutting_down: bool = False,
|
||||
action_message: Optional[dict] = None,
|
||||
) -> Optional[BaseAction]:
|
||||
action_message: dict | None = None,
|
||||
) -> BaseAction | None:
|
||||
"""
|
||||
创建动作处理器实例
|
||||
|
||||
@@ -68,7 +66,7 @@ class ChatterActionManager:
|
||||
"""
|
||||
try:
|
||||
# 获取组件类 - 明确指定查询Action类型
|
||||
component_class: Type[BaseAction] = component_registry.get_component_class(
|
||||
component_class: type[BaseAction] = component_registry.get_component_class(
|
||||
action_name, ComponentType.ACTION
|
||||
) # type: ignore
|
||||
if not component_class:
|
||||
@@ -107,7 +105,7 @@ class ChatterActionManager:
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
def get_using_actions(self) -> Dict[str, ActionInfo]:
|
||||
def get_using_actions(self) -> dict[str, ActionInfo]:
|
||||
"""获取当前正在使用的动作集合"""
|
||||
return self._using_actions.copy()
|
||||
|
||||
@@ -140,10 +138,10 @@ class ChatterActionManager:
|
||||
self,
|
||||
action_name: str,
|
||||
chat_id: str,
|
||||
target_message: Optional[dict] = None,
|
||||
target_message: dict | None = None,
|
||||
reasoning: str = "",
|
||||
action_data: Optional[dict] = None,
|
||||
thinking_id: Optional[str] = None,
|
||||
action_data: dict | None = None,
|
||||
thinking_id: str | None = None,
|
||||
log_prefix: str = "",
|
||||
clear_unread_messages: bool = True,
|
||||
) -> Any:
|
||||
@@ -437,10 +435,10 @@ class ChatterActionManager:
|
||||
response_set,
|
||||
loop_start_time,
|
||||
action_message,
|
||||
cycle_timers: Dict[str, float],
|
||||
cycle_timers: dict[str, float],
|
||||
thinking_id,
|
||||
actions,
|
||||
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
|
||||
) -> tuple[dict[str, Any], str, dict[str, float]]:
|
||||
"""
|
||||
发送并存储回复信息
|
||||
|
||||
@@ -488,7 +486,7 @@ class ChatterActionManager:
|
||||
)
|
||||
|
||||
# 构建循环信息
|
||||
loop_info: Dict[str, Any] = {
|
||||
loop_info: dict[str, Any] = {
|
||||
"loop_plan_info": {
|
||||
"action_result": actions,
|
||||
},
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
import random
|
||||
import asyncio
|
||||
import hashlib
|
||||
import random
|
||||
import time
|
||||
from typing import List, Any, Dict, TYPE_CHECKING, Tuple
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
||||
from src.plugin_system.base.component_types import ActionInfo, ActionActivationType
|
||||
from src.plugin_system.base.component_types import ActionActivationType, ActionInfo
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -59,18 +59,17 @@ class ActionModifier:
|
||||
"""
|
||||
logger.debug(f"{self.log_prefix}开始完整动作修改流程")
|
||||
|
||||
removals_s1: List[Tuple[str, str]] = []
|
||||
removals_s2: List[Tuple[str, str]] = []
|
||||
removals_s3: List[Tuple[str, str]] = []
|
||||
removals_s1: list[tuple[str, str]] = []
|
||||
removals_s2: list[tuple[str, str]] = []
|
||||
removals_s3: list[tuple[str, str]] = []
|
||||
|
||||
self.action_manager.restore_actions()
|
||||
all_actions = self.action_manager.get_using_actions()
|
||||
|
||||
# === 第0阶段:根据聊天类型过滤动作 ===
|
||||
from src.plugin_system.base.component_types import ChatType
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.plugin_system.base.component_types import ChatType, ComponentType
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
# 获取聊天类型
|
||||
is_group_chat, _ = get_chat_type_and_target_info(self.chat_id)
|
||||
@@ -167,8 +166,8 @@ class ActionModifier:
|
||||
|
||||
logger.info(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}")
|
||||
|
||||
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: StreamContext):
|
||||
type_mismatched_actions: List[Tuple[str, str]] = []
|
||||
def _check_action_associated_types(self, all_actions: dict[str, ActionInfo], chat_context: StreamContext):
|
||||
type_mismatched_actions: list[tuple[str, str]] = []
|
||||
for action_name, action_info in all_actions.items():
|
||||
if action_info.associated_types and not chat_context.check_types(action_info.associated_types):
|
||||
associated_types_str = ", ".join(action_info.associated_types)
|
||||
@@ -179,9 +178,9 @@ class ActionModifier:
|
||||
|
||||
async def _get_deactivated_actions_by_type(
|
||||
self,
|
||||
actions_with_info: Dict[str, ActionInfo],
|
||||
actions_with_info: dict[str, ActionInfo],
|
||||
chat_content: str = "",
|
||||
) -> List[tuple[str, str]]:
|
||||
) -> list[tuple[str, str]]:
|
||||
"""
|
||||
根据激活类型过滤,返回需要停用的动作列表及原因
|
||||
|
||||
@@ -254,9 +253,9 @@ class ActionModifier:
|
||||
|
||||
async def _process_llm_judge_actions_parallel(
|
||||
self,
|
||||
llm_judge_actions: Dict[str, Any],
|
||||
llm_judge_actions: dict[str, Any],
|
||||
chat_content: str = "",
|
||||
) -> Dict[str, bool]:
|
||||
) -> dict[str, bool]:
|
||||
"""
|
||||
并行处理LLM判定actions,支持智能缓存
|
||||
|
||||
|
||||
@@ -3,42 +3,41 @@
|
||||
使用重构后的统一Prompt系统替换原有的复杂提示词构建逻辑
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import time
|
||||
import asyncio
|
||||
import random
|
||||
import re
|
||||
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from src.mais4u.mai_think import mai_thinking_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.individuality.individuality import get_individuality
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending
|
||||
from typing import Any
|
||||
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.utils.memory_mappings import get_memory_type_chinese_label
|
||||
from src.chat.message_receive.message import MessageRecv, MessageSending, Seg, UserInfo
|
||||
from src.chat.message_receive.uni_message_sender import HeartFCSender
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
replace_user_references_sync,
|
||||
)
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
from src.chat.utils.memory_mappings import get_memory_type_chinese_label
|
||||
|
||||
# 导入新的统一Prompt系统
|
||||
from src.chat.utils.prompt import Prompt, PromptParameters, global_prompt_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.individuality.individuality import get_individuality
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.mais4u.mai_think import mai_thinking_manager
|
||||
|
||||
# 旧记忆系统已被移除
|
||||
# 旧记忆系统已被移除
|
||||
from src.mood.mood_manager import mood_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo, EventType
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
# 导入新的统一Prompt系统
|
||||
from src.chat.utils.prompt import PromptParameters
|
||||
from src.plugin_system.base.component_types import ActionInfo, EventType
|
||||
|
||||
logger = get_logger("replyer")
|
||||
|
||||
@@ -259,15 +258,12 @@ class DefaultReplyer:
|
||||
async def generate_reply_with_context(
|
||||
self,
|
||||
extra_info: str = "",
|
||||
reply_reason: str = "",
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
choosen_actions: Optional[List[Dict[str, Any]]] = None,
|
||||
available_actions: dict[str, ActionInfo] | None = None,
|
||||
enable_tool: bool = True,
|
||||
from_plugin: bool = True,
|
||||
stream_id: Optional[str] = None,
|
||||
reply_message: Optional[Dict[str, Any]] = None,
|
||||
read_mark: float = 0.0,
|
||||
) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]:
|
||||
stream_id: str | None = None,
|
||||
reply_message: dict[str, Any] | None = None,
|
||||
) -> tuple[bool, dict[str, Any] | None, str | None]:
|
||||
# sourcery skip: merge-nested-ifs
|
||||
"""
|
||||
回复器 (Replier): 负责生成回复文本的核心逻辑。
|
||||
@@ -376,7 +372,7 @@ class DefaultReplyer:
|
||||
reason: str = "",
|
||||
reply_to: str = "",
|
||||
return_prompt: bool = False,
|
||||
) -> Tuple[bool, Optional[str], Optional[str]]:
|
||||
) -> tuple[bool, str | None, str | None]:
|
||||
"""
|
||||
表达器 (Expressor): 负责重写和优化回复文本。
|
||||
|
||||
@@ -740,8 +736,7 @@ class DefaultReplyer:
|
||||
logger.error(f"工具信息获取失败: {e}")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _parse_reply_target(target_message: str) -> Tuple[str, str]:
|
||||
def _parse_reply_target(self, target_message: str) -> tuple[str, str]:
|
||||
"""解析回复目标消息 - 使用共享工具"""
|
||||
from src.chat.utils.prompt import Prompt
|
||||
|
||||
@@ -750,8 +745,7 @@ class DefaultReplyer:
|
||||
return "未知用户", "(无消息内容)"
|
||||
return Prompt.parse_reply_target(target_message)
|
||||
|
||||
@staticmethod
|
||||
async def build_keywords_reaction_prompt(target: Optional[str]) -> str:
|
||||
async def build_keywords_reaction_prompt(self, target: str | None) -> str:
|
||||
"""构建关键词反应提示
|
||||
|
||||
Args:
|
||||
@@ -786,15 +780,14 @@ class DefaultReplyer:
|
||||
keywords_reaction_prompt += f"{reaction},"
|
||||
break
|
||||
except re.error as e:
|
||||
logger.error(f"正则表达式编译错误: {pattern_str}, 错误信息: {str(e)}")
|
||||
logger.error(f"正则表达式编译错误: {pattern_str}, 错误信息: {e!s}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"关键词检测与反应时发生异常: {str(e)}", exc_info=True)
|
||||
logger.error(f"关键词检测与反应时发生异常: {e!s}", exc_info=True)
|
||||
|
||||
return keywords_reaction_prompt
|
||||
|
||||
@staticmethod
|
||||
async def _time_and_run_task(coroutine, name: str) -> Tuple[str, Any, float]:
|
||||
async def _time_and_run_task(self, coroutine, name: str) -> tuple[str, Any, float]:
|
||||
"""计时并运行异步任务的辅助函数
|
||||
|
||||
Args:
|
||||
@@ -811,8 +804,8 @@ class DefaultReplyer:
|
||||
return name, result, duration
|
||||
|
||||
async def build_s4u_chat_history_prompts(
|
||||
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, chat_id: str
|
||||
) -> Tuple[str, str]:
|
||||
self, message_list_before_now: list[dict[str, Any]], target_user_id: str, sender: str, chat_id: str
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
构建 s4u 风格的已读/未读历史消息 prompt
|
||||
|
||||
@@ -928,8 +921,8 @@ class DefaultReplyer:
|
||||
return await self._fallback_build_chat_history_prompts(message_list_before_now, target_user_id, sender)
|
||||
|
||||
async def _fallback_build_chat_history_prompts(
|
||||
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str
|
||||
) -> Tuple[str, str]:
|
||||
self, message_list_before_now: list[dict[str, Any]], target_user_id: str, sender: str
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
回退的已读/未读历史消息构建方法
|
||||
"""
|
||||
@@ -1021,15 +1014,15 @@ class DefaultReplyer:
|
||||
|
||||
return read_history_prompt, unread_history_prompt
|
||||
|
||||
async def _get_interest_scores_for_messages(self, messages: List[dict]) -> dict[str, float]:
|
||||
async def _get_interest_scores_for_messages(self, messages: list[dict]) -> dict[str, float]:
|
||||
"""为消息获取兴趣度评分"""
|
||||
interest_scores = {}
|
||||
|
||||
try:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import (
|
||||
chatter_interest_scoring_system as interest_scoring_system,
|
||||
)
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
# 转换消息格式
|
||||
db_messages = []
|
||||
@@ -1148,12 +1141,9 @@ class DefaultReplyer:
|
||||
async def build_prompt_reply_context(
|
||||
self,
|
||||
extra_info: str = "",
|
||||
reply_reason: str = "",
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
choosen_actions: Optional[List[Dict[str, Any]]] = None,
|
||||
available_actions: dict[str, ActionInfo] | None = None,
|
||||
enable_tool: bool = True,
|
||||
reply_message: Optional[Dict[str, Any]] = None,
|
||||
read_mark: float = 0.0,
|
||||
reply_message: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
构建回复器上下文
|
||||
@@ -1521,7 +1511,7 @@ class DefaultReplyer:
|
||||
raw_reply: str,
|
||||
reason: str,
|
||||
reply_to: str,
|
||||
reply_message: Optional[Dict[str, Any]] = None,
|
||||
reply_message: dict[str, Any] | None = None,
|
||||
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||
await self._async_init()
|
||||
chat_stream = self.chat_stream
|
||||
@@ -1659,7 +1649,7 @@ class DefaultReplyer:
|
||||
is_emoji: bool,
|
||||
thinking_start_time: float,
|
||||
display_message: str,
|
||||
anchor_message: Optional[MessageRecv] = None,
|
||||
anchor_message: MessageRecv | None = None,
|
||||
) -> MessageSending:
|
||||
"""构建单个发送消息"""
|
||||
|
||||
@@ -1750,7 +1740,7 @@ class DefaultReplyer:
|
||||
logger.debug("从LPMM知识库获取知识失败,可能是从未导入过知识,返回空知识...")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"获取知识库内容时发生异常: {str(e)}")
|
||||
logger.error(f"获取知识库内容时发生异常: {e!s}")
|
||||
return ""
|
||||
|
||||
async def build_relation_info(self, sender: str, target: str):
|
||||
@@ -1766,10 +1756,9 @@ class DefaultReplyer:
|
||||
|
||||
# 使用AFC关系追踪器获取关系信息
|
||||
try:
|
||||
from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker
|
||||
|
||||
# 创建关系追踪器实例
|
||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||
from src.plugins.built_in.affinity_flow_chatter.relationship_tracker import ChatterRelationshipTracker
|
||||
|
||||
relationship_tracker = ChatterRelationshipTracker(chatter_interest_scoring_system)
|
||||
if relationship_tracker:
|
||||
@@ -1810,7 +1799,7 @@ class DefaultReplyer:
|
||||
logger.error(f"获取AFC关系信息失败: {e}")
|
||||
return f"你与{sender}是普通朋友关系。"
|
||||
|
||||
async def _store_chat_memory_async(self, reply_to: str, reply_message: Optional[Dict[str, Any]] = None):
|
||||
async def _store_chat_memory_async(self, reply_to: str, reply_message: dict[str, Any] | None = None):
|
||||
"""
|
||||
异步存储聊天记忆(从build_memory_block迁移而来)
|
||||
|
||||
|
||||
@@ -1,22 +1,20 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.replyer.default_generator import DefaultReplyer
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("ReplyerManager")
|
||||
|
||||
|
||||
class ReplyerManager:
|
||||
def __init__(self):
|
||||
self._repliers: Dict[str, DefaultReplyer] = {}
|
||||
self._repliers: dict[str, DefaultReplyer] = {}
|
||||
|
||||
def get_replyer(
|
||||
self,
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
chat_stream: ChatStream | None = None,
|
||||
chat_id: str | None = None,
|
||||
request_type: str = "replyer",
|
||||
) -> Optional[DefaultReplyer]:
|
||||
) -> DefaultReplyer | None:
|
||||
"""
|
||||
获取或创建回复器实例。
|
||||
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
import time # 导入 time 模块以获取当前时间
|
||||
import random
|
||||
import re
|
||||
import time # 导入 time 模块以获取当前时间
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from typing import List, Dict, Any, Tuple, Optional, Callable
|
||||
from rich.traceback import install
|
||||
from sqlalchemy import and_, select
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.message_repository import find_messages, count_messages
|
||||
from src.common.database.sqlalchemy_models import ActionRecords, Images
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
|
||||
from src.chat.utils.utils import assign_message_ids, translate_timestamp_to_human_readable
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from sqlalchemy import select, and_
|
||||
from src.common.database.sqlalchemy_models import ActionRecords, Images
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import count_messages, find_messages
|
||||
from src.config.config import global_config
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
|
||||
logger = get_logger("chat_message_builder")
|
||||
|
||||
@@ -22,7 +23,7 @@ install(extra_lines=3)
|
||||
def replace_user_references_sync(
|
||||
content: str,
|
||||
platform: str,
|
||||
name_resolver: Optional[Callable[[str, str], str]] = None,
|
||||
name_resolver: Callable[[str, str], str] | None = None,
|
||||
replace_bot_name: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
@@ -98,7 +99,7 @@ def replace_user_references_sync(
|
||||
async def replace_user_references_async(
|
||||
content: str,
|
||||
platform: str,
|
||||
name_resolver: Optional[Callable[[str, str], Any]] = None,
|
||||
name_resolver: Callable[[str, str], Any] | None = None,
|
||||
replace_bot_name: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
@@ -171,7 +172,7 @@ async def replace_user_references_async(
|
||||
|
||||
async def get_raw_msg_by_timestamp(
|
||||
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
@@ -191,7 +192,7 @@ async def get_raw_msg_by_timestamp_with_chat(
|
||||
limit_mode: str = "latest",
|
||||
filter_bot=False,
|
||||
filter_command=False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。
|
||||
@@ -217,7 +218,7 @@ async def get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
filter_bot=False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的消息(包含边界),按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。
|
||||
@@ -236,10 +237,10 @@ async def get_raw_msg_by_timestamp_with_chat_users(
|
||||
chat_id: str,
|
||||
timestamp_start: float,
|
||||
timestamp_end: float,
|
||||
person_ids: List[str],
|
||||
person_ids: list[str],
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取某些特定用户在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。
|
||||
@@ -260,7 +261,7 @@ async def get_actions_by_timestamp_with_chat(
|
||||
timestamp_end: float = time.time(),
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -369,7 +370,7 @@ async def get_actions_by_timestamp_with_chat(
|
||||
|
||||
async def get_actions_by_timestamp_with_chat_inclusive(
|
||||
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
|
||||
async with get_db_session() as session:
|
||||
if limit > 0:
|
||||
@@ -420,7 +421,7 @@ async def get_actions_by_timestamp_with_chat_inclusive(
|
||||
|
||||
async def get_raw_msg_by_timestamp_random(
|
||||
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
先在范围时间戳内随机选择一条消息,取得消息的chat_id,然后根据chat_id获取该聊天在指定时间戳范围内的消息
|
||||
"""
|
||||
@@ -438,7 +439,7 @@ async def get_raw_msg_by_timestamp_random(
|
||||
|
||||
async def get_raw_msg_by_timestamp_with_users(
|
||||
timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。
|
||||
@@ -449,7 +450,7 @@ async def get_raw_msg_by_timestamp_with_users(
|
||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
|
||||
|
||||
|
||||
async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
|
||||
async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> list[dict[str, Any]]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
"""
|
||||
@@ -460,7 +461,7 @@ async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List
|
||||
|
||||
async def get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id: str, timestamp: float, limit: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
"""
|
||||
@@ -471,7 +472,7 @@ async def get_raw_msg_before_timestamp_with_chat(
|
||||
|
||||
async def get_raw_msg_before_timestamp_with_users(
|
||||
timestamp: float, person_ids: list, limit: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
"""
|
||||
@@ -480,9 +481,7 @@ async def get_raw_msg_before_timestamp_with_users(
|
||||
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
|
||||
|
||||
async def num_new_messages_since(
|
||||
chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None
|
||||
) -> int:
|
||||
async def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: float | None = None) -> int:
|
||||
"""
|
||||
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
|
||||
如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。
|
||||
@@ -514,17 +513,16 @@ async def num_new_messages_since_with_users(
|
||||
|
||||
|
||||
async def _build_readable_messages_internal(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: list[dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
truncate: bool = False,
|
||||
pic_id_mapping: Optional[Dict[str, str]] = None,
|
||||
pic_id_mapping: dict[str, str] | None = None,
|
||||
pic_counter: int = 1,
|
||||
show_pic: bool = True,
|
||||
message_id_list: Optional[List[Dict[str, Any]]] = None,
|
||||
read_mark: float = 0.0,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
|
||||
message_id_list: list[dict[str, Any]] | None = None,
|
||||
) -> tuple[str, list[tuple[float, str, str]], dict[str, str], int]:
|
||||
"""
|
||||
内部辅助函数,构建可读消息字符串和原始消息详情列表。
|
||||
|
||||
@@ -543,7 +541,7 @@ async def _build_readable_messages_internal(
|
||||
if not messages:
|
||||
return "", [], pic_id_mapping or {}, pic_counter
|
||||
|
||||
message_details_raw: List[Tuple[float, str, str, bool]] = []
|
||||
message_details_raw: list[tuple[float, str, str, bool]] = []
|
||||
|
||||
# 使用传入的映射字典,如果没有则创建新的
|
||||
if pic_id_mapping is None:
|
||||
@@ -669,7 +667,7 @@ async def _build_readable_messages_internal(
|
||||
message_details_with_flags.append((timestamp, name, content, is_action))
|
||||
|
||||
# 应用截断逻辑 (如果 truncate 为 True)
|
||||
message_details: List[Tuple[float, str, str, bool]] = []
|
||||
message_details: list[tuple[float, str, str, bool]] = []
|
||||
n_messages = len(message_details_with_flags)
|
||||
if truncate and n_messages > 0:
|
||||
for i, (timestamp, name, content, is_action) in enumerate(message_details_with_flags):
|
||||
@@ -811,7 +809,7 @@ async def _build_readable_messages_internal(
|
||||
)
|
||||
|
||||
|
||||
async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
async def build_pic_mapping_info(pic_id_mapping: dict[str, str]) -> str:
|
||||
# sourcery skip: use-contextlib-suppress
|
||||
"""
|
||||
构建图片映射信息字符串,显示图片的具体描述内容
|
||||
@@ -849,7 +847,7 @@ async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
return "\n".join(mapping_lines)
|
||||
|
||||
|
||||
def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
|
||||
def build_readable_actions(actions: list[dict[str, Any]]) -> str:
|
||||
"""
|
||||
将动作列表转换为可读的文本格式。
|
||||
格式: 在()分钟前,你使用了(action_name),具体内容是:(action_prompt_display)
|
||||
@@ -924,12 +922,12 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
|
||||
|
||||
|
||||
async def build_readable_messages_with_list(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: list[dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
truncate: bool = False,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]]]:
|
||||
) -> tuple[str, list[tuple[float, str, str]]]:
|
||||
"""
|
||||
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
|
||||
允许通过参数控制格式化行为。
|
||||
@@ -945,7 +943,7 @@ async def build_readable_messages_with_list(
|
||||
|
||||
|
||||
async def build_readable_messages_with_id(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: list[dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
@@ -953,7 +951,7 @@ async def build_readable_messages_with_id(
|
||||
truncate: bool = False,
|
||||
show_actions: bool = False,
|
||||
show_pic: bool = True,
|
||||
) -> Tuple[str, List[Dict[str, Any]]]:
|
||||
) -> tuple[str, list[dict[str, Any]]]:
|
||||
"""
|
||||
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
|
||||
允许通过参数控制格式化行为。
|
||||
@@ -982,7 +980,7 @@ async def build_readable_messages_with_id(
|
||||
|
||||
|
||||
async def build_readable_messages(
|
||||
messages: List[Dict[str, Any]],
|
||||
messages: list[dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
@@ -990,7 +988,7 @@ async def build_readable_messages(
|
||||
truncate: bool = False,
|
||||
show_actions: bool = True,
|
||||
show_pic: bool = True,
|
||||
message_id_list: Optional[List[Dict[str, Any]]] = None,
|
||||
message_id_list: list[dict[str, Any]] | None = None,
|
||||
) -> str: # sourcery skip: extract-method
|
||||
"""
|
||||
将消息列表转换为可读的文本格式。
|
||||
@@ -1150,7 +1148,7 @@ async def build_readable_messages(
|
||||
return "".join(result_parts)
|
||||
|
||||
|
||||
async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
async def build_anonymous_messages(messages: list[dict[str, Any]]) -> str:
|
||||
"""
|
||||
构建匿名可读消息,将不同人的名称转为唯一占位符(A、B、C...),bot自己用SELF。
|
||||
处理 回复<aaa:bbb> 和 @<aaa:bbb> 字段,将bbb映射为匿名占位符。
|
||||
@@ -1263,7 +1261,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
return formatted_string
|
||||
|
||||
|
||||
async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
async def get_person_id_list(messages: list[dict[str, Any]]) -> list[str]:
|
||||
"""
|
||||
从消息列表中提取不重复的 person_id 列表 (忽略机器人自身)。
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆系统相关的映射表和工具函数
|
||||
提供记忆类型、置信度、重要性等的中文标签映射
|
||||
|
||||
@@ -3,19 +3,20 @@
|
||||
将原有的Prompt类和SmartPrompt功能整合为一个真正的Prompt类
|
||||
"""
|
||||
|
||||
import re
|
||||
import asyncio
|
||||
import time
|
||||
import contextvars
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Any, Optional, List, Literal, Tuple
|
||||
import re
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -50,11 +51,11 @@ class PromptParameters:
|
||||
debug_mode: bool = False
|
||||
|
||||
# 聊天历史和上下文
|
||||
chat_target_info: Optional[Dict[str, Any]] = None
|
||||
message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list)
|
||||
message_list_before_short: List[Dict[str, Any]] = field(default_factory=list)
|
||||
chat_target_info: dict[str, Any] | None = None
|
||||
message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list)
|
||||
message_list_before_short: list[dict[str, Any]] = field(default_factory=list)
|
||||
chat_talking_prompt_short: str = ""
|
||||
target_user_info: Optional[Dict[str, Any]] = None
|
||||
target_user_info: dict[str, Any] | None = None
|
||||
|
||||
# 已构建的内容块
|
||||
expression_habits_block: str = ""
|
||||
@@ -77,12 +78,12 @@ class PromptParameters:
|
||||
action_descriptions: str = ""
|
||||
|
||||
# 可用动作信息
|
||||
available_actions: Optional[Dict[str, Any]] = None
|
||||
available_actions: dict[str, Any] | None = None
|
||||
|
||||
# 动态生成的聊天场景提示
|
||||
chat_scene: str = ""
|
||||
|
||||
def validate(self) -> List[str]:
|
||||
def validate(self) -> list[str]:
|
||||
"""参数验证"""
|
||||
errors = []
|
||||
if not self.chat_id:
|
||||
@@ -98,22 +99,22 @@ class PromptContext:
|
||||
"""提示词上下文管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
|
||||
self._context_prompts: dict[str, dict[str, "Prompt"]] = {}
|
||||
self._current_context_var = contextvars.ContextVar("current_context", default=None)
|
||||
self._context_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def _current_context(self) -> Optional[str]:
|
||||
def _current_context(self) -> str | None:
|
||||
"""获取当前协程的上下文ID"""
|
||||
return self._current_context_var.get()
|
||||
|
||||
@_current_context.setter
|
||||
def _current_context(self, value: Optional[str]):
|
||||
def _current_context(self, value: str | None):
|
||||
"""设置当前协程的上下文ID"""
|
||||
self._current_context_var.set(value) # type: ignore
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_scope(self, context_id: Optional[str] = None):
|
||||
async def async_scope(self, context_id: str | None = None):
|
||||
"""创建一个异步的临时提示模板作用域"""
|
||||
if context_id is not None:
|
||||
try:
|
||||
@@ -159,7 +160,7 @@ class PromptContext:
|
||||
return self._context_prompts[current_context][name]
|
||||
return None
|
||||
|
||||
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
|
||||
async def register_async(self, prompt: "Prompt", context_id: str | None = None) -> None:
|
||||
"""异步注册提示模板到指定作用域"""
|
||||
async with self._context_lock:
|
||||
if target_context := context_id or self._current_context:
|
||||
@@ -177,7 +178,7 @@ class PromptManager:
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_message_scope(self, message_id: Optional[str] = None):
|
||||
async def async_message_scope(self, message_id: str | None = None):
|
||||
"""为消息处理创建异步临时作用域"""
|
||||
async with self._context.async_scope(message_id):
|
||||
yield self
|
||||
@@ -240,8 +241,8 @@ class Prompt:
|
||||
def __init__(
|
||||
self,
|
||||
template: str,
|
||||
name: Optional[str] = None,
|
||||
parameters: Optional[PromptParameters] = None,
|
||||
name: str | None = None,
|
||||
parameters: PromptParameters | None = None,
|
||||
should_register: bool = True,
|
||||
):
|
||||
"""
|
||||
@@ -281,7 +282,7 @@ class Prompt:
|
||||
"""将临时标记还原为实际的花括号字符"""
|
||||
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
|
||||
|
||||
def _parse_template_args(self, template: str) -> List[str]:
|
||||
def _parse_template_args(self, template: str) -> list[str]:
|
||||
"""解析模板参数"""
|
||||
template_args = []
|
||||
processed_template = self._process_escaped_braces(template)
|
||||
@@ -325,7 +326,7 @@ class Prompt:
|
||||
logger.error(f"构建Prompt失败: {e}")
|
||||
raise RuntimeError(f"构建Prompt失败: {e}") from e
|
||||
|
||||
async def _build_context_data(self) -> Dict[str, Any]:
|
||||
async def _build_context_data(self) -> dict[str, Any]:
|
||||
"""构建智能上下文数据"""
|
||||
# 并行执行所有构建任务
|
||||
start_time = time.time()
|
||||
@@ -405,7 +406,7 @@ class Prompt:
|
||||
default_result = self._get_default_result_for_task(task_name)
|
||||
results.append(default_result)
|
||||
except Exception as e:
|
||||
logger.error(f"构建任务{task_name}失败: {str(e)}")
|
||||
logger.error(f"构建任务{task_name}失败: {e!s}")
|
||||
default_result = self._get_default_result_for_task(task_name)
|
||||
results.append(default_result)
|
||||
|
||||
@@ -415,7 +416,7 @@ class Prompt:
|
||||
task_name = task_names[i] if i < len(task_names) else f"task_{i}"
|
||||
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"构建任务{task_name}失败: {str(result)}")
|
||||
logger.error(f"构建任务{task_name}失败: {result!s}")
|
||||
elif isinstance(result, dict):
|
||||
context_data.update(result)
|
||||
|
||||
@@ -457,7 +458,7 @@ class Prompt:
|
||||
|
||||
return context_data
|
||||
|
||||
async def _build_s4u_chat_context(self, context_data: Dict[str, Any]) -> None:
|
||||
async def _build_s4u_chat_context(self, context_data: dict[str, Any]) -> None:
|
||||
"""构建S4U模式的聊天上下文"""
|
||||
if not self.parameters.message_list_before_now_long:
|
||||
return
|
||||
@@ -472,7 +473,7 @@ class Prompt:
|
||||
context_data["read_history_prompt"] = read_history_prompt
|
||||
context_data["unread_history_prompt"] = unread_history_prompt
|
||||
|
||||
async def _build_normal_chat_context(self, context_data: Dict[str, Any]) -> None:
|
||||
async def _build_normal_chat_context(self, context_data: dict[str, Any]) -> None:
|
||||
"""构建normal模式的聊天上下文"""
|
||||
if not self.parameters.chat_talking_prompt_short:
|
||||
return
|
||||
@@ -481,8 +482,8 @@ class Prompt:
|
||||
{self.parameters.chat_talking_prompt_short}"""
|
||||
|
||||
async def _build_s4u_chat_history_prompts(
|
||||
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, chat_id: str
|
||||
) -> Tuple[str, str]:
|
||||
self, message_list_before_now: list[dict[str, Any]], target_user_id: str, sender: str, chat_id: str
|
||||
) -> tuple[str, str]:
|
||||
"""构建S4U风格的已读/未读历史消息prompt"""
|
||||
try:
|
||||
# 动态导入default_generator以避免循环导入
|
||||
@@ -496,7 +497,7 @@ class Prompt:
|
||||
except Exception as e:
|
||||
logger.error(f"构建S4U历史消息prompt失败: {e}")
|
||||
|
||||
async def _build_expression_habits(self) -> Dict[str, Any]:
|
||||
async def _build_expression_habits(self) -> dict[str, Any]:
|
||||
"""构建表达习惯"""
|
||||
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.parameters.chat_id)
|
||||
if not use_expression:
|
||||
@@ -537,7 +538,7 @@ class Prompt:
|
||||
logger.error(f"构建表达习惯失败: {e}")
|
||||
return {"expression_habits_block": ""}
|
||||
|
||||
async def _build_memory_block(self) -> Dict[str, Any]:
|
||||
async def _build_memory_block(self) -> dict[str, Any]:
|
||||
"""构建记忆块"""
|
||||
if not global_config.memory.enable_memory:
|
||||
return {"memory_block": ""}
|
||||
@@ -657,7 +658,7 @@ class Prompt:
|
||||
logger.error(f"构建记忆块失败: {e}")
|
||||
return {"memory_block": ""}
|
||||
|
||||
async def _build_memory_block_fast(self) -> Dict[str, Any]:
|
||||
async def _build_memory_block_fast(self) -> dict[str, Any]:
|
||||
"""快速构建记忆块(简化版本,用于未预构建时的后备方案)"""
|
||||
if not global_config.memory.enable_memory:
|
||||
return {"memory_block": ""}
|
||||
@@ -681,7 +682,7 @@ class Prompt:
|
||||
logger.warning(f"快速构建记忆块失败: {e}")
|
||||
return {"memory_block": ""}
|
||||
|
||||
async def _build_relation_info(self) -> Dict[str, Any]:
|
||||
async def _build_relation_info(self) -> dict[str, Any]:
|
||||
"""构建关系信息"""
|
||||
try:
|
||||
relation_info = await Prompt.build_relation_info(self.parameters.chat_id, self.parameters.reply_to)
|
||||
@@ -690,7 +691,7 @@ class Prompt:
|
||||
logger.error(f"构建关系信息失败: {e}")
|
||||
return {"relation_info_block": ""}
|
||||
|
||||
async def _build_tool_info(self) -> Dict[str, Any]:
|
||||
async def _build_tool_info(self) -> dict[str, Any]:
|
||||
"""构建工具信息"""
|
||||
if not global_config.tool.enable_tool:
|
||||
return {"tool_info_block": ""}
|
||||
@@ -738,7 +739,7 @@ class Prompt:
|
||||
logger.error(f"构建工具信息失败: {e}")
|
||||
return {"tool_info_block": ""}
|
||||
|
||||
async def _build_knowledge_info(self) -> Dict[str, Any]:
|
||||
async def _build_knowledge_info(self) -> dict[str, Any]:
|
||||
"""构建知识信息"""
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
return {"knowledge_prompt": ""}
|
||||
@@ -787,7 +788,7 @@ class Prompt:
|
||||
logger.error(f"构建知识信息失败: {e}")
|
||||
return {"knowledge_prompt": ""}
|
||||
|
||||
async def _build_cross_context(self) -> Dict[str, Any]:
|
||||
async def _build_cross_context(self) -> dict[str, Any]:
|
||||
"""构建跨群上下文"""
|
||||
try:
|
||||
cross_context = await Prompt.build_cross_context(
|
||||
@@ -798,7 +799,7 @@ class Prompt:
|
||||
logger.error(f"构建跨群上下文失败: {e}")
|
||||
return {"cross_context_block": ""}
|
||||
|
||||
async def _format_with_context(self, context_data: Dict[str, Any]) -> str:
|
||||
async def _format_with_context(self, context_data: dict[str, Any]) -> str:
|
||||
"""使用上下文数据格式化模板"""
|
||||
if self.parameters.prompt_mode == "s4u":
|
||||
params = self._prepare_s4u_params(context_data)
|
||||
@@ -809,7 +810,7 @@ class Prompt:
|
||||
|
||||
return await global_prompt_manager.format_prompt(self.name, **params) if self.name else self.format(**params)
|
||||
|
||||
def _prepare_s4u_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _prepare_s4u_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""准备S4U模式的参数"""
|
||||
return {
|
||||
**context_data,
|
||||
@@ -838,7 +839,7 @@ class Prompt:
|
||||
or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。",
|
||||
}
|
||||
|
||||
def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _prepare_normal_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""准备Normal模式的参数"""
|
||||
return {
|
||||
**context_data,
|
||||
@@ -866,7 +867,7 @@ class Prompt:
|
||||
or "你正在一个QQ群里聊天,你需要理解整个群的聊天动态和话题走向,并做出自然的回应。",
|
||||
}
|
||||
|
||||
def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _prepare_default_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""准备默认模式的参数"""
|
||||
return {
|
||||
"expression_habits_block": context_data.get("expression_habits_block", ""),
|
||||
@@ -909,7 +910,7 @@ class Prompt:
|
||||
result = self._restore_escaped_braces(processed_template)
|
||||
return result
|
||||
except (IndexError, KeyError) as e:
|
||||
raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {str(e)}") from e
|
||||
raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {e!s}") from e
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""返回格式化后的结果或原始模板"""
|
||||
@@ -926,7 +927,7 @@ class Prompt:
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def parse_reply_target(target_message: str) -> Tuple[str, str]:
|
||||
def parse_reply_target(target_message: str) -> tuple[str, str]:
|
||||
"""
|
||||
解析回复目标消息 - 统一实现
|
||||
|
||||
@@ -985,7 +986,7 @@ class Prompt:
|
||||
|
||||
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
|
||||
|
||||
def _get_default_result_for_task(self, task_name: str) -> Dict[str, Any]:
|
||||
def _get_default_result_for_task(self, task_name: str) -> dict[str, Any]:
|
||||
"""
|
||||
为超时的任务提供默认结果
|
||||
|
||||
@@ -1012,7 +1013,7 @@ class Prompt:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]) -> str:
|
||||
async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: dict[str, Any] | None) -> str:
|
||||
"""
|
||||
构建跨群聊上下文 - 统一实现
|
||||
|
||||
@@ -1075,7 +1076,7 @@ class Prompt:
|
||||
|
||||
# 工厂函数
|
||||
def create_prompt(
|
||||
template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
|
||||
template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs
|
||||
) -> Prompt:
|
||||
"""快速创建Prompt实例的工厂函数"""
|
||||
if parameters is None:
|
||||
@@ -1084,7 +1085,7 @@ def create_prompt(
|
||||
|
||||
|
||||
async def create_prompt_async(
|
||||
template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
|
||||
template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs
|
||||
) -> Prompt:
|
||||
"""异步创建Prompt实例"""
|
||||
prompt = create_prompt(template, name, parameters, **kwargs)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Tuple, List
|
||||
from typing import Any
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import db_get, db_query, db_save
|
||||
from src.common.database.sqlalchemy_models import LLMUsage, Messages, OnlineTime
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import OnlineTime, LLMUsage, Messages
|
||||
from src.common.database.sqlalchemy_database_api import db_query, db_save, db_get
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
@@ -162,7 +162,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
# 延迟300秒启动,运行间隔300秒
|
||||
super().__init__(task_name="Statistics Data Output Task", wait_before_start=0, run_interval=300)
|
||||
|
||||
self.name_mapping: Dict[str, Tuple[str, float]] = {}
|
||||
self.name_mapping: dict[str, tuple[str, float]] = {}
|
||||
"""
|
||||
联系人/群聊名称映射 {聊天ID: (联系人/群聊名称, 记录时间(timestamp))}
|
||||
注:设计记录时间的目的是方便更新名称,使联系人/群聊名称保持最新
|
||||
@@ -182,7 +182,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
deploy_time = datetime(2000, 1, 1)
|
||||
local_storage["deploy_time"] = now.timestamp()
|
||||
|
||||
self.stat_period: List[Tuple[str, timedelta, str]] = [
|
||||
self.stat_period: list[tuple[str, timedelta, str]] = [
|
||||
("all_time", now - deploy_time, "自部署以来"), # 必须保留"all_time"
|
||||
("last_7_days", timedelta(days=7), "最近7天"),
|
||||
("last_24_hours", timedelta(days=1), "最近24小时"),
|
||||
@@ -193,7 +193,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
统计时间段 [(统计名称, 统计时间段, 统计描述), ...]
|
||||
"""
|
||||
|
||||
def _statistic_console_output(self, stats: Dict[str, Any], now: datetime):
|
||||
def _statistic_console_output(self, stats: dict[str, Any], now: datetime):
|
||||
"""
|
||||
输出统计数据到控制台
|
||||
:param stats: 统计数据
|
||||
@@ -251,7 +251,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
# -- 以下为统计数据收集方法 --
|
||||
|
||||
@staticmethod
|
||||
async def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
||||
async def _collect_model_request_for_period(collect_period: list[tuple[str, datetime]]) -> dict[str, Any]:
|
||||
"""
|
||||
收集指定时间段的LLM请求统计数据
|
||||
|
||||
@@ -405,8 +405,8 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
@staticmethod
|
||||
async def _collect_online_time_for_period(
|
||||
collect_period: List[Tuple[str, datetime]], now: datetime
|
||||
) -> Dict[str, Any]:
|
||||
collect_period: list[tuple[str, datetime]], now: datetime
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
收集指定时间段的在线时间统计数据
|
||||
|
||||
@@ -464,7 +464,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
break
|
||||
return stats
|
||||
|
||||
async def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]:
|
||||
async def _collect_message_count_for_period(self, collect_period: list[tuple[str, datetime]]) -> dict[str, Any]:
|
||||
"""
|
||||
收集指定时间段的消息统计数据
|
||||
|
||||
@@ -535,7 +535,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
break
|
||||
return stats
|
||||
|
||||
async def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
|
||||
async def _collect_all_statistics(self, now: datetime) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
收集各时间段的统计数据
|
||||
:param now: 基准当前时间
|
||||
@@ -545,7 +545,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
if "last_full_statistics" in local_storage:
|
||||
# 如果存在上次完整统计数据,则使用该数据进行增量统计
|
||||
last_stat: Dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore
|
||||
last_stat: dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore
|
||||
|
||||
self.name_mapping = last_stat["name_mapping"] # 上次完整统计数据的名称映射
|
||||
last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据
|
||||
@@ -632,7 +632,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
# -- 以下为统计数据格式化方法 --
|
||||
|
||||
@staticmethod
|
||||
def _format_total_stat(stats: Dict[str, Any]) -> str:
|
||||
def _format_total_stat(stats: dict[str, Any]) -> str:
|
||||
"""
|
||||
格式化总统计数据
|
||||
"""
|
||||
@@ -648,7 +648,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
return "\n".join(output)
|
||||
|
||||
@staticmethod
|
||||
def _format_model_classified_stat(stats: Dict[str, Any]) -> str:
|
||||
def _format_model_classified_stat(stats: dict[str, Any]) -> str:
|
||||
"""
|
||||
格式化按模型分类的统计数据
|
||||
"""
|
||||
@@ -674,7 +674,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
output.append("")
|
||||
return "\n".join(output)
|
||||
|
||||
def _format_chat_stat(self, stats: Dict[str, Any]) -> str:
|
||||
def _format_chat_stat(self, stats: dict[str, Any]) -> str:
|
||||
"""
|
||||
格式化聊天统计数据
|
||||
"""
|
||||
@@ -1019,7 +1019,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
async def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
|
||||
"""生成图表数据 (异步)"""
|
||||
now = datetime.now()
|
||||
chart_data: Dict[str, Any] = {}
|
||||
chart_data: dict[str, Any] = {}
|
||||
|
||||
time_ranges = [
|
||||
("6h", 6, 10),
|
||||
@@ -1035,16 +1035,16 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
async def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict:
|
||||
start_time = now - timedelta(hours=hours)
|
||||
time_points: List[datetime] = []
|
||||
time_points: list[datetime] = []
|
||||
current_time = start_time
|
||||
while current_time <= now:
|
||||
time_points.append(current_time)
|
||||
current_time += timedelta(minutes=interval_minutes)
|
||||
|
||||
total_cost_data = [0.0] * len(time_points)
|
||||
cost_by_model: Dict[str, List[float]] = {}
|
||||
cost_by_module: Dict[str, List[float]] = {}
|
||||
message_by_chat: Dict[str, List[int]] = {}
|
||||
cost_by_model: dict[str, list[float]] = {}
|
||||
cost_by_module: dict[str, list[float]] = {}
|
||||
message_by_chat: dict[str, list[int]] = {}
|
||||
time_labels = [t.strftime("%H:%M") for t in time_points]
|
||||
interval_seconds = interval_minutes * 60
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import asyncio
|
||||
|
||||
from time import perf_counter
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Optional, Dict, Callable
|
||||
from time import perf_counter
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -75,12 +75,12 @@ class Timer:
|
||||
3. 直接实例化:如果不调用 __enter__,打印对象时将显示当前 perf_counter 的值
|
||||
"""
|
||||
|
||||
__slots__ = ("name", "storage", "elapsed", "auto_unit", "start")
|
||||
__slots__ = ("auto_unit", "elapsed", "name", "start", "storage")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
storage: Optional[Dict[str, float]] = None,
|
||||
name: str | None = None,
|
||||
storage: dict[str, float] | None = None,
|
||||
auto_unit: bool = True,
|
||||
do_type_check: bool = False,
|
||||
):
|
||||
@@ -103,7 +103,7 @@ class Timer:
|
||||
if storage is not None and not isinstance(storage, dict):
|
||||
raise TimerTypeError("storage", "Optional[dict]", type(storage))
|
||||
|
||||
def __call__(self, func: Optional[Callable] = None) -> Callable:
|
||||
def __call__(self, func: Callable | None = None) -> Callable:
|
||||
"""装饰器模式"""
|
||||
if func is None:
|
||||
return lambda f: Timer(name=self.name or f.__name__, storage=self.storage, auto_unit=self.auto_unit)(f)
|
||||
|
||||
@@ -2,15 +2,15 @@
|
||||
错别字生成器 - 基于拼音和字频的中文错别字生成工具
|
||||
"""
|
||||
|
||||
import orjson
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import jieba
|
||||
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import jieba
|
||||
import orjson
|
||||
from pypinyin import Style, pinyin
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -51,7 +51,7 @@ class ChineseTypoGenerator:
|
||||
|
||||
# 如果缓存文件存在,直接加载
|
||||
if cache_file.exists():
|
||||
with open(cache_file, "r", encoding="utf-8") as f:
|
||||
with open(cache_file, encoding="utf-8") as f:
|
||||
return orjson.loads(f.read())
|
||||
|
||||
# 使用内置的词频文件
|
||||
@@ -59,7 +59,7 @@ class ChineseTypoGenerator:
|
||||
dict_path = os.path.join(os.path.dirname(jieba.__file__), "dict.txt")
|
||||
|
||||
# 读取jieba的词典文件
|
||||
with open(dict_path, "r", encoding="utf-8") as f:
|
||||
with open(dict_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
word, freq = line.strip().split()[:2]
|
||||
# 对词中的每个字进行频率累加
|
||||
@@ -254,7 +254,7 @@ class ChineseTypoGenerator:
|
||||
# 获取jieba词典和词频信息
|
||||
dict_path = os.path.join(os.path.dirname(jieba.__file__), "dict.txt")
|
||||
valid_words = {} # 改用字典存储词语及其频率
|
||||
with open(dict_path, "r", encoding="utf-8") as f:
|
||||
with open(dict_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
parts = line.strip().split()
|
||||
if len(parts) >= 2:
|
||||
|
||||
@@ -3,20 +3,21 @@ import random
|
||||
import re
|
||||
import string
|
||||
import time
|
||||
from collections import Counter
|
||||
from typing import Any
|
||||
|
||||
import jieba
|
||||
import numpy as np
|
||||
|
||||
from collections import Counter
|
||||
from maim_message import UserInfo
|
||||
from typing import Optional, Tuple, Dict, List, Any, Coroutine
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import find_messages, count_messages
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import count_messages, find_messages
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.person_info.person_info import Person
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
|
||||
from .typo_generator import ChineseTypoGenerator
|
||||
|
||||
logger = get_logger("chat_utils")
|
||||
@@ -86,9 +87,9 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
||||
if not is_mentioned:
|
||||
# 判断是否被回复
|
||||
if re.match(
|
||||
rf"\[回复 (.+?)\({str(global_config.bot.qq_account)}\):(.+?)\],说:", message.processed_plain_text
|
||||
rf"\[回复 (.+?)\({global_config.bot.qq_account!s}\):(.+?)\],说:", message.processed_plain_text
|
||||
) or re.match(
|
||||
rf"\[回复<(.+?)(?=:{str(global_config.bot.qq_account)}>)\:{str(global_config.bot.qq_account)}>:(.+?)\],说:",
|
||||
rf"\[回复<(.+?)(?=:{global_config.bot.qq_account!s}>)\:{global_config.bot.qq_account!s}>:(.+?)\],说:",
|
||||
message.processed_plain_text,
|
||||
):
|
||||
is_mentioned = True
|
||||
@@ -110,14 +111,14 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
||||
return is_mentioned, reply_probability
|
||||
|
||||
|
||||
async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
|
||||
async def get_embedding(text, request_type="embedding") -> list[float] | None:
|
||||
"""获取文本的embedding向量"""
|
||||
# 每次都创建新的LLMRequest实例以避免事件循环冲突
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type)
|
||||
try:
|
||||
embedding, _ = await llm.get_embedding(text)
|
||||
except Exception as e:
|
||||
logger.error(f"获取embedding失败: {str(e)}")
|
||||
logger.error(f"获取embedding失败: {e!s}")
|
||||
embedding = None
|
||||
return embedding
|
||||
|
||||
@@ -622,7 +623,7 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal"
|
||||
return time.strftime("%H:%M:%S", time.localtime(timestamp))
|
||||
|
||||
|
||||
async def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
||||
def get_chat_type_and_target_info(chat_id: str) -> tuple[bool, dict | None]:
|
||||
"""
|
||||
获取聊天类型(是否群聊)和私聊对象信息。
|
||||
|
||||
@@ -675,7 +676,6 @@ async def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Di
|
||||
if loop.is_running():
|
||||
# 如果事件循环在运行,从其他线程提交并等待结果
|
||||
try:
|
||||
|
||||
fut = asyncio.run_coroutine_threadsafe(
|
||||
person_info_manager.get_value(person_id, "person_name"), loop
|
||||
)
|
||||
@@ -711,7 +711,7 @@ async def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Di
|
||||
return is_group_chat, chat_target_info
|
||||
|
||||
|
||||
def assign_message_ids(messages: List[Any]) -> List[Dict[str, Any]]:
|
||||
def assign_message_ids(messages: list[Any]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
为消息列表中的每个消息分配唯一的简短随机ID
|
||||
|
||||
|
||||
@@ -1,29 +1,27 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
import hashlib
|
||||
import uuid
|
||||
import io
|
||||
import numpy as np
|
||||
from typing import Any
|
||||
|
||||
from typing import Optional, Tuple, Dict, Any
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from rich.traceback import install
|
||||
from sqlalchemy import and_, select
|
||||
|
||||
from src.common.database.sqlalchemy_models import ImageDescriptions, Images, get_db_session
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import Images, ImageDescriptions
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
|
||||
from sqlalchemy import select, and_
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("chat_image")
|
||||
|
||||
|
||||
def is_image_message(message: Dict[str, Any]) -> bool:
|
||||
def is_image_message(message: dict[str, Any]) -> bool:
|
||||
"""
|
||||
判断消息是否为图片消息
|
||||
|
||||
@@ -69,7 +67,7 @@ class ImageManager:
|
||||
os.makedirs(self.IMAGE_DIR, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
async def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
|
||||
async def _get_description_from_db(image_hash: str, description_type: str) -> str | None:
|
||||
"""从数据库获取图片描述
|
||||
|
||||
Args:
|
||||
@@ -93,7 +91,7 @@ class ImageManager:
|
||||
).scalar()
|
||||
return record.description if record else None
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}")
|
||||
logger.error(f"从数据库获取描述失败 (SQLAlchemy): {e!s}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@@ -136,7 +134,7 @@ class ImageManager:
|
||||
await session.commit()
|
||||
# 会在上下文管理器中自动调用
|
||||
except Exception as e:
|
||||
logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}")
|
||||
logger.error(f"保存描述到数据库失败 (SQLAlchemy): {e!s}")
|
||||
|
||||
@staticmethod
|
||||
async def get_emoji_tag(image_base64: str) -> str:
|
||||
@@ -287,10 +285,10 @@ class ImageManager:
|
||||
session.add(new_img)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"保存到Images表失败: {str(e)}")
|
||||
logger.error(f"保存到Images表失败: {e!s}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存表情包文件或元数据失败: {str(e)}")
|
||||
logger.error(f"保存表情包文件或元数据失败: {e!s}")
|
||||
else:
|
||||
logger.debug("偷取表情包功能已关闭,跳过保存。")
|
||||
|
||||
@@ -300,7 +298,7 @@ class ImageManager:
|
||||
return f"[表情包:{final_emotion}]"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取表情包描述失败: {str(e)}")
|
||||
logger.error(f"获取表情包描述失败: {e!s}")
|
||||
return "[表情包(处理失败)]"
|
||||
|
||||
async def get_image_description(self, image_base64: str) -> str:
|
||||
@@ -391,11 +389,11 @@ class ImageManager:
|
||||
logger.info(f"[VLM完成] 图片描述生成: {description}...")
|
||||
return f"[图片:{description}]"
|
||||
except Exception as e:
|
||||
logger.error(f"获取图片描述失败: {str(e)}")
|
||||
logger.error(f"获取图片描述失败: {e!s}")
|
||||
return "[图片(处理失败)]"
|
||||
|
||||
@staticmethod
|
||||
def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> Optional[str]:
|
||||
def transform_gif(gif_base64: str, similarity_threshold: float = 1000.0, max_frames: int = 15) -> str | None:
|
||||
# sourcery skip: use-contextlib-suppress
|
||||
"""将GIF转换为水平拼接的静态图像, 跳过相似的帧
|
||||
|
||||
@@ -512,10 +510,10 @@ class ImageManager:
|
||||
logger.error("GIF转换失败: 内存不足,可能是GIF太大或帧数太多")
|
||||
return None # 内存不够啦
|
||||
except Exception as e:
|
||||
logger.error(f"GIF转换失败: {str(e)}", exc_info=True) # 记录详细错误信息
|
||||
logger.error(f"GIF转换失败: {e!s}", exc_info=True) # 记录详细错误信息
|
||||
return None # 其他错误也返回None
|
||||
|
||||
async def process_image(self, image_base64: str) -> Tuple[str, str]:
|
||||
async def process_image(self, image_base64: str) -> tuple[str, str]:
|
||||
# sourcery skip: hoist-if-from-if
|
||||
"""处理图片并返回图片ID和描述
|
||||
|
||||
@@ -604,7 +602,7 @@ class ImageManager:
|
||||
return image_id, f"[picid:{image_id}]"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片失败: {str(e)}")
|
||||
logger.error(f"处理图片失败: {e!s}")
|
||||
return "", "[图片]"
|
||||
|
||||
|
||||
@@ -637,4 +635,4 @@ def image_path_to_base64(image_path: str) -> str:
|
||||
if image_data := f.read():
|
||||
return base64.b64encode(image_data).decode("utf-8")
|
||||
else:
|
||||
raise IOError(f"读取图片文件失败: {image_path}")
|
||||
raise OSError(f"读取图片文件失败: {image_path}")
|
||||
|
||||
@@ -1,35 +1,31 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""纯 inkfox 视频关键帧分析工具
|
||||
|
||||
仅依赖 `inkfox.video` 提供的 Rust 扩展能力:
|
||||
- extract_keyframes_from_video
|
||||
- get_system_info
|
||||
|
||||
功能:
|
||||
- 关键帧提取 (base64, timestamp)
|
||||
- 批量 / 逐帧 LLM 描述
|
||||
- 自动模式 (<=3 帧批量,否则逐帧)
|
||||
"""
|
||||
视频分析器模块 - Rust优化版本
|
||||
集成了Rust视频关键帧提取模块,提供高性能的视频分析功能
|
||||
支持SIMD优化、多线程处理和智能关键帧检测
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import io
|
||||
import asyncio
|
||||
import base64
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Optional, Dict, Any
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import get_db_session, Videos
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.database.sqlalchemy_models import Videos, get_db_session
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger("utils_video")
|
||||
|
||||
# Rust模块可用性检测
|
||||
@@ -205,7 +201,7 @@ class VideoAnalyzer:
|
||||
hash_obj.update(video_data)
|
||||
return hash_obj.hexdigest()
|
||||
|
||||
async def _check_video_exists(self, video_hash: str) -> Optional[Videos]:
|
||||
async def _check_video_exists(self, video_hash: str) -> Videos | None:
|
||||
"""检查视频是否已经分析过"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
@@ -222,8 +218,8 @@ class VideoAnalyzer:
|
||||
return None
|
||||
|
||||
async def _store_video_result(
|
||||
self, video_hash: str, description: str, metadata: Optional[Dict] = None
|
||||
) -> Optional[Videos]:
|
||||
self, video_hash: str, description: str, metadata: dict | None = None
|
||||
) -> Videos | None:
|
||||
"""存储视频分析结果到数据库"""
|
||||
# 检查描述是否为错误信息,如果是则不保存
|
||||
if description.startswith("❌"):
|
||||
@@ -283,7 +279,7 @@ class VideoAnalyzer:
|
||||
else:
|
||||
logger.warning(f"无效的分析模式: {mode}")
|
||||
|
||||
async def extract_frames(self, video_path: str) -> List[Tuple[str, float]]:
|
||||
async def extract_frames(self, video_path: str) -> list[tuple[str, float]]:
|
||||
"""提取视频帧 - 智能选择最佳实现"""
|
||||
# 检查是否应该使用Rust实现
|
||||
if RUST_VIDEO_AVAILABLE and self.frame_extraction_mode == "keyframe":
|
||||
@@ -305,8 +301,8 @@ class VideoAnalyzer:
|
||||
logger.info(f"🔄 抽帧模式为 {self.frame_extraction_mode},使用Python抽帧实现")
|
||||
return await self._extract_frames_python_fallback(video_path)
|
||||
|
||||
# ---- 系统信息 ----
|
||||
def _log_system(self) -> None:
|
||||
async def _extract_frames_rust_advanced(self, video_path: str) -> list[tuple[str, float]]:
|
||||
"""使用 Rust 高级接口的帧提取"""
|
||||
try:
|
||||
info = video.get_system_info() # type: ignore[attr-defined]
|
||||
logger.info(
|
||||
@@ -329,25 +325,174 @@ class VideoAnalyzer:
|
||||
threads=self.threads,
|
||||
verbose=False,
|
||||
)
|
||||
files = sorted(Path(tmp).glob("keyframe_*.jpg"))[: self.max_frames]
|
||||
total_ms = getattr(result, "total_time_ms", 0)
|
||||
frames: List[Tuple[str, float]] = []
|
||||
for i, f in enumerate(files):
|
||||
img = Image.open(f).convert("RGB")
|
||||
if max(img.size) > self.max_image_size:
|
||||
scale = self.max_image_size / max(img.size)
|
||||
img = img.resize((int(img.width * scale), int(img.height * scale)), Image.Resampling.LANCZOS)
|
||||
buf = io.BytesIO()
|
||||
img.save(buf, format="JPEG", quality=self.frame_quality)
|
||||
b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
ts = (i / max(1, len(files) - 1)) * (total_ms / 1000.0) if total_ms else float(i)
|
||||
frames.append((b64, ts))
|
||||
|
||||
logger.info(f"检测到 {len(keyframe_indices)} 个关键帧")
|
||||
|
||||
# 3. 转换选定的关键帧为 base64
|
||||
frames = []
|
||||
frame_count = 0
|
||||
|
||||
for idx in keyframe_indices[: self.max_frames]:
|
||||
if idx < len(frames_data):
|
||||
try:
|
||||
frame = frames_data[idx]
|
||||
frame_data = frame.get_data()
|
||||
|
||||
# 将灰度数据转换为PIL图像
|
||||
frame_array = np.frombuffer(frame_data, dtype=np.uint8).reshape((frame.height, frame.width))
|
||||
pil_image = Image.fromarray(
|
||||
frame_array,
|
||||
mode="L", # 灰度模式
|
||||
)
|
||||
|
||||
# 转换为RGB模式以便保存为JPEG
|
||||
pil_image = pil_image.convert("RGB")
|
||||
|
||||
# 调整图像大小
|
||||
if max(pil_image.size) > self.max_image_size:
|
||||
ratio = self.max_image_size / max(pil_image.size)
|
||||
new_size = tuple(int(dim * ratio) for dim in pil_image.size)
|
||||
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
|
||||
# 转换为 base64
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="JPEG", quality=self.frame_quality)
|
||||
frame_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
# 估算时间戳
|
||||
estimated_timestamp = frame.frame_number * (1.0 / 30.0) # 假设30fps
|
||||
|
||||
frames.append((frame_base64, estimated_timestamp))
|
||||
frame_count += 1
|
||||
|
||||
logger.debug(
|
||||
f"处理关键帧 {frame_count}: 帧号 {frame.frame_number}, 时间 {estimated_timestamp:.2f}s"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理关键帧 {idx} 失败: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"✅ Rust 高级提取完成: {len(frames)} 关键帧")
|
||||
return frames
|
||||
|
||||
# ---- 批量分析 ----
|
||||
async def _analyze_batch(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str:
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
||||
from src.llm_models.utils_model import RequestType
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Rust 高级帧提取失败: {e}")
|
||||
# 回退到基础方法
|
||||
logger.info("回退到基础 Rust 方法")
|
||||
return await self._extract_frames_rust(video_path)
|
||||
|
||||
async def _extract_frames_rust(self, video_path: str) -> list[tuple[str, float]]:
|
||||
"""使用 Rust 实现的帧提取"""
|
||||
try:
|
||||
logger.info("🔄 使用 Rust 模块提取关键帧...")
|
||||
|
||||
# 创建临时输出目录
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# 使用便捷函数进行关键帧提取,使用配置参数
|
||||
result = rust_video.extract_keyframes_from_video(
|
||||
video_path=video_path,
|
||||
output_dir=temp_dir,
|
||||
threshold=self.rust_keyframe_threshold,
|
||||
max_frames=self.max_frames * 2, # 提取更多帧以便筛选
|
||||
max_save=self.max_frames,
|
||||
ffmpeg_path=self.ffmpeg_path,
|
||||
use_simd=self.rust_use_simd,
|
||||
threads=self.rust_threads,
|
||||
verbose=False, # 使用固定值,不需要配置
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Rust 处理完成: 总帧数 {result.total_frames}, 关键帧 {result.keyframes_extracted}, 处理速度 {result.processing_fps:.1f} FPS"
|
||||
)
|
||||
|
||||
# 转换保存的关键帧为 base64 格式
|
||||
frames = []
|
||||
temp_dir_path = Path(temp_dir)
|
||||
|
||||
# 获取所有保存的关键帧文件
|
||||
keyframe_files = sorted(temp_dir_path.glob("keyframe_*.jpg"))
|
||||
|
||||
for i, keyframe_file in enumerate(keyframe_files):
|
||||
if len(frames) >= self.max_frames:
|
||||
break
|
||||
|
||||
try:
|
||||
# 读取关键帧文件
|
||||
with open(keyframe_file, "rb") as f:
|
||||
image_data = f.read()
|
||||
|
||||
# 转换为 PIL 图像并压缩
|
||||
pil_image = Image.open(io.BytesIO(image_data))
|
||||
|
||||
# 调整图像大小
|
||||
if max(pil_image.size) > self.max_image_size:
|
||||
ratio = self.max_image_size / max(pil_image.size)
|
||||
new_size = tuple(int(dim * ratio) for dim in pil_image.size)
|
||||
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
|
||||
# 转换为 base64
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="JPEG", quality=self.frame_quality)
|
||||
frame_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
# 估算时间戳(基于帧索引和总时长)
|
||||
if result.total_frames > 0:
|
||||
# 假设关键帧在时间上均匀分布
|
||||
estimated_timestamp = (i * result.total_time_ms / 1000.0) / result.keyframes_extracted
|
||||
else:
|
||||
estimated_timestamp = i * 1.0 # 默认每秒一帧
|
||||
|
||||
frames.append((frame_base64, estimated_timestamp))
|
||||
|
||||
logger.debug(f"处理关键帧 {i + 1}: 估算时间 {estimated_timestamp:.2f}s")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理关键帧 {keyframe_file.name} 失败: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"✅ Rust 提取完成: {len(frames)} 关键帧")
|
||||
return frames
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Rust 帧提取失败: {e}")
|
||||
raise e
|
||||
|
||||
async def _extract_frames_python_fallback(self, video_path: str) -> list[tuple[str, float]]:
|
||||
"""Python降级抽帧实现 - 支持多种抽帧模式"""
|
||||
try:
|
||||
# 导入旧版本分析器
|
||||
from .utils_video_legacy import get_legacy_video_analyzer
|
||||
|
||||
logger.info("🔄 使用Python降级抽帧实现...")
|
||||
legacy_analyzer = get_legacy_video_analyzer()
|
||||
|
||||
# 同步配置参数
|
||||
legacy_analyzer.max_frames = self.max_frames
|
||||
legacy_analyzer.frame_quality = self.frame_quality
|
||||
legacy_analyzer.max_image_size = self.max_image_size
|
||||
legacy_analyzer.frame_extraction_mode = self.frame_extraction_mode
|
||||
legacy_analyzer.frame_interval_seconds = self.frame_interval_seconds
|
||||
legacy_analyzer.use_multiprocessing = self.use_multiprocessing
|
||||
|
||||
# 使用旧版本的抽帧功能
|
||||
frames = await legacy_analyzer.extract_frames(video_path)
|
||||
|
||||
logger.info(f"✅ Python降级抽帧完成: {len(frames)} 帧")
|
||||
return frames
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Python降级抽帧失败: {e}")
|
||||
return []
|
||||
|
||||
async def analyze_frames_batch(self, frames: list[tuple[str, float]], user_question: str = None) -> str:
|
||||
"""批量分析所有帧"""
|
||||
logger.info(f"开始批量分析{len(frames)}帧")
|
||||
|
||||
if not frames:
|
||||
return "❌ 没有可分析的帧"
|
||||
|
||||
# 构建提示词并格式化人格信息,要不然占位符的那个会爆炸
|
||||
prompt = self.batch_analysis_prompt.format(
|
||||
personality_core=self.personality_core, personality_side=self.personality_side
|
||||
)
|
||||
@@ -376,7 +521,7 @@ class VideoAnalyzer:
|
||||
logger.error(f"❌ 视频识别失败: {e}")
|
||||
raise e
|
||||
|
||||
async def _analyze_multiple_frames(self, frames: List[Tuple[str, float]], prompt: str) -> str:
|
||||
async def _analyze_multiple_frames(self, frames: list[tuple[str, float]], prompt: str) -> str:
|
||||
"""使用多图片分析方法"""
|
||||
logger.info(f"开始构建包含{len(frames)}帧的分析请求")
|
||||
|
||||
@@ -412,53 +557,75 @@ class VideoAnalyzer:
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
)
|
||||
return resp.content or "❌ 未获得响应"
|
||||
|
||||
# ---- 逐帧分析 ----
|
||||
async def _analyze_sequential(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str:
|
||||
results: List[str] = []
|
||||
for i, (b64, ts) in enumerate(frames):
|
||||
prompt = f"分析第{i+1}帧" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "")
|
||||
if question:
|
||||
prompt += f"\n关注: {question}"
|
||||
logger.info(f"视频识别完成,响应长度: {len(api_response.content or '')} ")
|
||||
return api_response.content or "❌ 未获得响应内容"
|
||||
|
||||
async def analyze_frames_sequential(self, frames: list[tuple[str, float]], user_question: str = None) -> str:
|
||||
"""逐帧分析并汇总"""
|
||||
logger.info(f"开始逐帧分析{len(frames)}帧")
|
||||
|
||||
frame_analyses = []
|
||||
|
||||
for i, (frame_base64, timestamp) in enumerate(frames):
|
||||
try:
|
||||
text, _ = await self.video_llm.generate_response_for_image(
|
||||
prompt=prompt, image_base64=b64, image_format="jpeg"
|
||||
)
|
||||
results.append(f"第{i+1}帧: {text}")
|
||||
except Exception as e: # pragma: no cover
|
||||
results.append(f"第{i+1}帧: 失败 {e}")
|
||||
if i < len(frames) - 1:
|
||||
await asyncio.sleep(self.frame_analysis_delay)
|
||||
summary_prompt = "基于以下逐帧结果给出完整总结:\n\n" + "\n".join(results)
|
||||
try:
|
||||
final, _ = await self.video_llm.generate_response_for_image(
|
||||
prompt=summary_prompt, image_base64=frames[-1][0], image_format="jpeg"
|
||||
)
|
||||
return final
|
||||
except Exception: # pragma: no cover
|
||||
return "\n".join(results)
|
||||
logger.info("✅ 逐帧分析和汇总完成")
|
||||
return summary
|
||||
else:
|
||||
return "❌ 没有可用于汇总的帧"
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 汇总分析失败: {e}")
|
||||
# 如果汇总失败,返回各帧分析结果
|
||||
return f"视频逐帧分析结果:\n\n{chr(10).join(frame_analyses)}"
|
||||
|
||||
# ---- 主入口 ----
|
||||
async def analyze_video(self, video_path: str, question: Optional[str] = None) -> Tuple[bool, str]:
|
||||
if not os.path.exists(video_path):
|
||||
return False, "❌ 文件不存在"
|
||||
frames = await self.extract_keyframes(video_path)
|
||||
if not frames:
|
||||
return False, "❌ 未提取到关键帧"
|
||||
mode = self.analysis_mode
|
||||
if mode == "auto":
|
||||
mode = "batch" if len(frames) <= 20 else "sequential"
|
||||
text = await (self._analyze_batch(frames, question) if mode == "batch" else self._analyze_sequential(frames, question))
|
||||
return True, text
|
||||
async def analyze_video(self, video_path: str, user_question: str = None) -> tuple[bool, str]:
|
||||
"""分析视频的主要方法
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否成功, 分析结果或错误信息)
|
||||
"""
|
||||
if self.disabled:
|
||||
error_msg = "❌ 视频分析功能已禁用:没有可用的视频处理实现"
|
||||
logger.warning(error_msg)
|
||||
return (False, error_msg)
|
||||
|
||||
try:
|
||||
logger.info(f"开始分析视频: {os.path.basename(video_path)}")
|
||||
|
||||
# 提取帧
|
||||
frames = await self.extract_frames(video_path)
|
||||
if not frames:
|
||||
error_msg = "❌ 无法从视频中提取有效帧"
|
||||
return (False, error_msg)
|
||||
|
||||
# 根据模式选择分析方法
|
||||
if self.analysis_mode == "auto":
|
||||
# 智能选择:少于等于3帧用批量,否则用逐帧
|
||||
mode = "batch" if len(frames) <= 3 else "sequential"
|
||||
logger.info(f"自动选择分析模式: {mode} (基于{len(frames)}帧)")
|
||||
else:
|
||||
mode = self.analysis_mode
|
||||
|
||||
# 执行分析
|
||||
if mode == "batch":
|
||||
result = await self.analyze_frames_batch(frames, user_question)
|
||||
else: # sequential
|
||||
result = await self.analyze_frames_sequential(frames, user_question)
|
||||
|
||||
logger.info("✅ 视频分析完成")
|
||||
return (True, result)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"❌ 视频分析失败: {e!s}"
|
||||
logger.error(error_msg)
|
||||
return (False, error_msg)
|
||||
|
||||
async def analyze_video_from_bytes(
|
||||
self,
|
||||
video_bytes: bytes,
|
||||
filename: Optional[str] = None,
|
||||
prompt: Optional[str] = None,
|
||||
question: Optional[str] = None,
|
||||
) -> Dict[str, str]:
|
||||
self, video_bytes: bytes, filename: str = None, user_question: str = None, prompt: str = None
|
||||
) -> dict[str, str]:
|
||||
"""从字节数据分析视频
|
||||
|
||||
Args:
|
||||
@@ -568,34 +735,81 @@ class VideoAnalyzer:
|
||||
return {"summary": result}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"❌ 从字节数据分析视频失败: {str(e)}"
|
||||
error_msg = f"❌ 从字节数据分析视频失败: {e!s}"
|
||||
logger.error(error_msg)
|
||||
|
||||
async def _save_cache(self, video_hash: str, summary: str, file_size: int) -> None:
|
||||
# 不保存错误信息到数据库,允许后续重试
|
||||
logger.info("💡 错误信息不保存到数据库,允许后续重试")
|
||||
|
||||
# 处理失败,通知等待者并清理资源
|
||||
try:
|
||||
if video_hash and video_event:
|
||||
async with video_lock_manager:
|
||||
if video_hash in video_events:
|
||||
video_events[video_hash].set()
|
||||
video_locks.pop(video_hash, None)
|
||||
video_events.pop(video_hash, None)
|
||||
except Exception as cleanup_e:
|
||||
logger.error(f"❌ 清理锁资源失败: {cleanup_e}")
|
||||
|
||||
return {"summary": error_msg}
|
||||
|
||||
def is_supported_video(self, file_path: str) -> bool:
|
||||
"""检查是否为支持的视频格式"""
|
||||
supported_formats = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".m4v", ".3gp", ".webm"}
|
||||
return Path(file_path).suffix.lower() in supported_formats
|
||||
|
||||
def get_processing_capabilities(self) -> dict[str, any]:
|
||||
"""获取处理能力信息"""
|
||||
if not RUST_VIDEO_AVAILABLE:
|
||||
return {"error": "Rust视频处理模块不可用", "available": False, "reason": "rust_video模块未安装或加载失败"}
|
||||
|
||||
try:
|
||||
async with get_db_session() as session: # type: ignore
|
||||
stmt = insert(Videos).values( # type: ignore
|
||||
video_id="",
|
||||
video_hash=video_hash,
|
||||
description=summary,
|
||||
count=1,
|
||||
timestamp=time.time(),
|
||||
vlm_processed=True,
|
||||
duration=None,
|
||||
frame_count=None,
|
||||
fps=None,
|
||||
resolution=None,
|
||||
file_size=file_size,
|
||||
)
|
||||
try:
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
logger.debug(f"视频缓存写入 success hash={video_hash}")
|
||||
except sa_exc.IntegrityError: # 可能并发已写入
|
||||
await session.rollback()
|
||||
logger.debug(f"视频缓存已存在 hash={video_hash}")
|
||||
except Exception: # pragma: no cover
|
||||
logger.debug("视频缓存写入失败")
|
||||
system_info = rust_video.get_system_info()
|
||||
|
||||
# 创建一个临时的extractor来获取CPU特性
|
||||
extractor = rust_video.VideoKeyframeExtractor(threads=0, verbose=False)
|
||||
cpu_features = extractor.get_cpu_features()
|
||||
|
||||
capabilities = {
|
||||
"system": {
|
||||
"threads": system_info.get("threads", 0),
|
||||
"rust_version": system_info.get("version", "unknown"),
|
||||
},
|
||||
"cpu_features": cpu_features,
|
||||
"recommended_settings": self._get_recommended_settings(cpu_features),
|
||||
"analysis_modes": ["auto", "batch", "sequential"],
|
||||
"supported_formats": [".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".m4v", ".3gp", ".webm"],
|
||||
"available": True,
|
||||
}
|
||||
|
||||
return capabilities
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取处理能力信息失败: {e}")
|
||||
return {"error": str(e), "available": False}
|
||||
|
||||
def _get_recommended_settings(self, cpu_features: dict[str, bool]) -> dict[str, any]:
|
||||
"""根据CPU特性推荐最佳设置"""
|
||||
settings = {
|
||||
"use_simd": any(cpu_features.values()),
|
||||
"block_size": 8192,
|
||||
"threads": 0, # 自动检测
|
||||
}
|
||||
|
||||
# 根据CPU特性调整设置
|
||||
if cpu_features.get("avx2", False):
|
||||
settings["block_size"] = 16384 # AVX2支持更大的块
|
||||
settings["optimization_level"] = "avx2"
|
||||
elif cpu_features.get("sse2", False):
|
||||
settings["block_size"] = 8192
|
||||
settings["optimization_level"] = "sse2"
|
||||
else:
|
||||
settings["use_simd"] = False
|
||||
settings["block_size"] = 4096
|
||||
settings["optimization_level"] = "scalar"
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
# ---- 外部接口 ----
|
||||
@@ -613,7 +827,14 @@ def is_video_analysis_available() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def get_video_analysis_status() -> Dict[str, Any]:
|
||||
def get_video_analysis_status() -> dict[str, any]:
|
||||
"""获取视频分析功能的详细状态信息
|
||||
|
||||
Returns:
|
||||
Dict[str, any]: 包含功能状态信息的字典
|
||||
"""
|
||||
# 检查OpenCV是否可用
|
||||
opencv_available = False
|
||||
try:
|
||||
info = video.get_system_info() # type: ignore[attr-defined]
|
||||
except Exception as e: # pragma: no cover
|
||||
|
||||
@@ -1,25 +1,25 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
视频分析器模块 - 旧版本兼容模块
|
||||
支持多种分析模式:批处理、逐帧、自动选择
|
||||
包含Python原生的抽帧功能,作为Rust模块的降级方案
|
||||
"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Optional, Any
|
||||
import io
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger("utils_video_legacy")
|
||||
|
||||
@@ -30,7 +30,7 @@ def _extract_frames_worker(
|
||||
frame_quality: int,
|
||||
max_image_size: int,
|
||||
frame_extraction_mode: str,
|
||||
frame_interval_seconds: Optional[float],
|
||||
frame_interval_seconds: float | None,
|
||||
) -> list[Any] | list[tuple[str, str]]:
|
||||
"""线程池中提取视频帧的工作函数"""
|
||||
frames = []
|
||||
@@ -221,7 +221,7 @@ class LegacyVideoAnalyzer:
|
||||
f"✅ 旧版本视频分析器初始化完成,分析模式: {self.analysis_mode}, 线程池: {self.use_multiprocessing}"
|
||||
)
|
||||
|
||||
async def extract_frames(self, video_path: str) -> List[Tuple[str, float]]:
|
||||
async def extract_frames(self, video_path: str) -> list[tuple[str, float]]:
|
||||
"""提取视频帧 - 支持多进程和单线程模式"""
|
||||
# 先获取视频信息
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
@@ -247,7 +247,7 @@ class LegacyVideoAnalyzer:
|
||||
else:
|
||||
return await self._extract_frames_fallback(video_path)
|
||||
|
||||
async def _extract_frames_multiprocess(self, video_path: str) -> List[Tuple[str, float]]:
|
||||
async def _extract_frames_multiprocess(self, video_path: str) -> list[tuple[str, float]]:
|
||||
"""线程池版本的帧提取"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
@@ -282,7 +282,7 @@ class LegacyVideoAnalyzer:
|
||||
logger.info("🔄 降级到单线程模式...")
|
||||
return await self._extract_frames_fallback(video_path)
|
||||
|
||||
async def _extract_frames_fallback(self, video_path: str) -> List[Tuple[str, float]]:
|
||||
async def _extract_frames_fallback(self, video_path: str) -> list[tuple[str, float]]:
|
||||
"""帧提取的降级方法 - 原始异步版本"""
|
||||
frames = []
|
||||
extracted_count = 0
|
||||
@@ -389,7 +389,7 @@ class LegacyVideoAnalyzer:
|
||||
logger.info(f"✅ 成功提取{len(frames)}帧")
|
||||
return frames
|
||||
|
||||
async def analyze_frames_batch(self, frames: List[Tuple[str, float]], user_question: str = None) -> str:
|
||||
async def analyze_frames_batch(self, frames: list[tuple[str, float]], user_question: str = None) -> str:
|
||||
"""批量分析所有帧"""
|
||||
logger.info(f"开始批量分析{len(frames)}帧")
|
||||
|
||||
@@ -441,7 +441,7 @@ class LegacyVideoAnalyzer:
|
||||
logger.error(f"❌ 降级分析也失败: {fallback_e}")
|
||||
raise
|
||||
|
||||
async def _analyze_multiple_frames(self, frames: List[Tuple[str, float]], prompt: str) -> str:
|
||||
async def _analyze_multiple_frames(self, frames: list[tuple[str, float]], prompt: str) -> str:
|
||||
"""使用多图片分析方法"""
|
||||
logger.info(f"开始构建包含{len(frames)}帧的分析请求")
|
||||
|
||||
@@ -481,7 +481,7 @@ class LegacyVideoAnalyzer:
|
||||
logger.info(f"视频识别完成,响应长度: {len(api_response.content or '')} ")
|
||||
return api_response.content or "❌ 未获得响应内容"
|
||||
|
||||
async def analyze_frames_sequential(self, frames: List[Tuple[str, float]], user_question: str = None) -> str:
|
||||
async def analyze_frames_sequential(self, frames: list[tuple[str, float]], user_question: str = None) -> str:
|
||||
"""逐帧分析并汇总"""
|
||||
logger.info(f"开始逐帧分析{len(frames)}帧")
|
||||
|
||||
@@ -567,7 +567,7 @@ class LegacyVideoAnalyzer:
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"❌ 视频分析失败: {str(e)}"
|
||||
error_msg = f"❌ 视频分析失败: {e!s}"
|
||||
logger.error(error_msg)
|
||||
return error_msg
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from rich.traceback import install
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -25,5 +25,5 @@ async def get_voice_text(voice_base64: str) -> str:
|
||||
|
||||
return f"[语音:{text}]"
|
||||
except Exception as e:
|
||||
logger.error(f"语音转文字失败: {str(e)}")
|
||||
logger.error(f"语音转文字失败: {e!s}")
|
||||
return "[语音]"
|
||||
|
||||
Reference in New Issue
Block a user