This commit is contained in:
Windpicker-owo
2025-11-07 21:16:45 +08:00
69 changed files with 1061 additions and 3246 deletions

View File

@@ -0,0 +1,361 @@
"""
记忆图可视化 - API 路由模块
提供 Web API 用于可视化记忆图数据
"""
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
import orjson
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates
# 调整项目根目录的计算方式
project_root = Path(__file__).parent.parent.parent
data_dir = project_root / "data" / "memory_graph"
# 缓存
graph_data_cache = None
current_data_file = None
# FastAPI 路由
router = APIRouter()
# Jinja2 模板引擎
templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates"))
def find_available_data_files() -> List[Path]:
"""查找所有可用的记忆图数据文件"""
files = []
if not data_dir.exists():
return files
possible_files = ["graph_store.json", "memory_graph.json", "graph_data.json"]
for filename in possible_files:
file_path = data_dir / filename
if file_path.exists():
files.append(file_path)
for pattern in ["graph_store_*.json", "memory_graph_*.json", "graph_data_*.json"]:
for backup_file in data_dir.glob(pattern):
if backup_file not in files:
files.append(backup_file)
backups_dir = data_dir / "backups"
if backups_dir.exists():
for backup_file in backups_dir.glob("**/*.json"):
if backup_file not in files:
files.append(backup_file)
backup_dir = data_dir.parent / "backup"
if backup_dir.exists():
for pattern in ["**/graph_*.json", "**/memory_*.json"]:
for backup_file in backup_dir.glob(pattern):
if backup_file not in files:
files.append(backup_file)
return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True)
def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any]:
"""从磁盘加载图数据"""
global graph_data_cache, current_data_file
if file_path and file_path != current_data_file:
graph_data_cache = None
current_data_file = file_path
if graph_data_cache:
return graph_data_cache
try:
graph_file = current_data_file
if not graph_file:
available_files = find_available_data_files()
if not available_files:
return {"error": "未找到数据文件", "nodes": [], "edges": [], "stats": {}}
graph_file = available_files[0]
current_data_file = graph_file
if not graph_file.exists():
return {"error": f"文件不存在: {graph_file}", "nodes": [], "edges": [], "stats": {}}
with open(graph_file, "r", encoding="utf-8") as f:
data = orjson.loads(f.read())
nodes = data.get("nodes", [])
edges = data.get("edges", [])
metadata = data.get("metadata", {})
nodes_dict = {
node["id"]: {
**node,
"label": node.get("content", ""),
"group": node.get("node_type", ""),
"title": f"{node.get('node_type', '')}: {node.get('content', '')}",
}
for node in nodes
if node.get("id")
}
edges_list = [
{
**edge,
"from": edge.get("source", edge.get("source_id")),
"to": edge.get("target", edge.get("target_id")),
"label": edge.get("relation", ""),
"arrows": "to",
}
for edge in edges
]
stats = metadata.get("statistics", {})
total_memories = stats.get("total_memories", 0)
graph_data_cache = {
"nodes": list(nodes_dict.values()),
"edges": edges_list,
"memories": [],
"stats": {
"total_nodes": len(nodes_dict),
"total_edges": len(edges_list),
"total_memories": total_memories,
},
"current_file": str(graph_file),
"file_size": graph_file.stat().st_size,
"file_modified": datetime.fromtimestamp(graph_file.stat().st_mtime).isoformat(),
}
return graph_data_cache
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"加载图数据失败: {e}")
@router.get("/", response_class=HTMLResponse)
async def index(request: Request):
"""主页面"""
return templates.TemplateResponse("visualizer.html", {"request": request})
def _format_graph_data_from_manager(memory_manager) -> Dict[str, Any]:
"""从 MemoryManager 提取并格式化图数据"""
if not memory_manager.graph_store:
return {"nodes": [], "edges": [], "memories": [], "stats": {}}
all_memories = memory_manager.graph_store.get_all_memories()
nodes_dict = {}
edges_list = []
memory_info = []
for memory in all_memories:
memory_info.append(
{
"id": memory.id,
"type": memory.memory_type.value,
"importance": memory.importance,
"text": memory.to_text(),
}
)
for node in memory.nodes:
if node.id not in nodes_dict:
nodes_dict[node.id] = {
"id": node.id,
"label": node.content,
"type": node.node_type.value,
"group": node.node_type.name,
"title": f"{node.node_type.value}: {node.content}",
}
for edge in memory.edges:
edges_list.append( # noqa: PERF401
{
"id": edge.id,
"from": edge.source_id,
"to": edge.target_id,
"label": edge.relation,
"arrows": "to",
"memory_id": memory.id,
}
)
stats = memory_manager.get_statistics()
return {
"nodes": list(nodes_dict.values()),
"edges": edges_list,
"memories": memory_info,
"stats": {
"total_nodes": stats.get("total_nodes", 0),
"total_edges": stats.get("total_edges", 0),
"total_memories": stats.get("total_memories", 0),
},
"current_file": "memory_manager (实时数据)",
}
@router.get("/api/graph/full")
async def get_full_graph():
"""获取完整记忆图数据"""
try:
from src.memory_graph.manager_singleton import get_memory_manager
memory_manager = get_memory_manager()
data = {}
if memory_manager and memory_manager._initialized:
data = _format_graph_data_from_manager(memory_manager)
else:
# 如果内存管理器不可用,则从文件加载
data = load_graph_data_from_file()
return JSONResponse(content={"success": True, "data": data})
except Exception as e:
import traceback
traceback.print_exc()
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
@router.get("/api/files")
async def list_files_api():
"""列出所有可用的数据文件"""
try:
files = find_available_data_files()
file_list = []
for f in files:
stat = f.stat()
file_list.append(
{
"path": str(f),
"name": f.name,
"size": stat.st_size,
"size_kb": round(stat.st_size / 1024, 2),
"modified": datetime.fromtimestamp(stat.st_mtime).isoformat(),
"modified_readable": datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S"),
"is_current": str(f) == str(current_data_file) if current_data_file else False,
}
)
return JSONResponse(
content={
"success": True,
"files": file_list,
"count": len(file_list),
"current_file": str(current_data_file) if current_data_file else None,
}
)
except Exception as e:
# 增加日志记录
# logger.error(f"列出数据文件失败: {e}", exc_info=True)
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
@router.post("/select_file")
async def select_file(request: Request):
"""选择要加载的数据文件"""
global graph_data_cache, current_data_file
try:
data = await request.json()
file_path = data.get("file_path")
if not file_path:
raise HTTPException(status_code=400, detail="未提供文件路径")
file_to_load = Path(file_path)
if not file_to_load.exists():
raise HTTPException(status_code=404, detail=f"文件不存在: {file_path}")
graph_data_cache = None
current_data_file = file_to_load
graph_data = load_graph_data_from_file(file_to_load)
return JSONResponse(
content={
"success": True,
"message": f"已切换到文件: {file_to_load.name}",
"stats": graph_data.get("stats", {}),
}
)
except Exception as e:
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
@router.get("/reload")
async def reload_data():
"""重新加载数据"""
global graph_data_cache
graph_data_cache = None
data = load_graph_data_from_file()
return JSONResponse(content={"success": True, "message": "数据已重新加载", "stats": data.get("stats", {})})
@router.get("/api/search")
async def search_memories(q: str, limit: int = 50):
"""搜索记忆"""
try:
from src.memory_graph.manager_singleton import get_memory_manager
memory_manager = get_memory_manager()
results = []
if memory_manager and memory_manager._initialized and memory_manager.graph_store:
# 从 memory_manager 搜索
all_memories = memory_manager.graph_store.get_all_memories()
for memory in all_memories:
if q.lower() in memory.to_text().lower():
results.append(
{
"id": memory.id,
"type": memory.memory_type.value,
"importance": memory.importance,
"text": memory.to_text(),
}
)
else:
# 从文件加载的数据中搜索 (降级方案)
data = load_graph_data_from_file()
for memory in data.get("memories", []):
if q.lower() in memory.get("text", "").lower():
results.append(memory)
return JSONResponse(
content={
"success": True,
"data": {
"results": results[:limit],
"count": len(results),
},
}
)
except Exception as e:
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
@router.get("/api/stats")
async def get_statistics():
"""获取统计信息"""
try:
data = load_graph_data_from_file()
node_types = {}
memory_types = {}
for node in data["nodes"]:
node_type = node.get("type", "Unknown")
node_types[node_type] = node_types.get(node_type, 0) + 1
for memory in data.get("memories", []):
mem_type = memory.get("type", "Unknown")
memory_types[mem_type] = memory_types.get(mem_type, 0) + 1
stats = data.get("stats", {})
stats["node_types"] = node_types
stats["memory_types"] = memory_types
return JSONResponse(content={"success": True, "data": stats})
except Exception as e:
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)

View File

@@ -4,10 +4,10 @@ from typing import Any, Literal
from fastapi import APIRouter, HTTPException, Query
from src.common.database.compatibility import db_get
from src.common.database.core.models import LLMUsage
from src.chat.utils.statistic import (
StatisticOutputTask,
)
from src.common.logger import get_logger
from src.config.config import model_config
logger = get_logger("LLM统计API")
@@ -37,108 +37,6 @@ COST_BY_USER = "costs_by_user"
COST_BY_MODEL = "costs_by_model"
COST_BY_MODULE = "costs_by_module"
async def _collect_stats_in_period(start_time: datetime, end_time: datetime) -> dict[str, Any]:
"""在指定时间段内收集LLM使用统计信息"""
records = await db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": start_time, "$lt": end_time}},
)
if not records:
return {}
# 创建一个从 model_identifier 到 name 的映射
model_identifier_to_name_map = {model.model_identifier: model.name for model in model_config.models}
stats: dict[str, Any] = {
TOTAL_REQ_CNT: 0,
TOTAL_COST: 0.0,
REQ_CNT_BY_TYPE: defaultdict(int),
REQ_CNT_BY_USER: defaultdict(int),
REQ_CNT_BY_MODEL: defaultdict(int),
REQ_CNT_BY_MODULE: defaultdict(int),
IN_TOK_BY_TYPE: defaultdict(int),
IN_TOK_BY_USER: defaultdict(int),
IN_TOK_BY_MODEL: defaultdict(int),
IN_TOK_BY_MODULE: defaultdict(int),
OUT_TOK_BY_TYPE: defaultdict(int),
OUT_TOK_BY_USER: defaultdict(int),
OUT_TOK_BY_MODEL: defaultdict(int),
OUT_TOK_BY_MODULE: defaultdict(int),
TOTAL_TOK_BY_TYPE: defaultdict(int),
TOTAL_TOK_BY_USER: defaultdict(int),
TOTAL_TOK_BY_MODEL: defaultdict(int),
TOTAL_TOK_BY_MODULE: defaultdict(int),
COST_BY_TYPE: defaultdict(float),
COST_BY_USER: defaultdict(float),
COST_BY_MODEL: defaultdict(float),
COST_BY_MODULE: defaultdict(float),
}
for record in records:
if not isinstance(record, dict):
continue
stats[TOTAL_REQ_CNT] += 1
request_type = record.get("request_type") or "unknown"
user_id = record.get("user_id") or "unknown"
# 从数据库获取的是真实模型名 (model_identifier)
real_model_name = record.get("model_name") or "unknown"
module_name = request_type.split(".")[0] if "." in request_type else request_type
# 尝试通过真实模型名找到配置文件中的模型名
config_model_name = model_identifier_to_name_map.get(real_model_name, real_model_name)
prompt_tokens = record.get("prompt_tokens") or 0
completion_tokens = record.get("completion_tokens") or 0
total_tokens = prompt_tokens + completion_tokens
cost = 0.0
try:
# 使用配置文件中的模型名来获取模型信息
model_info = model_config.get_model_info(config_model_name)
if model_info:
input_cost = (prompt_tokens / 1000000) * model_info.price_in
output_cost = (completion_tokens / 1000000) * model_info.price_out
cost = round(input_cost + output_cost, 6)
except KeyError as e:
logger.info(str(e))
logger.warning(f"模型 '{config_model_name}' (真实名称: '{real_model_name}') 在配置中未找到,成本计算将使用默认值 0.0")
stats[TOTAL_COST] += cost
# 按类型统计
stats[REQ_CNT_BY_TYPE][request_type] += 1
stats[IN_TOK_BY_TYPE][request_type] += prompt_tokens
stats[OUT_TOK_BY_TYPE][request_type] += completion_tokens
stats[TOTAL_TOK_BY_TYPE][request_type] += total_tokens
stats[COST_BY_TYPE][request_type] += cost
# 按用户统计
stats[REQ_CNT_BY_USER][user_id] += 1
stats[IN_TOK_BY_USER][user_id] += prompt_tokens
stats[OUT_TOK_BY_USER][user_id] += completion_tokens
stats[TOTAL_TOK_BY_USER][user_id] += total_tokens
stats[COST_BY_USER][user_id] += cost
# 按模型统计 (使用配置文件中的名称)
stats[REQ_CNT_BY_MODEL][config_model_name] += 1
stats[IN_TOK_BY_MODEL][config_model_name] += prompt_tokens
stats[OUT_TOK_BY_MODEL][config_model_name] += completion_tokens
stats[TOTAL_TOK_BY_MODEL][config_model_name] += total_tokens
stats[COST_BY_MODEL][config_model_name] += cost
# 按模块统计
stats[REQ_CNT_BY_MODULE][module_name] += 1
stats[IN_TOK_BY_MODULE][module_name] += prompt_tokens
stats[OUT_TOK_BY_MODULE][module_name] += completion_tokens
stats[TOTAL_TOK_BY_MODULE][module_name] += total_tokens
stats[COST_BY_MODULE][module_name] += cost
return stats
@router.get("/llm/stats")
async def get_llm_stats(
period_type: Literal[
@@ -179,7 +77,8 @@ async def get_llm_stats(
if start_time is None:
raise HTTPException(status_code=400, detail="无法确定查询的起始时间")
period_stats = await _collect_stats_in_period(start_time, end_time)
stats_data = await StatisticOutputTask._collect_model_request_for_period([("custom", start_time)])
period_stats = stats_data.get("custom", {})
if not period_stats:
return {"period": {"start": start_time.isoformat(), "end": end_time.isoformat()}, "data": {}}

File diff suppressed because it is too large Load Diff

View File

@@ -680,9 +680,9 @@ class EmojiManager:
try:
# 🔧 使用 QueryBuilder 以启用数据库缓存
from src.common.database.api.query import QueryBuilder
logger.debug("[数据库] 开始加载所有表情包记录 ...")
emoji_instances = await QueryBuilder(Emoji).all()
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
@@ -802,7 +802,7 @@ class EmojiManager:
# 如果内存中没有,从数据库查找(使用 QueryBuilder 启用数据库缓存)
try:
from src.common.database.api.query import QueryBuilder
emoji_record = await QueryBuilder(Emoji).filter(emoji_hash=emoji_hash).first()
if emoji_record and emoji_record.description:
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
@@ -966,7 +966,7 @@ class EmojiManager:
existing_description = None
try:
from src.common.database.api.query import QueryBuilder
existing_image = await QueryBuilder(Images).filter(emoji_hash=image_hash, type="emoji").first()
if existing_image and existing_image.description:
existing_description = existing_image.description

View File

@@ -1,5 +1,4 @@
import os
import random
import time
from datetime import datetime
from typing import Any
@@ -135,20 +134,20 @@ class ExpressionLearner:
async def cleanup_expired_expressions(self, expiration_days: int | None = None) -> int:
"""
清理过期的表达方式
Args:
expiration_days: 过期天数,超过此天数未激活的表达方式将被删除(不指定则从配置读取)
Returns:
int: 删除的表达方式数量
"""
# 从配置读取过期天数
if expiration_days is None:
expiration_days = global_config.expression.expiration_days
current_time = time.time()
expiration_threshold = current_time - (expiration_days * 24 * 3600)
try:
deleted_count = 0
async with get_db_session() as session:
@@ -160,15 +159,15 @@ class ExpressionLearner:
)
)
expired_expressions = list(query.scalars())
if expired_expressions:
for expr in expired_expressions:
await session.delete(expr)
deleted_count += 1
await session.commit()
logger.info(f"清理了 {deleted_count} 个过期表达方式(超过 {expiration_days} 天未使用)")
# 清除缓存
from src.common.database.optimization.cache_manager import get_cache
from src.common.database.utils.decorators import generate_cache_key
@@ -176,7 +175,7 @@ class ExpressionLearner:
await cache.delete(generate_cache_key("chat_expressions", self.chat_id))
else:
logger.debug(f"没有发现过期的表达方式(阈值:{expiration_days} 天)")
return deleted_count
except Exception as e:
logger.error(f"清理过期表达方式失败: {e}")
@@ -460,7 +459,7 @@ class ExpressionLearner:
)
)
same_situation_expr = query_same_situation.scalar()
# 情况2相同 chat_id + type + style相同表达不同情景
query_same_style = await session.execute(
select(Expression).where(
@@ -470,7 +469,7 @@ class ExpressionLearner:
)
)
same_style_expr = query_same_style.scalar()
# 情况3完全相同相同情景+相同表达)
query_exact_match = await session.execute(
select(Expression).where(
@@ -481,7 +480,7 @@ class ExpressionLearner:
)
)
exact_match_expr = query_exact_match.scalar()
# 优先处理完全匹配的情况
if exact_match_expr:
# 完全相同增加count更新时间

