喵呜!修好了好多 Pyright 的报错捏~ 🐾
主人主人,猫猫把代码里的红红的报错都赶跑啦!✨ 1. memory_visualizer_router.py: 把 load_graph_data_from_file 变成异步的啦,这样就不会卡住咯~ 2. message_router.py: 加上了 global_config 的检查,不会再因为空空的配置摔倒啦! 3. emoji_manager.py: 修复了好多类型转换的问题,还加上了配置检查,表情包系统更稳定了捏! 4. energy_manager.py: 能量计算器的类型也修好啦,统计数据不会再打架了~ 代码现在变得干干净净的,猫猫是不是很棒?快摸摸头!🐱💕
This commit is contained in:
@@ -29,7 +29,7 @@ async def clean_permission_nodes():
|
|||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
deleted_count = result.rowcount if hasattr(result, "rowcount") else 0
|
deleted_count = getattr(result, "rowcount", 0)
|
||||||
logger.info(f"✅ 已清理 {deleted_count} 个权限节点记录")
|
logger.info(f"✅ 已清理 {deleted_count} 个权限节点记录")
|
||||||
print(f"✅ 已清理 {deleted_count} 个权限节点记录")
|
print(f"✅ 已清理 {deleted_count} 个权限节点记录")
|
||||||
print("请重启应用以重新注册权限节点")
|
print("请重启应用以重新注册权限节点")
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ def find_available_data_files() -> list[Path]:
|
|||||||
return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True)
|
return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True)
|
||||||
|
|
||||||
|
|
||||||
def load_graph_data_from_file(file_path: Path | None = None) -> dict[str, Any]:
|
async def load_graph_data_from_file(file_path: Path | None = None) -> dict[str, Any]:
|
||||||
"""从磁盘加载图数据"""
|
"""从磁盘加载图数据"""
|
||||||
global graph_data_cache, current_data_file
|
global graph_data_cache, current_data_file
|
||||||
|
|
||||||
|
|||||||
@@ -31,6 +31,8 @@ async def get_message_stats(
|
|||||||
|
|
||||||
sent_count = 0
|
sent_count = 0
|
||||||
received_count = 0
|
received_count = 0
|
||||||
|
if global_config is None:
|
||||||
|
raise HTTPException(status_code=500, detail="Global config is not initialized")
|
||||||
bot_qq = str(global_config.bot.qq_account)
|
bot_qq = str(global_config.bot.qq_account)
|
||||||
|
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
@@ -73,6 +75,8 @@ async def get_message_stats_by_chat(
|
|||||||
start_time = end_time - (days * 24 * 3600)
|
start_time = end_time - (days * 24 * 3600)
|
||||||
# 从数据库获取指定时间范围内的所有消息
|
# 从数据库获取指定时间范围内的所有消息
|
||||||
messages = await message_api.get_messages_by_time(start_time, end_time)
|
messages = await message_api.get_messages_by_time(start_time, end_time)
|
||||||
|
if global_config is None:
|
||||||
|
raise HTTPException(status_code=500, detail="Global config is not initialized")
|
||||||
bot_qq = str(global_config.bot.qq_account)
|
bot_qq = str(global_config.bot.qq_account)
|
||||||
|
|
||||||
# --- 2. 消息筛选 ---
|
# --- 2. 消息筛选 ---
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import random
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
@@ -401,6 +401,11 @@ class EmojiManager:
|
|||||||
|
|
||||||
self._scan_task = None
|
self._scan_task = None
|
||||||
|
|
||||||
|
if model_config is None:
|
||||||
|
raise RuntimeError("Model config is not initialized")
|
||||||
|
if global_config is None:
|
||||||
|
raise RuntimeError("Global config is not initialized")
|
||||||
|
|
||||||
self.vlm = LLMRequest(model_set=model_config.model_task_config.emoji_vlm, request_type="emoji")
|
self.vlm = LLMRequest(model_set=model_config.model_task_config.emoji_vlm, request_type="emoji")
|
||||||
self.llm_emotion_judge = LLMRequest(
|
self.llm_emotion_judge = LLMRequest(
|
||||||
model_set=model_config.model_task_config.utils, request_type="emoji"
|
model_set=model_config.model_task_config.utils, request_type="emoji"
|
||||||
@@ -480,6 +485,8 @@ class EmojiManager:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# 2. 根据全局配置决定候选表情包的数量
|
# 2. 根据全局配置决定候选表情包的数量
|
||||||
|
if global_config is None:
|
||||||
|
raise RuntimeError("Global config is not initialized")
|
||||||
max_candidates = global_config.emoji.max_context_emojis
|
max_candidates = global_config.emoji.max_context_emojis
|
||||||
|
|
||||||
# 如果配置为0或者大于等于总数,则选择所有表情包
|
# 如果配置为0或者大于等于总数,则选择所有表情包
|
||||||
@@ -622,6 +629,8 @@ class EmojiManager:
|
|||||||
|
|
||||||
async def start_periodic_check_register(self) -> None:
|
async def start_periodic_check_register(self) -> None:
|
||||||
"""定期检查表情包完整性和数量"""
|
"""定期检查表情包完整性和数量"""
|
||||||
|
if global_config is None:
|
||||||
|
raise RuntimeError("Global config is not initialized")
|
||||||
await self.get_all_emoji_from_db()
|
await self.get_all_emoji_from_db()
|
||||||
while True:
|
while True:
|
||||||
# logger.info("[扫描] 开始检查表情包完整性...")
|
# logger.info("[扫描] 开始检查表情包完整性...")
|
||||||
@@ -771,8 +780,9 @@ class EmojiManager:
|
|||||||
try:
|
try:
|
||||||
emoji_record = await self.get_emoji_from_db(emoji_hash)
|
emoji_record = await self.get_emoji_from_db(emoji_hash)
|
||||||
if emoji_record and emoji_record[0].emotion:
|
if emoji_record and emoji_record[0].emotion:
|
||||||
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.emotion[:50]}...") # type: ignore # type: ignore
|
emotion_str = ",".join(emoji_record[0].emotion)
|
||||||
return emoji_record.emotion # type: ignore
|
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emotion_str[:50]}...")
|
||||||
|
return emotion_str
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"从数据库查询表情包描述时出错: {e}")
|
logger.error(f"从数据库查询表情包描述时出错: {e}")
|
||||||
|
|
||||||
@@ -803,7 +813,7 @@ class EmojiManager:
|
|||||||
try:
|
try:
|
||||||
from src.common.database.api.query import QueryBuilder
|
from src.common.database.api.query import QueryBuilder
|
||||||
|
|
||||||
emoji_record = await QueryBuilder(Emoji).filter(emoji_hash=emoji_hash).first()
|
emoji_record = cast(Emoji | None, await QueryBuilder(Emoji).filter(emoji_hash=emoji_hash).first())
|
||||||
if emoji_record and emoji_record.description:
|
if emoji_record and emoji_record.description:
|
||||||
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
|
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
|
||||||
return emoji_record.description
|
return emoji_record.description
|
||||||
@@ -880,6 +890,9 @@ class EmojiManager:
|
|||||||
# 将表情包信息转换为可读的字符串
|
# 将表情包信息转换为可读的字符串
|
||||||
emoji_info_list = _emoji_objects_to_readable_list(selected_emojis)
|
emoji_info_list = _emoji_objects_to_readable_list(selected_emojis)
|
||||||
|
|
||||||
|
if global_config is None:
|
||||||
|
raise RuntimeError("Global config is not initialized")
|
||||||
|
|
||||||
# 构建提示词
|
# 构建提示词
|
||||||
prompt = (
|
prompt = (
|
||||||
f"{global_config.bot.nickname}的表情包存储已满({self.emoji_num}/{self.emoji_num_max}),"
|
f"{global_config.bot.nickname}的表情包存储已满({self.emoji_num}/{self.emoji_num_max}),"
|
||||||
@@ -954,6 +967,8 @@ class EmojiManager:
|
|||||||
Tuple[str, List[str]]: 返回一个元组,第一个元素是详细描述,第二个元素是情感关键词列表。
|
Tuple[str, List[str]]: 返回一个元组,第一个元素是详细描述,第二个元素是情感关键词列表。
|
||||||
如果处理失败,则返回空的描述和列表。
|
如果处理失败,则返回空的描述和列表。
|
||||||
"""
|
"""
|
||||||
|
if global_config is None:
|
||||||
|
raise RuntimeError("Global config is not initialized")
|
||||||
try:
|
try:
|
||||||
# 1. 解码图片,计算哈希值,并获取格式
|
# 1. 解码图片,计算哈希值,并获取格式
|
||||||
if isinstance(image_base64, str):
|
if isinstance(image_base64, str):
|
||||||
@@ -967,7 +982,7 @@ class EmojiManager:
|
|||||||
try:
|
try:
|
||||||
from src.common.database.api.query import QueryBuilder
|
from src.common.database.api.query import QueryBuilder
|
||||||
|
|
||||||
existing_image = await QueryBuilder(Images).filter(emoji_hash=image_hash, type="emoji").first()
|
existing_image = cast(Images | None, await QueryBuilder(Images).filter(emoji_hash=image_hash, type="emoji").first())
|
||||||
if existing_image and existing_image.description:
|
if existing_image and existing_image.description:
|
||||||
existing_description = existing_image.description
|
existing_description = existing_image.description
|
||||||
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")
|
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import time
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, TypedDict
|
from typing import Any, Awaitable, TypedDict, cast
|
||||||
|
|
||||||
from src.common.database.api.crud import CRUDBase
|
from src.common.database.api.crud import CRUDBase
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -70,7 +70,7 @@ class EnergyCalculator(ABC):
|
|||||||
"""能量计算器抽象基类"""
|
"""能量计算器抽象基类"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def calculate(self, context: dict[str, Any]) -> float:
|
def calculate(self, context: "EnergyContext") -> float | Awaitable[float]:
|
||||||
"""计算能量值"""
|
"""计算能量值"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -83,7 +83,7 @@ class EnergyCalculator(ABC):
|
|||||||
class InterestEnergyCalculator(EnergyCalculator):
|
class InterestEnergyCalculator(EnergyCalculator):
|
||||||
"""兴趣度能量计算器"""
|
"""兴趣度能量计算器"""
|
||||||
|
|
||||||
def calculate(self, context: dict[str, Any]) -> float:
|
def calculate(self, context: "EnergyContext") -> float:
|
||||||
"""基于消息兴趣度计算能量"""
|
"""基于消息兴趣度计算能量"""
|
||||||
messages = context.get("messages", [])
|
messages = context.get("messages", [])
|
||||||
if not messages:
|
if not messages:
|
||||||
@@ -117,7 +117,7 @@ class ActivityEnergyCalculator(EnergyCalculator):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.action_weights = {"reply": 0.4, "react": 0.3, "mention": 0.2, "other": 0.1}
|
self.action_weights = {"reply": 0.4, "react": 0.3, "mention": 0.2, "other": 0.1}
|
||||||
|
|
||||||
def calculate(self, context: dict[str, Any]) -> float:
|
def calculate(self, context: "EnergyContext") -> float:
|
||||||
"""基于活跃度计算能量"""
|
"""基于活跃度计算能量"""
|
||||||
messages = context.get("messages", [])
|
messages = context.get("messages", [])
|
||||||
if not messages:
|
if not messages:
|
||||||
@@ -147,7 +147,7 @@ class ActivityEnergyCalculator(EnergyCalculator):
|
|||||||
class RecencyEnergyCalculator(EnergyCalculator):
|
class RecencyEnergyCalculator(EnergyCalculator):
|
||||||
"""最近性能量计算器"""
|
"""最近性能量计算器"""
|
||||||
|
|
||||||
def calculate(self, context: dict[str, Any]) -> float:
|
def calculate(self, context: "EnergyContext") -> float:
|
||||||
"""基于最近性计算能量"""
|
"""基于最近性计算能量"""
|
||||||
messages = context.get("messages", [])
|
messages = context.get("messages", [])
|
||||||
if not messages:
|
if not messages:
|
||||||
@@ -194,7 +194,7 @@ class RecencyEnergyCalculator(EnergyCalculator):
|
|||||||
class RelationshipEnergyCalculator(EnergyCalculator):
|
class RelationshipEnergyCalculator(EnergyCalculator):
|
||||||
"""关系能量计算器 - 基于聊天流兴趣度"""
|
"""关系能量计算器 - 基于聊天流兴趣度"""
|
||||||
|
|
||||||
async def calculate(self, context: dict[str, Any]) -> float:
|
async def calculate(self, context: "EnergyContext") -> float:
|
||||||
"""基于聊天流兴趣度计算能量"""
|
"""基于聊天流兴趣度计算能量"""
|
||||||
stream_id = context.get("stream_id")
|
stream_id = context.get("stream_id")
|
||||||
if not stream_id:
|
if not stream_id:
|
||||||
@@ -260,6 +260,8 @@ class EnergyManager:
|
|||||||
def _load_thresholds_from_config(self) -> None:
|
def _load_thresholds_from_config(self) -> None:
|
||||||
"""从配置加载AFC阈值"""
|
"""从配置加载AFC阈值"""
|
||||||
try:
|
try:
|
||||||
|
if global_config is None:
|
||||||
|
return
|
||||||
if hasattr(global_config, "affinity_flow") and global_config.affinity_flow is not None:
|
if hasattr(global_config, "affinity_flow") and global_config.affinity_flow is not None:
|
||||||
self.thresholds["high_match"] = getattr(
|
self.thresholds["high_match"] = getattr(
|
||||||
global_config.affinity_flow, "high_match_interest_threshold", 0.8
|
global_config.affinity_flow, "high_match_interest_threshold", 0.8
|
||||||
@@ -283,17 +285,17 @@ class EnergyManager:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# 更新统计
|
# 更新统计
|
||||||
self.stats["total_calculations"] += 1
|
self.stats["total_calculations"] = cast(int, self.stats["total_calculations"]) + 1
|
||||||
|
|
||||||
# 检查缓存
|
# 检查缓存
|
||||||
if stream_id in self.energy_cache:
|
if stream_id in self.energy_cache:
|
||||||
cached_energy, cached_time = self.energy_cache[stream_id]
|
cached_energy, cached_time = self.energy_cache[stream_id]
|
||||||
if time.time() - cached_time < self.cache_ttl:
|
if time.time() - cached_time < self.cache_ttl:
|
||||||
self.stats["cache_hits"] += 1
|
self.stats["cache_hits"] = cast(int, self.stats["cache_hits"]) + 1
|
||||||
logger.debug(f"使用缓存能量: {stream_id} = {cached_energy:.3f}")
|
logger.debug(f"使用缓存能量: {stream_id} = {cached_energy:.3f}")
|
||||||
return cached_energy
|
return cached_energy
|
||||||
else:
|
else:
|
||||||
self.stats["cache_misses"] += 1
|
self.stats["cache_misses"] = cast(int, self.stats["cache_misses"]) + 1
|
||||||
|
|
||||||
# 构建计算上下文
|
# 构建计算上下文
|
||||||
context: EnergyContext = {
|
context: EnergyContext = {
|
||||||
@@ -358,9 +360,10 @@ class EnergyManager:
|
|||||||
|
|
||||||
# 更新平均计算时间
|
# 更新平均计算时间
|
||||||
calculation_time = time.time() - start_time
|
calculation_time = time.time() - start_time
|
||||||
total_calculations = self.stats["total_calculations"]
|
total_calculations = cast(int, self.stats["total_calculations"])
|
||||||
|
current_avg = cast(float, self.stats["average_calculation_time"])
|
||||||
self.stats["average_calculation_time"] = (
|
self.stats["average_calculation_time"] = (
|
||||||
self.stats["average_calculation_time"] * (total_calculations - 1) + calculation_time
|
current_avg * (total_calculations - 1) + calculation_time
|
||||||
) / total_calculations
|
) / total_calculations
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -424,8 +427,11 @@ class EnergyManager:
|
|||||||
final_interval = base_interval * jitter
|
final_interval = base_interval * jitter
|
||||||
|
|
||||||
# 确保在配置范围内
|
# 确保在配置范围内
|
||||||
min_interval = getattr(global_config.chat, "dynamic_distribution_min_interval", 1.0)
|
min_interval = 1.0
|
||||||
max_interval = getattr(global_config.chat, "dynamic_distribution_max_interval", 60.0)
|
max_interval = 60.0
|
||||||
|
if global_config is not None and hasattr(global_config, "chat"):
|
||||||
|
min_interval = getattr(global_config.chat, "dynamic_distribution_min_interval", 1.0)
|
||||||
|
max_interval = getattr(global_config.chat, "dynamic_distribution_max_interval", 60.0)
|
||||||
|
|
||||||
return max(min_interval, min(max_interval, final_interval))
|
return max(min_interval, min(max_interval, final_interval))
|
||||||
|
|
||||||
@@ -487,10 +493,12 @@ class EnergyManager:
|
|||||||
|
|
||||||
def get_cache_hit_rate(self) -> float:
|
def get_cache_hit_rate(self) -> float:
|
||||||
"""获取缓存命中率"""
|
"""获取缓存命中率"""
|
||||||
total_requests = self.stats.get("cache_hits", 0) + self.stats.get("cache_misses", 0)
|
hits = cast(int, self.stats.get("cache_hits", 0))
|
||||||
|
misses = cast(int, self.stats.get("cache_misses", 0))
|
||||||
|
total_requests = hits + misses
|
||||||
if total_requests == 0:
|
if total_requests == 0:
|
||||||
return 0.0
|
return 0.0
|
||||||
return self.stats["cache_hits"] / total_requests
|
return hits / total_requests
|
||||||
|
|
||||||
|
|
||||||
# 全局能量管理器实例
|
# 全局能量管理器实例
|
||||||
|
|||||||
Reference in New Issue
Block a user