This commit is contained in:
minecraft1024a
2025-11-29 20:06:59 +08:00
5 changed files with 49 additions and 22 deletions

View File

@@ -29,7 +29,7 @@ async def clean_permission_nodes():
result = await session.execute(stmt) result = await session.execute(stmt)
await session.commit() await session.commit()
deleted_count = result.rowcount if hasattr(result, "rowcount") else 0 deleted_count = getattr(result, "rowcount", 0)
logger.info(f"✅ 已清理 {deleted_count} 个权限节点记录") logger.info(f"✅ 已清理 {deleted_count} 个权限节点记录")
print(f"✅ 已清理 {deleted_count} 个权限节点记录") print(f"✅ 已清理 {deleted_count} 个权限节点记录")
print("请重启应用以重新注册权限节点") print("请重启应用以重新注册权限节点")

View File

@@ -63,7 +63,7 @@ def find_available_data_files() -> list[Path]:
return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True) return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True)
def load_graph_data_from_file(file_path: Path | None = None) -> dict[str, Any]: async def load_graph_data_from_file(file_path: Path | None = None) -> dict[str, Any]:
"""从磁盘加载图数据""" """从磁盘加载图数据"""
global graph_data_cache, current_data_file global graph_data_cache, current_data_file

View File

@@ -31,6 +31,8 @@ async def get_message_stats(
sent_count = 0 sent_count = 0
received_count = 0 received_count = 0
if global_config is None:
raise HTTPException(status_code=500, detail="Global config is not initialized")
bot_qq = str(global_config.bot.qq_account) bot_qq = str(global_config.bot.qq_account)
for msg in messages: for msg in messages:
@@ -73,6 +75,8 @@ async def get_message_stats_by_chat(
start_time = end_time - (days * 24 * 3600) start_time = end_time - (days * 24 * 3600)
# 从数据库获取指定时间范围内的所有消息 # 从数据库获取指定时间范围内的所有消息
messages = await message_api.get_messages_by_time(start_time, end_time) messages = await message_api.get_messages_by_time(start_time, end_time)
if global_config is None:
raise HTTPException(status_code=500, detail="Global config is not initialized")
bot_qq = str(global_config.bot.qq_account) bot_qq = str(global_config.bot.qq_account)
# --- 2. 消息筛选 --- # --- 2. 消息筛选 ---

View File

@@ -9,7 +9,7 @@ import random
import re import re
import time import time
import traceback import traceback
from typing import Any, Optional from typing import Any, Optional, cast
from PIL import Image from PIL import Image
from rich.traceback import install from rich.traceback import install
@@ -401,6 +401,11 @@ class EmojiManager:
self._scan_task = None self._scan_task = None
if model_config is None:
raise RuntimeError("Model config is not initialized")
if global_config is None:
raise RuntimeError("Global config is not initialized")
self.vlm = LLMRequest(model_set=model_config.model_task_config.emoji_vlm, request_type="emoji") self.vlm = LLMRequest(model_set=model_config.model_task_config.emoji_vlm, request_type="emoji")
self.llm_emotion_judge = LLMRequest( self.llm_emotion_judge = LLMRequest(
model_set=model_config.model_task_config.utils, request_type="emoji" model_set=model_config.model_task_config.utils, request_type="emoji"
@@ -480,6 +485,8 @@ class EmojiManager:
return None return None
# 2. 根据全局配置决定候选表情包的数量 # 2. 根据全局配置决定候选表情包的数量
if global_config is None:
raise RuntimeError("Global config is not initialized")
max_candidates = global_config.emoji.max_context_emojis max_candidates = global_config.emoji.max_context_emojis
# 如果配置为0或者大于等于总数则选择所有表情包 # 如果配置为0或者大于等于总数则选择所有表情包
@@ -622,6 +629,8 @@ class EmojiManager:
async def start_periodic_check_register(self) -> None: async def start_periodic_check_register(self) -> None:
"""定期检查表情包完整性和数量""" """定期检查表情包完整性和数量"""
if global_config is None:
raise RuntimeError("Global config is not initialized")
await self.get_all_emoji_from_db() await self.get_all_emoji_from_db()
while True: while True:
# logger.info("[扫描] 开始检查表情包完整性...") # logger.info("[扫描] 开始检查表情包完整性...")
@@ -771,8 +780,9 @@ class EmojiManager:
try: try:
emoji_record = await self.get_emoji_from_db(emoji_hash) emoji_record = await self.get_emoji_from_db(emoji_hash)
if emoji_record and emoji_record[0].emotion: if emoji_record and emoji_record[0].emotion:
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.emotion[:50]}...") # type: ignore # type: ignore emotion_str = ",".join(emoji_record[0].emotion)
return emoji_record.emotion # type: ignore logger.info(f"[缓存命中] 从数据库获取表情包描述: {emotion_str[:50]}...")
return emotion_str
except Exception as e: except Exception as e:
logger.error(f"从数据库查询表情包描述时出错: {e}") logger.error(f"从数据库查询表情包描述时出错: {e}")
@@ -803,7 +813,7 @@ class EmojiManager:
try: try:
from src.common.database.api.query import QueryBuilder from src.common.database.api.query import QueryBuilder
emoji_record = await QueryBuilder(Emoji).filter(emoji_hash=emoji_hash).first() emoji_record = cast(Emoji | None, await QueryBuilder(Emoji).filter(emoji_hash=emoji_hash).first())
if emoji_record and emoji_record.description: if emoji_record and emoji_record.description:
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...") logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
return emoji_record.description return emoji_record.description
@@ -880,6 +890,9 @@ class EmojiManager:
# 将表情包信息转换为可读的字符串 # 将表情包信息转换为可读的字符串
emoji_info_list = _emoji_objects_to_readable_list(selected_emojis) emoji_info_list = _emoji_objects_to_readable_list(selected_emojis)
if global_config is None:
raise RuntimeError("Global config is not initialized")
# 构建提示词 # 构建提示词
prompt = ( prompt = (
f"{global_config.bot.nickname}的表情包存储已满({self.emoji_num}/{self.emoji_num_max})" f"{global_config.bot.nickname}的表情包存储已满({self.emoji_num}/{self.emoji_num_max})"
@@ -954,6 +967,8 @@ class EmojiManager:
Tuple[str, List[str]]: 返回一个元组,第一个元素是详细描述,第二个元素是情感关键词列表。 Tuple[str, List[str]]: 返回一个元组,第一个元素是详细描述,第二个元素是情感关键词列表。
如果处理失败,则返回空的描述和列表。 如果处理失败,则返回空的描述和列表。
""" """
if global_config is None:
raise RuntimeError("Global config is not initialized")
try: try:
# 1. 解码图片,计算哈希值,并获取格式 # 1. 解码图片,计算哈希值,并获取格式
if isinstance(image_base64, str): if isinstance(image_base64, str):
@@ -967,7 +982,7 @@ class EmojiManager:
try: try:
from src.common.database.api.query import QueryBuilder from src.common.database.api.query import QueryBuilder
existing_image = await QueryBuilder(Images).filter(emoji_hash=image_hash, type="emoji").first() existing_image = cast(Images | None, await QueryBuilder(Images).filter(emoji_hash=image_hash, type="emoji").first())
if existing_image and existing_image.description: if existing_image and existing_image.description:
existing_description = existing_image.description existing_description = existing_image.description
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...") logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")

View File

@@ -7,7 +7,7 @@ import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Any, TypedDict from typing import Any, Awaitable, TypedDict, cast
from src.common.database.api.crud import CRUDBase from src.common.database.api.crud import CRUDBase
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -70,7 +70,7 @@ class EnergyCalculator(ABC):
"""能量计算器抽象基类""" """能量计算器抽象基类"""
@abstractmethod @abstractmethod
def calculate(self, context: dict[str, Any]) -> float: def calculate(self, context: "EnergyContext") -> float | Awaitable[float]:
"""计算能量值""" """计算能量值"""
pass pass
@@ -83,7 +83,7 @@ class EnergyCalculator(ABC):
class InterestEnergyCalculator(EnergyCalculator): class InterestEnergyCalculator(EnergyCalculator):
"""兴趣度能量计算器""" """兴趣度能量计算器"""
def calculate(self, context: dict[str, Any]) -> float: def calculate(self, context: "EnergyContext") -> float:
"""基于消息兴趣度计算能量""" """基于消息兴趣度计算能量"""
messages = context.get("messages", []) messages = context.get("messages", [])
if not messages: if not messages:
@@ -117,7 +117,7 @@ class ActivityEnergyCalculator(EnergyCalculator):
def __init__(self): def __init__(self):
self.action_weights = {"reply": 0.4, "react": 0.3, "mention": 0.2, "other": 0.1} self.action_weights = {"reply": 0.4, "react": 0.3, "mention": 0.2, "other": 0.1}
def calculate(self, context: dict[str, Any]) -> float: def calculate(self, context: "EnergyContext") -> float:
"""基于活跃度计算能量""" """基于活跃度计算能量"""
messages = context.get("messages", []) messages = context.get("messages", [])
if not messages: if not messages:
@@ -147,7 +147,7 @@ class ActivityEnergyCalculator(EnergyCalculator):
class RecencyEnergyCalculator(EnergyCalculator): class RecencyEnergyCalculator(EnergyCalculator):
"""最近性能量计算器""" """最近性能量计算器"""
def calculate(self, context: dict[str, Any]) -> float: def calculate(self, context: "EnergyContext") -> float:
"""基于最近性计算能量""" """基于最近性计算能量"""
messages = context.get("messages", []) messages = context.get("messages", [])
if not messages: if not messages:
@@ -194,7 +194,7 @@ class RecencyEnergyCalculator(EnergyCalculator):
class RelationshipEnergyCalculator(EnergyCalculator): class RelationshipEnergyCalculator(EnergyCalculator):
"""关系能量计算器 - 基于聊天流兴趣度""" """关系能量计算器 - 基于聊天流兴趣度"""
async def calculate(self, context: dict[str, Any]) -> float: async def calculate(self, context: "EnergyContext") -> float:
"""基于聊天流兴趣度计算能量""" """基于聊天流兴趣度计算能量"""
stream_id = context.get("stream_id") stream_id = context.get("stream_id")
if not stream_id: if not stream_id:
@@ -260,6 +260,8 @@ class EnergyManager:
def _load_thresholds_from_config(self) -> None: def _load_thresholds_from_config(self) -> None:
"""从配置加载AFC阈值""" """从配置加载AFC阈值"""
try: try:
if global_config is None:
return
if hasattr(global_config, "affinity_flow") and global_config.affinity_flow is not None: if hasattr(global_config, "affinity_flow") and global_config.affinity_flow is not None:
self.thresholds["high_match"] = getattr( self.thresholds["high_match"] = getattr(
global_config.affinity_flow, "high_match_interest_threshold", 0.8 global_config.affinity_flow, "high_match_interest_threshold", 0.8
@@ -283,17 +285,17 @@ class EnergyManager:
start_time = time.time() start_time = time.time()
# 更新统计 # 更新统计
self.stats["total_calculations"] += 1 self.stats["total_calculations"] = cast(int, self.stats["total_calculations"]) + 1
# 检查缓存 # 检查缓存
if stream_id in self.energy_cache: if stream_id in self.energy_cache:
cached_energy, cached_time = self.energy_cache[stream_id] cached_energy, cached_time = self.energy_cache[stream_id]
if time.time() - cached_time < self.cache_ttl: if time.time() - cached_time < self.cache_ttl:
self.stats["cache_hits"] += 1 self.stats["cache_hits"] = cast(int, self.stats["cache_hits"]) + 1
logger.debug(f"使用缓存能量: {stream_id} = {cached_energy:.3f}") logger.debug(f"使用缓存能量: {stream_id} = {cached_energy:.3f}")
return cached_energy return cached_energy
else: else:
self.stats["cache_misses"] += 1 self.stats["cache_misses"] = cast(int, self.stats["cache_misses"]) + 1
# 构建计算上下文 # 构建计算上下文
context: EnergyContext = { context: EnergyContext = {
@@ -358,9 +360,10 @@ class EnergyManager:
# 更新平均计算时间 # 更新平均计算时间
calculation_time = time.time() - start_time calculation_time = time.time() - start_time
total_calculations = self.stats["total_calculations"] total_calculations = cast(int, self.stats["total_calculations"])
current_avg = cast(float, self.stats["average_calculation_time"])
self.stats["average_calculation_time"] = ( self.stats["average_calculation_time"] = (
self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time current_avg * (total_calculations - 1) + calculation_time
) / total_calculations ) / total_calculations
logger.debug( logger.debug(
@@ -424,8 +427,11 @@ class EnergyManager:
final_interval = base_interval * jitter final_interval = base_interval * jitter
# 确保在配置范围内 # 确保在配置范围内
min_interval = getattr(global_config.chat, "dynamic_distribution_min_interval", 1.0) min_interval = 1.0
max_interval = getattr(global_config.chat, "dynamic_distribution_max_interval", 60.0) max_interval = 60.0
if global_config is not None and hasattr(global_config, "chat"):
min_interval = getattr(global_config.chat, "dynamic_distribution_min_interval", 1.0)
max_interval = getattr(global_config.chat, "dynamic_distribution_max_interval", 60.0)
return max(min_interval, min(max_interval, final_interval)) return max(min_interval, min(max_interval, final_interval))
@@ -487,10 +493,12 @@ class EnergyManager:
def get_cache_hit_rate(self) -> float: def get_cache_hit_rate(self) -> float:
"""获取缓存命中率""" """获取缓存命中率"""
total_requests = self.stats.get("cache_hits", 0) + self.stats.get("cache_misses", 0) hits = cast(int, self.stats.get("cache_hits", 0))
misses = cast(int, self.stats.get("cache_misses", 0))
total_requests = hits + misses
if total_requests == 0: if total_requests == 0:
return 0.0 return 0.0
return self.stats["cache_hits"] / total_requests return hits / total_requests
# 全局能量管理器实例 # 全局能量管理器实例