View File

@@ -72,21 +72,21 @@ class ExpressorModel:
是否删除成功
"""
removed = False
if cid in self._candidates:
del self._candidates[cid]
removed = True
if cid in self._situations:
del self._situations[cid]
# 从nb模型中删除
if cid in self.nb.cls_counts:
del self.nb.cls_counts[cid]
if cid in self.nb.token_counts:
del self.nb.token_counts[cid]
return removed
def predict(self, text: str, k: int | None = None) -> tuple[str | None, dict[str, float]]:

View File

@@ -72,7 +72,7 @@ class StyleLearner:
# 检查是否需要清理
current_count = len(self.style_to_id)
cleanup_trigger = int(self.max_styles * self.cleanup_threshold)
if current_count >= cleanup_trigger:
if current_count >= self.max_styles:
# 已经达到最大限制,必须清理
@@ -109,7 +109,7 @@ class StyleLearner:
def _cleanup_styles(self):
"""
清理低价值的风格,为新风格腾出空间
清理策略:
1. 综合考虑使用次数和最后使用时间
2. 删除得分最低的风格
@@ -118,34 +118,34 @@ class StyleLearner:
try:
current_time = time.time()
cleanup_count = max(1, int(len(self.style_to_id) * self.cleanup_ratio))
# 计算每个风格的价值分数
style_scores = []
for style_id in self.style_to_id.values():
# 使用次数
usage_count = self.learning_stats["style_counts"].get(style_id, 0)
# 最后使用时间(越近越好)
last_used = self.learning_stats["style_last_used"].get(style_id, 0)
time_since_used = current_time - last_used if last_used > 0 else float('inf')
time_since_used = current_time - last_used if last_used > 0 else float("inf")
# 综合分数:使用次数越多越好,距离上次使用时间越短越好
# 使用对数来平滑使用次数的影响
import math
usage_score = math.log1p(usage_count) # log(1 + count)
# 时间分数:转换为天数,使用指数衰减
days_unused = time_since_used / 86400 # 转换为天
time_score = math.exp(-days_unused / 30) # 30天衰减因子
# 综合分数80%使用频率 + 20%时间新鲜度
total_score = 0.8 * usage_score + 0.2 * time_score
style_scores.append((style_id, total_score, usage_count, days_unused))
# 按分数排序,分数低的先删除
style_scores.sort(key=lambda x: x[1])
# 删除分数最低的风格
deleted_styles = []
for style_id, score, usage, days in style_scores[:cleanup_count]:
@@ -156,27 +156,27 @@ class StyleLearner:
del self.id_to_style[style_id]
if style_id in self.id_to_situation:
del self.id_to_situation[style_id]
# 从统计中删除
if style_id in self.learning_stats["style_counts"]:
del self.learning_stats["style_counts"][style_id]
if style_id in self.learning_stats["style_last_used"]:
del self.learning_stats["style_last_used"][style_id]
# 从expressor模型中删除
self.expressor.remove_candidate(style_id)
deleted_styles.append((style_text[:30], usage, f"{days:.1f}"))
logger.info(
f"风格清理完成: 删除了 {len(deleted_styles)}/{len(style_scores)} 个风格,"
f"剩余 {len(self.style_to_id)} 个风格"
)
# 记录前5个被删除的风格用于调试
if deleted_styles:
logger.debug(f"被删除的风格样例(前5): {deleted_styles[:5]}")
except Exception as e:
logger.error(f"清理风格失败: {e}", exc_info=True)
@@ -303,10 +303,10 @@ class StyleLearner:
def cleanup_old_styles(self, ratio: float | None = None) -> int:
"""
手动清理旧风格
Args:
ratio: 清理比例如果为None则使用默认的cleanup_ratio
Returns:
清理的风格数量
"""
@@ -318,7 +318,7 @@ class StyleLearner:
self.cleanup_ratio = old_cleanup_ratio
else:
self._cleanup_styles()
new_count = len(self.style_to_id)
cleaned = old_count - new_count
logger.info(f"手动清理完成: chat_id={self.chat_id}, 清理了 {cleaned} 个风格")
@@ -357,11 +357,11 @@ class StyleLearner:
import pickle
meta_path = os.path.join(save_dir, "meta.pkl")
# 确保 learning_stats 包含所有必要字段
if "style_last_used" not in self.learning_stats:
self.learning_stats["style_last_used"] = {}
meta_data = {
"style_to_id": self.style_to_id,
"id_to_style": self.id_to_style,
@@ -416,7 +416,7 @@ class StyleLearner:
self.id_to_situation = meta_data["id_to_situation"]
self.next_style_id = meta_data["next_style_id"]
self.learning_stats = meta_data["learning_stats"]
# 确保旧数据兼容:如果没有 style_last_used 字段,添加它
if "style_last_used" not in self.learning_stats:
self.learning_stats["style_last_used"] = {}
@@ -526,10 +526,10 @@ class StyleLearnerManager:
def cleanup_all_old_styles(self, ratio: float | None = None) -> dict[str, int]:
"""
对所有学习器清理旧风格
Args:
ratio: 清理比例
Returns:
{chat_id: 清理数量}
"""
@@ -538,7 +538,7 @@ class StyleLearnerManager:
cleaned = learner.cleanup_old_styles(ratio)
if cleaned > 0:
cleanup_results[chat_id] = cleaned
total_cleaned = sum(cleanup_results.values())
logger.info(f"清理所有StyleLearner完成: 总共清理了 {total_cleaned} 个风格")
return cleanup_results

View File

@@ -8,7 +8,6 @@ from datetime import datetime
from typing import Any
import numpy as np
import orjson
from sqlalchemy import select
from src.common.config_helpers import resolve_embedding_dimension
@@ -124,7 +123,7 @@ class BotInterestManager:
tags_info = [f" - '{tag.tag_name}' (权重: {tag.weight:.2f})" for tag in loaded_interests.get_active_tags()]
tags_str = "\n".join(tags_info)
logger.info(f"当前兴趣标签:\n{tags_str}")
# 为加载的标签生成embedding数据库不存储embedding启动时动态生成
logger.info("🧠 为加载的标签生成embedding向量...")
await self._generate_embeddings_for_tags(loaded_interests)
@@ -326,13 +325,13 @@ class BotInterestManager:
raise RuntimeError("❌ Embedding客户端未初始化无法生成embedding")
total_tags = len(interests.interest_tags)
# 尝试从文件加载缓存
file_cache = await self._load_embedding_cache_from_file(interests.personality_id)
if file_cache:
logger.info(f"📂 从文件加载 {len(file_cache)} 个embedding缓存")
self.embedding_cache.update(file_cache)
logger.info(f"🧠 开始为 {total_tags} 个兴趣标签生成embedding向量...")
memory_cached_count = 0
@@ -477,14 +476,14 @@ class BotInterestManager:
self, message_text: str, keywords: list[str] | None = None
) -> InterestMatchResult:
"""计算消息与机器人兴趣的匹配度(优化版 - 标签扩展策略)
核心优化:将短标签扩展为完整的描述性句子,解决语义粒度不匹配问题
原问题:
- 消息: "今天天气不错" (完整句子)
- 标签: "蹭人治愈" (2-4字短语)
- 标签: "蹭人治愈" (2-4字短语)
- 结果: 误匹配,因为短标签的 embedding 过于抽象
解决方案:
- 标签扩展: "蹭人治愈" -> "表达亲近、寻求安慰、撒娇的内容"
- 现在是: 句子 vs 句子,匹配更准确
@@ -527,18 +526,18 @@ class BotInterestManager:
if tag.embedding:
# 🔧 优化:获取扩展标签的 embedding带缓存
expanded_embedding = await self._get_expanded_tag_embedding(tag.tag_name)
if expanded_embedding:
# 使用扩展标签的 embedding 进行匹配
similarity = self._calculate_cosine_similarity(message_embedding, expanded_embedding)
# 同时计算原始标签的相似度作为参考
original_similarity = self._calculate_cosine_similarity(message_embedding, tag.embedding)
# 混合策略扩展标签权重更高70%原始标签作为补充30%
# 这样可以兼顾准确性(扩展)和灵活性(原始)
final_similarity = similarity * 0.7 + original_similarity * 0.3
logger.debug(f"标签'{tag.tag_name}': 原始={original_similarity:.3f}, 扩展={similarity:.3f}, 最终={final_similarity:.3f}")
else:
# 如果扩展 embedding 获取失败,使用原始 embedding
@@ -603,27 +602,27 @@ class BotInterestManager:
logger.debug(
f"最终结果: 总分={result.overall_score:.3f}, 置信度={result.confidence:.3f}, 匹配标签数={len(result.matched_tags)}"
)
# 如果有新生成的扩展embedding保存到缓存文件
if hasattr(self, '_new_expanded_embeddings_generated') and self._new_expanded_embeddings_generated:
if hasattr(self, "_new_expanded_embeddings_generated") and self._new_expanded_embeddings_generated:
await self._save_embedding_cache_to_file(self.current_interests.personality_id)
self._new_expanded_embeddings_generated = False
logger.debug("💾 已保存新生成的扩展embedding到缓存文件")
return result
async def _get_expanded_tag_embedding(self, tag_name: str) -> list[float] | None:
"""获取扩展标签的 embedding带缓存
优先使用缓存,如果没有则生成并缓存
"""
# 检查缓存
if tag_name in self.expanded_embedding_cache:
return self.expanded_embedding_cache[tag_name]
# 扩展标签
expanded_tag = self._expand_tag_for_matching(tag_name)
# 生成 embedding
try:
embedding = await self._get_embedding(expanded_tag)
@@ -636,19 +635,19 @@ class BotInterestManager:
return embedding
except Exception as e:
logger.warning(f"为标签'{tag_name}'生成扩展embedding失败: {e}")
return None
def _expand_tag_for_matching(self, tag_name: str) -> str:
"""将短标签扩展为完整的描述性句子
这是解决"标签太短导致误匹配"的核心方法
策略:
1. 优先使用 LLM 生成的 expanded 字段(最准确)
2. 如果没有,使用基于规则的回退方案
3. 最后使用通用模板
示例:
- "Python" + expanded -> "讨论Python编程语言、写Python代码、Python脚本开发、Python技术问题"
- "蹭人治愈" + expanded -> "想要获得安慰、寻求温暖关怀、撒娇卖萌、表达亲昵、求抱抱求陪伴的对话"
@@ -656,7 +655,7 @@ class BotInterestManager:
# 使用缓存
if tag_name in self.expanded_tag_cache:
return self.expanded_tag_cache[tag_name]
# 🎯 优先策略:使用 LLM 生成的 expanded 字段
if self.current_interests:
for tag in self.current_interests.interest_tags:
@@ -664,66 +663,66 @@ class BotInterestManager:
logger.debug(f"✅ 使用LLM生成的扩展描述: {tag_name} -> {tag.expanded[:50]}...")
self.expanded_tag_cache[tag_name] = tag.expanded
return tag.expanded
# 🔧 回退策略基于规则的扩展用于兼容旧数据或LLM未生成扩展的情况
logger.debug(f"⚠️ 标签'{tag_name}'没有LLM扩展描述使用规则回退方案")
tag_lower = tag_name.lower()
# 技术编程类标签(具体化描述)
if any(word in tag_lower for word in ['python', 'java', 'code', '代码', '编程', '脚本', '算法', '开发']):
if 'python' in tag_lower:
return f"讨论Python编程语言、写Python代码、Python脚本开发、Python技术问题"
elif '算法' in tag_lower:
return f"讨论算法题目、数据结构、编程竞赛、刷LeetCode题目、代码优化"
elif '代码' in tag_lower or '被窝' in tag_lower:
return f"讨论写代码、编程开发、代码实现、技术方案、编程技巧"
if any(word in tag_lower for word in ["python", "java", "code", "代码", "编程", "脚本", "算法", "开发"]):
if "python" in tag_lower:
return "讨论Python编程语言、写Python代码、Python脚本开发、Python技术问题"
elif "算法" in tag_lower:
return "讨论算法题目、数据结构、编程竞赛、刷LeetCode题目、代码优化"
elif "代码" in tag_lower or "被窝" in tag_lower:
return "讨论写代码、编程开发、代码实现、技术方案、编程技巧"
else:
return f"讨论编程开发、软件技术、代码编写、技术实现"
return "讨论编程开发、软件技术、代码编写、技术实现"
# 情感表达类标签(具体化为真实对话场景)
elif any(word in tag_lower for word in ['治愈', '撒娇', '安慰', '呼噜', '', '卖萌']):
return f"想要获得安慰、寻求温暖关怀、撒娇卖萌、表达亲昵、求抱抱求陪伴的对话"
elif any(word in tag_lower for word in ["治愈", "撒娇", "安慰", "呼噜", "", "卖萌"]):
return "想要获得安慰、寻求温暖关怀、撒娇卖萌、表达亲昵、求抱抱求陪伴的对话"
# 游戏娱乐类标签(具体游戏场景)
elif any(word in tag_lower for word in ['游戏', '网游', 'mmo', '', '']):
return f"讨论网络游戏、MMO游戏、游戏玩法、组队打副本、游戏攻略心得"
elif any(word in tag_lower for word in ["游戏", "网游", "mmo", "", ""]):
return "讨论网络游戏、MMO游戏、游戏玩法、组队打副本、游戏攻略心得"
# 动漫影视类标签(具体观看行为)
elif any(word in tag_lower for word in ['', '动漫', '视频', 'b站', '弹幕', '追番', '云新番']):
elif any(word in tag_lower for word in ["", "动漫", "视频", "b站", "弹幕", "追番", "云新番"]):
# 特别处理"云新番" - 它的意思是在网上看新动漫,不是泛泛的"新东西"
if '' in tag_lower or '新番' in tag_lower:
return f"讨论正在播出的新动漫、新番剧集、动漫剧情、追番心得、动漫角色"
if "" in tag_lower or "新番" in tag_lower:
return "讨论正在播出的新动漫、新番剧集、动漫剧情、追番心得、动漫角色"
else:
return f"讨论动漫番剧内容、B站视频、弹幕文化、追番体验"
return "讨论动漫番剧内容、B站视频、弹幕文化、追番体验"
# 社交平台类标签(具体平台行为)
elif any(word in tag_lower for word in ['小红书', '贴吧', '论坛', '社区', '吃瓜', '八卦']):
if '吃瓜' in tag_lower:
return f"聊八卦爆料、吃瓜看热闹、网络热点事件、社交平台热议话题"
elif any(word in tag_lower for word in ["小红书", "贴吧", "论坛", "社区", "吃瓜", "八卦"]):
if "吃瓜" in tag_lower:
return "聊八卦爆料、吃瓜看热闹、网络热点事件、社交平台热议话题"
else:
return f"讨论社交平台内容、网络社区话题、论坛讨论、分享生活"
return "讨论社交平台内容、网络社区话题、论坛讨论、分享生活"
# 生活日常类标签(具体萌宠场景)
elif any(word in tag_lower for word in ['', '宠物', '尾巴', '耳朵', '毛绒']):
return f"讨论猫咪宠物、晒猫分享、萌宠日常、可爱猫猫、养猫心得"
elif any(word in tag_lower for word in ["", "宠物", "尾巴", "耳朵", "毛绒"]):
return "讨论猫咪宠物、晒猫分享、萌宠日常、可爱猫猫、养猫心得"
# 状态心情类标签(具体情绪状态)
elif any(word in tag_lower for word in ['社恐', '隐身', '流浪', '深夜', '被窝']):
if '社恐' in tag_lower:
return f"表达社交焦虑、不想见人、想躲起来、害怕社交的心情"
elif '深夜' in tag_lower:
return f"深夜睡不着、熬夜、夜猫子、深夜思考人生的对话"
elif any(word in tag_lower for word in ["社恐", "隐身", "流浪", "深夜", "被窝"]):
if "社恐" in tag_lower:
return "表达社交焦虑、不想见人、想躲起来、害怕社交的心情"
elif "深夜" in tag_lower:
return "深夜睡不着、熬夜、夜猫子、深夜思考人生的对话"
else:
return f"表达当前心情状态、个人感受、生活状态"
return "表达当前心情状态、个人感受、生活状态"
# 物品装备类标签(具体使用场景)
elif any(word in tag_lower for word in ['键盘', '耳机', '装备', '设备']):
return f"讨论键盘耳机装备、数码产品、使用体验、装备推荐评测"
elif any(word in tag_lower for word in ["键盘", "耳机", "装备", "设备"]):
return "讨论键盘耳机装备、数码产品、使用体验、装备推荐评测"
# 互动关系类标签
elif any(word in tag_lower for word in ['拾风', '互怼', '互动']):
return f"聊天互动、开玩笑、友好互怼、日常对话交流"
elif any(word in tag_lower for word in ["拾风", "互怼", "互动"]):
return "聊天互动、开玩笑、友好互怼、日常对话交流"
# 默认:尽量具体化
else:
return f"明确讨论{tag_name}这个特定主题的具体内容和相关话题"
@@ -1011,56 +1010,58 @@ class BotInterestManager:
async def _load_embedding_cache_from_file(self, personality_id: str) -> dict[str, list[float]] | None:
"""从文件加载embedding缓存"""
try:
import orjson
from pathlib import Path
import orjson
cache_dir = Path("data/embedding")
cache_dir.mkdir(parents=True, exist_ok=True)
cache_file = cache_dir / f"{personality_id}_embeddings.json"
if not cache_file.exists():
logger.debug(f"📂 Embedding缓存文件不存在: {cache_file}")
return None
# 读取缓存文件
with open(cache_file, "rb") as f:
cache_data = orjson.loads(f.read())
# 验证缓存版本和embedding模型
cache_version = cache_data.get("version", 1)
cache_embedding_model = cache_data.get("embedding_model", "")
current_embedding_model = self.embedding_config.model_list[0] if hasattr(self.embedding_config, "model_list") else ""
if cache_embedding_model != current_embedding_model:
logger.warning(f"⚠️ Embedding模型已变更 ({cache_embedding_model}{current_embedding_model}),忽略旧缓存")
return None
embeddings = cache_data.get("embeddings", {})
# 同时加载扩展标签的embedding缓存
expanded_embeddings = cache_data.get("expanded_embeddings", {})
if expanded_embeddings:
self.expanded_embedding_cache.update(expanded_embeddings)
logger.info(f"📂 加载 {len(expanded_embeddings)} 个扩展标签embedding缓存")
logger.info(f"✅ 成功从文件加载 {len(embeddings)} 个标签embedding缓存 (版本: {cache_version}, 模型: {cache_embedding_model})")
return embeddings
except Exception as e:
logger.warning(f"⚠️ 加载embedding缓存文件失败: {e}")
return None
async def _save_embedding_cache_to_file(self, personality_id: str):
"""保存embedding缓存到文件包括扩展标签的embedding"""
try:
import orjson
from pathlib import Path
from datetime import datetime
from pathlib import Path
import orjson
cache_dir = Path("data/embedding")
cache_dir.mkdir(parents=True, exist_ok=True)
cache_file = cache_dir / f"{personality_id}_embeddings.json"
# 准备缓存数据
current_embedding_model = self.embedding_config.model_list[0] if hasattr(self.embedding_config, "model_list") and self.embedding_config.model_list else ""
cache_data = {
@@ -1071,13 +1072,13 @@ class BotInterestManager:
"embeddings": self.embedding_cache,
"expanded_embeddings": self.expanded_embedding_cache, # 同时保存扩展标签的embedding
}
# 写入文件
with open(cache_file, "wb") as f:
f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2))
logger.debug(f"💾 已保存 {len(self.embedding_cache)} 个标签embedding和 {len(self.expanded_embedding_cache)} 个扩展embedding到缓存文件: {cache_file}")
except Exception as e:
logger.warning(f"⚠️ 保存embedding缓存文件失败: {e}")

View File

@@ -9,8 +9,8 @@ from .scheduler_dispatcher import SchedulerDispatcher, scheduler_dispatcher
__all__ = [
"MessageManager",
"SingleStreamContextManager",
"SchedulerDispatcher",
"SingleStreamContextManager",
"message_manager",
"scheduler_dispatcher",
]

View File

@@ -73,7 +73,7 @@ class SingleStreamContextManager:
cache_enabled = global_config.chat.enable_message_cache
use_cache_system = message_manager.is_running and cache_enabled
if not cache_enabled:
logger.debug(f"消息缓存系统已在配置中禁用")
logger.debug("消息缓存系统已在配置中禁用")
except Exception as e:
logger.debug(f"MessageManager不可用使用直接添加: {e}")
use_cache_system = False
@@ -129,13 +129,13 @@ class SingleStreamContextManager:
await self._calculate_message_interest(message)
self.total_messages += 1
self.last_access_time = time.time()
logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}")
return True
# 不应该到达这里,但为了类型检查添加返回值
return True
except Exception as e:
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
return False

View File

@@ -4,13 +4,11 @@
"""
import asyncio
import random
import time
from collections import defaultdict, deque
from typing import TYPE_CHECKING, Any
from src.chat.chatter_manager import ChatterManager
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.planner_actions.action_manager import ChatterActionManager
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_manager_data_model import MessageManagerStats, StreamStats
@@ -77,7 +75,7 @@ class MessageManager:
# 启动基于 scheduler 的消息分发器
await scheduler_dispatcher.start()
scheduler_dispatcher.set_chatter_manager(self.chatter_manager)
# 保留旧的流循环管理器(暂时)以便平滑过渡
# TODO: 在确认新机制稳定后移除
# await stream_loop_manager.start()
@@ -108,7 +106,7 @@ class MessageManager:
# 停止基于 scheduler 的消息分发器
await scheduler_dispatcher.stop()
# 停止旧的流循环管理器(如果启用)
# await stream_loop_manager.stop()
@@ -116,7 +114,7 @@ class MessageManager:
async def add_message(self, stream_id: str, message: DatabaseMessages):
"""添加消息到指定聊天流
新的流程:
1. 检查 notice 消息
2. 将消息添加到上下文(缓存)
@@ -149,10 +147,10 @@ class MessageManager:
if not chat_stream:
logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在")
return
# 将消息添加到上下文
await chat_stream.context_manager.add_message(message)
# 通知 scheduler_dispatcher 处理消息接收事件
# dispatcher 会检查是否需要打断、创建或更新 schedule
await scheduler_dispatcher.on_message_received(stream_id)

View File

