ruff,私聊视为提及了bot

This commit is contained in:
Windpicker-owo
2025-09-20 22:34:22 +08:00
parent 3baf4c533a
commit df3c616d09
75 changed files with 1055 additions and 885 deletions

View File

@@ -5,4 +5,4 @@
from src.chat.affinity_flow.afc_manager import afc_manager
__all__ = ['afc_manager', 'AFCManager', 'AffinityFlowChatter']
__all__ = ["afc_manager", "AFCManager", "AffinityFlowChatter"]

View File

@@ -2,6 +2,7 @@
亲和力聊天处理流管理器
管理不同聊天流的亲和力聊天处理流,统一获取新消息并分发到对应的亲和力聊天处理流
"""
import time
import traceback
from typing import Dict, Optional, List
@@ -20,7 +21,7 @@ class AFCManager:
def __init__(self):
self.affinity_flow_chatters: Dict[str, "AffinityFlowChatter"] = {}
'''所有聊天流的亲和力聊天处理流stream_id -> affinity_flow_chatter'''
"""所有聊天流的亲和力聊天处理流stream_id -> affinity_flow_chatter"""
# 动作管理器
self.action_manager = ActionManager()
@@ -40,11 +41,7 @@ class AFCManager:
# 创建增强版规划器
planner = ActionPlanner(stream_id, self.action_manager)
chatter = AffinityFlowChatter(
stream_id=stream_id,
planner=planner,
action_manager=self.action_manager
)
chatter = AffinityFlowChatter(stream_id=stream_id, planner=planner, action_manager=self.action_manager)
self.affinity_flow_chatters[stream_id] = chatter
logger.info(f"创建新的亲和力聊天处理器: {stream_id}")
@@ -74,7 +71,6 @@ class AFCManager:
"executed_count": 0,
}
def get_chatter_stats(self, stream_id: str) -> Optional[Dict[str, any]]:
"""获取聊天处理器统计"""
if stream_id in self.affinity_flow_chatters:
@@ -131,4 +127,5 @@ class AFCManager:
self.affinity_flow_chatters[stream_id].update_interest_keywords(new_keywords)
logger.info(f"已更新聊天流 {stream_id} 的兴趣关键词: {list(new_keywords.keys())}")
afc_manager = AFCManager()
afc_manager = AFCManager()

View File

@@ -2,6 +2,7 @@
亲和力聊天处理器
单个聊天流的处理器,负责处理特定聊天流的完整交互流程
"""
import time
import traceback
from datetime import datetime
@@ -57,10 +58,7 @@ class AffinityFlowChatter:
unread_messages = context.get_unread_messages()
# 使用增强版规划器处理消息
actions, target_message = await self.planner.plan(
mode=ChatMode.FOCUS,
context=context
)
actions, target_message = await self.planner.plan(mode=ChatMode.FOCUS, context=context)
self.stats["plans_created"] += 1
# 执行动作(如果规划器返回了动作)
@@ -84,7 +82,9 @@ class AffinityFlowChatter:
**execution_result,
}
logger.info(f"聊天流 {self.stream_id} StreamContext处理成功: 动作数={result['actions_count']}, 未读消息={result['unread_messages_processed']}")
logger.info(
f"聊天流 {self.stream_id} StreamContext处理成功: 动作数={result['actions_count']}, 未读消息={result['unread_messages_processed']}"
)
return result
@@ -197,7 +197,9 @@ class AffinityFlowChatter:
def __repr__(self) -> str:
"""详细字符串表示"""
return (f"AffinityFlowChatter(stream_id={self.stream_id}, "
f"messages_processed={self.stats['messages_processed']}, "
f"plans_created={self.stats['plans_created']}, "
f"last_activity={datetime.fromtimestamp(self.last_activity_time)})")
return (
f"AffinityFlowChatter(stream_id={self.stream_id}, "
f"messages_processed={self.stats['messages_processed']}, "
f"plans_created={self.stats['plans_created']}, "
f"last_activity={datetime.fromtimestamp(self.last_activity_time)})"
)

View File

