feat(affinity-flow): 实现回复后关系追踪系统集成
- 在relationship_tracker.py中添加数据库支持的回复后关系追踪功能 - 新增UserRelationships数据库模型存储用户关系数据 - 集成全局关系追踪器到planner和interest_scoring系统 - 优化兴趣度评分系统的关系分获取逻辑,优先使用数据库存储的关系分 - 在plan_executor中执行回复后关系追踪,分析用户反应并更新关系 - 添加LLM响应清理功能确保JSON解析稳定性 - 更新模型配置模板添加relationship_tracker模型配置
This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
基于多维度评分机制,包括兴趣匹配度、用户关系分、提及度和时间因子
|
基于多维度评分机制,包括兴趣匹配度、用户关系分、提及度和时间因子
|
||||||
现在使用embedding计算智能兴趣匹配
|
现在使用embedding计算智能兴趣匹配
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Dict, List, Any
|
from typing import Dict, List, Any
|
||||||
|
|
||||||
@@ -24,24 +25,26 @@ class InterestScoringSystem:
|
|||||||
|
|
||||||
# 评分权重
|
# 评分权重
|
||||||
self.score_weights = {
|
self.score_weights = {
|
||||||
"interest_match": 0.5, # 兴趣匹配度权重
|
"interest_match": 0.5, # 兴趣匹配度权重
|
||||||
"relationship": 0.3, # 关系分权重
|
"relationship": 0.3, # 关系分权重
|
||||||
"mentioned": 0.2, # 是否提及bot权重
|
"mentioned": 0.2, # 是否提及bot权重
|
||||||
}
|
}
|
||||||
|
|
||||||
# 评分阈值
|
# 评分阈值
|
||||||
self.reply_threshold = 0.56 # 默认回复阈值
|
self.reply_threshold = 0.62 # 默认回复阈值
|
||||||
self.mention_threshold = 0.3 # 提及阈值
|
self.mention_threshold = 0.3 # 提及阈值
|
||||||
|
|
||||||
# 连续不回复概率提升
|
# 连续不回复概率提升
|
||||||
self.no_reply_count = 0
|
self.no_reply_count = 0
|
||||||
self.max_no_reply_count = 20
|
self.max_no_reply_count = 10
|
||||||
self.probability_boost_per_no_reply = 0.02 # 每次不回复增加5%概率
|
self.probability_boost_per_no_reply = 0.01 # 每次不回复增加5%概率
|
||||||
|
|
||||||
# 用户关系数据
|
# 用户关系数据
|
||||||
self.user_relationships: Dict[str, float] = {} # user_id -> relationship_score
|
self.user_relationships: Dict[str, float] = {} # user_id -> relationship_score
|
||||||
|
|
||||||
async def calculate_interest_scores(self, messages: List[DatabaseMessages], bot_nickname: str) -> List[InterestScore]:
|
async def calculate_interest_scores(
|
||||||
|
self, messages: List[DatabaseMessages], bot_nickname: str
|
||||||
|
) -> List[InterestScore]:
|
||||||
"""计算消息的兴趣度评分"""
|
"""计算消息的兴趣度评分"""
|
||||||
logger.info("🚀 开始计算消息兴趣度评分...")
|
logger.info("🚀 开始计算消息兴趣度评分...")
|
||||||
logger.info(f"📨 收到 {len(messages)} 条消息")
|
logger.info(f"📨 收到 {len(messages)} 条消息")
|
||||||
@@ -87,9 +90,9 @@ class InterestScoringSystem:
|
|||||||
# 4. 计算总分
|
# 4. 计算总分
|
||||||
logger.debug("🧮 计算加权总分...")
|
logger.debug("🧮 计算加权总分...")
|
||||||
total_score = (
|
total_score = (
|
||||||
interest_match_score * self.score_weights["interest_match"] +
|
interest_match_score * self.score_weights["interest_match"]
|
||||||
relationship_score * self.score_weights["relationship"] +
|
+ relationship_score * self.score_weights["relationship"]
|
||||||
mentioned_score * self.score_weights["mentioned"]
|
+ mentioned_score * self.score_weights["mentioned"]
|
||||||
)
|
)
|
||||||
|
|
||||||
details = {
|
details = {
|
||||||
@@ -108,7 +111,7 @@ class InterestScoringSystem:
|
|||||||
interest_match_score=interest_match_score,
|
interest_match_score=interest_match_score,
|
||||||
relationship_score=relationship_score,
|
relationship_score=relationship_score,
|
||||||
mentioned_score=mentioned_score,
|
mentioned_score=mentioned_score,
|
||||||
details=details
|
details=details,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _calculate_interest_match_score(self, content: str, keywords: List[str] = None) -> float:
|
async def _calculate_interest_match_score(self, content: str, keywords: List[str] = None) -> float:
|
||||||
@@ -150,7 +153,9 @@ class InterestScoringSystem:
|
|||||||
# 返回匹配分数,考虑置信度和匹配标签数量
|
# 返回匹配分数,考虑置信度和匹配标签数量
|
||||||
match_count_bonus = min(len(match_result.matched_tags) * 0.05, 0.3) # 每多匹配一个标签+0.05,最高+0.3
|
match_count_bonus = min(len(match_result.matched_tags) * 0.05, 0.3) # 每多匹配一个标签+0.05,最高+0.3
|
||||||
final_score = match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus
|
final_score = match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus
|
||||||
logger.debug(f"⚖️ 最终分数计算: 总分({match_result.overall_score:.3f}) × 1.3 × 置信度({match_result.confidence:.3f}) + 标签数量奖励({match_count_bonus:.3f}) = {final_score:.3f}")
|
logger.debug(
|
||||||
|
f"⚖️ 最终分数计算: 总分({match_result.overall_score:.3f}) × 1.3 × 置信度({match_result.confidence:.3f}) + 标签数量奖励({match_count_bonus:.3f}) = {final_score:.3f}"
|
||||||
|
)
|
||||||
return final_score
|
return final_score
|
||||||
else:
|
else:
|
||||||
logger.warning("⚠️ 智能兴趣匹配未返回结果")
|
logger.warning("⚠️ 智能兴趣匹配未返回结果")
|
||||||
@@ -171,6 +176,7 @@ class InterestScoringSystem:
|
|||||||
if message.key_words:
|
if message.key_words:
|
||||||
try:
|
try:
|
||||||
import orjson
|
import orjson
|
||||||
|
|
||||||
keywords = orjson.loads(message.key_words)
|
keywords = orjson.loads(message.key_words)
|
||||||
if not isinstance(keywords, list):
|
if not isinstance(keywords, list):
|
||||||
keywords = []
|
keywords = []
|
||||||
@@ -181,6 +187,7 @@ class InterestScoringSystem:
|
|||||||
if not keywords and message.key_words_lite:
|
if not keywords and message.key_words_lite:
|
||||||
try:
|
try:
|
||||||
import orjson
|
import orjson
|
||||||
|
|
||||||
keywords = orjson.loads(message.key_words_lite)
|
keywords = orjson.loads(message.key_words_lite)
|
||||||
if not isinstance(keywords, list):
|
if not isinstance(keywords, list):
|
||||||
keywords = []
|
keywords = []
|
||||||
@@ -198,16 +205,18 @@ class InterestScoringSystem:
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
# 清理文本
|
# 清理文本
|
||||||
content = re.sub(r'[^\w\s\u4e00-\u9fff]', ' ', content) # 保留中文、英文、数字
|
content = re.sub(r"[^\w\s\u4e00-\u9fff]", " ", content) # 保留中文、英文、数字
|
||||||
words = content.split()
|
words = content.split()
|
||||||
|
|
||||||
# 过滤和关键词提取
|
# 过滤和关键词提取
|
||||||
keywords = []
|
keywords = []
|
||||||
for word in words:
|
for word in words:
|
||||||
word = word.strip()
|
word = word.strip()
|
||||||
if (len(word) >= 2 and # 至少2个字符
|
if (
|
||||||
word.isalnum() and # 字母数字
|
len(word) >= 2 # 至少2个字符
|
||||||
not word.isdigit()): # 不是纯数字
|
and word.isalnum() # 字母数字
|
||||||
|
and not word.isdigit()
|
||||||
|
): # 不是纯数字
|
||||||
keywords.append(word.lower())
|
keywords.append(word.lower())
|
||||||
|
|
||||||
# 去重并限制数量
|
# 去重并限制数量
|
||||||
@@ -215,11 +224,37 @@ class InterestScoringSystem:
|
|||||||
return unique_keywords[:10] # 返回前10个唯一关键词
|
return unique_keywords[:10] # 返回前10个唯一关键词
|
||||||
|
|
||||||
def _calculate_relationship_score(self, user_id: str) -> float:
|
def _calculate_relationship_score(self, user_id: str) -> float:
|
||||||
"""计算关系分"""
|
"""计算关系分 - 从数据库获取关系分"""
|
||||||
|
# 优先使用内存中的关系分
|
||||||
if user_id in self.user_relationships:
|
if user_id in self.user_relationships:
|
||||||
relationship_value = self.user_relationships[user_id]
|
relationship_value = self.user_relationships[user_id]
|
||||||
return min(relationship_value, 1.0)
|
return min(relationship_value, 1.0)
|
||||||
return 0.3 # 默认新用户的基础分
|
|
||||||
|
# 如果内存中没有,尝试从关系追踪器获取
|
||||||
|
if hasattr(self, "relationship_tracker") and self.relationship_tracker:
|
||||||
|
try:
|
||||||
|
relationship_score = self.relationship_tracker.get_user_relationship_score(user_id)
|
||||||
|
# 同时更新内存缓存
|
||||||
|
self.user_relationships[user_id] = relationship_score
|
||||||
|
return relationship_score
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"从关系追踪器获取关系分失败: {e}")
|
||||||
|
else:
|
||||||
|
# 尝试从全局关系追踪器获取
|
||||||
|
try:
|
||||||
|
from src.chat.affinity_flow.relationship_integration import get_relationship_tracker
|
||||||
|
|
||||||
|
global_tracker = get_relationship_tracker()
|
||||||
|
if global_tracker:
|
||||||
|
relationship_score = global_tracker.get_user_relationship_score(user_id)
|
||||||
|
# 同时更新内存缓存
|
||||||
|
self.user_relationships[user_id] = relationship_score
|
||||||
|
return relationship_score
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"从全局关系追踪器获取关系分失败: {e}")
|
||||||
|
|
||||||
|
# 默认新用户的基础分
|
||||||
|
return 0.3
|
||||||
|
|
||||||
def _calculate_mentioned_score(self, msg: DatabaseMessages, bot_nickname: str) -> float:
|
def _calculate_mentioned_score(self, msg: DatabaseMessages, bot_nickname: str) -> float:
|
||||||
"""计算提及分数"""
|
"""计算提及分数"""
|
||||||
@@ -228,9 +263,9 @@ class InterestScoringSystem:
|
|||||||
|
|
||||||
if msg.is_mentioned or (bot_nickname and bot_nickname in msg.processed_plain_text):
|
if msg.is_mentioned or (bot_nickname and bot_nickname in msg.processed_plain_text):
|
||||||
return 1.0
|
return 1.0
|
||||||
|
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
def should_reply(self, score: InterestScore) -> bool:
|
def should_reply(self, score: InterestScore) -> bool:
|
||||||
"""判断是否应该回复"""
|
"""判断是否应该回复"""
|
||||||
logger.info("🤔 评估是否应该回复...")
|
logger.info("🤔 评估是否应该回复...")
|
||||||
@@ -312,7 +347,6 @@ class InterestScoringSystem:
|
|||||||
"user_relationships": len(self.user_relationships),
|
"user_relationships": len(self.user_relationships),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def reset_stats(self):
|
def reset_stats(self):
|
||||||
"""重置统计信息"""
|
"""重置统计信息"""
|
||||||
self.no_reply_count = 0
|
self.no_reply_count = 0
|
||||||
@@ -345,9 +379,11 @@ class InterestScoringSystem:
|
|||||||
return {
|
return {
|
||||||
"use_smart_matching": self.use_smart_matching,
|
"use_smart_matching": self.use_smart_matching,
|
||||||
"smart_system_initialized": bot_interest_manager.is_initialized,
|
"smart_system_initialized": bot_interest_manager.is_initialized,
|
||||||
"smart_system_stats": bot_interest_manager.get_interest_stats() if bot_interest_manager.is_initialized else None
|
"smart_system_stats": bot_interest_manager.get_interest_stats()
|
||||||
|
if bot_interest_manager.is_initialized
|
||||||
|
else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# 创建全局兴趣评分系统实例
|
# 创建全局兴趣评分系统实例
|
||||||
interest_scoring_system = InterestScoringSystem()
|
interest_scoring_system = InterestScoringSystem()
|
||||||
|
|||||||
68
src/chat/affinity_flow/relationship_integration.py
Normal file
68
src/chat/affinity_flow/relationship_integration.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
"""
|
||||||
|
回复后关系追踪集成初始化脚本
|
||||||
|
|
||||||
|
此脚本用于设置回复后关系追踪系统的全局变量和初始化连接
|
||||||
|
确保各组件能正确协同工作
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.chat.affinity_flow.relationship_tracker import UserRelationshipTracker
|
||||||
|
from src.chat.affinity_flow.interest_scoring import interest_scoring_system
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("relationship_integration")
|
||||||
|
|
||||||
|
# 全局关系追踪器实例
|
||||||
|
relationship_tracker = None
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_relationship_tracking():
|
||||||
|
"""初始化关系追踪系统"""
|
||||||
|
global relationship_tracker
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info("🚀 初始化回复后关系追踪系统...")
|
||||||
|
|
||||||
|
# 创建关系追踪器实例
|
||||||
|
relationship_tracker = UserRelationshipTracker(interest_scoring_system=interest_scoring_system)
|
||||||
|
|
||||||
|
# 设置兴趣度评分系统的关系追踪器引用
|
||||||
|
interest_scoring_system.relationship_tracker = relationship_tracker
|
||||||
|
|
||||||
|
logger.info("✅ 回复后关系追踪系统初始化完成")
|
||||||
|
logger.info("📋 系统功能:")
|
||||||
|
logger.info(" 🔄 自动回复后关系追踪")
|
||||||
|
logger.info(" 💾 数据库持久化存储")
|
||||||
|
logger.info(" 🧠 LLM智能关系分析")
|
||||||
|
logger.info(" ⏰ 5分钟追踪间隔")
|
||||||
|
logger.info(" 🎯 兴趣度评分集成")
|
||||||
|
|
||||||
|
return relationship_tracker
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 关系追踪系统初始化失败: {e}")
|
||||||
|
logger.debug("错误详情:", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_relationship_tracker():
|
||||||
|
"""获取全局关系追踪器实例"""
|
||||||
|
global relationship_tracker
|
||||||
|
return relationship_tracker
|
||||||
|
|
||||||
|
|
||||||
|
def setup_plan_executor_relationship_tracker(plan_executor):
|
||||||
|
"""为PlanExecutor设置关系追踪器"""
|
||||||
|
global relationship_tracker
|
||||||
|
|
||||||
|
if relationship_tracker and plan_executor:
|
||||||
|
plan_executor.set_relationship_tracker(relationship_tracker)
|
||||||
|
logger.info("✅ PlanExecutor关系追踪器设置完成")
|
||||||
|
return True
|
||||||
|
|
||||||
|
logger.warning("⚠️ 无法设置PlanExecutor关系追踪器")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# 自动初始化
|
||||||
|
if __name__ == "__main__":
|
||||||
|
initialize_relationship_tracking()
|
||||||
@@ -1,13 +1,19 @@
|
|||||||
"""
|
"""
|
||||||
用户关系追踪器
|
用户关系追踪器
|
||||||
负责追踪用户交互历史,并通过LLM分析更新用户关系分
|
负责追踪用户交互历史,并通过LLM分析更新用户关系分
|
||||||
|
支持数据库持久化存储和回复后自动关系更新
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import model_config
|
from src.config.config import model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||||
|
from src.common.database.sqlalchemy_models import UserRelationships, Messages
|
||||||
|
from sqlalchemy import select, desc
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|
||||||
logger = get_logger("relationship_tracker")
|
logger = get_logger("relationship_tracker")
|
||||||
|
|
||||||
@@ -23,16 +29,25 @@ class UserRelationshipTracker:
|
|||||||
self.relationship_history: List[Dict] = []
|
self.relationship_history: List[Dict] = []
|
||||||
self.interest_scoring_system = interest_scoring_system
|
self.interest_scoring_system = interest_scoring_system
|
||||||
|
|
||||||
|
# 数据库访问 - 使用SQLAlchemy
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 用户关系缓存 (user_id -> {"relationship_text": str, "relationship_score": float, "last_tracked": float})
|
||||||
|
self.user_relationship_cache: Dict[str, Dict] = {}
|
||||||
|
self.cache_expiry_hours = 1 # 缓存过期时间(小时)
|
||||||
|
|
||||||
# 关系更新LLM
|
# 关系更新LLM
|
||||||
try:
|
try:
|
||||||
self.relationship_llm = LLMRequest(
|
self.relationship_llm = LLMRequest(
|
||||||
model_set=model_config.model_task_config.relationship_tracker,
|
model_set=model_config.model_task_config.relationship_tracker, request_type="relationship_tracker"
|
||||||
request_type="relationship_tracker"
|
|
||||||
)
|
)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# 如果relationship_tracker配置不存在,尝试其他可用的模型配置
|
# 如果relationship_tracker配置不存在,尝试其他可用的模型配置
|
||||||
available_models = [attr for attr in dir(model_config.model_task_config)
|
available_models = [
|
||||||
if not attr.startswith('_') and attr != 'model_dump']
|
attr
|
||||||
|
for attr in dir(model_config.model_task_config)
|
||||||
|
if not attr.startswith("_") and attr != "model_dump"
|
||||||
|
]
|
||||||
|
|
||||||
if available_models:
|
if available_models:
|
||||||
# 使用第一个可用的模型配置
|
# 使用第一个可用的模型配置
|
||||||
@@ -40,14 +55,14 @@ class UserRelationshipTracker:
|
|||||||
logger.warning(f"relationship_tracker model configuration not found, using fallback: {fallback_model}")
|
logger.warning(f"relationship_tracker model configuration not found, using fallback: {fallback_model}")
|
||||||
self.relationship_llm = LLMRequest(
|
self.relationship_llm = LLMRequest(
|
||||||
model_set=getattr(model_config.model_task_config, fallback_model),
|
model_set=getattr(model_config.model_task_config, fallback_model),
|
||||||
request_type="relationship_tracker"
|
request_type="relationship_tracker",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 如果没有任何模型配置,创建一个简单的LLMRequest
|
# 如果没有任何模型配置,创建一个简单的LLMRequest
|
||||||
logger.warning("No model configurations found, creating basic LLMRequest")
|
logger.warning("No model configurations found, creating basic LLMRequest")
|
||||||
self.relationship_llm = LLMRequest(
|
self.relationship_llm = LLMRequest(
|
||||||
model_set="gpt-3.5-turbo", # 默认模型
|
model_set="gpt-3.5-turbo", # 默认模型
|
||||||
request_type="relationship_tracker"
|
request_type="relationship_tracker",
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_interest_scoring_system(self, interest_scoring_system):
|
def set_interest_scoring_system(self, interest_scoring_system):
|
||||||
@@ -58,8 +73,9 @@ class UserRelationshipTracker:
|
|||||||
"""添加用户交互记录"""
|
"""添加用户交互记录"""
|
||||||
if len(self.tracking_users) >= self.max_tracking_users:
|
if len(self.tracking_users) >= self.max_tracking_users:
|
||||||
# 移除最旧的记录
|
# 移除最旧的记录
|
||||||
oldest_user = min(self.tracking_users.keys(),
|
oldest_user = min(
|
||||||
key=lambda k: self.tracking_users[k].get("reply_timestamp", 0))
|
self.tracking_users.keys(), key=lambda k: self.tracking_users[k].get("reply_timestamp", 0)
|
||||||
|
)
|
||||||
del self.tracking_users[oldest_user]
|
del self.tracking_users[oldest_user]
|
||||||
|
|
||||||
# 获取当前关系分
|
# 获取当前关系分
|
||||||
@@ -73,7 +89,7 @@ class UserRelationshipTracker:
|
|||||||
"user_message": user_message,
|
"user_message": user_message,
|
||||||
"bot_reply": bot_reply,
|
"bot_reply": bot_reply,
|
||||||
"reply_timestamp": reply_timestamp,
|
"reply_timestamp": reply_timestamp,
|
||||||
"current_relationship_score": current_relationship_score
|
"current_relationship_score": current_relationship_score,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.debug(f"添加用户交互追踪: {user_id}")
|
logger.debug(f"添加用户交互追踪: {user_id}")
|
||||||
@@ -101,11 +117,11 @@ class UserRelationshipTracker:
|
|||||||
prompt = f"""
|
prompt = f"""
|
||||||
分析以下用户交互,更新用户关系:
|
分析以下用户交互,更新用户关系:
|
||||||
|
|
||||||
用户ID: {interaction['user_id']}
|
用户ID: {interaction["user_id"]}
|
||||||
用户名: {interaction['user_name']}
|
用户名: {interaction["user_name"]}
|
||||||
用户消息: {interaction['user_message']}
|
用户消息: {interaction["user_message"]}
|
||||||
Bot回复: {interaction['bot_reply']}
|
Bot回复: {interaction["bot_reply"]}
|
||||||
当前关系分: {interaction['current_relationship_score']}
|
当前关系分: {interaction["current_relationship_score"]}
|
||||||
|
|
||||||
请以JSON格式返回更新结果:
|
请以JSON格式返回更新结果:
|
||||||
{{
|
{{
|
||||||
@@ -118,21 +134,30 @@ Bot回复: {interaction['bot_reply']}
|
|||||||
llm_response, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
|
llm_response, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
|
||||||
if llm_response:
|
if llm_response:
|
||||||
import json
|
import json
|
||||||
response_data = json.loads(llm_response)
|
|
||||||
new_score = max(0.0, min(1.0, float(response_data.get("new_relationship_score", 0.3))))
|
|
||||||
|
|
||||||
if self.interest_scoring_system:
|
try:
|
||||||
self.interest_scoring_system.update_user_relationship(
|
# 清理LLM响应,移除可能的格式标记
|
||||||
interaction['user_id'],
|
cleaned_response = self._clean_llm_json_response(llm_response)
|
||||||
new_score - interaction['current_relationship_score']
|
response_data = json.loads(cleaned_response)
|
||||||
)
|
new_score = max(0.0, min(1.0, float(response_data.get("new_relationship_score", 0.3))))
|
||||||
|
|
||||||
return {
|
if self.interest_scoring_system:
|
||||||
"user_id": interaction['user_id'],
|
self.interest_scoring_system.update_user_relationship(
|
||||||
"new_relationship_score": new_score,
|
interaction["user_id"], new_score - interaction["current_relationship_score"]
|
||||||
"reasoning": response_data.get("reasoning", ""),
|
)
|
||||||
"interaction_summary": response_data.get("interaction_summary", "")
|
|
||||||
}
|
return {
|
||||||
|
"user_id": interaction["user_id"],
|
||||||
|
"new_relationship_score": new_score,
|
||||||
|
"reasoning": response_data.get("reasoning", ""),
|
||||||
|
"interaction_summary": response_data.get("interaction_summary", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"LLM响应JSON解析失败: {e}")
|
||||||
|
logger.debug(f"LLM原始响应: {llm_response}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理关系更新数据失败: {e}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"更新用户关系时出错: {e}")
|
logger.error(f"更新用户关系时出错: {e}")
|
||||||
@@ -164,10 +189,7 @@ Bot回复: {interaction['bot_reply']}
|
|||||||
|
|
||||||
def add_to_history(self, relationship_update: Dict):
|
def add_to_history(self, relationship_update: Dict):
|
||||||
"""添加到关系历史"""
|
"""添加到关系历史"""
|
||||||
self.relationship_history.append({
|
self.relationship_history.append({**relationship_update, "update_time": time.time()})
|
||||||
**relationship_update,
|
|
||||||
"update_time": time.time()
|
|
||||||
})
|
|
||||||
|
|
||||||
# 限制历史记录数量
|
# 限制历史记录数量
|
||||||
if len(self.relationship_history) > 100:
|
if len(self.relationship_history) > 100:
|
||||||
@@ -198,16 +220,13 @@ Bot回复: {interaction['bot_reply']}
|
|||||||
if user_id in self.tracking_users:
|
if user_id in self.tracking_users:
|
||||||
current_score = self.tracking_users[user_id]["current_relationship_score"]
|
current_score = self.tracking_users[user_id]["current_relationship_score"]
|
||||||
if self.interest_scoring_system:
|
if self.interest_scoring_system:
|
||||||
self.interest_scoring_system.update_user_relationship(
|
self.interest_scoring_system.update_user_relationship(user_id, new_score - current_score)
|
||||||
user_id,
|
|
||||||
new_score - current_score
|
|
||||||
)
|
|
||||||
|
|
||||||
update_info = {
|
update_info = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"new_relationship_score": new_score,
|
"new_relationship_score": new_score,
|
||||||
"reasoning": reasoning or "手动更新",
|
"reasoning": reasoning or "手动更新",
|
||||||
"interaction_summary": "手动更新关系分"
|
"interaction_summary": "手动更新关系分",
|
||||||
}
|
}
|
||||||
self.add_to_history(update_info)
|
self.add_to_history(update_info)
|
||||||
logger.info(f"强制更新用户关系: {user_id} -> {new_score:.2f}")
|
logger.info(f"强制更新用户关系: {user_id} -> {new_score:.2f}")
|
||||||
@@ -224,5 +243,371 @@ Bot回复: {interaction['bot_reply']}
|
|||||||
"current_relationship_score": interaction["current_relationship_score"],
|
"current_relationship_score": interaction["current_relationship_score"],
|
||||||
"interaction_count": 1, # 简化版本,每次追踪只记录一次交互
|
"interaction_count": 1, # 简化版本,每次追踪只记录一次交互
|
||||||
"last_interaction": interaction["reply_timestamp"],
|
"last_interaction": interaction["reply_timestamp"],
|
||||||
"recent_message": interaction["user_message"][:100] + "..." if len(interaction["user_message"]) > 100 else interaction["user_message"]
|
"recent_message": interaction["user_message"][:100] + "..."
|
||||||
}
|
if len(interaction["user_message"]) > 100
|
||||||
|
else interaction["user_message"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# ===== 数据库支持方法 =====
|
||||||
|
|
||||||
|
def get_user_relationship_score(self, user_id: str) -> float:
|
||||||
|
"""获取用户关系分"""
|
||||||
|
# 先检查缓存
|
||||||
|
if user_id in self.user_relationship_cache:
|
||||||
|
cache_data = self.user_relationship_cache[user_id]
|
||||||
|
# 检查缓存是否过期
|
||||||
|
cache_time = cache_data.get("last_tracked", 0)
|
||||||
|
if time.time() - cache_time < self.cache_expiry_hours * 3600:
|
||||||
|
return cache_data.get("relationship_score", 0.3)
|
||||||
|
|
||||||
|
# 缓存过期或不存在,从数据库获取
|
||||||
|
relationship_data = self._get_user_relationship_from_db(user_id)
|
||||||
|
if relationship_data:
|
||||||
|
# 更新缓存
|
||||||
|
self.user_relationship_cache[user_id] = {
|
||||||
|
"relationship_text": relationship_data.get("relationship_text", ""),
|
||||||
|
"relationship_score": relationship_data.get("relationship_score", 0.3),
|
||||||
|
"last_tracked": time.time(),
|
||||||
|
}
|
||||||
|
return relationship_data.get("relationship_score", 0.3)
|
||||||
|
|
||||||
|
# 数据库中也没有,返回默认值
|
||||||
|
return 0.3
|
||||||
|
|
||||||
|
def _get_user_relationship_from_db(self, user_id: str) -> Optional[Dict]:
|
||||||
|
"""从数据库获取用户关系数据"""
|
||||||
|
try:
|
||||||
|
with get_db_session() as session:
|
||||||
|
# 查询用户关系表
|
||||||
|
stmt = select(UserRelationships).where(UserRelationships.user_id == user_id)
|
||||||
|
result = session.execute(stmt).scalar_one_or_none()
|
||||||
|
|
||||||
|
if result:
|
||||||
|
return {
|
||||||
|
"relationship_text": result.relationship_text or "",
|
||||||
|
"relationship_score": float(result.relationship_score)
|
||||||
|
if result.relationship_score is not None
|
||||||
|
else 0.3,
|
||||||
|
"last_updated": result.last_updated,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"从数据库获取用户关系失败: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _update_user_relationship_in_db(self, user_id: str, relationship_text: str, relationship_score: float):
|
||||||
|
"""更新数据库中的用户关系"""
|
||||||
|
try:
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
with get_db_session() as session:
|
||||||
|
# 检查是否已存在关系记录
|
||||||
|
existing = session.execute(
|
||||||
|
select(UserRelationships).where(UserRelationships.user_id == user_id)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
# 更新现有记录
|
||||||
|
existing.relationship_text = relationship_text
|
||||||
|
existing.relationship_score = relationship_score
|
||||||
|
existing.last_updated = current_time
|
||||||
|
existing.user_name = existing.user_name or user_id # 更新用户名如果为空
|
||||||
|
else:
|
||||||
|
# 插入新记录
|
||||||
|
new_relationship = UserRelationships(
|
||||||
|
user_id=user_id,
|
||||||
|
user_name=user_id,
|
||||||
|
relationship_text=relationship_text,
|
||||||
|
relationship_score=relationship_score,
|
||||||
|
last_updated=current_time,
|
||||||
|
)
|
||||||
|
session.add(new_relationship)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
logger.info(f"已更新数据库中用户关系: {user_id} -> 分数: {relationship_score:.3f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"更新数据库用户关系失败: {e}")
|
||||||
|
|
||||||
|
# ===== 回复后关系追踪方法 =====
|
||||||
|
|
||||||
|
async def track_reply_relationship(
|
||||||
|
self, user_id: str, user_name: str, bot_reply_content: str, reply_timestamp: float
|
||||||
|
):
|
||||||
|
"""回复后关系追踪 - 主要入口点"""
|
||||||
|
try:
|
||||||
|
logger.info(f"🔄 开始回复后关系追踪: {user_id}")
|
||||||
|
|
||||||
|
# 检查上次追踪时间
|
||||||
|
last_tracked_time = self._get_last_tracked_time(user_id)
|
||||||
|
time_diff = reply_timestamp - last_tracked_time
|
||||||
|
|
||||||
|
if time_diff < 5 * 60: # 5分钟内不重复追踪
|
||||||
|
logger.debug(f"用户 {user_id} 距离上次追踪时间不足5分钟,跳过")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 获取上次bot回复该用户的消息
|
||||||
|
last_bot_reply = await self._get_last_bot_reply_to_user(user_id)
|
||||||
|
if not last_bot_reply:
|
||||||
|
logger.debug(f"未找到上次回复用户 {user_id} 的记录")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 获取用户后续的反应消息
|
||||||
|
user_reactions = await self._get_user_reactions_after_reply(user_id, last_bot_reply.time)
|
||||||
|
|
||||||
|
# 获取当前关系数据
|
||||||
|
current_relationship = self._get_user_relationship_from_db(user_id)
|
||||||
|
current_score = current_relationship.get("relationship_score", 0.3) if current_relationship else 0.3
|
||||||
|
current_text = current_relationship.get("relationship_text", "新用户") if current_relationship else "新用户"
|
||||||
|
|
||||||
|
# 使用LLM分析并更新关系
|
||||||
|
await self._analyze_and_update_relationship(
|
||||||
|
user_id, user_name, last_bot_reply, user_reactions, current_text, current_score, bot_reply_content
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"回复后关系追踪失败: {e}")
|
||||||
|
logger.debug("错误详情:", exc_info=True)
|
||||||
|
|
||||||
|
def _get_last_tracked_time(self, user_id: str) -> float:
|
||||||
|
"""获取上次追踪时间"""
|
||||||
|
# 先检查缓存
|
||||||
|
if user_id in self.user_relationship_cache:
|
||||||
|
return self.user_relationship_cache[user_id].get("last_tracked", 0)
|
||||||
|
|
||||||
|
# 从数据库获取
|
||||||
|
relationship_data = self._get_user_relationship_from_db(user_id)
|
||||||
|
if relationship_data:
|
||||||
|
return relationship_data.get("last_updated", 0)
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def _get_last_bot_reply_to_user(self, user_id: str) -> Optional[DatabaseMessages]:
|
||||||
|
"""获取上次bot回复该用户的消息"""
|
||||||
|
try:
|
||||||
|
with get_db_session() as session:
|
||||||
|
# 查询bot回复给该用户的最新消息
|
||||||
|
stmt = (
|
||||||
|
select(Messages)
|
||||||
|
.where(Messages.user_id == user_id)
|
||||||
|
.where(Messages.reply_to.isnot(None))
|
||||||
|
.order_by(desc(Messages.time))
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = session.execute(stmt).scalar_one_or_none()
|
||||||
|
if result:
|
||||||
|
# 将SQLAlchemy模型转换为DatabaseMessages对象
|
||||||
|
return self._sqlalchemy_to_database_messages(result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取上次回复消息失败: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _get_user_reactions_after_reply(self, user_id: str, reply_time: float) -> List[DatabaseMessages]:
|
||||||
|
"""获取用户在bot回复后的反应消息"""
|
||||||
|
try:
|
||||||
|
with get_db_session() as session:
|
||||||
|
# 查询用户在回复时间之后的5分钟内的消息
|
||||||
|
end_time = reply_time + 5 * 60 # 5分钟
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(Messages)
|
||||||
|
.where(Messages.user_id == user_id)
|
||||||
|
.where(Messages.time > reply_time)
|
||||||
|
.where(Messages.time <= end_time)
|
||||||
|
.order_by(Messages.time)
|
||||||
|
)
|
||||||
|
|
||||||
|
results = session.execute(stmt).scalars().all()
|
||||||
|
if results:
|
||||||
|
return [self._sqlalchemy_to_database_messages(result) for result in results]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取用户反应消息失败: {e}")
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _sqlalchemy_to_database_messages(self, sqlalchemy_message) -> DatabaseMessages:
|
||||||
|
"""将SQLAlchemy消息模型转换为DatabaseMessages对象"""
|
||||||
|
try:
|
||||||
|
return DatabaseMessages(
|
||||||
|
message_id=sqlalchemy_message.message_id or "",
|
||||||
|
time=float(sqlalchemy_message.time) if sqlalchemy_message.time is not None else 0.0,
|
||||||
|
chat_id=sqlalchemy_message.chat_id or "",
|
||||||
|
reply_to=sqlalchemy_message.reply_to,
|
||||||
|
processed_plain_text=sqlalchemy_message.processed_plain_text or "",
|
||||||
|
user_id=sqlalchemy_message.user_id or "",
|
||||||
|
user_nickname=sqlalchemy_message.user_nickname or "",
|
||||||
|
user_platform=sqlalchemy_message.user_platform or "",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"SQLAlchemy消息转换失败: {e}")
|
||||||
|
# 返回一个基本的消息对象
|
||||||
|
return DatabaseMessages(
|
||||||
|
message_id="",
|
||||||
|
time=0.0,
|
||||||
|
chat_id="",
|
||||||
|
processed_plain_text="",
|
||||||
|
user_id="",
|
||||||
|
user_nickname="",
|
||||||
|
user_platform="",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _analyze_and_update_relationship(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
user_name: str,
|
||||||
|
last_bot_reply: DatabaseMessages,
|
||||||
|
user_reactions: List[DatabaseMessages],
|
||||||
|
current_text: str,
|
||||||
|
current_score: float,
|
||||||
|
current_reply: str,
|
||||||
|
):
|
||||||
|
"""使用LLM分析并更新用户关系"""
|
||||||
|
try:
|
||||||
|
# 构建分析提示
|
||||||
|
user_reactions_text = "\n".join([f"- {msg.processed_plain_text}" for msg in user_reactions])
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
分析以下用户交互,更新用户关系印象和分数:
|
||||||
|
|
||||||
|
用户信息:
|
||||||
|
- 用户ID: {user_id}
|
||||||
|
- 用户名: {user_name}
|
||||||
|
|
||||||
|
上次Bot回复: {last_bot_reply.processed_plain_text}
|
||||||
|
|
||||||
|
用户反应消息:
|
||||||
|
{user_reactions_text}
|
||||||
|
|
||||||
|
当前Bot回复: {current_reply}
|
||||||
|
|
||||||
|
当前关系印象: {current_text}
|
||||||
|
当前关系分数: {current_score:.3f}
|
||||||
|
|
||||||
|
请根据用户的反应和对话内容,分析用户性格特点、与Bot的互动模式,然后更新关系印象和分数。
|
||||||
|
|
||||||
|
分析要点:
|
||||||
|
1. 用户的情绪态度(积极/消极/中性)
|
||||||
|
2. 用户对Bot的兴趣程度
|
||||||
|
3. 用户的交流风格(主动/被动/友好/正式等)
|
||||||
|
4. 互动的质量和深度
|
||||||
|
|
||||||
|
请以JSON格式返回更新结果:
|
||||||
|
{{
|
||||||
|
"relationship_text": "更新的关系印象描述(50字以内)",
|
||||||
|
"relationship_score": 0.0~1.0的新分数,
|
||||||
|
"analysis_reasoning": "分析理由说明",
|
||||||
|
"interaction_quality": "high/medium/low"
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 调用LLM进行分析
|
||||||
|
llm_response, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
|
||||||
|
|
||||||
|
if llm_response:
|
||||||
|
import json
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 清理LLM响应,移除可能的格式标记
|
||||||
|
cleaned_response = self._clean_llm_json_response(llm_response)
|
||||||
|
response_data = json.loads(cleaned_response)
|
||||||
|
|
||||||
|
new_text = response_data.get("relationship_text", current_text)
|
||||||
|
new_score = max(0.0, min(1.0, float(response_data.get("relationship_score", current_score))))
|
||||||
|
reasoning = response_data.get("analysis_reasoning", "")
|
||||||
|
quality = response_data.get("interaction_quality", "medium")
|
||||||
|
|
||||||
|
# 更新数据库
|
||||||
|
self._update_user_relationship_in_db(user_id, new_text, new_score)
|
||||||
|
|
||||||
|
# 更新缓存
|
||||||
|
self.user_relationship_cache[user_id] = {
|
||||||
|
"relationship_text": new_text,
|
||||||
|
"relationship_score": new_score,
|
||||||
|
"last_tracked": time.time(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果有兴趣度评分系统,也更新内存中的关系分
|
||||||
|
if self.interest_scoring_system:
|
||||||
|
self.interest_scoring_system.update_user_relationship(user_id, new_score - current_score)
|
||||||
|
|
||||||
|
# 记录分析历史
|
||||||
|
analysis_record = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"old_score": current_score,
|
||||||
|
"new_score": new_score,
|
||||||
|
"old_text": current_text,
|
||||||
|
"new_text": new_text,
|
||||||
|
"reasoning": reasoning,
|
||||||
|
"quality": quality,
|
||||||
|
"user_reactions_count": len(user_reactions),
|
||||||
|
}
|
||||||
|
self.relationship_history.append(analysis_record)
|
||||||
|
|
||||||
|
# 限制历史记录数量
|
||||||
|
if len(self.relationship_history) > 100:
|
||||||
|
self.relationship_history = self.relationship_history[-100:]
|
||||||
|
|
||||||
|
logger.info(f"✅ 关系分析完成: {user_id}")
|
||||||
|
logger.info(f" 📝 印象: '{current_text}' -> '{new_text}'")
|
||||||
|
logger.info(f" 💝 分数: {current_score:.3f} -> {new_score:.3f}")
|
||||||
|
logger.info(f" 🎯 质量: {quality}")
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"LLM响应JSON解析失败: {e}")
|
||||||
|
logger.debug(f"LLM原始响应: {llm_response}")
|
||||||
|
else:
|
||||||
|
logger.warning("LLM未返回有效响应")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"关系分析失败: {e}")
|
||||||
|
logger.debug("错误详情:", exc_info=True)
|
||||||
|
|
||||||
|
def _clean_llm_json_response(self, response: str) -> str:
|
||||||
|
"""
|
||||||
|
清理LLM响应,移除可能的JSON格式标记
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: LLM原始响应
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
清理后的JSON字符串
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import re
|
||||||
|
|
||||||
|
# 移除常见的JSON格式标记
|
||||||
|
cleaned = response.strip()
|
||||||
|
|
||||||
|
# 移除 ```json 或 ``` 等标记
|
||||||
|
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.MULTILINE | re.IGNORECASE)
|
||||||
|
cleaned = re.sub(r"\s*```$", "", cleaned, flags=re.MULTILINE)
|
||||||
|
|
||||||
|
# 移除可能的Markdown代码块标记
|
||||||
|
cleaned = re.sub(r"^`|`$", "", cleaned, flags=re.MULTILINE)
|
||||||
|
|
||||||
|
# 尝试找到JSON对象的开始和结束
|
||||||
|
json_start = cleaned.find("{")
|
||||||
|
json_end = cleaned.rfind("}")
|
||||||
|
|
||||||
|
if json_start != -1 and json_end != -1 and json_end > json_start:
|
||||||
|
# 提取JSON部分
|
||||||
|
cleaned = cleaned[json_start : json_end + 1]
|
||||||
|
|
||||||
|
# 移除多余的空白字符
|
||||||
|
cleaned = cleaned.strip()
|
||||||
|
|
||||||
|
logger.debug(f"LLM响应清理: 原始长度={len(response)}, 清理后长度={len(cleaned)}")
|
||||||
|
if cleaned != response:
|
||||||
|
logger.debug(f"清理前: {response[:200]}...")
|
||||||
|
logger.debug(f"清理后: {cleaned[:200]}...")
|
||||||
|
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"清理LLM响应失败: {e}")
|
||||||
|
return response # 清理失败时返回原始响应
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
PlanExecutor: 接收 Plan 对象并执行其中的所有动作。
|
PlanExecutor: 接收 Plan 对象并执行其中的所有动作。
|
||||||
集成用户关系追踪机制,自动记录交互并更新关系。
|
集成用户关系追踪机制,自动记录交互并更新关系。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
@@ -94,7 +95,9 @@ class PlanExecutor:
|
|||||||
self.execution_stats["successful_executions"] += successful_count
|
self.execution_stats["successful_executions"] += successful_count
|
||||||
self.execution_stats["failed_executions"] += len(execution_results) - successful_count
|
self.execution_stats["failed_executions"] += len(execution_results) - successful_count
|
||||||
|
|
||||||
logger.info(f"动作执行完成: 总数={len(plan.decided_actions)}, 成功={successful_count}, 失败={len(execution_results) - successful_count}")
|
logger.info(
|
||||||
|
f"动作执行完成: 总数={len(plan.decided_actions)}, 成功={successful_count}, 失败={len(execution_results) - successful_count}"
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"executed_count": len(plan.decided_actions),
|
"executed_count": len(plan.decided_actions),
|
||||||
@@ -123,7 +126,7 @@ class PlanExecutor:
|
|||||||
try:
|
try:
|
||||||
logger.info(f"执行回复动作: {action_info.action_type}, 原因: {action_info.reasoning}")
|
logger.info(f"执行回复动作: {action_info.action_type}, 原因: {action_info.reasoning}")
|
||||||
|
|
||||||
if action_info.action_message.get("user_id","") == str(global_config.bot.qq_account):
|
if action_info.action_message.get("user_id", "") == str(global_config.bot.qq_account):
|
||||||
logger.warning("尝试回复自己,跳过此动作以防止死循环。")
|
logger.warning("尝试回复自己,跳过此动作以防止死循环。")
|
||||||
return {
|
return {
|
||||||
"action_type": action_info.action_type,
|
"action_type": action_info.action_type,
|
||||||
@@ -143,8 +146,7 @@ class PlanExecutor:
|
|||||||
|
|
||||||
# 通过动作管理器执行回复
|
# 通过动作管理器执行回复
|
||||||
reply_content = await self.action_manager.execute_action(
|
reply_content = await self.action_manager.execute_action(
|
||||||
action_name=action_info.action_type,
|
action_name=action_info.action_type, **action_params
|
||||||
**action_params
|
|
||||||
)
|
)
|
||||||
|
|
||||||
success = True
|
success = True
|
||||||
@@ -185,13 +187,15 @@ class PlanExecutor:
|
|||||||
for i, result in enumerate(executed_results):
|
for i, result in enumerate(executed_results):
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
logger.error(f"执行动作 {other_actions[i].action_type} 时发生异常: {result}")
|
logger.error(f"执行动作 {other_actions[i].action_type} 时发生异常: {result}")
|
||||||
results.append({
|
results.append(
|
||||||
"action_type": other_actions[i].action_type,
|
{
|
||||||
"success": False,
|
"action_type": other_actions[i].action_type,
|
||||||
"error_message": str(result),
|
"success": False,
|
||||||
"execution_time": 0,
|
"error_message": str(result),
|
||||||
"reasoning": other_actions[i].reasoning,
|
"execution_time": 0,
|
||||||
})
|
"reasoning": other_actions[i].reasoning,
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
@@ -215,10 +219,7 @@ class PlanExecutor:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 通过动作管理器执行动作
|
# 通过动作管理器执行动作
|
||||||
await self.action_manager.execute_action(
|
await self.action_manager.execute_action(action_name=action_info.action_type, **action_params)
|
||||||
action_name=action_info.action_type,
|
|
||||||
**action_params
|
|
||||||
)
|
|
||||||
|
|
||||||
success = True
|
success = True
|
||||||
logger.info(f"其他动作执行成功: {action_info.action_type}")
|
logger.info(f"其他动作执行成功: {action_info.action_type}")
|
||||||
@@ -239,30 +240,49 @@ class PlanExecutor:
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def _track_user_interaction(self, action_info: ActionPlannerInfo, plan: Plan, reply_content: str):
|
async def _track_user_interaction(self, action_info: ActionPlannerInfo, plan: Plan, reply_content: str):
|
||||||
"""追踪用户交互"""
|
"""追踪用户交互 - 集成回复后关系追踪"""
|
||||||
try:
|
try:
|
||||||
if not action_info.action_message:
|
if not action_info.action_message:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 获取用户信息
|
# 获取用户信息 - 处理对象和字典两种情况
|
||||||
user_id = action_info.action_message.user_id
|
if hasattr(action_info.action_message, "user_id"):
|
||||||
user_name = action_info.action_message.user_nickname or user_id
|
# 对象情况
|
||||||
user_message = action_info.action_message.content
|
user_id = action_info.action_message.user_id
|
||||||
|
user_name = getattr(action_info.action_message, "user_nickname", user_id) or user_id
|
||||||
|
user_message = getattr(action_info.action_message, "content", "")
|
||||||
|
else:
|
||||||
|
# 字典情况
|
||||||
|
user_id = action_info.action_message.get("user_id", "")
|
||||||
|
user_name = action_info.action_message.get("user_nickname", user_id) or user_id
|
||||||
|
user_message = action_info.action_message.get("content", "")
|
||||||
|
|
||||||
# 如果有设置关系追踪器,添加交互记录
|
if not user_id:
|
||||||
|
logger.debug("跳过追踪:缺少用户ID")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 如果有设置关系追踪器,执行回复后关系追踪
|
||||||
if self.relationship_tracker:
|
if self.relationship_tracker:
|
||||||
|
# 记录基础交互信息(保持向后兼容)
|
||||||
self.relationship_tracker.add_interaction(
|
self.relationship_tracker.add_interaction(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
user_message=user_message,
|
user_message=user_message,
|
||||||
bot_reply=reply_content,
|
bot_reply=reply_content,
|
||||||
reply_timestamp=time.time()
|
reply_timestamp=time.time(),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"已添加用户交互追踪: {user_id}")
|
# 执行新的回复后关系追踪
|
||||||
|
await self.relationship_tracker.track_reply_relationship(
|
||||||
|
user_id=user_id, user_name=user_name, bot_reply_content=reply_content, reply_timestamp=time.time()
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"已执行用户交互追踪: {user_id}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"追踪用户交互时出错: {e}")
|
logger.error(f"追踪用户交互时出错: {e}")
|
||||||
|
logger.debug(f"action_message类型: {type(action_info.action_message)}")
|
||||||
|
logger.debug(f"action_message内容: {action_info.action_message}")
|
||||||
|
|
||||||
def get_execution_stats(self) -> Dict[str, any]:
|
def get_execution_stats(self) -> Dict[str, any]:
|
||||||
"""获取执行统计信息"""
|
"""获取执行统计信息"""
|
||||||
@@ -308,4 +328,4 @@ class PlanExecutor:
|
|||||||
"timestamp": time.time() - (len(recent_times) - i) * 60, # 估算时间戳
|
"timestamp": time.time() - (len(recent_times) - i) * 60, # 估算时间戳
|
||||||
}
|
}
|
||||||
for i, time_val in enumerate(recent_times)
|
for i, time_val in enumerate(recent_times)
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
主规划器入口,负责协调 PlanGenerator, PlanFilter, 和 PlanExecutor。
|
主规划器入口,负责协调 PlanGenerator, PlanFilter, 和 PlanExecutor。
|
||||||
集成兴趣度评分系统和用户关系追踪机制,实现智能化的聊天决策。
|
集成兴趣度评分系统和用户关系追踪机制,实现智能化的聊天决策。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
@@ -17,10 +18,11 @@ from src.config.config import global_config
|
|||||||
from src.plugin_system.base.component_types import ChatMode
|
from src.plugin_system.base.component_types import ChatMode
|
||||||
|
|
||||||
# 导入提示词模块以确保其被初始化
|
# 导入提示词模块以确保其被初始化
|
||||||
from src.chat.planner_actions import planner_prompts #noqa
|
from src.chat.planner_actions import planner_prompts # noqa
|
||||||
|
|
||||||
logger = get_logger("planner")
|
logger = get_logger("planner")
|
||||||
|
|
||||||
|
|
||||||
class ActionPlanner:
|
class ActionPlanner:
|
||||||
"""
|
"""
|
||||||
增强版ActionPlanner,集成兴趣度评分和用户关系追踪机制。
|
增强版ActionPlanner,集成兴趣度评分和用户关系追踪机制。
|
||||||
@@ -49,8 +51,25 @@ class ActionPlanner:
|
|||||||
# 初始化兴趣度评分系统
|
# 初始化兴趣度评分系统
|
||||||
self.interest_scoring = InterestScoringSystem()
|
self.interest_scoring = InterestScoringSystem()
|
||||||
|
|
||||||
# 初始化用户关系追踪器
|
# 尝试获取全局关系追踪器,如果没有则创建新的
|
||||||
self.relationship_tracker = UserRelationshipTracker(self.interest_scoring)
|
try:
|
||||||
|
from src.chat.affinity_flow.relationship_integration import get_relationship_tracker
|
||||||
|
|
||||||
|
global_relationship_tracker = get_relationship_tracker()
|
||||||
|
if global_relationship_tracker:
|
||||||
|
# 使用全局关系追踪器
|
||||||
|
self.relationship_tracker = global_relationship_tracker
|
||||||
|
# 设置兴趣度评分系统的关系追踪器引用
|
||||||
|
self.interest_scoring.relationship_tracker = self.relationship_tracker
|
||||||
|
logger.info("使用全局关系追踪器")
|
||||||
|
else:
|
||||||
|
# 创建新的关系追踪器
|
||||||
|
self.relationship_tracker = UserRelationshipTracker(self.interest_scoring)
|
||||||
|
logger.info("创建新的关系追踪器实例")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"获取全局关系追踪器失败: {e}")
|
||||||
|
# 创建新的关系追踪器
|
||||||
|
self.relationship_tracker = UserRelationshipTracker(self.interest_scoring)
|
||||||
|
|
||||||
# 设置执行器的关系追踪器
|
# 设置执行器的关系追踪器
|
||||||
self.executor.set_relationship_tracker(self.relationship_tracker)
|
self.executor.set_relationship_tracker(self.relationship_tracker)
|
||||||
@@ -64,7 +83,9 @@ class ActionPlanner:
|
|||||||
"other_actions_executed": 0,
|
"other_actions_executed": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def plan(self, mode: ChatMode = ChatMode.FOCUS, message_data: dict = None) -> Tuple[List[Dict], Optional[Dict]]:
|
async def plan(
|
||||||
|
self, mode: ChatMode = ChatMode.FOCUS, message_data: dict = None
|
||||||
|
) -> Tuple[List[Dict], Optional[Dict]]:
|
||||||
"""
|
"""
|
||||||
执行完整的增强版规划流程。
|
执行完整的增强版规划流程。
|
||||||
|
|
||||||
@@ -86,13 +107,14 @@ class ActionPlanner:
|
|||||||
|
|
||||||
return await self._enhanced_plan_flow(mode, unread_messages or [])
|
return await self._enhanced_plan_flow(mode, unread_messages or [])
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"规划流程出错: {e}")
|
logger.error(f"规划流程出错: {e}")
|
||||||
self.planner_stats["failed_plans"] += 1
|
self.planner_stats["failed_plans"] += 1
|
||||||
return [], None
|
return [], None
|
||||||
|
|
||||||
async def _enhanced_plan_flow(self, mode: ChatMode, unread_messages: List[Dict]) -> Tuple[List[Dict], Optional[Dict]]:
|
async def _enhanced_plan_flow(
|
||||||
|
self, mode: ChatMode, unread_messages: List[Dict]
|
||||||
|
) -> Tuple[List[Dict], Optional[Dict]]:
|
||||||
"""执行增强版规划流程"""
|
"""执行增强版规划流程"""
|
||||||
try:
|
try:
|
||||||
# 1. 生成初始 Plan
|
# 1. 生成初始 Plan
|
||||||
@@ -101,9 +123,7 @@ class ActionPlanner:
|
|||||||
# 2. 兴趣度评分 - 只对未读消息进行评分
|
# 2. 兴趣度评分 - 只对未读消息进行评分
|
||||||
if unread_messages:
|
if unread_messages:
|
||||||
bot_nickname = global_config.bot.nickname
|
bot_nickname = global_config.bot.nickname
|
||||||
interest_scores = await self.interest_scoring.calculate_interest_scores(
|
interest_scores = await self.interest_scoring.calculate_interest_scores(unread_messages, bot_nickname)
|
||||||
unread_messages, bot_nickname
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. 根据兴趣度调整可用动作
|
# 3. 根据兴趣度调整可用动作
|
||||||
if interest_scores:
|
if interest_scores:
|
||||||
@@ -123,6 +143,7 @@ class ActionPlanner:
|
|||||||
logger.info(f"📊 最低要求: 阈值({base_threshold:.3f}) × 0.8 = {threshold_requirement:.3f}")
|
logger.info(f"📊 最低要求: 阈值({base_threshold:.3f}) × 0.8 = {threshold_requirement:.3f}")
|
||||||
# 直接返回 no_action
|
# 直接返回 no_action
|
||||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||||
|
|
||||||
no_action = ActionPlannerInfo(
|
no_action = ActionPlannerInfo(
|
||||||
action_type="no_action",
|
action_type="no_action",
|
||||||
reasoning=f"兴趣度评分 {score:.3f} 未达阈值80% {threshold_requirement:.3f}",
|
reasoning=f"兴趣度评分 {score:.3f} 未达阈值80% {threshold_requirement:.3f}",
|
||||||
@@ -133,7 +154,7 @@ class ActionPlanner:
|
|||||||
filtered_plan.decided_actions = [no_action]
|
filtered_plan.decided_actions = [no_action]
|
||||||
else:
|
else:
|
||||||
# 4. 筛选 Plan
|
# 4. 筛选 Plan
|
||||||
filtered_plan = await self.filter.filter(reply_not_available,initial_plan)
|
filtered_plan = await self.filter.filter(reply_not_available, initial_plan)
|
||||||
|
|
||||||
# 检查filtered_plan是否有reply动作,以便记录reply action
|
# 检查filtered_plan是否有reply动作,以便记录reply action
|
||||||
has_reply_action = False
|
has_reply_action = False
|
||||||
@@ -158,42 +179,40 @@ class ActionPlanner:
|
|||||||
logger.error(f"增强版规划流程出错: {e}")
|
logger.error(f"增强版规划流程出错: {e}")
|
||||||
self.planner_stats["failed_plans"] += 1
|
self.planner_stats["failed_plans"] += 1
|
||||||
return [], None
|
return [], None
|
||||||
|
|
||||||
def _update_stats_from_execution_result(self, execution_result: Dict[str, any]):
|
def _update_stats_from_execution_result(self, execution_result: Dict[str, any]):
|
||||||
"""根据执行结果更新规划器统计"""
|
"""根据执行结果更新规划器统计"""
|
||||||
if not execution_result:
|
if not execution_result:
|
||||||
return
|
return
|
||||||
|
|
||||||
successful_count = execution_result.get("successful_count", 0)
|
successful_count = execution_result.get("successful_count", 0)
|
||||||
|
|
||||||
# 更新成功执行计数
|
# 更新成功执行计数
|
||||||
self.planner_stats["successful_plans"] += successful_count
|
self.planner_stats["successful_plans"] += successful_count
|
||||||
|
|
||||||
# 统计回复动作和其他动作
|
# 统计回复动作和其他动作
|
||||||
reply_count = 0
|
reply_count = 0
|
||||||
other_count = 0
|
other_count = 0
|
||||||
|
|
||||||
for result in execution_result.get("results", []):
|
for result in execution_result.get("results", []):
|
||||||
action_type = result.get("action_type", "")
|
action_type = result.get("action_type", "")
|
||||||
if action_type in ["reply", "proactive_reply"]:
|
if action_type in ["reply", "proactive_reply"]:
|
||||||
reply_count += 1
|
reply_count += 1
|
||||||
else:
|
else:
|
||||||
other_count += 1
|
other_count += 1
|
||||||
|
|
||||||
self.planner_stats["replies_generated"] += reply_count
|
self.planner_stats["replies_generated"] += reply_count
|
||||||
self.planner_stats["other_actions_executed"] += other_count
|
self.planner_stats["other_actions_executed"] += other_count
|
||||||
|
|
||||||
def _build_return_result(self, plan: Plan) -> Tuple[List[Dict], Optional[Dict]]:
|
def _build_return_result(self, plan: Plan) -> Tuple[List[Dict], Optional[Dict]]:
|
||||||
"""构建返回结果"""
|
"""构建返回结果"""
|
||||||
final_actions = plan.decided_actions or []
|
final_actions = plan.decided_actions or []
|
||||||
final_target_message = next(
|
final_target_message = next((act.action_message for act in final_actions if act.action_message), None)
|
||||||
(act.action_message for act in final_actions if act.action_message), None
|
|
||||||
)
|
|
||||||
|
|
||||||
final_actions_dict = [asdict(act) for act in final_actions]
|
final_actions_dict = [asdict(act) for act in final_actions]
|
||||||
|
|
||||||
if final_target_message:
|
if final_target_message:
|
||||||
if hasattr(final_target_message, '__dataclass_fields__'):
|
if hasattr(final_target_message, "__dataclass_fields__"):
|
||||||
final_target_message_dict = asdict(final_target_message)
|
final_target_message_dict = asdict(final_target_message)
|
||||||
else:
|
else:
|
||||||
final_target_message_dict = final_target_message
|
final_target_message_dict = final_target_message
|
||||||
@@ -234,4 +253,4 @@ class ActionPlanner:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# 全局兴趣度评分系统实例 - 在 individuality 模块中创建
|
# 全局兴趣度评分系统实例 - 在 individuality 模块中创建
|
||||||
|
|||||||
@@ -712,12 +712,28 @@ class DefaultReplyer:
|
|||||||
msg_id = msg.message_id
|
msg_id = msg.message_id
|
||||||
msg_time = time.strftime('%H:%M:%S', time.localtime(msg.time))
|
msg_time = time.strftime('%H:%M:%S', time.localtime(msg.time))
|
||||||
msg_content = msg.processed_plain_text
|
msg_content = msg.processed_plain_text
|
||||||
|
|
||||||
|
# 使用与已读历史消息相同的方法获取用户名
|
||||||
|
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||||
|
|
||||||
|
# 获取用户信息
|
||||||
|
user_info = getattr(msg, 'user_info', {})
|
||||||
|
platform = getattr(user_info, 'platform', '') or getattr(msg, 'platform', '')
|
||||||
|
user_id = getattr(user_info, 'user_id', '') or getattr(msg, 'user_id', '')
|
||||||
|
|
||||||
|
# 获取用户名
|
||||||
|
if platform and user_id:
|
||||||
|
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||||
|
person_info_manager = get_person_info_manager()
|
||||||
|
sender_name = person_info_manager.get_value_sync(person_id, "person_name") or "未知用户"
|
||||||
|
else:
|
||||||
|
sender_name = "未知用户"
|
||||||
|
|
||||||
# 添加兴趣度信息
|
# 添加兴趣度信息
|
||||||
interest_score = interest_scores.get(msg_id, 0.0)
|
interest_score = interest_scores.get(msg_id, 0.0)
|
||||||
interest_text = f" [兴趣度: {interest_score:.3f}]" if interest_score > 0 else ""
|
interest_text = f" [兴趣度: {interest_score:.3f}]" if interest_score > 0 else ""
|
||||||
|
|
||||||
unread_lines.append(f"{msg_time}: {msg_content}{interest_text}")
|
unread_lines.append(f"{msg_time} {sender_name}: {msg_content}{interest_text}")
|
||||||
|
|
||||||
unread_history_prompt_str = "\n".join(unread_lines)
|
unread_history_prompt_str = "\n".join(unread_lines)
|
||||||
unread_history_prompt = f"这是未读历史消息,包含兴趣度评分,请优先对兴趣值高的消息做出动作:\n{unread_history_prompt_str}"
|
unread_history_prompt = f"这是未读历史消息,包含兴趣度评分,请优先对兴趣值高的消息做出动作:\n{unread_history_prompt_str}"
|
||||||
@@ -795,12 +811,28 @@ class DefaultReplyer:
|
|||||||
msg_id = msg.get("message_id", "")
|
msg_id = msg.get("message_id", "")
|
||||||
msg_time = time.strftime('%H:%M:%S', time.localtime(msg.get("time", time.time())))
|
msg_time = time.strftime('%H:%M:%S', time.localtime(msg.get("time", time.time())))
|
||||||
msg_content = msg.get("processed_plain_text", "")
|
msg_content = msg.get("processed_plain_text", "")
|
||||||
|
|
||||||
|
# 使用与已读历史消息相同的方法获取用户名
|
||||||
|
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||||
|
|
||||||
|
# 获取用户信息
|
||||||
|
user_info = msg.get("user_info", {})
|
||||||
|
platform = user_info.get("platform") or msg.get("platform", "")
|
||||||
|
user_id = user_info.get("user_id") or msg.get("user_id", "")
|
||||||
|
|
||||||
|
# 获取用户名
|
||||||
|
if platform and user_id:
|
||||||
|
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||||
|
person_info_manager = get_person_info_manager()
|
||||||
|
sender_name = person_info_manager.get_value_sync(person_id, "person_name") or "未知用户"
|
||||||
|
else:
|
||||||
|
sender_name = "未知用户"
|
||||||
|
|
||||||
# 添加兴趣度信息
|
# 添加兴趣度信息
|
||||||
interest_score = interest_scores.get(msg_id, 0.0)
|
interest_score = interest_scores.get(msg_id, 0.0)
|
||||||
interest_text = f" [兴趣度: {interest_score:.3f}]" if interest_score > 0 else ""
|
interest_text = f" [兴趣度: {interest_score:.3f}]" if interest_score > 0 else ""
|
||||||
|
|
||||||
unread_lines.append(f"{msg_time}: {msg_content}{interest_text}")
|
unread_lines.append(f"{msg_time} {sender_name}: {msg_content}{interest_text}")
|
||||||
|
|
||||||
unread_history_prompt_str = "\n".join(unread_lines)
|
unread_history_prompt_str = "\n".join(unread_lines)
|
||||||
unread_history_prompt = f"这是未读历史消息,包含兴趣度评分,请优先对兴趣值高的消息做出动作:\n{unread_history_prompt_str}"
|
unread_history_prompt = f"这是未读历史消息,包含兴趣度评分,请优先对兴趣值高的消息做出动作:\n{unread_history_prompt_str}"
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from src.common.database.sqlalchemy_models import (
|
|||||||
Schedule,
|
Schedule,
|
||||||
MaiZoneScheduleStatus,
|
MaiZoneScheduleStatus,
|
||||||
CacheEntries,
|
CacheEntries,
|
||||||
|
UserRelationships,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger("sqlalchemy_database_api")
|
logger = get_logger("sqlalchemy_database_api")
|
||||||
@@ -53,6 +54,7 @@ MODEL_MAPPING = {
|
|||||||
"Schedule": Schedule,
|
"Schedule": Schedule,
|
||||||
"MaiZoneScheduleStatus": MaiZoneScheduleStatus,
|
"MaiZoneScheduleStatus": MaiZoneScheduleStatus,
|
||||||
"CacheEntries": CacheEntries,
|
"CacheEntries": CacheEntries,
|
||||||
|
"UserRelationships": UserRelationships,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -696,7 +696,7 @@ def get_db_session() -> Iterator[Session]:
|
|||||||
raise RuntimeError("Database session not initialized")
|
raise RuntimeError("Database session not initialized")
|
||||||
session = SessionLocal()
|
session = SessionLocal()
|
||||||
yield session
|
yield session
|
||||||
#session.commit()
|
# session.commit()
|
||||||
except Exception:
|
except Exception:
|
||||||
if session:
|
if session:
|
||||||
session.rollback()
|
session.rollback()
|
||||||
@@ -748,3 +748,23 @@ class UserPermissions(Base):
|
|||||||
Index("idx_user_permission", "platform", "user_id", "permission_node"),
|
Index("idx_user_permission", "platform", "user_id", "permission_node"),
|
||||||
Index("idx_permission_granted", "permission_node", "granted"),
|
Index("idx_permission_granted", "permission_node", "granted"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UserRelationships(Base):
|
||||||
|
"""用户关系模型 - 存储用户与bot的关系数据"""
|
||||||
|
|
||||||
|
__tablename__ = "user_relationships"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
user_id = Column(get_string_field(100), nullable=False, unique=True, index=True) # 用户ID
|
||||||
|
user_name = Column(get_string_field(100), nullable=True) # 用户名
|
||||||
|
relationship_text = Column(Text, nullable=True) # 关系印象描述
|
||||||
|
relationship_score = Column(Float, nullable=False, default=0.3) # 关系分数(0-1)
|
||||||
|
last_updated = Column(Float, nullable=False, default=time.time) # 最后更新时间
|
||||||
|
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_user_relationship_id", "user_id"),
|
||||||
|
Index("idx_relationship_score", "relationship_score"),
|
||||||
|
Index("idx_relationship_updated", "last_updated"),
|
||||||
|
)
|
||||||
|
|||||||
@@ -143,7 +143,7 @@ class ModelTaskConfig(ValidatedConfigBase):
|
|||||||
monthly_plan_generator: TaskConfig = Field(..., description="月层计划生成模型配置")
|
monthly_plan_generator: TaskConfig = Field(..., description="月层计划生成模型配置")
|
||||||
emoji_vlm: TaskConfig = Field(..., description="表情包识别模型配置")
|
emoji_vlm: TaskConfig = Field(..., description="表情包识别模型配置")
|
||||||
anti_injection: TaskConfig = Field(..., description="反注入检测专用模型配置")
|
anti_injection: TaskConfig = Field(..., description="反注入检测专用模型配置")
|
||||||
|
relationship_tracker: TaskConfig = Field(..., description="关系追踪模型配置")
|
||||||
# 处理配置文件中命名不一致的问题
|
# 处理配置文件中命名不一致的问题
|
||||||
utils_video: TaskConfig = Field(..., description="视频分析模型配置(兼容配置文件中的命名)")
|
utils_video: TaskConfig = Field(..., description="视频分析模型配置(兼容配置文件中的命名)")
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from maim_message import MessageServer
|
|||||||
from src.common.remote import TelemetryHeartBeatTask
|
from src.common.remote import TelemetryHeartBeatTask
|
||||||
from src.manager.async_task_manager import async_task_manager
|
from src.manager.async_task_manager import async_task_manager
|
||||||
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
||||||
from src.common.remote import TelemetryHeartBeatTask
|
|
||||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
@@ -249,6 +248,14 @@ MoFox_Bot(第三方修改版)
|
|||||||
get_emoji_manager().initialize()
|
get_emoji_manager().initialize()
|
||||||
logger.info("表情包管理器初始化成功")
|
logger.info("表情包管理器初始化成功")
|
||||||
|
|
||||||
|
# 初始化回复后关系追踪系统
|
||||||
|
from src.chat.affinity_flow.relationship_integration import initialize_relationship_tracking
|
||||||
|
relationship_tracker = initialize_relationship_tracking()
|
||||||
|
if relationship_tracker:
|
||||||
|
logger.info("回复后关系追踪系统初始化成功")
|
||||||
|
else:
|
||||||
|
logger.warning("回复后关系追踪系统初始化失败")
|
||||||
|
|
||||||
# 启动情绪管理器
|
# 启动情绪管理器
|
||||||
await mood_manager.start()
|
await mood_manager.start()
|
||||||
logger.info("情绪管理器初始化成功")
|
logger.info("情绪管理器初始化成功")
|
||||||
|
|||||||
@@ -296,8 +296,13 @@ class SendHandler:
|
|||||||
return reply_seg
|
return reply_seg
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 尝试通过 message_id 获取消息详情
|
# 检查是否为缓冲消息ID(格式:buffered-{original_id}-{timestamp})
|
||||||
msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(id)})
|
if id.startswith('buffered-'):
|
||||||
|
# 从缓冲消息ID中提取原始消息ID
|
||||||
|
original_id = id.split('-')[1]
|
||||||
|
msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(original_id)})
|
||||||
|
else:
|
||||||
|
msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(id)})
|
||||||
|
|
||||||
replied_user_id = None
|
replied_user_id = None
|
||||||
if msg_info_response and msg_info_response.get("status") == "ok":
|
if msg_info_response and msg_info_response.get("status") == "ok":
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[inner]
|
[inner]
|
||||||
version = "1.3.4"
|
version = "1.3.5"
|
||||||
|
|
||||||
# 配置文件版本号迭代规则同bot_config.toml
|
# 配置文件版本号迭代规则同bot_config.toml
|
||||||
|
|
||||||
@@ -195,6 +195,11 @@ model_list = ["siliconflow-deepseek-v3"]
|
|||||||
temperature = 0.7
|
temperature = 0.7
|
||||||
max_tokens = 1000
|
max_tokens = 1000
|
||||||
|
|
||||||
|
[model_task_config.relationship_tracker] # 用户关系追踪模型
|
||||||
|
model_list = ["siliconflow-deepseek-v3"]
|
||||||
|
temperature = 0.7
|
||||||
|
max_tokens = 1000
|
||||||
|
|
||||||
#嵌入模型
|
#嵌入模型
|
||||||
[model_task_config.embedding]
|
[model_task_config.embedding]
|
||||||
model_list = ["bge-m3"]
|
model_list = ["bge-m3"]
|
||||||
|
|||||||
Reference in New Issue
Block a user