@@ -20,7 +20,7 @@ logger = get_logger("scheduler_dispatcher")
class SchedulerDispatcher:
"""基于 scheduler 的消息分发器
工作流程:
1. 接收消息时,将消息添加到聊天流上下文
2. 检查是否有活跃的 schedule如果没有则创建
@@ -32,10 +32,17 @@ class SchedulerDispatcher:
def __init__(self):
# 追踪每个流的 schedule_id
self.stream_schedules: dict[str, str] = {} # stream_id -> schedule_id
<<<<<<< HEAD
=======
# 用于保护 schedule 创建/删除的锁,避免竞态条件
self.schedule_locks: dict[str, asyncio.Lock] = {} # stream_id -> Lock
>>>>>>> b0ee26652efd8d03c2e407575b8c5eea78f6afec
# Chatter 管理器
self.chatter_manager: ChatterManager | None = None
# 统计信息
self.stats = {
"total_schedules_created": 0,
@@ -45,9 +52,9 @@ class SchedulerDispatcher:
"total_failures": 0,
"start_time": time.time(),
}
self.is_running = False
logger.info("基于 Scheduler 的消息分发器初始化完成")
async def start(self) -> None:
@@ -55,7 +62,7 @@ class SchedulerDispatcher:
if self.is_running:
logger.warning("分发器已在运行")
return
self.is_running = True
logger.info("基于 Scheduler 的消息分发器已启动")
@@ -63,9 +70,9 @@ class SchedulerDispatcher:
"""停止分发器"""
if not self.is_running:
return
self.is_running = False
# 取消所有活跃的 schedule
schedule_ids = list(self.stream_schedules.values())
for schedule_id in schedule_ids:
@@ -73,7 +80,7 @@ class SchedulerDispatcher:
await unified_scheduler.remove_schedule(schedule_id)
except Exception as e:
logger.error(f"移除 schedule {schedule_id} 失败: {e}")
self.stream_schedules.clear()
logger.info("基于 Scheduler 的消息分发器已停止")
@@ -81,43 +88,52 @@ class SchedulerDispatcher:
"""设置 Chatter 管理器"""
self.chatter_manager = chatter_manager
logger.debug(f"设置 Chatter 管理器: {chatter_manager.__class__.__name__}")
<<<<<<< HEAD
=======
def _get_schedule_lock(self, stream_id: str) -> asyncio.Lock:
"""获取流的 schedule 锁"""
if stream_id not in self.schedule_locks:
self.schedule_locks[stream_id] = asyncio.Lock()
return self.schedule_locks[stream_id]
>>>>>>> b0ee26652efd8d03c2e407575b8c5eea78f6afec
async def on_message_received(self, stream_id: str) -> None:
"""消息接收时的处理逻辑
Args:
stream_id: 聊天流ID
"""
if not self.is_running:
logger.warning("分发器未运行,忽略消息")
return
try:
# 1. 获取流上下文
context = await self._get_stream_context(stream_id)
if not context:
logger.warning(f"无法获取流上下文: {stream_id}")
return
# 2. 检查是否有活跃的 schedule
has_active_schedule = stream_id in self.stream_schedules
if not has_active_schedule:
# 4. 创建新的 schedule在锁内避免重复创建
await self._create_schedule(stream_id, context)
return
# 3. 检查打断判定
if has_active_schedule:
should_interrupt = await self._check_interruption(stream_id, context)
if should_interrupt:
# 移除旧 schedule 并创建新的
await self._cancel_and_recreate_schedule(stream_id, context)
logger.debug(f"⚡ 打断成功: 流={stream_id[:8]}..., 已重新创建 schedule")
else:
logger.debug(f"打断判定失败,保持原有 schedule: 流={stream_id[:8]}...")
except Exception as e:
logger.error(f"处理消息接收事件失败 {stream_id}: {e}", exc_info=True)
@@ -135,18 +151,18 @@ class SchedulerDispatcher:
async def _check_interruption(self, stream_id: str, context: StreamContext) -> bool:
"""检查是否应该打断当前处理
Args:
stream_id: 流ID
context: 流上下文
Returns:
bool: 是否应该打断
"""
# 检查是否启用打断
if not global_config.chat.interruption_enabled:
return False
# 检查是否正在回复,以及是否允许在回复时打断
if context.is_replying:
if not global_config.chat.allow_reply_interruption:
@@ -154,49 +170,49 @@ class SchedulerDispatcher:
return False
else:
logger.debug(f"聊天流 {stream_id} 正在回复中,但配置允许回复时打断")
# 只有当 Chatter 真正在处理时才检查打断
if not context.is_chatter_processing:
logger.debug(f"聊天流 {stream_id} Chatter 未在处理,无需打断")
return False
# 检查最后一条消息
last_message = context.get_last_message()
if not last_message:
return False
# 检查是否为表情包消息
if last_message.is_picid or last_message.is_emoji:
logger.info(f"消息 {last_message.message_id} 是表情包或Emoji跳过打断检查")
return False
# 检查触发用户ID
triggering_user_id = context.triggering_user_id
if triggering_user_id and last_message.user_info.user_id != triggering_user_id:
logger.info(f"消息来自非触发用户 {last_message.user_info.user_id},实际触发用户为 {triggering_user_id},跳过打断检查")
return False
# 检查是否已达到最大打断次数
if context.interruption_count >= global_config.chat.interruption_max_limit:
logger.debug(
f"聊天流 {stream_id} 已达到最大打断次数 {context.interruption_count}/{global_config.chat.interruption_max_limit}"
)
return False
# 计算打断概率
interruption_probability = context.calculate_interruption_probability(
global_config.chat.interruption_max_limit
)
# 根据概率决定是否打断
import random
if random.random() < interruption_probability:
logger.debug(f"聊天流 {stream_id} 触发消息打断,打断概率: {interruption_probability:.2f}")
# 增加打断计数
await context.increment_interruption_count()
self.stats["total_interruptions"] += 1
# 检查是否已达到最大次数
if context.interruption_count >= global_config.chat.interruption_max_limit:
logger.warning(
@@ -206,7 +222,7 @@ class SchedulerDispatcher:
logger.info(
f"聊天流 {stream_id} 已打断,当前打断次数: {context.interruption_count}/{global_config.chat.interruption_max_limit}"
)
return True
else:
logger.debug(f"聊天流 {stream_id} 未触发打断,打断概率: {interruption_probability:.2f}")
@@ -214,7 +230,7 @@ class SchedulerDispatcher:
async def _cancel_and_recreate_schedule(self, stream_id: str, context: StreamContext) -> None:
"""取消旧的 schedule 并创建新的(打断模式,使用极短延迟)
Args:
stream_id: 流ID
context: 流上下文
@@ -235,13 +251,13 @@ class SchedulerDispatcher:
)
# 移除失败,不创建新 schedule避免重复
return
# 创建新的 schedule使用即时处理模式极短延迟
await self._create_schedule(stream_id, context, immediate_mode=True)
async def _create_schedule(self, stream_id: str, context: StreamContext, immediate_mode: bool = False) -> None:
"""为聊天流创建新的 schedule
Args:
stream_id: 流ID
context: 流上下文
@@ -257,7 +273,7 @@ class SchedulerDispatcher:
)
await unified_scheduler.remove_schedule(old_schedule_id)
del self.stream_schedules[stream_id]
# 如果是即时处理模式打断时使用固定的1秒延迟立即重新处理
if immediate_mode:
delay = 1.0 # 硬编码1秒延迟确保打断后能快速重新处理
@@ -268,10 +284,10 @@ class SchedulerDispatcher:
else:
# 常规模式:计算初始延迟
delay = await self._calculate_initial_delay(stream_id, context)
# 获取未读消息数量用于日志
unread_count = len(context.unread_messages) if context.unread_messages else 0
# 创建 schedule
schedule_id = await unified_scheduler.create_schedule(
callback=self._on_schedule_triggered,
@@ -281,41 +297,41 @@ class SchedulerDispatcher:
task_name=f"dispatch_{stream_id[:8]}",
callback_args=(stream_id,),
)
# 追踪 schedule
self.stream_schedules[stream_id] = schedule_id
self.stats["total_schedules_created"] += 1
mode_indicator = "⚡打断" if immediate_mode else "📅常规"
logger.info(
f"{mode_indicator} 创建 schedule: 流={stream_id[:8]}..., "
f"延迟={delay:.3f}s, 未读={unread_count}, "
f"ID={schedule_id[:8]}..."
)
except Exception as e:
logger.error(f"创建 schedule 失败 {stream_id}: {e}", exc_info=True)
async def _calculate_initial_delay(self, stream_id: str, context: StreamContext) -> float:
"""计算初始延迟时间
Args:
stream_id: 流ID
context: 流上下文
Returns:
float: 延迟时间(秒)
"""
# 基础间隔
base_interval = getattr(global_config.chat, "distribution_interval", 5.0)
# 检查是否有未读消息
unread_count = len(context.unread_messages) if context.unread_messages else 0
# 强制分发阈值
force_dispatch_threshold = getattr(global_config.chat, "force_dispatch_unread_threshold", 20)
# 如果未读消息过多,使用最小间隔
if force_dispatch_threshold and unread_count > force_dispatch_threshold:
min_interval = getattr(global_config.chat, "force_dispatch_min_interval", 0.1)
@@ -325,24 +341,24 @@ class SchedulerDispatcher:
f"使用最小间隔={min_interval}s"
)
return min_interval
# 尝试使用能量管理器计算间隔
try:
# 更新能量值
await self._update_stream_energy(stream_id, context)
# 获取当前 focus_energy
focus_energy = energy_manager.energy_cache.get(stream_id, (0.5, 0))[0]
# 使用能量管理器计算间隔
interval = energy_manager.get_distribution_interval(focus_energy)
logger.info(
f"📊 动态间隔计算: 流={stream_id[:8]}..., "
f"能量={focus_energy:.3f}, 间隔={interval:.2f}s"
)
return interval
except Exception as e:
logger.info(
f"📊 使用默认间隔: 流={stream_id[:8]}..., "
@@ -352,96 +368,96 @@ class SchedulerDispatcher:
async def _update_stream_energy(self, stream_id: str, context: StreamContext) -> None:
"""更新流的能量值
Args:
stream_id: 流ID
context: 流上下文
"""
try:
from src.chat.message_receive.chat_stream import get_chat_manager
# 获取聊天流
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.debug(f"无法找到聊天流 {stream_id},跳过能量更新")
return
# 合并未读消息和历史消息
all_messages = []
# 添加历史消息
history_messages = context.get_history_messages(limit=global_config.chat.max_context_size)
all_messages.extend(history_messages)
# 添加未读消息
unread_messages = context.get_unread_messages()
all_messages.extend(unread_messages)
# 按时间排序并限制数量
all_messages.sort(key=lambda m: m.time)
messages = all_messages[-global_config.chat.max_context_size:]
# 获取用户ID
user_id = context.triggering_user_id
# 使用能量管理器计算并缓存能量值
energy = await energy_manager.calculate_focus_energy(
stream_id=stream_id,
messages=messages,
user_id=user_id
)
# 同步更新到 ChatStream
chat_stream._focus_energy = energy
logger.debug(f"已更新流 {stream_id} 的能量值: {energy:.3f}")
except Exception as e:
logger.warning(f"更新流能量失败 {stream_id}: {e}", exc_info=False)
async def _on_schedule_triggered(self, stream_id: str) -> None:
"""schedule 触发时的回调
Args:
stream_id: 流ID
"""
try:
old_schedule_id = self.stream_schedules.get(stream_id)
logger.info(
f"⏰ Schedule 触发: 流={stream_id[:8]}..., "
f"ID={old_schedule_id[:8] if old_schedule_id else 'None'}..., "
f"开始处理消息"
)
# 获取流上下文
context = await self._get_stream_context(stream_id)
if not context:
logger.warning(f"Schedule 触发时无法获取流上下文: {stream_id}")
return
# 检查是否有未读消息
if not context.unread_messages:
logger.debug(f"{stream_id} 没有未读消息,跳过处理")
return
# 激活 chatter 处理(不需要锁,允许并发处理)
success = await self._process_stream(stream_id, context)
# 更新统计
self.stats["total_process_cycles"] += 1
if not success:
self.stats["total_failures"] += 1
self.stream_schedules.pop(stream_id, None)
# 检查缓存中是否有待处理的消息
from src.chat.message_manager.message_manager import message_manager
has_cached = message_manager.has_cached_messages(stream_id)
if has_cached:
# 有缓存消息,立即创建新 schedule 继续处理
logger.info(
@@ -455,60 +471,60 @@ class SchedulerDispatcher:
f"✅ 处理完成且无缓存消息: 流={stream_id[:8]}..., "
f"等待新消息到达"
)
except Exception as e:
logger.error(f"Schedule 回调执行失败 {stream_id}: {e}", exc_info=True)
async def _process_stream(self, stream_id: str, context: StreamContext) -> bool:
"""处理流消息
Args:
stream_id: 流ID
context: 流上下文
Returns:
bool: 是否处理成功
"""
if not self.chatter_manager:
logger.warning(f"Chatter 管理器未设置: {stream_id}")
return False
# 设置处理状态
self._set_stream_processing_status(stream_id, True)
try:
start_time = time.time()
# 设置触发用户ID
last_message = context.get_last_message()
if last_message:
context.triggering_user_id = last_message.user_info.user_id
# 创建异步任务刷新能量(不阻塞主流程)
energy_task = asyncio.create_task(self._refresh_focus_energy(stream_id))
# 设置 Chatter 正在处理的标志
context.is_chatter_processing = True
logger.debug(f"设置 Chatter 处理标志: {stream_id}")
try:
# 调用 chatter_manager 处理流上下文
results = await self.chatter_manager.process_stream_context(stream_id, context)
success = results.get("success", False)
if success:
process_time = time.time() - start_time
logger.debug(f"流处理成功: {stream_id} (耗时: {process_time:.2f}s)")
else:
logger.warning(f"流处理失败: {stream_id} - {results.get('error_message', '未知错误')}")
return success
finally:
# 清除 Chatter 处理标志
context.is_chatter_processing = False
logger.debug(f"清除 Chatter 处理标志: {stream_id}")
# 等待能量刷新任务完成
try:
await asyncio.wait_for(energy_task, timeout=5.0)
@@ -516,11 +532,11 @@ class SchedulerDispatcher:
logger.warning(f"等待能量刷新超时: {stream_id}")
except Exception as e:
logger.debug(f"能量刷新任务异常: {e}")
except Exception as e:
logger.error(f"流处理异常: {stream_id} - {e}", exc_info=True)
return False
finally:
# 设置处理状态为未处理
self._set_stream_processing_status(stream_id, False)
@@ -529,11 +545,11 @@ class SchedulerDispatcher:
"""设置流的处理状态"""
try:
from src.chat.message_manager.message_manager import message_manager
if message_manager.is_running:
message_manager.set_stream_processing_status(stream_id, is_processing)
logger.debug(f"设置流处理状态: stream={stream_id}, processing={is_processing}")
except ImportError:
logger.debug("MessageManager 不可用,跳过状态设置")
except Exception as e:
@@ -547,7 +563,7 @@ class SchedulerDispatcher:
if not chat_stream:
logger.debug(f"刷新能量时未找到聊天流: {stream_id}")
return
await chat_stream.context_manager.refresh_focus_energy_from_history()
logger.debug(f"已刷新聊天流 {stream_id} 的聚焦能量")
except Exception as e:

View File

@@ -367,7 +367,7 @@ class ChatBot:
message_segment = message_data.get("message_segment")
if message_segment and isinstance(message_segment, dict):
if message_segment.get("type") == "adapter_response":
logger.info(f"[DEBUG bot.py message_process] 检测到adapter_response立即处理")
logger.info("[DEBUG bot.py message_process] 检测到adapter_response立即处理")
await self._handle_adapter_response_from_dict(message_segment.get("data"))
return

View File

@@ -205,7 +205,7 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
return result
else:
logger.warning(f"[at处理] 无法解析格式: '{segment.data}'")
return f"@{segment.data}"
return f"@{segment.data}"
logger.warning(f"[at处理] 数据类型异常: {type(segment.data)}")
return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户"

View File

@@ -542,7 +542,7 @@ class DefaultReplyer:
all_memories = []
try:
from src.memory_graph.manager_singleton import get_memory_manager, is_initialized
if is_initialized():
manager = get_memory_manager()
if manager:
@@ -552,12 +552,12 @@ class DefaultReplyer:
sender_name = ""
if user_info_obj:
sender_name = getattr(user_info_obj, "user_nickname", "") or getattr(user_info_obj, "user_cardname", "")
# 获取参与者信息
participants = []
try:
# 尝试从聊天流中获取参与者信息
if hasattr(stream, 'chat_history_manager'):
if hasattr(stream, "chat_history_manager"):
history_manager = stream.chat_history_manager
# 获取最近的参与者列表
recent_records = history_manager.get_memory_chat_history(
@@ -586,16 +586,16 @@ class DefaultReplyer:
formatted_history = ""
if chat_history:
# 移除过长的历史记录,只保留最近部分
lines = chat_history.strip().split('\n')
lines = chat_history.strip().split("\n")
recent_lines = lines[-10:] if len(lines) > 10 else lines
formatted_history = '\n'.join(recent_lines)
formatted_history = "\n".join(recent_lines)
query_context = {
"chat_history": formatted_history,
"sender": sender_name,
"participants": participants,
}
# 使用记忆管理器的智能检索(多查询策略)
memories = await manager.search_memories(
query=target,
@@ -605,23 +605,23 @@ class DefaultReplyer:
use_multi_query=True,
context=query_context,
)
if memories:
logger.info(f"[记忆图] 检索到 {len(memories)} 条相关记忆")
# 使用新的格式化工具构建完整的记忆描述
from src.memory_graph.utils.memory_formatter import (
format_memory_for_prompt,
get_memory_type_label,
)
for memory in memories:
# 使用格式化工具生成完整的主谓宾描述
content = format_memory_for_prompt(memory, include_metadata=False)
# 获取记忆类型
mem_type = memory.memory_type.value if memory.memory_type else "未知"
if content:
all_memories.append({
"content": content,
@@ -636,7 +636,7 @@ class DefaultReplyer:
except Exception as e:
logger.debug(f"[记忆图] 检索失败: {e}")
all_memories = []
# 构建记忆字符串,使用方括号格式
memory_str = ""
has_any_memory = False
@@ -725,7 +725,7 @@ class DefaultReplyer:
for tool_result in tool_results:
tool_name = tool_result.get("tool_name", "unknown")
content = tool_result.get("content", "")
result_type = tool_result.get("type", "tool_result")
tool_result.get("type", "tool_result")
# 不进行截断,让工具自己处理结果长度
current_results_parts.append(f"- **{tool_name}**: {content}")
@@ -744,7 +744,7 @@ class DefaultReplyer:
logger.error(f"工具信息获取失败: {e}")
return ""
def _parse_reply_target(self, target_message: str) -> tuple[str, str]:
"""解析回复目标消息 - 使用共享工具"""
from src.chat.utils.prompt import Prompt
@@ -1897,7 +1897,7 @@ class DefaultReplyer:
async def _store_chat_memory_async(self, reply_to: str, reply_message: DatabaseMessages | dict[str, Any] | None = None):
"""
[已废弃] 异步存储聊天记忆从build_memory_block迁移而来
此函数已被记忆图系统的工具调用方式替代。
记忆现在由LLM在对话过程中通过CreateMemoryTool主动创建。
@@ -1906,14 +1906,13 @@ class DefaultReplyer:
reply_message: 回复的原始消息
"""
return # 已禁用,保留函数签名以防其他地方有引用
# 以下代码已废弃,不再执行
try:
if not global_config.memory.enable_memory:
return
# 使用统一记忆系统存储记忆
from src.chat.memory_system import get_memory_system
stream = self.chat_stream
user_info_obj = getattr(stream, "user_info", None)
@@ -2036,7 +2035,7 @@ class DefaultReplyer:
timestamp=time.time(),
limit=int(global_config.chat.max_context_size),
)
chat_history = await build_readable_messages(
await build_readable_messages(
message_list_before_short,
replace_bot_name=True,
merge_messages=False,

View File

@@ -400,7 +400,7 @@ class Prompt:
# 初始化预构建参数字典
pre_built_params = {}
try:
# --- 步骤 1: 准备构建任务 ---
tasks = []

View File

@@ -87,20 +87,18 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
)
processed_text = message.processed_plain_text or ""
# 1. 判断是否为私聊(强提及)
group_info = getattr(message, "group_info", None)
if not group_info or not getattr(group_info, "group_id", None):
is_private = True
mention_type = 2
logger.debug("检测到私聊消息 - 强提及")
# 2. 判断是否被@(强提及)
if re.search(rf"@<(.+?):{global_config.bot.qq_account}>", processed_text):
is_at = True
mention_type = 2
logger.debug("检测到@提及 - 强提及")
# 3. 判断是否被回复(强提及)
if re.match(
rf"\[回复 (.+?)\({global_config.bot.qq_account!s}\)(.+?)\],说:", processed_text
@@ -108,10 +106,9 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
rf"\[回复<(.+?)(?=:{global_config.bot.qq_account!s}>)\:{global_config.bot.qq_account!s}>(.+?)\],说:",
processed_text,
):
is_replied = True
mention_type = 2
logger.debug("检测到回复消息 - 强提及")
# 4. 判断文本中是否提及bot名字或别名弱提及
if mention_type == 0: # 只有在没有强提及时才检查弱提及
# 移除@和回复标记后再检查
@@ -119,21 +116,19 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
message_content = re.sub(r"@<(.+?)(?=:(\d+))\:(\d+)>", "", message_content)
message_content = re.sub(r"\[回复 (.+?)\(((\d+)|未知id)\)(.+?)\],说:", "", message_content)
message_content = re.sub(r"\[回复<(.+?)(?=:(\d+))\:(\d+)>(.+?)\],说:", "", message_content)
# 检查bot主名字
if global_config.bot.nickname in message_content:
is_text_mentioned = True
mention_type = 1
logger.debug(f"检测到文本提及bot主名字 '{global_config.bot.nickname}' - 弱提及")
# 如果主名字没匹配,再检查别名
elif nicknames:
for alias_name in nicknames:
if alias_name in message_content:
is_text_mentioned = True
mention_type = 1
logger.debug(f"检测到文本提及bot别名 '{alias_name}' - 弱提及")
break
# 返回结果
is_mentioned = mention_type > 0
return is_mentioned, float(mention_type)

View File

@@ -368,13 +368,13 @@ class CacheManager:
if expired_keys:
logger.info(f"清理了 {len(expired_keys)} 个过期的L1缓存条目")
def get_health_stats(self) -> dict[str, Any]:
"""获取缓存健康统计信息"""
# 简化的健康统计,不包含内存监控(因为相关属性未定义)
return {
"l1_count": len(self.l1_kv_cache),
"l1_vector_count": self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0,
"l1_vector_count": self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, "ntotal") else 0,
"tool_stats": {
"total_tool_calls": self.tool_stats.get("total_tool_calls", 0),
"tracked_tools": len(self.tool_stats.get("most_used_tools", {})),
@@ -397,7 +397,7 @@ class CacheManager:
warnings.append(f"⚠️ L1缓存条目数较多: {l1_size}")
# 检查向量索引大小
vector_count = self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0
vector_count = self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, "ntotal") else 0
if isinstance(vector_count, int) and vector_count > 500:
warnings.append(f"⚠️ 向量索引条目数较多: {vector_count}")

View File

@@ -66,7 +66,7 @@ class BatchStats:
last_batch_duration: float = 0.0
last_batch_size: int = 0
congestion_score: float = 0.0 # 拥塞评分 (0-1)
# 🔧 新增:缓存统计
cache_size: int = 0 # 缓存条目数
cache_memory_mb: float = 0.0 # 缓存内存占用MB
@@ -539,8 +539,7 @@ class AdaptiveBatchScheduler:
def _set_cache(self, cache_key: str, result: Any) -> None:
"""设置缓存(改进版,带大小限制和内存统计)"""
import sys
# 🔧 检查缓存大小限制
if len(self._result_cache) >= self._cache_max_size:
# 首先清理过期条目
@@ -549,18 +548,18 @@ class AdaptiveBatchScheduler:
k for k, (_, ts) in self._result_cache.items()
if current_time - ts >= self.cache_ttl
]
for k in expired_keys:
# 更新内存统计
if k in self._cache_size_map:
self._cache_memory_estimate -= self._cache_size_map[k]
del self._cache_size_map[k]
del self._result_cache[k]
# 如果还是太大清理最老的条目LRU
if len(self._result_cache) >= self._cache_max_size:
oldest_key = min(
self._result_cache.keys(),
self._result_cache.keys(),
key=lambda k: self._result_cache[k][1]
)
# 更新内存统计
@@ -569,7 +568,7 @@ class AdaptiveBatchScheduler:
del self._cache_size_map[oldest_key]
del self._result_cache[oldest_key]
logger.debug(f"缓存已满,淘汰最老条目: {oldest_key}")
# 🔧 使用准确的内存估算方法
try:
total_size = estimate_size_smart(cache_key) + estimate_size_smart(result)
@@ -580,7 +579,7 @@ class AdaptiveBatchScheduler:
# 使用默认值
self._cache_size_map[cache_key] = 1024
self._cache_memory_estimate += 1024
self._result_cache[cache_key] = (result, time.time())
async def get_stats(self) -> BatchStats:

View File

@@ -171,7 +171,7 @@ class LRUCache(Generic[T]):
)
else:
adjusted_created_at = now
entry = CacheEntry(
value=value,
created_at=adjusted_created_at,
@@ -345,7 +345,7 @@ class MultiLevelCache:
# 估算数据大小(如果未提供)
if size is None:
size = estimate_size_smart(value)
# 检查单个条目大小是否超过限制
if size > self.max_item_size_bytes:
logger.warning(
@@ -354,7 +354,7 @@ class MultiLevelCache:
f"limit={self.max_item_size_bytes / (1024 * 1024):.2f}MB"
)
return
# 根据TTL决定写入哪个缓存层
if ttl is not None:
# 有自定义TTL根据TTL大小决定写入层级
@@ -394,37 +394,37 @@ class MultiLevelCache:
"""获取所有缓存层的统计信息(修正版,避免重复计数)"""
l1_stats = await self.l1_cache.get_stats()
l2_stats = await self.l2_cache.get_stats()
# 🔧 修复计算实际独占的内存避免L1和L2共享数据的重复计数
l1_keys = set(self.l1_cache._cache.keys())
l2_keys = set(self.l2_cache._cache.keys())
shared_keys = l1_keys & l2_keys
l1_only_keys = l1_keys - l2_keys
l2_only_keys = l2_keys - l1_keys
# 计算实际总内存(避免重复计数)
# L1独占内存
l1_only_size = sum(
self.l1_cache._cache[k].size
for k in l1_only_keys
self.l1_cache._cache[k].size
for k in l1_only_keys
if k in self.l1_cache._cache
)
# L2独占内存
l2_only_size = sum(
self.l2_cache._cache[k].size
for k in l2_only_keys
self.l2_cache._cache[k].size
for k in l2_only_keys
if k in self.l2_cache._cache
)
# 共享内存只计算一次使用L1的数据
shared_size = sum(
self.l1_cache._cache[k].size
for k in shared_keys
self.l1_cache._cache[k].size
for k in shared_keys
if k in self.l1_cache._cache
)
actual_total_size = l1_only_size + l2_only_size + shared_size
return {
"l1": l1_stats,
"l2": l2_stats,
@@ -442,7 +442,7 @@ class MultiLevelCache:
"""检查并强制清理超出内存限制的缓存"""
stats = await self.get_stats()
total_size = stats["l1"].total_size + stats["l2"].total_size
if total_size > self.max_memory_bytes:
memory_mb = total_size / (1024 * 1024)
max_mb = self.max_memory_bytes / (1024 * 1024)
@@ -452,14 +452,14 @@ class MultiLevelCache:
)
# 优先清理L2缓存温数据
await self.l2_cache.clear()
# 如果清理L2后仍超限清理L1
stats_after_l2 = await self.get_stats()
total_after_l2 = stats_after_l2["l1"].total_size + stats_after_l2["l2"].total_size
if total_after_l2 > self.max_memory_bytes:
logger.warning("清理L2后仍超限继续清理L1缓存")
await self.l1_cache.clear()
logger.info("缓存强制清理完成")
async def start_cleanup_task(self, interval: float = 60) -> None:
@@ -476,10 +476,10 @@ class MultiLevelCache:
while not self._is_closing:
try:
await asyncio.sleep(interval)
if self._is_closing:
break
stats = await self.get_stats()
l1_stats = stats["l1"]
l2_stats = stats["l2"]
@@ -493,13 +493,13 @@ class MultiLevelCache:
f"共享: {stats['shared_keys_count']}键/{stats['shared_mb']:.2f}MB "
f"(去重节省{stats['dedup_savings_mb']:.2f}MB)"
)
# 🔧 清理过期条目
await self._clean_expired_entries()
# 检查内存限制
await self.check_memory_limit()
except asyncio.CancelledError:
break
except Exception as e:
@@ -511,7 +511,7 @@ class MultiLevelCache:
async def stop_cleanup_task(self) -> None:
"""停止清理任务"""
self._is_closing = True
if self._cleanup_task is not None:
self._cleanup_task.cancel()
try:
@@ -520,43 +520,43 @@ class MultiLevelCache:
pass
self._cleanup_task = None
logger.info("缓存清理任务已停止")
async def _clean_expired_entries(self) -> None:
"""清理过期的缓存条目"""
try:
current_time = time.time()
# 清理 L1 过期条目
async with self.l1_cache._lock:
expired_keys = [
key for key, entry in self.l1_cache._cache.items()
if current_time - entry.created_at > self.l1_cache.ttl
]
for key in expired_keys:
entry = self.l1_cache._cache.pop(key, None)
if entry:
self.l1_cache._stats.evictions += 1
self.l1_cache._stats.item_count -= 1
self.l1_cache._stats.total_size -= entry.size
# 清理 L2 过期条目
async with self.l2_cache._lock:
expired_keys = [
key for key, entry in self.l2_cache._cache.items()
if current_time - entry.created_at > self.l2_cache.ttl
]
for key in expired_keys:
entry = self.l2_cache._cache.pop(key, None)
if entry:
self.l2_cache._stats.evictions += 1
self.l2_cache._stats.item_count -= 1
self.l2_cache._stats.total_size -= entry.size
if expired_keys:
logger.debug(f"清理了 {len(expired_keys)} 个过期缓存条目")
except Exception as e:
logger.error(f"清理过期条目失败: {e}", exc_info=True)
@@ -568,7 +568,7 @@ _cache_lock = asyncio.Lock()
async def get_cache() -> MultiLevelCache:
"""获取全局缓存实例(单例)
从配置文件读取缓存参数,如果配置未加载则使用默认值
如果配置中禁用了缓存返回一个最小化的缓存实例容量为1
"""
@@ -580,9 +580,9 @@ async def get_cache() -> MultiLevelCache:
# 尝试从配置读取参数
try:
from src.config.config import global_config
db_config = global_config.database
# 检查是否启用缓存
if not db_config.enable_database_cache:
logger.info("数据库缓存已禁用,使用最小化缓存实例")
@@ -594,7 +594,7 @@ async def get_cache() -> MultiLevelCache:
max_memory_mb=1,
)
return _global_cache
l1_max_size = db_config.cache_l1_max_size
l1_ttl = db_config.cache_l1_ttl
l2_max_size = db_config.cache_l2_max_size
@@ -602,7 +602,7 @@ async def get_cache() -> MultiLevelCache:
max_memory_mb = db_config.cache_max_memory_mb
max_item_size_mb = db_config.cache_max_item_size_mb
cleanup_interval = db_config.cache_cleanup_interval
logger.info(
f"从配置加载缓存参数: L1({l1_max_size}/{l1_ttl}s), "
f"L2({l2_max_size}/{l2_ttl}s), 内存限制({max_memory_mb}MB), "
@@ -618,7 +618,7 @@ async def get_cache() -> MultiLevelCache:
max_memory_mb = 100
max_item_size_mb = 1
cleanup_interval = 60
_global_cache = MultiLevelCache(
l1_max_size=l1_max_size,
l1_ttl=l1_ttl,

View File

@@ -4,73 +4,74 @@
提供比 sys.getsizeof() 更准确的内存占用估算方法
"""
import sys
import pickle
import sys
from typing import Any
import numpy as np
def get_accurate_size(obj: Any, seen: set | None = None) -> int:
"""
准确估算对象的内存大小(递归计算所有引用对象)
比 sys.getsizeof() 准确得多,特别是对于复杂嵌套对象。
Args:
obj: 要估算大小的对象
seen: 已访问对象的集合(用于避免循环引用)
Returns:
估算的字节数
"""
if seen is None:
seen = set()
obj_id = id(obj)
if obj_id in seen:
return 0
seen.add(obj_id)
size = sys.getsizeof(obj)
# NumPy 数组特殊处理
if isinstance(obj, np.ndarray):
size += obj.nbytes
return size
# 字典:递归计算所有键值对
if isinstance(obj, dict):
size += sum(get_accurate_size(k, seen) + get_accurate_size(v, seen)
size += sum(get_accurate_size(k, seen) + get_accurate_size(v, seen)
for k, v in obj.items())
# 列表、元组、集合:递归计算所有元素
elif isinstance(obj, (list, tuple, set, frozenset)):
elif isinstance(obj, list | tuple | set | frozenset):
size += sum(get_accurate_size(item, seen) for item in obj)
# 有 __dict__ 的对象:递归计算属性
elif hasattr(obj, '__dict__'):
elif hasattr(obj, "__dict__"):
size += get_accurate_size(obj.__dict__, seen)
# 其他可迭代对象
elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)):
elif hasattr(obj, "__iter__") and not isinstance(obj, str | bytes | bytearray):
try:
size += sum(get_accurate_size(item, seen) for item in obj)
except:
pass
return size
def get_pickle_size(obj: Any) -> int:
"""
使用 pickle 序列化大小作为参考
通常比 sys.getsizeof() 更接近实际内存占用,
但可能略小于真实内存占用(不包括 Python 对象开销)
Args:
obj: 要估算大小的对象
Returns:
pickle 序列化后的字节数,失败返回 0
"""
@@ -83,17 +84,17 @@ def get_pickle_size(obj: Any) -> int:
def estimate_size_smart(obj: Any, max_depth: int = 5, sample_large: bool = True) -> int:
"""
智能估算对象大小(平衡准确性和性能)
使用深度受限的递归估算+采样策略,平衡准确性和性能:
- 深度5层足以覆盖99%的缓存数据结构
- 对大型容器(>100项进行采样估算
- 性能开销约60倍于sys.getsizeof但准确度提升1000+倍
Args:
obj: 要估算大小的对象
max_depth: 最大递归深度默认5层可覆盖大多数嵌套结构
sample_large: 对大型容器是否采样默认True提升性能
Returns:
估算的字节数
"""
@@ -105,24 +106,24 @@ def _estimate_recursive(obj: Any, depth: int, seen: set, sample_large: bool) ->
# 检查深度限制
if depth <= 0:
return sys.getsizeof(obj)
# 检查循环引用
obj_id = id(obj)
if obj_id in seen:
return 0
seen.add(obj_id)
# 基本大小
size = sys.getsizeof(obj)
# 简单类型直接返回
if isinstance(obj, (int, float, bool, type(None), str, bytes, bytearray)):
if isinstance(obj, int | float | bool | type(None) | str | bytes | bytearray):
return size
# NumPy 数组特殊处理
if isinstance(obj, np.ndarray):
return size + obj.nbytes
# 字典递归
if isinstance(obj, dict):
items = list(obj.items())
@@ -130,7 +131,7 @@ def _estimate_recursive(obj: Any, depth: int, seen: set, sample_large: bool) ->
# 大字典采样前50 + 中间50 + 最后50
sample_items = items[:50] + items[len(items)//2-25:len(items)//2+25] + items[-50:]
sampled_size = sum(
_estimate_recursive(k, depth - 1, seen, sample_large) +
_estimate_recursive(k, depth - 1, seen, sample_large) +
_estimate_recursive(v, depth - 1, seen, sample_large)
for k, v in sample_items
)
@@ -142,9 +143,9 @@ def _estimate_recursive(obj: Any, depth: int, seen: set, sample_large: bool) ->
size += _estimate_recursive(k, depth - 1, seen, sample_large)
size += _estimate_recursive(v, depth - 1, seen, sample_large)
return size
# 列表、元组、集合递归
if isinstance(obj, (list, tuple, set, frozenset)):
if isinstance(obj, list | tuple | set | frozenset):
items = list(obj)
if sample_large and len(items) > 100:
# 大容器采样前50 + 中间50 + 最后50
@@ -160,21 +161,21 @@ def _estimate_recursive(obj: Any, depth: int, seen: set, sample_large: bool) ->
for item in items:
size += _estimate_recursive(item, depth - 1, seen, sample_large)
return size
# 有 __dict__ 的对象
if hasattr(obj, '__dict__'):
if hasattr(obj, "__dict__"):
size += _estimate_recursive(obj.__dict__, depth - 1, seen, sample_large)
return size
def format_size(size_bytes: int) -> str:
"""
格式化字节数为人类可读的格式
Args:
size_bytes: 字节数
Returns:
格式化后的字符串,如 "1.23 MB"
"""

View File

@@ -2,7 +2,8 @@ import os
import socket
from fastapi import APIRouter, FastAPI
from fastapi.middleware.cors import CORSMiddleware # 新增导入
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from rich.traceback import install
from uvicorn import Config
from uvicorn import Server as UvicornServer

View File

@@ -2,7 +2,6 @@ import os
import shutil
import sys
from datetime import datetime
from typing import Optional
import tomlkit
from pydantic import Field
@@ -381,7 +380,7 @@ class Config(ValidatedConfigBase):
notice: NoticeConfig = Field(..., description="Notice消息配置")
emoji: EmojiConfig = Field(..., description="表情配置")
expression: ExpressionConfig = Field(..., description="表达配置")
memory: Optional[MemoryConfig] = Field(default=None, description="记忆配置")
memory: MemoryConfig | None = Field(default=None, description="记忆配置")
mood: MoodConfig = Field(..., description="情绪配置")
reaction: ReactionConfig = Field(default_factory=ReactionConfig, description="反应规则配置")
chinese_typo: ChineseTypoConfig = Field(..., description="中文错别字配置")

View File

@@ -401,16 +401,16 @@ class MemoryConfig(ValidatedConfigBase):
memory_system_load_balancing: bool = Field(default=True, description="启用记忆系统负载均衡")
memory_build_throttling: bool = Field(default=True, description="启用记忆构建节流")
memory_priority_queue_enabled: bool = Field(default=True, description="启用记忆优先级队列")
# === 记忆图系统配置 (Memory Graph System) ===
# 新一代记忆系统的配置项
enable: bool = Field(default=True, description="启用记忆图系统")
data_dir: str = Field(default="data/memory_graph", description="记忆数据存储目录")
# 向量存储配置
vector_collection_name: str = Field(default="memory_nodes", description="向量集合名称")
vector_db_path: str = Field(default="data/memory_graph/chroma_db", description="向量数据库路径")
# 检索配置
search_top_k: int = Field(default=10, description="默认检索返回数量")
search_min_importance: float = Field(default=0.3, description="最小重要性阈值")
@@ -418,13 +418,13 @@ class MemoryConfig(ValidatedConfigBase):
search_max_expand_depth: int = Field(default=2, description="检索时图扩展深度0-3")
search_expand_semantic_threshold: float = Field(default=0.3, description="图扩展时语义相似度阈值建议0.3-0.5,过低可能引入无关记忆,过高无法扩展)")
enable_query_optimization: bool = Field(default=True, description="启用查询优化")
# 检索权重配置 (记忆图系统)
search_vector_weight: float = Field(default=0.4, description="向量相似度权重")
search_graph_distance_weight: float = Field(default=0.2, description="图距离权重")
search_importance_weight: float = Field(default=0.2, description="重要性权重")
search_recency_weight: float = Field(default=0.2, description="时效性权重")
# 记忆整合配置
consolidation_enabled: bool = Field(default=False, description="是否启用记忆整合")
consolidation_interval_hours: float = Field(default=2.0, description="整合任务执行间隔(小时)")
@@ -442,21 +442,21 @@ class MemoryConfig(ValidatedConfigBase):
consolidation_linking_min_confidence: float = Field(default=0.7, description="LLM分析最低置信度阈值")
consolidation_linking_llm_temperature: float = Field(default=0.2, description="LLM分析温度参数")
consolidation_linking_llm_max_tokens: int = Field(default=1500, description="LLM分析最大输出长度")
# 遗忘配置 (记忆图系统)
forgetting_enabled: bool = Field(default=True, description="是否启用自动遗忘")
forgetting_activation_threshold: float = Field(default=0.1, description="激活度阈值")
forgetting_min_importance: float = Field(default=0.8, description="最小保护重要性")
# 激活配置
activation_decay_rate: float = Field(default=0.9, description="激活度衰减率")
activation_propagation_strength: float = Field(default=0.5, description="激活传播强度")
activation_propagation_depth: int = Field(default=2, description="激活传播深度")
# 性能配置
max_memory_nodes_per_memory: int = Field(default=10, description="每个记忆最多包含的节点数")
max_related_memories: int = Field(default=5, description="相关记忆最大数量")
# 节点去重合并配置
node_merger_similarity_threshold: float = Field(default=0.85, description="节点去重相似度阈值")
node_merger_context_match_required: bool = Field(default=True, description="节点合并是否要求上下文匹配")
@@ -637,6 +637,7 @@ class WebSearchConfig(ValidatedConfigBase):
exa_api_keys: list[str] = Field(default_factory=lambda: [], description="exa API密钥列表支持轮询机制")
searxng_instances: list[str] = Field(default_factory=list, description="SearXNG 实例 URL 列表")
searxng_api_keys: list[str] = Field(default_factory=list, description="SearXNG 实例 API 密钥列表")
serper_api_keys: list[str] = Field(default_factory=list, description="serper API 密钥列表")
enabled_engines: list[str] = Field(default_factory=lambda: ["ddg"], description="启用的搜索引擎")
search_strategy: Literal["fallback", "single", "parallel"] = Field(default="single", description="搜索策略")

View File

@@ -534,7 +534,7 @@ class _RequestExecutor:
model_name = model_info.name
retry_interval = api_provider.retry_interval
if isinstance(e, (NetworkConnectionError, ReqAbortException)):
if isinstance(e, NetworkConnectionError | ReqAbortException):
return await self._check_retry(remain_try, retry_interval, "连接异常", model_name)
elif isinstance(e, RespNotOkException):
return await self._handle_resp_not_ok(e, model_info, api_provider, remain_try, messages_info)

View File

@@ -423,15 +423,16 @@ MoFox_Bot(第三方修改版)
# 注册API路由
try:
from src.api.memory_visualizer_router import router as visualizer_router
from src.api.message_router import router as message_router
from src.api.statistic_router import router as llm_statistic_router
self.server.register_router(message_router, prefix="/api")
self.server.register_router(llm_statistic_router, prefix="/api")
self.server.register_router(visualizer_router, prefix="/visualizer")
logger.info("API路由注册成功")
except Exception as e:
logger.error(f"注册API路由失败: {e}")
# 初始化统一调度器
try:
from src.schedule.unified_scheduler import initialize_scheduler

View File

@@ -100,10 +100,10 @@ class VectorStore:
# 处理额外的元数据,将 list 转换为 JSON 字符串
for key, value in node.metadata.items():
if isinstance(value, (list, dict)):
if isinstance(value, list | dict):
import orjson
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
elif isinstance(value, (str, int, float, bool)) or value is None:
elif isinstance(value, str | int | float | bool) or value is None:
metadata[key] = value
else:
metadata[key] = str(value)
@@ -149,9 +149,9 @@ class VectorStore:
"created_at": n.created_at.isoformat(),
}
for key, value in n.metadata.items():
if isinstance(value, (list, dict)):
if isinstance(value, list | dict):
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
elif isinstance(value, (str, int, float, bool)) or value is None:
elif isinstance(value, str | int | float | bool) or value is None:
metadata[key] = value # type: ignore
else:
metadata[key] = str(value)