@@ -38,7 +38,9 @@ class InterestScoringSystem:
# 连续不回复概率提升
self.no_reply_count = 0
self.max_no_reply_count = affinity_config.max_no_reply_count
self.probability_boost_per_no_reply = affinity_config.no_reply_threshold_adjustment / affinity_config.max_no_reply_count # 每次不回复增加的概率
self.probability_boost_per_no_reply = (
affinity_config.no_reply_threshold_adjustment / affinity_config.max_no_reply_count
) # 每次不回复增加的概率
# 用户关系数据
self.user_relationships: Dict[str, float] = {} # user_id -> relationship_score
@@ -153,7 +155,9 @@ class InterestScoringSystem:
# 返回匹配分数,考虑置信度和匹配标签数量
affinity_config = global_config.affinity_flow
match_count_bonus = min(len(match_result.matched_tags) * affinity_config.match_count_bonus, affinity_config.max_match_bonus)
match_count_bonus = min(
len(match_result.matched_tags) * affinity_config.match_count_bonus, affinity_config.max_match_bonus
)
final_score = match_result.overall_score * 1.15 * match_result.confidence + match_count_bonus
logger.debug(
f"⚖️ 最终分数计算: 总分({match_result.overall_score:.3f}) × 1.3 × 置信度({match_result.confidence:.3f}) + 标签数量奖励({match_count_bonus:.3f}) = {final_score:.3f}"
@@ -263,7 +267,17 @@ class InterestScoringSystem:
if not msg.processed_plain_text:
return 0.0
if msg.is_mentioned or (bot_nickname and bot_nickname in msg.processed_plain_text):
# 检查是否被提及
is_mentioned = msg.is_mentioned or (bot_nickname and bot_nickname in msg.processed_plain_text)
# 检查是否为私聊group_info为None表示私聊
is_private_chat = msg.group_info is None
# 如果被提及或是私聊都视为提及了bot
if is_mentioned or is_private_chat:
logger.debug(f"🔍 提及检测 - 被提及: {is_mentioned}, 私聊: {is_private_chat}")
if is_private_chat and not is_mentioned:
logger.debug("💬 私聊消息自动视为提及bot")
return global_config.affinity_flow.mention_bot_interest_score
return 0.0
@@ -282,7 +296,9 @@ class InterestScoringSystem:
logger.debug(f"📋 基础阈值: {base_threshold:.3f}")
# 如果被提及,降低阈值
if score.mentioned_score >= global_config.affinity_flow.mention_bot_interest_score * 0.5: # 使用提及bot兴趣分的一半作为判断阈值
if (
score.mentioned_score >= global_config.affinity_flow.mention_bot_interest_score * 0.5
): # 使用提及bot兴趣分的一半作为判断阈值
base_threshold = self.mention_threshold
logger.debug(f"📣 消息提及了机器人,使用降低阈值: {base_threshold:.3f}")
@@ -325,7 +341,9 @@ class InterestScoringSystem:
def update_user_relationship(self, user_id: str, relationship_change: float):
"""更新用户关系"""
old_score = self.user_relationships.get(user_id, global_config.affinity_flow.base_relationship_score) # 默认新用户分数
old_score = self.user_relationships.get(
user_id, global_config.affinity_flow.base_relationship_score
) # 默认新用户分数
new_score = max(0.0, min(1.0, old_score + relationship_change))
self.user_relationships[user_id] = new_score

View File

@@ -116,6 +116,7 @@ class UserRelationshipTracker:
try:
# 获取bot人设信息
from src.individuality.individuality import Individuality
individuality = Individuality()
bot_personality = await individuality.get_personality_block()
@@ -168,7 +169,17 @@ class UserRelationshipTracker:
# 清理LLM响应移除可能的格式标记
cleaned_response = self._clean_llm_json_response(llm_response)
response_data = json.loads(cleaned_response)
new_score = max(0.0, min(1.0, float(response_data.get("new_relationship_score", global_config.affinity_flow.base_relationship_score))))
new_score = max(
0.0,
min(
1.0,
float(
response_data.get(
"new_relationship_score", global_config.affinity_flow.base_relationship_score
)
),
),
)
if self.interest_scoring_system:
self.interest_scoring_system.update_user_relationship(
@@ -295,7 +306,9 @@ class UserRelationshipTracker:
# 更新缓存
self.user_relationship_cache[user_id] = {
"relationship_text": relationship_data.get("relationship_text", ""),
"relationship_score": relationship_data.get("relationship_score", global_config.affinity_flow.base_relationship_score),
"relationship_score": relationship_data.get(
"relationship_score", global_config.affinity_flow.base_relationship_score
),
"last_tracked": time.time(),
}
return relationship_data.get("relationship_score", global_config.affinity_flow.base_relationship_score)
@@ -386,7 +399,11 @@ class UserRelationshipTracker:
# 获取当前关系数据
current_relationship = self._get_user_relationship_from_db(user_id)
current_score = current_relationship.get("relationship_score", global_config.affinity_flow.base_relationship_score) if current_relationship else global_config.affinity_flow.base_relationship_score
current_score = (
current_relationship.get("relationship_score", global_config.affinity_flow.base_relationship_score)
if current_relationship
else global_config.affinity_flow.base_relationship_score
)
current_text = current_relationship.get("relationship_text", "新用户") if current_relationship else "新用户"
# 使用LLM分析并更新关系
@@ -501,6 +518,7 @@ class UserRelationshipTracker:
# 获取bot人设信息
from src.individuality.individuality import Individuality
individuality = Individuality()
bot_personality = await individuality.get_personality_block()

View File

@@ -2,6 +2,9 @@
"""
表情包发送历史记录模块
"""
import os
from typing import List, Dict
from collections import deque
from typing import List, Dict
@@ -25,15 +28,15 @@ def add_emoji_to_history(chat_id: str, emoji_description: str):
"""
if not chat_id or not emoji_description:
return
# 如果当前聊天还没有历史记录,则创建一个新的 deque
if chat_id not in _history_cache:
_history_cache[chat_id] = deque(maxlen=MAX_HISTORY_SIZE)
# 添加新表情到历史记录
history = _history_cache[chat_id]
history.append(emoji_description)
logger.debug(f"已将表情 '{emoji_description}' 添加到聊天 {chat_id} 的内存历史中")
@@ -49,10 +52,10 @@ def get_recent_emojis(chat_id: str, limit: int = 5) -> List[str]:
return []
history = _history_cache[chat_id]
# 从 deque 的右侧(即最近添加的)开始取
num_to_get = min(limit, len(history))
recent_emojis = [history[-i] for i in range(1, num_to_get + 1)]
logger.debug(f"为聊天 {chat_id} 从内存中获取到最近 {len(recent_emojis)} 个表情: {recent_emojis}")
return recent_emojis

View File

@@ -479,7 +479,7 @@ class EmojiManager:
emoji_options_str = ""
for i, emoji in enumerate(candidate_emojis):
# 为每个表情包创建一个编号和它的详细描述
emoji_options_str += f"编号: {i+1}\n描述: {emoji.description}\n\n"
emoji_options_str += f"编号: {i + 1}\n描述: {emoji.description}\n\n"
# 精心设计的prompt引导LLM做出选择
prompt = f"""
@@ -526,10 +526,8 @@ class EmojiManager:
await self.record_usage(selected_emoji.emoji_hash)
_time_end = time.time()
logger.info(
f"找到匹配描述的表情包: {selected_emoji.description}, 耗时: {(_time_end - _time_start):.2f}s"
)
logger.info(f"找到匹配描述的表情包: {selected_emoji.description}, 耗时: {(_time_end - _time_start):.2f}s")
# 8. 返回选中的表情包信息
return selected_emoji.full_path, f"[表情包:{selected_emoji.description}]", text_emotion
@@ -629,8 +627,9 @@ class EmojiManager:
# 无论steal_emoji是否开启都检查emoji文件夹以支持手动注册
# 只有在需要腾出空间或填充表情库时,才真正执行注册
if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or \
(self.emoji_num < self.emoji_num_max):
if (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or (
self.emoji_num < self.emoji_num_max
):
try:
# 获取目录下所有图片文件
files_to_process = [
@@ -938,19 +937,21 @@ class EmojiManager:
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() if Image.open(io.BytesIO(image_bytes)).format else "jpeg"
image_format = (
Image.open(io.BytesIO(image_bytes)).format.lower()
if Image.open(io.BytesIO(image_bytes)).format
else "jpeg"
)
# 2. 检查数据库中是否已存在该表情包的描述,实现复用
existing_description = None
try:
async with get_db_session() as session:
result = await session.execute(
select(Images).filter(
(Images.emoji_hash == image_hash) & (Images.type == "emoji")
)
with get_db_session() as session:
existing_image = (
session.query(Images)
.filter((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
.one_or_none()
)
existing_image = result.scalar_one_or_none()
if existing_image and existing_image.description:
existing_description = existing_image.description
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")

View File

@@ -14,6 +14,7 @@ Chat Frequency Analyzer
- MIN_CHATS_FOR_PEAK: 在一个窗口内需要多少次聊天才能被认为是高峰时段。
- MIN_GAP_BETWEEN_PEAKS_HOURS: 两个独立高峰时段之间的最小间隔(小时)。
"""
import time as time_module
from datetime import datetime, timedelta, time
from typing import List, Tuple, Optional
@@ -72,12 +73,14 @@ class ChatFrequencyAnalyzer:
current_window_end = datetimes[i]
# 合并重叠或相邻的高峰时段
if peak_windows and current_window_start - peak_windows[-1][1] < timedelta(hours=MIN_GAP_BETWEEN_PEAKS_HOURS):
if peak_windows and current_window_start - peak_windows[-1][1] < timedelta(
hours=MIN_GAP_BETWEEN_PEAKS_HOURS
):
# 扩展上一个窗口的结束时间
peak_windows[-1] = (peak_windows[-1][0], current_window_end)
else:
peak_windows.append((current_window_start, current_window_end))
return peak_windows
def get_peak_chat_times(self, chat_id: str) -> List[Tuple[time, time]]:
@@ -100,7 +103,7 @@ class ChatFrequencyAnalyzer:
return []
peak_datetime_windows = self._find_peak_windows(timestamps)
# 将 datetime 窗口转换为 time 窗口,并进行归一化处理
peak_time_windows = []
for start_dt, end_dt in peak_datetime_windows:
@@ -110,7 +113,7 @@ class ChatFrequencyAnalyzer:
# 更新缓存
self._analysis_cache[chat_id] = (time_module.time(), peak_time_windows)
return peak_time_windows
def is_in_peak_time(self, chat_id: str, now: Optional[datetime] = None) -> bool:
@@ -126,7 +129,7 @@ class ChatFrequencyAnalyzer:
"""
if now is None:
now = datetime.now()
now_time = now.time()
peak_times = self.get_peak_chat_times(chat_id)
@@ -137,7 +140,7 @@ class ChatFrequencyAnalyzer:
else: # 跨天
if now_time >= start_time or now_time <= end_time:
return True
return False

View File

@@ -56,7 +56,7 @@ class ChatFrequencyTracker:
now = time.time()
if chat_id not in self._timestamps:
self._timestamps[chat_id] = []
self._timestamps[chat_id].append(now)
logger.debug(f"为 chat_id '{chat_id}' 记录了新的聊天时间: {now}")
self._save_timestamps()

View File

@@ -14,6 +14,7 @@ Frequency-Based Proactive Trigger
- TRIGGER_CHECK_INTERVAL_SECONDS: 触发器检查的周期(秒)。
- COOLDOWN_HOURS: 在同一个高峰时段内触发一次后的冷却时间(小时)。
"""
import asyncio
import time
from datetime import datetime
@@ -21,6 +22,7 @@ from typing import Dict, Optional
from src.common.logger import get_logger
from src.chat.affinity_flow.afc_manager import afc_manager
# TODO: 需要重新实现主动思考和睡眠管理功能
from .analyzer import chat_frequency_analyzer
@@ -65,7 +67,7 @@ class FrequencyBasedTrigger:
continue
now = datetime.now()
for chat_id in all_chat_ids:
# 3. 检查是否处于冷却时间内
last_triggered_time = self._last_triggered.get(chat_id, 0)
@@ -74,7 +76,6 @@ class FrequencyBasedTrigger:
# 4. 检查当前是否是该用户的高峰聊天时间
if chat_frequency_analyzer.is_in_peak_time(chat_id, now):
# 5. 检查用户当前是否已有活跃的处理任务
# 亲和力流系统不直接提供循环状态,通过检查最后活动时间来判断是否忙碌
chatter = afc_manager.get_or_create_chatter(chat_id)
@@ -87,13 +88,13 @@ class FrequencyBasedTrigger:
if current_time - chatter.get_activity_time() < 60:
logger.debug(f"用户 {chat_id} 的亲和力处理器正忙,本次不触发。")
continue
logger.info(f"检测到用户 {chat_id} 处于聊天高峰期,且处理器空闲,准备触发主动思考。")
# 6. TODO: 亲和力流系统的主动思考机制需要另行实现
# 目前先记录日志,等待后续实现
logger.info(f"用户 {chat_id} 处于高峰期,但亲和力流的主动思考功能暂未实现")
# 7. 更新触发时间,进入冷却
self._last_triggered[chat_id] = time.time()

View File

@@ -4,14 +4,12 @@
"""
from .bot_interest_manager import BotInterestManager, bot_interest_manager
from src.common.data_models.bot_interest_data_model import (
BotInterestTag, BotPersonalityInterests, InterestMatchResult
)
from src.common.data_models.bot_interest_data_model import BotInterestTag, BotPersonalityInterests, InterestMatchResult
__all__ = [
"BotInterestManager",
"bot_interest_manager",
"BotInterestTag",
"BotPersonalityInterests",
"InterestMatchResult"
]
"InterestMatchResult",
]

View File

@@ -2,6 +2,7 @@
机器人兴趣标签管理系统
基于人设生成兴趣标签并使用embedding计算匹配度
"""
import orjson
import traceback
from typing import List, Dict, Optional, Any
@@ -10,9 +11,7 @@ import numpy as np
from src.common.logger import get_logger
from src.config.config import global_config
from src.common.data_models.bot_interest_data_model import (
BotPersonalityInterests, BotInterestTag, InterestMatchResult
)
from src.common.data_models.bot_interest_data_model import BotPersonalityInterests, BotInterestTag, InterestMatchResult
logger = get_logger("bot_interest_manager")
@@ -87,7 +86,7 @@ class BotInterestManager:
logger.debug("✅ 成功导入embedding相关模块")
# 检查embedding配置是否存在
if not hasattr(model_config.model_task_config, 'embedding'):
if not hasattr(model_config.model_task_config, "embedding"):
raise RuntimeError("❌ 未找到embedding模型配置")
logger.info("📋 找到embedding模型配置")
@@ -101,7 +100,7 @@ class BotInterestManager:
logger.info(f"🔗 客户端类型: {type(self.embedding_request).__name__}")
# 获取第一个embedding模型的ModelInfo
if hasattr(self.embedding_config, 'model_list') and self.embedding_config.model_list:
if hasattr(self.embedding_config, "model_list") and self.embedding_config.model_list:
first_model_name = self.embedding_config.model_list[0]
logger.info(f"🎯 使用embedding模型: {first_model_name}")
else:
@@ -127,7 +126,9 @@ class BotInterestManager:
# 生成新的兴趣标签
logger.info("🆕 数据库中未找到兴趣标签,开始生成新的...")
logger.info("🤖 正在调用LLM生成个性化兴趣标签...")
generated_interests = await self._generate_interests_from_personality(personality_description, personality_id)
generated_interests = await self._generate_interests_from_personality(
personality_description, personality_id
)
if generated_interests:
self.current_interests = generated_interests
@@ -140,14 +141,16 @@ class BotInterestManager:
else:
raise RuntimeError("❌ 兴趣标签生成失败")
async def _generate_interests_from_personality(self, personality_description: str, personality_id: str) -> Optional[BotPersonalityInterests]:
async def _generate_interests_from_personality(
self, personality_description: str, personality_id: str
) -> Optional[BotPersonalityInterests]:
"""根据人设生成兴趣标签"""
try:
logger.info("🎨 开始根据人设生成兴趣标签...")
logger.info(f"📝 人设长度: {len(personality_description)} 字符")
# 检查embedding客户端是否可用
if not hasattr(self, 'embedding_request'):
if not hasattr(self, "embedding_request"):
raise RuntimeError("❌ Embedding客户端未初始化无法生成兴趣标签")
# 构建提示词
@@ -190,8 +193,7 @@ class BotInterestManager:
interests_data = orjson.loads(response)
bot_interests = BotPersonalityInterests(
personality_id=personality_id,
personality_description=personality_description
personality_id=personality_id, personality_description=personality_description
)
# 解析生成的兴趣标签
@@ -202,10 +204,7 @@ class BotInterestManager:
tag_name = tag_data.get("name", f"标签_{i}")
weight = tag_data.get("weight", 0.5)
tag = BotInterestTag(
tag_name=tag_name,
weight=weight
)
tag = BotInterestTag(tag_name=tag_name, weight=weight)
bot_interests.interest_tags.append(tag)
logger.debug(f" 🏷️ {tag_name} (权重: {weight:.2f})")
@@ -225,7 +224,6 @@ class BotInterestManager:
traceback.print_exc()
raise
async def _call_llm_for_interest_generation(self, prompt: str) -> Optional[str]:
"""调用LLM生成兴趣标签"""
try:
@@ -241,10 +239,10 @@ class BotInterestManager:
{prompt}
请确保返回格式为有效的JSON不要包含任何额外的文本、解释或代码块标记。只返回JSON对象本身。"""
# 使用replyer模型配置
replyer_config = model_config.model_task_config.replyer
# 调用LLM API
logger.info("🚀 正在通过LLM API发送请求...")
success, response, reasoning_content, model_name = await llm_api.generate_with_model(
@@ -252,15 +250,17 @@ class BotInterestManager:
model_config=replyer_config,
request_type="interest_generation",
temperature=0.7,
max_tokens=2000
max_tokens=2000,
)
if success and response:
logger.info(f"✅ LLM响应成功模型: {model_name}, 响应长度: {len(response)} 字符")
logger.debug(f"📄 LLM响应内容: {response[:200]}..." if len(response) > 200 else f"📄 LLM响应内容: {response}")
logger.debug(
f"📄 LLM响应内容: {response[:200]}..." if len(response) > 200 else f"📄 LLM响应内容: {response}"
)
if reasoning_content:
logger.debug(f"🧠 推理内容: {reasoning_content[:100]}...")
# 清理响应内容,移除可能的代码块标记
cleaned_response = self._clean_llm_response(response)
return cleaned_response
@@ -277,25 +277,25 @@ class BotInterestManager:
def _clean_llm_response(self, response: str) -> str:
"""清理LLM响应移除代码块标记和其他非JSON内容"""
import re
# 移除 ```json 和 ``` 标记
cleaned = re.sub(r'```json\s*', '', response)
cleaned = re.sub(r'\s*```', '', cleaned)
cleaned = re.sub(r"```json\s*", "", response)
cleaned = re.sub(r"\s*```", "", cleaned)
# 移除可能的多余空格和换行
cleaned = cleaned.strip()
# 尝试提取JSON对象如果响应中有其他文本
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
json_match = re.search(r"\{.*\}", cleaned, re.DOTALL)
if json_match:
cleaned = json_match.group(0)
logger.debug(f"🧹 清理后的响应: {cleaned[:200]}..." if len(cleaned) > 200 else f"🧹 清理后的响应: {cleaned}")
return cleaned
async def _generate_embeddings_for_tags(self, interests: BotPersonalityInterests):
"""为所有兴趣标签生成embedding"""
if not hasattr(self, 'embedding_request'):
if not hasattr(self, "embedding_request"):
raise RuntimeError("❌ Embedding客户端未初始化无法生成embedding")
total_tags = len(interests.interest_tags)
@@ -342,7 +342,7 @@ class BotInterestManager:
async def _get_embedding(self, text: str) -> List[float]:
"""获取文本的embedding向量"""
if not hasattr(self, 'embedding_request'):
if not hasattr(self, "embedding_request"):
raise RuntimeError("❌ Embedding请求客户端未初始化")
# 检查缓存
@@ -376,7 +376,9 @@ class BotInterestManager:
logger.debug(f"✅ 消息embedding生成成功维度: {len(embedding)}")
return embedding
async def _calculate_similarity_scores(self, result: InterestMatchResult, message_embedding: List[float], keywords: List[str]):
async def _calculate_similarity_scores(
self, result: InterestMatchResult, message_embedding: List[float], keywords: List[str]
):
"""计算消息与兴趣标签的相似度分数"""
try:
if not self.current_interests:
@@ -397,7 +399,9 @@ class BotInterestManager:
# 设置相似度阈值为0.3
if similarity > 0.3:
result.add_match(tag.tag_name, weighted_score, keywords)
logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 加权分数={weighted_score:.3f}")
logger.debug(
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 加权分数={weighted_score:.3f}"
)
except Exception as e:
logger.error(f"❌ 计算相似度分数失败: {e}")
@@ -455,7 +459,9 @@ class BotInterestManager:
match_count += 1
high_similarity_count += 1
result.add_match(tag.tag_name, enhanced_score, [tag.tag_name])
logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [高匹配]")
logger.debug(
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [高匹配]"
)
elif similarity > medium_threshold:
# 中相似度:中等加成
@@ -463,7 +469,9 @@ class BotInterestManager:
match_count += 1
medium_similarity_count += 1
result.add_match(tag.tag_name, enhanced_score, [tag.tag_name])
logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [中匹配]")
logger.debug(
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [中匹配]"
)
elif similarity > low_threshold:
# 低相似度:轻微加成
@@ -471,7 +479,9 @@ class BotInterestManager:
match_count += 1
low_similarity_count += 1
result.add_match(tag.tag_name, enhanced_score, [tag.tag_name])
logger.debug(f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [低匹配]")
logger.debug(
f" 🏷️ '{tag.tag_name}': 相似度={similarity:.3f}, 权重={tag.weight:.2f}, 基础分数={weighted_score:.3f}, 增强分数={enhanced_score:.3f} [低匹配]"
)
logger.info(f"📈 匹配统计: {match_count}/{len(active_tags)} 个标签超过阈值")
logger.info(f"🔥 高相似度匹配(>{high_threshold}): {high_similarity_count}")
@@ -488,7 +498,9 @@ class BotInterestManager:
original_score = result.match_scores[tag_name]
bonus = keyword_bonus[tag_name]
result.match_scores[tag_name] = original_score + bonus
logger.debug(f" 🏷️ '{tag_name}': 原始分数={original_score:.3f}, 奖励={bonus:.3f}, 最终分数={result.match_scores[tag_name]:.3f}")
logger.debug(
f" 🏷️ '{tag_name}': 原始分数={original_score:.3f}, 奖励={bonus:.3f}, 最终分数={result.match_scores[tag_name]:.3f}"
)
# 计算总体分数
result.calculate_overall_score()
@@ -499,10 +511,11 @@ class BotInterestManager:
result.top_tag = top_tag_name
logger.info(f"🏆 最佳匹配标签: '{top_tag_name}' (分数: {result.match_scores[top_tag_name]:.3f})")
logger.info(f"📊 最终结果: 总分={result.overall_score:.3f}, 置信度={result.confidence:.3f}, 匹配标签数={len(result.matched_tags)}")
logger.info(
f"📊 最终结果: 总分={result.overall_score:.3f}, 置信度={result.confidence:.3f}, 匹配标签数={len(result.matched_tags)}"
)
return result
def _calculate_keyword_match_bonus(self, keywords: List[str], matched_tags: List[str]) -> Dict[str, float]:
"""计算关键词直接匹配奖励"""
if not keywords or not matched_tags:
@@ -522,17 +535,25 @@ class BotInterestManager:
# 完全匹配
if keyword_lower == tag_name_lower:
bonus += affinity_config.high_match_interest_threshold * 0.6 # 使用高匹配阈值的60%作为完全匹配奖励
logger.debug(f" 🎯 关键词完全匹配: '{keyword}' == '{tag_name}' (+{affinity_config.high_match_interest_threshold * 0.6:.3f})")
logger.debug(
f" 🎯 关键词完全匹配: '{keyword}' == '{tag_name}' (+{affinity_config.high_match_interest_threshold * 0.6:.3f})"
)
# 包含匹配
elif keyword_lower in tag_name_lower or tag_name_lower in keyword_lower:
bonus += affinity_config.medium_match_interest_threshold * 0.3 # 使用中匹配阈值的30%作为包含匹配奖励
logger.debug(f" 🎯 关键词包含匹配: '{keyword}''{tag_name}' (+{affinity_config.medium_match_interest_threshold * 0.3:.3f})")
bonus += (
affinity_config.medium_match_interest_threshold * 0.3
) # 使用中匹配阈值的30%作为包含匹配奖励
logger.debug(
f" 🎯 关键词包含匹配: '{keyword}''{tag_name}' (+{affinity_config.medium_match_interest_threshold * 0.3:.3f})"
)
# 部分匹配(编辑距离)
elif self._calculate_partial_match(keyword_lower, tag_name_lower):
bonus += affinity_config.low_match_interest_threshold * 0.4 # 使用低匹配阈值的40%作为部分匹配奖励
logger.debug(f" 🎯 关键词部分匹配: '{keyword}''{tag_name}' (+{affinity_config.low_match_interest_threshold * 0.4:.3f})")
logger.debug(
f" 🎯 关键词部分匹配: '{keyword}''{tag_name}' (+{affinity_config.low_match_interest_threshold * 0.4:.3f})"
)
if bonus > 0:
bonus_dict[tag_name] = min(bonus, affinity_config.max_match_bonus) # 使用配置的最大奖励限制
@@ -608,12 +629,12 @@ class BotInterestManager:
with get_db_session() as session:
# 查询最新的兴趣标签配置
db_interests = session.query(DBBotPersonalityInterests).filter(
DBBotPersonalityInterests.personality_id == personality_id
).order_by(
DBBotPersonalityInterests.version.desc(),
DBBotPersonalityInterests.last_updated.desc()
).first()
db_interests = (
session.query(DBBotPersonalityInterests)
.filter(DBBotPersonalityInterests.personality_id == personality_id)
.order_by(DBBotPersonalityInterests.version.desc(), DBBotPersonalityInterests.last_updated.desc())
.first()
)
if db_interests:
logger.info(f"✅ 找到数据库中的兴趣标签配置,版本: {db_interests.version}")
@@ -631,7 +652,7 @@ class BotInterestManager:
personality_description=db_interests.personality_description,
embedding_model=db_interests.embedding_model,
version=db_interests.version,
last_updated=db_interests.last_updated
last_updated=db_interests.last_updated,
)
# 解析兴趣标签
@@ -639,10 +660,14 @@ class BotInterestManager:
tag = BotInterestTag(
tag_name=tag_data.get("tag_name", ""),
weight=tag_data.get("weight", 0.5),
created_at=datetime.fromisoformat(tag_data.get("created_at", datetime.now().isoformat())),
updated_at=datetime.fromisoformat(tag_data.get("updated_at", datetime.now().isoformat())),
created_at=datetime.fromisoformat(
tag_data.get("created_at", datetime.now().isoformat())
),
updated_at=datetime.fromisoformat(
tag_data.get("updated_at", datetime.now().isoformat())
),
is_active=tag_data.get("is_active", True),
embedding=tag_data.get("embedding")
embedding=tag_data.get("embedding"),
)
interests.interest_tags.append(tag)
@@ -685,7 +710,7 @@ class BotInterestManager:
"created_at": tag.created_at.isoformat(),
"updated_at": tag.updated_at.isoformat(),
"is_active": tag.is_active,
"embedding": tag.embedding
"embedding": tag.embedding,
}
tags_data.append(tag_dict)
@@ -694,9 +719,11 @@ class BotInterestManager:
with get_db_session() as session:
# 检查是否已存在相同personality_id的记录
existing_record = session.query(DBBotPersonalityInterests).filter(
DBBotPersonalityInterests.personality_id == interests.personality_id
).first()
existing_record = (
session.query(DBBotPersonalityInterests)
.filter(DBBotPersonalityInterests.personality_id == interests.personality_id)
.first()
)
if existing_record:
# 更新现有记录
@@ -718,7 +745,7 @@ class BotInterestManager:
interest_tags=json_data,
embedding_model=interests.embedding_model,
version=interests.version,
last_updated=interests.last_updated
last_updated=interests.last_updated,
)
session.add(new_record)
session.commit()
@@ -728,9 +755,11 @@ class BotInterestManager:
# 验证保存是否成功
with get_db_session() as session:
saved_record = session.query(DBBotPersonalityInterests).filter(
DBBotPersonalityInterests.personality_id == interests.personality_id
).first()
saved_record = (
session.query(DBBotPersonalityInterests)
.filter(DBBotPersonalityInterests.personality_id == interests.personality_id)
.first()
)
session.commit()
if saved_record:
logger.info(f"✅ 验证成功数据库中存在personality_id为 {interests.personality_id} 的记录")
@@ -760,7 +789,7 @@ class BotInterestManager:
"total_tags": len(active_tags),
"embedding_model": self.current_interests.embedding_model,
"last_updated": self.current_interests.last_updated.isoformat(),
"cache_size": len(self.embedding_cache)
"cache_size": len(self.embedding_cache),
}
async def update_interest_tags(self, new_personality_description: str = None):
@@ -775,8 +804,7 @@ class BotInterestManager:
# 重新生成兴趣标签
new_interests = await self._generate_interests_from_personality(
self.current_interests.personality_description,
self.current_interests.personality_id
self.current_interests.personality_description, self.current_interests.personality_id
)
if new_interests:
@@ -791,4 +819,4 @@ class BotInterestManager:
# 创建全局实例(重新创建以包含新的属性)
bot_interest_manager = BotInterestManager()
bot_interest_manager = BotInterestManager()

View File

@@ -4,13 +4,11 @@
"""
from .message_manager import MessageManager, message_manager
from src.common.data_models.message_manager_data_model import StreamContext, MessageStatus, MessageManagerStats, StreamStats
from src.common.data_models.message_manager_data_model import (
StreamContext,
MessageStatus,
MessageManagerStats,
StreamStats,
)
__all__ = [
"MessageManager",
"message_manager",
"StreamContext",
"MessageStatus",
"MessageManagerStats",
"StreamStats"
]
__all__ = ["MessageManager", "message_manager", "StreamContext", "MessageStatus", "MessageManagerStats", "StreamStats"]

View File

@@ -2,6 +2,7 @@
消息管理模块
管理每个聊天流的上下文信息,包含历史记录和未读消息,定期检查并处理新消息
"""
import asyncio
import time
import traceback
@@ -100,9 +101,7 @@ class MessageManager:
# 如果没有处理任务,创建一个
if not context.processing_task or context.processing_task.done():
context.processing_task = asyncio.create_task(
self._process_stream_messages(stream_id)
)
context.processing_task = asyncio.create_task(self._process_stream_messages(stream_id))
# 更新统计
self.stats.active_streams = active_streams
@@ -128,11 +127,11 @@ class MessageManager:
try:
# 发送到AFC处理器传递StreamContext对象
results = await afc_manager.process_stream_context(stream_id, context)
# 处理结果,标记消息为已读
if results.get("success", False):
self._clear_all_unread_messages(context)
except Exception as e:
logger.error(f"处理聊天流 {stream_id} 时发生异常,将清除所有未读消息: {e}")
raise
@@ -175,7 +174,7 @@ class MessageManager:
unread_count=len(context.get_unread_messages()),
history_count=len(context.history_messages),
last_check_time=context.last_check_time,
has_active_task=context.processing_task and not context.processing_task.done()
has_active_task=context.processing_task and not context.processing_task.done(),
)
def get_manager_stats(self) -> Dict[str, Any]:
@@ -186,7 +185,7 @@ class MessageManager:
"total_unread_messages": self.stats.total_unread_messages,
"total_processed_messages": self.stats.total_processed_messages,
"uptime": self.stats.uptime,
"start_time": self.stats.start_time
"start_time": self.stats.start_time,
}
def cleanup_inactive_streams(self, max_inactive_hours: int = 24):
@@ -196,8 +195,7 @@ class MessageManager:
inactive_streams = []
for stream_id, context in self.stream_contexts.items():
if (current_time - context.last_check_time > max_inactive_seconds and
not context.get_unread_messages()):
if current_time - context.last_check_time > max_inactive_seconds and not context.get_unread_messages():
inactive_streams.append(stream_id)
for stream_id in inactive_streams:
@@ -210,9 +208,9 @@ class MessageManager:
unread_messages = context.get_unread_messages()
if not unread_messages:
return
logger.warning(f"正在清除 {len(unread_messages)} 条未读消息")
# 将所有未读消息标记为已读并移动到历史记录
for msg in unread_messages[:]: # 使用切片复制避免迭代时修改列表
try:
@@ -224,4 +222,4 @@ class MessageManager:
# 创建全局消息管理器实例
message_manager = MessageManager()
message_manager = MessageManager()

View File

@@ -17,6 +17,7 @@ from src.plugin_system.core import component_registry, event_manager, global_ann
from src.plugin_system.base import BaseCommand, EventType
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
from src.chat.utils.utils import is_mentioned_bot_in_message
# 导入反注入系统
from src.chat.antipromptinjector import initialize_anti_injector
@@ -515,7 +516,7 @@ class ChatBot:
chat_info_user_id=message.chat_stream.user_info.user_id,
chat_info_user_nickname=message.chat_stream.user_info.user_nickname,
chat_info_user_cardname=message.chat_stream.user_info.user_cardname,
chat_info_user_platform=message.chat_stream.user_info.platform
chat_info_user_platform=message.chat_stream.user_info.platform,
)
# 如果是群聊,添加群组信息

