Merge branch 'MaiM-with-u:main' into main

This commit is contained in:
2829798842
2025-04-09 10:04:58 +08:00
committed by GitHub
57 changed files with 4898 additions and 1572 deletions

View File

@@ -1,6 +1,6 @@
import time
import asyncio
from typing import Optional, Dict, Any, List
from typing import Optional, Dict, Any, List, Tuple
from src.common.logger import get_module_logger
from src.common.database import db
from ..message.message_base import UserInfo
@@ -57,6 +57,35 @@ class ChatObserver:
self._update_event = asyncio.Event() # 触发更新的事件
self._update_complete = asyncio.Event() # 更新完成的事件
def check(self) -> bool:
"""检查距离上一次观察之后是否有了新消息
Returns:
bool: 是否有新消息
"""
logger.debug(f"检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
query = {
"chat_id": self.stream_id,
"time": {"$gt": self.last_check_time}
}
# 只需要查询是否存在,不需要获取具体消息
new_message_exists = db.messages.find_one(query) is not None
if new_message_exists:
logger.debug("发现新消息")
self.last_check_time = time.time()
return new_message_exists
def get_new_message(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""获取上一次观察的时间点后的新消息,插入到历史记录中,并返回新消息和历史记录两个对象"""
messages = self.get_message_history(self.last_check_time)
for message in messages:
self._add_message_to_history(message)
return messages, self.message_history
def new_message_after(self, time_point: float) -> bool:
"""判断是否在指定时间点后有新消息
@@ -66,6 +95,7 @@ class ChatObserver:
Returns:
bool: 是否有新消息
"""
logger.debug(f"判断是否在指定时间点后有新消息: {self.last_message_time} > {time_point}")
return self.last_message_time is None or self.last_message_time > time_point
def _add_message_to_history(self, message: Dict[str, Any]):

View File

@@ -17,7 +17,8 @@ from ..storage.storage import MessageStorage
from .chat_observer import ChatObserver
from .pfc_KnowledgeFetcher import KnowledgeFetcher
from .reply_checker import ReplyChecker
import json
from .pfc_utils import get_items_from_json
from src.individuality.individuality import Individuality
import time
logger = get_module_logger("pfc")
@@ -51,7 +52,7 @@ class ActionPlanner:
max_tokens=1000,
request_type="action_planning"
)
self.personality_info = " ".join(global_config.PROMPT_PERSONALITY)
self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2)
self.name = global_config.BOT_NICKNAME
self.chat_observer = ChatObserver.get_instance(stream_id)
@@ -67,7 +68,6 @@ class ActionPlanner:
Args:
goal: 对话目标
method: 实现方式
reasoning: 目标原因
action_history: 行动历史记录
@@ -128,43 +128,18 @@ judge_conversation: 判断对话是否结束,当发现对话目标已经达到
content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"LLM原始返回内容: {content}")
# 清理内容,尝试提取JSON部分
content = content.strip()
try:
# 尝试直接解析
result = json.loads(content)
except json.JSONDecodeError:
# 如果直接解析失败尝试查找和提取JSON部分
import re
json_pattern = r'\{[^{}]*\}'
json_match = re.search(json_pattern, content)
if json_match:
try:
result = json.loads(json_match.group())
except json.JSONDecodeError:
logger.error("提取的JSON内容解析失败返回默认行动")
return "direct_reply", "JSON解析失败选择直接回复"
else:
# 如果找不到JSON尝试从文本中提取行动和原因
if "direct_reply" in content.lower():
return "direct_reply", "从文本中提取的行动"
elif "fetch_knowledge" in content.lower():
return "fetch_knowledge", "从文本中提取的行动"
elif "wait" in content.lower():
return "wait", "从文本中提取的行动"
elif "listening" in content.lower():
return "listening", "从文本中提取的行动"
elif "rethink_goal" in content.lower():
return "rethink_goal", "从文本中提取的行动"
elif "judge_conversation" in content.lower():
return "judge_conversation", "从文本中提取的行动"
else:
logger.error("无法从返回内容中提取行动类型")
return "direct_reply", "无法解析响应,选择直接回复"
# 使用简化函数提取JSON内容
success, result = get_items_from_json(
content,
"action", "reason",
default_values={"action": "direct_reply", "reason": "默认原因"}
)
# 验证JSON字段
action = result.get("action", "direct_reply")
reason = result.get("reason", "默认原因")
if not success:
return "direct_reply", "JSON解析失败选择直接回复"
action = result["action"]
reason = result["reason"]
# 验证action类型
if action not in ["direct_reply", "fetch_knowledge", "wait", "listening", "rethink_goal", "judge_conversation"]:
@@ -191,10 +166,15 @@ class GoalAnalyzer:
request_type="conversation_goal"
)
self.personality_info = " ".join(global_config.PROMPT_PERSONALITY)
self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2)
self.name = global_config.BOT_NICKNAME
self.nick_name = global_config.BOT_ALIAS_NAMES
self.chat_observer = ChatObserver.get_instance(stream_id)
# 多目标存储结构
self.goals = [] # 存储多个目标
self.max_goals = 3 # 同时保持的最大目标数量
self.current_goal_and_reason = None
async def analyze_goal(self) -> Tuple[str, str, str]:
"""分析对话历史并设定目标
@@ -220,12 +200,29 @@ class GoalAnalyzer:
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
personality_text = f"你的名字是{self.name}{self.personality_info}"
# 构建当前已有目标的文本
existing_goals_text = ""
if self.goals:
existing_goals_text = "当前已有的对话目标:\n"
for i, (goal, _, reason) in enumerate(self.goals):
existing_goals_text += f"{i+1}. 目标: {goal}, 原因: {reason}\n"
prompt = f"""{personality_text}。现在你在参与一场QQ聊天请分析以下聊天记录并根据你的性格特征确定个明确的对话目标。
目标应该反映出对话的意图和期望的结果
prompt = f"""{personality_text}。现在你在参与一场QQ聊天请分析以下聊天记录并根据你的性格特征确定个明确的对话目标。
目标应该反映出对话的不同方面和意图。
{existing_goals_text}
聊天记录:
{chat_history_text}
请以JSON格式输出包含以下字段
请分析当前对话并确定最适合的对话目标。你可以:
1. 保持现有目标不变
2. 修改现有目标
3. 添加新目标
4. 删除不再相关的目标
请以JSON格式输出一个当前最主要的对话目标包含以下字段
1. goal: 对话目标(简短的一句话)
2. reasoning: 对话原因,为什么设定这个目标(简要解释)
@@ -239,51 +236,32 @@ class GoalAnalyzer:
content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"LLM原始返回内容: {content}")
# 清理和验证返回内容
if not content or not isinstance(content, str):
logger.error("LLM返回内容为空或格式不正确")
continue
# 尝试提取JSON部分
content = content.strip()
try:
# 尝试直接解析
result = json.loads(content)
except json.JSONDecodeError:
# 如果直接解析失败尝试查找和提取JSON部分
import re
json_pattern = r'\{[^{}]*\}'
json_match = re.search(json_pattern, content)
if json_match:
try:
result = json.loads(json_match.group())
except json.JSONDecodeError:
logger.error(f"提取的JSON内容解析失败重试第{retry + 1}")
continue
else:
logger.error(f"无法在返回内容中找到有效的JSON重试第{retry + 1}")
continue
# 使用简化函数提取JSON内容
success, result = get_items_from_json(
content,
"goal", "reasoning",
required_types={"goal": str, "reasoning": str}
)
# 验证JSON字段
if not all(key in result for key in ["goal", "reasoning"]):
logger.error(f"JSON缺少必要字段实际内容: {result},重试第{retry + 1}")
if not success:
logger.error(f"无法解析JSON重试第{retry + 1}")
continue
goal = result["goal"]
reasoning = result["reasoning"]
# 验证字段内容
if not isinstance(goal, str) or not isinstance(reasoning, str):
logger.error(f"JSON字段类型错误goal和reasoning必须是字符串重试第{retry + 1}")
continue
if not goal.strip() or not reasoning.strip():
logger.error(f"JSON字段内容为空重试第{retry + 1}")
continue
# 使用默认的方法
method = "以友好的态度回应"
return goal, method, reasoning
# 更新目标列表
await self._update_goals(goal, method, reasoning)
# 返回当前最主要的目标
if self.goals:
current_goal, current_method, current_reasoning = self.goals[0]
return current_goal, current_method, current_reasoning
else:
return goal, method, reasoning
except Exception as e:
logger.error(f"分析对话目标时出错: {str(e)},重试第{retry + 1}")
@@ -293,8 +271,69 @@ class GoalAnalyzer:
# 所有重试都失败后的默认返回
return "保持友好的对话", "以友好的态度回应", "确保对话顺利进行"
async def _update_goals(self, new_goal: str, method: str, reasoning: str):
"""更新目标列表
Args:
new_goal: 新的目标
method: 实现目标的方法
reasoning: 目标的原因
"""
# 检查新目标是否与现有目标相似
for i, (existing_goal, _, _) in enumerate(self.goals):
if self._calculate_similarity(new_goal, existing_goal) > 0.7: # 相似度阈值
# 更新现有目标
self.goals[i] = (new_goal, method, reasoning)
# 将此目标移到列表前面(最主要的位置)
self.goals.insert(0, self.goals.pop(i))
return
# 添加新目标到列表前面
self.goals.insert(0, (new_goal, method, reasoning))
# 限制目标数量
if len(self.goals) > self.max_goals:
self.goals.pop() # 移除最老的目标
def _calculate_similarity(self, goal1: str, goal2: str) -> float:
"""简单计算两个目标之间的相似度
这里使用一个简单的实现,实际可以使用更复杂的文本相似度算法
Args:
goal1: 第一个目标
goal2: 第二个目标
Returns:
float: 相似度得分 (0-1)
"""
# 简单实现:检查重叠字数比例
words1 = set(goal1)
words2 = set(goal2)
overlap = len(words1.intersection(words2))
total = len(words1.union(words2))
return overlap / total if total > 0 else 0
async def get_all_goals(self) -> List[Tuple[str, str, str]]:
"""获取所有当前目标
Returns:
List[Tuple[str, str, str]]: 目标列表,每项为(目标, 方法, 原因)
"""
return self.goals.copy()
async def get_alternative_goals(self) -> List[Tuple[str, str, str]]:
"""获取除了当前主要目标外的其他备选目标
Returns:
List[Tuple[str, str, str]]: 备选目标列表
"""
if len(self.goals) <= 1:
return []
return self.goals[1:].copy()
async def analyze_conversation(self,goal,reasoning):
async def analyze_conversation(self, goal, reasoning):
messages = self.chat_observer.get_message_history()
chat_history_text = ""
for msg in messages:
@@ -330,58 +369,31 @@ class GoalAnalyzer:
content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"LLM原始返回内容: {content}")
# 清理和验证返回内容
if not content or not isinstance(content, str):
logger.error("LLM返回内容为空或格式不正确")
return False, False, "确保对话顺利进行"
# 尝试提取JSON部分
content = content.strip()
try:
# 尝试直接解析
result = json.loads(content)
except json.JSONDecodeError:
# 如果直接解析失败尝试查找和提取JSON部分
import re
json_pattern = r'\{[^{}]*\}'
json_match = re.search(json_pattern, content)
if json_match:
try:
result = json.loads(json_match.group())
except json.JSONDecodeError as e:
logger.error(f"提取的JSON内容解析失败: {e}")
return False, False, "确保对话顺利进行"
else:
logger.error("无法在返回内容中找到有效的JSON")
return False, False, "确保对话顺利进行"
# 使用简化函数提取JSON内容
success, result = get_items_from_json(
content,
"goal_achieved", "stop_conversation", "reason",
required_types={
"goal_achieved": bool,
"stop_conversation": bool,
"reason": str
}
)
# 验证JSON字段
if not all(key in result for key in ["goal_achieved", "stop_conversation", "reason"]):
logger.error(f"JSON缺少必要字段实际内容: {result}")
return False, False, "确保对话顺利进行"
goal_achieved = result["goal_achieved"]
stop_conversation = result["stop_conversation"]
reason = result["reason"]
# 验证字段类型
if not isinstance(goal_achieved, bool):
logger.error("goal_achieved 必须是布尔值")
return False, False, "确保对话顺利进行"
if not isinstance(stop_conversation, bool):
logger.error("stop_conversation 必须是布尔值")
return False, False, "确保对话顺利进行"
if not isinstance(reason, str):
logger.error("reason 必须是字符串")
return False, False, "确保对话顺利进行"
if not reason.strip():
logger.error("reason 不能为空")
if not success:
return False, False, "确保对话顺利进行"
return goal_achieved, stop_conversation, reason
# 如果当前目标达成,从目标列表中移除
if result["goal_achieved"] and not result["stop_conversation"]:
for i, (g, _, _) in enumerate(self.goals):
if g == goal:
self.goals.pop(i)
# 如果还有其他目标,不停止对话
if self.goals:
result["stop_conversation"] = False
break
return result["goal_achieved"], result["stop_conversation"], result["reason"]
except Exception as e:
logger.error(f"分析对话目标时出错: {str(e)}")
@@ -392,7 +404,7 @@ class Waiter:
"""快 速 等 待"""
def __init__(self, stream_id: str):
self.chat_observer = ChatObserver.get_instance(stream_id)
self.personality_info = " ".join(global_config.PROMPT_PERSONALITY)
self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2)
self.name = global_config.BOT_NICKNAME
async def wait(self) -> bool:
@@ -406,8 +418,8 @@ class Waiter:
await asyncio.sleep(1)
logger.info("等待中...")
# 检查是否超过60秒
if time.time() - wait_start_time > 60:
logger.info("等待超过60秒结束对话")
if time.time() - wait_start_time > 300:
logger.info("等待超过300秒结束对话")
return True
logger.info("等待结束")
return False
@@ -423,7 +435,7 @@ class ReplyGenerator:
max_tokens=300,
request_type="reply_generation"
)
self.personality_info = " ".join(global_config.PROMPT_PERSONALITY)
self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2)
self.name = global_config.BOT_NICKNAME
self.chat_observer = ChatObserver.get_instance(stream_id)
self.reply_checker = ReplyChecker(stream_id)
@@ -435,19 +447,18 @@ class ReplyGenerator:
knowledge_cache: Dict[str, str],
previous_reply: Optional[str] = None,
retry_count: int = 0
) -> Tuple[str, bool]:
) -> str:
"""生成回复
Args:
goal: 对话目标
method: 实现方式
chat_history: 聊天历史
knowledge_cache: 知识缓存
previous_reply: 上一次生成的回复(如果有)
retry_count: 当前重试次数
Returns:
Tuple[str, bool]: (生成的回复, 是否需要重新规划)
str: 生成的回复
"""
# 构建提示词
logger.debug(f"开始生成回复:当前目标: {goal}")
@@ -508,53 +519,105 @@ class ReplyGenerator:
try:
content, _ = await self.llm.generate_response_async(prompt)
logger.info(f"生成的回复: {content}")
is_new = self.chat_observer.check()
logger.debug(f"再看一眼聊天记录,{'' if is_new else '没有'}新消息")
# 检查生成回复是否合适
is_suitable, reason, need_replan = await self.reply_checker.check(
content, goal, retry_count
)
if not is_suitable:
logger.warning(f"生成的回复不合适,原因: {reason}")
if need_replan:
logger.info("需要重新规划对话目标")
return "让我重新思考一下...", True
else:
# 递归调用将当前回复作为previous_reply传入
return await self.generate(
goal, chat_history, knowledge_cache,
content, retry_count + 1
)
# 如果有新消息,重新生成回复
if is_new:
logger.info("检测到新消息,重新生成回复")
return await self.generate(
goal, chat_history, knowledge_cache,
None, retry_count
)
return content, False
return content
except Exception as e:
logger.error(f"生成回复时出错: {e}")
return "抱歉,我现在有点混乱,让我重新思考一下...", True
return "抱歉,我现在有点混乱,让我重新思考一下..."
async def check_reply(
self,
reply: str,
goal: str,
retry_count: int = 0
) -> Tuple[bool, str, bool]:
"""检查回复是否合适
Args:
reply: 生成的回复
goal: 对话目标
retry_count: 当前重试次数
Returns:
Tuple[bool, str, bool]: (是否合适, 原因, 是否需要重新规划)
"""
return await self.reply_checker.check(reply, goal, retry_count)
class Conversation:
# 类级别的实例管理
_instances: Dict[str, 'Conversation'] = {}
_instance_lock = asyncio.Lock() # 类级别的全局锁
_init_events: Dict[str, asyncio.Event] = {} # 初始化完成事件
_initializing: Dict[str, bool] = {} # 标记是否正在初始化
@classmethod
def get_instance(cls, stream_id: str) -> 'Conversation':
"""获取或创建对话实例"""
if stream_id not in cls._instances:
cls._instances[stream_id] = cls(stream_id)
logger.info(f"创建新的对话实例: {stream_id}")
return cls._instances[stream_id]
async def get_instance(cls, stream_id: str) -> Optional['Conversation']:
"""获取或创建对话实例
Args:
stream_id: 聊天流ID
Returns:
Optional[Conversation]: 对话实例如果创建或等待失败则返回None
"""
try:
# 使用全局锁来确保线程安全
async with cls._instance_lock:
# 如果已经在初始化中,等待初始化完成
if stream_id in cls._initializing and cls._initializing[stream_id]:
# 释放锁等待初始化
cls._instance_lock.release()
try:
await asyncio.wait_for(cls._init_events[stream_id].wait(), timeout=5.0)
except asyncio.TimeoutError:
logger.error(f"等待实例 {stream_id} 初始化超时")
return None
finally:
await cls._instance_lock.acquire()
# 如果实例不存在,创建新实例
if stream_id not in cls._instances:
cls._instances[stream_id] = cls(stream_id)
cls._init_events[stream_id] = asyncio.Event()
cls._initializing[stream_id] = True
logger.info(f"创建新的对话实例: {stream_id}")
return cls._instances[stream_id]
except Exception as e:
logger.error(f"获取对话实例失败: {e}")
return None
@classmethod
def remove_instance(cls, stream_id: str):
"""删除对话实例"""
if stream_id in cls._instances:
# 停止相关组件
instance = cls._instances[stream_id]
instance.chat_observer.stop()
# 删除实例
del cls._instances[stream_id]
logger.info(f"已删除对话实例 {stream_id}")
async def remove_instance(cls, stream_id: str):
"""删除对话实例
Args:
stream_id: 聊天流ID
"""
async with cls._instance_lock:
if stream_id in cls._instances:
# 停止相关组件
instance = cls._instances[stream_id]
instance.chat_observer.stop()
# 删除实例
del cls._instances[stream_id]
if stream_id in cls._init_events:
del cls._init_events[stream_id]
if stream_id in cls._initializing:
del cls._initializing[stream_id]
logger.info(f"已删除对话实例 {stream_id}")
def __init__(self, stream_id: str):
"""初始化对话系统"""
@@ -592,13 +655,21 @@ class Conversation:
async def start(self):
"""开始对话流程"""
logger.info("对话系统启动")
self.should_continue = True
self.chat_observer.start() # 启动观察器
await asyncio.sleep(1)
# 启动对话循环
await self._conversation_loop()
try:
logger.info("对话系统启动")
self.should_continue = True
self.chat_observer.start() # 启动观察器
await asyncio.sleep(1)
# 启动对话循环
await self._conversation_loop()
except Exception as e:
logger.error(f"启动对话系统失败: {e}")
raise
finally:
# 标记初始化完成
self._init_events[self.stream_id].set()
self._initializing[self.stream_id] = False
async def _conversation_loop(self):
"""对话循环"""
# 获取最近的消息历史
@@ -658,17 +729,101 @@ class Conversation:
if action == "direct_reply":
self.state = ConversationState.GENERATING
messages = self.chat_observer.get_message_history(limit=30)
self.generated_reply, need_replan = await self.reply_generator.generate(
self.generated_reply = await self.reply_generator.generate(
self.current_goal,
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache
)
if need_replan:
self.state = ConversationState.RETHINKING
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
else:
await self._send_reply()
# 检查回复是否合适
is_suitable, reason, need_replan = await self.reply_generator.check_reply(
self.generated_reply,
self.current_goal
)
if not is_suitable:
logger.warning(f"生成的回复不合适,原因: {reason}")
if need_replan:
# 尝试切换到其他备选目标
alternative_goals = await self.goal_analyzer.get_alternative_goals()
if alternative_goals:
# 有备选目标,尝试使用下一个目标
self.current_goal, self.current_method, self.goal_reasoning = alternative_goals[0]
logger.info(f"切换到备选目标: {self.current_goal}")
# 使用新目标生成回复
self.generated_reply = await self.reply_generator.generate(
self.current_goal,
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache
)
# 检查使用新目标生成的回复是否合适
is_suitable, reason, _ = await self.reply_generator.check_reply(
self.generated_reply,
self.current_goal
)
if is_suitable:
# 如果新目标的回复合适,调整目标优先级
await self.goal_analyzer._update_goals(
self.current_goal,
self.current_method,
self.goal_reasoning
)
else:
# 如果新目标还是不合适,重新思考目标
self.state = ConversationState.RETHINKING
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
return
else:
# 没有备选目标,重新分析
self.state = ConversationState.RETHINKING
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
return
else:
# 重新生成回复
self.generated_reply = await self.reply_generator.generate(
self.current_goal,
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache,
self.generated_reply # 将不合适的回复作为previous_reply传入
)
while self.chat_observer.check():
if not is_suitable:
logger.warning(f"生成的回复不合适,原因: {reason}")
if need_replan:
# 尝试切换到其他备选目标
alternative_goals = await self.goal_analyzer.get_alternative_goals()
if alternative_goals:
# 有备选目标,尝试使用下一个目标
self.current_goal, self.current_method, self.goal_reasoning = alternative_goals[0]
logger.info(f"切换到备选目标: {self.current_goal}")
# 使用新目标生成回复
self.generated_reply = await self.reply_generator.generate(
self.current_goal,
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache
)
is_suitable = True # 假设使用新目标后回复是合适的
else:
# 没有备选目标,重新分析
self.state = ConversationState.RETHINKING
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
return
else:
# 重新生成回复
self.generated_reply = await self.reply_generator.generate(
self.current_goal,
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache,
self.generated_reply # 将不合适的回复作为previous_reply传入
)
await self._send_reply()
elif action == "fetch_knowledge":
self.state = ConversationState.GENERATING
@@ -682,17 +837,58 @@ class Conversation:
if knowledge != "未找到相关知识":
self.knowledge_cache[sources] = knowledge
self.generated_reply, need_replan = await self.reply_generator.generate(
self.generated_reply = await self.reply_generator.generate(
self.current_goal,
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache
)
if need_replan:
self.state = ConversationState.RETHINKING
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
else:
await self._send_reply()
# 检查回复是否合适
is_suitable, reason, need_replan = await self.reply_generator.check_reply(
self.generated_reply,
self.current_goal
)
if not is_suitable:
logger.warning(f"生成的回复不合适,原因: {reason}")
if need_replan:
# 尝试切换到其他备选目标
alternative_goals = await self.goal_analyzer.get_alternative_goals()
if alternative_goals:
# 有备选目标,尝试使用
self.current_goal, self.current_method, self.goal_reasoning = alternative_goals[0]
logger.info(f"切换到备选目标: {self.current_goal}")
# 使用新目标获取知识并生成回复
knowledge, sources = await self.knowledge_fetcher.fetch(
self.current_goal,
[self._convert_to_message(msg) for msg in messages]
)
if knowledge != "未找到相关知识":
self.knowledge_cache[sources] = knowledge
self.generated_reply = await self.reply_generator.generate(
self.current_goal,
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache
)
else:
# 没有备选目标,重新分析
self.state = ConversationState.RETHINKING
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
return
else:
# 重新生成回复
self.generated_reply = await self.reply_generator.generate(
self.current_goal,
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache,
self.generated_reply # 将不合适的回复作为previous_reply传入
)
await self._send_reply()
elif action == "rethink_goal":
self.state = ConversationState.RETHINKING
@@ -701,6 +897,16 @@ class Conversation:
elif action == "judge_conversation":
self.state = ConversationState.JUDGING
self.goal_achieved, self.stop_conversation, self.reason = await self.goal_analyzer.analyze_conversation(self.current_goal, self.goal_reasoning)
# 如果当前目标达成但还有其他目标
if self.goal_achieved and not self.stop_conversation:
alternative_goals = await self.goal_analyzer.get_alternative_goals()
if alternative_goals:
# 切换到下一个目标
self.current_goal, self.current_method, self.goal_reasoning = alternative_goals[0]
logger.info(f"当前目标已达成,切换到新目标: {self.current_goal}")
return
if self.stop_conversation:
await self._stop_conversation()
@@ -724,7 +930,7 @@ class Conversation:
self.should_continue = False
self.state = ConversationState.ENDED
# 删除实例这会同时停止chat_observer
self.remove_instance(self.stream_id)
await self.remove_instance(self.stream_id)
async def _send_timeout_message(self):
"""发送超时结束消息"""
@@ -821,7 +1027,7 @@ class DirectMessageSender:
if not end_point:
raise ValueError(f"未找到平台:{chat_stream.platform} 的url配置")
await global_api.send_message(end_point, message_json)
await global_api.send_message_REST(end_point, message_json)
# 存储消息
await self.storage.store_message(message, message.chat_stream)

View File

@@ -0,0 +1,72 @@
import json
import re
from typing import Dict, Any, Optional, Tuple
from src.common.logger import get_module_logger
logger = get_module_logger("pfc_utils")
def get_items_from_json(
content: str,
*items: str,
default_values: Optional[Dict[str, Any]] = None,
required_types: Optional[Dict[str, type]] = None
) -> Tuple[bool, Dict[str, Any]]:
"""从文本中提取JSON内容并获取指定字段
Args:
content: 包含JSON的文本
*items: 要提取的字段名
default_values: 字段的默认值,格式为 {字段名: 默认值}
required_types: 字段的必需类型,格式为 {字段名: 类型}
Returns:
Tuple[bool, Dict[str, Any]]: (是否成功, 提取的字段字典)
"""
content = content.strip()
result = {}
# 设置默认值
if default_values:
result.update(default_values)
# 尝试解析JSON
try:
json_data = json.loads(content)
except json.JSONDecodeError:
# 如果直接解析失败尝试查找和提取JSON部分
json_pattern = r'\{[^{}]*\}'
json_match = re.search(json_pattern, content)
if json_match:
try:
json_data = json.loads(json_match.group())
except json.JSONDecodeError:
logger.error("提取的JSON内容解析失败")
return False, result
else:
logger.error("无法在返回内容中找到有效的JSON")
return False, result
# 提取字段
for item in items:
if item in json_data:
result[item] = json_data[item]
# 验证必需字段
if not all(item in result for item in items):
logger.error(f"JSON缺少必要字段实际内容: {json_data}")
return False, result
# 验证字段类型
if required_types:
for field, expected_type in required_types.items():
if field in result and not isinstance(result[field], expected_type):
logger.error(f"{field} 必须是 {expected_type.__name__} 类型")
return False, result
# 验证字符串字段不为空
for field in items:
if isinstance(result[field], str) and not result[field].strip():
logger.error(f"{field} 不能为空")
return False, result
return True, result

View File

@@ -9,6 +9,7 @@ from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig
from ..chat_module.think_flow_chat.think_flow_chat import ThinkFlowChat
from ..chat_module.reasoning_chat.reasoning_chat import ReasoningChat
import asyncio
import traceback
# 定义日志配置
chat_config = LogConfig(
@@ -42,11 +43,24 @@ class ChatBot:
if global_config.enable_pfc_chatting:
# 获取或创建对话实例
conversation = Conversation.get_instance(chat_id)
conversation = await Conversation.get_instance(chat_id)
if conversation is None:
logger.error(f"创建或获取对话实例失败: {chat_id}")
return
# 如果是新创建的实例,启动对话系统
if conversation.state == ConversationState.INIT:
asyncio.create_task(conversation.start())
logger.info(f"为聊天 {chat_id} 创建新的对话实例")
elif conversation.state == ConversationState.ENDED:
# 如果实例已经结束,重新创建
await Conversation.remove_instance(chat_id)
conversation = await Conversation.get_instance(chat_id)
if conversation is None:
logger.error(f"重新创建对话实例失败: {chat_id}")
return
asyncio.create_task(conversation.start())
logger.info(f"为聊天 {chat_id} 重新创建对话实例")
except Exception as e:
logger.error(f"创建PFC聊天流失败: {e}")
@@ -78,8 +92,13 @@ class ChatBot:
try:
message = MessageRecv(message_data)
groupinfo = message.message_info.group_info
logger.debug(f"处理消息:{str(message_data)[:50]}...")
userinfo = message.message_info.user_info
logger.debug(f"处理消息:{str(message_data)[:80]}...")
if userinfo.user_id in global_config.ban_user_id:
logger.debug(f"用户{userinfo.user_id}被禁止回复")
return
if global_config.enable_pfc_chatting:
try:
if groupinfo is None and global_config.enable_friend_chat:
@@ -96,11 +115,11 @@ class ChatBot:
await self._create_PFC_chat(message)
else:
if groupinfo.group_id in global_config.talk_allowed_groups:
logger.debug(f"开始群聊模式{message_data}")
logger.debug(f"开始群聊模式{str(message_data)[:50]}...")
if global_config.response_mode == "heart_flow":
await self.think_flow_chat.process_message(message_data)
elif global_config.response_mode == "reasoning":
logger.debug(f"开始推理模式{message_data}")
logger.debug(f"开始推理模式{str(message_data)[:50]}...")
await self.reasoning_chat.process_message(message_data)
else:
logger.error(f"未知的回复模式,请检查配置文件!!: {global_config.response_mode}")
@@ -126,6 +145,7 @@ class ChatBot:
logger.error(f"未知的回复模式,请检查配置文件!!: {global_config.response_mode}")
except Exception as e:
logger.error(f"预处理消息失败: {e}")
traceback.print_exc()
# 创建全局ChatBot实例

View File

@@ -28,7 +28,7 @@ class ChatStream:
self.platform = platform
self.user_info = user_info
self.group_info = group_info
self.create_time = data.get("create_time", int(time.time())) if data else int(time.time())
self.create_time = data.get("create_time", time.time()) if data else time.time()
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
self.saved = False
@@ -60,7 +60,7 @@ class ChatStream:
def update_active_time(self):
"""更新最后活跃时间"""
self.last_active_time = int(time.time())
self.last_active_time = time.time()
self.saved = False

View File

@@ -249,7 +249,22 @@ class EmojiManager:
f for f in os.listdir(emoji_dir) if f.lower().endswith((".jpg", ".jpeg", ".png", ".gif"))
]
# 检查当前表情包数量
self._update_emoji_count()
if self.emoji_num >= self.emoji_num_max:
logger.warning(f"[警告] 表情包数量已达到上限({self.emoji_num}/{self.emoji_num_max}),跳过注册")
return
# 计算还可以注册的数量
remaining_slots = self.emoji_num_max - self.emoji_num
logger.info(f"[注册] 还可以注册 {remaining_slots} 个表情包")
for filename in files_to_process:
# 如果已经达到上限,停止注册
if self.emoji_num >= self.emoji_num_max:
logger.warning(f"[警告] 表情包数量已达到上限({self.emoji_num}/{self.emoji_num_max}),停止注册")
break
image_path = os.path.join(emoji_dir, filename)
# 获取图片的base64编码和哈希值
@@ -340,6 +355,10 @@ class EmojiManager:
logger.success(f"[注册] 新表情包: {filename}")
logger.info(f"[描述] {description}")
# 更新当前表情包数量
self.emoji_num += 1
logger.info(f"[统计] 当前表情包数量: {self.emoji_num}/{self.emoji_num_max}")
# 保存到images数据库
image_doc = {
"hash": image_hash,

View File

@@ -168,7 +168,7 @@ class MessageProcessBase(Message):
# 调用父类初始化
super().__init__(
message_id=message_id,
time=int(time.time()),
time=round(time.time(), 3), # 保留3位小数
chat_stream=chat_stream,
user_info=bot_user_info,
message_segment=message_segment,

View File

@@ -0,0 +1,190 @@
from ..person_info.person_info import person_info_manager
from src.common.logger import get_module_logger
import asyncio
from dataclasses import dataclass, field
from .message import MessageRecv
from ..message.message_base import BaseMessageInfo, GroupInfo
import hashlib
from typing import Dict
from collections import OrderedDict
import random
import time
from ..config.config import global_config
logger = get_module_logger("message_buffer")
@dataclass
class CacheMessages:
message: MessageRecv
cache_determination: asyncio.Event = field(default_factory=asyncio.Event) # 判断缓冲是否产生结果
result: str = "U"
class MessageBuffer:
def __init__(self):
self.buffer_pool: Dict[str, OrderedDict[str, CacheMessages]] = {}
self.lock = asyncio.Lock()
def get_person_id_(self, platform:str, user_id:str, group_info:GroupInfo):
"""获取唯一id"""
if group_info:
group_id = group_info.group_id
else:
group_id = "私聊"
key = f"{platform}_{user_id}_{group_id}"
return hashlib.md5(key.encode()).hexdigest()
async def start_caching_messages(self, message:MessageRecv):
"""添加消息,启动缓冲"""
if not global_config.message_buffer:
person_id = person_info_manager.get_person_id(message.message_info.user_info.platform,
message.message_info.user_info.user_id)
asyncio.create_task(self.save_message_interval(person_id, message.message_info))
return
person_id_ = self.get_person_id_(message.message_info.platform,
message.message_info.user_info.user_id,
message.message_info.group_info)
async with self.lock:
if person_id_ not in self.buffer_pool:
self.buffer_pool[person_id_] = OrderedDict()
# 标记该用户之前的未处理消息
for cache_msg in self.buffer_pool[person_id_].values():
if cache_msg.result == "U":
cache_msg.result = "F"
cache_msg.cache_determination.set()
logger.debug(f"被新消息覆盖信息id: {cache_msg.message.message_info.message_id}")
# 查找最近的处理成功消息(T)
recent_F_count = 0
for msg_id in reversed(self.buffer_pool[person_id_]):
msg = self.buffer_pool[person_id_][msg_id]
if msg.result == "T":
break
elif msg.result == "F":
recent_F_count += 1
# 判断条件最近T之后有超过3-5条F
if (recent_F_count >= random.randint(3, 5)):
new_msg = CacheMessages(message=message, result="T")
new_msg.cache_determination.set()
self.buffer_pool[person_id_][message.message_info.message_id] = new_msg
logger.debug(f"快速处理消息(已堆积{recent_F_count}条F): {message.message_info.message_id}")
return
# 添加新消息
self.buffer_pool[person_id_][message.message_info.message_id] = CacheMessages(message=message)
# 启动3秒缓冲计时器
person_id = person_info_manager.get_person_id(message.message_info.user_info.platform,
message.message_info.user_info.user_id)
asyncio.create_task(self.save_message_interval(person_id, message.message_info))
asyncio.create_task(self._debounce_processor(person_id_,
message.message_info.message_id,
person_id))
async def _debounce_processor(self, person_id_: str, message_id: str, person_id: str):
"""等待3秒无新消息"""
interval_time = await person_info_manager.get_value(person_id, "msg_interval")
if not isinstance(interval_time, (int, str)) or not str(interval_time).isdigit():
logger.debug("debounce_processor无效的时间")
return
interval_time = max(0.5, int(interval_time) / 1000)
await asyncio.sleep(interval_time)
async with self.lock:
if (person_id_ not in self.buffer_pool or
message_id not in self.buffer_pool[person_id_]):
logger.debug(f"消息已被清理msgid: {message_id}")
return
cache_msg = self.buffer_pool[person_id_][message_id]
if cache_msg.result == "U":
cache_msg.result = "T"
cache_msg.cache_determination.set()
async def query_buffer_result(self, message:MessageRecv) -> bool:
"""查询缓冲结果,并清理"""
if not global_config.message_buffer:
return True
person_id_ = self.get_person_id_(message.message_info.platform,
message.message_info.user_info.user_id,
message.message_info.group_info)
async with self.lock:
user_msgs = self.buffer_pool.get(person_id_, {})
cache_msg = user_msgs.get(message.message_info.message_id)
if not cache_msg:
logger.debug(f"查询异常消息不存在msgid: {message.message_info.message_id}")
return False # 消息不存在或已清理
try:
await asyncio.wait_for(cache_msg.cache_determination.wait(), timeout=10)
result = cache_msg.result == "T"
if result:
async with self.lock: # 再次加锁
# 清理所有早于当前消息的已处理消息, 收集所有早于当前消息的F消息的processed_plain_text
keep_msgs = OrderedDict()
combined_text = []
found = False
type = "text"
is_update = True
for msg_id, msg in self.buffer_pool[person_id_].items():
if msg_id == message.message_info.message_id:
found = True
type = msg.message.message_segment.type
combined_text.append(msg.message.processed_plain_text)
continue
if found:
keep_msgs[msg_id] = msg
elif msg.result == "F":
# 收集F消息的文本内容
if (hasattr(msg.message, 'processed_plain_text')
and msg.message.processed_plain_text):
if msg.message.message_segment.type == "text":
combined_text.append(msg.message.processed_plain_text)
elif msg.message.message_segment.type != "text":
is_update = False
elif msg.result == "U":
logger.debug(f"异常未处理信息id {msg.message.message_info.message_id}")
# 更新当前消息的processed_plain_text
if combined_text and combined_text[0] != message.processed_plain_text and is_update:
if type == "text":
message.processed_plain_text = "".join(combined_text)
logger.debug(f"整合了{len(combined_text)-1}条F消息的内容到当前消息")
elif type == "emoji":
combined_text.pop()
message.processed_plain_text = "".join(combined_text)
message.is_emoji = False
logger.debug(f"整合了{len(combined_text)-1}条F消息的内容覆盖当前emoji消息")
self.buffer_pool[person_id_] = keep_msgs
return result
except asyncio.TimeoutError:
logger.debug(f"查询超时消息id {message.message_info.message_id}")
return False
async def save_message_interval(self, person_id:str, message:BaseMessageInfo):
message_interval_list = await person_info_manager.get_value(person_id, "msg_interval_list")
now_time_ms = int(round(time.time() * 1000))
if len(message_interval_list) < 1000:
message_interval_list.append(now_time_ms)
else:
message_interval_list.pop(0)
message_interval_list.append(now_time_ms)
data = {
"platform" : message.platform,
"user_id" : message.user_info.user_id,
"nickname" : message.user_info.user_nickname,
"konw_time" : int(time.time())
}
await person_info_manager.update_one_field(person_id, "msg_interval_list", message_interval_list, data)
message_buffer = MessageBuffer()

View File

@@ -43,6 +43,12 @@ class Message_Sender:
# 按thinking_start_time排序时间早的在前面
return recalled_messages
async def send_via_ws(self, message: MessageSending) -> None:
try:
await global_api.send_message(message)
except Exception as e:
raise ValueError(f"未找到平台:{message.message_info.platform} 的url配置请检查配置文件") from e
async def send_message(
self,
message: MessageSending,
@@ -58,8 +64,14 @@ class Message_Sender:
logger.warning(f"消息“{message.processed_plain_text}”已被撤回,不发送")
break
if not is_recalled:
typing_time = calculate_typing_time(message.processed_plain_text)
# print(message.processed_plain_text + str(message.is_emoji))
typing_time = calculate_typing_time(
input_string=message.processed_plain_text,
thinking_start_time=message.thinking_start_time,
is_emoji=message.is_emoji)
logger.debug(f"{message.processed_plain_text},{typing_time},计算输入时间结束")
await asyncio.sleep(typing_time)
logger.debug(f"{message.processed_plain_text},{typing_time},等待输入时间结束")
message_json = message.to_dict()
@@ -69,14 +81,14 @@ class Message_Sender:
if end_point:
# logger.info(f"发送消息到{end_point}")
# logger.info(message_json)
await global_api.send_message_REST(end_point, message_json)
else:
try:
await global_api.send_message(message)
await global_api.send_message_REST(end_point, message_json)
except Exception as e:
raise ValueError(
f"未找到平台:{message.message_info.platform} 的url配置请检查配置文件"
) from e
logger.error(f"REST方式发送失败出现错误: {str(e)}")
logger.info("尝试使用ws发送")
await self.send_via_ws(message)
else:
await self.send_via_ws(message)
logger.success(f"发送消息“{message_preview}”成功")
except Exception as e:
logger.error(f"发送消息“{message_preview}”失败: {str(e)}")
@@ -214,6 +226,8 @@ class MessageManager:
await message_earliest.process()
# print(f"message_earliest.thinking_start_tim22222e:{message_earliest.thinking_start_time}")
await message_sender.send_message(message_earliest)
await self.storage.store_message(message_earliest, message_earliest.chat_stream)

View File

@@ -334,26 +334,19 @@ def process_llm_response(text: str) -> List[str]:
return sentences
def calculate_typing_time(input_string: str, chinese_time: float = 0.2, english_time: float = 0.1) -> float:
def calculate_typing_time(input_string: str, thinking_start_time: float, chinese_time: float = 0.2, english_time: float = 0.1, is_emoji: bool = False) -> float:
"""
计算输入字符串所需的时间,中文和英文字符有不同的输入时间
input_string (str): 输入的字符串
chinese_time (float): 中文字符的输入时间默认为0.2秒
english_time (float): 英文字符的输入时间默认为0.1秒
is_emoji (bool): 是否为emoji默认为False
特殊情况:
- 如果只有一个中文字符将使用3倍的中文输入时间
- 在所有输入结束后额外加上回车时间0.3秒
- 如果is_emoji为True将使用固定1秒的输入时间
"""
# 如果输入是列表,将其连接成字符串
if isinstance(input_string, list):
input_string = ''.join(input_string)
# 确保现在是字符串类型
if not isinstance(input_string, str):
input_string = str(input_string)
mood_manager = MoodManager.get_instance()
# 将0-1的唤醒度映射到-1到1
mood_arousal = mood_manager.current_mood.arousal
@@ -376,7 +369,19 @@ def calculate_typing_time(input_string: str, chinese_time: float = 0.2, english_
else: # 其他字符(如英文)
total_time += english_time
return total_time + 0.3 # 加上回车时间
if is_emoji:
total_time = 1
if time.time() - thinking_start_time > 10:
total_time = 1
# print(f"thinking_start_time:{thinking_start_time}")
# print(f"nowtime:{time.time()}")
# print(f"nowtime - thinking_start_time:{time.time() - thinking_start_time}")
# print(f"{total_time}")
return total_time # 加上回车时间
def cosine_similarity(v1, v2):

View File

@@ -17,6 +17,7 @@ from ...message import UserInfo, Seg
from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig
from ...chat.chat_stream import chat_manager
from ...person_info.relationship_manager import relationship_manager
from ...chat.message_buffer import message_buffer
# 定义日志配置
chat_config = LogConfig(
@@ -143,6 +144,8 @@ class ReasoningChat:
userinfo = message.message_info.user_info
messageinfo = message.message_info
# 消息加入缓冲池
await message_buffer.start_caching_messages(message)
# logger.info("使用推理聊天模式")
@@ -172,6 +175,17 @@ class ReasoningChat:
timer2 = time.time()
timing_results["记忆激活"] = timer2 - timer1
# 查询缓冲器结果会整合前面跳过的消息改变processed_plain_text
buffer_result = await message_buffer.query_buffer_result(message)
if not buffer_result:
if message.message_segment.type == "text":
logger.info(f"触发缓冲,已炸飞消息:{message.processed_plain_text}")
elif message.message_segment.type == "image":
logger.info("触发缓冲,已炸飞表情包/图片")
elif message.message_segment.type == "seglist":
logger.info("触发缓冲,已炸飞消息列")
return
is_mentioned = is_mentioned_bot_in_message(message)
# 计算回复意愿

View File

@@ -1,16 +1,17 @@
import random
import time
from typing import Optional
from typing import Optional, Union
from ....common.database import db
from ...memory_system.Hippocampus import HippocampusManager
from ...moods.moods import MoodManager
from ...schedule.schedule_generator import bot_schedule
from ...config.config import global_config
from ...chat.utils import get_embedding, get_recent_group_detailed_plain_text, get_recent_group_speaker
from ...chat.chat_stream import chat_manager
from src.common.logger import get_module_logger
from ...moods.moods import MoodManager
from ....individuality.individuality import Individuality
from ...memory_system.Hippocampus import HippocampusManager
from ...schedule.schedule_generator import bot_schedule
from ...config.config import global_config
from ...person_info.relationship_manager import relationship_manager
from src.common.logger import get_module_logger
logger = get_module_logger("prompt")
@@ -25,7 +26,23 @@ class PromptBuilder:
) -> tuple[str, str]:
# 开始构建prompt
prompt_personality = ""
#person
individuality = Individuality.get_instance()
personality_core = individuality.personality.personality_core
prompt_personality += personality_core
personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}"
identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}"
# 关系
who_chat_in_group = [(chat_stream.user_info.platform,
chat_stream.user_info.user_id,
@@ -102,20 +119,6 @@ class PromptBuilder:
)
keywords_reaction_prompt += rule.get("reaction", "") + ""
# 人格选择
personality = global_config.PROMPT_PERSONALITY
probability_1 = global_config.PERSONALITY_1
probability_2 = global_config.PERSONALITY_2
personality_choice = random.random()
if personality_choice < probability_1: # 第一种风格
prompt_personality = personality[0]
elif personality_choice < probability_1 + probability_2: # 第二种风格
prompt_personality = personality[1]
else: # 第三种人格
prompt_personality = personality[2]
# 中文高手(新加的好玩功能)
prompt_ger = ""
if random.random() < 0.04:
@@ -128,7 +131,7 @@ class PromptBuilder:
# 知识构建
start_time = time.time()
prompt_info = ""
prompt_info = await self.get_prompt_info(message_txt, threshold=0.5)
prompt_info = await self.get_prompt_info(message_txt, threshold=0.38)
if prompt_info:
prompt_info = f"""\n你有以下这些**知识**\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n"""
@@ -142,12 +145,13 @@ class PromptBuilder:
logger.info("开始构建prompt")
prompt = f"""
{relation_prompt_all}
{memory_prompt}
{prompt_info}
{schedule_prompt}
{chat_target}
{chat_talking_prompt}
现在"{sender_name}"说的:{message_txt}。引起了你的注意,你想要在群里发言发言或者回复这条消息。{relation_prompt_all}\n
现在"{sender_name}"说的:{message_txt}。引起了你的注意,你想要在群里发言发言或者回复这条消息。\n
你的网名叫{global_config.BOT_NICKNAME},有人也叫你{"/".join(global_config.BOT_ALIAS_NAMES)}{prompt_personality}
你正在{chat_target_2},现在请你读读之前的聊天记录,{mood_prompt},然后给出日常且口语化的回复,平淡一些,
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。{prompt_ger}
@@ -158,16 +162,156 @@ class PromptBuilder:
return prompt
async def get_prompt_info(self, message: str, threshold: float):
start_time = time.time()
related_info = ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
embedding = await get_embedding(message, request_type="prompt_build")
related_info += self.get_info_from_db(embedding, limit=1, threshold=threshold)
# 1. 先从LLM获取主题类似于记忆系统的做法
topics = []
# try:
# # 先尝试使用记忆系统的方法获取主题
# hippocampus = HippocampusManager.get_instance()._hippocampus
# topic_num = min(5, max(1, int(len(message) * 0.1)))
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
# # 提取关键词
# topics = re.findall(r"<([^>]+)>", topics_response[0])
# if not topics:
# topics = []
# else:
# topics = [
# topic.strip()
# for topic in ",".join(topics).replace("", ",").replace("、", ",").replace(" ", ",").split(",")
# if topic.strip()
# ]
# logger.info(f"从LLM提取的主题: {', '.join(topics)}")
# except Exception as e:
# logger.error(f"从LLM提取主题失败: {str(e)}")
# # 如果LLM提取失败使用jieba分词提取关键词作为备选
# words = jieba.cut(message)
# topics = [word for word in words if len(word) > 1][:5]
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
# 如果无法提取到主题,直接使用整个消息
if not topics:
logger.info("未能提取到任何主题,使用整个消息进行查询")
embedding = await get_embedding(message, request_type="prompt_build")
if not embedding:
logger.error("获取消息嵌入向量失败")
return ""
related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}")
return related_info
# 2. 对每个主题进行知识库查询
logger.info(f"开始处理{len(topics)}个主题的知识库查询")
# 优化批量获取嵌入向量减少API调用
embeddings = {}
topics_batch = [topic for topic in topics if len(topic) > 0]
if message: # 确保消息非空
topics_batch.append(message)
# 批量获取嵌入向量
embed_start_time = time.time()
for text in topics_batch:
if not text or len(text.strip()) == 0:
continue
try:
embedding = await get_embedding(text, request_type="prompt_build")
if embedding:
embeddings[text] = embedding
else:
logger.warning(f"获取'{text}'的嵌入向量失败")
except Exception as e:
logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}")
logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}")
if not embeddings:
logger.error("所有嵌入向量获取失败")
return ""
# 3. 对每个主题进行知识库查询
all_results = []
query_start_time = time.time()
# 首先添加原始消息的查询结果
if message in embeddings:
original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True)
if original_results:
for result in original_results:
result["topic"] = "原始消息"
all_results.extend(original_results)
logger.info(f"原始消息查询到{len(original_results)}条结果")
# 然后添加每个主题的查询结果
for topic in topics:
if not topic or topic not in embeddings:
continue
try:
topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True)
if topic_results:
# 添加主题标记
for result in topic_results:
result["topic"] = topic
all_results.extend(topic_results)
logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果")
except Exception as e:
logger.error(f"查询主题'{topic}'时发生错误: {str(e)}")
logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果")
# 4. 去重和过滤
process_start_time = time.time()
unique_contents = set()
filtered_results = []
for result in all_results:
content = result["content"]
if content not in unique_contents:
unique_contents.add(content)
filtered_results.append(result)
# 5. 按相似度排序
filtered_results.sort(key=lambda x: x["similarity"], reverse=True)
# 6. 限制总数量最多10条
filtered_results = filtered_results[:10]
logger.info(f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果")
# 7. 格式化输出
if filtered_results:
format_start_time = time.time()
grouped_results = {}
for result in filtered_results:
topic = result["topic"]
if topic not in grouped_results:
grouped_results[topic] = []
grouped_results[topic].append(result)
# 按主题组织输出
for topic, results in grouped_results.items():
related_info += f"【主题: {topic}\n"
for _i, result in enumerate(results, 1):
_similarity = result["similarity"]
content = result["content"].strip()
# 调试:为内容添加序号和相似度信息
# related_info += f"{i}. [{similarity:.2f}] {content}\n"
related_info += f"{content}\n"
related_info += "\n"
logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}")
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}")
return related_info
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5) -> str:
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False) -> Union[str, list]:
if not query_embedding:
return ""
return "" if not return_raw else []
# 使用余弦相似度计算
pipeline = [
{
@@ -221,13 +365,16 @@ class PromptBuilder:
]
results = list(db.knowledges.aggregate(pipeline))
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
logger.debug(f"知识库查询结果数量: {len(results)}")
if not results:
return ""
return "" if not return_raw else []
# 返回所有找到的内容,用换行分隔
return "\n".join(str(result["content"]) for result in results)
if return_raw:
return results
else:
# 返回所有找到的内容,用换行分隔
return "\n".join(str(result["content"]) for result in results)
prompt_builder = PromptBuilder()

View File

@@ -18,6 +18,7 @@ from src.heart_flow.heartflow import heartflow
from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig
from ...chat.chat_stream import chat_manager
from ...person_info.relationship_manager import relationship_manager
from ...chat.message_buffer import message_buffer
# 定义日志配置
chat_config = LogConfig(
@@ -95,6 +96,8 @@ class ThinkFlowChat:
)
if not mark_head:
mark_head = True
# print(f"thinking_start_time:{bot_message.thinking_start_time}")
message_set.add_message(bot_message)
message_manager.add_message(message_set)
@@ -161,6 +164,8 @@ class ThinkFlowChat:
userinfo = message.message_info.user_info
messageinfo = message.message_info
# 消息加入缓冲池
await message_buffer.start_caching_messages(message)
# 创建聊天流
chat = await chat_manager.get_or_create_stream(
@@ -195,8 +200,20 @@ class ThinkFlowChat:
timing_results["记忆激活"] = timer2 - timer1
logger.debug(f"记忆激活: {interested_rate}")
# 查询缓冲器结果会整合前面跳过的消息改变processed_plain_text
buffer_result = await message_buffer.query_buffer_result(message)
if not buffer_result:
if message.message_segment.type == "text":
logger.info(f"触发缓冲,已炸飞消息:{message.processed_plain_text}")
elif message.message_segment.type == "image":
logger.info("触发缓冲,已炸飞表情包/图片")
elif message.message_segment.type == "seglist":
logger.info("触发缓冲,已炸飞消息列")
return
is_mentioned = is_mentioned_bot_in_message(message)
# 计算回复意愿
current_willing_old = willing_manager.get_willing(chat_stream=chat)
# current_willing_new = (heartflow.get_subheartflow(chat.stream_id).current_state.willing - 5) / 4
@@ -236,59 +253,84 @@ class ThinkFlowChat:
do_reply = False
if random() < reply_probability:
do_reply = True
# 创建思考消息
timer1 = time.time()
thinking_id = await self._create_thinking_message(message, chat, userinfo, messageinfo)
timer2 = time.time()
timing_results["创建思考消息"] = timer2 - timer1
# 观察
timer1 = time.time()
await heartflow.get_subheartflow(chat.stream_id).do_observe()
timer2 = time.time()
timing_results["观察"] = timer2 - timer1
# 思考前脑内状态
timer1 = time.time()
await heartflow.get_subheartflow(chat.stream_id).do_thinking_before_reply(message.processed_plain_text)
timer2 = time.time()
timing_results["思考前脑内状态"] = timer2 - timer1
# 生成回复
timer1 = time.time()
response_set = await self.gpt.generate_response(message)
timer2 = time.time()
timing_results["生成回复"] = timer2 - timer1
try:
do_reply = True
# 创建思考消息
try:
timer1 = time.time()
thinking_id = await self._create_thinking_message(message, chat, userinfo, messageinfo)
timer2 = time.time()
timing_results["创建思考消息"] = timer2 - timer1
except Exception as e:
logger.error(f"心流创建思考消息失败: {e}")
try:
# 观察
timer1 = time.time()
await heartflow.get_subheartflow(chat.stream_id).do_observe()
timer2 = time.time()
timing_results["观察"] = timer2 - timer1
except Exception as e:
logger.error(f"心流观察失败: {e}")
if not response_set:
logger.info("为什么生成回复失败?")
return
# 思考前脑内状态
try:
timer1 = time.time()
await heartflow.get_subheartflow(chat.stream_id).do_thinking_before_reply(message.processed_plain_text)
timer2 = time.time()
timing_results["思考前脑内状态"] = timer2 - timer1
except Exception as e:
logger.error(f"心流思考前脑内状态失败: {e}")
# 生成回复
timer1 = time.time()
response_set = await self.gpt.generate_response(message)
timer2 = time.time()
timing_results["生成回复"] = timer2 - timer1
# 发送消息
timer1 = time.time()
await self._send_response_messages(message, chat, response_set, thinking_id)
timer2 = time.time()
timing_results["发送消息"] = timer2 - timer1
if not response_set:
logger.info("为什么生成回复失败?")
return
# 处理表情包
timer1 = time.time()
await self._handle_emoji(message, chat, response_set)
timer2 = time.time()
timing_results["处理表情包"] = timer2 - timer1
# 发送消息
try:
timer1 = time.time()
await self._send_response_messages(message, chat, response_set, thinking_id)
timer2 = time.time()
timing_results["发送消息"] = timer2 - timer1
except Exception as e:
logger.error(f"心流发送消息失败: {e}")
# 更新心流
timer1 = time.time()
await self._update_using_response(message, response_set)
timer2 = time.time()
timing_results["更新心流"] = timer2 - timer1
# 处理表情包
try:
timer1 = time.time()
await self._handle_emoji(message, chat, response_set)
timer2 = time.time()
timing_results["处理表情包"] = timer2 - timer1
except Exception as e:
logger.error(f"心流处理表情包失败: {e}")
# 更新关系情绪
timer1 = time.time()
await self._update_relationship(message, response_set)
timer2 = time.time()
timing_results["更新关系情绪"] = timer2 - timer1
# 更新心流
try:
timer1 = time.time()
await self._update_using_response(message, response_set)
timer2 = time.time()
timing_results["更新心流"] = timer2 - timer1
except Exception as e:
logger.error(f"心流更新失败: {e}")
# 更新关系情绪
try:
timer1 = time.time()
await self._update_relationship(message, response_set)
timer2 = time.time()
timing_results["更新关系情绪"] = timer2 - timer1
except Exception as e:
logger.error(f"心流更新关系情绪失败: {e}")
except Exception as e:
logger.error(f"心流处理消息失败: {e}")
# 输出性能计时结果
if do_reply:

View File

@@ -1,16 +1,13 @@
import random
import time
from typing import Optional
from ...memory_system.Hippocampus import HippocampusManager
from ...moods.moods import MoodManager
from ...schedule.schedule_generator import bot_schedule
from ...config.config import global_config
from ...chat.utils import get_recent_group_detailed_plain_text, get_recent_group_speaker
from ...chat.chat_stream import chat_manager
from src.common.logger import get_module_logger
from ...person_info.relationship_manager import relationship_manager
from ....individuality.individuality import Individuality
from src.heart_flow.heartflow import heartflow
logger = get_module_logger("prompt")
@@ -26,9 +23,10 @@ class PromptBuilder:
) -> tuple[str, str]:
current_mind_info = heartflow.get_subheartflow(stream_id).current_mind
# 开始构建prompt
individuality = Individuality.get_instance()
prompt_personality = individuality.get_prompt(type = "personality",x_person = 2,level = 1)
prompt_identity = individuality.get_prompt(type = "identity",x_person = 2,level = 1)
# 关系
who_chat_in_group = [(chat_stream.user_info.platform,
chat_stream.user_info.user_id,
@@ -90,20 +88,6 @@ class PromptBuilder:
)
keywords_reaction_prompt += rule.get("reaction", "") + ""
# 人格选择
personality = global_config.PROMPT_PERSONALITY
probability_1 = global_config.PERSONALITY_1
probability_2 = global_config.PERSONALITY_2
personality_choice = random.random()
if personality_choice < probability_1: # 第一种风格
prompt_personality = personality[0]
elif personality_choice < probability_1 + probability_2: # 第二种风格
prompt_personality = personality[1]
else: # 第三种人格
prompt_personality = personality[2]
# 中文高手(新加的好玩功能)
prompt_ger = ""
if random.random() < 0.04:
@@ -123,8 +107,8 @@ class PromptBuilder:
{chat_talking_prompt}
你刚刚脑子里在想:
{current_mind_info}
现在"{sender_name}"说的:{message_txt}。引起了你的注意,你想要在群里发言发言或者回复这条消息。{relation_prompt_all}\n
你的网名叫{global_config.BOT_NICKNAME},有人也叫你{"/".join(global_config.BOT_ALIAS_NAMES)}{prompt_personality}
现在"{sender_name}"说的:{message_txt}。引起了你的注意,你想要在群里发言发言或者回复这条消息。\n
你的网名叫{global_config.BOT_NICKNAME},有人也叫你{"/".join(global_config.BOT_ALIAS_NAMES)}{prompt_personality} {prompt_identity}
你正在{chat_target_2},现在请你读读之前的聊天记录,然后给出日常且口语化的回复,平淡一些,
尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。{prompt_ger}
请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,尽量不要说你说过的话
@@ -133,73 +117,5 @@ class PromptBuilder:
return prompt
def _build_initiative_prompt_select(self, group_id, probability_1=0.8, probability_2=0.1):
current_date = time.strftime("%Y-%m-%d", time.localtime())
current_time = time.strftime("%H:%M:%S", time.localtime())
bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task()
prompt_date = f"""今天是{current_date},现在是{current_time},你今天的日程是:
{bot_schedule.today_schedule}
你现在正在{bot_schedule_now_activity}
"""
chat_talking_prompt = ""
if group_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(
group_id, limit=global_config.MAX_CONTEXT_SIZE, combine=True
)
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
# 获取主动发言的话题
all_nodes = HippocampusManager.get_instance().memory_graph.dots
all_nodes = filter(lambda dot: len(dot[1]["memory_items"]) > 3, all_nodes)
nodes_for_select = random.sample(all_nodes, 5)
topics = [info[0] for info in nodes_for_select]
# 激活prompt构建
activate_prompt = ""
activate_prompt = "以上是群里正在进行的聊天。"
personality = global_config.PROMPT_PERSONALITY
prompt_personality = ""
personality_choice = random.random()
if personality_choice < probability_1: # 第一种人格
prompt_personality = f"""{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[0]}"""
elif personality_choice < probability_1 + probability_2: # 第二种人格
prompt_personality = f"""{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[1]}"""
else: # 第三种人格
prompt_personality = f"""{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[2]}"""
topics_str = ",".join(f'"{topics}"')
prompt_for_select = (
f"你现在想在群里发言,回忆了一下,想到几个话题,分别是{topics_str},综合当前状态以及群内气氛,"
f"请你在其中选择一个合适的话题,注意只需要输出话题,除了话题什么也不要输出(双引号也不要输出)"
)
prompt_initiative_select = f"{prompt_date}\n{prompt_personality}\n{prompt_for_select}"
prompt_regular = f"{prompt_date}\n{prompt_personality}"
return prompt_initiative_select, nodes_for_select, prompt_regular
def _build_initiative_prompt_check(self, selected_node, prompt_regular):
memory = random.sample(selected_node["memory_items"], 3)
memory = "\n".join(memory)
prompt_for_check = (
f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']}"
f"关于这个话题的记忆有\n{memory}\n,以这个作为主题发言合适吗?请在把握群里的聊天内容的基础上,"
f"综合群内的氛围如果认为应该发言请输出yes否则输出no请注意是决定是否需要发言而不是编写回复内容"
f"除了yes和no不要输出任何回复内容。"
)
return prompt_for_check, memory
def _build_initiative_prompt(self, selected_node, prompt_regular, memory):
prompt_for_initiative = (
f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']}"
f"关于这个话题的记忆有\n{memory}\n,请在把握群里的聊天内容的基础上,综合群内的氛围,"
f"以日常且口语化的口吻,简短且随意一点进行发言,不要说的太有条理,可以有个性。"
f"记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情,@等)"
)
return prompt_for_initiative
prompt_builder = PromptBuilder()

View File

@@ -25,12 +25,19 @@ config_config = LogConfig(
logger = get_module_logger("config", config=config_config)
#考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
mai_version_main = "0.6.0"
is_test = False
mai_version_main = "0.6.1"
mai_version_fix = ""
if mai_version_fix:
mai_version = f"{mai_version_main}-{mai_version_fix}"
if is_test:
mai_version = f"test-{mai_version_main}-{mai_version_fix}"
else:
mai_version = f"{mai_version_main}-{mai_version_fix}"
else:
mai_version = mai_version_main
if is_test:
mai_version = f"test-{mai_version_main}"
else:
mai_version = mai_version_main
def update_config():
# 获取根目录路径
@@ -141,14 +148,22 @@ class BotConfig:
ban_user_id = set()
# personality
PROMPT_PERSONALITY = [
"用一句话或几句话描述性格特点和其他特征",
"例如,是一个热爱国家热爱党的新时代好青年",
"例如,曾经是一个学习地质的女大学生,现在学习心理学和脑科学,你会刷贴吧",
]
PERSONALITY_1: float = 0.6 # 第一种人格概率
PERSONALITY_2: float = 0.3 # 第二种人格概率
PERSONALITY_3: float = 0.1 # 第三种人格概率
personality_core = "用一句话或几句话描述人格的核心特点" # 建议20字以内谁再写3000字小作文敲谁脑袋
personality_sides: List[str] = field(default_factory=lambda: [
"用一句话或几句话描述人格的一些侧面",
"用一句话或几句话描述人格的一些侧面",
"用一句话或几句话描述人格的一些侧面"
])
# identity
identity_detail: List[str] = field(default_factory=lambda: [
"身份特点",
"身份特点",
])
height: int = 170 # 身高 单位厘米
weight: int = 50 # 体重 单位千克
age: int = 20 # 年龄 单位岁
gender: str = "" # 性别
appearance: str = "用几句话描述外貌特征" # 外貌特征
# schedule
ENABLE_SCHEDULE_GEN: bool = False # 是否启用日程生成
@@ -162,6 +177,7 @@ class BotConfig:
emoji_chance: float = 0.2 # 发送表情包的基础概率
thinking_timeout: int = 120 # 思考时间
max_response_length: int = 1024 # 最大回复长度
message_buffer: bool = True # 消息缓冲器
ban_words = set()
ban_msgs_regex = set()
@@ -339,14 +355,19 @@ class BotConfig:
def personality(parent: dict):
personality_config = parent["personality"]
personality = personality_config.get("prompt_personality")
if len(personality) >= 2:
logger.info(f"载入自定义人格:{personality}")
config.PROMPT_PERSONALITY = personality_config.get("prompt_personality", config.PROMPT_PERSONALITY)
if config.INNER_VERSION in SpecifierSet(">=1.2.4"):
config.personality_core = personality_config.get("personality_core", config.personality_core)
config.personality_sides = personality_config.get("personality_sides", config.personality_sides)
config.PERSONALITY_1 = personality_config.get("personality_1_probability", config.PERSONALITY_1)
config.PERSONALITY_2 = personality_config.get("personality_2_probability", config.PERSONALITY_2)
config.PERSONALITY_3 = personality_config.get("personality_3_probability", config.PERSONALITY_3)
def identity(parent: dict):
identity_config = parent["identity"]
if config.INNER_VERSION in SpecifierSet(">=1.2.4"):
config.identity_detail = identity_config.get("identity_detail", config.identity_detail)
config.height = identity_config.get("height", config.height)
config.weight = identity_config.get("weight", config.weight)
config.age = identity_config.get("age", config.age)
config.gender = identity_config.get("gender", config.gender)
config.appearance = identity_config.get("appearance", config.appearance)
def schedule(parent: dict):
schedule_config = parent["schedule"]
@@ -505,6 +526,8 @@ class BotConfig:
if config.INNER_VERSION in SpecifierSet(">=0.0.11"):
config.max_response_length = msg_config.get("max_response_length", config.max_response_length)
if config.INNER_VERSION in SpecifierSet(">=1.1.4"):
config.message_buffer = msg_config.get("message_buffer", config.message_buffer)
def memory(parent: dict):
memory_config = parent["memory"]
@@ -601,6 +624,7 @@ class BotConfig:
"bot": {"func": bot, "support": ">=0.0.0"},
"groups": {"func": groups, "support": ">=0.0.0"},
"personality": {"func": personality, "support": ">=0.0.0"},
"identity": {"func": identity, "support": ">=1.2.4"},
"schedule": {"func": schedule, "support": ">=0.0.11", "necessary": False},
"message": {"func": message, "support": ">=0.0.0"},
"willing": {"func": willing, "support": ">=0.0.9", "necessary": False},

View File

@@ -29,7 +29,10 @@ class BaseMessageHandler:
try:
tasks.append(handler(message))
except Exception as e:
raise RuntimeError(str(e)) from e
logger.error(f"消息处理出错: {str(e)}")
logger.error(traceback.format_exc())
# 不抛出异常,而是记录错误并继续处理其他消息
continue
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
@@ -212,9 +215,8 @@ class MessageServer(BaseMessageHandler):
try:
async with session.post(url, json=data, headers={"Content-Type": "application/json"}) as response:
return await response.json()
except Exception:
# logger.error(f"发送消息失败: {str(e)}")
pass
except Exception as e:
raise e
class BaseMessageAPI:

View File

@@ -6,6 +6,7 @@ from dataclasses import dataclass
from ..config.config import global_config
from src.common.logger import get_module_logger, LogConfig, MOOD_STYLE_CONFIG
from ..person_info.relationship_manager import relationship_manager
from src.individuality.individuality import Individuality
mood_config = LogConfig(
# 使用海马体专用样式
@@ -17,8 +18,8 @@ logger = get_module_logger("mood_manager", config=mood_config)
@dataclass
class MoodState:
valence: float # 愉悦度 (-1 到 1)
arousal: float # 唤醒度 (0 到 1)
valence: float # 愉悦度 (-1.0 到 1.0)-1表示极度负面1表示极度正面
arousal: float # 唤醒度 (0.0 到 1.0)0表示完全平静1表示极度兴奋
text: str # 心情文本描述
@@ -125,20 +126,48 @@ class MoodManager:
time.sleep(update_interval)
def _apply_decay(self) -> None:
"""应用情绪衰减"""
"""应用情绪衰减,正向和负向情绪分开计算"""
current_time = time.time()
time_diff = current_time - self.last_update
agreeableness_factor = 1
agreeableness_bias = 0
neuroticism_factor = 0.5
# Valence 向中性0回归
valence_target = 0
self.current_mood.valence = valence_target + (self.current_mood.valence - valence_target) * math.exp(
-self.decay_rate_valence * time_diff
)
# 获取人格特质
personality = Individuality.get_instance().personality
if personality:
# 神经质:影响情绪变化速度
neuroticism_factor = 1 + (personality.neuroticism - 0.5) * 0.5
agreeableness_factor = 1 + (personality.agreeableness - 0.5) * 0.5
# 宜人性:影响情绪基准线
if personality.agreeableness < 0.2:
agreeableness_bias = (personality.agreeableness - 0.2) * 2
elif personality.agreeableness > 0.8:
agreeableness_bias = (personality.agreeableness - 0.8) * 2
else:
agreeableness_bias = 0
# 分别计算正向和负向的衰减率
if self.current_mood.valence >= 0:
# 正向情绪衰减
decay_rate_positive = self.decay_rate_valence * (1/agreeableness_factor)
valence_target = 0 + agreeableness_bias
self.current_mood.valence = valence_target + (self.current_mood.valence - valence_target) * math.exp(
-decay_rate_positive * time_diff * neuroticism_factor
)
else:
# 负向情绪衰减
decay_rate_negative = self.decay_rate_valence * agreeableness_factor
valence_target = 0 + agreeableness_bias
self.current_mood.valence = valence_target + (self.current_mood.valence - valence_target) * math.exp(
-decay_rate_negative * time_diff * neuroticism_factor
)
# Arousal 向中性0.5)回归
arousal_target = 0.5
self.current_mood.arousal = arousal_target + (self.current_mood.arousal - arousal_target) * math.exp(
-self.decay_rate_arousal * time_diff
-self.decay_rate_arousal * time_diff * neuroticism_factor
)
# 确保值在合理范围内
@@ -237,7 +266,7 @@ class MoodManager:
old_arousal = self.current_mood.arousal
old_mood = self.current_mood.text
valence_change *= relationship_manager.gain_coefficient[relationship_manager.positive_feedback_value]
valence_change = relationship_manager.feedback_to_mood(valence_change)
# 应用情绪强度
valence_change *= intensity

View File

@@ -2,8 +2,14 @@ from src.common.logger import get_module_logger
from ...common.database import db
import copy
import hashlib
from typing import Any, Callable, Dict, TypeVar
T = TypeVar('T') # 泛型类型
from typing import Any, Callable, Dict
import datetime
import asyncio
import numpy
# import matplotlib.pyplot as plt
# from pathlib import Path
# import pandas as pd
"""
PersonInfoManager 类方法功能摘要:
@@ -15,6 +21,7 @@ PersonInfoManager 类方法功能摘要:
6. get_values - 批量获取字段值(任一字段无效则返回空字典)
7. del_all_undefined_field - 清理全集合中未定义的字段
8. get_specific_value_list - 根据指定条件返回person_id,value字典
9. personal_habit_deduction - 定时推断个人习惯
"""
logger = get_module_logger("person_info")
@@ -30,6 +37,8 @@ person_info_default = {
# "impression" : None,
# "gender" : Unkown,
"konw_time" : 0,
"msg_interval": 3000,
"msg_interval_list": []
} # 个人信息的各项与默认值在此定义,以下处理会自动创建/补全每一项
class PersonInfoManager:
@@ -108,8 +117,9 @@ class PersonInfoManager:
if document and field_name in document:
return document[field_name]
else:
logger.debug(f"获取{person_id}{field_name}失败,已返回默认值{person_info_default[field_name]}")
return person_info_default[field_name]
default_value = copy.deepcopy(person_info_default[field_name])
logger.debug(f"获取{person_id}{field_name}失败,已返回默认值{default_value}")
return default_value
async def get_values(self, person_id: str, field_names: list) -> dict:
"""获取指定person_id文档的多个字段值若不存在该字段则返回该字段的全局默认值"""
@@ -133,7 +143,10 @@ class PersonInfoManager:
result = {}
for field in field_names:
result[field] = document.get(field, person_info_default[field]) if document else person_info_default[field]
result[field] = copy.deepcopy(
document.get(field, person_info_default[field])
if document else person_info_default[field]
)
return result
@@ -209,5 +222,47 @@ class PersonInfoManager:
except Exception as e:
logger.error(f"数据库查询失败: {str(e)}", exc_info=True)
return {}
async def personal_habit_deduction(self):
"""启动个人信息推断,每天根据一定条件推断一次"""
try:
while(1):
await asyncio.sleep(60)
current_time = datetime.datetime.now()
logger.info(f"个人信息推断启动: {current_time.strftime('%Y-%m-%d %H:%M:%S')}")
person_info_manager = PersonInfoManager()
# "msg_interval"推断
msg_interval_lists = await self.get_specific_value_list(
"msg_interval_list",
lambda x: isinstance(x, list) and len(x) >= 100
)
for person_id, msg_interval_list_ in msg_interval_lists.items():
try:
time_interval = []
for t1, t2 in zip(msg_interval_list_, msg_interval_list_[1:]):
delta = t2 - t1
if delta < 8000 and delta > 0: # 小于8秒
time_interval.append(delta)
if len(time_interval) > 30:
# 移除matplotlib相关的绘图功能
filtered_intervals = [t for t in time_interval if t >= 500]
if len(filtered_intervals) > 25:
msg_interval = int(round(numpy.percentile(filtered_intervals, 80)))
await self.update_one_field(person_id, "msg_interval", msg_interval)
logger.debug(f"用户{person_id}的msg_interval已经被更新为{msg_interval}")
except Exception as e:
logger.debug(f"处理用户{person_id}msg_interval推断时出错: {str(e)}")
continue
# 其他...
logger.info(f"个人信息推断结束: {current_time.strftime('%Y-%m-%d %H:%M:%S')}")
await asyncio.sleep(86400)
except Exception as e:
logger.error(f"个人信息推断运行时出错: {str(e)}")
logger.exception("详细错误信息:")
person_info_manager = PersonInfoManager()

View File

@@ -63,7 +63,15 @@ class RelationshipManager:
value += value * mood_gain
logger.info(f"当前relationship增益系数{mood_gain:.3f}")
return value
def feedback_to_mood(self, mood_value):
"""对情绪的反馈"""
coefficient = self.gain_coefficient[abs(self.positive_feedback_value)]
if (mood_value > 0 and self.positive_feedback_value > 0
or mood_value < 0 and self.positive_feedback_value < 0):
return mood_value*coefficient
else:
return mood_value/coefficient
async def calculate_update_relationship_value(self, chat_stream: ChatStream, label: str, stance: str) -> None:
"""计算并变更关系值

View File

@@ -1,195 +0,0 @@
"""
The definition of artificial personality in this paper follows the dispositional para-digm and adapts a definition of
personality developed for humans [17]:
Personality for a human is the "whole and organisation of relatively stable tendencies and patterns of experience and
behaviour within one person (distinguishing it from other persons)". This definition is modified for artificial
personality:
Artificial personality describes the relatively stable tendencies and patterns of behav-iour of an AI-based machine that
can be designed by developers and designers via different modalities, such as language, creating the impression
of individuality of a humanized social agent when users interact with the machine."""
from typing import Dict, List
import json
import os
from pathlib import Path
from dotenv import load_dotenv
import sys
"""
第一种方案:基于情景评估的人格测定
"""
current_dir = Path(__file__).resolve().parent
project_root = current_dir.parent.parent.parent
env_path = project_root / ".env"
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
from src.plugins.personality.scene import get_scene_by_factor, PERSONALITY_SCENES # noqa: E402
from src.plugins.personality.questionnaire import FACTOR_DESCRIPTIONS # noqa: E402
from src.plugins.personality.offline_llm import LLMModel # noqa: E402
# 加载环境变量
if env_path.exists():
print(f"{env_path} 加载环境变量")
load_dotenv(env_path)
else:
print(f"未找到环境变量文件: {env_path}")
print("将使用默认配置")
class PersonalityEvaluator_direct:
def __init__(self):
self.personality_traits = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
self.scenarios = []
# 为每个人格特质获取对应的场景
for trait in PERSONALITY_SCENES:
scenes = get_scene_by_factor(trait)
if not scenes:
continue
# 从每个维度选择3个场景
import random
scene_keys = list(scenes.keys())
selected_scenes = random.sample(scene_keys, min(3, len(scene_keys)))
for scene_key in selected_scenes:
scene = scenes[scene_key]
# 为每个场景添加评估维度
# 主维度是当前特质,次维度随机选择一个其他特质
other_traits = [t for t in PERSONALITY_SCENES if t != trait]
secondary_trait = random.choice(other_traits)
self.scenarios.append(
{"场景": scene["scenario"], "评估维度": [trait, secondary_trait], "场景编号": scene_key}
)
self.llm = LLMModel()
def evaluate_response(self, scenario: str, response: str, dimensions: List[str]) -> Dict[str, float]:
"""
使用 DeepSeek AI 评估用户对特定场景的反应
"""
# 构建维度描述
dimension_descriptions = []
for dim in dimensions:
desc = FACTOR_DESCRIPTIONS.get(dim, "")
if desc:
dimension_descriptions.append(f"- {dim}{desc}")
dimensions_text = "\n".join(dimension_descriptions)
prompt = f"""请根据以下场景和用户描述评估用户在大五人格模型中的相关维度得分1-6分
场景描述:
{scenario}
用户回应:
{response}
需要评估的维度说明:
{dimensions_text}
请按照以下格式输出评估结果仅输出JSON格式
{{
"{dimensions[0]}": 分数,
"{dimensions[1]}": 分数
}}
评分标准:
1 = 非常不符合该维度特征
2 = 比较不符合该维度特征
3 = 有点不符合该维度特征
4 = 有点符合该维度特征
5 = 比较符合该维度特征
6 = 非常符合该维度特征
请根据用户的回应结合场景和维度说明进行评分。确保分数在1-6之间并给出合理的评估。"""
try:
ai_response, _ = self.llm.generate_response(prompt)
# 尝试从AI响应中提取JSON部分
start_idx = ai_response.find("{")
end_idx = ai_response.rfind("}") + 1
if start_idx != -1 and end_idx != 0:
json_str = ai_response[start_idx:end_idx]
scores = json.loads(json_str)
# 确保所有分数在1-6之间
return {k: max(1, min(6, float(v))) for k, v in scores.items()}
else:
print("AI响应格式不正确使用默认评分")
return {dim: 3.5 for dim in dimensions}
except Exception as e:
print(f"评估过程出错:{str(e)}")
return {dim: 3.5 for dim in dimensions}
def main():
print("欢迎使用人格形象创建程序!")
print("接下来您将面对一系列场景共15个。请根据您想要创建的角色形象描述在该场景下可能的反应。")
print("每个场景都会评估不同的人格维度,最终得出完整的人格特征评估。")
print("评分标准1=非常不符合2=比较不符合3=有点不符合4=有点符合5=比较符合6=非常符合")
print("\n准备好了吗?按回车键开始...")
input()
evaluator = PersonalityEvaluator_direct()
final_scores = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
dimension_counts = {trait: 0 for trait in final_scores.keys()}
for i, scenario_data in enumerate(evaluator.scenarios, 1):
print(f"\n场景 {i}/{len(evaluator.scenarios)} - {scenario_data['场景编号']}:")
print("-" * 50)
print(scenario_data["场景"])
print("\n请描述您的角色在这种情况下会如何反应:")
response = input().strip()
if not response:
print("反应描述不能为空!")
continue
print("\n正在评估您的描述...")
scores = evaluator.evaluate_response(scenario_data["场景"], response, scenario_data["评估维度"])
# 更新最终分数
for dimension, score in scores.items():
final_scores[dimension] += score
dimension_counts[dimension] += 1
print("\n当前评估结果:")
print("-" * 30)
for dimension, score in scores.items():
print(f"{dimension}: {score}/6")
if i < len(evaluator.scenarios):
print("\n按回车键继续下一个场景...")
input()
# 计算平均分
for dimension in final_scores:
if dimension_counts[dimension] > 0:
final_scores[dimension] = round(final_scores[dimension] / dimension_counts[dimension], 2)
print("\n最终人格特征评估结果:")
print("-" * 30)
for trait, score in final_scores.items():
print(f"{trait}: {score}/6")
print(f"测试场景数:{dimension_counts[trait]}")
# 保存结果
result = {"final_scores": final_scores, "dimension_counts": dimension_counts, "scenarios": evaluator.scenarios}
# 确保目录存在
os.makedirs("results", exist_ok=True)
# 保存到文件
with open("results/personality_result.json", "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print("\n结果已保存到 results/personality_result.json")
if __name__ == "__main__":
main()

View File

@@ -1,261 +0,0 @@
from typing import Dict
PERSONALITY_SCENES = {
"外向性": {
"场景1": {
"scenario": """你刚刚搬到一个新的城市工作。今天是你入职的第一天,在公司的电梯里,一位同事微笑着和你打招呼:
同事:「嗨!你是新来的同事吧?我是市场部的小林。」
同事看起来很友善,还主动介绍说:「待会午饭时间,我们部门有几个人准备一起去楼下新开的餐厅,你要一起来吗?可以认识一下其他同事。」""",
"explanation": "这个场景通过职场社交情境,观察个体对于新环境、新社交圈的态度和反应倾向。",
},
"场景2": {
"scenario": """在大学班级群里,班长发起了一个组织班级联谊活动的投票:
班长「大家好下周末我们准备举办一次班级联谊活动地点在学校附近的KTV。想请大家报名参加也欢迎大家邀请其他班级的同学
已经有几个同学在群里积极响应,有人@你问你要不要一起参加。""",
"explanation": "通过班级活动场景,观察个体对群体社交活动的参与意愿。",
},
"场景3": {
"scenario": """你在社交平台上发布了一条动态,收到了很多陌生网友的评论和私信:
网友A「你说的这个观点很有意思想和你多交流一下。」
网友B「我也对这个话题很感兴趣要不要建个群一起讨论""",
"explanation": "通过网络社交场景,观察个体对线上社交的态度。",
},
"场景4": {
"scenario": """你暗恋的对象今天主动来找你:
对方:「那个...我最近在准备一个演讲比赛,听说你口才很好。能不能请你帮我看看演讲稿,顺便给我一些建议?"""
"""如果你有时间的话,可以一起吃个饭聊聊。」""",
"explanation": "通过恋爱情境,观察个体在面对心仪对象时的社交表现。",
},
"场景5": {
"scenario": """在一次线下读书会上,主持人突然点名让你分享读后感:
主持人:「听说你对这本书很有见解,能不能和大家分享一下你的想法?」
现场有二十多个陌生的读书爱好者,都期待地看着你。""",
"explanation": "通过即兴发言场景,观察个体的社交表现欲和公众表达能力。",
},
},
"神经质": {
"场景1": {
"scenario": """你正在准备一个重要的项目演示,这关系到你的晋升机会。"""
"""就在演示前30分钟你收到了主管发来的消息
主管「临时有个变动CEO也会来听你的演示。他对这个项目特别感兴趣。」
正当你准备回复时主管又发来一条「对了能不能把演示时间压缩到15分钟CEO下午还有其他安排。你之前准备的是30分钟的版本对吧""",
"explanation": "这个场景通过突发的压力情境,观察个体在面对计划外变化时的情绪反应和调节能力。",
},
"场景2": {
"scenario": """期末考试前一天晚上,你收到了好朋友发来的消息:
好朋友:「不好意思这么晚打扰你...我看你平时成绩很好,能不能帮我解答几个问题?我真的很担心明天的考试。」
你看了看时间已经是晚上11点而你原本计划的复习还没完成。""",
"explanation": "通过考试压力场景,观察个体在时间紧张时的情绪管理。",
},
"场景3": {
"scenario": """你在社交媒体上发表的一个观点引发了争议,有不少人开始批评你:
网友A「这种观点也好意思说出来真是无知。」
网友B「建议楼主先去补补课再来发言。」
评论区里的负面评论越来越多,还有人开始人身攻击。""",
"explanation": "通过网络争议场景,观察个体面对批评时的心理承受能力。",
},
"场景4": {
"scenario": """你和恋人约好今天一起看电影,但在约定时间前半小时,对方发来消息:
恋人:「对不起,我临时有点事,可能要迟到一会儿。」
二十分钟后,对方又发来消息:「可能要再等等,抱歉!」
电影快要开始了,但对方还是没有出现。""",
"explanation": "通过恋爱情境,观察个体对不确定性的忍耐程度。",
},
"场景5": {
"scenario": """在一次重要的小组展示中,你的组员在演示途中突然卡壳了:
组员小声对你说:「我忘词了,接下来的部分是什么来着...」
台下的老师和同学都在等待,气氛有些尴尬。""",
"explanation": "通过公开场合的突发状况,观察个体的应急反应和压力处理能力。",
},
},
"严谨性": {
"场景1": {
"scenario": """你是团队的项目负责人,刚刚接手了一个为期两个月的重要项目。在第一次团队会议上:
小王:「老大,我觉得两个月时间很充裕,我们先做着看吧,遇到问题再解决。」
小张:「要不要先列个时间表?不过感觉太详细的计划也没必要,点到为止就行。」
小李:「客户那边说如果能提前完成有奖励,我觉得我们可以先做快一点的部分。」""",
"explanation": "这个场景通过项目管理情境,体现个体在工作方法、计划性和责任心方面的特征。",
},
"场景2": {
"scenario": """期末小组作业,组长让大家分工完成一份研究报告。在截止日期前三天:
组员A「我的部分大概写完了感觉还行。」
组员B「我这边可能还要一天才能完成最近太忙了。」
组员C发来一份没有任何引用出处、可能存在抄袭的内容「我写完了你们看看怎么样""",
"explanation": "通过学习场景,观察个体对学术规范和质量要求的重视程度。",
},
"场景3": {
"scenario": """你在一个兴趣小组的群聊中,大家正在讨论举办一次线下活动:
成员A「到时候见面就知道具体怎么玩了
成员B「对啊随意一点挺好的。」
成员C「人来了自然就热闹了。」""",
"explanation": "通过活动组织场景,观察个体对活动计划的态度。",
},
"场景4": {
"scenario": """你和恋人计划一起去旅游,对方说:
恋人:「我们就随心而行吧!订个目的地,其他的到了再说,这样更有意思。」
距离出发还有一周时间,但机票、住宿和具体行程都还没有确定。""",
"explanation": "通过旅行规划场景,观察个体的计划性和对不确定性的接受程度。",
},
"场景5": {
"scenario": """在一个重要的团队项目中,你发现一个同事的工作存在明显错误:
同事:「差不多就行了,反正领导也看不出来。」
这个错误可能不会立即造成问题,但长期来看可能会影响项目质量。""",
"explanation": "通过工作质量场景,观察个体对细节和标准的坚持程度。",
},
},
"开放性": {
"场景1": {
"scenario": """周末下午,你的好友小美兴致勃勃地给你打电话:
小美:「我刚发现一个特别有意思的沉浸式艺术展!不是传统那种挂画的展览,而是把整个空间都变成了艺术品。"""
"""观众要穿特制的服装还要带上VR眼镜好像还有AI实时互动
小美继续说:「虽然票价不便宜,但听说体验很独特。网上评价两极分化,有人说是前所未有的艺术革新,也有人说是哗众取宠。"""
"""要不要周末一起去体验一下?」""",
"explanation": "这个场景通过新型艺术体验,反映个体对创新事物的接受程度和尝试意愿。",
},
"场景2": {
"scenario": """在一节创意写作课上,老师提出了一个特别的作业:
老师「下周的作业是用AI写作工具协助创作一篇小说。你们可以自由探索如何与AI合作打破传统写作方式。」
班上随即展开了激烈讨论,有人认为这是对创作的亵渎,也有人对这种新形式感到兴奋。""",
"explanation": "通过新技术应用场景,观察个体对创新学习方式的态度。",
},
"场景3": {
"scenario": """在社交媒体上,你看到一个朋友分享了一种新的生活方式:
「最近我在尝试'数字游牧'生活,就是一边远程工作一边环游世界。"""
"""没有固定住所,住青旅或短租,认识来自世界各地的朋友。虽然有时会很不稳定,但这种自由的生活方式真的很棒!」
评论区里争论不断,有人向往这种生活,也有人觉得太冒险。""",
"explanation": "通过另类生活方式,观察个体对非传统选择的态度。",
},
"场景4": {
"scenario": """你的恋人突然提出了一个想法:
恋人:「我们要不要尝试一下开放式关系?就是在保持彼此关系的同时,也允许和其他人发展感情。现在国外很多年轻人都这样。」
这个提议让你感到意外,你之前从未考虑过这种可能性。""",
"explanation": "通过感情观念场景,观察个体对非传统关系模式的接受度。",
},
"场景5": {
"scenario": """在一次朋友聚会上,大家正在讨论未来职业规划:
朋友A「我准备辞职去做自媒体专门介绍一些小众的文化和艺术。」
朋友B「我想去学习生物科技准备转行做人造肉研发。」
朋友C「我在考虑加入一个区块链创业项目虽然风险很大。」""",
"explanation": "通过职业选择场景,观察个体对新兴领域的探索意愿。",
},
},
"宜人性": {
"场景1": {
"scenario": """在回家的公交车上,你遇到这样一幕:
一位老奶奶颤颤巍巍地上了车,车上座位已经坐满了。她站在你旁边,看起来很疲惫。这时你听到前排两个年轻人的对话:
年轻人A「那个老太太好像站不稳看起来挺累的。」
年轻人B「现在的老年人真是...我看她包里还有菜,肯定是去菜市场买完菜回来的,这么多人都不知道叫子女开车接送。」
就在这时,老奶奶一个趔趄,差点摔倒。她扶住了扶手,但包里的东西洒了一些出来。""",
"explanation": "这个场景通过公共场合的助人情境,体现个体的同理心和对他人需求的关注程度。",
},
"场景2": {
"scenario": """在班级群里,有同学发起为生病住院的同学捐款:
同学A「大家好小林最近得了重病住院医药费很贵家里负担很重。我们要不要一起帮帮他
同学B「我觉得这是他家里的事我们不方便参与吧。」
同学C「但是都是同学一场帮帮忙也是应该的。」""",
"explanation": "通过同学互助场景,观察个体的助人意愿和同理心。",
},
"场景3": {
"scenario": """在一个网络讨论组里,有人发布了求助信息:
求助者:「最近心情很低落,感觉生活很压抑,不知道该怎么办...」
评论区里已经有一些回复:
「生活本来就是这样,想开点!」
「你这样子太消极了,要积极面对。」
「谁还没点烦心事啊,过段时间就好了。」""",
"explanation": "通过网络互助场景,观察个体的共情能力和安慰方式。",
},
"场景4": {
"scenario": """你的恋人向你倾诉工作压力:
恋人:「最近工作真的好累,感觉快坚持不下去了...」
但今天你也遇到了很多烦心事,心情也不太好。""",
"explanation": "通过感情关系场景,观察个体在自身状态不佳时的关怀能力。",
},
"场景5": {
"scenario": """在一次团队项目中,新来的同事小王因为经验不足,造成了一个严重的错误。在部门会议上:
主管:「这个错误造成了很大的损失,是谁负责的这部分?」
小王看起来很紧张,欲言又止。你知道是他造成的错误,同时你也是这个项目的共同负责人。""",
"explanation": "通过职场情境,观察个体在面对他人过错时的态度和处理方式。",
},
},
}
def get_scene_by_factor(factor: str) -> Dict:
"""
根据人格因子获取对应的情景测试
Args:
factor (str): 人格因子名称
Returns:
Dict: 包含情景描述的字典
"""
return PERSONALITY_SCENES.get(factor, None)
def get_all_scenes() -> Dict:
"""
获取所有情景测试
Returns:
Dict: 所有情景测试的字典
"""
return PERSONALITY_SCENES

View File

@@ -62,9 +62,7 @@ class ScheduleGenerator:
self.name = name
self.behavior = behavior
self.schedule_doing_update_interval = interval
for pers in personality:
self.personality += pers + "\n"
self.personality = personality
async def mai_schedule_start(self):
"""启动日程系统每5分钟执行一次move_doing并在日期变化时重新检查日程"""

View File

@@ -41,7 +41,7 @@ class KnowledgeLibrary:
return f.read()
def split_content(self, content: str, max_length: int = 512) -> list:
"""将内容分割成适当大小的块,保持段落完整性
"""将内容分割成适当大小的块,按空行分割
Args:
content: 要分割的文本内容
@@ -50,67 +50,21 @@ class KnowledgeLibrary:
Returns:
list: 分割后的文本块列表
"""
# 首先按段落分割
# 按空行分割内容
paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
chunks = []
current_chunk = []
current_length = 0
for para in paragraphs:
para_length = len(para)
# 如果单个段落就超过最大长度
if para_length > max_length:
# 如果当前chunk不为空先保存
if current_chunk:
chunks.append("\n".join(current_chunk))
current_chunk = []
current_length = 0
# 将长段落按句子分割
sentences = [
s.strip()
for s in para.replace("", "\n").replace("", "\n").replace("", "\n").split("\n")
if s.strip()
]
temp_chunk = []
temp_length = 0
for sentence in sentences:
sentence_length = len(sentence)
if sentence_length > max_length:
# 如果单个句子超长,强制按长度分割
if temp_chunk:
chunks.append("\n".join(temp_chunk))
temp_chunk = []
temp_length = 0
for i in range(0, len(sentence), max_length):
chunks.append(sentence[i : i + max_length])
elif temp_length + sentence_length + 1 <= max_length:
temp_chunk.append(sentence)
temp_length += sentence_length + 1
else:
chunks.append("\n".join(temp_chunk))
temp_chunk = [sentence]
temp_length = sentence_length
if temp_chunk:
chunks.append("\n".join(temp_chunk))
# 如果当前段落加上现有chunk不超过最大长度
elif current_length + para_length + 1 <= max_length:
current_chunk.append(para)
current_length += para_length + 1
# 如果段落长度小于等于最大长度,直接添加
if para_length <= max_length:
chunks.append(para)
else:
# 保存当前chunk并开始新的chunk
chunks.append("\n".join(current_chunk))
current_chunk = [para]
current_length = para_length
# 添加最后一个chunk
if current_chunk:
chunks.append("\n".join(current_chunk))
# 如果段落超过最大长度,则按最大长度切分
for i in range(0, para_length, max_length):
chunks.append(para[i:i + max_length])
return chunks
def get_embedding(self, text: str) -> list: