ruff
This commit is contained in:
@@ -313,7 +313,9 @@ class EnergyManager:
|
||||
|
||||
# 确保 score 是 float 类型
|
||||
if not isinstance(score, int | float):
|
||||
logger.warning(f"计算器 {calculator.__class__.__name__} 返回了非数值类型: {type(score)},跳过此组件")
|
||||
logger.warning(
|
||||
f"计算器 {calculator.__class__.__name__} 返回了非数值类型: {type(score)},跳过此组件"
|
||||
)
|
||||
continue
|
||||
|
||||
component_scores[calculator.__class__.__name__] = float(score)
|
||||
|
||||
@@ -527,7 +527,7 @@ class ExpressionLearnerManager:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
logger.debug(f"确保目录存在: {directory}")
|
||||
except Exception as e:
|
||||
logger.error(f"创建目录失败 {directory}: {e}")
|
||||
logger.error(f"创建目录失败 {directory}: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def _auto_migrate_json_to_db():
|
||||
|
||||
@@ -429,7 +429,9 @@ class BotInterestManager:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 计算相似度分数失败: {e}")
|
||||
|
||||
async def calculate_interest_match(self, message_text: str, keywords: list[str] | None = None) -> InterestMatchResult:
|
||||
async def calculate_interest_match(
|
||||
self, message_text: str, keywords: list[str] | None = None
|
||||
) -> InterestMatchResult:
|
||||
"""计算消息与机器人兴趣的匹配度"""
|
||||
if not self.current_interests or not self._initialized:
|
||||
raise RuntimeError("❌ 兴趣标签系统未初始化")
|
||||
|
||||
@@ -79,7 +79,9 @@ class InterestManager:
|
||||
|
||||
# 如果已有组件在运行,先清理并替换
|
||||
if self._current_calculator:
|
||||
logger.info(f"替换现有兴趣值计算组件: {self._current_calculator.component_name} -> {calculator.component_name}")
|
||||
logger.info(
|
||||
f"替换现有兴趣值计算组件: {self._current_calculator.component_name} -> {calculator.component_name}"
|
||||
)
|
||||
await self._current_calculator.cleanup()
|
||||
else:
|
||||
logger.info(f"注册新的兴趣值计算组件: {calculator.component_name}")
|
||||
@@ -114,7 +116,7 @@ class InterestManager:
|
||||
success=False,
|
||||
message_id=getattr(message, "message_id", ""),
|
||||
interest_value=0.3,
|
||||
error_message="没有可用的兴趣值计算组件"
|
||||
error_message="没有可用的兴趣值计算组件",
|
||||
)
|
||||
|
||||
# 使用 create_task 异步执行计算
|
||||
@@ -133,7 +135,7 @@ class InterestManager:
|
||||
interest_value=0.5, # 固定默认兴趣值
|
||||
should_reply=False,
|
||||
should_act=False,
|
||||
error_message=f"计算超时({timeout}s),使用默认值"
|
||||
error_message=f"计算超时({timeout}s),使用默认值",
|
||||
)
|
||||
except Exception as e:
|
||||
# 发生异常,返回默认结果
|
||||
@@ -142,7 +144,7 @@ class InterestManager:
|
||||
success=False,
|
||||
message_id=getattr(message, "message_id", ""),
|
||||
interest_value=0.3,
|
||||
error_message=f"计算异常: {e!s}"
|
||||
error_message=f"计算异常: {e!s}",
|
||||
)
|
||||
|
||||
async def _async_calculate(self, message: "DatabaseMessages") -> InterestCalculationResult:
|
||||
@@ -171,7 +173,7 @@ class InterestManager:
|
||||
message_id=getattr(message, "message_id", ""),
|
||||
interest_value=0.0,
|
||||
error_message=f"计算异常: {e!s}",
|
||||
calculation_time=time.time() - start_time
|
||||
calculation_time=time.time() - start_time,
|
||||
)
|
||||
|
||||
async def _calculation_worker(self):
|
||||
@@ -179,10 +181,7 @@ class InterestManager:
|
||||
while not self._shutdown_event.is_set():
|
||||
try:
|
||||
# 等待计算任务或关闭信号
|
||||
await asyncio.wait_for(
|
||||
self._calculation_queue.get(),
|
||||
timeout=1.0
|
||||
)
|
||||
await asyncio.wait_for(self._calculation_queue.get(), timeout=1.0)
|
||||
|
||||
# 处理计算任务
|
||||
# 这里可以实现批量处理逻辑
|
||||
@@ -210,7 +209,7 @@ class InterestManager:
|
||||
"failed_calculations": self._failed_calculations,
|
||||
"success_rate": success_rate,
|
||||
"last_calculation_time": self._last_calculation_time,
|
||||
"current_calculator": self._current_calculator.component_name if self._current_calculator else None
|
||||
"current_calculator": self._current_calculator.component_name if self._current_calculator else None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -125,19 +125,19 @@ class OpenIE:
|
||||
def extract_entity_dict(self):
|
||||
"""提取实体列表"""
|
||||
ner_output_dict = {
|
||||
doc_item["idx"]: doc_item["extracted_entities"]
|
||||
for doc_item in self.docs
|
||||
if len(doc_item["extracted_entities"]) > 0
|
||||
}
|
||||
doc_item["idx"]: doc_item["extracted_entities"]
|
||||
for doc_item in self.docs
|
||||
if len(doc_item["extracted_entities"]) > 0
|
||||
}
|
||||
return ner_output_dict
|
||||
|
||||
def extract_triple_dict(self):
|
||||
"""提取三元组列表"""
|
||||
triple_output_dict = {
|
||||
doc_item["idx"]: doc_item["extracted_triples"]
|
||||
for doc_item in self.docs
|
||||
if len(doc_item["extracted_triples"]) > 0
|
||||
}
|
||||
doc_item["idx"]: doc_item["extracted_triples"]
|
||||
for doc_item in self.docs
|
||||
if len(doc_item["extracted_triples"]) > 0
|
||||
}
|
||||
return triple_output_dict
|
||||
|
||||
def extract_raw_paragraph_dict(self):
|
||||
|
||||
@@ -19,10 +19,10 @@ def dyn_select_top_k(
|
||||
for score_item in sorted_score:
|
||||
normalized_score.append(
|
||||
(
|
||||
score_item[0],
|
||||
score_item[1],
|
||||
(score_item[1] - min_score) / (max_score - min_score),
|
||||
)
|
||||
score_item[0],
|
||||
score_item[1],
|
||||
(score_item[1] - min_score) / (max_score - min_score),
|
||||
)
|
||||
)
|
||||
|
||||
# 寻找跳变点:score变化最大的位置
|
||||
|
||||
@@ -29,20 +29,21 @@ logger = get_logger(__name__)
|
||||
@dataclass
|
||||
class HippocampusSampleConfig:
|
||||
"""海马体采样配置"""
|
||||
|
||||
# 双峰分布参数
|
||||
recent_mean_hours: float = 12.0 # 近期分布均值(小时)
|
||||
recent_std_hours: float = 8.0 # 近期分布标准差(小时)
|
||||
recent_weight: float = 0.7 # 近期分布权重
|
||||
recent_std_hours: float = 8.0 # 近期分布标准差(小时)
|
||||
recent_weight: float = 0.7 # 近期分布权重
|
||||
|
||||
distant_mean_hours: float = 48.0 # 远期分布均值(小时)
|
||||
distant_std_hours: float = 24.0 # 远期分布标准差(小时)
|
||||
distant_weight: float = 0.3 # 远期分布权重
|
||||
distant_std_hours: float = 24.0 # 远期分布标准差(小时)
|
||||
distant_weight: float = 0.3 # 远期分布权重
|
||||
|
||||
# 采样参数
|
||||
total_samples: int = 50 # 总采样数
|
||||
sample_interval: int = 1800 # 采样间隔(秒)
|
||||
max_sample_length: int = 30 # 每次采样的最大消息数量
|
||||
batch_size: int = 5 # 批处理大小
|
||||
total_samples: int = 50 # 总采样数
|
||||
sample_interval: int = 1800 # 采样间隔(秒)
|
||||
max_sample_length: int = 30 # 每次采样的最大消息数量
|
||||
batch_size: int = 5 # 批处理大小
|
||||
|
||||
@classmethod
|
||||
def from_global_config(cls) -> "HippocampusSampleConfig":
|
||||
@@ -84,12 +85,10 @@ class HippocampusSampler:
|
||||
try:
|
||||
# 初始化LLM模型
|
||||
from src.config.config import model_config
|
||||
|
||||
task_config = getattr(model_config.model_task_config, "utils", None)
|
||||
if task_config:
|
||||
self.memory_builder_model = LLMRequest(
|
||||
model_set=task_config,
|
||||
request_type="memory.hippocampus_build"
|
||||
)
|
||||
self.memory_builder_model = LLMRequest(model_set=task_config, request_type="memory.hippocampus_build")
|
||||
asyncio.create_task(self.start_background_sampling())
|
||||
logger.info("✅ 海马体采样器初始化成功")
|
||||
else:
|
||||
@@ -107,14 +106,10 @@ class HippocampusSampler:
|
||||
|
||||
# 生成两个正态分布的小时偏移
|
||||
recent_offsets = np.random.normal(
|
||||
loc=self.config.recent_mean_hours,
|
||||
scale=self.config.recent_std_hours,
|
||||
size=recent_samples
|
||||
loc=self.config.recent_mean_hours, scale=self.config.recent_std_hours, size=recent_samples
|
||||
)
|
||||
distant_offsets = np.random.normal(
|
||||
loc=self.config.distant_mean_hours,
|
||||
scale=self.config.distant_std_hours,
|
||||
size=distant_samples
|
||||
loc=self.config.distant_mean_hours, scale=self.config.distant_std_hours, size=distant_samples
|
||||
)
|
||||
|
||||
# 合并两个分布的偏移
|
||||
@@ -122,10 +117,7 @@ class HippocampusSampler:
|
||||
|
||||
# 转换为时间戳(使用绝对值确保时间点在过去)
|
||||
base_time = datetime.now()
|
||||
timestamps = [
|
||||
base_time - timedelta(hours=abs(offset))
|
||||
for offset in all_offsets
|
||||
]
|
||||
timestamps = [base_time - timedelta(hours=abs(offset)) for offset in all_offsets]
|
||||
|
||||
# 按时间排序(从最早到最近)
|
||||
return sorted(timestamps)
|
||||
@@ -171,7 +163,8 @@ class HippocampusSampler:
|
||||
if messages and len(messages) >= 2: # 至少需要2条消息
|
||||
# 过滤掉已经记忆过的消息
|
||||
filtered_messages = [
|
||||
msg for msg in messages
|
||||
msg
|
||||
for msg in messages
|
||||
if msg.get("memorized_times", 0) < 2 # 最多记忆2次
|
||||
]
|
||||
|
||||
@@ -229,7 +222,7 @@ class HippocampusSampler:
|
||||
conversation_text=input_text,
|
||||
context=context,
|
||||
timestamp=time.time(),
|
||||
bypass_interval=True # 海马体采样器绕过构建间隔限制
|
||||
bypass_interval=True, # 海马体采样器绕过构建间隔限制
|
||||
)
|
||||
|
||||
if memories:
|
||||
@@ -367,7 +360,7 @@ class HippocampusSampler:
|
||||
max_concurrent = min(5, len(time_samples)) # 提高并发数到5
|
||||
|
||||
for i in range(0, len(time_samples), max_concurrent):
|
||||
batch = time_samples[i:i + max_concurrent]
|
||||
batch = time_samples[i : i + max_concurrent]
|
||||
tasks = []
|
||||
|
||||
# 创建并发收集任务
|
||||
@@ -392,7 +385,9 @@ class HippocampusSampler:
|
||||
|
||||
return collected_messages
|
||||
|
||||
async def _fuse_and_deduplicate_messages(self, collected_messages: list[list[dict[str, Any]]]) -> list[list[dict[str, Any]]]:
|
||||
async def _fuse_and_deduplicate_messages(
|
||||
self, collected_messages: list[list[dict[str, Any]]]
|
||||
) -> list[list[dict[str, Any]]]:
|
||||
"""融合和去重消息样本"""
|
||||
if not collected_messages:
|
||||
return []
|
||||
@@ -416,7 +411,7 @@ class HippocampusSampler:
|
||||
chat_id = message.get("chat_id", "")
|
||||
|
||||
# 简单哈希:内容前50字符 + 时间戳(精确到分钟) + 聊天ID
|
||||
hash_key = f"{content[:50]}_{int(timestamp//60)}_{chat_id}"
|
||||
hash_key = f"{content[:50]}_{int(timestamp // 60)}_{chat_id}"
|
||||
|
||||
if hash_key not in seen_hashes and len(content.strip()) > 10:
|
||||
seen_hashes.add(hash_key)
|
||||
@@ -448,7 +443,9 @@ class HippocampusSampler:
|
||||
# 返回原始消息组作为备选
|
||||
return collected_messages[:5] # 限制返回数量
|
||||
|
||||
def _merge_adjacent_messages(self, messages: list[dict[str, Any]], time_gap: int = 1800) -> list[list[dict[str, Any]]]:
|
||||
def _merge_adjacent_messages(
|
||||
self, messages: list[dict[str, Any]], time_gap: int = 1800
|
||||
) -> list[list[dict[str, Any]]]:
|
||||
"""合并时间间隔内的消息"""
|
||||
if not messages:
|
||||
return []
|
||||
@@ -479,7 +476,9 @@ class HippocampusSampler:
|
||||
|
||||
return result_groups
|
||||
|
||||
async def _build_batch_memory(self, fused_messages: list[list[dict[str, Any]]], time_samples: list[datetime]) -> dict[str, Any]:
|
||||
async def _build_batch_memory(
|
||||
self, fused_messages: list[list[dict[str, Any]]], time_samples: list[datetime]
|
||||
) -> dict[str, Any]:
|
||||
"""批量构建记忆"""
|
||||
if not fused_messages:
|
||||
return {"memory_count": 0, "memories": []}
|
||||
@@ -513,10 +512,7 @@ class HippocampusSampler:
|
||||
|
||||
# 一次性构建记忆
|
||||
memories = await self.memory_system.build_memory_from_conversation(
|
||||
conversation_text=batch_input_text,
|
||||
context=batch_context,
|
||||
timestamp=time.time(),
|
||||
bypass_interval=True
|
||||
conversation_text=batch_input_text, context=batch_context, timestamp=time.time(), bypass_interval=True
|
||||
)
|
||||
|
||||
if memories:
|
||||
@@ -545,11 +541,7 @@ class HippocampusSampler:
|
||||
if len(self.last_sample_results) > 10:
|
||||
self.last_sample_results.pop(0)
|
||||
|
||||
return {
|
||||
"memory_count": total_memory_count,
|
||||
"memories": total_memories,
|
||||
"result": result
|
||||
}
|
||||
return {"memory_count": total_memory_count, "memories": total_memories, "result": result}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量构建记忆失败: {e}")
|
||||
@@ -601,11 +593,7 @@ class HippocampusSampler:
|
||||
except Exception as e:
|
||||
logger.debug(f"单独构建失败: {e}")
|
||||
|
||||
return {
|
||||
"memory_count": total_count,
|
||||
"memories": total_memories,
|
||||
"fallback_mode": True
|
||||
}
|
||||
return {"memory_count": total_count, "memories": total_memories, "fallback_mode": True}
|
||||
|
||||
async def process_sample_timestamp(self, target_timestamp: float) -> str | None:
|
||||
"""处理单个时间戳采样(保留作为备选方法)"""
|
||||
@@ -696,7 +684,9 @@ class HippocampusSampler:
|
||||
"performance_metrics": {
|
||||
"avg_messages_per_sample": f"{recent_avg_messages:.1f}",
|
||||
"avg_memories_per_sample": f"{recent_avg_memory_count:.1f}",
|
||||
"fusion_efficiency": f"{(recent_avg_messages/max(recent_avg_memory_count, 1)):.1f}x" if recent_avg_messages > 0 else "N/A"
|
||||
"fusion_efficiency": f"{(recent_avg_messages / max(recent_avg_memory_count, 1)):.1f}x"
|
||||
if recent_avg_messages > 0
|
||||
else "N/A",
|
||||
},
|
||||
"config": {
|
||||
"sample_interval": self.config.sample_interval,
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
返回: 适合直接嵌入提示词的大段文本;若无有效记忆返回空串。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
@@ -24,9 +24,12 @@ from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
|
||||
# 记忆采样模式枚举
|
||||
class MemorySamplingMode(Enum):
|
||||
"""记忆采样模式"""
|
||||
|
||||
HIPPOCAMPUS = "hippocampus" # 海马体模式:定时任务采样
|
||||
IMMEDIATE = "immediate" # 即时模式:回复后立即采样
|
||||
ALL = "all" # 所有模式:同时使用海马体和即时采样
|
||||
IMMEDIATE = "immediate" # 即时模式:回复后立即采样
|
||||
ALL = "all" # 所有模式:同时使用海马体和即时采样
|
||||
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
@@ -165,7 +168,6 @@ class MemorySystem:
|
||||
async def initialize(self):
|
||||
"""异步初始化记忆系统"""
|
||||
try:
|
||||
|
||||
# 初始化LLM模型
|
||||
fallback_task = getattr(self.llm_model, "model_for_task", None) if self.llm_model else None
|
||||
|
||||
@@ -264,6 +266,7 @@ class MemorySystem:
|
||||
if global_config.memory.enable_hippocampus_sampling:
|
||||
try:
|
||||
from .hippocampus_sampler import initialize_hippocampus_sampler
|
||||
|
||||
self.hippocampus_sampler = await initialize_hippocampus_sampler(self)
|
||||
logger.info("✅ 海马体采样器初始化成功")
|
||||
except Exception as e:
|
||||
@@ -321,7 +324,11 @@ class MemorySystem:
|
||||
return []
|
||||
|
||||
async def build_memory_from_conversation(
|
||||
self, conversation_text: str, context: dict[str, Any], timestamp: float | None = None, bypass_interval: bool = False
|
||||
self,
|
||||
conversation_text: str,
|
||||
context: dict[str, Any],
|
||||
timestamp: float | None = None,
|
||||
bypass_interval: bool = False,
|
||||
) -> list[MemoryChunk]:
|
||||
"""从对话中构建记忆
|
||||
|
||||
@@ -560,7 +567,6 @@ class MemorySystem:
|
||||
sampling_mode = getattr(global_config.memory, "memory_sampling_mode", "precision")
|
||||
current_mode = MemorySamplingMode(sampling_mode)
|
||||
|
||||
|
||||
context["__sampling_mode"] = current_mode.value
|
||||
logger.debug(f"使用记忆采样模式: {current_mode.value}")
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ logger = get_logger("adaptive_stream_manager")
|
||||
|
||||
class StreamPriority(Enum):
|
||||
"""流优先级"""
|
||||
|
||||
LOW = 1
|
||||
NORMAL = 2
|
||||
HIGH = 3
|
||||
@@ -26,6 +27,7 @@ class StreamPriority(Enum):
|
||||
@dataclass
|
||||
class SystemMetrics:
|
||||
"""系统指标"""
|
||||
|
||||
cpu_usage: float = 0.0
|
||||
memory_usage: float = 0.0
|
||||
active_coroutines: int = 0
|
||||
@@ -36,6 +38,7 @@ class SystemMetrics:
|
||||
@dataclass
|
||||
class StreamMetrics:
|
||||
"""流指标"""
|
||||
|
||||
stream_id: str
|
||||
priority: StreamPriority
|
||||
message_rate: float = 0.0 # 消息速率(消息/分钟)
|
||||
@@ -56,7 +59,7 @@ class AdaptiveStreamManager:
|
||||
metrics_window: float = 60.0, # 指标窗口时间
|
||||
adjustment_interval: float = 30.0, # 调整间隔
|
||||
cpu_threshold_high: float = 0.8, # CPU高负载阈值
|
||||
cpu_threshold_low: float = 0.3, # CPU低负载阈值
|
||||
cpu_threshold_low: float = 0.3, # CPU低负载阈值
|
||||
memory_threshold_high: float = 0.85, # 内存高负载阈值
|
||||
):
|
||||
self.base_concurrent_limit = base_concurrent_limit
|
||||
@@ -139,10 +142,7 @@ class AdaptiveStreamManager:
|
||||
logger.info("自适应流管理器已停止")
|
||||
|
||||
async def acquire_stream_slot(
|
||||
self,
|
||||
stream_id: str,
|
||||
priority: StreamPriority = StreamPriority.NORMAL,
|
||||
force: bool = False
|
||||
self, stream_id: str, priority: StreamPriority = StreamPriority.NORMAL, force: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
获取流处理槽位
|
||||
@@ -165,10 +165,7 @@ class AdaptiveStreamManager:
|
||||
|
||||
# 更新流指标
|
||||
if stream_id not in self.stream_metrics:
|
||||
self.stream_metrics[stream_id] = StreamMetrics(
|
||||
stream_id=stream_id,
|
||||
priority=priority
|
||||
)
|
||||
self.stream_metrics[stream_id] = StreamMetrics(stream_id=stream_id, priority=priority)
|
||||
self.stream_metrics[stream_id].last_activity = current_time
|
||||
|
||||
# 检查是否已经活跃
|
||||
@@ -271,8 +268,10 @@ class AdaptiveStreamManager:
|
||||
|
||||
# 如果最近有活跃且响应时间较长,可能需要强制分发
|
||||
current_time = time.time()
|
||||
if (current_time - metrics.last_activity < 300 and # 5分钟内有活动
|
||||
metrics.response_time > 5.0): # 响应时间超过5秒
|
||||
if (
|
||||
current_time - metrics.last_activity < 300 # 5分钟内有活动
|
||||
and metrics.response_time > 5.0
|
||||
): # 响应时间超过5秒
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -324,26 +323,20 @@ class AdaptiveStreamManager:
|
||||
memory_usage=memory_usage,
|
||||
active_coroutines=active_coroutines,
|
||||
event_loop_lag=event_loop_lag,
|
||||
timestamp=time.time()
|
||||
timestamp=time.time(),
|
||||
)
|
||||
|
||||
self.system_metrics.append(metrics)
|
||||
|
||||
# 保持指标窗口大小
|
||||
cutoff_time = time.time() - self.metrics_window
|
||||
self.system_metrics = [
|
||||
m for m in self.system_metrics
|
||||
if m.timestamp > cutoff_time
|
||||
]
|
||||
self.system_metrics = [m for m in self.system_metrics if m.timestamp > cutoff_time]
|
||||
|
||||
# 更新统计信息
|
||||
self.stats["avg_concurrent_streams"] = (
|
||||
self.stats["avg_concurrent_streams"] * 0.9 + len(self.active_streams) * 0.1
|
||||
)
|
||||
self.stats["peak_concurrent_streams"] = max(
|
||||
self.stats["peak_concurrent_streams"],
|
||||
len(self.active_streams)
|
||||
)
|
||||
self.stats["peak_concurrent_streams"] = max(self.stats["peak_concurrent_streams"], len(self.active_streams))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"收集系统指标失败: {e}")
|
||||
@@ -445,14 +438,16 @@ class AdaptiveStreamManager:
|
||||
def get_stats(self) -> dict:
|
||||
"""获取统计信息"""
|
||||
stats = self.stats.copy()
|
||||
stats.update({
|
||||
"current_limit": self.current_limit,
|
||||
"active_streams": len(self.active_streams),
|
||||
"pending_streams": len(self.pending_streams),
|
||||
"is_running": self.is_running,
|
||||
"system_cpu": self.system_metrics[-1].cpu_usage if self.system_metrics else 0,
|
||||
"system_memory": self.system_metrics[-1].memory_usage if self.system_metrics else 0,
|
||||
})
|
||||
stats.update(
|
||||
{
|
||||
"current_limit": self.current_limit,
|
||||
"active_streams": len(self.active_streams),
|
||||
"pending_streams": len(self.pending_streams),
|
||||
"is_running": self.is_running,
|
||||
"system_cpu": self.system_metrics[-1].cpu_usage if self.system_metrics else 0,
|
||||
"system_memory": self.system_metrics[-1].memory_usage if self.system_metrics else 0,
|
||||
}
|
||||
)
|
||||
|
||||
# 计算接受率
|
||||
if stats["total_requests"] > 0:
|
||||
|
||||
@@ -20,6 +20,7 @@ logger = get_logger("batch_database_writer")
|
||||
@dataclass
|
||||
class StreamUpdatePayload:
|
||||
"""流更新数据结构"""
|
||||
|
||||
stream_id: str
|
||||
update_data: dict[str, Any]
|
||||
priority: int = 0 # 优先级,数字越大优先级越高
|
||||
@@ -95,12 +96,7 @@ class BatchDatabaseWriter:
|
||||
|
||||
logger.info("批量数据库写入器已停止")
|
||||
|
||||
async def schedule_stream_update(
|
||||
self,
|
||||
stream_id: str,
|
||||
update_data: dict[str, Any],
|
||||
priority: int = 0
|
||||
) -> bool:
|
||||
async def schedule_stream_update(self, stream_id: str, update_data: dict[str, Any], priority: int = 0) -> bool:
|
||||
"""
|
||||
调度流更新
|
||||
|
||||
@@ -119,11 +115,7 @@ class BatchDatabaseWriter:
|
||||
return True
|
||||
|
||||
# 创建更新载荷
|
||||
payload = StreamUpdatePayload(
|
||||
stream_id=stream_id,
|
||||
update_data=update_data,
|
||||
priority=priority
|
||||
)
|
||||
payload = StreamUpdatePayload(stream_id=stream_id, update_data=update_data, priority=priority)
|
||||
|
||||
# 非阻塞方式加入队列
|
||||
try:
|
||||
@@ -178,10 +170,7 @@ class BatchDatabaseWriter:
|
||||
if remaining_time == 0:
|
||||
break
|
||||
|
||||
payload = await asyncio.wait_for(
|
||||
self.write_queue.get(),
|
||||
timeout=remaining_time
|
||||
)
|
||||
payload = await asyncio.wait_for(self.write_queue.get(), timeout=remaining_time)
|
||||
batch.append(payload)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
@@ -203,7 +192,10 @@ class BatchDatabaseWriter:
|
||||
# 合并同一流ID的更新(保留最新的)
|
||||
merged_updates = {}
|
||||
for payload in batch:
|
||||
if payload.stream_id not in merged_updates or payload.timestamp > merged_updates[payload.stream_id].timestamp:
|
||||
if (
|
||||
payload.stream_id not in merged_updates
|
||||
or payload.timestamp > merged_updates[payload.stream_id].timestamp
|
||||
):
|
||||
merged_updates[payload.stream_id] = payload
|
||||
|
||||
# 批量写入
|
||||
@@ -211,9 +203,7 @@ class BatchDatabaseWriter:
|
||||
|
||||
# 更新统计
|
||||
self.stats["batch_writes"] += 1
|
||||
self.stats["avg_batch_size"] = (
|
||||
self.stats["avg_batch_size"] * 0.9 + len(batch) * 0.1
|
||||
) # 滑动平均
|
||||
self.stats["avg_batch_size"] = self.stats["avg_batch_size"] * 0.9 + len(batch) * 0.1 # 滑动平均
|
||||
self.stats["last_flush_time"] = start_time
|
||||
|
||||
logger.debug(f"批量写入完成: {len(batch)} 个更新,耗时 {time.time() - start_time:.3f}s")
|
||||
@@ -238,31 +228,22 @@ class BatchDatabaseWriter:
|
||||
# 根据数据库类型选择不同的插入/更新策略
|
||||
if global_config.database.database_type == "sqlite":
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
stmt = sqlite_insert(ChatStreams).values(
|
||||
stream_id=stream_id, **update_data
|
||||
)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["stream_id"],
|
||||
set_=update_data
|
||||
)
|
||||
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data)
|
||||
elif global_config.database.database_type == "mysql":
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
stmt = mysql_insert(ChatStreams).values(
|
||||
stream_id=stream_id, **update_data
|
||||
)
|
||||
|
||||
stmt = mysql_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||
stmt = stmt.on_duplicate_key_update(
|
||||
**{key: value for key, value in update_data.items() if key != "stream_id"}
|
||||
)
|
||||
else:
|
||||
# 默认使用SQLite语法
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
stmt = sqlite_insert(ChatStreams).values(
|
||||
stream_id=stream_id, **update_data
|
||||
)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["stream_id"],
|
||||
set_=update_data
|
||||
)
|
||||
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data)
|
||||
|
||||
await session.execute(stmt)
|
||||
|
||||
@@ -273,30 +254,21 @@ class BatchDatabaseWriter:
|
||||
async with get_db_session() as session:
|
||||
if global_config.database.database_type == "sqlite":
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
stmt = sqlite_insert(ChatStreams).values(
|
||||
stream_id=stream_id, **update_data
|
||||
)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["stream_id"],
|
||||
set_=update_data
|
||||
)
|
||||
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data)
|
||||
elif global_config.database.database_type == "mysql":
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
stmt = mysql_insert(ChatStreams).values(
|
||||
stream_id=stream_id, **update_data
|
||||
)
|
||||
|
||||
stmt = mysql_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||
stmt = stmt.on_duplicate_key_update(
|
||||
**{key: value for key, value in update_data.items() if key != "stream_id"}
|
||||
)
|
||||
else:
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
stmt = sqlite_insert(ChatStreams).values(
|
||||
stream_id=stream_id, **update_data
|
||||
)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["stream_id"],
|
||||
set_=update_data
|
||||
)
|
||||
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data)
|
||||
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
@@ -273,8 +273,10 @@ class SingleStreamContextManager:
|
||||
message.should_reply = result.should_reply
|
||||
message.should_act = result.should_act
|
||||
|
||||
logger.debug(f"消息 {message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
|
||||
f"should_reply: {result.should_reply}, should_act: {result.should_act}")
|
||||
logger.debug(
|
||||
f"消息 {message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
|
||||
f"should_reply: {result.should_reply}, should_act: {result.should_act}"
|
||||
)
|
||||
return result.interest_value
|
||||
else:
|
||||
logger.warning(f"消息 {message.message_id} 兴趣值计算失败: {result.error_message}")
|
||||
|
||||
@@ -79,7 +79,7 @@ class StreamLoopManager:
|
||||
logger.info(f"正在取消 {len(cancel_tasks)} 个流循环任务...")
|
||||
await asyncio.gather(
|
||||
*[self._wait_for_task_cancel(stream_id, task) for stream_id, task in cancel_tasks],
|
||||
return_exceptions=True
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# 取消所有活跃的 chatter 处理任务
|
||||
@@ -115,6 +115,7 @@ class StreamLoopManager:
|
||||
# 使用自适应流管理器获取槽位
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||
|
||||
adaptive_manager = get_adaptive_stream_manager()
|
||||
|
||||
if adaptive_manager.is_running:
|
||||
@@ -123,9 +124,7 @@ class StreamLoopManager:
|
||||
|
||||
# 获取处理槽位
|
||||
slot_acquired = await adaptive_manager.acquire_stream_slot(
|
||||
stream_id=stream_id,
|
||||
priority=priority,
|
||||
force=force
|
||||
stream_id=stream_id, priority=priority, force=force
|
||||
)
|
||||
|
||||
if slot_acquired:
|
||||
@@ -140,10 +139,7 @@ class StreamLoopManager:
|
||||
|
||||
# 创建流循环任务
|
||||
try:
|
||||
loop_task = asyncio.create_task(
|
||||
self._stream_loop_worker(stream_id),
|
||||
name=f"stream_loop_{stream_id}"
|
||||
)
|
||||
loop_task = asyncio.create_task(self._stream_loop_worker(stream_id), name=f"stream_loop_{stream_id}")
|
||||
self.stream_loops[stream_id] = loop_task
|
||||
# 更新统计信息
|
||||
self.stats["active_streams"] += 1
|
||||
@@ -156,6 +152,7 @@ class StreamLoopManager:
|
||||
logger.error(f"启动流循环任务失败 {stream_id}: {e}")
|
||||
# 释放槽位
|
||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||
|
||||
adaptive_manager = get_adaptive_stream_manager()
|
||||
adaptive_manager.release_stream_slot(stream_id)
|
||||
|
||||
@@ -179,8 +176,8 @@ class StreamLoopManager:
|
||||
|
||||
except Exception:
|
||||
from src.chat.message_manager.adaptive_stream_manager import StreamPriority
|
||||
return StreamPriority.NORMAL
|
||||
|
||||
return StreamPriority.NORMAL
|
||||
|
||||
async def stop_stream_loop(self, stream_id: str) -> bool:
|
||||
"""停止指定流的循环任务
|
||||
@@ -244,11 +241,12 @@ class StreamLoopManager:
|
||||
# 3. 更新自适应管理器指标
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||
|
||||
adaptive_manager = get_adaptive_stream_manager()
|
||||
adaptive_manager.update_stream_metrics(
|
||||
stream_id,
|
||||
message_rate=unread_count / 5.0 if unread_count > 0 else 0.0, # 简化计算
|
||||
last_activity=time.time()
|
||||
last_activity=time.time(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"更新流指标失败: {e}")
|
||||
@@ -300,6 +298,7 @@ class StreamLoopManager:
|
||||
# 释放自适应管理器的槽位
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
|
||||
|
||||
adaptive_manager = get_adaptive_stream_manager()
|
||||
adaptive_manager.release_stream_slot(stream_id)
|
||||
logger.debug(f"释放自适应流处理槽位: {stream_id}")
|
||||
@@ -553,12 +552,12 @@ class StreamLoopManager:
|
||||
existing_task.cancel()
|
||||
# 创建异步任务来等待取消完成,并添加异常处理
|
||||
cancel_task = asyncio.create_task(
|
||||
self._wait_for_task_cancel(stream_id, existing_task),
|
||||
name=f"cancel_existing_loop_{stream_id}"
|
||||
self._wait_for_task_cancel(stream_id, existing_task), name=f"cancel_existing_loop_{stream_id}"
|
||||
)
|
||||
# 为取消任务添加异常处理,避免孤儿任务
|
||||
cancel_task.add_done_callback(
|
||||
lambda task: logger.debug(f"取消任务完成: {stream_id}") if not task.exception()
|
||||
lambda task: logger.debug(f"取消任务完成: {stream_id}")
|
||||
if not task.exception()
|
||||
else logger.error(f"取消任务异常: {stream_id} - {task.exception()}")
|
||||
)
|
||||
# 从字典中移除
|
||||
@@ -582,10 +581,7 @@ class StreamLoopManager:
|
||||
logger.info(f"流 {stream_id} 当前未读消息数: {unread_count}")
|
||||
|
||||
# 创建新的流循环任务
|
||||
new_task = asyncio.create_task(
|
||||
self._stream_loop(stream_id),
|
||||
name=f"force_stream_loop_{stream_id}"
|
||||
)
|
||||
new_task = asyncio.create_task(self._stream_loop(stream_id), name=f"force_stream_loop_{stream_id}")
|
||||
self.stream_loops[stream_id] = new_task
|
||||
self.stats["total_loops"] += 1
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ class MessageManager:
|
||||
# 启动批量数据库写入器
|
||||
try:
|
||||
from src.chat.message_manager.batch_database_writer import init_batch_writer
|
||||
|
||||
await init_batch_writer()
|
||||
except Exception as e:
|
||||
logger.error(f"启动批量数据库写入器失败: {e}")
|
||||
@@ -66,6 +67,7 @@ class MessageManager:
|
||||
# 启动流缓存管理器
|
||||
try:
|
||||
from src.chat.message_manager.stream_cache_manager import init_stream_cache_manager
|
||||
|
||||
await init_stream_cache_manager()
|
||||
except Exception as e:
|
||||
logger.error(f"启动流缓存管理器失败: {e}")
|
||||
@@ -73,6 +75,7 @@ class MessageManager:
|
||||
# 启动自适应流管理器
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import init_adaptive_stream_manager
|
||||
|
||||
await init_adaptive_stream_manager()
|
||||
logger.info("🎯 自适应流管理器已启动")
|
||||
except Exception as e:
|
||||
@@ -97,6 +100,7 @@ class MessageManager:
|
||||
# 停止批量数据库写入器
|
||||
try:
|
||||
from src.chat.message_manager.batch_database_writer import shutdown_batch_writer
|
||||
|
||||
await shutdown_batch_writer()
|
||||
logger.info("📦 批量数据库写入器已停止")
|
||||
except Exception as e:
|
||||
@@ -105,6 +109,7 @@ class MessageManager:
|
||||
# 停止流缓存管理器
|
||||
try:
|
||||
from src.chat.message_manager.stream_cache_manager import shutdown_stream_cache_manager
|
||||
|
||||
await shutdown_stream_cache_manager()
|
||||
logger.info("🗄️ 流缓存管理器已停止")
|
||||
except Exception as e:
|
||||
@@ -113,6 +118,7 @@ class MessageManager:
|
||||
# 停止自适应流管理器
|
||||
try:
|
||||
from src.chat.message_manager.adaptive_stream_manager import shutdown_adaptive_stream_manager
|
||||
|
||||
await shutdown_adaptive_stream_manager()
|
||||
logger.info("🎯 自适应流管理器已停止")
|
||||
except Exception as e:
|
||||
|
||||
@@ -19,6 +19,7 @@ logger = get_logger("stream_cache_manager")
|
||||
@dataclass
|
||||
class StreamCacheStats:
|
||||
"""缓存统计信息"""
|
||||
|
||||
hot_cache_size: int = 0
|
||||
warm_storage_size: int = 0
|
||||
cold_storage_size: int = 0
|
||||
@@ -38,9 +39,9 @@ class TieredStreamCache:
|
||||
max_warm_size: int = 500,
|
||||
max_cold_size: int = 2000,
|
||||
cleanup_interval: float = 300.0, # 5分钟清理一次
|
||||
hot_timeout: float = 1800.0, # 30分钟未访问降级到warm
|
||||
warm_timeout: float = 7200.0, # 2小时未访问降级到cold
|
||||
cold_timeout: float = 86400.0, # 24小时未访问删除
|
||||
hot_timeout: float = 1800.0, # 30分钟未访问降级到warm
|
||||
warm_timeout: float = 7200.0, # 2小时未访问降级到cold
|
||||
cold_timeout: float = 86400.0, # 24小时未访问删除
|
||||
):
|
||||
self.max_hot_size = max_hot_size
|
||||
self.max_warm_size = max_warm_size
|
||||
@@ -52,8 +53,8 @@ class TieredStreamCache:
|
||||
|
||||
# 三层缓存存储
|
||||
self.hot_cache: OrderedDict[str, OptimizedChatStream] = OrderedDict() # 热数据(LRU)
|
||||
self.warm_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间)
|
||||
self.cold_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间)
|
||||
self.warm_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间)
|
||||
self.cold_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间)
|
||||
|
||||
# 统计信息
|
||||
self.stats = StreamCacheStats()
|
||||
@@ -134,11 +135,7 @@ class TieredStreamCache:
|
||||
# 4. 缓存未命中,创建新流
|
||||
self.stats.cache_misses += 1
|
||||
stream = create_optimized_chat_stream(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
data=data
|
||||
stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info, data=data
|
||||
)
|
||||
logger.debug(f"缓存未命中,创建新流: {stream_id}")
|
||||
|
||||
@@ -294,9 +291,9 @@ class TieredStreamCache:
|
||||
|
||||
# 估算内存使用(粗略估计)
|
||||
self.stats.total_memory_usage = (
|
||||
len(self.hot_cache) * 1024 + # 每个热流约1KB
|
||||
len(self.warm_storage) * 512 + # 每个温流约512B
|
||||
len(self.cold_storage) * 256 # 每个冷流约256B
|
||||
len(self.hot_cache) * 1024 # 每个热流约1KB
|
||||
+ len(self.warm_storage) * 512 # 每个温流约512B
|
||||
+ len(self.cold_storage) * 256 # 每个冷流约256B
|
||||
)
|
||||
|
||||
if sum(cleanup_stats.values()) > 0:
|
||||
|
||||
@@ -557,7 +557,11 @@ class ChatBot:
|
||||
|
||||
# 将兴趣度结果同步回原始消息,便于后续流程使用
|
||||
message.interest_value = getattr(db_message, "interest_value", getattr(message, "interest_value", 0.0))
|
||||
setattr(message, "should_reply", getattr(db_message, "should_reply", getattr(message, "should_reply", False)))
|
||||
setattr(
|
||||
message,
|
||||
"should_reply",
|
||||
getattr(db_message, "should_reply", getattr(message, "should_reply", False)),
|
||||
)
|
||||
setattr(message, "should_act", getattr(db_message, "should_act", getattr(message, "should_act", False)))
|
||||
|
||||
# 存储消息到数据库,只进行一次写入
|
||||
|
||||
@@ -298,8 +298,10 @@ class ChatStream:
|
||||
db_message.should_reply = result.should_reply
|
||||
db_message.should_act = result.should_act
|
||||
|
||||
logger.debug(f"消息 {db_message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
|
||||
f"should_reply: {result.should_reply}, should_act: {result.should_act}")
|
||||
logger.debug(
|
||||
f"消息 {db_message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
|
||||
f"should_reply: {result.should_reply}, should_act: {result.should_act}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"消息 {db_message.message_id} 兴趣值计算失败: {result.error_message}")
|
||||
# 使用默认值
|
||||
@@ -521,18 +523,17 @@ class ChatManager:
|
||||
# 优先使用缓存管理器(优化版本)
|
||||
try:
|
||||
from src.chat.message_manager.stream_cache_manager import get_stream_cache_manager
|
||||
|
||||
cache_manager = get_stream_cache_manager()
|
||||
|
||||
if cache_manager.is_running:
|
||||
optimized_stream = await cache_manager.get_or_create_stream(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
user_info=user_info,
|
||||
group_info=group_info
|
||||
stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info
|
||||
)
|
||||
|
||||
# 设置消息上下文
|
||||
from .message import MessageRecv
|
||||
|
||||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
|
||||
optimized_stream.set_context(self.last_messages[stream_id])
|
||||
|
||||
@@ -715,7 +716,7 @@ class ChatManager:
|
||||
success = await batch_writer.schedule_stream_update(
|
||||
stream_id=stream_data_dict["stream_id"],
|
||||
update_data=ChatManager._prepare_stream_data(stream_data_dict),
|
||||
priority=1 # 流更新的优先级
|
||||
priority=1, # 流更新的优先级
|
||||
)
|
||||
if success:
|
||||
stream.saved = True
|
||||
@@ -738,7 +739,7 @@ class ChatManager:
|
||||
result = await batch_update(
|
||||
model_class=ChatStreams,
|
||||
conditions={"stream_id": stream_data_dict["stream_id"]},
|
||||
data=ChatManager._prepare_stream_data(stream_data_dict)
|
||||
data=ChatManager._prepare_stream_data(stream_data_dict),
|
||||
)
|
||||
if result and result > 0:
|
||||
stream.saved = True
|
||||
@@ -874,43 +875,43 @@ chat_manager = None
|
||||
|
||||
|
||||
def _convert_to_original_stream(self, optimized_stream) -> "ChatStream":
|
||||
"""将OptimizedChatStream转换为原始ChatStream以保持兼容性"""
|
||||
try:
|
||||
# 创建原始ChatStream实例
|
||||
original_stream = ChatStream(
|
||||
stream_id=optimized_stream.stream_id,
|
||||
platform=optimized_stream.platform,
|
||||
user_info=optimized_stream._get_effective_user_info(),
|
||||
group_info=optimized_stream._get_effective_group_info()
|
||||
)
|
||||
"""将OptimizedChatStream转换为原始ChatStream以保持兼容性"""
|
||||
try:
|
||||
# 创建原始ChatStream实例
|
||||
original_stream = ChatStream(
|
||||
stream_id=optimized_stream.stream_id,
|
||||
platform=optimized_stream.platform,
|
||||
user_info=optimized_stream._get_effective_user_info(),
|
||||
group_info=optimized_stream._get_effective_group_info(),
|
||||
)
|
||||
|
||||
# 复制状态
|
||||
original_stream.create_time = optimized_stream.create_time
|
||||
original_stream.last_active_time = optimized_stream.last_active_time
|
||||
original_stream.sleep_pressure = optimized_stream.sleep_pressure
|
||||
original_stream.base_interest_energy = optimized_stream.base_interest_energy
|
||||
original_stream._focus_energy = optimized_stream._focus_energy
|
||||
original_stream.no_reply_consecutive = optimized_stream.no_reply_consecutive
|
||||
original_stream.saved = optimized_stream.saved
|
||||
# 复制状态
|
||||
original_stream.create_time = optimized_stream.create_time
|
||||
original_stream.last_active_time = optimized_stream.last_active_time
|
||||
original_stream.sleep_pressure = optimized_stream.sleep_pressure
|
||||
original_stream.base_interest_energy = optimized_stream.base_interest_energy
|
||||
original_stream._focus_energy = optimized_stream._focus_energy
|
||||
original_stream.no_reply_consecutive = optimized_stream.no_reply_consecutive
|
||||
original_stream.saved = optimized_stream.saved
|
||||
|
||||
# 复制上下文信息(如果存在)
|
||||
if hasattr(optimized_stream, "_stream_context") and optimized_stream._stream_context:
|
||||
original_stream.stream_context = optimized_stream._stream_context
|
||||
# 复制上下文信息(如果存在)
|
||||
if hasattr(optimized_stream, "_stream_context") and optimized_stream._stream_context:
|
||||
original_stream.stream_context = optimized_stream._stream_context
|
||||
|
||||
if hasattr(optimized_stream, "_context_manager") and optimized_stream._context_manager:
|
||||
original_stream.context_manager = optimized_stream._context_manager
|
||||
if hasattr(optimized_stream, "_context_manager") and optimized_stream._context_manager:
|
||||
original_stream.context_manager = optimized_stream._context_manager
|
||||
|
||||
return original_stream
|
||||
return original_stream
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换OptimizedChatStream失败: {e}")
|
||||
# 如果转换失败,创建一个新的原始流
|
||||
return ChatStream(
|
||||
stream_id=optimized_stream.stream_id,
|
||||
platform=optimized_stream.platform,
|
||||
user_info=optimized_stream._get_effective_user_info(),
|
||||
group_info=optimized_stream._get_effective_group_info()
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"转换OptimizedChatStream失败: {e}")
|
||||
# 如果转换失败,创建一个新的原始流
|
||||
return ChatStream(
|
||||
stream_id=optimized_stream.stream_id,
|
||||
platform=optimized_stream.platform,
|
||||
user_info=optimized_stream._get_effective_user_info(),
|
||||
group_info=optimized_stream._get_effective_group_info(),
|
||||
)
|
||||
|
||||
|
||||
def get_chat_manager():
|
||||
|
||||
@@ -80,10 +80,7 @@ class OptimizedChatStream:
|
||||
):
|
||||
# 共享的只读数据
|
||||
self._shared_context = SharedContext(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
user_info=user_info,
|
||||
group_info=group_info
|
||||
stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info
|
||||
)
|
||||
|
||||
# 本地修改数据
|
||||
@@ -269,14 +266,13 @@ class OptimizedChatStream:
|
||||
self._stream_context = StreamContext(
|
||||
stream_id=self.stream_id,
|
||||
chat_type=ChatType.GROUP if self.group_info else ChatType.PRIVATE,
|
||||
chat_mode=ChatMode.NORMAL
|
||||
chat_mode=ChatMode.NORMAL,
|
||||
)
|
||||
|
||||
# 创建单流上下文管理器
|
||||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||||
self._context_manager = SingleStreamContextManager(
|
||||
stream_id=self.stream_id, context=self._stream_context
|
||||
)
|
||||
|
||||
self._context_manager = SingleStreamContextManager(stream_id=self.stream_id, context=self._stream_context)
|
||||
|
||||
@property
|
||||
def stream_context(self):
|
||||
@@ -331,9 +327,11 @@ class OptimizedChatStream:
|
||||
# 恢复stream_context信息
|
||||
if "stream_context_chat_type" in data:
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"])
|
||||
if "stream_context_chat_mode" in data:
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"])
|
||||
|
||||
# 恢复interruption_count信息
|
||||
@@ -352,6 +350,7 @@ class OptimizedChatStream:
|
||||
if isinstance(actions, str):
|
||||
try:
|
||||
import json
|
||||
|
||||
actions = json.loads(actions)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"无法解析actions JSON字符串: {actions}")
|
||||
@@ -458,7 +457,7 @@ class OptimizedChatStream:
|
||||
stream_id=self.stream_id,
|
||||
platform=self.platform,
|
||||
user_info=self._get_effective_user_info(),
|
||||
group_info=self._get_effective_group_info()
|
||||
group_info=self._get_effective_group_info(),
|
||||
)
|
||||
|
||||
# 复制本地修改(但不触发写时复制)
|
||||
@@ -482,9 +481,5 @@ def create_optimized_chat_stream(
|
||||
) -> OptimizedChatStream:
|
||||
"""创建优化版聊天流实例"""
|
||||
return OptimizedChatStream(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
data=data
|
||||
stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info, data=data
|
||||
)
|
||||
|
||||
@@ -196,18 +196,20 @@ class ChatterActionManager:
|
||||
thinking_id=thinking_id or "",
|
||||
action_done=True,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason
|
||||
action_prompt_display=reason,
|
||||
)
|
||||
else:
|
||||
asyncio.create_task(database_api.store_action_info(
|
||||
chat_stream=chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reason": reason},
|
||||
action_name="no_reply",
|
||||
))
|
||||
asyncio.create_task(
|
||||
database_api.store_action_info(
|
||||
chat_stream=chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=reason,
|
||||
action_done=True,
|
||||
thinking_id=thinking_id,
|
||||
action_data={"reason": reason},
|
||||
action_name="no_reply",
|
||||
)
|
||||
)
|
||||
|
||||
# 自动清空所有未读消息
|
||||
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, "no_reply"))
|
||||
@@ -228,7 +230,9 @@ class ChatterActionManager:
|
||||
|
||||
# 记录执行的动作到目标消息
|
||||
if success:
|
||||
asyncio.create_task(self._record_action_to_message(chat_stream, action_name, target_message, action_data))
|
||||
asyncio.create_task(
|
||||
self._record_action_to_message(chat_stream, action_name, target_message, action_data)
|
||||
)
|
||||
# 自动清空所有未读消息
|
||||
if clear_unread_messages:
|
||||
asyncio.create_task(self._clear_all_unread_messages(chat_stream.stream_id, action_name))
|
||||
@@ -496,7 +500,7 @@ class ChatterActionManager:
|
||||
thinking_id=thinking_id or "",
|
||||
action_done=True,
|
||||
action_build_into_prompt=False,
|
||||
action_prompt_display=action_prompt_display
|
||||
action_prompt_display=action_prompt_display,
|
||||
)
|
||||
else:
|
||||
await database_api.store_action_info(
|
||||
@@ -618,9 +622,15 @@ class ChatterActionManager:
|
||||
self._pending_actions = [] # 清空队列
|
||||
logger.debug("已禁用批量存储模式")
|
||||
|
||||
def add_action_to_batch(self, action_name: str, action_data: dict, thinking_id: str = "",
|
||||
action_done: bool = True, action_build_into_prompt: bool = False,
|
||||
action_prompt_display: str = ""):
|
||||
def add_action_to_batch(
|
||||
self,
|
||||
action_name: str,
|
||||
action_data: dict,
|
||||
thinking_id: str = "",
|
||||
action_done: bool = True,
|
||||
action_build_into_prompt: bool = False,
|
||||
action_prompt_display: str = "",
|
||||
):
|
||||
"""添加动作到批量存储列表"""
|
||||
if not self._batch_storage_enabled:
|
||||
return False
|
||||
@@ -632,7 +642,7 @@ class ChatterActionManager:
|
||||
"action_done": action_done,
|
||||
"action_build_into_prompt": action_build_into_prompt,
|
||||
"action_prompt_display": action_prompt_display,
|
||||
"timestamp": time.time()
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
self._pending_actions.append(action_record)
|
||||
logger.debug(f"已添加动作到批量存储列表: {action_name} (当前待处理: {len(self._pending_actions)} 个)")
|
||||
@@ -658,7 +668,7 @@ class ChatterActionManager:
|
||||
action_done=action_data.get("action_done", True),
|
||||
action_build_into_prompt=action_data.get("action_build_into_prompt", False),
|
||||
action_prompt_display=action_data.get("action_prompt_display", ""),
|
||||
thinking_id=action_data.get("thinking_id", "")
|
||||
thinking_id=action_data.get("thinking_id", ""),
|
||||
)
|
||||
if result:
|
||||
stored_count += 1
|
||||
|
||||
@@ -589,7 +589,7 @@ class DefaultReplyer:
|
||||
# 获取记忆系统实例
|
||||
memory_system = get_memory_system()
|
||||
|
||||
# 使用统一记忆系统检索相关记忆
|
||||
# 使用统一记忆系统检索相关记忆
|
||||
enhanced_memories = await memory_system.retrieve_relevant_memories(
|
||||
query=target, user_id=memory_user_id, scope_id=stream.stream_id, context=memory_context, limit=10
|
||||
)
|
||||
@@ -1208,12 +1208,32 @@ class DefaultReplyer:
|
||||
|
||||
# 并行执行六个构建任务
|
||||
tasks = {
|
||||
"expression_habits": asyncio.create_task(self._time_and_run_task(self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits")),
|
||||
"relation_info": asyncio.create_task(self._time_and_run_task(self.build_relation_info(sender, target), "relation_info")),
|
||||
"memory_block": asyncio.create_task(self._time_and_run_task(self.build_memory_block(chat_talking_prompt_short, target), "memory_block")),
|
||||
"tool_info": asyncio.create_task(self._time_and_run_task(self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info")),
|
||||
"prompt_info": asyncio.create_task(self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info")),
|
||||
"cross_context": asyncio.create_task(self._time_and_run_task(Prompt.build_cross_context(chat_id, global_config.personality.prompt_mode, target_user_info), "cross_context")),
|
||||
"expression_habits": asyncio.create_task(
|
||||
self._time_and_run_task(
|
||||
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
|
||||
)
|
||||
),
|
||||
"relation_info": asyncio.create_task(
|
||||
self._time_and_run_task(self.build_relation_info(sender, target), "relation_info")
|
||||
),
|
||||
"memory_block": asyncio.create_task(
|
||||
self._time_and_run_task(self.build_memory_block(chat_talking_prompt_short, target), "memory_block")
|
||||
),
|
||||
"tool_info": asyncio.create_task(
|
||||
self._time_and_run_task(
|
||||
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool),
|
||||
"tool_info",
|
||||
)
|
||||
),
|
||||
"prompt_info": asyncio.create_task(
|
||||
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info")
|
||||
),
|
||||
"cross_context": asyncio.create_task(
|
||||
self._time_and_run_task(
|
||||
Prompt.build_cross_context(chat_id, global_config.personality.prompt_mode, target_user_info),
|
||||
"cross_context",
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
# 设置超时
|
||||
@@ -1512,13 +1532,8 @@ class DefaultReplyer:
|
||||
chat_target_name = (
|
||||
self.chat_target_info.get("person_name") or self.chat_target_info.get("user_nickname") or "对方"
|
||||
)
|
||||
await global_prompt_manager.format_prompt(
|
||||
"chat_target_private1", sender_name=chat_target_name
|
||||
)
|
||||
await global_prompt_manager.format_prompt(
|
||||
"chat_target_private2", sender_name=chat_target_name
|
||||
)
|
||||
|
||||
await global_prompt_manager.format_prompt("chat_target_private1", sender_name=chat_target_name)
|
||||
await global_prompt_manager.format_prompt("chat_target_private2", sender_name=chat_target_name)
|
||||
|
||||
# 使用新的统一Prompt系统 - Expressor模式,创建PromptParameters
|
||||
prompt_parameters = PromptParameters(
|
||||
|
||||
@@ -121,13 +121,14 @@ class VideoAnalyzer:
|
||||
async def _analyze_batch(self, frames: list[tuple[str, float]], question: str | None) -> str:
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
||||
from src.llm_models.utils_model import RequestType
|
||||
|
||||
prompt = self.batch_analysis_prompt.format(
|
||||
personality_core=self.personality_core, personality_side=self.personality_side
|
||||
)
|
||||
if question:
|
||||
prompt += f"\n用户关注: {question}"
|
||||
desc = [
|
||||
(f"第{i+1}帧 (时间: {ts:.2f}s)" if self.enable_frame_timing else f"第{i+1}帧")
|
||||
(f"第{i + 1}帧 (时间: {ts:.2f}s)" if self.enable_frame_timing else f"第{i + 1}帧")
|
||||
for i, (_b, ts) in enumerate(frames)
|
||||
]
|
||||
prompt += "\n帧列表: " + ", ".join(desc)
|
||||
@@ -151,16 +152,16 @@ class VideoAnalyzer:
|
||||
async def _analyze_sequential(self, frames: list[tuple[str, float]], question: str | None) -> str:
|
||||
results: list[str] = []
|
||||
for i, (b64, ts) in enumerate(frames):
|
||||
prompt = f"分析第{i+1}帧" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "")
|
||||
prompt = f"分析第{i + 1}帧" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "")
|
||||
if question:
|
||||
prompt += f"\n关注: {question}"
|
||||
try:
|
||||
text, _ = await self.video_llm.generate_response_for_image(
|
||||
prompt=prompt, image_base64=b64, image_format="jpeg"
|
||||
)
|
||||
results.append(f"第{i+1}帧: {text}")
|
||||
results.append(f"第{i + 1}帧: {text}")
|
||||
except Exception as e: # pragma: no cover
|
||||
results.append(f"第{i+1}帧: 失败 {e}")
|
||||
results.append(f"第{i + 1}帧: 失败 {e}")
|
||||
if i < len(frames) - 1:
|
||||
await asyncio.sleep(self.frame_analysis_delay)
|
||||
summary_prompt = "基于以下逐帧结果给出完整总结:\n\n" + "\n".join(results)
|
||||
@@ -182,7 +183,9 @@ class VideoAnalyzer:
|
||||
mode = self.analysis_mode
|
||||
if mode == "auto":
|
||||
mode = "batch" if len(frames) <= 20 else "sequential"
|
||||
text = await (self._analyze_batch(frames, question) if mode == "batch" else self._analyze_sequential(frames, question))
|
||||
text = await (
|
||||
self._analyze_batch(frames, question) if mode == "batch" else self._analyze_sequential(frames, question)
|
||||
)
|
||||
return True, text
|
||||
|
||||
async def analyze_video_from_bytes(
|
||||
|
||||
Reference in New Issue
Block a user