View File

@@ -4,8 +4,6 @@
from __future__ import annotations
import asyncio
import numpy as np
from src.common.logger import get_logger
@@ -72,7 +70,7 @@ class EmbeddingGenerator:
logger.warning(f"⚠️ Embedding API 初始化失败: {e}")
self._api_available = False
async def generate(self, text: str) -> np.ndarray | None:
"""
生成单个文本的嵌入向量
@@ -130,7 +128,7 @@ class EmbeddingGenerator:
logger.debug(f"API 嵌入生成失败: {e}")
return None
def _get_dimension(self) -> int:
"""获取嵌入维度"""
# 优先使用 API 维度

View File

@@ -7,11 +7,12 @@
"""
import atexit
import orjson
import os
import threading
from typing import Any, ClassVar
import orjson
from src.common.logger import get_logger
# 获取日志记录器
@@ -125,7 +126,7 @@ class PluginStorage:
try:
with open(self.file_path, "w", encoding="utf-8") as f:
f.write(orjson.dumps(self._data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode('utf-8'))
f.write(orjson.dumps(self._data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode("utf-8"))
self._dirty = False # 保存后重置标志
logger.debug(f"插件 '{self.name}' 的数据已成功保存到磁盘。")
except Exception as e:

View File

@@ -5,12 +5,12 @@ MCP Client Manager
"""
import asyncio
import orjson
import shutil
from pathlib import Path
from typing import Any
import mcp.types
import orjson
from fastmcp.client import Client, StdioTransport, StreamableHttpTransport
from src.common.logger import get_logger

View File

@@ -4,11 +4,13 @@
"""
import time
from typing import Any, Optional
from dataclasses import dataclass, asdict, field
from dataclasses import dataclass, field
from typing import Any
import orjson
from src.common.logger import get_logger
from src.common.cache_manager import tool_cache
from src.common.logger import get_logger
logger = get_logger("stream_tool_history")
@@ -18,10 +20,10 @@ class ToolCallRecord:
"""工具调用记录"""
tool_name: str
args: dict[str, Any]
result: Optional[dict[str, Any]] = None
result: dict[str, Any] | None = None
status: str = "success" # success, error, pending
timestamp: float = field(default_factory=time.time)
execution_time: Optional[float] = None # 执行耗时(秒)
execution_time: float | None = None # 执行耗时(秒)
cache_hit: bool = False # 是否命中缓存
result_preview: str = "" # 结果预览
error_message: str = "" # 错误信息
@@ -32,9 +34,9 @@ class ToolCallRecord:
content = self.result.get("content", "")
if isinstance(content, str):
self.result_preview = content[:500] + ("..." if len(content) > 500 else "")
elif isinstance(content, (list, dict)):
elif isinstance(content, list | dict):
try:
self.result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')[:500] + "..."
self.result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")[:500] + "..."
except Exception:
self.result_preview = str(content)[:500] + "..."
else:
@@ -105,7 +107,7 @@ class StreamToolHistoryManager:
logger.debug(f"[{self.chat_id}] 添加工具调用记录: {record.tool_name}, 缓存命中: {record.cache_hit}")
async def get_cached_result(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]:
async def get_cached_result(self, tool_name: str, args: dict[str, Any]) -> dict[str, Any] | None:
"""从缓存或历史记录中获取结果
Args:
@@ -160,9 +162,9 @@ class StreamToolHistoryManager:
return None
async def cache_result(self, tool_name: str, args: dict[str, Any], result: dict[str, Any],
execution_time: Optional[float] = None,
tool_file_path: Optional[str] = None,
ttl: Optional[int] = None) -> None:
execution_time: float | None = None,
tool_file_path: str | None = None,
ttl: int | None = None) -> None:
"""缓存工具调用结果
Args:
@@ -207,7 +209,7 @@ class StreamToolHistoryManager:
except Exception as e:
logger.warning(f"[{self.chat_id}] 缓存设置失败: {e}")
async def get_recent_history(self, count: int = 5, status_filter: Optional[str] = None) -> list[ToolCallRecord]:
async def get_recent_history(self, count: int = 5, status_filter: str | None = None) -> list[ToolCallRecord]:
"""获取最近的历史记录
Args:
@@ -295,7 +297,7 @@ class StreamToolHistoryManager:
self._history.clear()
logger.info(f"[{self.chat_id}] 工具历史记录已清除")
def _search_memory_cache(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]:
def _search_memory_cache(self, tool_name: str, args: dict[str, Any]) -> dict[str, Any] | None:
"""在内存历史记录中搜索缓存
Args:
@@ -333,7 +335,7 @@ class StreamToolHistoryManager:
return tool_path_mapping.get(tool_name, f"src/plugins/tools/{tool_name}.py")
def _extract_semantic_query(self, tool_name: str, args: dict[str, Any]) -> Optional[str]:
def _extract_semantic_query(self, tool_name: str, args: dict[str, Any]) -> str | None:
"""提取语义查询参数
Args:
@@ -370,7 +372,7 @@ class StreamToolHistoryManager:
return ""
try:
args_str = orjson.dumps(args, option=orjson.OPT_SORT_KEYS).decode('utf-8')
args_str = orjson.dumps(args, option=orjson.OPT_SORT_KEYS).decode("utf-8")
if len(args_str) > max_length:
args_str = args_str[:max_length] + "..."
return args_str
@@ -411,4 +413,4 @@ def cleanup_stream_manager(chat_id: str) -> None:
"""
if chat_id in _stream_managers:
del _stream_managers[chat_id]
logger.info(f"已清理聊天 {chat_id} 的工具历史记录管理器")
logger.info(f"已清理聊天 {chat_id} 的工具历史记录管理器")

View File

@@ -1,5 +1,6 @@
import inspect
import time
from dataclasses import asdict
from typing import Any
from src.chat.utils.prompt import Prompt, global_prompt_manager
@@ -10,8 +11,7 @@ from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.plugin_system.core.stream_tool_history import get_stream_tool_history_manager, ToolCallRecord
from dataclasses import asdict
from src.plugin_system.core.stream_tool_history import ToolCallRecord, get_stream_tool_history_manager
logger = get_logger("tool_use")
@@ -140,7 +140,7 @@ class ToolExecutor:
# 构建工具调用历史文本
tool_history = self.history_manager.format_for_prompt(max_records=5, include_results=True)
# 获取人设信息
personality_core = global_config.personality.personality_core
personality_side = global_config.personality.personality_side
@@ -197,7 +197,7 @@ class ToolExecutor:
return tool_definitions
async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]:
"""执行工具调用
@@ -338,9 +338,8 @@ class ToolExecutor:
if tool_instance and result and tool_instance.enable_cache:
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
function_args.get(tool_instance.semantic_cache_query_key)
await self.history_manager.cache_result(
tool_name=tool_call.func_name,

View File

@@ -122,7 +122,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
+ relationship_score * self.score_weights["relationship"]
+ mentioned_score * self.score_weights["mentioned"]
)
# 限制总分上限为1.0,确保分数在合理范围内
total_score = min(raw_total_score, 1.0)
@@ -131,7 +131,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
f"{relationship_score:.3f}*{self.score_weights['relationship']} + "
f"{mentioned_score:.3f}*{self.score_weights['mentioned']} = {raw_total_score:.3f}"
)
if raw_total_score > 1.0:
logger.debug(f"[Affinity兴趣计算] 原始分数 {raw_total_score:.3f} 超过1.0,已限制为 {total_score:.3f}")
@@ -217,7 +217,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
return 0.0
except asyncio.TimeoutError:
logger.warning(f"⏱️ 兴趣匹配计算超时(>1.5秒)返回默认分值0.5以保留其他分数")
logger.warning("⏱️ 兴趣匹配计算超时(>1.5秒)返回默认分值0.5以保留其他分数")
return 0.5 # 超时时返回默认分值,避免丢失提及分和关系分
except Exception as e:
logger.warning(f"智能兴趣匹配失败: {e}")
@@ -251,19 +251,19 @@ class AffinityInterestCalculator(BaseInterestCalculator):
def _calculate_mentioned_score(self, message: "DatabaseMessages", bot_nickname: str) -> float:
"""计算提及分 - 区分强提及和弱提及
强提及(被@、被回复、私聊): 使用 strong_mention_interest_score
弱提及(文本匹配名字/别名): 使用 weak_mention_interest_score
"""
from src.chat.utils.utils import is_mentioned_bot_in_message
# 使用统一的提及检测函数
is_mentioned, mention_type = is_mentioned_bot_in_message(message)
if not is_mentioned:
logger.debug("[提及分计算] 未提及机器人返回0.0")
return 0.0
# mention_type: 0=未提及, 1=弱提及, 2=强提及
if mention_type >= 2:
# 强提及:被@、被回复、私聊
@@ -281,22 +281,22 @@ class AffinityInterestCalculator(BaseInterestCalculator):
def _apply_no_reply_threshold_adjustment(self) -> tuple[float, float]:
"""应用阈值调整(包括连续不回复和回复后降低机制)
Returns:
tuple[float, float]: (调整后的回复阈值, 调整后的动作阈值)
"""
# 基础阈值
base_reply_threshold = self.reply_threshold
base_action_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
total_reduction = 0.0
# 1. 连续不回复的阈值降低
if self.no_reply_count > 0 and self.no_reply_count < self.max_no_reply_count:
no_reply_reduction = self.no_reply_count * self.probability_boost_per_no_reply
total_reduction += no_reply_reduction
logger.debug(f"[阈值调整] 连续不回复降低: {no_reply_reduction:.3f} (计数: {self.no_reply_count})")
# 2. 回复后的阈值降低使bot更容易连续对话
if self.enable_post_reply_boost and self.post_reply_boost_remaining > 0:
# 计算衰减后的降低值
@@ -309,16 +309,16 @@ class AffinityInterestCalculator(BaseInterestCalculator):
f"[阈值调整] 回复后降低: {post_reply_reduction:.3f} "
f"(剩余次数: {self.post_reply_boost_remaining}, 衰减: {decay_factor:.2f})"
)
# 应用总降低量
adjusted_reply_threshold = max(0.0, base_reply_threshold - total_reduction)
adjusted_action_threshold = max(0.0, base_action_threshold - total_reduction)
return adjusted_reply_threshold, adjusted_action_threshold
def _apply_no_reply_boost(self, base_score: float) -> float:
"""【已弃用】应用连续不回复的概率提升
注意:此方法已被 _apply_no_reply_threshold_adjustment 替代
保留用于向后兼容
"""
@@ -388,7 +388,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
self.no_reply_count = 0
else:
self.no_reply_count = min(self.no_reply_count + 1, self.max_no_reply_count)
def on_reply_sent(self):
"""当机器人发送回复后调用,激活回复后阈值降低机制"""
if self.enable_post_reply_boost:
@@ -399,16 +399,16 @@ class AffinityInterestCalculator(BaseInterestCalculator):
)
# 同时重置不回复计数
self.no_reply_count = 0
def on_message_processed(self, replied: bool):
"""消息处理完成后调用,更新各种计数器
Args:
replied: 是否回复了此消息
"""
# 更新不回复计数
self.update_no_reply_count(replied)
# 如果已回复,激活回复后降低机制
if replied:
self.on_reply_sent()

View File

@@ -4,10 +4,10 @@ AffinityFlow Chatter 规划器模块
包含计划生成、过滤、执行等规划相关功能
"""
from . import planner_prompts
from .plan_executor import ChatterPlanExecutor
from .plan_filter import ChatterPlanFilter
from .plan_generator import ChatterPlanGenerator
from .planner import ChatterActionPlanner
from . import planner_prompts
__all__ = ["ChatterActionPlanner", "planner_prompts", "ChatterPlanGenerator", "ChatterPlanFilter", "ChatterPlanExecutor"]
__all__ = ["ChatterActionPlanner", "ChatterPlanExecutor", "ChatterPlanFilter", "ChatterPlanGenerator", "planner_prompts"]

View File

@@ -14,9 +14,7 @@ from json_repair import repair_json
# 旧的Hippocampus系统已被移除现在使用增强记忆系统
# from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
from src.chat.utils.chat_message_builder import (
build_readable_actions,
build_readable_messages_with_id,
get_actions_by_timestamp_with_chat,
)
from src.chat.utils.prompt import global_prompt_manager
from src.common.data_models.info_data_model import ActionPlannerInfo, Plan
@@ -646,7 +644,7 @@ class ChatterPlanFilter:
memory_manager = get_memory_manager()
if not memory_manager:
return "记忆系统未初始化。"
# 将关键词转换为查询字符串
query = " ".join(keywords)
enhanced_memories = await memory_manager.search_memories(

View File

@@ -21,7 +21,6 @@ if TYPE_CHECKING:
from src.common.data_models.message_manager_data_model import StreamContext
# 导入提示词模块以确保其被初始化
from src.plugins.built_in.affinity_flow_chatter.planner import planner_prompts
logger = get_logger("planner")
@@ -159,10 +158,10 @@ class ChatterActionPlanner:
action_data={},
action_message=None,
)
# 更新连续不回复计数
await self._update_interest_calculator_state(replied=False)
initial_plan = await self.generator.generate(chat_mode)
filtered_plan = initial_plan
filtered_plan.decided_actions = [no_action]
@@ -270,7 +269,7 @@ class ChatterActionPlanner:
try:
# Normal模式开始时刷新缓存消息到未读列表
await self._flush_cached_messages_to_unread(context)
unread_messages = context.get_unread_messages() if context else []
if not unread_messages:
@@ -347,7 +346,7 @@ class ChatterActionPlanner:
self._update_stats_from_execution_result(execution_result)
logger.info("Normal模式: 执行reply动作完成")
# 更新兴趣计算器状态(回复成功,重置不回复计数)
await self._update_interest_calculator_state(replied=True)
@@ -465,7 +464,7 @@ class ChatterActionPlanner:
async def _update_interest_calculator_state(self, replied: bool) -> None:
"""更新兴趣计算器状态(连续不回复计数和回复后降低机制)
Args:
replied: 是否回复了消息
"""
@@ -504,36 +503,36 @@ class ChatterActionPlanner:
async def _flush_cached_messages_to_unread(self, context: "StreamContext | None") -> list:
"""在planner开始时将缓存消息刷新到未读消息列表
此方法在动作修改器执行后、生成初始计划前调用,确保计划阶段能看到所有积累的消息。
Args:
context: 流上下文
Returns:
list: 刷新的消息列表
"""
if not context:
return []
try:
from src.chat.message_manager.message_manager import message_manager
stream_id = context.stream_id
if message_manager.is_running and message_manager.has_cached_messages(stream_id):
# 获取缓存消息
cached_messages = message_manager.flush_cached_messages(stream_id)
if cached_messages:
# 直接添加到上下文的未读消息列表
for message in cached_messages:
context.unread_messages.append(message)
logger.info(f"Planner开始前刷新缓存消息到未读列表: stream={stream_id}, 数量={len(cached_messages)}")
return cached_messages
return []
except ImportError:
logger.debug("MessageManager不可用跳过缓存刷新")
return []

View File

@@ -9,9 +9,9 @@ from .proactive_thinking_executor import execute_proactive_thinking
from .proactive_thinking_scheduler import ProactiveThinkingScheduler, proactive_thinking_scheduler
__all__ = [
"ProactiveThinkingReplyHandler",
"ProactiveThinkingMessageHandler",
"execute_proactive_thinking",
"ProactiveThinkingReplyHandler",
"ProactiveThinkingScheduler",
"execute_proactive_thinking",
"proactive_thinking_scheduler",
]

View File

@@ -3,7 +3,6 @@
当定时任务触发时负责搜集信息、调用LLM决策、并根据决策生成回复
"""
import orjson
from datetime import datetime
from typing import Any, Literal

View File

@@ -14,7 +14,6 @@ from maim_message import UserInfo
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger
from src.config.api_ada_configs import TaskConfig
from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis import config_api, generator_api, llm_api
@@ -320,7 +319,7 @@ class ContentService:
- 禁止在说说中直接、完整地提及当前的年月日,除非日期有特殊含义,但也尽量用节日名/节气名字代替。
2. **严禁重复**:下方会提供你最近发过的说说历史,你必须创作一条全新的、与历史记录内容和主题都不同的说说。
**其他的禁止的内容以及说明**
- 绝对禁止提及当下具体几点几分的时间戳。
- 绝对禁止攻击性内容和过度的负面情绪。

View File

@@ -136,10 +136,10 @@ class QZoneService:
logger.info(f"[DEBUG] 准备获取API客户端qq_account={qq_account}")
api_client = await self._get_api_client(qq_account, stream_id)
if not api_client:
logger.error(f"[DEBUG] API客户端获取失败返回错误")
logger.error("[DEBUG] API客户端获取失败返回错误")
return {"success": False, "message": "获取QZone API客户端失败"}
logger.info(f"[DEBUG] API客户端获取成功准备读取说说")
logger.info("[DEBUG] API客户端获取成功准备读取说说")
num_to_read = self.get_config("read.read_number", 5)
# 尝试执行如果Cookie失效则自动重试一次
@@ -186,7 +186,7 @@ class QZoneService:
# 检查是否是Cookie失效-3000错误
if "错误码: -3000" in error_msg and retry_count == 0:
logger.warning(f"检测到Cookie失效-3000错误准备删除缓存并重试...")
logger.warning("检测到Cookie失效-3000错误准备删除缓存并重试...")
# 删除Cookie缓存文件
cookie_file = self.cookie_service._get_cookie_file_path(qq_account)
@@ -623,7 +623,7 @@ class QZoneService:
logger.error(f"获取API客户端失败Cookie中缺少关键的 'p_skey'。Cookie内容: {cookies}")
return None
logger.info(f"[DEBUG] p_skey获取成功")
logger.info("[DEBUG] p_skey获取成功")
gtk = self._generate_gtk(p_skey)
uin = cookies.get("uin", "").lstrip("o")
@@ -1230,7 +1230,7 @@ class QZoneService:
logger.error(f"监控好友动态失败: {e}", exc_info=True)
return []
logger.info(f"[DEBUG] API客户端构造完成返回包含6个方法的字典")
logger.info("[DEBUG] API客户端构造完成返回包含6个方法的字典")
return {
"publish": _publish,
"list_feeds": _list_feeds,

View File

@@ -3,11 +3,12 @@
负责记录和管理已回复过的评论ID避免重复回复
"""
import orjson
import time
from pathlib import Path
from typing import Any
import orjson
from src.common.logger import get_logger
logger = get_logger("MaiZone.ReplyTrackerService")
@@ -117,8 +118,8 @@ class ReplyTrackerService:
temp_file = self.reply_record_file.with_suffix(".tmp")
# 先写入临时文件
with open(temp_file, "w", encoding="utf-8") as f:
orjson.dumps(self.replied_comments, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode('utf-8')
with open(temp_file, "w", encoding="utf-8"):
orjson.dumps(self.replied_comments, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode("utf-8")
# 如果写入成功,重命名为正式文件
if temp_file.stat().st_size > 0: # 确保写入成功

View File

@@ -1,7 +1,6 @@
import orjson
import random
import time
import random
import websockets as Server
import uuid
from maim_message import (
@@ -205,7 +204,7 @@ class SendHandler:
# 发送响应回MoFox-Bot
logger.debug(f"[DEBUG handle_adapter_command] 即将调用send_adapter_command_response, request_id={request_id}")
await self.send_adapter_command_response(raw_message_base, response, request_id)
logger.debug(f"[DEBUG handle_adapter_command] send_adapter_command_response调用完成")
logger.debug("[DEBUG handle_adapter_command] send_adapter_command_response调用完成")
if response.get("status") == "ok":
logger.info(f"适配器命令 {action} 执行成功")

View File

@@ -1,10 +1,10 @@
"""
Metaso Search Engine (Chat Completions Mode)
"""
import orjson
from typing import Any
import httpx
import orjson
from src.common.logger import get_logger
from src.plugin_system.apis import config_api

View File

@@ -3,9 +3,10 @@ Serper search engine implementation
Google Search via Serper.dev API
"""
import aiohttp
from typing import Any
import aiohttp
from src.common.logger import get_logger
from src.plugin_system.apis import config_api

View File

@@ -5,7 +5,7 @@ Web Search Tool Plugin
"""
from src.common.logger import get_logger
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, PythonDependency, register_plugin
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, register_plugin
from src.plugin_system.apis import config_api
from .tools.url_parser import URLParserTool

View File

@@ -113,7 +113,7 @@ class WebSurfingTool(BaseTool):
custom_args["num_results"] = custom_args.get("num_results", 5)
# 如果启用了answer模式且是Exa引擎使用answer_search方法
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'):
if answer_mode and engine_name == "exa" and hasattr(engine, "answer_search"):
search_tasks.append(engine.answer_search(custom_args))
else:
search_tasks.append(engine.search(custom_args))
@@ -162,7 +162,7 @@ class WebSurfingTool(BaseTool):
custom_args["num_results"] = custom_args.get("num_results", 5)
# 如果启用了answer模式且是Exa引擎使用answer_search方法
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'):
if answer_mode and engine_name == "exa" and hasattr(engine, "answer_search"):
logger.info("使用Exa答案模式进行搜索fallback策略")
results = await engine.answer_search(custom_args)
else:
@@ -195,7 +195,7 @@ class WebSurfingTool(BaseTool):
custom_args["num_results"] = custom_args.get("num_results", 5)
# 如果启用了answer模式且是Exa引擎使用answer_search方法
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'):
if answer_mode and engine_name == "exa" and hasattr(engine, "answer_search"):
logger.info("使用Exa答案模式进行搜索")
results = await engine.answer_search(custom_args)
else:

View File

@@ -266,13 +266,13 @@ class UnifiedScheduler:
name=f"execute_{task.task_name}"
)
execution_tasks.append(execution_task)
# 追踪正在执行的任务,以便在 remove_schedule 时可以取消
self._executing_tasks[task.schedule_id] = execution_task
# 等待所有任务完成(使用 return_exceptions=True 避免单个任务失败影响其他任务)
results = await asyncio.gather(*execution_tasks, return_exceptions=True)
# 清理执行追踪
for task in tasks_to_trigger:
self._executing_tasks.pop(task.schedule_id, None)
@@ -515,7 +515,7 @@ class UnifiedScheduler:
async def remove_schedule(self, schedule_id: str) -> bool:
"""移除调度任务
如果任务正在执行,会取消执行中的任务
"""
async with self._lock:
@@ -524,7 +524,7 @@ class UnifiedScheduler:
return False
task = self._tasks[schedule_id]
# 检查是否有正在执行的任务
executing_task = self._executing_tasks.get(schedule_id)
if executing_task and not executing_task.done():

View File

@@ -19,42 +19,42 @@ logger = get_logger(__name__)
def extract_and_parse_json(response: str, *, strict: bool = False) -> dict[str, Any] | list | None:
"""
从 LLM 响应中提取并解析 JSON
处理策略:
1. 清理 Markdown 代码块标记(```json 和 ```
2. 提取 JSON 对象或数组
3. 使用 json_repair 修复格式问题
4. 解析为 Python 对象
Args:
response: LLM 响应字符串
strict: 严格模式,如果为 True 则解析失败时返回 None否则尝试容错处理
Returns:
解析后的 dict 或 list失败时返回 None
Examples:
>>> extract_and_parse_json('```json\\n{"key": "value"}\\n```')
{'key': 'value'}
>>> extract_and_parse_json('Some text {"key": "value"} more text')
{'key': 'value'}
>>> extract_and_parse_json('[{"a": 1}, {"b": 2}]')
[{'a': 1}, {'b': 2}]
"""
if not response:
logger.debug("空响应,无法解析 JSON")
return None
try:
# 步骤 1: 清理响应
cleaned = _clean_llm_response(response)
if not cleaned:
logger.warning("清理后的响应为空")
return None
# 步骤 2: 尝试直接解析
try:
result = orjson.loads(cleaned)
@@ -62,11 +62,11 @@ def extract_and_parse_json(response: str, *, strict: bool = False) -> dict[str,
return result
except Exception as direct_error:
logger.debug(f"直接解析失败: {type(direct_error).__name__}: {direct_error}")
# 步骤 3: 使用 json_repair 修复并解析
try:
repaired = repair_json(cleaned)
# repair_json 可能返回字符串或已解析的对象
if isinstance(repaired, str):
result = orjson.loads(repaired)
@@ -74,16 +74,16 @@ def extract_and_parse_json(response: str, *, strict: bool = False) -> dict[str,
else:
result = repaired
logger.debug(f"✅ JSON 修复后解析成功(对象模式),类型: {type(result).__name__}")
return result
except Exception as repair_error:
logger.warning(f"JSON 修复失败: {type(repair_error).__name__}: {repair_error}")
if strict:
logger.error(f"严格模式下解析失败,响应片段: {cleaned[:200]}")
return None
# 最后的容错尝试:返回空字典或空列表
if cleaned.strip().startswith("["):
logger.warning("返回空列表作为容错")
@@ -91,7 +91,7 @@ def extract_and_parse_json(response: str, *, strict: bool = False) -> dict[str,
else:
logger.warning("返回空字典作为容错")
return {}
except Exception as e:
logger.error(f"❌ JSON 解析过程出现异常: {type(e).__name__}: {e}")
if strict:
@@ -102,37 +102,37 @@ def extract_and_parse_json(response: str, *, strict: bool = False) -> dict[str,
def _clean_llm_response(response: str) -> str:
"""
清理 LLM 响应,提取 JSON 部分
处理步骤:
1. 移除 Markdown 代码块标记(```json 和 ```
2. 提取第一个完整的 JSON 对象 {...} 或数组 [...]
3. 清理多余的空格和换行
Args:
response: 原始 LLM 响应
Returns:
清理后的 JSON 字符串
"""
if not response:
return ""
cleaned = response.strip()
# 移除 Markdown 代码块标记
# 匹配 ```json ... ``` 或 ``` ... ```
code_block_patterns = [
r"```json\s*(.*?)```", # ```json ... ```
r"```\s*(.*?)```", # ``` ... ```
]
for pattern in code_block_patterns:
match = re.search(pattern, cleaned, re.IGNORECASE | re.DOTALL)
if match:
cleaned = match.group(1).strip()
logger.debug(f"从 Markdown 代码块中提取内容,长度: {len(cleaned)}")
break
# 提取 JSON 对象或数组
# 优先查找对象 {...},其次查找数组 [...]
for start_char, end_char in [("{", "}"), ("[", "]")]:
@@ -143,7 +143,7 @@ def _clean_llm_response(response: str) -> str:
if extracted:
logger.debug(f"提取到 {start_char}...{end_char} 结构,长度: {len(extracted)}")
return extracted
# 如果没有找到明确的 JSON 结构,返回清理后的原始内容
logger.debug("未找到明确的 JSON 结构,返回清理后的原始内容")
return cleaned
@@ -152,39 +152,39 @@ def _clean_llm_response(response: str) -> str:
def _extract_balanced_json(text: str, start_idx: int, start_char: str, end_char: str) -> str | None:
"""
从指定位置提取平衡的 JSON 结构
使用栈匹配算法找到对应的结束符,处理嵌套和字符串中的特殊字符
Args:
text: 源文本
start_idx: 起始字符的索引
start_char: 起始字符({ 或 [
end_char: 结束字符(} 或 ]
Returns:
提取的 JSON 字符串,失败时返回 None
"""
depth = 0
in_string = False
escape_next = False
for i in range(start_idx, len(text)):
char = text[i]
# 处理转义字符
if escape_next:
escape_next = False
continue
if char == "\\":
escape_next = True
continue
# 处理字符串
if char == '"':
in_string = not in_string
continue
# 只在非字符串内处理括号
if not in_string:
if char == start_char:
@@ -194,7 +194,7 @@ def _extract_balanced_json(text: str, start_idx: int, start_char: str, end_char:
if depth == 0:
# 找到匹配的结束符
return text[start_idx : i + 1].strip()
# 没有找到匹配的结束符
logger.debug(f"未找到匹配的 {end_char},深度: {depth}")
return None
@@ -203,11 +203,11 @@ def _extract_balanced_json(text: str, start_idx: int, start_char: str, end_char:
def safe_parse_json(json_str: str, default: Any = None) -> Any:
"""
安全解析 JSON失败时返回默认值
Args:
json_str: JSON 字符串
default: 解析失败时返回的默认值
Returns:
解析结果或默认值
"""
@@ -222,19 +222,19 @@ def safe_parse_json(json_str: str, default: Any = None) -> Any:
def extract_json_field(response: str, field_name: str, default: Any = None) -> Any:
"""
从 LLM 响应中提取特定字段的值
Args:
response: LLM 响应
field_name: 字段名
default: 字段不存在时的默认值
Returns:
字段值或默认值
"""
parsed = extract_and_parse_json(response, strict=False)
if isinstance(parsed, dict):
return parsed.get(field_name, default)
logger.warning(f"解析结果不是字典,无法提取字段 '{field_name}'")
return default