ruff,私聊视为提及了bot
This commit is contained in:
@@ -4,14 +4,12 @@
|
||||
"""
|
||||
|
||||
from .bot_interest_manager import BotInterestManager, bot_interest_manager
|
||||
from src.common.data_models.bot_interest_data_model import (
|
||||
BotInterestTag, BotPersonalityInterests, InterestMatchResult
|
||||
)
|
||||
from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
|
||||
|
||||
__all__ = [
|
||||
"BotInterestManager",
|
||||
"bot_interest_manager",
|
||||
"BotInterestTag",
|
||||
"BotPersonalityInterests",
|
||||
"InterestMatchResult"
|
||||
]
|
||||
"InterestMatchResult",
|
||||
]
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
机器人兴趣标签管理系统
|
||||
基于人设生成兴趣标签,并使用embedding计算匹配度
|
||||
"""
|
||||
|
||||
import orjson
|
||||
import traceback
|
||||
from typing import List, Dict, Optional, Any
|
||||
@@ -10,9 +11,7 @@ import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.common.data_models.bot_interest_data_model import (
|
||||
BotPersonalityInterests, BotInterestTag, InterestMatchResult
|
||||
)
|
||||
from src.common.data_models.bot_interest_data_model import BotPersonalityInterests, BotInterestTag, InterestMatchResult
|
||||
|
||||
logger = get_logger("bot_interest_manager")
|
||||
|
||||
@@ -87,7 +86,7 @@ class BotInterestManager:
|
||||
logger.debug("✅ 成功导入embedding相关模块")
|
||||
|
||||
# 检查embedding配置是否存在
|
||||
if not hasattr(model_config.model_task_config, 'embedding'):
|
||||
if not hasattr(model_config.model_task_config, "embedding"):
|
||||
raise RuntimeError("❌ 未找到embedding模型配置")
|
||||
|
||||
logger.info("📋 找到embedding模型配置")
|
||||
@@ -101,7 +100,7 @@ class BotInterestManager:
|
||||
logger.info(f"🔗 客户端类型: {type(self.embedding_request).__name__}")
|
||||
|
||||
# 获取第一个embedding模型的ModelInfo
|
||||
if hasattr(self.embedding_config, 'model_list') and self.embedding_config.model_list:
|
||||
if hasattr(self.embedding_config, "model_list") and self.embedding_config.model_list:
|
||||
first_model_name = self.embedding_config.model_list[0]
|
||||
logger.info(f"🎯 使用embedding模型: {first_model_name}")
|
||||
else:
|
||||
@@ -127,7 +126,9 @@ class BotInterestManager:
|
||||
# 生成新的兴趣标签
|
||||
logger.info("🆕 数据库中未找到兴趣标签,开始生成新的...")
|
||||
logger.info("🤖 正在调用LLM生成个性化兴趣标签...")
|
||||
generated_interests = await self._generate_interests_from_personality(personality_description, personality_id)
|
||||
generated_interests = await self._generate_interests_from_personality(
|
||||
personality_description, personality_id
|
||||
)
|
||||
|
||||
if generated_interests:
|
||||
self.current_interests = generated_interests
|
||||
@@ -140,14 +141,16 @@ class BotInterestManager:
|
||||
else:
|
||||
raise RuntimeError("❌ 兴趣标签生成失败")
|
||||
|
||||
async def _generate_interests_from_personality(self, personality_description: str, personality_id: str) -> Optional[BotPersonalityInterests]:
|
||||
async def _generate_interests_from_personality(
|
||||
self, personality_description: str, personality_id: str
|
||||
) -> Optional[BotPersonalityInterests]:
|
||||
"""根据人设生成兴趣标签"""
|
||||
try:
|
||||
logger.info("🎨 开始根据人设生成兴趣标签...")
|
||||
logger.info(f"📝 人设长度: {len(personality_description)} 字符")
|
||||
|
||||
# 检查embedding客户端是否可用
|
||||
if not hasattr(self, 'embedding_request'):
|
||||
if not hasattr(self, "embedding_request"):
|
||||
raise RuntimeError("❌ Embedding客户端未初始化,无法生成兴趣标签")
|
||||
|
||||
# 构建提示词
|
||||
@@ -190,8 +193,7 @@ class BotInterestManager:
|
||||
interests_data = orjson.loads(response)
|
||||
|
||||
bot_interests = BotPersonalityInterests(
|
||||
personality_id=personality_id,
|
||||
personality_description=personality_description
|
||||
personality_id=personality_id, personality_description=personality_description
|
||||
)
|
||||
|
||||
# 解析生成的兴趣标签
|
||||
@@ -202,10 +204,7 @@ class BotInterestManager:
|
||||
tag_name = tag_data.get("name", f"标签_{i}")
|
||||
weight = tag_data.get("weight", 0.5)
|
||||
|
||||
tag = BotInterestTag(
|
||||
tag_name=tag_name,
|
||||
weight=weight
|
||||
)
|
||||
tag = BotInterestTag(tag_name=tag_name, weight=weight)
|
||||
bot_interests.interest_tags.append(tag)
|
||||
|
||||
logger.debug(f" 🏷️ {tag_name} (权重: {weight:.2f})")
|
||||
@@ -225,7 +224,6 @@ class BotInterestManager:
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
|
||||
async def _call_llm_for_interest_generation(self, prompt: str) -> Optional[str]:
|
||||
"""调用LLM生成兴趣标签"""
|
||||
try:
|
||||
@@ -241,10 +239,10 @@ class BotInterestManager:
|
||||
{prompt}
|
||||
|
||||
请确保返回格式为有效的JSON,不要包含任何额外的文本、解释或代码块标记。只返回JSON对象本身。"""
|
||||
|
||||
|
||||
# 使用replyer模型配置
|
||||
replyer_config = model_config.model_task_config.replyer
|
||||
|
||||
|
||||
# 调用LLM API
|
||||
logger.info("🚀 正在通过LLM API发送请求...")
|
||||
success, response, reasoning_content, model_name = await llm_api.generate_with_model(
|
||||
@@ -252,15 +250,17 @@ class BotInterestManager:
|
||||
model_config=replyer_config,
|
||||
request_type="interest_generation",
|
||||
temperature=0.7,
|
||||
max_tokens=2000
|
||||
max_tokens=2000,
|
||||
)
|
||||
|
||||
if success and response:
|
||||
logger.info(f"✅ LLM响应成功,模型: {model_name}, 响应长度: {len(response)} 字符")
|
||||
logger.debug(f"📄 LLM响应内容: {response[:200]}..." if len(response) > 200 else f"📄 LLM响应内容: {response}")
|
||||
logger.debug(
|
||||
f"📄 LLM响应内容: {response[:200]}..." if len(response) > 200 else f"📄 LLM响应内容: {response}"
|
||||
)
|
||||
if reasoning_content:
|
||||
logger.debug(f"🧠 推理内容: {reasoning_content[:100]}...")
|
||||
|
||||
|
||||
# 清理响应内容,移除可能的代码块标记
|
||||
cleaned_response = self._clean_llm_response(response)
|
||||
return cleaned_response
|
||||
@@ -277,25 +277,25 @@ class BotInterestManager:
|
||||
def _clean_llm_response(self, response: str) -> str:
|
||||
"""清理LLM响应,移除代码块标记和其他非JSON内容"""
|
||||
import re
|
||||
|
||||
|
||||
# 移除 ```json 和 ``` 标记
|
||||
cleaned = re.sub(r'```json\s*', '', response)
|
||||
cleaned = re.sub(r'\s*```', '', cleaned)
|
||||
|
||||
cleaned = re.sub(r"```json\s*", "", response)
|
||||
cleaned = re.sub(r"\s*```", "", cleaned)
|
||||
|
||||
# 移除可能的多余空格和换行
|
||||
cleaned = cleaned.strip()
|
||||
|
||||
|
||||
# 尝试提取JSON对象(如果响应中有其他文本)
|
||||
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
|
||||
json_match = re.search(r"\{.*\}", cleaned, re.DOTALL)
|
||||
if json_match:
|
||||
cleaned = json_match.group(0)
|
||||
|
||||
|
||||
logger.debug(f"🧹 清理后的响应: {cleaned[:200]}..." if len(cleaned) > 200 else f"🧹 清理后的响应: {cleaned}")
|
||||
return cleaned
|
||||
|
||||
async def _generate_embeddings_for_tags(self, interests: BotPersonalityInterests):
|
||||
"""为所有兴趣标签生成embedding"""
|
||||
if not hasattr(self, 'embedding_request'):
|
||||
if not hasattr(self, "embedding_request"):
|
||||
raise RuntimeError("❌ Embedding客户端未初始化,无法生成embedding")
|
||||
|
||||
total_tags = len(interests.interest_tags)
|
||||
@@ -342,7 +342,7 @@ class BotInterestManager:
|
||||
|
||||
async def _get_embedding(self, text: str) -> List[float]:
|
||||
"""获取文本的embedding向量"""
|
||||
if not hasattr(self, 'embedding_request'):
|
||||
if not hasattr(self, "embedding_request"):
|
||||
raise RuntimeError("❌ Embedding请求客户端未初始化")
|
||||
|
||||
# 检查缓存
|
||||
@@ -376,7 +376,9 @@ class BotInterestManager:
|
||||
logger.debug(f"✅ 消息embedding生成成功,维度: {len(embedding)}")
|
||||
return embedding
|
||||
|
||||
async def _calculate_similarity_scores(self, result: InterestMatchResult, message_embedding: List[float], keywords: List[str]):
|
||||
async def _calculate_similarity_scores(
|
||||
self, result: InterestMatchResult, message_embedding: List[float], keywords: List[str]
|
||||
):
|
||||
"""计算消息与兴趣标签的相似度分数"""
|
||||
try:
|
||||
if not self.current_interests:
|
||||
@@ -397,7 +399,9 @@ class BotInterestManager:
|
||||
# 设置相似度阈值为0.3
|
||||
if similarity > 0.3:
|
||||
result.add_match(tag.tag_name, weighted_score, keywords)
|
||||
logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 加权分数={weighted_score:.3f}")
|
||||
logger.debug(
|
||||
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 加权分数={weighted_score:.3f}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 计算相似度分数失败: {e}")
|
||||
@@ -455,7 +459,9 @@ class BotInterestManager:
|
||||
match_count += 1
|
||||
high_similarity_count += 1
|
||||
result.add_match(tag.tag_name, enhanced_score, [tag.tag_name])
|
||||
logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [高匹配]")
|
||||
logger.debug(
|
||||
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [高匹配]"
|
||||
)
|
||||
|
||||
elif similarity > medium_threshold:
|
||||
# 中相似度:中等加成
|
||||
@@ -463,7 +469,9 @@ class BotInterestManager:
|
||||
match_count += 1
|
||||
medium_similarity_count += 1
|
||||
result.add_match(tag.tag_name, enhanced_score, [tag.tag_name])
|
||||
logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [中匹配]")
|
||||
logger.debug(
|
||||
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [中匹配]"
|
||||
)
|
||||
|
||||
elif similarity > low_threshold:
|
||||
# 低相似度:轻微加成
|
||||
@@ -471,7 +479,9 @@ class BotInterestManager:
|
||||
match_count += 1
|
||||
low_similarity_count += 1
|
||||
result.add_match(tag.tag_name, enhanced_score, [tag.tag_name])
|
||||
logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [低匹配]")
|
||||
logger.debug(
|
||||
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [低匹配]"
|
||||
)
|
||||
|
||||
logger.info(f"📈 匹配统计: {match_count}/{len(active_tags)} 个标签超过阈值")
|
||||
logger.info(f"🔥 高相似度匹配(>{high_threshold}): {high_similarity_count} 个")
|
||||
@@ -488,7 +498,9 @@ class BotInterestManager:
|
||||
original_score = result.match_scores[tag_name]
|
||||
bonus = keyword_bonus[tag_name]
|
||||
result.match_scores[tag_name] = original_score + bonus
|
||||
logger.debug(f" 🏷️ '{tag_name}': 原始分数={original_score:.3f}, 奖励={bonus:.3f}, 最终分数={result.match_scores[tag_name]:.3f}")
|
||||
logger.debug(
|
||||
f" 🏷️ '{tag_name}': 原始分数={original_score:.3f}, 奖励={bonus:.3f}, 最终分数={result.match_scores[tag_name]:.3f}"
|
||||
)
|
||||
|
||||
# 计算总体分数
|
||||
result.calculate_overall_score()
|
||||
@@ -499,10 +511,11 @@ class BotInterestManager:
|
||||
result.top_tag = top_tag_name
|
||||
logger.info(f"🏆 最佳匹配标签: '{top_tag_name}' (分数: {result.match_scores[top_tag_name]:.3f})")
|
||||
|
||||
logger.info(f"📊 最终结果: 总分={result.overall_score:.3f}, 置信度={result.confidence:.3f}, 匹配标签数={len(result.matched_tags)}")
|
||||
logger.info(
|
||||
f"📊 最终结果: 总分={result.overall_score:.3f}, 置信度={result.confidence:.3f}, 匹配标签数={len(result.matched_tags)}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _calculate_keyword_match_bonus(self, keywords: List[str], matched_tags: List[str]) -> Dict[str, float]:
|
||||
"""计算关键词直接匹配奖励"""
|
||||
if not keywords or not matched_tags:
|
||||
@@ -522,17 +535,25 @@ class BotInterestManager:
|
||||
# 完全匹配
|
||||
if keyword_lower == tag_name_lower:
|
||||
bonus += affinity_config.high_match_interest_threshold * 0.6 # 使用高匹配阈值的60%作为完全匹配奖励
|
||||
logger.debug(f" 🎯 关键词完全匹配: '{keyword}' == '{tag_name}' (+{affinity_config.high_match_interest_threshold * 0.6:.3f})")
|
||||
logger.debug(
|
||||
f" 🎯 关键词完全匹配: '{keyword}' == '{tag_name}' (+{affinity_config.high_match_interest_threshold * 0.6:.3f})"
|
||||
)
|
||||
|
||||
# 包含匹配
|
||||
elif keyword_lower in tag_name_lower or tag_name_lower in keyword_lower:
|
||||
bonus += affinity_config.medium_match_interest_threshold * 0.3 # 使用中匹配阈值的30%作为包含匹配奖励
|
||||
logger.debug(f" 🎯 关键词包含匹配: '{keyword}' ⊃ '{tag_name}' (+{affinity_config.medium_match_interest_threshold * 0.3:.3f})")
|
||||
bonus += (
|
||||
affinity_config.medium_match_interest_threshold * 0.3
|
||||
) # 使用中匹配阈值的30%作为包含匹配奖励
|
||||
logger.debug(
|
||||
f" 🎯 关键词包含匹配: '{keyword}' ⊃ '{tag_name}' (+{affinity_config.medium_match_interest_threshold * 0.3:.3f})"
|
||||
)
|
||||
|
||||
# 部分匹配(编辑距离)
|
||||
elif self._calculate_partial_match(keyword_lower, tag_name_lower):
|
||||
bonus += affinity_config.low_match_interest_threshold * 0.4 # 使用低匹配阈值的40%作为部分匹配奖励
|
||||
logger.debug(f" 🎯 关键词部分匹配: '{keyword}' ≈ '{tag_name}' (+{affinity_config.low_match_interest_threshold * 0.4:.3f})")
|
||||
logger.debug(
|
||||
f" 🎯 关键词部分匹配: '{keyword}' ≈ '{tag_name}' (+{affinity_config.low_match_interest_threshold * 0.4:.3f})"
|
||||
)
|
||||
|
||||
if bonus > 0:
|
||||
bonus_dict[tag_name] = min(bonus, affinity_config.max_match_bonus) # 使用配置的最大奖励限制
|
||||
@@ -608,12 +629,12 @@ class BotInterestManager:
|
||||
|
||||
with get_db_session() as session:
|
||||
# 查询最新的兴趣标签配置
|
||||
db_interests = session.query(DBBotPersonalityInterests).filter(
|
||||
DBBotPersonalityInterests.personality_id == personality_id
|
||||
).order_by(
|
||||
DBBotPersonalityInterests.version.desc(),
|
||||
DBBotPersonalityInterests.last_updated.desc()
|
||||
).first()
|
||||
db_interests = (
|
||||
session.query(DBBotPersonalityInterests)
|
||||
.filter(DBBotPersonalityInterests.personality_id == personality_id)
|
||||
.order_by(DBBotPersonalityInterests.version.desc(), DBBotPersonalityInterests.last_updated.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if db_interests:
|
||||
logger.info(f"✅ 找到数据库中的兴趣标签配置,版本: {db_interests.version}")
|
||||
@@ -631,7 +652,7 @@ class BotInterestManager:
|
||||
personality_description=db_interests.personality_description,
|
||||
embedding_model=db_interests.embedding_model,
|
||||
version=db_interests.version,
|
||||
last_updated=db_interests.last_updated
|
||||
last_updated=db_interests.last_updated,
|
||||
)
|
||||
|
||||
# 解析兴趣标签
|
||||
@@ -639,10 +660,14 @@ class BotInterestManager:
|
||||
tag = BotInterestTag(
|
||||
tag_name=tag_data.get("tag_name", ""),
|
||||
weight=tag_data.get("weight", 0.5),
|
||||
created_at=datetime.fromisoformat(tag_data.get("created_at", datetime.now().isoformat())),
|
||||
updated_at=datetime.fromisoformat(tag_data.get("updated_at", datetime.now().isoformat())),
|
||||
created_at=datetime.fromisoformat(
|
||||
tag_data.get("created_at", datetime.now().isoformat())
|
||||
),
|
||||
updated_at=datetime.fromisoformat(
|
||||
tag_data.get("updated_at", datetime.now().isoformat())
|
||||
),
|
||||
is_active=tag_data.get("is_active", True),
|
||||
embedding=tag_data.get("embedding")
|
||||
embedding=tag_data.get("embedding"),
|
||||
)
|
||||
interests.interest_tags.append(tag)
|
||||
|
||||
@@ -685,7 +710,7 @@ class BotInterestManager:
|
||||
"created_at": tag.created_at.isoformat(),
|
||||
"updated_at": tag.updated_at.isoformat(),
|
||||
"is_active": tag.is_active,
|
||||
"embedding": tag.embedding
|
||||
"embedding": tag.embedding,
|
||||
}
|
||||
tags_data.append(tag_dict)
|
||||
|
||||
@@ -694,9 +719,11 @@ class BotInterestManager:
|
||||
|
||||
with get_db_session() as session:
|
||||
# 检查是否已存在相同personality_id的记录
|
||||
existing_record = session.query(DBBotPersonalityInterests).filter(
|
||||
DBBotPersonalityInterests.personality_id == interests.personality_id
|
||||
).first()
|
||||
existing_record = (
|
||||
session.query(DBBotPersonalityInterests)
|
||||
.filter(DBBotPersonalityInterests.personality_id == interests.personality_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_record:
|
||||
# 更新现有记录
|
||||
@@ -718,7 +745,7 @@ class BotInterestManager:
|
||||
interest_tags=json_data,
|
||||
embedding_model=interests.embedding_model,
|
||||
version=interests.version,
|
||||
last_updated=interests.last_updated
|
||||
last_updated=interests.last_updated,
|
||||
)
|
||||
session.add(new_record)
|
||||
session.commit()
|
||||
@@ -728,9 +755,11 @@ class BotInterestManager:
|
||||
|
||||
# 验证保存是否成功
|
||||
with get_db_session() as session:
|
||||
saved_record = session.query(DBBotPersonalityInterests).filter(
|
||||
DBBotPersonalityInterests.personality_id == interests.personality_id
|
||||
).first()
|
||||
saved_record = (
|
||||
session.query(DBBotPersonalityInterests)
|
||||
.filter(DBBotPersonalityInterests.personality_id == interests.personality_id)
|
||||
.first()
|
||||
)
|
||||
session.commit()
|
||||
if saved_record:
|
||||
logger.info(f"✅ 验证成功:数据库中存在personality_id为 {interests.personality_id} 的记录")
|
||||
@@ -760,7 +789,7 @@ class BotInterestManager:
|
||||
"total_tags": len(active_tags),
|
||||
"embedding_model": self.current_interests.embedding_model,
|
||||
"last_updated": self.current_interests.last_updated.isoformat(),
|
||||
"cache_size": len(self.embedding_cache)
|
||||
"cache_size": len(self.embedding_cache),
|
||||
}
|
||||
|
||||
async def update_interest_tags(self, new_personality_description: str = None):
|
||||
@@ -775,8 +804,7 @@ class BotInterestManager:
|
||||
|
||||
# 重新生成兴趣标签
|
||||
new_interests = await self._generate_interests_from_personality(
|
||||
self.current_interests.personality_description,
|
||||
self.current_interests.personality_id
|
||||
self.current_interests.personality_description, self.current_interests.personality_id
|
||||
)
|
||||
|
||||
if new_interests:
|
||||
@@ -791,4 +819,4 @@ class BotInterestManager:
|
||||
|
||||
|
||||
# 创建全局实例(重新创建以包含新的属性)
|
||||
bot_interest_manager = BotInterestManager()
|
||||
bot_interest_manager = BotInterestManager()
|
||||
|
||||
Reference in New Issue
Block a user