style: 统一代码风格并进行现代化改进
对整个代码库进行了一次全面的风格统一和现代化改进。主要变更包括:
- 将 `hasattr` 等内置函数中的字符串参数从单引号 `'` 统一为双引号 `"`。
- 采用现代类型注解,例如将 `Optional[T]` 替换为 `T | None`,`List[T]` 替换为 `list[T]` 等。
- 移除不再需要的 Python 2 兼容性声明 `# -*- coding: utf-8 -*-`。
- 清理了多余的空行、注释和未使用的导入。
- 统一了文件末尾的换行符。
- 优化了部分日志输出和字符串格式化 (`f"{e!s}"`)。
这些改动旨在提升代码的可读性、一致性和可维护性,使其更符合现代 Python 编码规范。
This commit is contained in:
8
bot.py
8
bot.py
@@ -103,7 +103,7 @@ async def graceful_shutdown(main_system_instance):
|
|||||||
logger.info("正在优雅关闭麦麦...")
|
logger.info("正在优雅关闭麦麦...")
|
||||||
|
|
||||||
# 停止MainSystem中的组件,它会处理服务器等
|
# 停止MainSystem中的组件,它会处理服务器等
|
||||||
if main_system_instance and hasattr(main_system_instance, 'shutdown'):
|
if main_system_instance and hasattr(main_system_instance, "shutdown"):
|
||||||
logger.info("正在关闭MainSystem...")
|
logger.info("正在关闭MainSystem...")
|
||||||
await main_system_instance.shutdown()
|
await main_system_instance.shutdown()
|
||||||
|
|
||||||
@@ -111,7 +111,7 @@ async def graceful_shutdown(main_system_instance):
|
|||||||
try:
|
try:
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
chat_manager = get_chat_manager()
|
chat_manager = get_chat_manager()
|
||||||
if hasattr(chat_manager, '_stop_auto_save'):
|
if hasattr(chat_manager, "_stop_auto_save"):
|
||||||
logger.info("正在停止聊天管理器...")
|
logger.info("正在停止聊天管理器...")
|
||||||
chat_manager._stop_auto_save()
|
chat_manager._stop_auto_save()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -120,7 +120,7 @@ async def graceful_shutdown(main_system_instance):
|
|||||||
# 停止情绪管理器
|
# 停止情绪管理器
|
||||||
try:
|
try:
|
||||||
from src.mood.mood_manager import mood_manager
|
from src.mood.mood_manager import mood_manager
|
||||||
if hasattr(mood_manager, 'stop'):
|
if hasattr(mood_manager, "stop"):
|
||||||
logger.info("正在停止情绪管理器...")
|
logger.info("正在停止情绪管理器...")
|
||||||
await mood_manager.stop()
|
await mood_manager.stop()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -129,7 +129,7 @@ async def graceful_shutdown(main_system_instance):
|
|||||||
# 停止记忆系统
|
# 停止记忆系统
|
||||||
try:
|
try:
|
||||||
from src.chat.memory_system.memory_manager import memory_manager
|
from src.chat.memory_system.memory_manager import memory_manager
|
||||||
if hasattr(memory_manager, 'shutdown'):
|
if hasattr(memory_manager, "shutdown"):
|
||||||
logger.info("正在停止记忆系统...")
|
logger.info("正在停止记忆系统...")
|
||||||
await memory_manager.shutdown()
|
await memory_manager.shutdown()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -7,4 +7,4 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
author="Your Name",
|
author="Your Name",
|
||||||
license="MIT",
|
license="MIT",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -7,4 +7,4 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
author="Your Name",
|
author="Your Name",
|
||||||
license="MIT",
|
license="MIT",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
# This file makes src/api a Python package.
|
# This file makes src/api a Python package.
|
||||||
|
|||||||
@@ -3,10 +3,10 @@ from typing import Literal
|
|||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Query
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.plugin_system.apis import message_api, chat_api, person_api
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.plugin_system.apis import message_api, person_api
|
||||||
|
|
||||||
logger = get_logger("HTTP消息API")
|
logger = get_logger("HTTP消息API")
|
||||||
|
|
||||||
@@ -86,7 +86,7 @@ async def get_message_stats_by_chat(
|
|||||||
if group_by_user:
|
if group_by_user:
|
||||||
if user_id not in stats[chat_id]["user_stats"]:
|
if user_id not in stats[chat_id]["user_stats"]:
|
||||||
stats[chat_id]["user_stats"][user_id] = 0
|
stats[chat_id]["user_stats"][user_id] = 0
|
||||||
|
|
||||||
stats[chat_id]["user_stats"][user_id] += 1
|
stats[chat_id]["user_stats"][user_id] += 1
|
||||||
|
|
||||||
if not group_by_user:
|
if not group_by_user:
|
||||||
@@ -120,7 +120,7 @@ async def get_message_stats_by_chat(
|
|||||||
"nickname": nickname,
|
"nickname": nickname,
|
||||||
"count": count
|
"count": count
|
||||||
}
|
}
|
||||||
|
|
||||||
formatted_stats[chat_id] = formatted_data
|
formatted_stats[chat_id] = formatted_data
|
||||||
return formatted_stats
|
return formatted_stats
|
||||||
|
|
||||||
@@ -164,7 +164,7 @@ async def get_bot_message_stats_by_chat(
|
|||||||
chat_name = stream.group_info.group_name
|
chat_name = stream.group_info.group_name
|
||||||
elif stream.user_info and stream.user_info.user_nickname:
|
elif stream.user_info and stream.user_info.user_nickname:
|
||||||
chat_name = stream.user_info.user_nickname
|
chat_name = stream.user_info.user_nickname
|
||||||
|
|
||||||
formatted_stats[chat_id] = {
|
formatted_stats[chat_id] = {
|
||||||
"chat_name": chat_name,
|
"chat_name": chat_name,
|
||||||
"count": count
|
"count": count
|
||||||
@@ -174,4 +174,4 @@ async def get_bot_message_stats_by_chat(
|
|||||||
return stats
|
return stats
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|||||||
@@ -112,7 +112,7 @@ class InterestManager:
|
|||||||
# 返回默认结果
|
# 返回默认结果
|
||||||
return InterestCalculationResult(
|
return InterestCalculationResult(
|
||||||
success=False,
|
success=False,
|
||||||
message_id=getattr(message, 'message_id', ''),
|
message_id=getattr(message, "message_id", ""),
|
||||||
interest_value=0.3,
|
interest_value=0.3,
|
||||||
error_message="没有可用的兴趣值计算组件"
|
error_message="没有可用的兴趣值计算组件"
|
||||||
)
|
)
|
||||||
@@ -129,7 +129,7 @@ class InterestManager:
|
|||||||
logger.warning(f"兴趣值计算超时 ({timeout}s),消息 {getattr(message, 'message_id', '')} 使用默认兴趣值 0.5")
|
logger.warning(f"兴趣值计算超时 ({timeout}s),消息 {getattr(message, 'message_id', '')} 使用默认兴趣值 0.5")
|
||||||
return InterestCalculationResult(
|
return InterestCalculationResult(
|
||||||
success=True,
|
success=True,
|
||||||
message_id=getattr(message, 'message_id', ''),
|
message_id=getattr(message, "message_id", ""),
|
||||||
interest_value=0.5, # 固定默认兴趣值
|
interest_value=0.5, # 固定默认兴趣值
|
||||||
should_reply=False,
|
should_reply=False,
|
||||||
should_act=False,
|
should_act=False,
|
||||||
@@ -140,9 +140,9 @@ class InterestManager:
|
|||||||
logger.error(f"兴趣值计算异常: {e}")
|
logger.error(f"兴趣值计算异常: {e}")
|
||||||
return InterestCalculationResult(
|
return InterestCalculationResult(
|
||||||
success=False,
|
success=False,
|
||||||
message_id=getattr(message, 'message_id', ''),
|
message_id=getattr(message, "message_id", ""),
|
||||||
interest_value=0.3,
|
interest_value=0.3,
|
||||||
error_message=f"计算异常: {str(e)}"
|
error_message=f"计算异常: {e!s}"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _async_calculate(self, message: "DatabaseMessages") -> InterestCalculationResult:
|
async def _async_calculate(self, message: "DatabaseMessages") -> InterestCalculationResult:
|
||||||
@@ -168,9 +168,9 @@ class InterestManager:
|
|||||||
logger.error(f"兴趣值计算异常: {e}", exc_info=True)
|
logger.error(f"兴趣值计算异常: {e}", exc_info=True)
|
||||||
return InterestCalculationResult(
|
return InterestCalculationResult(
|
||||||
success=False,
|
success=False,
|
||||||
message_id=getattr(message, 'message_id', ''),
|
message_id=getattr(message, "message_id", ""),
|
||||||
interest_value=0.0,
|
interest_value=0.0,
|
||||||
error_message=f"计算异常: {str(e)}",
|
error_message=f"计算异常: {e!s}",
|
||||||
calculation_time=time.time() - start_time
|
calculation_time=time.time() - start_time
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -245,4 +245,4 @@ def get_interest_manager() -> InterestManager:
|
|||||||
global _interest_manager
|
global _interest_manager
|
||||||
if _interest_manager is None:
|
if _interest_manager is None:
|
||||||
_interest_manager = InterestManager()
|
_interest_manager = InterestManager()
|
||||||
return _interest_manager
|
return _interest_manager
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
"""
|
||||||
海马体双峰分布采样器
|
海马体双峰分布采样器
|
||||||
基于旧版海马体的采样策略,适配新版记忆系统
|
基于旧版海马体的采样策略,适配新版记忆系统
|
||||||
@@ -8,16 +7,15 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import List, Optional, Tuple, Dict, Any
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import orjson
|
|
||||||
|
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
get_raw_msg_by_timestamp,
|
|
||||||
build_readable_messages,
|
build_readable_messages,
|
||||||
|
get_raw_msg_by_timestamp,
|
||||||
get_raw_msg_by_timestamp_with_chat,
|
get_raw_msg_by_timestamp_with_chat,
|
||||||
)
|
)
|
||||||
from src.chat.utils.utils import translate_timestamp_to_human_readable
|
from src.chat.utils.utils import translate_timestamp_to_human_readable
|
||||||
@@ -47,7 +45,7 @@ class HippocampusSampleConfig:
|
|||||||
batch_size: int = 5 # 批处理大小
|
batch_size: int = 5 # 批处理大小
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_global_config(cls) -> 'HippocampusSampleConfig':
|
def from_global_config(cls) -> "HippocampusSampleConfig":
|
||||||
"""从全局配置创建海马体采样配置"""
|
"""从全局配置创建海马体采样配置"""
|
||||||
config = global_config.memory.hippocampus_distribution_config
|
config = global_config.memory.hippocampus_distribution_config
|
||||||
return cls(
|
return cls(
|
||||||
@@ -74,12 +72,12 @@ class HippocampusSampler:
|
|||||||
self.is_running = False
|
self.is_running = False
|
||||||
|
|
||||||
# 记忆构建模型
|
# 记忆构建模型
|
||||||
self.memory_builder_model: Optional[LLMRequest] = None
|
self.memory_builder_model: LLMRequest | None = None
|
||||||
|
|
||||||
# 统计信息
|
# 统计信息
|
||||||
self.sample_count = 0
|
self.sample_count = 0
|
||||||
self.success_count = 0
|
self.success_count = 0
|
||||||
self.last_sample_results: List[Dict[str, Any]] = []
|
self.last_sample_results: list[dict[str, Any]] = []
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""初始化采样器"""
|
"""初始化采样器"""
|
||||||
@@ -101,7 +99,7 @@ class HippocampusSampler:
|
|||||||
logger.error(f"❌ 海马体采样器初始化失败: {e}")
|
logger.error(f"❌ 海马体采样器初始化失败: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def generate_time_samples(self) -> List[datetime]:
|
def generate_time_samples(self) -> list[datetime]:
|
||||||
"""生成双峰分布的时间采样点"""
|
"""生成双峰分布的时间采样点"""
|
||||||
# 计算每个分布的样本数
|
# 计算每个分布的样本数
|
||||||
recent_samples = max(1, int(self.config.total_samples * self.config.recent_weight))
|
recent_samples = max(1, int(self.config.total_samples * self.config.recent_weight))
|
||||||
@@ -132,7 +130,7 @@ class HippocampusSampler:
|
|||||||
# 按时间排序(从最早到最近)
|
# 按时间排序(从最早到最近)
|
||||||
return sorted(timestamps)
|
return sorted(timestamps)
|
||||||
|
|
||||||
async def collect_message_samples(self, target_timestamp: float) -> Optional[List[Dict[str, Any]]]:
|
async def collect_message_samples(self, target_timestamp: float) -> list[dict[str, Any]] | None:
|
||||||
"""收集指定时间戳附近的消息样本"""
|
"""收集指定时间戳附近的消息样本"""
|
||||||
try:
|
try:
|
||||||
# 随机时间窗口:5-30分钟
|
# 随机时间窗口:5-30分钟
|
||||||
@@ -190,7 +188,7 @@ class HippocampusSampler:
|
|||||||
logger.error(f"收集消息样本失败: {e}")
|
logger.error(f"收集消息样本失败: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def build_memory_from_samples(self, messages: List[Dict[str, Any]], target_timestamp: float) -> Optional[str]:
|
async def build_memory_from_samples(self, messages: list[dict[str, Any]], target_timestamp: float) -> str | None:
|
||||||
"""从消息样本构建记忆"""
|
"""从消息样本构建记忆"""
|
||||||
if not messages or not self.memory_system or not self.memory_builder_model:
|
if not messages or not self.memory_system or not self.memory_builder_model:
|
||||||
return None
|
return None
|
||||||
@@ -262,7 +260,7 @@ class HippocampusSampler:
|
|||||||
logger.error(f"海马体采样构建记忆失败: {e}")
|
logger.error(f"海马体采样构建记忆失败: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def perform_sampling_cycle(self) -> Dict[str, Any]:
|
async def perform_sampling_cycle(self) -> dict[str, Any]:
|
||||||
"""执行一次完整的采样周期(优化版:批量融合构建)"""
|
"""执行一次完整的采样周期(优化版:批量融合构建)"""
|
||||||
if not self.should_sample():
|
if not self.should_sample():
|
||||||
return {"status": "skipped", "reason": "interval_not_met"}
|
return {"status": "skipped", "reason": "interval_not_met"}
|
||||||
@@ -363,7 +361,7 @@ class HippocampusSampler:
|
|||||||
"duration": time.time() - start_time,
|
"duration": time.time() - start_time,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _collect_all_message_samples(self, time_samples: List[datetime]) -> List[List[Dict[str, Any]]]:
|
async def _collect_all_message_samples(self, time_samples: list[datetime]) -> list[list[dict[str, Any]]]:
|
||||||
"""批量收集所有时间点的消息样本"""
|
"""批量收集所有时间点的消息样本"""
|
||||||
collected_messages = []
|
collected_messages = []
|
||||||
max_concurrent = min(5, len(time_samples)) # 提高并发数到5
|
max_concurrent = min(5, len(time_samples)) # 提高并发数到5
|
||||||
@@ -394,7 +392,7 @@ class HippocampusSampler:
|
|||||||
|
|
||||||
return collected_messages
|
return collected_messages
|
||||||
|
|
||||||
async def _fuse_and_deduplicate_messages(self, collected_messages: List[List[Dict[str, Any]]]) -> List[List[Dict[str, Any]]]:
|
async def _fuse_and_deduplicate_messages(self, collected_messages: list[list[dict[str, Any]]]) -> list[list[dict[str, Any]]]:
|
||||||
"""融合和去重消息样本"""
|
"""融合和去重消息样本"""
|
||||||
if not collected_messages:
|
if not collected_messages:
|
||||||
return []
|
return []
|
||||||
@@ -450,7 +448,7 @@ class HippocampusSampler:
|
|||||||
# 返回原始消息组作为备选
|
# 返回原始消息组作为备选
|
||||||
return collected_messages[:5] # 限制返回数量
|
return collected_messages[:5] # 限制返回数量
|
||||||
|
|
||||||
def _merge_adjacent_messages(self, messages: List[Dict[str, Any]], time_gap: int = 1800) -> List[List[Dict[str, Any]]]:
|
def _merge_adjacent_messages(self, messages: list[dict[str, Any]], time_gap: int = 1800) -> list[list[dict[str, Any]]]:
|
||||||
"""合并时间间隔内的消息"""
|
"""合并时间间隔内的消息"""
|
||||||
if not messages:
|
if not messages:
|
||||||
return []
|
return []
|
||||||
@@ -481,7 +479,7 @@ class HippocampusSampler:
|
|||||||
|
|
||||||
return result_groups
|
return result_groups
|
||||||
|
|
||||||
async def _build_batch_memory(self, fused_messages: List[List[Dict[str, Any]]], time_samples: List[datetime]) -> Dict[str, Any]:
|
async def _build_batch_memory(self, fused_messages: list[list[dict[str, Any]]], time_samples: list[datetime]) -> dict[str, Any]:
|
||||||
"""批量构建记忆"""
|
"""批量构建记忆"""
|
||||||
if not fused_messages:
|
if not fused_messages:
|
||||||
return {"memory_count": 0, "memories": []}
|
return {"memory_count": 0, "memories": []}
|
||||||
@@ -557,7 +555,7 @@ class HippocampusSampler:
|
|||||||
logger.error(f"批量构建记忆失败: {e}")
|
logger.error(f"批量构建记忆失败: {e}")
|
||||||
return {"memory_count": 0, "error": str(e)}
|
return {"memory_count": 0, "error": str(e)}
|
||||||
|
|
||||||
async def _build_fused_conversation_text(self, fused_messages: List[List[Dict[str, Any]]]) -> str:
|
async def _build_fused_conversation_text(self, fused_messages: list[list[dict[str, Any]]]) -> str:
|
||||||
"""构建融合后的对话文本"""
|
"""构建融合后的对话文本"""
|
||||||
try:
|
try:
|
||||||
# 添加批次标识
|
# 添加批次标识
|
||||||
@@ -589,7 +587,7 @@ class HippocampusSampler:
|
|||||||
logger.error(f"构建融合文本失败: {e}")
|
logger.error(f"构建融合文本失败: {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
async def _fallback_individual_build(self, fused_messages: List[List[Dict[str, Any]]]) -> Dict[str, Any]:
|
async def _fallback_individual_build(self, fused_messages: list[list[dict[str, Any]]]) -> dict[str, Any]:
|
||||||
"""备选方案:单独构建每个消息组"""
|
"""备选方案:单独构建每个消息组"""
|
||||||
total_memories = []
|
total_memories = []
|
||||||
total_count = 0
|
total_count = 0
|
||||||
@@ -609,7 +607,7 @@ class HippocampusSampler:
|
|||||||
"fallback_mode": True
|
"fallback_mode": True
|
||||||
}
|
}
|
||||||
|
|
||||||
async def process_sample_timestamp(self, target_timestamp: float) -> Optional[str]:
|
async def process_sample_timestamp(self, target_timestamp: float) -> str | None:
|
||||||
"""处理单个时间戳采样(保留作为备选方法)"""
|
"""处理单个时间戳采样(保留作为备选方法)"""
|
||||||
try:
|
try:
|
||||||
# 收集消息样本
|
# 收集消息样本
|
||||||
@@ -676,7 +674,7 @@ class HippocampusSampler:
|
|||||||
self.is_running = False
|
self.is_running = False
|
||||||
logger.info("🛑 停止海马体后台采样任务")
|
logger.info("🛑 停止海马体后台采样任务")
|
||||||
|
|
||||||
def get_sampling_stats(self) -> Dict[str, Any]:
|
def get_sampling_stats(self) -> dict[str, Any]:
|
||||||
"""获取采样统计信息"""
|
"""获取采样统计信息"""
|
||||||
success_rate = (self.success_count / self.sample_count * 100) if self.sample_count > 0 else 0
|
success_rate = (self.success_count / self.sample_count * 100) if self.sample_count > 0 else 0
|
||||||
|
|
||||||
@@ -713,7 +711,7 @@ class HippocampusSampler:
|
|||||||
|
|
||||||
|
|
||||||
# 全局海马体采样器实例
|
# 全局海马体采样器实例
|
||||||
_hippocampus_sampler: Optional[HippocampusSampler] = None
|
_hippocampus_sampler: HippocampusSampler | None = None
|
||||||
|
|
||||||
|
|
||||||
def get_hippocampus_sampler(memory_system=None) -> HippocampusSampler:
|
def get_hippocampus_sampler(memory_system=None) -> HippocampusSampler:
|
||||||
@@ -728,4 +726,4 @@ async def initialize_hippocampus_sampler(memory_system=None) -> HippocampusSampl
|
|||||||
"""初始化全局海马体采样器"""
|
"""初始化全局海马体采样器"""
|
||||||
sampler = get_hippocampus_sampler(memory_system)
|
sampler = get_hippocampus_sampler(memory_system)
|
||||||
await sampler.initialize()
|
await sampler.initialize()
|
||||||
return sampler
|
return sampler
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ import time
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Type, TypeVar
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
E = TypeVar("E", bound=Enum)
|
E = TypeVar("E", bound=Enum)
|
||||||
|
|
||||||
@@ -503,7 +503,7 @@ class MemoryBuilder:
|
|||||||
logger.warning(f"无法解析未知的记忆类型 '{type_str}',回退到上下文类型")
|
logger.warning(f"无法解析未知的记忆类型 '{type_str}',回退到上下文类型")
|
||||||
return MemoryType.CONTEXTUAL
|
return MemoryType.CONTEXTUAL
|
||||||
|
|
||||||
def _parse_enum_value(self, enum_cls: Type[E], raw_value: Any, default: E, field_name: str) -> E:
|
def _parse_enum_value(self, enum_cls: type[E], raw_value: Any, default: E, field_name: str) -> E:
|
||||||
"""解析枚举值,兼容数字/字符串表示"""
|
"""解析枚举值,兼容数字/字符串表示"""
|
||||||
if isinstance(raw_value, enum_cls):
|
if isinstance(raw_value, enum_cls):
|
||||||
return raw_value
|
return raw_value
|
||||||
|
|||||||
@@ -556,11 +556,11 @@ class MemorySystem:
|
|||||||
context = dict(context or {})
|
context = dict(context or {})
|
||||||
|
|
||||||
# 获取配置的采样模式
|
# 获取配置的采样模式
|
||||||
sampling_mode = getattr(global_config.memory, 'memory_sampling_mode', 'precision')
|
sampling_mode = getattr(global_config.memory, "memory_sampling_mode", "precision")
|
||||||
current_mode = MemorySamplingMode(sampling_mode)
|
current_mode = MemorySamplingMode(sampling_mode)
|
||||||
|
|
||||||
|
|
||||||
context['__sampling_mode'] = current_mode.value
|
context["__sampling_mode"] = current_mode.value
|
||||||
logger.debug(f"使用记忆采样模式: {current_mode.value}")
|
logger.debug(f"使用记忆采样模式: {current_mode.value}")
|
||||||
|
|
||||||
# 根据采样模式处理记忆
|
# 根据采样模式处理记忆
|
||||||
@@ -636,7 +636,7 @@ class MemorySystem:
|
|||||||
|
|
||||||
# 检查信息价值阈值
|
# 检查信息价值阈值
|
||||||
value_score = await self._assess_information_value(conversation_text, normalized_context)
|
value_score = await self._assess_information_value(conversation_text, normalized_context)
|
||||||
threshold = getattr(global_config.memory, 'precision_memory_reply_threshold', 0.5)
|
threshold = getattr(global_config.memory, "precision_memory_reply_threshold", 0.5)
|
||||||
|
|
||||||
if value_score < threshold:
|
if value_score < threshold:
|
||||||
logger.debug(f"信息价值评分 {value_score:.2f} 低于阈值 {threshold},跳过记忆构建")
|
logger.debug(f"信息价值评分 {value_score:.2f} 低于阈值 {threshold},跳过记忆构建")
|
||||||
@@ -1614,8 +1614,8 @@ async def initialize_memory_system(llm_model: LLMRequest | None = None):
|
|||||||
await memory_system.initialize()
|
await memory_system.initialize()
|
||||||
|
|
||||||
# 根据配置启动海马体采样
|
# 根据配置启动海马体采样
|
||||||
sampling_mode = getattr(global_config.memory, 'memory_sampling_mode', 'immediate')
|
sampling_mode = getattr(global_config.memory, "memory_sampling_mode", "immediate")
|
||||||
if sampling_mode in ['hippocampus', 'all']:
|
if sampling_mode in ["hippocampus", "all"]:
|
||||||
memory_system.start_hippocampus_sampling()
|
memory_system.start_hippocampus_sampling()
|
||||||
|
|
||||||
return memory_system
|
return memory_system
|
||||||
|
|||||||
@@ -4,14 +4,13 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import psutil
|
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional, Set, Tuple
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
|
||||||
|
|
||||||
logger = get_logger("adaptive_stream_manager")
|
logger = get_logger("adaptive_stream_manager")
|
||||||
|
|
||||||
@@ -71,16 +70,16 @@ class AdaptiveStreamManager:
|
|||||||
|
|
||||||
# 当前状态
|
# 当前状态
|
||||||
self.current_limit = base_concurrent_limit
|
self.current_limit = base_concurrent_limit
|
||||||
self.active_streams: Set[str] = set()
|
self.active_streams: set[str] = set()
|
||||||
self.pending_streams: Set[str] = set()
|
self.pending_streams: set[str] = set()
|
||||||
self.stream_metrics: Dict[str, StreamMetrics] = {}
|
self.stream_metrics: dict[str, StreamMetrics] = {}
|
||||||
|
|
||||||
# 异步信号量
|
# 异步信号量
|
||||||
self.semaphore = asyncio.Semaphore(base_concurrent_limit)
|
self.semaphore = asyncio.Semaphore(base_concurrent_limit)
|
||||||
self.priority_semaphore = asyncio.Semaphore(5) # 高优先级专用信号量
|
self.priority_semaphore = asyncio.Semaphore(5) # 高优先级专用信号量
|
||||||
|
|
||||||
# 系统监控
|
# 系统监控
|
||||||
self.system_metrics: List[SystemMetrics] = []
|
self.system_metrics: list[SystemMetrics] = []
|
||||||
self.last_adjustment_time = 0.0
|
self.last_adjustment_time = 0.0
|
||||||
|
|
||||||
# 统计信息
|
# 统计信息
|
||||||
@@ -95,8 +94,8 @@ class AdaptiveStreamManager:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 监控任务
|
# 监控任务
|
||||||
self.monitor_task: Optional[asyncio.Task] = None
|
self.monitor_task: asyncio.Task | None = None
|
||||||
self.adjustment_task: Optional[asyncio.Task] = None
|
self.adjustment_task: asyncio.Task | None = None
|
||||||
self.is_running = False
|
self.is_running = False
|
||||||
|
|
||||||
logger.info(f"自适应流管理器初始化完成 (base_limit={base_concurrent_limit}, max_limit={max_concurrent_limit})")
|
logger.info(f"自适应流管理器初始化完成 (base_limit={base_concurrent_limit}, max_limit={max_concurrent_limit})")
|
||||||
@@ -443,7 +442,7 @@ class AdaptiveStreamManager:
|
|||||||
if hasattr(metrics, key):
|
if hasattr(metrics, key):
|
||||||
setattr(metrics, key, value)
|
setattr(metrics, key, value)
|
||||||
|
|
||||||
def get_stats(self) -> Dict:
|
def get_stats(self) -> dict:
|
||||||
"""获取统计信息"""
|
"""获取统计信息"""
|
||||||
stats = self.stats.copy()
|
stats = self.stats.copy()
|
||||||
stats.update({
|
stats.update({
|
||||||
@@ -465,7 +464,7 @@ class AdaptiveStreamManager:
|
|||||||
|
|
||||||
|
|
||||||
# 全局自适应管理器实例
|
# 全局自适应管理器实例
|
||||||
_adaptive_manager: Optional[AdaptiveStreamManager] = None
|
_adaptive_manager: AdaptiveStreamManager | None = None
|
||||||
|
|
||||||
|
|
||||||
def get_adaptive_stream_manager() -> AdaptiveStreamManager:
|
def get_adaptive_stream_manager() -> AdaptiveStreamManager:
|
||||||
@@ -485,4 +484,4 @@ async def init_adaptive_stream_manager():
|
|||||||
async def shutdown_adaptive_stream_manager():
|
async def shutdown_adaptive_stream_manager():
|
||||||
"""关闭自适应流管理器"""
|
"""关闭自适应流管理器"""
|
||||||
manager = get_adaptive_stream_manager()
|
manager = get_adaptive_stream_manager()
|
||||||
await manager.stop()
|
await manager.stop()
|
||||||
|
|||||||
@@ -5,9 +5,9 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||||
from src.common.database.sqlalchemy_models import ChatStreams
|
from src.common.database.sqlalchemy_models import ChatStreams
|
||||||
@@ -21,7 +21,7 @@ logger = get_logger("batch_database_writer")
|
|||||||
class StreamUpdatePayload:
|
class StreamUpdatePayload:
|
||||||
"""流更新数据结构"""
|
"""流更新数据结构"""
|
||||||
stream_id: str
|
stream_id: str
|
||||||
update_data: Dict[str, Any]
|
update_data: dict[str, Any]
|
||||||
priority: int = 0 # 优先级,数字越大优先级越高
|
priority: int = 0 # 优先级,数字越大优先级越高
|
||||||
timestamp: float = field(default_factory=time.time)
|
timestamp: float = field(default_factory=time.time)
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ class BatchDatabaseWriter:
|
|||||||
|
|
||||||
# 运行状态
|
# 运行状态
|
||||||
self.is_running = False
|
self.is_running = False
|
||||||
self.writer_task: Optional[asyncio.Task] = None
|
self.writer_task: asyncio.Task | None = None
|
||||||
|
|
||||||
# 统计信息
|
# 统计信息
|
||||||
self.stats = {
|
self.stats = {
|
||||||
@@ -60,7 +60,7 @@ class BatchDatabaseWriter:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 按优先级分类的批次
|
# 按优先级分类的批次
|
||||||
self.priority_batches: Dict[int, List[StreamUpdatePayload]] = defaultdict(list)
|
self.priority_batches: dict[int, list[StreamUpdatePayload]] = defaultdict(list)
|
||||||
|
|
||||||
logger.info(f"批量数据库写入器初始化完成 (batch_size={batch_size}, interval={flush_interval}s)")
|
logger.info(f"批量数据库写入器初始化完成 (batch_size={batch_size}, interval={flush_interval}s)")
|
||||||
|
|
||||||
@@ -98,7 +98,7 @@ class BatchDatabaseWriter:
|
|||||||
async def schedule_stream_update(
|
async def schedule_stream_update(
|
||||||
self,
|
self,
|
||||||
stream_id: str,
|
stream_id: str,
|
||||||
update_data: Dict[str, Any],
|
update_data: dict[str, Any],
|
||||||
priority: int = 0
|
priority: int = 0
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -166,7 +166,7 @@ class BatchDatabaseWriter:
|
|||||||
await self._flush_all_batches()
|
await self._flush_all_batches()
|
||||||
logger.info("批量写入循环结束")
|
logger.info("批量写入循环结束")
|
||||||
|
|
||||||
async def _collect_batch(self) -> List[StreamUpdatePayload]:
|
async def _collect_batch(self) -> list[StreamUpdatePayload]:
|
||||||
"""收集一个批次的数据"""
|
"""收集一个批次的数据"""
|
||||||
batch = []
|
batch = []
|
||||||
deadline = time.time() + self.flush_interval
|
deadline = time.time() + self.flush_interval
|
||||||
@@ -189,7 +189,7 @@ class BatchDatabaseWriter:
|
|||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
async def _write_batch(self, batch: List[StreamUpdatePayload]):
|
async def _write_batch(self, batch: list[StreamUpdatePayload]):
|
||||||
"""批量写入数据库"""
|
"""批量写入数据库"""
|
||||||
if not batch:
|
if not batch:
|
||||||
return
|
return
|
||||||
@@ -228,7 +228,7 @@ class BatchDatabaseWriter:
|
|||||||
except Exception as single_e:
|
except Exception as single_e:
|
||||||
logger.error(f"单个写入也失败: {single_e}")
|
logger.error(f"单个写入也失败: {single_e}")
|
||||||
|
|
||||||
async def _batch_write_to_database(self, payloads: List[StreamUpdatePayload]):
|
async def _batch_write_to_database(self, payloads: list[StreamUpdatePayload]):
|
||||||
"""批量写入数据库"""
|
"""批量写入数据库"""
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
for payload in payloads:
|
for payload in payloads:
|
||||||
@@ -268,7 +268,7 @@ class BatchDatabaseWriter:
|
|||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def _direct_write(self, stream_id: str, update_data: Dict[str, Any]):
|
async def _direct_write(self, stream_id: str, update_data: dict[str, Any]):
|
||||||
"""直接写入数据库(降级方案)"""
|
"""直接写入数据库(降级方案)"""
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
if global_config.database.database_type == "sqlite":
|
if global_config.database.database_type == "sqlite":
|
||||||
@@ -315,7 +315,7 @@ class BatchDatabaseWriter:
|
|||||||
if remaining_batch:
|
if remaining_batch:
|
||||||
await self._write_batch(remaining_batch)
|
await self._write_batch(remaining_batch)
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> dict[str, Any]:
|
||||||
"""获取统计信息"""
|
"""获取统计信息"""
|
||||||
stats = self.stats.copy()
|
stats = self.stats.copy()
|
||||||
stats["is_running"] = self.is_running
|
stats["is_running"] = self.is_running
|
||||||
@@ -324,7 +324,7 @@ class BatchDatabaseWriter:
|
|||||||
|
|
||||||
|
|
||||||
# 全局批量写入器实例
|
# 全局批量写入器实例
|
||||||
_batch_writer: Optional[BatchDatabaseWriter] = None
|
_batch_writer: BatchDatabaseWriter | None = None
|
||||||
|
|
||||||
|
|
||||||
def get_batch_writer() -> BatchDatabaseWriter:
|
def get_batch_writer() -> BatchDatabaseWriter:
|
||||||
@@ -344,4 +344,4 @@ async def init_batch_writer():
|
|||||||
async def shutdown_batch_writer():
|
async def shutdown_batch_writer():
|
||||||
"""关闭批量写入器"""
|
"""关闭批量写入器"""
|
||||||
writer = get_batch_writer()
|
writer = get_batch_writer()
|
||||||
await writer.stop()
|
await writer.stop()
|
||||||
|
|||||||
@@ -117,7 +117,7 @@ class StreamLoopManager:
|
|||||||
# 使用自适应流管理器获取槽位
|
# 使用自适应流管理器获取槽位
|
||||||
use_adaptive = False
|
use_adaptive = False
|
||||||
try:
|
try:
|
||||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager, StreamPriority
|
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||||
adaptive_manager = get_adaptive_stream_manager()
|
adaptive_manager = get_adaptive_stream_manager()
|
||||||
|
|
||||||
if adaptive_manager.is_running:
|
if adaptive_manager.is_running:
|
||||||
@@ -137,7 +137,7 @@ class StreamLoopManager:
|
|||||||
else:
|
else:
|
||||||
logger.debug(f"自适应管理器拒绝槽位请求: {stream_id},尝试回退方案")
|
logger.debug(f"自适应管理器拒绝槽位请求: {stream_id},尝试回退方案")
|
||||||
else:
|
else:
|
||||||
logger.debug(f"自适应管理器未运行,使用原始方法")
|
logger.debug("自适应管理器未运行,使用原始方法")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"自适应管理器获取槽位失败,使用原始方法: {e}")
|
logger.debug(f"自适应管理器获取槽位失败,使用原始方法: {e}")
|
||||||
|
|||||||
@@ -5,13 +5,13 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional, Set
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from maim_message import GroupInfo, UserInfo
|
from maim_message import GroupInfo, UserInfo
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.chat.message_receive.optimized_chat_stream import OptimizedChatStream, create_optimized_chat_stream
|
from src.chat.message_receive.optimized_chat_stream import OptimizedChatStream, create_optimized_chat_stream
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("stream_cache_manager")
|
logger = get_logger("stream_cache_manager")
|
||||||
|
|
||||||
@@ -52,14 +52,14 @@ class TieredStreamCache:
|
|||||||
|
|
||||||
# 三层缓存存储
|
# 三层缓存存储
|
||||||
self.hot_cache: OrderedDict[str, OptimizedChatStream] = OrderedDict() # 热数据(LRU)
|
self.hot_cache: OrderedDict[str, OptimizedChatStream] = OrderedDict() # 热数据(LRU)
|
||||||
self.warm_storage: Dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间)
|
self.warm_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间)
|
||||||
self.cold_storage: Dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间)
|
self.cold_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间)
|
||||||
|
|
||||||
# 统计信息
|
# 统计信息
|
||||||
self.stats = StreamCacheStats()
|
self.stats = StreamCacheStats()
|
||||||
|
|
||||||
# 清理任务
|
# 清理任务
|
||||||
self.cleanup_task: Optional[asyncio.Task] = None
|
self.cleanup_task: asyncio.Task | None = None
|
||||||
self.is_running = False
|
self.is_running = False
|
||||||
|
|
||||||
logger.info(f"分层流缓存管理器初始化完成 (hot:{max_hot_size}, warm:{max_warm_size}, cold:{max_cold_size})")
|
logger.info(f"分层流缓存管理器初始化完成 (hot:{max_hot_size}, warm:{max_warm_size}, cold:{max_cold_size})")
|
||||||
@@ -96,8 +96,8 @@ class TieredStreamCache:
|
|||||||
stream_id: str,
|
stream_id: str,
|
||||||
platform: str,
|
platform: str,
|
||||||
user_info: UserInfo,
|
user_info: UserInfo,
|
||||||
group_info: Optional[GroupInfo] = None,
|
group_info: GroupInfo | None = None,
|
||||||
data: Optional[Dict] = None,
|
data: dict | None = None,
|
||||||
) -> OptimizedChatStream:
|
) -> OptimizedChatStream:
|
||||||
"""获取或创建流 - 优化版本"""
|
"""获取或创建流 - 优化版本"""
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
@@ -255,7 +255,7 @@ class TieredStreamCache:
|
|||||||
hot_to_demote = []
|
hot_to_demote = []
|
||||||
for stream_id, stream in self.hot_cache.items():
|
for stream_id, stream in self.hot_cache.items():
|
||||||
# 获取最后访问时间(简化:使用创建时间作为近似)
|
# 获取最后访问时间(简化:使用创建时间作为近似)
|
||||||
last_access = getattr(stream, 'last_active_time', stream.create_time)
|
last_access = getattr(stream, "last_active_time", stream.create_time)
|
||||||
if current_time - last_access > self.hot_timeout:
|
if current_time - last_access > self.hot_timeout:
|
||||||
hot_to_demote.append(stream_id)
|
hot_to_demote.append(stream_id)
|
||||||
|
|
||||||
@@ -341,7 +341,7 @@ class TieredStreamCache:
|
|||||||
|
|
||||||
logger.info("所有缓存已清空")
|
logger.info("所有缓存已清空")
|
||||||
|
|
||||||
async def get_stream_snapshot(self, stream_id: str) -> Optional[OptimizedChatStream]:
|
async def get_stream_snapshot(self, stream_id: str) -> OptimizedChatStream | None:
|
||||||
"""获取流的快照(不修改缓存状态)"""
|
"""获取流的快照(不修改缓存状态)"""
|
||||||
if stream_id in self.hot_cache:
|
if stream_id in self.hot_cache:
|
||||||
return self.hot_cache[stream_id].create_snapshot()
|
return self.hot_cache[stream_id].create_snapshot()
|
||||||
@@ -351,13 +351,13 @@ class TieredStreamCache:
|
|||||||
return self.cold_storage[stream_id][0].create_snapshot()
|
return self.cold_storage[stream_id][0].create_snapshot()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_cached_stream_ids(self) -> Set[str]:
|
def get_cached_stream_ids(self) -> set[str]:
|
||||||
"""获取所有缓存的流ID"""
|
"""获取所有缓存的流ID"""
|
||||||
return set(self.hot_cache.keys()) | set(self.warm_storage.keys()) | set(self.cold_storage.keys())
|
return set(self.hot_cache.keys()) | set(self.warm_storage.keys()) | set(self.cold_storage.keys())
|
||||||
|
|
||||||
|
|
||||||
# 全局缓存管理器实例
|
# 全局缓存管理器实例
|
||||||
_cache_manager: Optional[TieredStreamCache] = None
|
_cache_manager: TieredStreamCache | None = None
|
||||||
|
|
||||||
|
|
||||||
def get_stream_cache_manager() -> TieredStreamCache:
|
def get_stream_cache_manager() -> TieredStreamCache:
|
||||||
@@ -377,4 +377,4 @@ async def init_stream_cache_manager():
|
|||||||
async def shutdown_stream_cache_manager():
|
async def shutdown_stream_cache_manager():
|
||||||
"""关闭流缓存管理器"""
|
"""关闭流缓存管理器"""
|
||||||
manager = get_stream_cache_manager()
|
manager = get_stream_cache_manager()
|
||||||
await manager.stop()
|
await manager.stop()
|
||||||
|
|||||||
@@ -313,11 +313,11 @@ class ChatStream:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"计算消息兴趣值失败: {e}", exc_info=True)
|
logger.error(f"计算消息兴趣值失败: {e}", exc_info=True)
|
||||||
# 异常情况下使用默认值
|
# 异常情况下使用默认值
|
||||||
if hasattr(db_message, 'interest_value'):
|
if hasattr(db_message, "interest_value"):
|
||||||
db_message.interest_value = 0.3
|
db_message.interest_value = 0.3
|
||||||
if hasattr(db_message, 'should_reply'):
|
if hasattr(db_message, "should_reply"):
|
||||||
db_message.should_reply = False
|
db_message.should_reply = False
|
||||||
if hasattr(db_message, 'should_act'):
|
if hasattr(db_message, "should_act"):
|
||||||
db_message.should_act = False
|
db_message.should_act = False
|
||||||
|
|
||||||
def _extract_reply_from_segment(self, segment) -> str | None:
|
def _extract_reply_from_segment(self, segment) -> str | None:
|
||||||
@@ -894,10 +894,10 @@ def _convert_to_original_stream(self, optimized_stream) -> "ChatStream":
|
|||||||
original_stream.saved = optimized_stream.saved
|
original_stream.saved = optimized_stream.saved
|
||||||
|
|
||||||
# 复制上下文信息(如果存在)
|
# 复制上下文信息(如果存在)
|
||||||
if hasattr(optimized_stream, '_stream_context') and optimized_stream._stream_context:
|
if hasattr(optimized_stream, "_stream_context") and optimized_stream._stream_context:
|
||||||
original_stream.stream_context = optimized_stream._stream_context
|
original_stream.stream_context = optimized_stream._stream_context
|
||||||
|
|
||||||
if hasattr(optimized_stream, '_context_manager') and optimized_stream._context_manager:
|
if hasattr(optimized_stream, "_context_manager") and optimized_stream._context_manager:
|
||||||
original_stream.context_manager = optimized_stream._context_manager
|
original_stream.context_manager = optimized_stream._context_manager
|
||||||
|
|
||||||
return original_stream
|
return original_stream
|
||||||
|
|||||||
@@ -3,17 +3,12 @@
|
|||||||
避免不必要的深拷贝开销,提升多流并发性能
|
避免不必要的深拷贝开销,提升多流并发性能
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import copy
|
|
||||||
import hashlib
|
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from maim_message import GroupInfo, UserInfo
|
from maim_message import GroupInfo, UserInfo
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
|
||||||
from src.common.database.sqlalchemy_models import ChatStreams
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
@@ -28,7 +23,7 @@ logger = get_logger("optimized_chat_stream")
|
|||||||
class SharedContext:
|
class SharedContext:
|
||||||
"""共享上下文数据 - 只读数据结构"""
|
"""共享上下文数据 - 只读数据结构"""
|
||||||
|
|
||||||
def __init__(self, stream_id: str, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None):
|
def __init__(self, stream_id: str, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None):
|
||||||
self.stream_id = stream_id
|
self.stream_id = stream_id
|
||||||
self.platform = platform
|
self.platform = platform
|
||||||
self.user_info = user_info
|
self.user_info = user_info
|
||||||
@@ -37,7 +32,7 @@ class SharedContext:
|
|||||||
self._frozen = True
|
self._frozen = True
|
||||||
|
|
||||||
def __setattr__(self, name, value):
|
def __setattr__(self, name, value):
|
||||||
if hasattr(self, '_frozen') and self._frozen and name not in ['_frozen']:
|
if hasattr(self, "_frozen") and self._frozen and name not in ["_frozen"]:
|
||||||
raise AttributeError(f"SharedContext is frozen, cannot modify {name}")
|
raise AttributeError(f"SharedContext is frozen, cannot modify {name}")
|
||||||
super().__setattr__(name, value)
|
super().__setattr__(name, value)
|
||||||
|
|
||||||
@@ -46,7 +41,7 @@ class LocalChanges:
|
|||||||
"""本地修改跟踪器"""
|
"""本地修改跟踪器"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._changes: Dict[str, Any] = {}
|
self._changes: dict[str, Any] = {}
|
||||||
self._dirty = False
|
self._dirty = False
|
||||||
|
|
||||||
def set_change(self, key: str, value: Any):
|
def set_change(self, key: str, value: Any):
|
||||||
@@ -62,7 +57,7 @@ class LocalChanges:
|
|||||||
"""是否有修改"""
|
"""是否有修改"""
|
||||||
return self._dirty
|
return self._dirty
|
||||||
|
|
||||||
def get_changes(self) -> Dict[str, Any]:
|
def get_changes(self) -> dict[str, Any]:
|
||||||
"""获取所有修改"""
|
"""获取所有修改"""
|
||||||
return self._changes.copy()
|
return self._changes.copy()
|
||||||
|
|
||||||
@@ -80,8 +75,8 @@ class OptimizedChatStream:
|
|||||||
stream_id: str,
|
stream_id: str,
|
||||||
platform: str,
|
platform: str,
|
||||||
user_info: UserInfo,
|
user_info: UserInfo,
|
||||||
group_info: Optional[GroupInfo] = None,
|
group_info: GroupInfo | None = None,
|
||||||
data: Optional[Dict] = None,
|
data: dict | None = None,
|
||||||
):
|
):
|
||||||
# 共享的只读数据
|
# 共享的只读数据
|
||||||
self._shared_context = SharedContext(
|
self._shared_context = SharedContext(
|
||||||
@@ -129,42 +124,42 @@ class OptimizedChatStream:
|
|||||||
"""修改用户信息时触发写时复制"""
|
"""修改用户信息时触发写时复制"""
|
||||||
self._ensure_copy_on_write()
|
self._ensure_copy_on_write()
|
||||||
# 由于SharedContext是frozen的,我们需要在本地修改中记录
|
# 由于SharedContext是frozen的,我们需要在本地修改中记录
|
||||||
self._local_changes.set_change('user_info', value)
|
self._local_changes.set_change("user_info", value)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def group_info(self) -> Optional[GroupInfo]:
|
def group_info(self) -> GroupInfo | None:
|
||||||
if self._local_changes.has_changes() and 'group_info' in self._local_changes._changes:
|
if self._local_changes.has_changes() and "group_info" in self._local_changes._changes:
|
||||||
return self._local_changes.get_change('group_info')
|
return self._local_changes.get_change("group_info")
|
||||||
return self._shared_context.group_info
|
return self._shared_context.group_info
|
||||||
|
|
||||||
@group_info.setter
|
@group_info.setter
|
||||||
def group_info(self, value: Optional[GroupInfo]):
|
def group_info(self, value: GroupInfo | None):
|
||||||
"""修改群组信息时触发写时复制"""
|
"""修改群组信息时触发写时复制"""
|
||||||
self._ensure_copy_on_write()
|
self._ensure_copy_on_write()
|
||||||
self._local_changes.set_change('group_info', value)
|
self._local_changes.set_change("group_info", value)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def create_time(self) -> float:
|
def create_time(self) -> float:
|
||||||
if self._local_changes.has_changes() and 'create_time' in self._local_changes._changes:
|
if self._local_changes.has_changes() and "create_time" in self._local_changes._changes:
|
||||||
return self._local_changes.get_change('create_time')
|
return self._local_changes.get_change("create_time")
|
||||||
return self._shared_context.create_time
|
return self._shared_context.create_time
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def last_active_time(self) -> float:
|
def last_active_time(self) -> float:
|
||||||
return self._local_changes.get_change('last_active_time', self.create_time)
|
return self._local_changes.get_change("last_active_time", self.create_time)
|
||||||
|
|
||||||
@last_active_time.setter
|
@last_active_time.setter
|
||||||
def last_active_time(self, value: float):
|
def last_active_time(self, value: float):
|
||||||
self._local_changes.set_change('last_active_time', value)
|
self._local_changes.set_change("last_active_time", value)
|
||||||
self.saved = False
|
self.saved = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sleep_pressure(self) -> float:
|
def sleep_pressure(self) -> float:
|
||||||
return self._local_changes.get_change('sleep_pressure', 0.0)
|
return self._local_changes.get_change("sleep_pressure", 0.0)
|
||||||
|
|
||||||
@sleep_pressure.setter
|
@sleep_pressure.setter
|
||||||
def sleep_pressure(self, value: float):
|
def sleep_pressure(self, value: float):
|
||||||
self._local_changes.set_change('sleep_pressure', value)
|
self._local_changes.set_change("sleep_pressure", value)
|
||||||
self.saved = False
|
self.saved = False
|
||||||
|
|
||||||
def _ensure_copy_on_write(self):
|
def _ensure_copy_on_write(self):
|
||||||
@@ -176,14 +171,14 @@ class OptimizedChatStream:
|
|||||||
|
|
||||||
def _get_effective_user_info(self) -> UserInfo:
|
def _get_effective_user_info(self) -> UserInfo:
|
||||||
"""获取有效的用户信息"""
|
"""获取有效的用户信息"""
|
||||||
if self._local_changes.has_changes() and 'user_info' in self._local_changes._changes:
|
if self._local_changes.has_changes() and "user_info" in self._local_changes._changes:
|
||||||
return self._local_changes.get_change('user_info')
|
return self._local_changes.get_change("user_info")
|
||||||
return self._shared_context.user_info
|
return self._shared_context.user_info
|
||||||
|
|
||||||
def _get_effective_group_info(self) -> Optional[GroupInfo]:
|
def _get_effective_group_info(self) -> GroupInfo | None:
|
||||||
"""获取有效的群组信息"""
|
"""获取有效的群组信息"""
|
||||||
if self._local_changes.has_changes() and 'group_info' in self._local_changes._changes:
|
if self._local_changes.has_changes() and "group_info" in self._local_changes._changes:
|
||||||
return self._local_changes.get_change('group_info')
|
return self._local_changes.get_change("group_info")
|
||||||
return self._shared_context.group_info
|
return self._shared_context.group_info
|
||||||
|
|
||||||
def update_active_time(self):
|
def update_active_time(self):
|
||||||
@@ -199,6 +194,7 @@ class OptimizedChatStream:
|
|||||||
|
|
||||||
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
|
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|
||||||
message_info = getattr(message, "message_info", {})
|
message_info = getattr(message, "message_info", {})
|
||||||
@@ -298,7 +294,7 @@ class OptimizedChatStream:
|
|||||||
self._create_stream_context()
|
self._create_stream_context()
|
||||||
return self._context_manager
|
return self._context_manager
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""转换为字典格式 - 考虑本地修改"""
|
"""转换为字典格式 - 考虑本地修改"""
|
||||||
user_info = self._get_effective_user_info()
|
user_info = self._get_effective_user_info()
|
||||||
group_info = self._get_effective_group_info()
|
group_info = self._get_effective_group_info()
|
||||||
@@ -319,7 +315,7 @@ class OptimizedChatStream:
|
|||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Dict) -> "OptimizedChatStream":
|
def from_dict(cls, data: dict) -> "OptimizedChatStream":
|
||||||
"""从字典创建实例"""
|
"""从字典创建实例"""
|
||||||
user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
|
user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
|
||||||
group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
|
group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
|
||||||
@@ -481,8 +477,8 @@ def create_optimized_chat_stream(
|
|||||||
stream_id: str,
|
stream_id: str,
|
||||||
platform: str,
|
platform: str,
|
||||||
user_info: UserInfo,
|
user_info: UserInfo,
|
||||||
group_info: Optional[GroupInfo] = None,
|
group_info: GroupInfo | None = None,
|
||||||
data: Optional[Dict] = None,
|
data: dict | None = None,
|
||||||
) -> OptimizedChatStream:
|
) -> OptimizedChatStream:
|
||||||
"""创建优化版聊天流实例"""
|
"""创建优化版聊天流实例"""
|
||||||
return OptimizedChatStream(
|
return OptimizedChatStream(
|
||||||
@@ -491,4 +487,4 @@ def create_optimized_chat_stream(
|
|||||||
user_info=user_info,
|
user_info=user_info,
|
||||||
group_info=group_info,
|
group_info=group_info,
|
||||||
data=data
|
data=data
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from src.plugin_system.base.component_types import ActionActivationType, ActionI
|
|||||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
pass
|
||||||
|
|
||||||
logger = get_logger("action_manager")
|
logger = get_logger("action_manager")
|
||||||
|
|
||||||
|
|||||||
@@ -536,7 +536,7 @@ class Prompt:
|
|||||||
style = expr.get("style", "")
|
style = expr.get("style", "")
|
||||||
if situation and style:
|
if situation and style:
|
||||||
formatted_expressions.append(f"- {situation}:{style}")
|
formatted_expressions.append(f"- {situation}:{style}")
|
||||||
|
|
||||||
if formatted_expressions:
|
if formatted_expressions:
|
||||||
style_habits_str = "\n".join(formatted_expressions)
|
style_habits_str = "\n".join(formatted_expressions)
|
||||||
expression_habits_block = f"你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:\n{style_habits_str}"
|
expression_habits_block = f"你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:\n{style_habits_str}"
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import time
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import rjieba
|
|
||||||
import orjson
|
import orjson
|
||||||
|
import rjieba
|
||||||
from pypinyin import Style, pinyin
|
from pypinyin import Style, pinyin
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import time
|
|||||||
from collections import Counter
|
from collections import Counter
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import rjieba
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import rjieba
|
||||||
from maim_message import UserInfo
|
from maim_message import UserInfo
|
||||||
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
|||||||
@@ -5,9 +5,8 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import weakref
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Any, Dict, Optional, Set
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
@@ -69,7 +68,7 @@ class ConnectionPoolManager:
|
|||||||
self.max_idle = max_idle
|
self.max_idle = max_idle
|
||||||
|
|
||||||
# 连接池
|
# 连接池
|
||||||
self._connections: Set[ConnectionInfo] = set()
|
self._connections: set[ConnectionInfo] = set()
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
# 统计信息
|
# 统计信息
|
||||||
@@ -83,7 +82,7 @@ class ConnectionPoolManager:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 后台清理任务
|
# 后台清理任务
|
||||||
self._cleanup_task: Optional[asyncio.Task] = None
|
self._cleanup_task: asyncio.Task | None = None
|
||||||
self._should_cleanup = False
|
self._should_cleanup = False
|
||||||
|
|
||||||
logger.info(f"连接池管理器初始化完成 (最大池大小: {max_pool_size})")
|
logger.info(f"连接池管理器初始化完成 (最大池大小: {max_pool_size})")
|
||||||
@@ -144,7 +143,7 @@ class ConnectionPoolManager:
|
|||||||
|
|
||||||
yield connection_info.session
|
yield connection_info.session
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
# 发生错误时回滚连接
|
# 发生错误时回滚连接
|
||||||
if connection_info and connection_info.session:
|
if connection_info and connection_info.session:
|
||||||
try:
|
try:
|
||||||
@@ -157,7 +156,7 @@ class ConnectionPoolManager:
|
|||||||
if connection_info:
|
if connection_info:
|
||||||
connection_info.mark_released()
|
connection_info.mark_released()
|
||||||
|
|
||||||
async def _get_reusable_connection(self, session_factory: async_sessionmaker[AsyncSession]) -> Optional[ConnectionInfo]:
|
async def _get_reusable_connection(self, session_factory: async_sessionmaker[AsyncSession]) -> ConnectionInfo | None:
|
||||||
"""获取可复用的连接"""
|
"""获取可复用的连接"""
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
# 清理过期连接
|
# 清理过期连接
|
||||||
@@ -231,7 +230,7 @@ class ConnectionPoolManager:
|
|||||||
self._connections.clear()
|
self._connections.clear()
|
||||||
logger.info("所有连接已关闭")
|
logger.info("所有连接已关闭")
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> dict[str, Any]:
|
||||||
"""获取连接池统计信息"""
|
"""获取连接池统计信息"""
|
||||||
return {
|
return {
|
||||||
**self._stats,
|
**self._stats,
|
||||||
@@ -244,7 +243,7 @@ class ConnectionPoolManager:
|
|||||||
|
|
||||||
|
|
||||||
# 全局连接池管理器实例
|
# 全局连接池管理器实例
|
||||||
_connection_pool_manager: Optional[ConnectionPoolManager] = None
|
_connection_pool_manager: ConnectionPoolManager | None = None
|
||||||
|
|
||||||
|
|
||||||
def get_connection_pool_manager() -> ConnectionPoolManager:
|
def get_connection_pool_manager() -> ConnectionPoolManager:
|
||||||
@@ -266,4 +265,4 @@ async def stop_connection_pool():
|
|||||||
global _connection_pool_manager
|
global _connection_pool_manager
|
||||||
if _connection_pool_manager:
|
if _connection_pool_manager:
|
||||||
await _connection_pool_manager.stop()
|
await _connection_pool_manager.stop()
|
||||||
_connection_pool_manager = None
|
_connection_pool_manager = None
|
||||||
|
|||||||
@@ -2,15 +2,16 @@ import os
|
|||||||
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
|
from src.common.database.connection_pool_manager import start_connection_pool, stop_connection_pool
|
||||||
|
|
||||||
|
# 数据库批量调度器和连接池
|
||||||
|
from src.common.database.db_batch_scheduler import get_db_batch_scheduler
|
||||||
|
|
||||||
# SQLAlchemy相关导入
|
# SQLAlchemy相关导入
|
||||||
from src.common.database.sqlalchemy_init import initialize_database_compat
|
from src.common.database.sqlalchemy_init import initialize_database_compat
|
||||||
from src.common.database.sqlalchemy_models import get_db_session, get_engine
|
from src.common.database.sqlalchemy_models import get_db_session, get_engine
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
# 数据库批量调度器和连接池
|
|
||||||
from src.common.database.db_batch_scheduler import get_db_batch_scheduler
|
|
||||||
from src.common.database.connection_pool_manager import start_connection_pool, stop_connection_pool
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
_sql_engine = None
|
_sql_engine = None
|
||||||
|
|||||||
@@ -6,19 +6,19 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from dataclasses import dataclass
|
from collections.abc import Callable
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
from sqlalchemy import select, delete, insert, update
|
from sqlalchemy import delete, insert, select, update
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("db_batch_scheduler")
|
logger = get_logger("db_batch_scheduler")
|
||||||
|
|
||||||
T = TypeVar('T')
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -26,10 +26,10 @@ class BatchOperation:
|
|||||||
"""批量操作基础类"""
|
"""批量操作基础类"""
|
||||||
operation_type: str # 'select', 'insert', 'update', 'delete'
|
operation_type: str # 'select', 'insert', 'update', 'delete'
|
||||||
model_class: Any
|
model_class: Any
|
||||||
conditions: Dict[str, Any]
|
conditions: dict[str, Any]
|
||||||
data: Optional[Dict[str, Any]] = None
|
data: dict[str, Any] | None = None
|
||||||
callback: Optional[Callable] = None
|
callback: Callable | None = None
|
||||||
future: Optional[asyncio.Future] = None
|
future: asyncio.Future | None = None
|
||||||
timestamp: float = 0.0
|
timestamp: float = 0.0
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -42,7 +42,7 @@ class BatchResult:
|
|||||||
"""批量操作结果"""
|
"""批量操作结果"""
|
||||||
success: bool
|
success: bool
|
||||||
data: Any = None
|
data: Any = None
|
||||||
error: Optional[str] = None
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class DatabaseBatchScheduler:
|
class DatabaseBatchScheduler:
|
||||||
@@ -57,23 +57,23 @@ class DatabaseBatchScheduler:
|
|||||||
self.max_queue_size = max_queue_size
|
self.max_queue_size = max_queue_size
|
||||||
|
|
||||||
# 操作队列,按操作类型和模型分类
|
# 操作队列,按操作类型和模型分类
|
||||||
self.operation_queues: Dict[str, deque] = defaultdict(deque)
|
self.operation_queues: dict[str, deque] = defaultdict(deque)
|
||||||
|
|
||||||
# 调度控制
|
# 调度控制
|
||||||
self._scheduler_task: Optional[asyncio.Task] = None
|
self._scheduler_task: asyncio.Task | None = None
|
||||||
self._is_running = bool = False
|
self._is_running = bool = False
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
# 统计信息
|
# 统计信息
|
||||||
self.stats = {
|
self.stats = {
|
||||||
'total_operations': 0,
|
"total_operations": 0,
|
||||||
'batched_operations': 0,
|
"batched_operations": 0,
|
||||||
'cache_hits': 0,
|
"cache_hits": 0,
|
||||||
'execution_time': 0.0
|
"execution_time": 0.0
|
||||||
}
|
}
|
||||||
|
|
||||||
# 简单的结果缓存(用于频繁的查询)
|
# 简单的结果缓存(用于频繁的查询)
|
||||||
self._result_cache: Dict[str, Tuple[Any, float]] = {}
|
self._result_cache: dict[str, tuple[Any, float]] = {}
|
||||||
self._cache_ttl = 5.0 # 5秒缓存
|
self._cache_ttl = 5.0 # 5秒缓存
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
@@ -102,7 +102,7 @@ class DatabaseBatchScheduler:
|
|||||||
await self._flush_all_queues()
|
await self._flush_all_queues()
|
||||||
logger.info("数据库批量调度器已停止")
|
logger.info("数据库批量调度器已停止")
|
||||||
|
|
||||||
def _generate_cache_key(self, operation_type: str, model_class: Any, conditions: Dict[str, Any]) -> str:
|
def _generate_cache_key(self, operation_type: str, model_class: Any, conditions: dict[str, Any]) -> str:
|
||||||
"""生成缓存键"""
|
"""生成缓存键"""
|
||||||
# 简单的缓存键生成,实际可以根据需要优化
|
# 简单的缓存键生成,实际可以根据需要优化
|
||||||
key_parts = [
|
key_parts = [
|
||||||
@@ -112,12 +112,12 @@ class DatabaseBatchScheduler:
|
|||||||
]
|
]
|
||||||
return "|".join(key_parts)
|
return "|".join(key_parts)
|
||||||
|
|
||||||
def _get_from_cache(self, cache_key: str) -> Optional[Any]:
|
def _get_from_cache(self, cache_key: str) -> Any | None:
|
||||||
"""从缓存获取结果"""
|
"""从缓存获取结果"""
|
||||||
if cache_key in self._result_cache:
|
if cache_key in self._result_cache:
|
||||||
result, timestamp = self._result_cache[cache_key]
|
result, timestamp = self._result_cache[cache_key]
|
||||||
if time.time() - timestamp < self._cache_ttl:
|
if time.time() - timestamp < self._cache_ttl:
|
||||||
self.stats['cache_hits'] += 1
|
self.stats["cache_hits"] += 1
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
# 清理过期缓存
|
# 清理过期缓存
|
||||||
@@ -131,7 +131,7 @@ class DatabaseBatchScheduler:
|
|||||||
async def add_operation(self, operation: BatchOperation) -> asyncio.Future:
|
async def add_operation(self, operation: BatchOperation) -> asyncio.Future:
|
||||||
"""添加操作到队列"""
|
"""添加操作到队列"""
|
||||||
# 检查是否可以立即返回缓存结果
|
# 检查是否可以立即返回缓存结果
|
||||||
if operation.operation_type == 'select':
|
if operation.operation_type == "select":
|
||||||
cache_key = self._generate_cache_key(
|
cache_key = self._generate_cache_key(
|
||||||
operation.operation_type,
|
operation.operation_type,
|
||||||
operation.model_class,
|
operation.model_class,
|
||||||
@@ -158,7 +158,7 @@ class DatabaseBatchScheduler:
|
|||||||
await self._execute_operations([operation])
|
await self._execute_operations([operation])
|
||||||
else:
|
else:
|
||||||
self.operation_queues[queue_key].append(operation)
|
self.operation_queues[queue_key].append(operation)
|
||||||
self.stats['total_operations'] += 1
|
self.stats["total_operations"] += 1
|
||||||
|
|
||||||
return future
|
return future
|
||||||
|
|
||||||
@@ -193,7 +193,7 @@ class DatabaseBatchScheduler:
|
|||||||
if operations:
|
if operations:
|
||||||
await self._execute_operations(list(operations))
|
await self._execute_operations(list(operations))
|
||||||
|
|
||||||
async def _execute_operations(self, operations: List[BatchOperation]):
|
async def _execute_operations(self, operations: list[BatchOperation]):
|
||||||
"""执行批量操作"""
|
"""执行批量操作"""
|
||||||
if not operations:
|
if not operations:
|
||||||
return
|
return
|
||||||
@@ -209,13 +209,13 @@ class DatabaseBatchScheduler:
|
|||||||
# 为每种操作类型创建批量执行任务
|
# 为每种操作类型创建批量执行任务
|
||||||
tasks = []
|
tasks = []
|
||||||
for op_type, ops in op_groups.items():
|
for op_type, ops in op_groups.items():
|
||||||
if op_type == 'select':
|
if op_type == "select":
|
||||||
tasks.append(self._execute_select_batch(ops))
|
tasks.append(self._execute_select_batch(ops))
|
||||||
elif op_type == 'insert':
|
elif op_type == "insert":
|
||||||
tasks.append(self._execute_insert_batch(ops))
|
tasks.append(self._execute_insert_batch(ops))
|
||||||
elif op_type == 'update':
|
elif op_type == "update":
|
||||||
tasks.append(self._execute_update_batch(ops))
|
tasks.append(self._execute_update_batch(ops))
|
||||||
elif op_type == 'delete':
|
elif op_type == "delete":
|
||||||
tasks.append(self._execute_delete_batch(ops))
|
tasks.append(self._execute_delete_batch(ops))
|
||||||
|
|
||||||
# 并发执行所有操作
|
# 并发执行所有操作
|
||||||
@@ -238,7 +238,7 @@ class DatabaseBatchScheduler:
|
|||||||
operation.future.set_result(result)
|
operation.future.set_result(result)
|
||||||
|
|
||||||
# 缓存查询结果
|
# 缓存查询结果
|
||||||
if operation.operation_type == 'select':
|
if operation.operation_type == "select":
|
||||||
cache_key = self._generate_cache_key(
|
cache_key = self._generate_cache_key(
|
||||||
operation.operation_type,
|
operation.operation_type,
|
||||||
operation.model_class,
|
operation.model_class,
|
||||||
@@ -246,7 +246,7 @@ class DatabaseBatchScheduler:
|
|||||||
)
|
)
|
||||||
self._set_cache(cache_key, result)
|
self._set_cache(cache_key, result)
|
||||||
|
|
||||||
self.stats['batched_operations'] += len(operations)
|
self.stats["batched_operations"] += len(operations)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"批量操作执行失败: {e}", exc_info="")
|
logger.error(f"批量操作执行失败: {e}", exc_info="")
|
||||||
@@ -255,9 +255,9 @@ class DatabaseBatchScheduler:
|
|||||||
if operation.future and not operation.future.done():
|
if operation.future and not operation.future.done():
|
||||||
operation.future.set_exception(e)
|
operation.future.set_exception(e)
|
||||||
finally:
|
finally:
|
||||||
self.stats['execution_time'] += time.time() - start_time
|
self.stats["execution_time"] += time.time() - start_time
|
||||||
|
|
||||||
async def _execute_select_batch(self, operations: List[BatchOperation]):
|
async def _execute_select_batch(self, operations: list[BatchOperation]):
|
||||||
"""批量执行查询操作"""
|
"""批量执行查询操作"""
|
||||||
# 合并相似的查询条件
|
# 合并相似的查询条件
|
||||||
merged_conditions = self._merge_select_conditions(operations)
|
merged_conditions = self._merge_select_conditions(operations)
|
||||||
@@ -302,7 +302,7 @@ class DatabaseBatchScheduler:
|
|||||||
|
|
||||||
return results if len(results) > 1 else results[0] if results else []
|
return results if len(results) > 1 else results[0] if results else []
|
||||||
|
|
||||||
async def _execute_insert_batch(self, operations: List[BatchOperation]):
|
async def _execute_insert_batch(self, operations: list[BatchOperation]):
|
||||||
"""批量执行插入操作"""
|
"""批量执行插入操作"""
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
try:
|
try:
|
||||||
@@ -323,7 +323,7 @@ class DatabaseBatchScheduler:
|
|||||||
logger.error(f"批量插入失败: {e}", exc_info=True)
|
logger.error(f"批量插入失败: {e}", exc_info=True)
|
||||||
return [0] * len(operations)
|
return [0] * len(operations)
|
||||||
|
|
||||||
async def _execute_update_batch(self, operations: List[BatchOperation]):
|
async def _execute_update_batch(self, operations: list[BatchOperation]):
|
||||||
"""批量执行更新操作"""
|
"""批量执行更新操作"""
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
try:
|
try:
|
||||||
@@ -353,7 +353,7 @@ class DatabaseBatchScheduler:
|
|||||||
logger.error(f"批量更新失败: {e}", exc_info=True)
|
logger.error(f"批量更新失败: {e}", exc_info=True)
|
||||||
return [0] * len(operations)
|
return [0] * len(operations)
|
||||||
|
|
||||||
async def _execute_delete_batch(self, operations: List[BatchOperation]):
|
async def _execute_delete_batch(self, operations: list[BatchOperation]):
|
||||||
"""批量执行删除操作"""
|
"""批量执行删除操作"""
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
try:
|
try:
|
||||||
@@ -382,7 +382,7 @@ class DatabaseBatchScheduler:
|
|||||||
logger.error(f"批量删除失败: {e}", exc_info=True)
|
logger.error(f"批量删除失败: {e}", exc_info=True)
|
||||||
return [0] * len(operations)
|
return [0] * len(operations)
|
||||||
|
|
||||||
def _merge_select_conditions(self, operations: List[BatchOperation]) -> Dict[Tuple, List[BatchOperation]]:
|
def _merge_select_conditions(self, operations: list[BatchOperation]) -> dict[tuple, list[BatchOperation]]:
|
||||||
"""合并相似的查询条件"""
|
"""合并相似的查询条件"""
|
||||||
merged = {}
|
merged = {}
|
||||||
|
|
||||||
@@ -405,15 +405,15 @@ class DatabaseBatchScheduler:
|
|||||||
|
|
||||||
# 记录操作
|
# 记录操作
|
||||||
if condition_key not in merged:
|
if condition_key not in merged:
|
||||||
merged[condition_key] = {'_operations': []}
|
merged[condition_key] = {"_operations": []}
|
||||||
if '_operations' not in merged[condition_key]:
|
if "_operations" not in merged[condition_key]:
|
||||||
merged[condition_key]['_operations'] = []
|
merged[condition_key]["_operations"] = []
|
||||||
merged[condition_key]['_operations'].append(op)
|
merged[condition_key]["_operations"].append(op)
|
||||||
|
|
||||||
# 去重并构建最终条件
|
# 去重并构建最终条件
|
||||||
final_merged = {}
|
final_merged = {}
|
||||||
for condition_key, conditions in merged.items():
|
for condition_key, conditions in merged.items():
|
||||||
operations = conditions.pop('_operations')
|
operations = conditions.pop("_operations")
|
||||||
|
|
||||||
# 去重
|
# 去重
|
||||||
for field_name, values in conditions.items():
|
for field_name, values in conditions.items():
|
||||||
@@ -423,13 +423,13 @@ class DatabaseBatchScheduler:
|
|||||||
|
|
||||||
return final_merged
|
return final_merged
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> dict[str, Any]:
|
||||||
"""获取统计信息"""
|
"""获取统计信息"""
|
||||||
return {
|
return {
|
||||||
**self.stats,
|
**self.stats,
|
||||||
'cache_size': len(self._result_cache),
|
"cache_size": len(self._result_cache),
|
||||||
'queue_sizes': {k: len(v) for k, v in self.operation_queues.items()},
|
"queue_sizes": {k: len(v) for k, v in self.operation_queues.items()},
|
||||||
'is_running': self._is_running
|
"is_running": self._is_running
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -450,20 +450,20 @@ async def get_batch_session():
|
|||||||
|
|
||||||
|
|
||||||
# 便捷函数
|
# 便捷函数
|
||||||
async def batch_select(model_class: Any, conditions: Dict[str, Any]) -> Any:
|
async def batch_select(model_class: Any, conditions: dict[str, Any]) -> Any:
|
||||||
"""批量查询"""
|
"""批量查询"""
|
||||||
operation = BatchOperation(
|
operation = BatchOperation(
|
||||||
operation_type='select',
|
operation_type="select",
|
||||||
model_class=model_class,
|
model_class=model_class,
|
||||||
conditions=conditions
|
conditions=conditions
|
||||||
)
|
)
|
||||||
return await db_batch_scheduler.add_operation(operation)
|
return await db_batch_scheduler.add_operation(operation)
|
||||||
|
|
||||||
|
|
||||||
async def batch_insert(model_class: Any, data: Dict[str, Any]) -> int:
|
async def batch_insert(model_class: Any, data: dict[str, Any]) -> int:
|
||||||
"""批量插入"""
|
"""批量插入"""
|
||||||
operation = BatchOperation(
|
operation = BatchOperation(
|
||||||
operation_type='insert',
|
operation_type="insert",
|
||||||
model_class=model_class,
|
model_class=model_class,
|
||||||
conditions={},
|
conditions={},
|
||||||
data=data
|
data=data
|
||||||
@@ -471,10 +471,10 @@ async def batch_insert(model_class: Any, data: Dict[str, Any]) -> int:
|
|||||||
return await db_batch_scheduler.add_operation(operation)
|
return await db_batch_scheduler.add_operation(operation)
|
||||||
|
|
||||||
|
|
||||||
async def batch_update(model_class: Any, conditions: Dict[str, Any], data: Dict[str, Any]) -> int:
|
async def batch_update(model_class: Any, conditions: dict[str, Any], data: dict[str, Any]) -> int:
|
||||||
"""批量更新"""
|
"""批量更新"""
|
||||||
operation = BatchOperation(
|
operation = BatchOperation(
|
||||||
operation_type='update',
|
operation_type="update",
|
||||||
model_class=model_class,
|
model_class=model_class,
|
||||||
conditions=conditions,
|
conditions=conditions,
|
||||||
data=data
|
data=data
|
||||||
@@ -482,10 +482,10 @@ async def batch_update(model_class: Any, conditions: Dict[str, Any], data: Dict[
|
|||||||
return await db_batch_scheduler.add_operation(operation)
|
return await db_batch_scheduler.add_operation(operation)
|
||||||
|
|
||||||
|
|
||||||
async def batch_delete(model_class: Any, conditions: Dict[str, Any]) -> int:
|
async def batch_delete(model_class: Any, conditions: dict[str, Any]) -> int:
|
||||||
"""批量删除"""
|
"""批量删除"""
|
||||||
operation = BatchOperation(
|
operation = BatchOperation(
|
||||||
operation_type='delete',
|
operation_type="delete",
|
||||||
model_class=model_class,
|
model_class=model_class,
|
||||||
conditions=conditions
|
conditions=conditions
|
||||||
)
|
)
|
||||||
@@ -494,4 +494,4 @@ async def batch_delete(model_class: Any, conditions: Dict[str, Any]) -> int:
|
|||||||
|
|
||||||
def get_db_batch_scheduler() -> DatabaseBatchScheduler:
|
def get_db_batch_scheduler() -> DatabaseBatchScheduler:
|
||||||
"""获取数据库批量调度器实例"""
|
"""获取数据库批量调度器实例"""
|
||||||
return db_batch_scheduler
|
return db_batch_scheduler
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_asyn
|
|||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.common.database.connection_pool_manager import get_connection_pool_manager
|
from src.common.database.connection_pool_manager import get_connection_pool_manager
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("sqlalchemy_models")
|
logger = get_logger("sqlalchemy_models")
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
# 使用基于时间戳的文件处理器,简单的轮转份数限制
|
# 使用基于时间戳的文件处理器,简单的轮转份数限制
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import tarfile
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import tarfile
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional, Dict
|
from typing import Any
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
import structlog
|
import structlog
|
||||||
@@ -18,15 +18,15 @@ LOG_DIR = Path("logs")
|
|||||||
LOG_DIR.mkdir(exist_ok=True)
|
LOG_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
# 全局handler实例,避免重复创建(可能为None表示禁用文件日志)
|
# 全局handler实例,避免重复创建(可能为None表示禁用文件日志)
|
||||||
_file_handler: Optional[logging.Handler] = None
|
_file_handler: logging.Handler | None = None
|
||||||
_console_handler: Optional[logging.Handler] = None
|
_console_handler: logging.Handler | None = None
|
||||||
|
|
||||||
# 动态 logger 元数据注册表 (name -> {alias:str|None, color:str|None})
|
# 动态 logger 元数据注册表 (name -> {alias:str|None, color:str|None})
|
||||||
_LOGGER_META_LOCK = threading.Lock()
|
_LOGGER_META_LOCK = threading.Lock()
|
||||||
_LOGGER_META: Dict[str, Dict[str, Optional[str]]] = {}
|
_LOGGER_META: dict[str, dict[str, str | None]] = {}
|
||||||
|
|
||||||
|
|
||||||
def _normalize_color(color: Optional[str]) -> Optional[str]:
|
def _normalize_color(color: str | None) -> str | None:
|
||||||
"""接受 ANSI 码 / #RRGGBB / rgb(r,g,b) / 颜色名(直接返回) -> ANSI 码.
|
"""接受 ANSI 码 / #RRGGBB / rgb(r,g,b) / 颜色名(直接返回) -> ANSI 码.
|
||||||
不做复杂解析,只支持 #RRGGBB 转 24bit ANSI。
|
不做复杂解析,只支持 #RRGGBB 转 24bit ANSI。
|
||||||
"""
|
"""
|
||||||
@@ -49,13 +49,13 @@ def _normalize_color(color: Optional[str]) -> Optional[str]:
|
|||||||
nums = color[color.find("(") + 1 : -1].split(",")
|
nums = color[color.find("(") + 1 : -1].split(",")
|
||||||
r, g, b = (int(x) for x in nums[:3])
|
r, g, b = (int(x) for x in nums[:3])
|
||||||
return f"\033[38;2;{r};{g};{b}m"
|
return f"\033[38;2;{r};{g};{b}m"
|
||||||
except Exception: # noqa: BLE001
|
except Exception:
|
||||||
return None
|
return None
|
||||||
# 其他情况直接返回,假设是短ANSI或名称(控制台渲染器不做翻译,仅输出)
|
# 其他情况直接返回,假设是短ANSI或名称(控制台渲染器不做翻译,仅输出)
|
||||||
return color
|
return color
|
||||||
|
|
||||||
|
|
||||||
def _register_logger_meta(name: str, *, alias: Optional[str] = None, color: Optional[str] = None):
|
def _register_logger_meta(name: str, *, alias: str | None = None, color: str | None = None):
|
||||||
"""注册/更新 logger 元数据。"""
|
"""注册/更新 logger 元数据。"""
|
||||||
if not name:
|
if not name:
|
||||||
return
|
return
|
||||||
@@ -67,7 +67,7 @@ def _register_logger_meta(name: str, *, alias: Optional[str] = None, color: Opti
|
|||||||
meta["color"] = _normalize_color(color)
|
meta["color"] = _normalize_color(color)
|
||||||
|
|
||||||
|
|
||||||
def get_logger_meta(name: str) -> Dict[str, Optional[str]]:
|
def get_logger_meta(name: str) -> dict[str, str | None]:
|
||||||
with _LOGGER_META_LOCK:
|
with _LOGGER_META_LOCK:
|
||||||
return _LOGGER_META.get(name, {"alias": None, "color": None}).copy()
|
return _LOGGER_META.get(name, {"alias": None, "color": None}).copy()
|
||||||
|
|
||||||
@@ -170,7 +170,7 @@ class TimestampedFileHandler(logging.Handler):
|
|||||||
try:
|
try:
|
||||||
self._compress_stale_logs()
|
self._compress_stale_logs()
|
||||||
self._cleanup_old_files()
|
self._cleanup_old_files()
|
||||||
except Exception as e: # noqa: BLE001
|
except Exception as e:
|
||||||
print(f"[日志轮转] 轮转过程出错: {e}")
|
print(f"[日志轮转] 轮转过程出错: {e}")
|
||||||
|
|
||||||
def _compress_stale_logs(self): # sourcery skip: extract-method
|
def _compress_stale_logs(self): # sourcery skip: extract-method
|
||||||
@@ -184,12 +184,12 @@ class TimestampedFileHandler(logging.Handler):
|
|||||||
continue
|
continue
|
||||||
# 压缩
|
# 压缩
|
||||||
try:
|
try:
|
||||||
with tarfile.open(tar_path, "w:gz") as tf: # noqa: SIM117
|
with tarfile.open(tar_path, "w:gz") as tf:
|
||||||
tf.add(f, arcname=f.name)
|
tf.add(f, arcname=f.name)
|
||||||
f.unlink(missing_ok=True)
|
f.unlink(missing_ok=True)
|
||||||
except Exception as e: # noqa: BLE001
|
except Exception as e:
|
||||||
print(f"[日志压缩] 压缩 {f.name} 失败: {e}")
|
print(f"[日志压缩] 压缩 {f.name} 失败: {e}")
|
||||||
except Exception as e: # noqa: BLE001
|
except Exception as e:
|
||||||
print(f"[日志压缩] 过程出错: {e}")
|
print(f"[日志压缩] 过程出错: {e}")
|
||||||
|
|
||||||
def _cleanup_old_files(self):
|
def _cleanup_old_files(self):
|
||||||
@@ -206,9 +206,9 @@ class TimestampedFileHandler(logging.Handler):
|
|||||||
mtime = datetime.fromtimestamp(f.stat().st_mtime)
|
mtime = datetime.fromtimestamp(f.stat().st_mtime)
|
||||||
if mtime < cutoff:
|
if mtime < cutoff:
|
||||||
f.unlink(missing_ok=True)
|
f.unlink(missing_ok=True)
|
||||||
except Exception as e: # noqa: BLE001
|
except Exception as e:
|
||||||
print(f"[日志清理] 删除 {f} 失败: {e}")
|
print(f"[日志清理] 删除 {f} 失败: {e}")
|
||||||
except Exception as e: # noqa: BLE001
|
except Exception as e:
|
||||||
print(f"[日志清理] 清理过程出错: {e}")
|
print(f"[日志清理] 清理过程出错: {e}")
|
||||||
|
|
||||||
def emit(self, record):
|
def emit(self, record):
|
||||||
@@ -850,7 +850,7 @@ class ModuleColoredConsoleRenderer:
|
|||||||
if logger_name:
|
if logger_name:
|
||||||
# 获取别名,如果没有别名则使用原名称
|
# 获取别名,如果没有别名则使用原名称
|
||||||
# 若上面条件不成立需要再次获取 meta
|
# 若上面条件不成立需要再次获取 meta
|
||||||
if 'meta' not in locals():
|
if "meta" not in locals():
|
||||||
meta = get_logger_meta(logger_name)
|
meta = get_logger_meta(logger_name)
|
||||||
display_name = meta.get("alias") or DEFAULT_MODULE_ALIASES.get(logger_name, logger_name)
|
display_name = meta.get("alias") or DEFAULT_MODULE_ALIASES.get(logger_name, logger_name)
|
||||||
|
|
||||||
@@ -1066,7 +1066,7 @@ raw_logger: structlog.stdlib.BoundLogger = structlog.get_logger()
|
|||||||
binds: dict[str, Callable] = {}
|
binds: dict[str, Callable] = {}
|
||||||
|
|
||||||
|
|
||||||
def get_logger(name: str | None, *, color: Optional[str] = None, alias: Optional[str] = None) -> structlog.stdlib.BoundLogger:
|
def get_logger(name: str | None, *, color: str | None = None, alias: str | None = None) -> structlog.stdlib.BoundLogger:
|
||||||
"""获取/创建 structlog logger。
|
"""获取/创建 structlog logger。
|
||||||
|
|
||||||
新增:
|
新增:
|
||||||
@@ -1132,10 +1132,10 @@ def cleanup_old_logs():
|
|||||||
tar_path = f.with_suffix(f.suffix + ".tar.gz")
|
tar_path = f.with_suffix(f.suffix + ".tar.gz")
|
||||||
if tar_path.exists():
|
if tar_path.exists():
|
||||||
continue
|
continue
|
||||||
with tarfile.open(tar_path, "w:gz") as tf: # noqa: SIM117
|
with tarfile.open(tar_path, "w:gz") as tf:
|
||||||
tf.add(f, arcname=f.name)
|
tf.add(f, arcname=f.name)
|
||||||
f.unlink(missing_ok=True)
|
f.unlink(missing_ok=True)
|
||||||
except Exception as e: # noqa: BLE001
|
except Exception as e:
|
||||||
logger = get_logger("logger")
|
logger = get_logger("logger")
|
||||||
logger.warning(f"周期压缩日志时出错: {e}")
|
logger.warning(f"周期压缩日志时出错: {e}")
|
||||||
|
|
||||||
@@ -1152,7 +1152,7 @@ def cleanup_old_logs():
|
|||||||
log_file.unlink(missing_ok=True)
|
log_file.unlink(missing_ok=True)
|
||||||
deleted_count += 1
|
deleted_count += 1
|
||||||
deleted_size += size
|
deleted_size += size
|
||||||
except Exception as e: # noqa: BLE001
|
except Exception as e:
|
||||||
logger = get_logger("logger")
|
logger = get_logger("logger")
|
||||||
logger.warning(f"清理日志文件 {log_file} 时出错: {e}")
|
logger.warning(f"清理日志文件 {log_file} 时出错: {e}")
|
||||||
if deleted_count:
|
if deleted_count:
|
||||||
@@ -1160,7 +1160,7 @@ def cleanup_old_logs():
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"清理 {deleted_count} 个过期日志 (≈{deleted_size / 1024 / 1024:.2f}MB), 保留策略={retention_days}天"
|
f"清理 {deleted_count} 个过期日志 (≈{deleted_size / 1024 / 1024:.2f}MB), 保留策略={retention_days}天"
|
||||||
)
|
)
|
||||||
except Exception as e: # noqa: BLE001
|
except Exception as e:
|
||||||
logger = get_logger("logger")
|
logger = get_logger("logger")
|
||||||
logger.error(f"清理旧日志文件时出错: {e}")
|
logger.error(f"清理旧日志文件时出错: {e}")
|
||||||
|
|
||||||
@@ -1183,7 +1183,7 @@ def start_log_cleanup_task():
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
cleanup_old_logs()
|
cleanup_old_logs()
|
||||||
except Exception as e: # noqa: BLE001
|
except Exception as e:
|
||||||
print(f"[日志任务] 执行清理出错: {e}")
|
print(f"[日志任务] 执行清理出错: {e}")
|
||||||
# 再次等待到下一个午夜
|
# 再次等待到下一个午夜
|
||||||
time.sleep(max(1, seconds_until_next_midnight()))
|
time.sleep(max(1, seconds_until_next_midnight()))
|
||||||
|
|||||||
24
src/main.py
24
src/main.py
@@ -120,10 +120,10 @@ class MainSystem:
|
|||||||
logger.warning("未发现任何兴趣计算器组件")
|
logger.warning("未发现任何兴趣计算器组件")
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"发现的兴趣计算器组件:")
|
logger.info("发现的兴趣计算器组件:")
|
||||||
for calc_name, calc_info in interest_calculators.items():
|
for calc_name, calc_info in interest_calculators.items():
|
||||||
enabled = getattr(calc_info, 'enabled', True)
|
enabled = getattr(calc_info, "enabled", True)
|
||||||
default_enabled = getattr(calc_info, 'enabled_by_default', True)
|
default_enabled = getattr(calc_info, "enabled_by_default", True)
|
||||||
logger.info(f" - {calc_name}: 启用: {enabled}, 默认启用: {default_enabled}")
|
logger.info(f" - {calc_name}: 启用: {enabled}, 默认启用: {default_enabled}")
|
||||||
|
|
||||||
# 初始化兴趣度管理器
|
# 初始化兴趣度管理器
|
||||||
@@ -136,8 +136,8 @@ class MainSystem:
|
|||||||
|
|
||||||
# 使用组件注册表获取组件类并注册
|
# 使用组件注册表获取组件类并注册
|
||||||
for calc_name, calc_info in interest_calculators.items():
|
for calc_name, calc_info in interest_calculators.items():
|
||||||
enabled = getattr(calc_info, 'enabled', True)
|
enabled = getattr(calc_info, "enabled", True)
|
||||||
default_enabled = getattr(calc_info, 'enabled_by_default', True)
|
default_enabled = getattr(calc_info, "enabled_by_default", True)
|
||||||
|
|
||||||
if not enabled or not default_enabled:
|
if not enabled or not default_enabled:
|
||||||
logger.info(f"兴趣计算器 {calc_name} 未启用,跳过")
|
logger.info(f"兴趣计算器 {calc_name} 未启用,跳过")
|
||||||
@@ -183,7 +183,7 @@ class MainSystem:
|
|||||||
async def _async_cleanup(self):
|
async def _async_cleanup(self):
|
||||||
"""异步清理资源"""
|
"""异步清理资源"""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# 停止数据库服务
|
# 停止数据库服务
|
||||||
try:
|
try:
|
||||||
from src.common.database.database import stop_database
|
from src.common.database.database import stop_database
|
||||||
@@ -343,8 +343,8 @@ MoFox_Bot(第三方修改版)
|
|||||||
# 初始化表情管理器
|
# 初始化表情管理器
|
||||||
get_emoji_manager().initialize()
|
get_emoji_manager().initialize()
|
||||||
logger.info("表情包管理器初始化成功")
|
logger.info("表情包管理器初始化成功")
|
||||||
|
|
||||||
'''
|
"""
|
||||||
# 初始化回复后关系追踪系统
|
# 初始化回复后关系追踪系统
|
||||||
try:
|
try:
|
||||||
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
|
||||||
@@ -356,8 +356,8 @@ MoFox_Bot(第三方修改版)
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"回复后关系追踪系统初始化失败: {e}")
|
logger.error(f"回复后关系追踪系统初始化失败: {e}")
|
||||||
relationship_tracker = None
|
relationship_tracker = None
|
||||||
'''
|
"""
|
||||||
|
|
||||||
# 启动情绪管理器
|
# 启动情绪管理器
|
||||||
await mood_manager.start()
|
await mood_manager.start()
|
||||||
logger.info("情绪管理器初始化成功")
|
logger.info("情绪管理器初始化成功")
|
||||||
@@ -487,10 +487,10 @@ MoFox_Bot(第三方修改版)
|
|||||||
# 关闭应用 (MessageServer可能没有shutdown方法)
|
# 关闭应用 (MessageServer可能没有shutdown方法)
|
||||||
try:
|
try:
|
||||||
if self.app:
|
if self.app:
|
||||||
if hasattr(self.app, 'shutdown'):
|
if hasattr(self.app, "shutdown"):
|
||||||
await self.app.shutdown()
|
await self.app.shutdown()
|
||||||
logger.info("应用已关闭")
|
logger.info("应用已关闭")
|
||||||
elif hasattr(self.app, 'stop'):
|
elif hasattr(self.app, "stop"):
|
||||||
await self.app.stop()
|
await self.app.stop()
|
||||||
logger.info("应用已停止")
|
logger.info("应用已停止")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import math
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.chat.message_receive.message import MessageRecv
|
||||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
get_raw_msg_before_timestamp_with_chat,
|
get_raw_msg_before_timestamp_with_chat,
|
||||||
get_raw_msg_by_timestamp_with_chat,
|
get_raw_msg_by_timestamp_with_chat,
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from typing import Any
|
|||||||
import orjson
|
import orjson
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
|
||||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ from datetime import datetime
|
|||||||
from difflib import SequenceMatcher
|
from difflib import SequenceMatcher
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import rjieba
|
|
||||||
import orjson
|
import orjson
|
||||||
|
import rjieba
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
from sklearn.metrics.pairwise import cosine_similarity
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
|||||||
@@ -49,7 +49,6 @@ from .base import (
|
|||||||
ToolParamType,
|
ToolParamType,
|
||||||
create_plus_command_adapter,
|
create_plus_command_adapter,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .utils.dependency_config import configure_dependency_settings, get_dependency_config
|
from .utils.dependency_config import configure_dependency_settings, get_dependency_config
|
||||||
|
|
||||||
# 导入依赖管理模块
|
# 导入依赖管理模块
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ class BaseInterestCalculator(ABC):
|
|||||||
try:
|
try:
|
||||||
self._enabled = True
|
self._enabled = True
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception:
|
||||||
self._enabled = False
|
self._enabled = False
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -170,7 +170,7 @@ class BaseInterestCalculator(ABC):
|
|||||||
if not self._enabled:
|
if not self._enabled:
|
||||||
return InterestCalculationResult(
|
return InterestCalculationResult(
|
||||||
success=False,
|
success=False,
|
||||||
message_id=getattr(message, 'message_id', ''),
|
message_id=getattr(message, "message_id", ""),
|
||||||
interest_value=0.0,
|
interest_value=0.0,
|
||||||
error_message="组件未启用"
|
error_message="组件未启用"
|
||||||
)
|
)
|
||||||
@@ -184,9 +184,9 @@ class BaseInterestCalculator(ABC):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
result = InterestCalculationResult(
|
result = InterestCalculationResult(
|
||||||
success=False,
|
success=False,
|
||||||
message_id=getattr(message, 'message_id', ''),
|
message_id=getattr(message, "message_id", ""),
|
||||||
interest_value=0.0,
|
interest_value=0.0,
|
||||||
error_message=f"计算执行失败: {str(e)}",
|
error_message=f"计算执行失败: {e!s}",
|
||||||
calculation_time=time.time() - start_time
|
calculation_time=time.time() - start_time
|
||||||
)
|
)
|
||||||
self._update_statistics(result)
|
self._update_statistics(result)
|
||||||
@@ -201,7 +201,7 @@ class BaseInterestCalculator(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
InterestCalculatorInfo: 生成的兴趣计算器信息对象
|
InterestCalculatorInfo: 生成的兴趣计算器信息对象
|
||||||
"""
|
"""
|
||||||
name = getattr(cls, 'component_name', cls.__name__.lower().replace('calculator', ''))
|
name = getattr(cls, "component_name", cls.__name__.lower().replace("calculator", ""))
|
||||||
if "." in name:
|
if "." in name:
|
||||||
logger.error(f"InterestCalculator名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
logger.error(f"InterestCalculator名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
||||||
raise ValueError(f"InterestCalculator名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
raise ValueError(f"InterestCalculator名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
||||||
@@ -209,12 +209,12 @@ class BaseInterestCalculator(ABC):
|
|||||||
return InterestCalculatorInfo(
|
return InterestCalculatorInfo(
|
||||||
name=name,
|
name=name,
|
||||||
component_type=ComponentType.INTEREST_CALCULATOR,
|
component_type=ComponentType.INTEREST_CALCULATOR,
|
||||||
description=getattr(cls, 'component_description', cls.__doc__ or "兴趣度计算器"),
|
description=getattr(cls, "component_description", cls.__doc__ or "兴趣度计算器"),
|
||||||
enabled_by_default=getattr(cls, 'enabled_by_default', True),
|
enabled_by_default=getattr(cls, "enabled_by_default", True),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"{self.__class__.__name__}("
|
return (f"{self.__class__.__name__}("
|
||||||
f"name={self.component_name}, "
|
f"name={self.component_name}, "
|
||||||
f"version={self.component_version}, "
|
f"version={self.component_version}, "
|
||||||
f"enabled={self._enabled})")
|
f"enabled={self._enabled})")
|
||||||
|
|||||||
@@ -43,21 +43,21 @@ class BasePlugin(PluginBase):
|
|||||||
对应类型的ComponentInfo对象
|
对应类型的ComponentInfo对象
|
||||||
"""
|
"""
|
||||||
if component_type == ComponentType.COMMAND:
|
if component_type == ComponentType.COMMAND:
|
||||||
if hasattr(component_class, 'get_command_info'):
|
if hasattr(component_class, "get_command_info"):
|
||||||
return component_class.get_command_info()
|
return component_class.get_command_info()
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Command类 {component_class.__name__} 缺少 get_command_info 方法")
|
logger.warning(f"Command类 {component_class.__name__} 缺少 get_command_info 方法")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
elif component_type == ComponentType.ACTION:
|
elif component_type == ComponentType.ACTION:
|
||||||
if hasattr(component_class, 'get_action_info'):
|
if hasattr(component_class, "get_action_info"):
|
||||||
return component_class.get_action_info()
|
return component_class.get_action_info()
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Action类 {component_class.__name__} 缺少 get_action_info 方法")
|
logger.warning(f"Action类 {component_class.__name__} 缺少 get_action_info 方法")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
elif component_type == ComponentType.INTEREST_CALCULATOR:
|
elif component_type == ComponentType.INTEREST_CALCULATOR:
|
||||||
if hasattr(component_class, 'get_interest_calculator_info'):
|
if hasattr(component_class, "get_interest_calculator_info"):
|
||||||
return component_class.get_interest_calculator_info()
|
return component_class.get_interest_calculator_info()
|
||||||
else:
|
else:
|
||||||
logger.warning(f"InterestCalculator类 {component_class.__name__} 缺少 get_interest_calculator_info 方法")
|
logger.warning(f"InterestCalculator类 {component_class.__name__} 缺少 get_interest_calculator_info 方法")
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Optional, Set
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PluginMetadata:
|
class PluginMetadata:
|
||||||
@@ -11,15 +12,15 @@ class PluginMetadata:
|
|||||||
usage: str # 插件使用方法
|
usage: str # 插件使用方法
|
||||||
|
|
||||||
# 以下为可选字段,参考自 _manifest.json 和 NoneBot 设计
|
# 以下为可选字段,参考自 _manifest.json 和 NoneBot 设计
|
||||||
type: Optional[str] = None # 插件类别: "library", "application"
|
type: str | None = None # 插件类别: "library", "application"
|
||||||
|
|
||||||
# 从原 _manifest.json 迁移的字段
|
# 从原 _manifest.json 迁移的字段
|
||||||
version: str = "1.0.0" # 插件版本
|
version: str = "1.0.0" # 插件版本
|
||||||
author: str = "" # 作者名称
|
author: str = "" # 作者名称
|
||||||
license: Optional[str] = None # 开源协议
|
license: str | None = None # 开源协议
|
||||||
repository_url: Optional[str] = None # 仓库地址
|
repository_url: str | None = None # 仓库地址
|
||||||
keywords: List[str] = field(default_factory=list) # 关键词
|
keywords: list[str] = field(default_factory=list) # 关键词
|
||||||
categories: List[str] = field(default_factory=list) # 分类
|
categories: list[str] = field(default_factory=list) # 分类
|
||||||
|
|
||||||
# 扩展字段
|
# 扩展字段
|
||||||
extra: Dict[str, Any] = field(default_factory=dict) # 其他任意信息
|
extra: dict[str, Any] = field(default_factory=dict) # 其他任意信息
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
import traceback
|
|
||||||
from importlib.util import module_from_spec, spec_from_file_location
|
from importlib.util import module_from_spec, spec_from_file_location
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
@@ -104,7 +103,7 @@ class PluginManager:
|
|||||||
return False, 1
|
return False, 1
|
||||||
|
|
||||||
module = self.plugin_modules.get(plugin_name)
|
module = self.plugin_modules.get(plugin_name)
|
||||||
|
|
||||||
if not module or not hasattr(module, "__plugin_meta__"):
|
if not module or not hasattr(module, "__plugin_meta__"):
|
||||||
self.failed_plugins[plugin_name] = "插件模块中缺少 __plugin_meta__"
|
self.failed_plugins[plugin_name] = "插件模块中缺少 __plugin_meta__"
|
||||||
logger.error(f"❌ 插件加载失败: {plugin_name} - 缺少 __plugin_meta__")
|
logger.error(f"❌ 插件加载失败: {plugin_name} - 缺少 __plugin_meta__")
|
||||||
@@ -288,7 +287,7 @@ class PluginManager:
|
|||||||
|
|
||||||
return loaded_count, failed_count
|
return loaded_count, failed_count
|
||||||
|
|
||||||
def _load_plugin_module_file(self, plugin_file: str) -> Optional[Any]:
|
def _load_plugin_module_file(self, plugin_file: str) -> Any | None:
|
||||||
# sourcery skip: extract-method
|
# sourcery skip: extract-method
|
||||||
"""加载单个插件模块文件
|
"""加载单个插件模块文件
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import inspect
|
|||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
|
||||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||||
from src.common.cache_manager import tool_cache
|
from src.common.cache_manager import tool_cache
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|||||||
@@ -2,4 +2,4 @@
|
|||||||
插件系统工具模块
|
插件系统工具模块
|
||||||
|
|
||||||
提供插件开发和管理的实用工具
|
提供插件开发和管理的实用工具
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
|||||||
# 用户关系数据缓存
|
# 用户关系数据缓存
|
||||||
self.user_relationships: dict[str, float] = {} # user_id -> relationship_score
|
self.user_relationships: dict[str, float] = {} # user_id -> relationship_score
|
||||||
|
|
||||||
logger.info(f"[Affinity兴趣计算器] 初始化完成:")
|
logger.info("[Affinity兴趣计算器] 初始化完成:")
|
||||||
logger.info(f" - 权重配置: {self.score_weights}")
|
logger.info(f" - 权重配置: {self.score_weights}")
|
||||||
logger.info(f" - 回复阈值: {self.reply_threshold}")
|
logger.info(f" - 回复阈值: {self.reply_threshold}")
|
||||||
logger.info(f" - 智能匹配: {self.use_smart_matching}")
|
logger.info(f" - 智能匹配: {self.use_smart_matching}")
|
||||||
@@ -69,9 +69,9 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
|||||||
"""执行AffinityFlow风格的兴趣值计算"""
|
"""执行AffinityFlow风格的兴趣值计算"""
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
message_id = getattr(message, 'message_id', '')
|
message_id = getattr(message, "message_id", "")
|
||||||
content = getattr(message, 'processed_plain_text', '')
|
content = getattr(message, "processed_plain_text", "")
|
||||||
user_id = getattr(message, 'user_info', {}).user_id if hasattr(message, 'user_info') and hasattr(message.user_info, 'user_id') else ''
|
user_id = getattr(message, "user_info", {}).user_id if hasattr(message, "user_info") and hasattr(message.user_info, "user_id") else ""
|
||||||
|
|
||||||
logger.debug(f"[Affinity兴趣计算] 开始处理消息 {message_id}")
|
logger.debug(f"[Affinity兴趣计算] 开始处理消息 {message_id}")
|
||||||
logger.debug(f"[Affinity兴趣计算] 消息内容: {content[:50]}...")
|
logger.debug(f"[Affinity兴趣计算] 消息内容: {content[:50]}...")
|
||||||
@@ -135,7 +135,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
|||||||
logger.error(f"Affinity兴趣值计算失败: {e}", exc_info=True)
|
logger.error(f"Affinity兴趣值计算失败: {e}", exc_info=True)
|
||||||
return InterestCalculationResult(
|
return InterestCalculationResult(
|
||||||
success=False,
|
success=False,
|
||||||
message_id=getattr(message, 'message_id', ''),
|
message_id=getattr(message, "message_id", ""),
|
||||||
interest_value=0.0,
|
interest_value=0.0,
|
||||||
error_message=str(e)
|
error_message=str(e)
|
||||||
)
|
)
|
||||||
@@ -206,9 +206,9 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
|||||||
|
|
||||||
def _calculate_mentioned_score(self, message: "DatabaseMessages", bot_nickname: str) -> float:
|
def _calculate_mentioned_score(self, message: "DatabaseMessages", bot_nickname: str) -> float:
|
||||||
"""计算提及分"""
|
"""计算提及分"""
|
||||||
is_mentioned = getattr(message, 'is_mentioned', False)
|
is_mentioned = getattr(message, "is_mentioned", False)
|
||||||
is_at = getattr(message, 'is_at', False)
|
is_at = getattr(message, "is_at", False)
|
||||||
processed_plain_text = getattr(message, 'processed_plain_text', '')
|
processed_plain_text = getattr(message, "processed_plain_text", "")
|
||||||
|
|
||||||
if is_mentioned:
|
if is_mentioned:
|
||||||
if is_at:
|
if is_at:
|
||||||
@@ -238,7 +238,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
|||||||
keywords = []
|
keywords = []
|
||||||
|
|
||||||
# 尝试从 key_words 字段提取(存储的是JSON字符串)
|
# 尝试从 key_words 字段提取(存储的是JSON字符串)
|
||||||
key_words = getattr(message, 'key_words', '')
|
key_words = getattr(message, "key_words", "")
|
||||||
if key_words:
|
if key_words:
|
||||||
try:
|
try:
|
||||||
import orjson
|
import orjson
|
||||||
@@ -250,7 +250,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
|||||||
|
|
||||||
# 如果没有 keywords,尝试从 key_words_lite 提取
|
# 如果没有 keywords,尝试从 key_words_lite 提取
|
||||||
if not keywords:
|
if not keywords:
|
||||||
key_words_lite = getattr(message, 'key_words_lite', '')
|
key_words_lite = getattr(message, "key_words_lite", "")
|
||||||
if key_words_lite:
|
if key_words_lite:
|
||||||
try:
|
try:
|
||||||
import orjson
|
import orjson
|
||||||
@@ -262,7 +262,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
|||||||
|
|
||||||
# 如果还是没有,从消息内容中提取(降级方案)
|
# 如果还是没有,从消息内容中提取(降级方案)
|
||||||
if not keywords:
|
if not keywords:
|
||||||
content = getattr(message, 'processed_plain_text', '') or ''
|
content = getattr(message, "processed_plain_text", "") or ""
|
||||||
keywords = self._extract_keywords_from_content(content)
|
keywords = self._extract_keywords_from_content(content)
|
||||||
|
|
||||||
return keywords[:15] # 返回前15个关键词
|
return keywords[:15] # 返回前15个关键词
|
||||||
@@ -298,4 +298,4 @@ class AffinityInterestCalculator(BaseInterestCalculator):
|
|||||||
self.no_reply_count = min(self.no_reply_count + 1, self.max_no_reply_count)
|
self.no_reply_count = min(self.no_reply_count + 1, self.max_no_reply_count)
|
||||||
|
|
||||||
# 是否使用智能兴趣匹配(作为类属性)
|
# 是否使用智能兴趣匹配(作为类属性)
|
||||||
use_smart_matching = True
|
use_smart_matching = True
|
||||||
|
|||||||
@@ -107,9 +107,9 @@ class ChatterActionPlanner:
|
|||||||
# 直接使用消息中已计算的标志,无需重复计算兴趣值
|
# 直接使用消息中已计算的标志,无需重复计算兴趣值
|
||||||
for message in unread_messages:
|
for message in unread_messages:
|
||||||
try:
|
try:
|
||||||
message_interest = getattr(message, 'interest_value', 0.3)
|
message_interest = getattr(message, "interest_value", 0.3)
|
||||||
message_should_reply = getattr(message, 'should_reply', False)
|
message_should_reply = getattr(message, "should_reply", False)
|
||||||
message_should_act = getattr(message, 'should_act', False)
|
message_should_act = getattr(message, "should_act", False)
|
||||||
|
|
||||||
# 确保interest_value不是None
|
# 确保interest_value不是None
|
||||||
if message_interest is None:
|
if message_interest is None:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.apis.plugin_register_api import register_plugin
|
from src.plugin_system.apis.plugin_register_api import register_plugin
|
||||||
from src.plugin_system.base.base_plugin import BasePlugin
|
from src.plugin_system.base.base_plugin import BasePlugin
|
||||||
from src.plugin_system.base.component_types import ComponentInfo, ComponentType, InterestCalculatorInfo
|
from src.plugin_system.base.component_types import ComponentInfo
|
||||||
|
|
||||||
logger = get_logger("affinity_chatter_plugin")
|
logger = get_logger("affinity_chatter_plugin")
|
||||||
|
|
||||||
@@ -52,4 +52,3 @@ class AffinityChatterPlugin(BasePlugin):
|
|||||||
|
|
||||||
return components
|
return components
|
||||||
|
|
||||||
|
|
||||||
@@ -14,4 +14,4 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
"is_built_in": True,
|
"is_built_in": True,
|
||||||
"plugin_type": "action_provider",
|
"plugin_type": "action_provider",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -13,4 +13,4 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
"is_built_in": True,
|
"is_built_in": True,
|
||||||
"plugin_type": "permission",
|
"plugin_type": "permission",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -14,4 +14,4 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
"is_built_in": True,
|
"is_built_in": True,
|
||||||
"plugin_type": "plugin_management",
|
"plugin_type": "plugin_management",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -14,4 +14,4 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
"is_built_in": True,
|
"is_built_in": True,
|
||||||
"plugin_type": "functional"
|
"plugin_type": "functional"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -65,10 +65,10 @@ class ColdStartTask(AsyncTask):
|
|||||||
nickname = await person_api.get_person_value(person_id, "nickname")
|
nickname = await person_api.get_person_value(person_id, "nickname")
|
||||||
user_nickname = nickname or f"用户{user_id}"
|
user_nickname = nickname or f"用户{user_id}"
|
||||||
user_info = UserInfo(platform=platform, user_id=str(user_id), user_nickname=user_nickname)
|
user_info = UserInfo(platform=platform, user_id=str(user_id), user_nickname=user_nickname)
|
||||||
|
|
||||||
# 使用 get_or_create_stream 来安全地获取或创建流
|
# 使用 get_or_create_stream 来安全地获取或创建流
|
||||||
stream = await self.chat_manager.get_or_create_stream(platform, user_info)
|
stream = await self.chat_manager.get_or_create_stream(platform, user_info)
|
||||||
|
|
||||||
formatted_stream_id = f"{stream.user_info.platform}:{stream.user_info.user_id}:private"
|
formatted_stream_id = f"{stream.user_info.platform}:{stream.user_info.user_id}:private"
|
||||||
await self.executor.execute(stream_id=formatted_stream_id, start_mode="cold_start")
|
await self.executor.execute(stream_id=formatted_stream_id, start_mode="cold_start")
|
||||||
logger.info(f"【冷启动】已为用户 {chat_id} (昵称: {user_nickname}) 发送唤醒/问候消息。")
|
logger.info(f"【冷启动】已为用户 {chat_id} (昵称: {user_nickname}) 发送唤醒/问候消息。")
|
||||||
|
|||||||
@@ -14,4 +14,4 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
"is_built_in": "true",
|
"is_built_in": "true",
|
||||||
"plugin_type": "functional"
|
"plugin_type": "functional"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -14,4 +14,4 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
"is_built_in": True,
|
"is_built_in": True,
|
||||||
"plugin_type": "audio_processor",
|
"plugin_type": "audio_processor",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -13,4 +13,4 @@ __plugin_meta__ = PluginMetadata(
|
|||||||
extra={
|
extra={
|
||||||
"is_built_in": True,
|
"is_built_in": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user