View File

@@ -84,7 +84,9 @@ class ChatStream:
self.saved = False
self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息
# 从配置文件中读取focus_value如果没有则使用默认值1.0
self.focus_energy = data.get("focus_energy", global_config.chat.focus_value) if data else global_config.chat.focus_value
self.focus_energy = (
data.get("focus_energy", global_config.chat.focus_value) if data else global_config.chat.focus_value
)
self.no_reply_consecutive = 0
self.breaking_accumulated_interest = 0.0

View File

@@ -165,10 +165,15 @@ class ActionManager:
# 通过chat_id获取chat_stream
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(chat_id)
if not chat_stream:
logger.error(f"{log_prefix} 无法找到chat_id对应的chat_stream: {chat_id}")
return {"action_type": action_name, "success": False, "reply_text": "", "error": "chat_stream not found"}
return {
"action_type": action_name,
"success": False,
"reply_text": "",
"error": "chat_stream not found",
}
if action_name == "no_action":
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
@@ -177,7 +182,7 @@ class ActionManager:
# 直接处理no_reply逻辑不再通过动作系统
reason = reasoning or "选择不回复"
logger.info(f"{log_prefix} 选择不回复,原因: {reason}")
# 存储no_reply信息到数据库
await database_api.store_action_info(
chat_stream=chat_stream,
@@ -396,7 +401,7 @@ class ActionManager:
}
return loop_info, reply_text, cycle_timers
async def send_response(self, chat_stream, reply_set, thinking_start_time, message_data) -> str:
"""
发送回复内容的具体实现
@@ -471,4 +476,4 @@ class ActionManager:
typing=True,
)
return reply_text
return reply_text

View File

@@ -1,6 +1,7 @@
"""
PlanGenerator: 负责搜集和汇总所有决策所需的信息,生成一个未经筛选的“原始计划” (Plan)。
"""
import time
from typing import Dict
@@ -35,6 +36,7 @@ class PlanGenerator:
chat_id (str): 当前聊天的 ID。
"""
from src.chat.planner_actions.action_manager import ActionManager
self.chat_id = chat_id
# 注意ActionManager 可能需要根据实际情况初始化
self.action_manager = ActionManager()
@@ -51,8 +53,8 @@ class PlanGenerator:
Returns:
Plan: 一个填充了初始上下文信息的 Plan 对象。
"""
_is_group_chat, chat_target_info_dict = await get_chat_type_and_target_info(self.chat_id)
_is_group_chat, chat_target_info_dict = get_chat_type_and_target_info(self.chat_id)
target_info = None
if chat_target_info_dict:
target_info = TargetPersonInfo(**chat_target_info_dict)
@@ -65,7 +67,6 @@ class PlanGenerator:
)
chat_history = [DatabaseMessages(**msg) for msg in await chat_history_raw]
plan = Plan(
chat_id=self.chat_id,
mode=mode,
@@ -86,10 +87,10 @@ class PlanGenerator:
Dict[str, "ActionInfo"]: 一个字典,键是动作名称,值是 ActionInfo 对象。
"""
current_available_actions_dict = self.action_manager.get_using_actions()
all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
all_registered_actions: Dict[str, ActionInfo] = component_registry.get_components_by_type( # type: ignore
ComponentType.ACTION
)
current_available_actions = {}
for action_name in current_available_actions_dict:
if action_name in all_registered_actions:
@@ -99,16 +100,13 @@ class PlanGenerator:
name="reply",
component_type=ComponentType.ACTION,
description="系统级动作:选择回复消息的决策",
action_parameters={
"content": "回复的文本内容",
"reply_to_message_id": "要回复的消息ID"
},
action_parameters={"content": "回复的文本内容", "reply_to_message_id": "要回复的消息ID"},
action_require=[
"你想要闲聊或者随便附和",
"当用户提到你或艾特你时",
"当需要回答用户的问题时",
"当你想参与对话时",
"当用户分享有趣的内容时"
"当用户分享有趣的内容时",
],
activation_type=ActionActivationType.ALWAYS,
activation_keywords=[],
@@ -131,4 +129,4 @@ class PlanGenerator:
)
current_available_actions["no_reply"] = no_reply_info
current_available_actions["reply"] = reply_info
return current_available_actions
return current_available_actions

