re-style: 格式化代码
This commit is contained in:
@@ -3,13 +3,14 @@
|
||||
提供机器人兴趣标签和智能匹配功能
|
||||
"""
|
||||
|
||||
from .bot_interest_manager import BotInterestManager, bot_interest_manager
|
||||
from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
|
||||
|
||||
from .bot_interest_manager import BotInterestManager, bot_interest_manager
|
||||
|
||||
__all__ = [
|
||||
"BotInterestManager",
|
||||
"bot_interest_manager",
|
||||
"BotInterestTag",
|
||||
"BotPersonalityInterests",
|
||||
"InterestMatchResult",
|
||||
"bot_interest_manager",
|
||||
]
|
||||
|
||||
@@ -3,17 +3,18 @@
|
||||
基于人设生成兴趣标签,并使用embedding计算匹配度
|
||||
"""
|
||||
|
||||
import orjson
|
||||
import traceback
|
||||
from typing import List, Dict, Optional, Any
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import orjson
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from src.common.data_models.bot_interest_data_model import BotPersonalityInterests, BotInterestTag, InterestMatchResult
|
||||
|
||||
logger = get_logger("bot_interest_manager")
|
||||
|
||||
@@ -22,8 +23,8 @@ class BotInterestManager:
|
||||
"""机器人兴趣标签管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.current_interests: Optional[BotPersonalityInterests] = None
|
||||
self.embedding_cache: Dict[str, List[float]] = {} # embedding缓存
|
||||
self.current_interests: BotPersonalityInterests | None = None
|
||||
self.embedding_cache: dict[str, list[float]] = {} # embedding缓存
|
||||
self._initialized = False
|
||||
|
||||
# Embedding客户端配置
|
||||
@@ -31,7 +32,7 @@ class BotInterestManager:
|
||||
self.embedding_config = None
|
||||
configured_dim = resolve_embedding_dimension()
|
||||
self.embedding_dimension = int(configured_dim) if configured_dim else 0
|
||||
self._detected_embedding_dimension: Optional[int] = None
|
||||
self._detected_embedding_dimension: int | None = None
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
@@ -145,7 +146,7 @@ class BotInterestManager:
|
||||
|
||||
async def _generate_interests_from_personality(
|
||||
self, personality_description: str, personality_id: str
|
||||
) -> Optional[BotPersonalityInterests]:
|
||||
) -> BotPersonalityInterests | None:
|
||||
"""根据人设生成兴趣标签"""
|
||||
try:
|
||||
logger.info("🎨 开始根据人设生成兴趣标签...")
|
||||
@@ -226,14 +227,14 @@ class BotInterestManager:
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
async def _call_llm_for_interest_generation(self, prompt: str) -> Optional[str]:
|
||||
async def _call_llm_for_interest_generation(self, prompt: str) -> str | None:
|
||||
"""调用LLM生成兴趣标签"""
|
||||
try:
|
||||
logger.info("🔧 配置LLM客户端...")
|
||||
|
||||
# 使用llm_api来处理请求
|
||||
from src.plugin_system.apis import llm_api
|
||||
from src.config.config import model_config
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
# 构建完整的提示词,明确要求只返回纯JSON
|
||||
full_prompt = f"""你是一个专业的机器人人设分析师,擅长根据人设描述生成合适的兴趣标签。
|
||||
@@ -342,7 +343,7 @@ class BotInterestManager:
|
||||
logger.info(f"🗃️ 总缓存大小: {len(self.embedding_cache)}")
|
||||
logger.info("=" * 50)
|
||||
|
||||
async def _get_embedding(self, text: str) -> List[float]:
|
||||
async def _get_embedding(self, text: str) -> list[float]:
|
||||
"""获取文本的embedding向量"""
|
||||
if not hasattr(self, "embedding_request"):
|
||||
raise RuntimeError("❌ Embedding请求客户端未初始化")
|
||||
@@ -383,7 +384,7 @@ class BotInterestManager:
|
||||
else:
|
||||
raise RuntimeError(f"❌ 返回的embedding为空: {embedding}")
|
||||
|
||||
async def _generate_message_embedding(self, message_text: str, keywords: List[str]) -> List[float]:
|
||||
async def _generate_message_embedding(self, message_text: str, keywords: list[str]) -> list[float]:
|
||||
"""为消息生成embedding向量"""
|
||||
# 组合消息文本和关键词作为embedding输入
|
||||
if keywords:
|
||||
@@ -399,7 +400,7 @@ class BotInterestManager:
|
||||
return embedding
|
||||
|
||||
async def _calculate_similarity_scores(
|
||||
self, result: InterestMatchResult, message_embedding: List[float], keywords: List[str]
|
||||
self, result: InterestMatchResult, message_embedding: list[float], keywords: list[str]
|
||||
):
|
||||
"""计算消息与兴趣标签的相似度分数"""
|
||||
try:
|
||||
@@ -428,7 +429,7 @@ class BotInterestManager:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 计算相似度分数失败: {e}")
|
||||
|
||||
async def calculate_interest_match(self, message_text: str, keywords: List[str] = None) -> InterestMatchResult:
|
||||
async def calculate_interest_match(self, message_text: str, keywords: list[str] = None) -> InterestMatchResult:
|
||||
"""计算消息与机器人兴趣的匹配度"""
|
||||
if not self.current_interests or not self._initialized:
|
||||
raise RuntimeError("❌ 兴趣标签系统未初始化")
|
||||
@@ -528,7 +529,7 @@ class BotInterestManager:
|
||||
)
|
||||
return result
|
||||
|
||||
def _calculate_keyword_match_bonus(self, keywords: List[str], matched_tags: List[str]) -> Dict[str, float]:
|
||||
def _calculate_keyword_match_bonus(self, keywords: list[str], matched_tags: list[str]) -> dict[str, float]:
|
||||
"""计算关键词直接匹配奖励"""
|
||||
if not keywords or not matched_tags:
|
||||
return {}
|
||||
@@ -610,7 +611,7 @@ class BotInterestManager:
|
||||
|
||||
return previous_row[-1]
|
||||
|
||||
def _calculate_cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
|
||||
def _calculate_cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float:
|
||||
"""计算余弦相似度"""
|
||||
try:
|
||||
vec1 = np.array(vec1)
|
||||
@@ -629,16 +630,17 @@ class BotInterestManager:
|
||||
logger.error(f"计算余弦相似度失败: {e}")
|
||||
return 0.0
|
||||
|
||||
async def _load_interests_from_database(self, personality_id: str) -> Optional[BotPersonalityInterests]:
|
||||
async def _load_interests_from_database(self, personality_id: str) -> BotPersonalityInterests | None:
|
||||
"""从数据库加载兴趣标签"""
|
||||
try:
|
||||
logger.debug(f"从数据库加载兴趣标签, personality_id: {personality_id}")
|
||||
|
||||
# 导入SQLAlchemy相关模块
|
||||
from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
import orjson
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests
|
||||
|
||||
async with get_db_session() as session:
|
||||
# 查询最新的兴趣标签配置
|
||||
db_interests = (
|
||||
@@ -716,10 +718,11 @@ class BotInterestManager:
|
||||
logger.info(f"🔄 版本: {interests.version}")
|
||||
|
||||
# 导入SQLAlchemy相关模块
|
||||
from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
import orjson
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import BotPersonalityInterests as DBBotPersonalityInterests
|
||||
|
||||
# 将兴趣标签转换为JSON格式
|
||||
tags_data = []
|
||||
for tag in interests.interest_tags:
|
||||
@@ -803,11 +806,11 @@ class BotInterestManager:
|
||||
logger.error("🔍 错误详情:")
|
||||
traceback.print_exc()
|
||||
|
||||
def get_current_interests(self) -> Optional[BotPersonalityInterests]:
|
||||
def get_current_interests(self) -> BotPersonalityInterests | None:
|
||||
"""获取当前的兴趣标签配置"""
|
||||
return self.current_interests
|
||||
|
||||
def get_interest_stats(self) -> Dict[str, Any]:
|
||||
def get_interest_stats(self) -> dict[str, Any]:
|
||||
"""获取兴趣系统统计信息"""
|
||||
if not self.current_interests:
|
||||
return {"initialized": False}
|
||||
|
||||
Reference in New Issue
Block a user