View File

@@ -109,9 +109,7 @@ class ActionPlanner:
self.planner_stats["failed_plans"] += 1
return [], None
async def _enhanced_plan_flow(
self, mode: ChatMode, context: StreamContext
) -> Tuple[List[Dict], Optional[Dict]]:
async def _enhanced_plan_flow(self, mode: ChatMode, context: StreamContext) -> Tuple[List[Dict], Optional[Dict]]:
"""执行增强版规划流程"""
try:
# 1. 生成初始 Plan
@@ -137,7 +135,9 @@ class ActionPlanner:
# 检查兴趣度是否达到非回复动作阈值
non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
if score < non_reply_action_interest_threshold:
logger.info(f"❌ 兴趣度不足非回复动作阈值: {score:.3f} < {non_reply_action_interest_threshold:.3f}直接返回no_action")
logger.info(
f"❌ 兴趣度不足非回复动作阈值: {score:.3f} < {non_reply_action_interest_threshold:.3f}直接返回no_action"
)
logger.info(f"📊 最低要求: {non_reply_action_interest_threshold:.3f}")
# 直接返回 no_action
from src.common.data_models.info_data_model import ActionPlannerInfo

View File

@@ -326,7 +326,7 @@ class DefaultReplyer:
"model": model_name,
"tool_calls": tool_call,
}
# 触发 AFTER_LLM 事件
if not from_plugin:
result = await event_manager.trigger_event(
@@ -618,6 +618,7 @@ class DefaultReplyer:
def _parse_reply_target(target_message: str) -> Tuple[str, str]:
"""解析回复目标消息 - 使用共享工具"""
from src.chat.utils.prompt import Prompt
if target_message is None:
logger.warning("target_message为None返回默认值")
return "未知用户", "(无消息内容)"
@@ -726,22 +727,24 @@ class DefaultReplyer:
unread_history_prompt = ""
if unread_messages:
# 尝试获取兴趣度评分
interest_scores = await self._get_interest_scores_for_messages([msg.flatten() for msg in unread_messages])
interest_scores = await self._get_interest_scores_for_messages(
[msg.flatten() for msg in unread_messages]
)
unread_lines = []
for msg in unread_messages:
msg_id = msg.message_id
msg_time = time.strftime('%H:%M:%S', time.localtime(msg.time))
msg_time = time.strftime("%H:%M:%S", time.localtime(msg.time))
msg_content = msg.processed_plain_text
# 使用与已读历史消息相同的方法获取用户名
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
# 获取用户信息
user_info = getattr(msg, 'user_info', {})
platform = getattr(user_info, 'platform', '') or getattr(msg, 'platform', '')
user_id = getattr(user_info, 'user_id', '') or getattr(msg, 'user_id', '')
user_info = getattr(msg, "user_info", {})
platform = getattr(user_info, "platform", "") or getattr(msg, "platform", "")
user_id = getattr(user_info, "user_id", "") or getattr(msg, "user_id", "")
# 获取用户名
if platform and user_id:
person_id = PersonInfoManager.get_person_id(platform, user_id)
@@ -749,11 +752,11 @@ class DefaultReplyer:
sender_name = person_info_manager.get_value_sync(person_id, "person_name") or "未知用户"
else:
sender_name = "未知用户"
# 添加兴趣度信息
interest_score = interest_scores.get(msg_id, 0.0)
interest_text = f" [兴趣度: {interest_score:.3f}]" if interest_score > 0 else ""
unread_lines.append(f"{msg_time} {sender_name}: {msg_content}{interest_text}")
unread_history_prompt_str = "\n".join(unread_lines)
@@ -830,17 +833,17 @@ class DefaultReplyer:
unread_lines = []
for msg in unread_messages:
msg_id = msg.get("message_id", "")
msg_time = time.strftime('%H:%M:%S', time.localtime(msg.get("time", time.time())))
msg_time = time.strftime("%H:%M:%S", time.localtime(msg.get("time", time.time())))
msg_content = msg.get("processed_plain_text", "")
# 使用与已读历史消息相同的方法获取用户名
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
# 获取用户信息
user_info = msg.get("user_info", {})
platform = user_info.get("platform") or msg.get("platform", "")
user_id = user_info.get("user_id") or msg.get("user_id", "")
# 获取用户名
if platform and user_id:
person_id = PersonInfoManager.get_person_id(platform, user_id)
@@ -856,7 +859,9 @@ class DefaultReplyer:
unread_lines.append(f"{msg_time} {sender_name}: {msg_content}{interest_text}")
unread_history_prompt_str = "\n".join(unread_lines)
unread_history_prompt = f"这是未读历史消息,包含兴趣度评分,请优先对兴趣值高的消息做出动作:\n{unread_history_prompt_str}"
unread_history_prompt = (
f"这是未读历史消息,包含兴趣度评分,请优先对兴趣值高的消息做出动作:\n{unread_history_prompt_str}"
)
else:
unread_history_prompt = "暂无未读历史消息"
@@ -1042,7 +1047,7 @@ class DefaultReplyer:
reply_message.get("user_id"), # type: ignore
)
person_name = await person_info_manager.get_value(person_id, "person_name")
# 如果person_name为None使用fallback值
if person_name is None:
# 尝试从reply_message获取用户名
@@ -1050,13 +1055,14 @@ class DefaultReplyer:
logger.warning(f"未知用户,将存储用户信息:{fallback_name}")
person_name = str(fallback_name)
person_info_manager.set_value(person_id, "person_name", fallback_name)
# 检查是否是bot自己的名字如果是则替换为"(你)"
bot_user_id = str(global_config.bot.qq_account)
# 优先使用传入的 reply_message 如果它不是 bot
candidate_msg = None
if reply_message and str(reply_message.get("user_id")) != bot_user_id:
candidate_msg = reply_message
current_user_id = person_info_manager.get_value_sync(person_id, "user_id")
current_platform = reply_message.get("chat_info_platform")
if current_user_id == bot_user_id and current_platform == global_config.bot.platform:
sender = f"{person_name}(你)"
else:
try:
recent_msgs = await get_raw_msg_before_timestamp_with_chat(
@@ -1129,8 +1135,9 @@ class DefaultReplyer:
target_user_info = None
if sender:
target_user_info = await person_info_manager.get_person_info_by_name(sender)
from src.chat.utils.prompt import Prompt
# 并行执行六个构建任务
task_results = await asyncio.gather(
self._time_and_run_task(
@@ -1207,6 +1214,7 @@ class DefaultReplyer:
schedule_block = ""
if global_config.planning_system.schedule_enable:
from src.schedule.schedule_manager import schedule_manager
current_activity = schedule_manager.get_current_activity()
if current_activity:
schedule_block = f"你当前正在:{current_activity}"
@@ -1219,7 +1227,7 @@ class DefaultReplyer:
safety_guidelines = global_config.personality.safety_guidelines
safety_guidelines_block = ""
if safety_guidelines:
guidelines_text = "\n".join(f"{i+1}. {line}" for i, line in enumerate(safety_guidelines))
guidelines_text = "\n".join(f"{i + 1}. {line}" for i, line in enumerate(safety_guidelines))
safety_guidelines_block = f"""### 安全与互动底线
在任何情况下,你都必须遵守以下由你的设定者为你定义的原则:
{guidelines_text}
@@ -1314,7 +1322,7 @@ class DefaultReplyer:
template_name = "normal_style_prompt"
elif current_prompt_mode == "minimal":
template_name = "default_expressor_prompt"
# 获取模板内容
template_prompt = await global_prompt_manager.get_prompt_async(template_name)
prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters)
@@ -1594,19 +1602,19 @@ class DefaultReplyer:
# 使用AFC关系追踪器获取关系信息
try:
from src.chat.affinity_flow.relationship_integration import get_relationship_tracker
relationship_tracker = get_relationship_tracker()
if relationship_tracker:
# 获取用户信息以获取真实的user_id
user_info = await person_info_manager.get_values(person_id, ["user_id", "platform"])
user_id = user_info.get("user_id", "unknown")
# 从数据库获取关系数据
relationship_data = relationship_tracker._get_user_relationship_from_db(user_id)
if relationship_data:
relationship_text = relationship_data.get("relationship_text", "")
relationship_score = relationship_data.get("relationship_score", 0.3)
# 构建丰富的关系信息描述
if relationship_text:
# 转换关系分数为描述性文本
@@ -1620,7 +1628,7 @@ class DefaultReplyer:
relationship_level = "认识的人"
else:
relationship_level = "陌生人"
return f"你与{sender}的关系:{relationship_level}(关系分:{relationship_score:.2f}/1.0)。{relationship_text}"
else:
return f"你与{sender}是初次见面,关系分:{relationship_score:.2f}/1.0。"
@@ -1629,7 +1637,7 @@ class DefaultReplyer:
else:
logger.warning("AFC关系追踪器未初始化使用默认关系信息")
return f"你与{sender}是普通朋友关系。"
except Exception as e:
logger.error(f"获取AFC关系信息失败: {e}")
return f"你与{sender}是普通朋友关系。"

View File

@@ -37,7 +37,7 @@ def replace_user_references_sync(
"""
if not content:
return ""
if name_resolver is None:
def default_resolver(platform: str, user_id: str) -> str:
# 检查是否是机器人自己
@@ -828,8 +828,8 @@ async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
# 从数据库中获取图片描述
description = "[图片内容未知]" # 默认描述
try:
async with get_db_session() as session:
image = (await session.execute(select(Images).where(Images.image_id == pic_id))).scalar_one_or_none()
with get_db_session() as session:
image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar_one_or_none()
if image and image.description: # type: ignore
description = image.description
except Exception:

View File

@@ -25,7 +25,7 @@ logger = get_logger("unified_prompt")
@dataclass
class PromptParameters:
"""统一提示词参数系统"""
# 基础参数
chat_id: str = ""
is_group_chat: bool = False
@@ -34,7 +34,7 @@ class PromptParameters:
reply_to: str = ""
extra_info: str = ""
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
# 功能开关
enable_tool: bool = True
enable_memory: bool = True
@@ -42,20 +42,20 @@ class PromptParameters:
enable_relation: bool = True
enable_cross_context: bool = True
enable_knowledge: bool = True
# 性能控制
max_context_messages: int = 50
# 调试选项
debug_mode: bool = False
# 聊天历史和上下文
chat_target_info: Optional[Dict[str, Any]] = None
message_list_before_now_long: List[Dict[str, Any]] = field(default_factory=list)
message_list_before_short: List[Dict[str, Any]] = field(default_factory=list)
chat_talking_prompt_short: str = ""
target_user_info: Optional[Dict[str, Any]] = None
# 已构建的内容块
expression_habits_block: str = ""
relation_info_block: str = ""
@@ -63,7 +63,7 @@ class PromptParameters:
tool_info_block: str = ""
knowledge_prompt: str = ""
cross_context_block: str = ""
# 其他内容块
keywords_reaction_prompt: str = ""
extra_info_block: str = ""
@@ -75,11 +75,10 @@ class PromptParameters:
reply_target_block: str = ""
mood_prompt: str = ""
action_descriptions: str = ""
# 可用动作信息
available_actions: Optional[Dict[str, Any]] = None
read_mark: float = 0.0
def validate(self) -> List[str]:
"""参数验证"""
errors = []
@@ -94,22 +93,22 @@ class PromptParameters:
class PromptContext:
"""提示词上下文管理器"""
def __init__(self):
self._context_prompts: Dict[str, Dict[str, "Prompt"]] = {}
self._current_context_var = contextvars.ContextVar("current_context", default=None)
self._context_lock = asyncio.Lock()
@property
def _current_context(self) -> Optional[str]:
"""获取当前协程的上下文ID"""
return self._current_context_var.get()
@_current_context.setter
def _current_context(self, value: Optional[str]):
"""设置当前协程的上下文ID"""
self._current_context_var.set(value) # type: ignore
@asynccontextmanager
async def async_scope(self, context_id: Optional[str] = None):
"""创建一个异步的临时提示模板作用域"""
@@ -124,13 +123,13 @@ class PromptContext:
except asyncio.TimeoutError:
logger.warning(f"获取上下文锁超时context_id: {context_id}")
context_id = None
previous_context = self._current_context
token = self._current_context_var.set(context_id) if context_id else None
else:
previous_context = self._current_context
token = None
try:
yield self
finally:
@@ -143,7 +142,7 @@ class PromptContext:
self._current_context = previous_context
except Exception:
...
async def get_prompt_async(self, name: str) -> Optional["Prompt"]:
"""异步获取当前作用域中的提示模板"""
async with self._context_lock:
@@ -156,7 +155,7 @@ class PromptContext:
):
return self._context_prompts[current_context][name]
return None
async def register_async(self, prompt: "Prompt", context_id: Optional[str] = None) -> None:
"""异步注册提示模板到指定作用域"""
async with self._context_lock:
@@ -167,49 +166,49 @@ class PromptContext:
class PromptManager:
"""统一提示词管理器"""
def __init__(self):
self._prompts = {}
self._counter = 0
self._context = PromptContext()
self._lock = asyncio.Lock()
@asynccontextmanager
async def async_message_scope(self, message_id: Optional[str] = None):
"""为消息处理创建异步临时作用域"""
async with self._context.async_scope(message_id):
yield self
async def get_prompt_async(self, name: str) -> "Prompt":
"""异步获取提示模板"""
context_prompt = await self._context.get_prompt_async(name)
if context_prompt is not None:
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
return context_prompt
async with self._lock:
if name not in self._prompts:
raise KeyError(f"Prompt '{name}' not found")
return self._prompts[name]
def generate_name(self, template: str) -> str:
"""为未命名的prompt生成名称"""
self._counter += 1
return f"prompt_{self._counter}"
def register(self, prompt: "Prompt") -> None:
"""注册一个prompt"""
if not prompt.name:
prompt.name = self.generate_name(prompt.template)
self._prompts[prompt.name] = prompt
def add_prompt(self, name: str, fstr: str) -> "Prompt":
"""添加新提示模板"""
prompt = Prompt(fstr, name=name)
if prompt.name:
self._prompts[prompt.name] = prompt
return prompt
async def format_prompt(self, name: str, **kwargs) -> str:
"""格式化提示模板"""
prompt = await self.get_prompt_async(name)
@@ -230,21 +229,21 @@ class Prompt:
统一提示词类 - 合并模板管理和智能构建功能
真正的Prompt类支持模板管理和智能上下文构建
"""
# 临时标记,作为类常量
_TEMP_LEFT_BRACE = "__ESCAPED_LEFT_BRACE__"
_TEMP_RIGHT_BRACE = "__ESCAPED_RIGHT_BRACE__"
def __init__(
self,
template: str,
name: Optional[str] = None,
parameters: Optional[PromptParameters] = None,
should_register: bool = True
should_register: bool = True,
):
"""
初始化统一提示词
Args:
template: 提示词模板字符串
name: 提示词名称
@@ -256,14 +255,14 @@ class Prompt:
self.parameters = parameters or PromptParameters()
self.args = self._parse_template_args(template)
self._formatted_result = ""
# 预处理模板中的转义花括号
self._processed_template = self._process_escaped_braces(template)
# 自动注册
if should_register and not global_prompt_manager.context._current_context:
global_prompt_manager.register(self)
@staticmethod
def _process_escaped_braces(template) -> str:
"""处理模板中的转义花括号"""
@@ -271,14 +270,14 @@ class Prompt:
template = "\n".join(str(item) for item in template)
elif not isinstance(template, str):
template = str(template)
return template.replace("\\{", Prompt._TEMP_LEFT_BRACE).replace("\\}", Prompt._TEMP_RIGHT_BRACE)
@staticmethod
def _restore_escaped_braces(template: str) -> str:
"""将临时标记还原为实际的花括号字符"""
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
def _parse_template_args(self, template: str) -> List[str]:
"""解析模板参数"""
template_args = []
@@ -288,11 +287,11 @@ class Prompt:
if expr and expr not in template_args:
template_args.append(expr)
return template_args
async def build(self) -> str:
"""
构建完整的提示词,包含智能上下文
Returns:
str: 构建完成的提示词文本
"""
@@ -301,38 +300,38 @@ class Prompt:
if errors:
logger.error(f"参数验证失败: {', '.join(errors)}")
raise ValueError(f"参数验证失败: {', '.join(errors)}")
start_time = time.time()
try:
# 构建上下文数据
context_data = await self._build_context_data()
# 格式化模板
result = await self._format_with_context(context_data)
total_time = time.time() - start_time
logger.debug(f"Prompt构建完成模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s")
self._formatted_result = result
return result
except asyncio.TimeoutError as e:
logger.error(f"构建Prompt超时: {e}")
raise TimeoutError(f"构建Prompt超时: {e}") from e
except Exception as e:
logger.error(f"构建Prompt失败: {e}")
raise RuntimeError(f"构建Prompt失败: {e}") from e
async def _build_context_data(self) -> Dict[str, Any]:
"""构建智能上下文数据"""
# 并行执行所有构建任务
start_time = time.time()
try:
# 准备构建任务
tasks = []
task_names = []
# 初始化预构建参数
pre_built_params = {}
if self.parameters.expression_habits_block:
@@ -347,32 +346,32 @@ class Prompt:
pre_built_params["knowledge_prompt"] = self.parameters.knowledge_prompt
if self.parameters.cross_context_block:
pre_built_params["cross_context_block"] = self.parameters.cross_context_block
# 根据参数确定要构建的项
if self.parameters.enable_expression and not pre_built_params.get("expression_habits_block"):
tasks.append(self._build_expression_habits())
task_names.append("expression_habits")
if self.parameters.enable_memory and not pre_built_params.get("memory_block"):
tasks.append(self._build_memory_block())
task_names.append("memory_block")
if self.parameters.enable_relation and not pre_built_params.get("relation_info_block"):
tasks.append(self._build_relation_info())
task_names.append("relation_info")
if self.parameters.enable_tool and not pre_built_params.get("tool_info_block"):
tasks.append(self._build_tool_info())
task_names.append("tool_info")
if self.parameters.enable_knowledge and not pre_built_params.get("knowledge_prompt"):
tasks.append(self._build_knowledge_info())
task_names.append("knowledge_info")
if self.parameters.enable_cross_context and not pre_built_params.get("cross_context_block"):
tasks.append(self._build_cross_context())
task_names.append("cross_context")
# 性能优化
base_timeout = 20.0
task_timeout = 2.0
@@ -380,13 +379,13 @@ class Prompt:
max(base_timeout, len(tasks) * task_timeout),
30.0,
)
max_concurrent_tasks = 5
if len(tasks) > max_concurrent_tasks:
results = []
for i in range(0, len(tasks), max_concurrent_tasks):
batch_tasks = tasks[i : i + max_concurrent_tasks]
batch_results = await asyncio.wait_for(
asyncio.gather(*batch_tasks, return_exceptions=True), timeout=timeout_seconds
)
@@ -395,53 +394,55 @@ class Prompt:
results = await asyncio.wait_for(
asyncio.gather(*tasks, return_exceptions=True), timeout=timeout_seconds
)
# 处理结果
context_data = {}
for i, result in enumerate(results):
task_name = task_names[i] if i < len(task_names) else f"task_{i}"
if isinstance(result, Exception):
logger.error(f"构建任务{task_name}失败: {str(result)}")
elif isinstance(result, dict):
context_data.update(result)
# 添加预构建的参数
for key, value in pre_built_params.items():
if value:
context_data[key] = value
except asyncio.TimeoutError:
logger.error(f"构建超时 ({timeout_seconds}s)")
context_data = {}
for key, value in pre_built_params.items():
if value:
context_data[key] = value
# 构建聊天历史
if self.parameters.prompt_mode == "s4u":
await self._build_s4u_chat_context(context_data)
else:
await self._build_normal_chat_context(context_data)
# 补充基础信息
context_data.update({
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt,
"extra_info_block": self.parameters.extra_info_block,
"time_block": self.parameters.time_block or f"当前时间:{time.strftime('%Y-%m-%d %H:%M:%S')}",
"identity": self.parameters.identity_block,
"schedule_block": self.parameters.schedule_block,
"moderation_prompt": self.parameters.moderation_prompt_block,
"reply_target_block": self.parameters.reply_target_block,
"mood_state": self.parameters.mood_prompt,
"action_descriptions": self.parameters.action_descriptions,
})
context_data.update(
{
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt,
"extra_info_block": self.parameters.extra_info_block,
"time_block": self.parameters.time_block or f"当前时间:{time.strftime('%Y-%m-%d %H:%M:%S')}",
"identity": self.parameters.identity_block,
"schedule_block": self.parameters.schedule_block,
"moderation_prompt": self.parameters.moderation_prompt_block,
"reply_target_block": self.parameters.reply_target_block,
"mood_state": self.parameters.mood_prompt,
"action_descriptions": self.parameters.action_descriptions,
}
)
total_time = time.time() - start_time
logger.debug(f"上下文构建完成,总耗时: {total_time:.2f}s")
return context_data
async def _build_s4u_chat_context(self, context_data: Dict[str, Any]) -> None:
"""构建S4U模式的聊天上下文"""
if not self.parameters.message_list_before_now_long:
@@ -451,21 +452,20 @@ class Prompt:
self.parameters.message_list_before_now_long,
self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "",
self.parameters.sender,
self.parameters.chat_id
self.parameters.chat_id,
)
context_data["read_history_prompt"] = read_history_prompt
context_data["unread_history_prompt"] = unread_history_prompt
async def _build_normal_chat_context(self, context_data: Dict[str, Any]) -> None:
"""构建normal模式的聊天上下文"""
if not self.parameters.chat_talking_prompt_short:
return
context_data["chat_info"] = f"""群里的聊天内容:
{self.parameters.chat_talking_prompt_short}"""
@staticmethod
async def _build_s4u_chat_history_prompts(
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str, chat_id: str
) -> Tuple[str, str]:
@@ -482,97 +482,92 @@ class Prompt:
except Exception as e:
logger.error(f"构建S4U历史消息prompt失败: {e}")
async def _build_expression_habits(self) -> Dict[str, Any]:
"""构建表达习惯"""
if not global_config.expression.enable_expression:
return {"expression_habits_block": ""}
try:
from src.chat.express.expression_selector import ExpressionSelector
# 获取聊天历史用于表情选择
chat_history = ""
if self.parameters.message_list_before_now_long:
recent_messages = self.parameters.message_list_before_now_long[-10:]
chat_history = await build_readable_messages(
recent_messages,
replace_bot_name=True,
timestamp_mode="normal",
truncate=True
chat_history = build_readable_messages(
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
)
# 创建表情选择器
expression_selector = ExpressionSelector()
expression_selector = ExpressionSelector(self.parameters.chat_id)
# 选择合适的表情
selected_expressions = await expression_selector.select_suitable_expressions_llm(
chat_history=chat_history,
current_message=self.parameters.target,
emotional_tone="neutral",
topic_type="general",
)
# 构建表达习惯块
if selected_expressions:
style_habits_str = "\n".join([f"- {expr}" for expr in selected_expressions])
expression_habits_block = f"- 你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:\n{style_habits_str}"
else:
expression_habits_block = ""
return {"expression_habits_block": expression_habits_block}
except Exception as e:
logger.error(f"构建表达习惯失败: {e}")
return {"expression_habits_block": ""}
async def _build_memory_block(self) -> Dict[str, Any]:
"""构建记忆块"""
if not global_config.memory.enable_memory:
return {"memory_block": ""}
try:
from src.chat.memory_system.memory_activator import MemoryActivator
from src.chat.memory_system.async_instant_memory_wrapper import get_async_instant_memory
# 获取聊天历史
chat_history = ""
if self.parameters.message_list_before_now_long:
recent_messages = self.parameters.message_list_before_now_long[-20:]
chat_history = await build_readable_messages(
recent_messages,
replace_bot_name=True,
timestamp_mode="normal",
truncate=True
chat_history = build_readable_messages(
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
)
# 激活长期记忆
memory_activator = MemoryActivator()
running_memories = await memory_activator.activate_memory_with_chat_history(
target_message=self.parameters.target,
chat_history_prompt=chat_history
target_message=self.parameters.target, chat_history_prompt=chat_history
)
# 获取即时记忆
async_memory_wrapper = get_async_instant_memory(self.parameters.chat_id)
instant_memory = await async_memory_wrapper.get_memory_with_fallback(self.parameters.target)
# 构建记忆块
memory_parts = []
if running_memories:
memory_parts.append("以下是当前在聊天中,你回忆起的记忆:")
for memory in running_memories:
memory_parts.append(f"- {memory['content']}")
if instant_memory:
memory_parts.append(f"- {instant_memory}")
memory_block = "\n".join(memory_parts) if memory_parts else ""
return {"memory_block": memory_block}
except Exception as e:
logger.error(f"构建记忆块失败: {e}")
return {"memory_block": ""}
async def _build_relation_info(self) -> Dict[str, Any]:
"""构建关系信息"""
try:
@@ -581,106 +576,104 @@ class Prompt:
except Exception as e:
logger.error(f"构建关系信息失败: {e}")
return {"relation_info_block": ""}
async def _build_tool_info(self) -> Dict[str, Any]:
"""构建工具信息"""
if not global_config.tool.enable_tool:
return {"tool_info_block": ""}
try:
from src.plugin_system.core.tool_use import ToolExecutor
# 获取聊天历史
chat_history = ""
if self.parameters.message_list_before_now_long:
recent_messages = self.parameters.message_list_before_now_long[-15:]
chat_history = await build_readable_messages(
recent_messages,
replace_bot_name=True,
timestamp_mode="normal",
truncate=True
chat_history = build_readable_messages(
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
)
# 创建工具执行器
tool_executor = ToolExecutor(chat_id=self.parameters.chat_id)
# 执行工具获取信息
tool_results, _, _ = await tool_executor.execute_from_chat_message(
sender=self.parameters.sender,
target_message=self.parameters.target,
chat_history=chat_history,
return_details=False
return_details=False,
)
# 构建工具信息块
if tool_results:
tool_info_parts = ["## 工具信息","以下是你通过工具获取到的实时信息:"]
tool_info_parts = ["## 工具信息", "以下是你通过工具获取到的实时信息:"]
for tool_result in tool_results:
tool_name = tool_result.get("tool_name", "unknown")
content = tool_result.get("content", "")
result_type = tool_result.get("type", "tool_result")
tool_info_parts.append(f"- 【{tool_name}{result_type}: {content}")
tool_info_parts.append("以上是你获取到的实时信息,请在回复时参考这些信息。")
tool_info_block = "\n".join(tool_info_parts)
else:
tool_info_block = ""
return {"tool_info_block": tool_info_block}
except Exception as e:
logger.error(f"构建工具信息失败: {e}")
return {"tool_info_block": ""}
async def _build_knowledge_info(self) -> Dict[str, Any]:
"""构建知识信息"""
if not global_config.lpmm_knowledge.enable:
return {"knowledge_prompt": ""}
try:
from src.chat.knowledge.knowledge_lib import qa_manager
from src.chat.knowledge.knowledge_lib import QAManager
# 获取问题文本(当前消息)
question = self.parameters.target or ""
if not question:
return {"knowledge_prompt": ""}
# 检查QA管理器是否已成功初始化
if not qa_manager:
logger.warning("QA管理器未初始化 (可能lpmm_knowledge被禁用),跳过知识库搜索。")
return {"knowledge_prompt": ""}
# 创建QA管理器
qa_manager = QAManager()
# 搜索相关知识
knowledge_results = await qa_manager.get_knowledge(
question=question
question=question, chat_id=self.parameters.chat_id, max_results=5, min_similarity=0.5
)
# 构建知识块
if knowledge_results and knowledge_results.get("knowledge_items"):
knowledge_parts = ["## 知识库信息","以下是与你当前对话相关的知识信息:"]
knowledge_parts = ["## 知识库信息", "以下是与你当前对话相关的知识信息:"]
for item in knowledge_results["knowledge_items"]:
content = item.get("content", "")
source = item.get("source", "")
relevance = item.get("relevance", 0.0)
if content:
knowledge_parts.append(f"- [相关度: {relevance}] {content}")
if summary := knowledge_results.get("summary"):
knowledge_parts.append(f"\n知识总结: {summary}")
if source:
knowledge_parts.append(f"- [{relevance:.2f}] {content} (来源: {source})")
else:
knowledge_parts.append(f"- [{relevance:.2f}] {content}")
if knowledge_results.get("summary"):
knowledge_parts.append(f"\n知识总结: {knowledge_results['summary']}")
knowledge_prompt = "\n".join(knowledge_parts)
else:
knowledge_prompt = ""
return {"knowledge_prompt": knowledge_prompt}
except Exception as e:
logger.error(f"构建知识信息失败: {e}")
return {"knowledge_prompt": ""}
async def _build_cross_context(self) -> Dict[str, Any]:
"""构建跨群上下文"""
try:
@@ -691,7 +684,7 @@ class Prompt:
except Exception as e:
logger.error(f"构建跨群上下文失败: {e}")
return {"cross_context_block": ""}
async def _format_with_context(self, context_data: Dict[str, Any]) -> str:
"""使用上下文数据格式化模板"""
if self.parameters.prompt_mode == "s4u":
@@ -700,9 +693,9 @@ class Prompt:
params = self._prepare_normal_params(context_data)
else:
params = self._prepare_default_params(context_data)
return await global_prompt_manager.format_prompt(self.name, **params) if self.name else self.format(**params)
def _prepare_s4u_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
"""准备S4U模式的参数"""
return {
@@ -723,12 +716,13 @@ class Prompt:
"time_block": context_data.get("time_block", ""),
"reply_target_block": context_data.get("reply_target_block", ""),
"reply_style": global_config.personality.reply_style,
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
"safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""),
"chat_context_type": "群聊" if self.parameters.is_group_chat else "私聊",
"safety_guidelines_block": self.parameters.safety_guidelines_block
or context_data.get("safety_guidelines_block", ""),
}
def _prepare_normal_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
"""准备Normal模式的参数"""
return {
@@ -748,11 +742,13 @@ class Prompt:
"reply_target_block": context_data.get("reply_target_block", ""),
"config_expression_style": global_config.personality.reply_style,
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
"safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""),
"safety_guidelines_block": self.parameters.safety_guidelines_block
or context_data.get("safety_guidelines_block", ""),
}
def _prepare_default_params(self, context_data: Dict[str, Any]) -> Dict[str, Any]:
"""准备默认模式的参数"""
return {
@@ -768,11 +764,13 @@ class Prompt:
"reason": "",
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"reply_style": global_config.personality.reply_style,
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt or context_data.get("keywords_reaction_prompt", ""),
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
"safety_guidelines_block": self.parameters.safety_guidelines_block or context_data.get("safety_guidelines_block", ""),
"safety_guidelines_block": self.parameters.safety_guidelines_block
or context_data.get("safety_guidelines_block", ""),
}
def format(self, *args, **kwargs) -> str:
"""格式化模板,支持位置参数和关键字参数"""
try:
@@ -785,21 +783,21 @@ class Prompt:
processed_template = self._processed_template.format(**formatted_args)
else:
processed_template = self._processed_template
# 再用关键字参数格式化
if kwargs:
processed_template = processed_template.format(**kwargs)
# 将临时标记还原为实际的花括号
result = self._restore_escaped_braces(processed_template)
return result
except (IndexError, KeyError) as e:
raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {str(e)}") from e
def __str__(self) -> str:
"""返回格式化后的结果或原始模板"""
return self._formatted_result if self._formatted_result else self.template
def __repr__(self) -> str:
"""返回提示词的表示形式"""
return f"Prompt(template='{self.template}', name='{self.name}')"
@@ -871,9 +869,7 @@ class Prompt:
return await relationship_fetcher.build_relation_info(person_id, points_num=5)
@staticmethod
async def build_cross_context(
chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]
) -> str:
async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: Optional[Dict[str, Any]]) -> str:
"""
构建跨群聊上下文 - 统一实现
@@ -889,7 +885,7 @@ class Prompt:
return ""
from src.plugin_system.apis import cross_context_api
other_chat_raw_ids = cross_context_api.get_context_groups(chat_id)
if not other_chat_raw_ids:
return ""
@@ -936,10 +932,7 @@ class Prompt:
# 工厂函数
def create_prompt(
template: str,
name: Optional[str] = None,
parameters: Optional[PromptParameters] = None,
**kwargs
template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
) -> Prompt:
"""快速创建Prompt实例的工厂函数"""
if parameters is None:
@@ -948,14 +941,10 @@ def create_prompt(
async def create_prompt_async(
template: str,
name: Optional[str] = None,
parameters: Optional[PromptParameters] = None,
**kwargs
template: str, name: Optional[str] = None, parameters: Optional[PromptParameters] = None, **kwargs
) -> Prompt:
"""异步创建Prompt实例"""
prompt = create_prompt(template, name, parameters, **kwargs)
if global_prompt_manager.context._current_context:
await global_prompt_manager.context.register_async(prompt)
return prompt

View File

@@ -332,17 +332,17 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
if global_config.response_splitter.enable and enable_splitter:
logger.info(f"回复分割器已启用,模式: {global_config.response_splitter.split_mode}")
split_mode = global_config.response_splitter.split_mode
if split_mode == "llm" and "[SPLIT]" in cleaned_text:
logger.debug("检测到 [SPLIT] 标记,使用 LLM 自定义分割。")
split_sentences_raw = cleaned_text.split("[SPLIT]")
split_sentences = [s.strip() for s in split_sentences_raw if s.strip()]
else:
if split_mode == "llm":
logger.debug("未检测到 [SPLIT] 标记,回退到基于标点的传统模式进行分割。")
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)
logger.debug("未检测到 [SPLIT] 标记,本次不进行分割。")
split_sentences = [cleaned_text]
else: # mode == "punctuation"
logger.debug("使用基于标点的传统模式进行分割。")
split_sentences = split_into_sentences_w_remove_punctuation(cleaned_text)