ruff reformatted

This commit is contained in:
春河晴
2025-04-08 15:31:13 +09:00
parent 0d7068acab
commit 7840a6080d
40 changed files with 1227 additions and 1336 deletions

View File

@@ -24,10 +24,10 @@
# # 标记GUI是否运行中
# self.is_running = True
# # 程序关闭时的清理操作
# self.protocol("WM_DELETE_WINDOW", self._on_closing)
# # 初始化进程、日志队列、日志数据等变量
# self.process = None
# self.log_queue = queue.Queue()
@@ -236,7 +236,7 @@
# while not self.log_queue.empty():
# line = self.log_queue.get()
# self.process_log_line(line)
# # 仅在GUI仍在运行时继续处理队列
# if self.is_running:
# self.after(100, self.process_log_queue)
@@ -245,11 +245,11 @@
# """解析单行日志并更新日志数据和筛选器"""
# match = re.match(
# r"""^
# (?:(?P<time>\d{2}:\d{2}(?::\d{2})?)\s*\|\s*)?
# (?P<level>\w+)\s*\|\s*
# (?P<module>.*?)
# \s*[-|]\s*
# (?P<message>.*)
# (?:(?P<time>\d{2}:\d{2}(?::\d{2})?)\s*\|\s*)?
# (?P<level>\w+)\s*\|\s*
# (?P<module>.*?)
# \s*[-|]\s*
# (?P<message>.*)
# $""",
# line.strip(),
# re.VERBOSE,
@@ -354,10 +354,10 @@
# """处理窗口关闭事件,安全清理资源"""
# # 标记GUI已关闭
# self.is_running = False
# # 停止日志进程
# self.stop_process()
# # 安全清理tkinter变量
# for attr_name in list(self.__dict__.keys()):
# if isinstance(getattr(self, attr_name), (ctk.Variable, ctk.StringVar, ctk.IntVar, ctk.DoubleVar, ctk.BooleanVar)):
@@ -367,7 +367,7 @@
# except Exception:
# pass
# setattr(self, attr_name, None)
# self.quit()
# sys.exit(0)

View File

@@ -127,7 +127,7 @@
# """处理窗口关闭事件"""
# # 标记GUI已关闭防止后台线程继续访问tkinter对象
# self.is_running = False
# # 安全清理所有可能的tkinter变量
# for attr_name in list(self.__dict__.keys()):
# if isinstance(getattr(self, attr_name), (ctk.Variable, ctk.StringVar, ctk.IntVar, ctk.DoubleVar, ctk.BooleanVar)):
@@ -138,7 +138,7 @@
# except Exception:
# pass
# setattr(self, attr_name, None)
# # 退出
# self.root.quit()
# sys.exit(0)
@@ -259,7 +259,7 @@
# while True:
# if not self.is_running:
# break # 如果GUI已关闭停止线程
# try:
# # 从数据库获取最新数据,只获取启动时间之后的记录
# query = {"time": {"$gt": self.start_timestamp}}

View File

@@ -42,7 +42,6 @@ class Heartflow:
self._subheartflows = {}
self.active_subheartflows_nums = 0
async def _cleanup_inactive_subheartflows(self):
"""定期清理不活跃的子心流"""
while True:
@@ -84,25 +83,22 @@ class Heartflow:
# 开始构建prompt
prompt_personality = ""
#person
# person
individuality = Individuality.get_instance()
personality_core = individuality.personality.personality_core
prompt_personality += personality_core
personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}"
identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}"
personality_info = prompt_personality
current_thinking_info = self.current_mind
mood_info = self.current_state.mood
related_memory_info = "memory"
@@ -146,22 +142,20 @@ class Heartflow:
async def minds_summary(self, minds_str):
# 开始构建prompt
prompt_personality = ""
#person
# person
individuality = Individuality.get_instance()
personality_core = individuality.personality.personality_core
prompt_personality += personality_core
personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}"
identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}"
personality_info = prompt_personality
mood_info = self.current_state.mood
@@ -183,7 +177,7 @@ class Heartflow:
添加一个SubHeartflow实例到self._subheartflows字典中
并根据subheartflow_id为子心流创建一个观察对象
"""
try:
if subheartflow_id not in self._subheartflows:
logger.debug(f"创建 subheartflow: {subheartflow_id}")

View File

@@ -7,6 +7,7 @@ from src.common.database import db
from src.individuality.individuality import Individuality
import random
# 所有观察的基类
class Observation:
def __init__(self, observe_type, observe_id):
@@ -24,7 +25,7 @@ class ChattingObservation(Observation):
self.talking_message = []
self.talking_message_str = ""
self.name = global_config.BOT_NICKNAME
self.nick_name = global_config.BOT_ALIAS_NAMES
@@ -57,7 +58,7 @@ class ChattingObservation(Observation):
for msg in new_messages:
if "detailed_plain_text" in msg:
new_messages_str += f"{msg['detailed_plain_text']}"
# print(f"new_messages_str{new_messages_str}")
# 将新消息添加到talking_message同时保持列表长度不超过20条
@@ -117,26 +118,22 @@ class ChattingObservation(Observation):
# print(f"更新聊天总结:{self.talking_summary}")
# 开始构建prompt
prompt_personality = ""
#person
# person
individuality = Individuality.get_instance()
personality_core = individuality.personality.personality_core
prompt_personality += personality_core
personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}"
identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}"
personality_info = prompt_personality
prompt = ""
prompt += f"{personality_info},请注意识别你自己的聊天发言"
prompt += f"你的名字叫:{self.name},你的昵称是:{self.nick_name}\n"
@@ -148,7 +145,6 @@ class ChattingObservation(Observation):
self.observe_info, reasoning_content = await self.llm_summary.generate_response_async(prompt)
print(f"prompt{prompt}")
print(f"self.observe_info{self.observe_info}")
def translate_message_list_to_str(self):
self.talking_message_str = ""

View File

@@ -53,11 +53,10 @@ class SubHeartflow:
if not self.current_mind:
self.current_mind = "你什么也没想"
self.is_active = False
self.observations: list[Observation] = []
self.running_knowledges = []
def add_observation(self, observation: Observation):
@@ -86,7 +85,9 @@ class SubHeartflow:
async def subheartflow_start_working(self):
while True:
current_time = time.time()
if current_time - self.last_reply_time > global_config.sub_heart_flow_freeze_time: # 120秒无回复/不在场,冻结
if (
current_time - self.last_reply_time > global_config.sub_heart_flow_freeze_time
): # 120秒无回复/不在场,冻结
self.is_active = False
await asyncio.sleep(global_config.sub_heart_flow_update_interval) # 每60秒检查一次
else:
@@ -100,7 +101,9 @@ class SubHeartflow:
await asyncio.sleep(global_config.sub_heart_flow_update_interval)
# 检查是否超过10分钟没有激活
if current_time - self.last_active_time > global_config.sub_heart_flow_stop_time: # 5分钟无回复/不在场,销毁
if (
current_time - self.last_active_time > global_config.sub_heart_flow_stop_time
): # 5分钟无回复/不在场,销毁
logger.info(f"子心流 {self.subheartflow_id} 已经5分钟没有激活正在销毁...")
break # 退出循环,销毁自己
@@ -147,11 +150,11 @@ class SubHeartflow:
# self.current_mind = reponse
# logger.debug(f"prompt:\n{prompt}\n")
# logger.info(f"麦麦的脑内状态:{self.current_mind}")
async def do_observe(self):
observation = self.observations[0]
await observation.observe()
async def do_thinking_before_reply(self, message_txt):
current_thinking_info = self.current_mind
mood_info = self.current_state.mood
@@ -162,23 +165,20 @@ class SubHeartflow:
# 开始构建prompt
prompt_personality = ""
#person
# person
individuality = Individuality.get_instance()
personality_core = individuality.personality.personality_core
prompt_personality += personality_core
personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}"
identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}"
# 调取记忆
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=chat_observe_info, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
@@ -191,7 +191,7 @@ class SubHeartflow:
else:
related_memory_info = ""
related_info,grouped_results = await self.get_prompt_info(chat_observe_info + message_txt, 0.4)
related_info, grouped_results = await self.get_prompt_info(chat_observe_info + message_txt, 0.4)
# print(related_info)
for _topic, results in grouped_results.items():
for result in results:
@@ -227,25 +227,23 @@ class SubHeartflow:
async def do_thinking_after_reply(self, reply_content, chat_talking_prompt):
# print("麦麦回复之后脑袋转起来了")
# 开始构建prompt
prompt_personality = ""
#person
# person
individuality = Individuality.get_instance()
personality_core = individuality.personality.personality_core
prompt_personality += personality_core
personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}"
identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}"
current_thinking_info = self.current_mind
mood_info = self.current_state.mood
@@ -279,22 +277,20 @@ class SubHeartflow:
async def judge_willing(self):
# 开始构建prompt
prompt_personality = ""
#person
# person
individuality = Individuality.get_instance()
personality_core = individuality.personality.personality_core
prompt_personality += personality_core
personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}"
identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}"
# print("麦麦闹情绪了1")
current_thinking_info = self.current_mind
mood_info = self.current_state.mood
@@ -320,13 +316,12 @@ class SubHeartflow:
def update_current_mind(self, reponse):
self.past_mind.append(self.current_mind)
self.current_mind = reponse
async def get_prompt_info(self, message: str, threshold: float):
start_time = time.time()
related_info = ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 1. 先从LLM获取主题类似于记忆系统的做法
topics = []
# try:
@@ -334,7 +329,7 @@ class SubHeartflow:
# hippocampus = HippocampusManager.get_instance()._hippocampus
# topic_num = min(5, max(1, int(len(message) * 0.1)))
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
# # 提取关键词
# topics = re.findall(r"<([^>]+)>", topics_response[0])
# if not topics:
@@ -345,7 +340,7 @@ class SubHeartflow:
# for topic in ",".join(topics).replace("", ",").replace("、", ",").replace(" ", ",").split(",")
# if topic.strip()
# ]
# logger.info(f"从LLM提取的主题: {', '.join(topics)}")
# except Exception as e:
# logger.error(f"从LLM提取主题失败: {str(e)}")
@@ -353,7 +348,7 @@ class SubHeartflow:
# words = jieba.cut(message)
# topics = [word for word in words if len(word) > 1][:5]
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
# 如果无法提取到主题,直接使用整个消息
if not topics:
logger.debug("未能提取到任何主题,使用整个消息进行查询")
@@ -361,26 +356,26 @@ class SubHeartflow:
if not embedding:
logger.error("获取消息嵌入向量失败")
return ""
related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}")
return related_info, {}
# 2. 对每个主题进行知识库查询
logger.info(f"开始处理{len(topics)}个主题的知识库查询")
# 优化批量获取嵌入向量减少API调用
embeddings = {}
topics_batch = [topic for topic in topics if len(topic) > 0]
if message: # 确保消息非空
topics_batch.append(message)
# 批量获取嵌入向量
embed_start_time = time.time()
for text in topics_batch:
if not text or len(text.strip()) == 0:
continue
try:
embedding = await get_embedding(text, request_type="info_retrieval")
if embedding:
@@ -389,17 +384,17 @@ class SubHeartflow:
logger.warning(f"获取'{text}'的嵌入向量失败")
except Exception as e:
logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}")
logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}")
if not embeddings:
logger.error("所有嵌入向量获取失败")
return ""
# 3. 对每个主题进行知识库查询
all_results = []
query_start_time = time.time()
# 首先添加原始消息的查询结果
if message in embeddings:
original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True)
@@ -408,12 +403,12 @@ class SubHeartflow:
result["topic"] = "原始消息"
all_results.extend(original_results)
logger.info(f"原始消息查询到{len(original_results)}条结果")
# 然后添加每个主题的查询结果
for topic in topics:
if not topic or topic not in embeddings:
continue
try:
topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True)
if topic_results:
@@ -424,9 +419,9 @@ class SubHeartflow:
logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果")
except Exception as e:
logger.error(f"查询主题'{topic}'时发生错误: {str(e)}")
logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果")
# 4. 去重和过滤
process_start_time = time.time()
unique_contents = set()
@@ -436,14 +431,16 @@ class SubHeartflow:
if content not in unique_contents:
unique_contents.add(content)
filtered_results.append(result)
# 5. 按相似度排序
filtered_results.sort(key=lambda x: x["similarity"], reverse=True)
# 6. 限制总数量最多10条
filtered_results = filtered_results[:10]
logger.info(f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果")
logger.info(
f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果"
)
# 7. 格式化输出
if filtered_results:
format_start_time = time.time()
@@ -453,7 +450,7 @@ class SubHeartflow:
if topic not in grouped_results:
grouped_results[topic] = []
grouped_results[topic].append(result)
# 按主题组织输出
for topic, results in grouped_results.items():
related_info += f"【主题: {topic}\n"
@@ -464,13 +461,15 @@ class SubHeartflow:
# related_info += f"{i}. [{similarity:.2f}] {content}\n"
related_info += f"{content}\n"
related_info += "\n"
logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}")
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}")
return related_info,grouped_results
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False) -> Union[str, list]:
logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}")
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}")
return related_info, grouped_results
def get_info_from_db(
self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
) -> Union[str, list]:
if not query_embedding:
return "" if not return_raw else []
# 使用余弦相似度计算

View File

@@ -2,27 +2,36 @@ from dataclasses import dataclass
from typing import List
import random
@dataclass
class Identity:
"""身份特征类"""
identity_detail: List[str] # 身份细节描述
height: int # 身高(厘米)
weight: int # 体重(千克)
age: int # 年龄
gender: str # 性别
appearance: str # 外貌特征
_instance = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, identity_detail: List[str] = None, height: int = 0, weight: int = 0,
age: int = 0, gender: str = "", appearance: str = ""):
def __init__(
self,
identity_detail: List[str] = None,
height: int = 0,
weight: int = 0,
age: int = 0,
gender: str = "",
appearance: str = "",
):
"""初始化身份特征
Args:
identity_detail: 身份细节描述列表
height: 身高(厘米)
@@ -39,23 +48,24 @@ class Identity:
self.age = age
self.gender = gender
self.appearance = appearance
@classmethod
def get_instance(cls) -> 'Identity':
def get_instance(cls) -> "Identity":
"""获取Identity单例实例
Returns:
Identity: 单例实例
"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
def initialize(cls, identity_detail: List[str], height: int, weight: int,
age: int, gender: str, appearance: str) -> 'Identity':
def initialize(
cls, identity_detail: List[str], height: int, weight: int, age: int, gender: str, appearance: str
) -> "Identity":
"""初始化身份特征
Args:
identity_detail: 身份细节描述列表
height: 身高(厘米)
@@ -63,7 +73,7 @@ class Identity:
age: 年龄
gender: 性别
appearance: 外貌特征
Returns:
Identity: 初始化后的身份特征实例
"""
@@ -75,8 +85,8 @@ class Identity:
instance.gender = gender
instance.appearance = appearance
return instance
def get_prompt(self,x_person,level):
def get_prompt(self, x_person, level):
"""
获取身份特征的prompt
"""
@@ -86,7 +96,7 @@ class Identity:
prompt_identity = ""
else:
prompt_identity = ""
if level == 1:
identity_detail = self.identity_detail
random.shuffle(identity_detail)
@@ -96,7 +106,7 @@ class Identity:
prompt_identity += f",{detail}"
prompt_identity += ""
return prompt_identity
def to_dict(self) -> dict:
"""将身份特征转换为字典格式"""
return {
@@ -105,13 +115,13 @@ class Identity:
"weight": self.weight,
"age": self.age,
"gender": self.gender,
"appearance": self.appearance
"appearance": self.appearance,
}
@classmethod
def from_dict(cls, data: dict) -> 'Identity':
def from_dict(cls, data: dict) -> "Identity":
"""从字典创建身份特征实例"""
instance = cls.get_instance()
for key, value in data.items():
setattr(instance, key, value)
return instance
return instance

View File

@@ -2,35 +2,46 @@ from typing import Optional
from .personality import Personality
from .identity import Identity
class Individuality:
"""个体特征管理类"""
_instance = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
self.personality: Optional[Personality] = None
self.identity: Optional[Identity] = None
@classmethod
def get_instance(cls) -> 'Individuality':
def get_instance(cls) -> "Individuality":
"""获取Individuality单例实例
Returns:
Individuality: 单例实例
"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
def initialize(self, bot_nickname: str, personality_core: str, personality_sides: list,
identity_detail: list, height: int, weight: int, age: int,
gender: str, appearance: str) -> None:
def initialize(
self,
bot_nickname: str,
personality_core: str,
personality_sides: list,
identity_detail: list,
height: int,
weight: int,
age: int,
gender: str,
appearance: str,
) -> None:
"""初始化个体特征
Args:
bot_nickname: 机器人昵称
personality_core: 人格核心特点
@@ -44,50 +55,43 @@ class Individuality:
"""
# 初始化人格
self.personality = Personality.initialize(
bot_nickname=bot_nickname,
personality_core=personality_core,
personality_sides=personality_sides
bot_nickname=bot_nickname, personality_core=personality_core, personality_sides=personality_sides
)
# 初始化身份
self.identity = Identity.initialize(
identity_detail=identity_detail,
height=height,
weight=weight,
age=age,
gender=gender,
appearance=appearance
identity_detail=identity_detail, height=height, weight=weight, age=age, gender=gender, appearance=appearance
)
def to_dict(self) -> dict:
"""将个体特征转换为字典格式"""
return {
"personality": self.personality.to_dict() if self.personality else None,
"identity": self.identity.to_dict() if self.identity else None
"identity": self.identity.to_dict() if self.identity else None,
}
@classmethod
def from_dict(cls, data: dict) -> 'Individuality':
def from_dict(cls, data: dict) -> "Individuality":
"""从字典创建个体特征实例"""
instance = cls.get_instance()
if data.get("personality"):
instance.personality = Personality.from_dict(data["personality"])
if data.get("identity"):
instance.identity = Identity.from_dict(data["identity"])
return instance
def get_prompt(self,type,x_person,level):
return instance
def get_prompt(self, type, x_person, level):
"""
获取个体特征的prompt
"""
if type == "personality":
return self.personality.get_prompt(x_person,level)
return self.personality.get_prompt(x_person, level)
elif type == "identity":
return self.identity.get_prompt(x_person,level)
return self.identity.get_prompt(x_person, level)
else:
return ""
def get_traits(self,factor):
def get_traits(self, factor):
"""
获取个体特征的特质
"""
@@ -101,5 +105,3 @@ class Individuality:
return self.personality.agreeableness
elif factor == "neuroticism":
return self.personality.neuroticism

View File

@@ -17,9 +17,9 @@ with open(config_path, "r", encoding="utf-8") as f:
config = toml.load(f)
# 现在可以导入src模块
from src.individuality.scene import get_scene_by_factor, PERSONALITY_SCENES #noqa E402
from src.individuality.questionnaire import FACTOR_DESCRIPTIONS #noqa E402
from src.individuality.offline_llm import LLM_request_off #noqa E402
from src.individuality.scene import get_scene_by_factor, PERSONALITY_SCENES # noqa E402
from src.individuality.questionnaire import FACTOR_DESCRIPTIONS # noqa E402
from src.individuality.offline_llm import LLM_request_off # noqa E402
# 加载环境变量
env_path = os.path.join(root_path, ".env")
@@ -32,13 +32,12 @@ else:
def adapt_scene(scene: str) -> str:
personality_core = config['personality']['personality_core']
personality_sides = config['personality']['personality_sides']
personality_core = config["personality"]["personality_core"]
personality_sides = config["personality"]["personality_sides"]
personality_side = random.choice(personality_sides)
identity_details = config['identity']['identity_detail']
identity_details = config["identity"]["identity_detail"]
identity_detail = random.choice(identity_details)
"""
根据config中的属性改编场景使其更适合当前角色
@@ -51,10 +50,10 @@ def adapt_scene(scene: str) -> str:
try:
prompt = f"""
这是一个参与人格测评的角色形象:
- 昵称: {config['bot']['nickname']}
- 性别: {config['identity']['gender']}
- 年龄: {config['identity']['age']}
- 外貌: {config['identity']['appearance']}
- 昵称: {config["bot"]["nickname"]}
- 性别: {config["identity"]["gender"]}
- 年龄: {config["identity"]["age"]}
- 外貌: {config["identity"]["appearance"]}
- 性格核心: {personality_core}
- 性格侧面: {personality_side}
- 身份细节: {identity_detail}
@@ -62,18 +61,18 @@ def adapt_scene(scene: str) -> str:
请根据上述形象,改编以下场景,在测评中,用户将根据该场景给出上述角色形象的反应:
{scene}
保持场景的本质不变,但最好贴近生活且具体,并且让它更适合这个角色。
改编后的场景应该自然、连贯,并考虑角色的年龄、身份和性格特点。只返回改编后的场景描述,不要包含其他说明。注意{config['bot']['nickname']}是面对这个场景的人,而不是场景的其他人。场景中不会有其描述,
改编后的场景应该自然、连贯,并考虑角色的年龄、身份和性格特点。只返回改编后的场景描述,不要包含其他说明。注意{config["bot"]["nickname"]}是面对这个场景的人,而不是场景的其他人。场景中不会有其描述,
现在,请你给出改编后的场景描述
"""
llm = LLM_request_off(model_name=config['model']['llm_normal']['name'])
llm = LLM_request_off(model_name=config["model"]["llm_normal"]["name"])
adapted_scene, _ = llm.generate_response(prompt)
# 检查返回的场景是否为空或错误信息
if not adapted_scene or "错误" in adapted_scene or "失败" in adapted_scene:
print("场景改编失败,将使用原始场景")
return scene
return adapted_scene
except Exception as e:
print(f"场景改编过程出错:{str(e)},将使用原始场景")
@@ -169,7 +168,7 @@ class PersonalityEvaluator_direct:
except Exception as e:
print(f"评估过程出错:{str(e)}")
return {dim: 3.5 for dim in dimensions}
def run_evaluation(self):
"""
运行整个评估过程
@@ -185,18 +184,23 @@ class PersonalityEvaluator_direct:
print(f"- 身份细节:{config['identity']['identity_detail']}")
print("\n准备好了吗?按回车键开始...")
input()
total_scenarios = len(self.scenarios)
progress_bar = tqdm(total=total_scenarios, desc="场景进度", ncols=100, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]')
progress_bar = tqdm(
total=total_scenarios,
desc="场景进度",
ncols=100,
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]",
)
for _i, scenario_data in enumerate(self.scenarios, 1):
# print(f"\n{'-' * 20} 场景 {i}/{total_scenarios} - {scenario_data['场景编号']} {'-' * 20}")
# 改编场景,使其更适合当前角色
print(f"{config['bot']['nickname']}祈祷中...")
adapted_scene = adapt_scene(scenario_data["场景"])
scenario_data["改编场景"] = adapted_scene
print(adapted_scene)
print(f"\n请描述{config['bot']['nickname']}在这种情况下会如何反应:")
response = input().strip()
@@ -220,13 +224,13 @@ class PersonalityEvaluator_direct:
# 更新进度条
progress_bar.update(1)
# if i < total_scenarios:
# print("\n按回车键继续下一个场景...")
# input()
# print("\n按回车键继续下一个场景...")
# input()
progress_bar.close()
# 计算平均分
for dimension in self.final_scores:
if self.dimension_counts[dimension] > 0:
@@ -241,26 +245,26 @@ class PersonalityEvaluator_direct:
# 返回评估结果
return self.get_result()
def get_result(self):
"""
获取评估结果
"""
return {
"final_scores": self.final_scores,
"dimension_counts": self.dimension_counts,
"final_scores": self.final_scores,
"dimension_counts": self.dimension_counts,
"scenarios": self.scenarios,
"bot_info": {
"nickname": config['bot']['nickname'],
"gender": config['identity']['gender'],
"age": config['identity']['age'],
"height": config['identity']['height'],
"weight": config['identity']['weight'],
"appearance": config['identity']['appearance'],
"personality_core": config['personality']['personality_core'],
"personality_sides": config['personality']['personality_sides'],
"identity_detail": config['identity']['identity_detail']
}
"nickname": config["bot"]["nickname"],
"gender": config["identity"]["gender"],
"age": config["identity"]["age"],
"height": config["identity"]["height"],
"weight": config["identity"]["weight"],
"appearance": config["identity"]["appearance"],
"personality_core": config["personality"]["personality_core"],
"personality_sides": config["personality"]["personality_sides"],
"identity_detail": config["identity"]["identity_detail"],
},
}
@@ -275,28 +279,28 @@ def main():
"extraversion": round(result["final_scores"]["外向性"] / 6, 1),
"agreeableness": round(result["final_scores"]["宜人性"] / 6, 1),
"neuroticism": round(result["final_scores"]["神经质"] / 6, 1),
"bot_nickname": config['bot']['nickname']
"bot_nickname": config["bot"]["nickname"],
}
# 确保目录存在
save_dir = os.path.join(root_path, "data", "personality")
os.makedirs(save_dir, exist_ok=True)
# 创建文件名,替换可能的非法字符
bot_name = config['bot']['nickname']
bot_name = config["bot"]["nickname"]
# 替换Windows文件名中不允许的字符
for char in ['\\', '/', ':', '*', '?', '"', '<', '>', '|']:
bot_name = bot_name.replace(char, '_')
for char in ["\\", "/", ":", "*", "?", '"', "<", ">", "|"]:
bot_name = bot_name.replace(char, "_")
file_name = f"{bot_name}_personality.per"
save_path = os.path.join(save_dir, file_name)
# 保存简化的结果
with open(save_path, "w", encoding="utf-8") as f:
json.dump(simplified_result, f, ensure_ascii=False, indent=4)
print(f"\n结果已保存到 {save_path}")
# 同时保存完整结果到results目录
os.makedirs("results", exist_ok=True)
with open("results/personality_result.json", "w", encoding="utf-8") as f:

View File

@@ -4,9 +4,11 @@ import json
from pathlib import Path
import random
@dataclass
class Personality:
"""人格特质类"""
openness: float # 开放性
conscientiousness: float # 尽责性
extraversion: float # 外向性
@@ -15,45 +17,45 @@ class Personality:
bot_nickname: str # 机器人昵称
personality_core: str # 人格核心特点
personality_sides: List[str] # 人格侧面描述
_instance = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, personality_core: str = "", personality_sides: List[str] = None):
if personality_sides is None:
personality_sides = []
self.personality_core = personality_core
self.personality_sides = personality_sides
@classmethod
def get_instance(cls) -> 'Personality':
def get_instance(cls) -> "Personality":
"""获取Personality单例实例
Returns:
Personality: 单例实例
"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
def _init_big_five_personality(self):
"""初始化大五人格特质"""
# 构建文件路径
personality_file = Path("data/personality") / f"{self.bot_nickname}_personality.per"
# 如果文件存在,读取文件
if personality_file.exists():
with open(personality_file, 'r', encoding='utf-8') as f:
with open(personality_file, "r", encoding="utf-8") as f:
personality_data = json.load(f)
self.openness = personality_data.get('openness', 0.5)
self.conscientiousness = personality_data.get('conscientiousness', 0.5)
self.extraversion = personality_data.get('extraversion', 0.5)
self.agreeableness = personality_data.get('agreeableness', 0.5)
self.neuroticism = personality_data.get('neuroticism', 0.5)
self.openness = personality_data.get("openness", 0.5)
self.conscientiousness = personality_data.get("conscientiousness", 0.5)
self.extraversion = personality_data.get("extraversion", 0.5)
self.agreeableness = personality_data.get("agreeableness", 0.5)
self.neuroticism = personality_data.get("neuroticism", 0.5)
else:
# 如果文件不存在根据personality_core和personality_core来设置大五人格特质
if "活泼" in self.personality_core or "开朗" in self.personality_sides:
@@ -62,31 +64,31 @@ class Personality:
else:
self.extraversion = 0.3
self.neuroticism = 0.5
if "认真" in self.personality_core or "负责" in self.personality_sides:
self.conscientiousness = 0.9
else:
self.conscientiousness = 0.5
if "友善" in self.personality_core or "温柔" in self.personality_sides:
self.agreeableness = 0.9
else:
self.agreeableness = 0.5
if "创新" in self.personality_core or "开放" in self.personality_sides:
self.openness = 0.8
else:
self.openness = 0.5
@classmethod
def initialize(cls, bot_nickname: str, personality_core: str, personality_sides: List[str]) -> 'Personality':
def initialize(cls, bot_nickname: str, personality_core: str, personality_sides: List[str]) -> "Personality":
"""初始化人格特质
Args:
bot_nickname: 机器人昵称
personality_core: 人格核心特点
personality_sides: 人格侧面描述
Returns:
Personality: 初始化后的人格特质实例
"""
@@ -96,7 +98,7 @@ class Personality:
instance.personality_sides = personality_sides
instance._init_big_five_personality()
return instance
def to_dict(self) -> Dict:
"""将人格特质转换为字典格式"""
return {
@@ -107,18 +109,18 @@ class Personality:
"neuroticism": self.neuroticism,
"bot_nickname": self.bot_nickname,
"personality_core": self.personality_core,
"personality_sides": self.personality_sides
"personality_sides": self.personality_sides,
}
@classmethod
def from_dict(cls, data: Dict) -> 'Personality':
def from_dict(cls, data: Dict) -> "Personality":
"""从字典创建人格特质实例"""
instance = cls.get_instance()
for key, value in data.items():
setattr(instance, key, value)
return instance
def get_prompt(self,x_person,level):
return instance
def get_prompt(self, x_person, level):
# 开始构建prompt
if x_person == 2:
prompt_personality = ""
@@ -126,10 +128,10 @@ class Personality:
prompt_personality = ""
else:
prompt_personality = ""
#person
# person
prompt_personality += self.personality_core
if level == 2:
personality_sides = self.personality_sides
random.shuffle(personality_sides)
@@ -140,5 +142,5 @@ class Personality:
prompt_personality += f",{side}"
prompt_personality += ""
return prompt_personality

View File

@@ -2,6 +2,7 @@ import json
from typing import Dict
import os
def load_scenes() -> Dict:
"""
从JSON文件加载场景数据
@@ -10,13 +11,15 @@ def load_scenes() -> Dict:
Dict: 包含所有场景的字典
"""
current_dir = os.path.dirname(os.path.abspath(__file__))
json_path = os.path.join(current_dir, 'template_scene.json')
with open(json_path, 'r', encoding='utf-8') as f:
json_path = os.path.join(current_dir, "template_scene.json")
with open(json_path, "r", encoding="utf-8") as f:
return json.load(f)
PERSONALITY_SCENES = load_scenes()
def get_scene_by_factor(factor: str) -> Dict:
"""
根据人格因子获取对应的情景测试

View File

@@ -100,7 +100,7 @@ class MainSystem:
weight=global_config.weight,
age=global_config.age,
gender=global_config.gender,
appearance=global_config.appearance
appearance=global_config.appearance,
)
logger.success("个体特征初始化成功")
@@ -135,7 +135,6 @@ class MainSystem:
await asyncio.sleep(global_config.build_memory_interval)
logger.info("正在进行记忆构建")
await HippocampusManager.get_instance().build_memory()
async def forget_memory_task(self):
"""记忆遗忘任务"""
@@ -144,7 +143,6 @@ class MainSystem:
print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
await HippocampusManager.get_instance().forget_memory(percentage=global_config.memory_forget_percentage)
print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
async def print_mood_task(self):
"""打印情绪状态"""

View File

@@ -1,6 +1,6 @@
import time
import asyncio
from typing import Optional, Dict, Any, List, Tuple
from typing import Optional, Dict, Any, List, Tuple
from src.common.logger import get_module_logger
from src.common.database import db
from ..message.message_base import UserInfo
@@ -8,99 +8,97 @@ from ..config.config import global_config
logger = get_module_logger("chat_observer")
class ChatObserver:
"""聊天状态观察器"""
# 类级别的实例管理
_instances: Dict[str, 'ChatObserver'] = {}
_instances: Dict[str, "ChatObserver"] = {}
@classmethod
def get_instance(cls, stream_id: str) -> 'ChatObserver':
def get_instance(cls, stream_id: str) -> "ChatObserver":
"""获取或创建观察器实例
Args:
stream_id: 聊天流ID
Returns:
ChatObserver: 观察器实例
"""
if stream_id not in cls._instances:
cls._instances[stream_id] = cls(stream_id)
return cls._instances[stream_id]
def __init__(self, stream_id: str):
"""初始化观察器
Args:
stream_id: 聊天流ID
"""
if stream_id in self._instances:
raise RuntimeError(f"ChatObserver for {stream_id} already exists. Use get_instance() instead.")
self.stream_id = stream_id
self.last_user_speak_time: Optional[float] = None # 对方上次发言时间
self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
self.last_check_time: float = time.time() # 上次查看聊天记录时间
self.last_message_read: Optional[str] = None # 最后读取的消息ID
self.last_message_time: Optional[float] = None # 最后一条消息的时间戳
self.waiting_start_time: Optional[float] = None # 等待开始时间
self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
self.last_check_time: float = time.time() # 上次查看聊天记录时间
self.last_message_read: Optional[str] = None # 最后读取的消息ID
self.last_message_time: Optional[float] = None # 最后一条消息的时间戳
self.waiting_start_time: Optional[float] = None # 等待开始时间
# 消息历史记录
self.message_history: List[Dict[str, Any]] = [] # 所有消息历史
self.last_message_id: Optional[str] = None # 最后一条消息的ID
self.message_count: int = 0 # 消息计数
self.last_message_id: Optional[str] = None # 最后一条消息的ID
self.message_count: int = 0 # 消息计数
# 运行状态
self._running: bool = False
self._task: Optional[asyncio.Task] = None
self._update_event = asyncio.Event() # 触发更新的事件
self._update_complete = asyncio.Event() # 更新完成的事件
def check(self) -> bool:
"""检查距离上一次观察之后是否有了新消息
Returns:
bool: 是否有新消息
"""
logger.debug(f"检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
query = {
"chat_id": self.stream_id,
"time": {"$gt": self.last_check_time}
}
query = {"chat_id": self.stream_id, "time": {"$gt": self.last_check_time}}
# 只需要查询是否存在,不需要获取具体消息
new_message_exists = db.messages.find_one(query) is not None
if new_message_exists:
logger.debug("发现新消息")
self.last_check_time = time.time()
return new_message_exists
def get_new_message(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""获取上一次观察的时间点后的新消息,插入到历史记录中,并返回新消息和历史记录两个对象"""
messages = self.get_message_history(self.last_check_time)
for message in messages:
self._add_message_to_history(message)
return messages, self.message_history
def new_message_after(self, time_point: float) -> bool:
"""判断是否在指定时间点后有新消息
Args:
time_point: 时间戳
Returns:
bool: 是否有新消息
"""
logger.debug(f"判断是否在指定时间点后有新消息: {self.last_message_time} > {time_point}")
return self.last_message_time is None or self.last_message_time > time_point
def _add_message_to_history(self, message: Dict[str, Any]):
"""添加消息到历史记录
Args:
message: 消息数据
"""
@@ -108,54 +106,53 @@ class ChatObserver:
self.last_message_id = message["message_id"]
self.last_message_time = message["time"] # 更新最后消息时间
self.message_count += 1
# 更新说话时间
user_info = UserInfo.from_dict(message.get("user_info", {}))
if user_info.user_id == global_config.BOT_QQ:
self.last_bot_speak_time = message["time"]
else:
self.last_user_speak_time = message["time"]
def get_message_history(
self,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
limit: Optional[int] = None,
user_id: Optional[str] = None
user_id: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""获取消息历史
Args:
start_time: 开始时间戳
end_time: 结束时间戳
limit: 限制返回消息数量
user_id: 指定用户ID
Returns:
List[Dict[str, Any]]: 消息列表
"""
filtered_messages = self.message_history
if start_time is not None:
filtered_messages = [m for m in filtered_messages if m["time"] >= start_time]
if end_time is not None:
filtered_messages = [m for m in filtered_messages if m["time"] <= end_time]
if user_id is not None:
filtered_messages = [
m for m in filtered_messages
if UserInfo.from_dict(m.get("user_info", {})).user_id == user_id
m for m in filtered_messages if UserInfo.from_dict(m.get("user_info", {})).user_id == user_id
]
if limit is not None:
filtered_messages = filtered_messages[-limit:]
return filtered_messages
async def _fetch_new_messages(self) -> List[Dict[str, Any]]:
"""获取新消息
Returns:
List[Dict[str, Any]]: 新消息列表
"""
@@ -165,42 +162,37 @@ class ChatObserver:
last_message = db.messages.find_one({"message_id": self.last_message_read})
if last_message:
query["time"] = {"$gt": last_message["time"]}
new_messages = list(
db.messages.find(query).sort("time", 1)
)
new_messages = list(db.messages.find(query).sort("time", 1))
if new_messages:
self.last_message_read = new_messages[-1]["message_id"]
return new_messages
async def _fetch_new_messages_before(self, time_point: float) -> List[Dict[str, Any]]:
"""获取指定时间点之前的消息
Args:
time_point: 时间戳
Returns:
List[Dict[str, Any]]: 最多5条消息
"""
query = {
"chat_id": self.stream_id,
"time": {"$lt": time_point}
}
query = {"chat_id": self.stream_id, "time": {"$lt": time_point}}
new_messages = list(
db.messages.find(query).sort("time", -1).limit(5) # 倒序获取5条
)
# 将消息按时间正序排列
new_messages.reverse()
if new_messages:
self.last_message_read = new_messages[-1]["message_id"]
return new_messages
async def _update_loop(self):
"""更新循环"""
try:
@@ -210,7 +202,7 @@ class ChatObserver:
self._add_message_to_history(message)
except Exception as e:
logger.error(f"缓冲消息出错: {e}")
while self._running:
try:
# 等待事件或超时1秒
@@ -218,35 +210,35 @@ class ChatObserver:
await asyncio.wait_for(self._update_event.wait(), timeout=1)
except asyncio.TimeoutError:
pass # 超时后也执行一次检查
self._update_event.clear() # 重置触发事件
self._update_complete.clear() # 重置完成事件
# 获取新消息
new_messages = await self._fetch_new_messages()
if new_messages:
# 处理新消息
for message in new_messages:
self._add_message_to_history(message)
# 设置完成事件
self._update_complete.set()
except Exception as e:
logger.error(f"更新循环出错: {e}")
self._update_complete.set() # 即使出错也要设置完成事件
def trigger_update(self):
"""触发一次立即更新"""
self._update_event.set()
async def wait_for_update(self, timeout: float = 5.0) -> bool:
"""等待更新完成
Args:
timeout: 超时时间(秒)
Returns:
bool: 是否成功完成更新False表示超时
"""
@@ -256,16 +248,16 @@ class ChatObserver:
except asyncio.TimeoutError:
logger.warning(f"等待更新完成超时({timeout}秒)")
return False
def start(self):
"""启动观察器"""
if self._running:
return
self._running = True
self._task = asyncio.create_task(self._update_loop())
logger.info(f"ChatObserver for {self.stream_id} started")
def stop(self):
"""停止观察器"""
self._running = False
@@ -274,15 +266,15 @@ class ChatObserver:
if self._task:
self._task.cancel()
logger.info(f"ChatObserver for {self.stream_id} stopped")
async def process_chat_history(self, messages: list):
"""处理聊天历史
Args:
messages: 消息列表
"""
self.update_check_time()
for msg in messages:
try:
user_info = UserInfo.from_dict(msg.get("user_info", {}))
@@ -292,31 +284,31 @@ class ChatObserver:
self.update_user_speak_time(msg["time"])
except Exception as e:
logger.warning(f"处理消息时间时出错: {e}")
continue
continue
def update_check_time(self):
"""更新查看时间"""
self.last_check_time = time.time()
def update_bot_speak_time(self, speak_time: Optional[float] = None):
"""更新机器人说话时间"""
self.last_bot_speak_time = speak_time or time.time()
def update_user_speak_time(self, speak_time: Optional[float] = None):
"""更新用户说话时间"""
self.last_user_speak_time = speak_time or time.time()
def get_time_info(self) -> str:
"""获取时间信息文本"""
current_time = time.time()
time_info = ""
if self.last_bot_speak_time:
bot_speak_ago = current_time - self.last_bot_speak_time
time_info += f"\n距离你上次发言已经过去了{int(bot_speak_ago)}"
if self.last_user_speak_time:
user_speak_ago = current_time - self.last_user_speak_time
time_info += f"\n距离对方上次发言已经过去了{int(user_speak_ago)}"
return time_info

View File

@@ -1,5 +1,5 @@
#Programmable Friendly Conversationalist
#Prefrontal cortex
# Programmable Friendly Conversationalist
# Prefrontal cortex
import datetime
import asyncio
from typing import List, Optional, Dict, Any, Tuple, Literal
@@ -26,6 +26,7 @@ logger = get_module_logger("pfc")
class ConversationState(Enum):
"""对话状态"""
INIT = "初始化"
RETHINKING = "重新思考"
ANALYZING = "分析历史"
@@ -44,40 +45,37 @@ ActionType = Literal["direct_reply", "fetch_knowledge", "wait"]
class ActionPlanner:
"""行动规划器"""
def __init__(self, stream_id: str):
self.llm = LLM_request(
model=global_config.llm_normal,
temperature=0.7,
max_tokens=1000,
request_type="action_planning"
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="action_planning"
)
self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2)
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
self.name = global_config.BOT_NICKNAME
self.chat_observer = ChatObserver.get_instance(stream_id)
async def plan(
self,
goal: str,
method: str,
self,
goal: str,
method: str,
reasoning: str,
action_history: List[Dict[str, str]] = None,
chat_observer: Optional[ChatObserver] = None, # 添加chat_observer参数
) -> Tuple[str, str]:
"""规划下一步行动
Args:
goal: 对话目标
reasoning: 目标原因
action_history: 行动历史记录
Returns:
Tuple[str, str]: (行动类型, 行动原因)
"""
# 构建提示词
# 获取最近20条消息
self.chat_observer.waiting_start_time = time.time()
messages = self.chat_observer.get_message_history(limit=20)
chat_history_text = ""
for msg in messages:
@@ -87,13 +85,13 @@ class ActionPlanner:
if sender == self.name:
sender = "你说"
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
personality_text = f"你的名字是{self.name}{self.personality_info}"
# 构建action历史文本
action_history_text = ""
if action_history:
if action_history[-1]['action'] == "direct_reply":
if action_history[-1]["action"] == "direct_reply":
action_history_text = "你刚刚发言回复了对方"
# 获取时间信息
@@ -127,29 +125,34 @@ judge_conversation: 判断对话是否结束,当发现对话目标已经达到
try:
content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"LLM原始返回内容: {content}")
# 使用简化函数提取JSON内容
success, result = get_items_from_json(
content,
"action", "reason",
default_values={"action": "direct_reply", "reason": "默认原因"}
content, "action", "reason", default_values={"action": "direct_reply", "reason": "默认原因"}
)
if not success:
return "direct_reply", "JSON解析失败选择直接回复"
action = result["action"]
reason = result["reason"]
# 验证action类型
if action not in ["direct_reply", "fetch_knowledge", "wait", "listening", "rethink_goal", "judge_conversation"]:
if action not in [
"direct_reply",
"fetch_knowledge",
"wait",
"listening",
"rethink_goal",
"judge_conversation",
]:
logger.warning(f"未知的行动类型: {action}默认使用listening")
action = "listening"
logger.info(f"规划的行动: {action}")
logger.info(f"行动原因: {reason}")
return action, reason
except Exception as e:
logger.error(f"规划行动时出错: {str(e)}")
return "direct_reply", "发生错误,选择直接回复"
@@ -157,20 +160,17 @@ judge_conversation: 判断对话是否结束,当发现对话目标已经达到
class GoalAnalyzer:
"""对话目标分析器"""
def __init__(self, stream_id: str):
self.llm = LLM_request(
model=global_config.llm_normal,
temperature=0.7,
max_tokens=1000,
request_type="conversation_goal"
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal"
)
self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2)
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
self.name = global_config.BOT_NICKNAME
self.nick_name = global_config.BOT_ALIAS_NAMES
self.chat_observer = ChatObserver.get_instance(stream_id)
# 多目标存储结构
self.goals = [] # 存储多个目标
self.max_goals = 3 # 同时保持的最大目标数量
@@ -178,10 +178,10 @@ class GoalAnalyzer:
async def analyze_goal(self) -> Tuple[str, str, str]:
"""分析对话历史并设定目标
Args:
chat_history: 聊天历史记录列表
Returns:
Tuple[str, str, str]: (目标, 方法, 原因)
"""
@@ -198,16 +198,16 @@ class GoalAnalyzer:
if sender == self.name:
sender = "你说"
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
personality_text = f"你的名字是{self.name}{self.personality_info}"
# 构建当前已有目标的文本
existing_goals_text = ""
if self.goals:
existing_goals_text = "当前已有的对话目标:\n"
for i, (goal, _, reason) in enumerate(self.goals):
existing_goals_text += f"{i+1}. 目标: {goal}, 原因: {reason}\n"
existing_goals_text += f"{i + 1}. 目标: {goal}, 原因: {reason}\n"
prompt = f"""{personality_text}。现在你在参与一场QQ聊天请分析以下聊天记录并根据你的性格特征确定多个明确的对话目标。
这些目标应该反映出对话的不同方面和意图。
@@ -235,46 +235,44 @@ class GoalAnalyzer:
logger.debug(f"发送到LLM的提示词: {prompt}")
content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"LLM原始返回内容: {content}")
# 使用简化函数提取JSON内容
success, result = get_items_from_json(
content,
"goal", "reasoning",
required_types={"goal": str, "reasoning": str}
content, "goal", "reasoning", required_types={"goal": str, "reasoning": str}
)
if not success:
logger.error(f"无法解析JSON重试第{retry + 1}")
continue
goal = result["goal"]
reasoning = result["reasoning"]
# 使用默认的方法
method = "以友好的态度回应"
# 更新目标列表
await self._update_goals(goal, method, reasoning)
# 返回当前最主要的目标
if self.goals:
current_goal, current_method, current_reasoning = self.goals[0]
return current_goal, current_method, current_reasoning
else:
return goal, method, reasoning
except Exception as e:
logger.error(f"分析对话目标时出错: {str(e)},重试第{retry + 1}")
if retry == max_retries - 1:
return "保持友好的对话", "以友好的态度回应", "确保对话顺利进行"
continue
# 所有重试都失败后的默认返回
return "保持友好的对话", "以友好的态度回应", "确保对话顺利进行"
async def _update_goals(self, new_goal: str, method: str, reasoning: str):
"""更新目标列表
Args:
new_goal: 新的目标
method: 实现目标的方法
@@ -288,23 +286,23 @@ class GoalAnalyzer:
# 将此目标移到列表前面(最主要的位置)
self.goals.insert(0, self.goals.pop(i))
return
# 添加新目标到列表前面
self.goals.insert(0, (new_goal, method, reasoning))
# 限制目标数量
if len(self.goals) > self.max_goals:
self.goals.pop() # 移除最老的目标
def _calculate_similarity(self, goal1: str, goal2: str) -> float:
"""简单计算两个目标之间的相似度
这里使用一个简单的实现,实际可以使用更复杂的文本相似度算法
Args:
goal1: 第一个目标
goal2: 第二个目标
Returns:
float: 相似度得分 (0-1)
"""
@@ -314,18 +312,18 @@ class GoalAnalyzer:
overlap = len(words1.intersection(words2))
total = len(words1.union(words2))
return overlap / total if total > 0 else 0
async def get_all_goals(self) -> List[Tuple[str, str, str]]:
"""获取所有当前目标
Returns:
List[Tuple[str, str, str]]: 目标列表,每项为(目标, 方法, 原因)
"""
return self.goals.copy()
async def get_alternative_goals(self) -> List[Tuple[str, str, str]]:
"""获取除了当前主要目标外的其他备选目标
Returns:
List[Tuple[str, str, str]]: 备选目标列表
"""
@@ -343,9 +341,9 @@ class GoalAnalyzer:
if sender == self.name:
sender = "你说"
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
personality_text = f"你的名字是{self.name}{self.personality_info}"
prompt = f"""{personality_text}。现在你在参与一场QQ聊天
当前对话目标:{goal}
产生该对话目标的原因:{reasoning}
@@ -368,21 +366,19 @@ class GoalAnalyzer:
try:
content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"LLM原始返回内容: {content}")
# 使用简化函数提取JSON内容
success, result = get_items_from_json(
content,
"goal_achieved", "stop_conversation", "reason",
required_types={
"goal_achieved": bool,
"stop_conversation": bool,
"reason": str
}
"goal_achieved",
"stop_conversation",
"reason",
required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str},
)
if not success:
return False, False, "确保对话顺利进行"
# 如果当前目标达成,从目标列表中移除
if result["goal_achieved"] and not result["stop_conversation"]:
for i, (g, _, _) in enumerate(self.goals):
@@ -392,9 +388,9 @@ class GoalAnalyzer:
if self.goals:
result["stop_conversation"] = False
break
return result["goal_achieved"], result["stop_conversation"], result["reason"]
except Exception as e:
logger.error(f"分析对话目标时出错: {str(e)}")
return False, False, "确保对话顺利进行"
@@ -402,14 +398,15 @@ class GoalAnalyzer:
class Waiter:
"""快 速 等 待"""
def __init__(self, stream_id: str):
self.chat_observer = ChatObserver.get_instance(stream_id)
self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2)
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
self.name = global_config.BOT_NICKNAME
async def wait(self) -> bool:
"""等待
Returns:
bool: 是否超时True表示超时
"""
@@ -424,39 +421,36 @@ class Waiter:
logger.info("等待结束")
return False
class ReplyGenerator:
"""回复生成器"""
def __init__(self, stream_id: str):
self.llm = LLM_request(
model=global_config.llm_normal,
temperature=0.7,
max_tokens=300,
request_type="reply_generation"
model=global_config.llm_normal, temperature=0.7, max_tokens=300, request_type="reply_generation"
)
self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2)
self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2)
self.name = global_config.BOT_NICKNAME
self.chat_observer = ChatObserver.get_instance(stream_id)
self.reply_checker = ReplyChecker(stream_id)
async def generate(
self,
goal: str,
chat_history: List[Message],
knowledge_cache: Dict[str, str],
previous_reply: Optional[str] = None,
retry_count: int = 0
retry_count: int = 0,
) -> str:
"""生成回复
Args:
goal: 对话目标
chat_history: 聊天历史
knowledge_cache: 知识缓存
previous_reply: 上一次生成的回复(如果有)
retry_count: 当前重试次数
Returns:
str: 生成的回复
"""
@@ -465,7 +459,7 @@ class ReplyGenerator:
self.chat_observer.trigger_update() # 触发立即更新
if not await self.chat_observer.wait_for_update():
logger.warning("等待消息更新超时")
messages = self.chat_observer.get_message_history(limit=20)
chat_history_text = ""
for msg in messages:
@@ -475,7 +469,7 @@ class ReplyGenerator:
if sender == self.name:
sender = "你说"
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
# 整理知识缓存
knowledge_text = ""
if knowledge_cache:
@@ -486,14 +480,14 @@ class ReplyGenerator:
elif isinstance(knowledge_cache, list):
for item in knowledge_cache:
knowledge_text += f"\n{item}"
# 添加上一次生成的回复信息
previous_reply_text = ""
if previous_reply:
previous_reply_text = f"\n上一次生成的回复(需要改进):\n{previous_reply}"
personality_text = f"你的名字是{self.name}{self.personality_info}"
prompt = f"""{personality_text}。现在你在参与一场QQ聊天请根据以下信息生成回复
当前对话目标:{goal}
@@ -507,7 +501,7 @@ class ReplyGenerator:
2. 体现你的性格特征
3. 自然流畅,像正常聊天一样,简短
4. 适当利用相关知识,但不要生硬引用
{'5. 改进上一次回复中的问题' if previous_reply else ''}
{"5. 改进上一次回复中的问题" if previous_reply else ""}
请注意把握聊天内容,不要回复的太有条理,可以有个性。请分清""和对方说的话,不要把""说的话当做对方说的话,这是你自己说的话。
请你回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,尽量不要说你说过的话
@@ -521,34 +515,26 @@ class ReplyGenerator:
logger.info(f"生成的回复: {content}")
is_new = self.chat_observer.check()
logger.debug(f"再看一眼聊天记录,{'' if is_new else '没有'}新消息")
# 如果有新消息,重新生成回复
if is_new:
logger.info("检测到新消息,重新生成回复")
return await self.generate(
goal, chat_history, knowledge_cache,
None, retry_count
)
return await self.generate(goal, chat_history, knowledge_cache, None, retry_count)
return content
except Exception as e:
logger.error(f"生成回复时出错: {e}")
return "抱歉,我现在有点混乱,让我重新思考一下..."
async def check_reply(
self,
reply: str,
goal: str,
retry_count: int = 0
) -> Tuple[bool, str, bool]:
async def check_reply(self, reply: str, goal: str, retry_count: int = 0) -> Tuple[bool, str, bool]:
"""检查回复是否合适
Args:
reply: 生成的回复
goal: 对话目标
retry_count: 当前重试次数
Returns:
Tuple[bool, str, bool]: (是否合适, 原因, 是否需要重新规划)
"""
@@ -557,18 +543,18 @@ class ReplyGenerator:
class Conversation:
# 类级别的实例管理
_instances: Dict[str, 'Conversation'] = {}
_instances: Dict[str, "Conversation"] = {}
_instance_lock = asyncio.Lock() # 类级别的全局锁
_init_events: Dict[str, asyncio.Event] = {} # 初始化完成事件
_initializing: Dict[str, bool] = {} # 标记是否正在初始化
@classmethod
async def get_instance(cls, stream_id: str) -> Optional['Conversation']:
async def get_instance(cls, stream_id: str) -> Optional["Conversation"]:
"""获取或创建对话实例
Args:
stream_id: 聊天流ID
Returns:
Optional[Conversation]: 对话实例如果创建或等待失败则返回None
"""
@@ -586,23 +572,23 @@ class Conversation:
return None
finally:
await cls._instance_lock.acquire()
# 如果实例不存在,创建新实例
if stream_id not in cls._instances:
cls._instances[stream_id] = cls(stream_id)
cls._init_events[stream_id] = asyncio.Event()
cls._initializing[stream_id] = True
logger.info(f"创建新的对话实例: {stream_id}")
return cls._instances[stream_id]
except Exception as e:
logger.error(f"获取对话实例失败: {e}")
return None
@classmethod
async def remove_instance(cls, stream_id: str):
"""删除对话实例
Args:
stream_id: 聊天流ID
"""
@@ -628,16 +614,16 @@ class Conversation:
self.goal_reasoning: Optional[str] = None
self.generated_reply: Optional[str] = None
self.should_continue = True
# 初始化聊天观察器
self.chat_observer = ChatObserver.get_instance(stream_id)
# 添加action历史记录
self.action_history: List[Dict[str, str]] = []
# 知识缓存
self.knowledge_cache: Dict[str, str] = {} # 确保初始化为字典
# 初始化各个组件
self.goal_analyzer = GoalAnalyzer(self.stream_id)
self.action_planner = ActionPlanner(self.stream_id)
@@ -645,14 +631,14 @@ class Conversation:
self.knowledge_fetcher = KnowledgeFetcher()
self.direct_sender = DirectMessageSender()
self.waiter = Waiter(self.stream_id)
# 创建聊天流
self.chat_stream = chat_manager.get_stream(self.stream_id)
def _clear_knowledge_cache(self):
"""清空知识缓存"""
self.knowledge_cache.clear() # 使用clear方法清空字典
async def start(self):
"""开始对话流程"""
try:
@@ -674,38 +660,38 @@ class Conversation:
"""对话循环"""
# 获取最近的消息历史
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
while self.should_continue:
# 执行行动
self.chat_observer.trigger_update() # 触发立即更新
if not await self.chat_observer.wait_for_update():
logger.warning("等待消息更新超时")
action, reason = await self.action_planner.plan(
self.current_goal,
self.current_method,
self.goal_reasoning,
self.action_history, # 传入action历史
self.chat_observer # 传入chat_observer
self.chat_observer, # 传入chat_observer
)
# 执行行动
await self._handle_action(action, reason)
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message:
"""将消息字典转换为Message对象"""
try:
chat_info = msg_dict.get("chat_info", {})
chat_stream = ChatStream.from_dict(chat_info)
user_info = UserInfo.from_dict(msg_dict.get("user_info", {}))
return Message(
message_id=msg_dict["message_id"],
chat_stream=chat_stream,
time=msg_dict["time"],
user_info=user_info,
processed_plain_text=msg_dict.get("processed_plain_text", ""),
detailed_plain_text=msg_dict.get("detailed_plain_text", "")
detailed_plain_text=msg_dict.get("detailed_plain_text", ""),
)
except Exception as e:
logger.warning(f"转换消息时出错: {e}")
@@ -714,18 +700,16 @@ class Conversation:
async def _handle_action(self, action: str, reason: str):
"""处理规划的行动"""
logger.info(f"执行行动: {action}, 原因: {reason}")
# 记录action历史
self.action_history.append({
"action": action,
"reason": reason,
"time": datetime.datetime.now().strftime("%H:%M:%S")
})
self.action_history.append(
{"action": action, "reason": reason, "time": datetime.datetime.now().strftime("%H:%M:%S")}
)
# 只保留最近的10条记录
if len(self.action_history) > 10:
self.action_history = self.action_history[-10:]
if action == "direct_reply":
self.state = ConversationState.GENERATING
messages = self.chat_observer.get_message_history(limit=30)
@@ -733,15 +717,14 @@ class Conversation:
self.current_goal,
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache
self.knowledge_cache,
)
# 检查回复是否合适
is_suitable, reason, need_replan = await self.reply_generator.check_reply(
self.generated_reply,
self.current_goal
self.generated_reply, self.current_goal
)
if not is_suitable:
logger.warning(f"生成的回复不合适,原因: {reason}")
if need_replan:
@@ -756,29 +739,34 @@ class Conversation:
self.current_goal,
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache
self.knowledge_cache,
)
# 检查使用新目标生成的回复是否合适
is_suitable, reason, _ = await self.reply_generator.check_reply(
self.generated_reply,
self.current_goal
self.generated_reply, self.current_goal
)
if is_suitable:
# 如果新目标的回复合适,调整目标优先级
await self.goal_analyzer._update_goals(
self.current_goal,
self.current_method,
self.goal_reasoning
self.current_goal, self.current_method, self.goal_reasoning
)
else:
# 如果新目标还是不合适,重新思考目标
self.state = ConversationState.RETHINKING
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
(
self.current_goal,
self.current_method,
self.goal_reasoning,
) = await self.goal_analyzer.analyze_goal()
return
else:
# 没有备选目标,重新分析
self.state = ConversationState.RETHINKING
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
(
self.current_goal,
self.current_method,
self.goal_reasoning,
) = await self.goal_analyzer.analyze_goal()
return
else:
# 重新生成回复
@@ -787,9 +775,9 @@ class Conversation:
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache,
self.generated_reply # 将不合适的回复作为previous_reply传入
self.generated_reply, # 将不合适的回复作为previous_reply传入
)
while self.chat_observer.check():
if not is_suitable:
logger.warning(f"生成的回复不合适,原因: {reason}")
@@ -805,13 +793,17 @@ class Conversation:
self.current_goal,
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache
self.knowledge_cache,
)
is_suitable = True # 假设使用新目标后回复是合适的
else:
# 没有备选目标,重新分析
self.state = ConversationState.RETHINKING
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
(
self.current_goal,
self.current_method,
self.goal_reasoning,
) = await self.goal_analyzer.analyze_goal()
return
else:
# 重新生成回复
@@ -820,36 +812,34 @@ class Conversation:
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache,
self.generated_reply # 将不合适的回复作为previous_reply传入
self.generated_reply, # 将不合适的回复作为previous_reply传入
)
await self._send_reply()
elif action == "fetch_knowledge":
self.state = ConversationState.GENERATING
messages = self.chat_observer.get_message_history(limit=30)
knowledge, sources = await self.knowledge_fetcher.fetch(
self.current_goal,
[self._convert_to_message(msg) for msg in messages]
self.current_goal, [self._convert_to_message(msg) for msg in messages]
)
logger.info(f"获取到知识,来源: {sources}")
if knowledge != "未找到相关知识":
self.knowledge_cache[sources] = knowledge
self.generated_reply = await self.reply_generator.generate(
self.current_goal,
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache
self.knowledge_cache,
)
# 检查回复是否合适
is_suitable, reason, need_replan = await self.reply_generator.check_reply(
self.generated_reply,
self.current_goal
self.generated_reply, self.current_goal
)
if not is_suitable:
logger.warning(f"生成的回复不合适,原因: {reason}")
if need_replan:
@@ -861,22 +851,25 @@ class Conversation:
logger.info(f"切换到备选目标: {self.current_goal}")
# 使用新目标获取知识并生成回复
knowledge, sources = await self.knowledge_fetcher.fetch(
self.current_goal,
[self._convert_to_message(msg) for msg in messages]
self.current_goal, [self._convert_to_message(msg) for msg in messages]
)
if knowledge != "未找到相关知识":
self.knowledge_cache[sources] = knowledge
self.generated_reply = await self.reply_generator.generate(
self.current_goal,
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache
self.knowledge_cache,
)
else:
# 没有备选目标,重新分析
self.state = ConversationState.RETHINKING
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
(
self.current_goal,
self.current_method,
self.goal_reasoning,
) = await self.goal_analyzer.analyze_goal()
return
else:
# 重新生成回复
@@ -885,19 +878,21 @@ class Conversation:
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache,
self.generated_reply # 将不合适的回复作为previous_reply传入
self.generated_reply, # 将不合适的回复作为previous_reply传入
)
await self._send_reply()
elif action == "rethink_goal":
self.state = ConversationState.RETHINKING
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
elif action == "judge_conversation":
self.state = ConversationState.JUDGING
self.goal_achieved, self.stop_conversation, self.reason = await self.goal_analyzer.analyze_conversation(self.current_goal, self.goal_reasoning)
self.goal_achieved, self.stop_conversation, self.reason = await self.goal_analyzer.analyze_conversation(
self.current_goal, self.goal_reasoning
)
# 如果当前目标达成但还有其他目标
if self.goal_achieved and not self.stop_conversation:
alternative_goals = await self.goal_analyzer.get_alternative_goals()
@@ -906,17 +901,17 @@ class Conversation:
self.current_goal, self.current_method, self.goal_reasoning = alternative_goals[0]
logger.info(f"当前目标已达成,切换到新目标: {self.current_goal}")
return
if self.stop_conversation:
await self._stop_conversation()
elif action == "listening":
self.state = ConversationState.LISTENING
logger.info("倾听对方发言...")
if await self.waiter.wait(): # 如果返回True表示超时
await self._send_timeout_message()
await self._stop_conversation()
else: # wait
self.state = ConversationState.WAITING
logger.info("等待更多信息...")
@@ -938,12 +933,12 @@ class Conversation:
messages = self.chat_observer.get_message_history(limit=1)
if not messages:
return
latest_message = self._convert_to_message(messages[0])
await self.direct_sender.send_message(
chat_stream=self.chat_stream,
content="抱歉,由于等待时间过长,我需要先去忙别的了。下次再聊吧~",
reply_to_message=latest_message
reply_to_message=latest_message,
)
except Exception as e:
logger.error(f"发送超时消息失败: {str(e)}")
@@ -953,23 +948,21 @@ class Conversation:
if not self.generated_reply:
logger.warning("没有生成回复")
return
messages = self.chat_observer.get_message_history(limit=1)
if not messages:
logger.warning("没有最近的消息可以回复")
return
latest_message = self._convert_to_message(messages[0])
try:
await self.direct_sender.send_message(
chat_stream=self.chat_stream,
content=self.generated_reply,
reply_to_message=latest_message
chat_stream=self.chat_stream, content=self.generated_reply, reply_to_message=latest_message
)
self.chat_observer.trigger_update() # 触发立即更新
if not await self.chat_observer.wait_for_update():
logger.warning("等待消息更新超时")
self.state = ConversationState.ANALYZING
except Exception as e:
logger.error(f"发送消息失败: {str(e)}")
@@ -978,7 +971,7 @@ class Conversation:
class DirectMessageSender:
"""直接发送消息到平台的发送器"""
def __init__(self):
self.logger = get_module_logger("direct_sender")
self.storage = MessageStorage()
@@ -990,7 +983,7 @@ class DirectMessageSender:
reply_to_message: Optional[Message] = None,
) -> None:
"""直接发送消息到平台
Args:
chat_stream: 聊天流
content: 消息内容
@@ -1003,7 +996,7 @@ class DirectMessageSender:
user_nickname=global_config.BOT_NICKNAME,
platform=chat_stream.platform,
)
message = MessageSending(
message_id=f"dm{round(time.time(), 2)}",
chat_stream=chat_stream,
@@ -1023,18 +1016,17 @@ class DirectMessageSender:
try:
message_json = message.to_dict()
end_point = global_config.api_urls.get(chat_stream.platform, None)
if not end_point:
raise ValueError(f"未找到平台:{chat_stream.platform} 的url配置")
await global_api.send_message_REST(end_point, message_json)
# 存储消息
await self.storage.store_message(message, message.chat_stream)
self.logger.info(f"直接发送消息成功: {content[:30]}...")
except Exception as e:
self.logger.error(f"直接发送消息失败: {str(e)}")
raise

View File

@@ -7,24 +7,22 @@ from ..chat.message import Message
logger = get_module_logger("knowledge_fetcher")
class KnowledgeFetcher:
"""知识调取器"""
def __init__(self):
self.llm = LLM_request(
model=global_config.llm_normal,
temperature=0.7,
max_tokens=1000,
request_type="knowledge_fetch"
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="knowledge_fetch"
)
async def fetch(self, query: str, chat_history: List[Message]) -> Tuple[str, str]:
"""获取相关知识
Args:
query: 查询内容
chat_history: 聊天历史
Returns:
Tuple[str, str]: (获取的知识, 知识来源)
"""
@@ -33,16 +31,16 @@ class KnowledgeFetcher:
for msg in chat_history:
# sender = msg.message_info.user_info.user_nickname or f"用户{msg.message_info.user_info.user_id}"
chat_history_text += f"{msg.detailed_plain_text}\n"
# 从记忆中获取相关知识
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=f"{query}\n{chat_history_text}",
max_memory_num=3,
max_memory_length=2,
max_depth=3,
fast_retrieval=False
fast_retrieval=False,
)
if related_memory:
knowledge = ""
sources = []
@@ -50,5 +48,5 @@ class KnowledgeFetcher:
knowledge += memory[1] + "\n"
sources.append(f"记忆片段{memory[0]}")
return knowledge.strip(), "".join(sources)
return "未找到相关知识", "无记忆匹配"
return "未找到相关知识", "无记忆匹配"

View File

@@ -5,36 +5,37 @@ from src.common.logger import get_module_logger
logger = get_module_logger("pfc_utils")
def get_items_from_json(
content: str,
*items: str,
default_values: Optional[Dict[str, Any]] = None,
required_types: Optional[Dict[str, type]] = None
required_types: Optional[Dict[str, type]] = None,
) -> Tuple[bool, Dict[str, Any]]:
"""从文本中提取JSON内容并获取指定字段
Args:
content: 包含JSON的文本
*items: 要提取的字段名
default_values: 字段的默认值,格式为 {字段名: 默认值}
required_types: 字段的必需类型,格式为 {字段名: 类型}
Returns:
Tuple[bool, Dict[str, Any]]: (是否成功, 提取的字段字典)
"""
content = content.strip()
result = {}
# 设置默认值
if default_values:
result.update(default_values)
# 尝试解析JSON
try:
json_data = json.loads(content)
except json.JSONDecodeError:
# 如果直接解析失败尝试查找和提取JSON部分
json_pattern = r'\{[^{}]*\}'
json_pattern = r"\{[^{}]*\}"
json_match = re.search(json_pattern, content)
if json_match:
try:
@@ -45,28 +46,28 @@ def get_items_from_json(
else:
logger.error("无法在返回内容中找到有效的JSON")
return False, result
# 提取字段
for item in items:
if item in json_data:
result[item] = json_data[item]
# 验证必需字段
if not all(item in result for item in items):
logger.error(f"JSON缺少必要字段实际内容: {json_data}")
return False, result
# 验证字段类型
if required_types:
for field, expected_type in required_types.items():
if field in result and not isinstance(result[field], expected_type):
logger.error(f"{field} 必须是 {expected_type.__name__} 类型")
return False, result
# 验证字符串字段不为空
for field in items:
if isinstance(result[field], str) and not result[field].strip():
logger.error(f"{field} 不能为空")
return False, result
return True, result
return True, result

View File

@@ -9,33 +9,26 @@ from ..message.message_base import UserInfo
logger = get_module_logger("reply_checker")
class ReplyChecker:
"""回复检查器"""
def __init__(self, stream_id: str):
self.llm = LLM_request(
model=global_config.llm_normal,
temperature=0.7,
max_tokens=1000,
request_type="reply_check"
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="reply_check"
)
self.name = global_config.BOT_NICKNAME
self.chat_observer = ChatObserver.get_instance(stream_id)
self.max_retries = 2 # 最大重试次数
async def check(
self,
reply: str,
goal: str,
retry_count: int = 0
) -> Tuple[bool, str, bool]:
async def check(self, reply: str, goal: str, retry_count: int = 0) -> Tuple[bool, str, bool]:
"""检查生成的回复是否合适
Args:
reply: 生成的回复
goal: 对话目标
retry_count: 当前重试次数
Returns:
Tuple[bool, str, bool]: (是否合适, 原因, 是否需要重新规划)
"""
@@ -49,7 +42,7 @@ class ReplyChecker:
if sender == self.name:
sender = "你说"
chat_history_text += f"{time_str},{sender}:{msg.get('processed_plain_text', '')}\n"
prompt = f"""请检查以下回复是否合适:
当前对话目标:{goal}
@@ -83,7 +76,7 @@ class ReplyChecker:
try:
content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"检查回复的原始返回: {content}")
# 清理内容尝试提取JSON部分
content = content.strip()
try:
@@ -92,7 +85,8 @@ class ReplyChecker:
except json.JSONDecodeError:
# 如果直接解析失败尝试查找和提取JSON部分
import re
json_pattern = r'\{[^{}]*\}'
json_pattern = r"\{[^{}]*\}"
json_match = re.search(json_pattern, content)
if json_match:
try:
@@ -109,33 +103,33 @@ class ReplyChecker:
reason = content[:100] if content else "无法解析响应"
need_replan = "重新规划" in content.lower() or "目标不适合" in content.lower()
return is_suitable, reason, need_replan
# 验证JSON字段
suitable = result.get("suitable", None)
reason = result.get("reason", "未提供原因")
need_replan = result.get("need_replan", False)
# 如果suitable字段是字符串转换为布尔值
if isinstance(suitable, str):
suitable = suitable.lower() == "true"
# 如果suitable字段不存在或不是布尔值从reason中判断
if suitable is None:
suitable = "不合适" not in reason.lower() and "违规" not in reason.lower()
# 如果不合适且未达到最大重试次数,返回需要重试
if not suitable and retry_count < self.max_retries:
return False, reason, False
# 如果不合适且已达到最大重试次数,返回需要重新规划
if not suitable and retry_count >= self.max_retries:
return False, f"多次重试后仍不合适: {reason}", True
return suitable, reason, need_replan
except Exception as e:
logger.error(f"检查回复时出错: {e}")
# 如果出错且已达到最大重试次数,建议重新规划
if retry_count >= self.max_retries:
return False, "多次检查失败,建议重新规划", True
return False, f"检查过程出错,建议重试: {str(e)}", False
return False, f"检查过程出错,建议重试: {str(e)}", False

View File

@@ -12,5 +12,5 @@ __all__ = [
"chat_manager",
"message_manager",
"MessageStorage",
"auto_speak_manager"
"auto_speak_manager",
]

View File

@@ -40,14 +40,14 @@ class ChatBot:
async def _create_PFC_chat(self, message: MessageRecv):
try:
chat_id = str(message.chat_stream.stream_id)
if global_config.enable_pfc_chatting:
# 获取或创建对话实例
conversation = await Conversation.get_instance(chat_id)
if conversation is None:
logger.error(f"创建或获取对话实例失败: {chat_id}")
return
# 如果是新创建的实例,启动对话系统
if conversation.state == ConversationState.INIT:
asyncio.create_task(conversation.start())
@@ -71,16 +71,16 @@ class ChatBot:
- 包含思维流状态管理
- 在回复前进行观察和状态更新
- 回复后更新思维流状态
2. reasoning模式使用推理系统进行回复
- 直接使用意愿管理器计算回复概率
- 没有思维流相关的状态管理
- 更简单直接的回复逻辑
3. pfc_chatting模式仅进行消息处理
- 不进行任何回复
- 只处理和存储消息
所有模式都包含:
- 消息过滤
- 记忆激活
@@ -98,7 +98,7 @@ class ChatBot:
if userinfo.user_id in global_config.ban_user_id:
logger.debug(f"用户{userinfo.user_id}被禁止回复")
return
if global_config.enable_pfc_chatting:
try:
if groupinfo is None and global_config.enable_friend_chat:
@@ -127,7 +127,7 @@ class ChatBot:
logger.error(f"处理PFC消息失败: {e}")
else:
if groupinfo is None and global_config.enable_friend_chat:
# 私聊处理流程
# 私聊处理流程
# await self._handle_private_chat(message)
if global_config.response_mode == "heart_flow":
await self.think_flow_chat.process_message(message_data)

View File

@@ -38,11 +38,11 @@ class EmojiManager:
self.llm_emotion_judge = LLM_request(
model=global_config.llm_emotion_judge, max_tokens=600, temperature=0.8, request_type="emoji"
) # 更高的温度更少的token后续可以根据情绪来调整温度
self.emoji_num = 0
self.emoji_num_max = global_config.max_emoji_num
self.emoji_num_max_reach_deletion = global_config.max_reach_deletion
logger.info("启动表情包管理器")
def _ensure_emoji_dir(self):
@@ -51,7 +51,7 @@ class EmojiManager:
def _update_emoji_count(self):
"""更新表情包数量统计
检查数据库中的表情包数量并更新到 self.emoji_num
"""
try:
@@ -376,7 +376,6 @@ class EmojiManager:
except Exception:
logger.exception("[错误] 扫描表情包失败")
def check_emoji_file_integrity(self):
"""检查表情包文件完整性
@@ -451,7 +450,7 @@ class EmojiManager:
def check_emoji_file_full(self):
"""检查表情包文件是否完整,如果数量超出限制且允许删除,则删除多余的表情包
删除规则:
1. 优先删除创建时间更早的表情包
2. 优先删除使用次数少的表情包,但使用次数多的也有小概率被删除
@@ -460,23 +459,23 @@ class EmojiManager:
self._ensure_db()
# 更新表情包数量
self._update_emoji_count()
# 检查是否超出限制
if self.emoji_num <= self.emoji_num_max:
return
# 如果超出限制但不允许删除,则只记录警告
if not global_config.max_reach_deletion:
logger.warning(f"[警告] 表情包数量({self.emoji_num})超出限制({self.emoji_num_max}),但未开启自动删除")
return
# 计算需要删除的数量
delete_count = self.emoji_num - self.emoji_num_max
logger.info(f"[清理] 需要删除 {delete_count} 个表情包")
# 获取所有表情包,按时间戳升序(旧的在前)排序
all_emojis = list(db.emoji.find().sort([("timestamp", 1)]))
# 计算权重:使用次数越多,被删除的概率越小
weights = []
max_usage = max((emoji.get("usage_count", 0) for emoji in all_emojis), default=1)
@@ -485,11 +484,11 @@ class EmojiManager:
# 使用指数衰减函数计算权重,使用次数越多权重越小
weight = 1.0 / (1.0 + usage_count / max(1, max_usage))
weights.append(weight)
# 根据权重随机选择要删除的表情包
to_delete = []
remaining_indices = list(range(len(all_emojis)))
while len(to_delete) < delete_count and remaining_indices:
# 计算当前剩余表情包的权重
current_weights = [weights[i] for i in remaining_indices]
@@ -497,13 +496,13 @@ class EmojiManager:
total_weight = sum(current_weights)
if total_weight == 0:
break
normalized_weights = [w/total_weight for w in current_weights]
normalized_weights = [w / total_weight for w in current_weights]
# 随机选择一个表情包
selected_idx = random.choices(remaining_indices, weights=normalized_weights, k=1)[0]
to_delete.append(all_emojis[selected_idx])
remaining_indices.remove(selected_idx)
# 删除选中的表情包
deleted_count = 0
for emoji in to_delete:
@@ -512,26 +511,26 @@ class EmojiManager:
if "path" in emoji and os.path.exists(emoji["path"]):
os.remove(emoji["path"])
logger.info(f"[删除] 文件: {emoji['path']} (使用次数: {emoji.get('usage_count', 0)})")
# 删除数据库记录
db.emoji.delete_one({"_id": emoji["_id"]})
deleted_count += 1
# 同时从images集合中删除
if "hash" in emoji:
db.images.delete_one({"hash": emoji["hash"]})
except Exception as e:
logger.error(f"[错误] 删除表情包失败: {str(e)}")
continue
# 更新表情包数量
self._update_emoji_count()
logger.success(f"[清理] 已删除 {deleted_count} 个表情包,当前数量: {self.emoji_num}")
except Exception as e:
logger.error(f"[错误] 检查表情包数量失败: {str(e)}")
async def start_periodic_check_register(self):
"""定期检查表情包完整性和数量"""
while True:
@@ -542,7 +541,7 @@ class EmojiManager:
logger.info("[扫描] 开始扫描新表情包...")
if self.emoji_num < self.emoji_num_max:
await self.scan_new_emojis()
if (self.emoji_num > self.emoji_num_max):
if self.emoji_num > self.emoji_num_max:
logger.warning(f"[警告] 表情包数量超过最大限制: {self.emoji_num} > {self.emoji_num_max},跳过注册")
if not global_config.max_reach_deletion:
logger.warning("表情包数量超过最大限制,终止注册")
@@ -551,7 +550,7 @@ class EmojiManager:
logger.warning("表情包数量超过最大限制,开始删除表情包")
self.check_emoji_file_full()
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
async def delete_all_images(self):
"""删除 data/image 目录下的所有文件"""
try:
@@ -559,10 +558,10 @@ class EmojiManager:
if not os.path.exists(image_dir):
logger.warning(f"[警告] 目录不存在: {image_dir}")
return
deleted_count = 0
failed_count = 0
# 遍历目录下的所有文件
for filename in os.listdir(image_dir):
file_path = os.path.join(image_dir, filename)
@@ -574,11 +573,12 @@ class EmojiManager:
except Exception as e:
failed_count += 1
logger.error(f"[错误] 删除文件失败 {file_path}: {str(e)}")
logger.success(f"[清理] 已删除 {deleted_count} 个文件,失败 {failed_count}")
except Exception as e:
logger.error(f"[错误] 删除图片目录失败: {str(e)}")
# 创建全局单例
emoji_manager = EmojiManager()

View File

@@ -13,9 +13,10 @@ from ..config.config import global_config
logger = get_module_logger("message_buffer")
@dataclass
class CacheMessages:
message: MessageRecv
message: MessageRecv
cache_determination: asyncio.Event = field(default_factory=asyncio.Event) # 判断缓冲是否产生结果
result: str = "U"
@@ -25,7 +26,7 @@ class MessageBuffer:
self.buffer_pool: Dict[str, OrderedDict[str, CacheMessages]] = {}
self.lock = asyncio.Lock()
def get_person_id_(self, platform:str, user_id:str, group_info:GroupInfo):
def get_person_id_(self, platform: str, user_id: str, group_info: GroupInfo):
"""获取唯一id"""
if group_info:
group_id = group_info.group_id
@@ -34,16 +35,17 @@ class MessageBuffer:
key = f"{platform}_{user_id}_{group_id}"
return hashlib.md5(key.encode()).hexdigest()
async def start_caching_messages(self, message:MessageRecv):
async def start_caching_messages(self, message: MessageRecv):
"""添加消息,启动缓冲"""
if not global_config.message_buffer:
person_id = person_info_manager.get_person_id(message.message_info.user_info.platform,
message.message_info.user_info.user_id)
person_id = person_info_manager.get_person_id(
message.message_info.user_info.platform, message.message_info.user_info.user_id
)
asyncio.create_task(self.save_message_interval(person_id, message.message_info))
return
person_id_ = self.get_person_id_(message.message_info.platform,
message.message_info.user_info.user_id,
message.message_info.group_info)
person_id_ = self.get_person_id_(
message.message_info.platform, message.message_info.user_info.user_id, message.message_info.group_info
)
async with self.lock:
if person_id_ not in self.buffer_pool:
@@ -64,25 +66,24 @@ class MessageBuffer:
break
elif msg.result == "F":
recent_F_count += 1
# 判断条件最近T之后有超过3-5条F
if (recent_F_count >= random.randint(3, 5)):
if recent_F_count >= random.randint(3, 5):
new_msg = CacheMessages(message=message, result="T")
new_msg.cache_determination.set()
self.buffer_pool[person_id_][message.message_info.message_id] = new_msg
logger.debug(f"快速处理消息(已堆积{recent_F_count}条F): {message.message_info.message_id}")
return
# 添加新消息
self.buffer_pool[person_id_][message.message_info.message_id] = CacheMessages(message=message)
# 启动3秒缓冲计时器
person_id = person_info_manager.get_person_id(message.message_info.user_info.platform,
message.message_info.user_info.user_id)
person_id = person_info_manager.get_person_id(
message.message_info.user_info.platform, message.message_info.user_info.user_id
)
asyncio.create_task(self.save_message_interval(person_id, message.message_info))
asyncio.create_task(self._debounce_processor(person_id_,
message.message_info.message_id,
person_id))
asyncio.create_task(self._debounce_processor(person_id_, message.message_info.message_id, person_id))
async def _debounce_processor(self, person_id_: str, message_id: str, person_id: str):
"""等待3秒无新消息"""
@@ -92,36 +93,33 @@ class MessageBuffer:
return
interval_time = max(0.5, int(interval_time) / 1000)
await asyncio.sleep(interval_time)
async with self.lock:
if (person_id_ not in self.buffer_pool or
message_id not in self.buffer_pool[person_id_]):
if person_id_ not in self.buffer_pool or message_id not in self.buffer_pool[person_id_]:
logger.debug(f"消息已被清理msgid: {message_id}")
return
cache_msg = self.buffer_pool[person_id_][message_id]
if cache_msg.result == "U":
cache_msg.result = "T"
cache_msg.cache_determination.set()
async def query_buffer_result(self, message:MessageRecv) -> bool:
async def query_buffer_result(self, message: MessageRecv) -> bool:
"""查询缓冲结果,并清理"""
if not global_config.message_buffer:
return True
person_id_ = self.get_person_id_(message.message_info.platform,
message.message_info.user_info.user_id,
message.message_info.group_info)
person_id_ = self.get_person_id_(
message.message_info.platform, message.message_info.user_info.user_id, message.message_info.group_info
)
async with self.lock:
user_msgs = self.buffer_pool.get(person_id_, {})
cache_msg = user_msgs.get(message.message_info.message_id)
if not cache_msg:
logger.debug(f"查询异常消息不存在msgid: {message.message_info.message_id}")
return False # 消息不存在或已清理
try:
await asyncio.wait_for(cache_msg.cache_determination.wait(), timeout=10)
result = cache_msg.result == "T"
@@ -144,9 +142,8 @@ class MessageBuffer:
keep_msgs[msg_id] = msg
elif msg.result == "F":
# 收集F消息的文本内容
if (hasattr(msg.message, 'processed_plain_text')
and msg.message.processed_plain_text):
if msg.message.message_segment.type == "text":
if hasattr(msg.message, "processed_plain_text") and msg.message.processed_plain_text:
if msg.message.message_segment.type == "text":
combined_text.append(msg.message.processed_plain_text)
elif msg.message.message_segment.type != "text":
is_update = False
@@ -157,20 +154,20 @@ class MessageBuffer:
if combined_text and combined_text[0] != message.processed_plain_text and is_update:
if type == "text":
message.processed_plain_text = "".join(combined_text)
logger.debug(f"整合了{len(combined_text)-1}条F消息的内容到当前消息")
logger.debug(f"整合了{len(combined_text) - 1}条F消息的内容到当前消息")
elif type == "emoji":
combined_text.pop()
message.processed_plain_text = "".join(combined_text)
message.is_emoji = False
logger.debug(f"整合了{len(combined_text)-1}条F消息的内容覆盖当前emoji消息")
logger.debug(f"整合了{len(combined_text) - 1}条F消息的内容覆盖当前emoji消息")
self.buffer_pool[person_id_] = keep_msgs
return result
except asyncio.TimeoutError:
logger.debug(f"查询超时消息id {message.message_info.message_id}")
return False
async def save_message_interval(self, person_id:str, message:BaseMessageInfo):
async def save_message_interval(self, person_id: str, message: BaseMessageInfo):
message_interval_list = await person_info_manager.get_value(person_id, "msg_interval_list")
now_time_ms = int(round(time.time() * 1000))
if len(message_interval_list) < 1000:
@@ -179,12 +176,12 @@ class MessageBuffer:
message_interval_list.pop(0)
message_interval_list.append(now_time_ms)
data = {
"platform" : message.platform,
"user_id" : message.user_info.user_id,
"nickname" : message.user_info.user_nickname,
"konw_time" : int(time.time())
"platform": message.platform,
"user_id": message.user_info.user_id,
"nickname": message.user_info.user_nickname,
"konw_time": int(time.time()),
}
await person_info_manager.update_one_field(person_id, "msg_interval_list", message_interval_list, data)
message_buffer = MessageBuffer()
message_buffer = MessageBuffer()

View File

@@ -68,7 +68,8 @@ class Message_Sender:
typing_time = calculate_typing_time(
input_string=message.processed_plain_text,
thinking_start_time=message.thinking_start_time,
is_emoji=message.is_emoji)
is_emoji=message.is_emoji,
)
logger.debug(f"{message.processed_plain_text},{typing_time},计算输入时间结束")
await asyncio.sleep(typing_time)
logger.debug(f"{message.processed_plain_text},{typing_time},等待输入时间结束")
@@ -227,7 +228,7 @@ class MessageManager:
await message_earliest.process()
# print(f"message_earliest.thinking_start_tim22222e:{message_earliest.thinking_start_time}")
await message_sender.send_message(message_earliest)
await self.storage.store_message(message_earliest, message_earliest.chat_stream)

View File

@@ -56,14 +56,13 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> bool:
logger.info("被@回复概率设置为100%")
else:
if not is_mentioned:
# 判断是否被回复
if re.match(f"回复[\s\S]*?\({global_config.BOT_QQ}\)的消息,说:", message.processed_plain_text):
is_mentioned = True
# 判断内容中是否被提及
message_content = re.sub(r'\@[\s\S]*?(\d+)','', message.processed_plain_text)
message_content = re.sub(r'回复[\s\S]*?\((\d+)\)的消息,说: ','', message_content)
message_content = re.sub(r"\@[\s\S]*?(\d+)", "", message.processed_plain_text)
message_content = re.sub(r"回复[\s\S]*?\((\d+)\)的消息,说: ", "", message_content)
for keyword in keywords:
if keyword in message_content:
is_mentioned = True
@@ -359,7 +358,13 @@ def process_llm_response(text: str) -> List[str]:
return sentences
def calculate_typing_time(input_string: str, thinking_start_time: float, chinese_time: float = 0.2, english_time: float = 0.1, is_emoji: bool = False) -> float:
def calculate_typing_time(
input_string: str,
thinking_start_time: float,
chinese_time: float = 0.2,
english_time: float = 0.1,
is_emoji: bool = False,
) -> float:
"""
计算输入字符串所需的时间,中文和英文字符有不同的输入时间
input_string (str): 输入的字符串
@@ -393,19 +398,18 @@ def calculate_typing_time(input_string: str, thinking_start_time: float, chinese
total_time += chinese_time
else: # 其他字符(如英文)
total_time += english_time
if is_emoji:
total_time = 1
if time.time() - thinking_start_time > 10:
total_time = 1
# print(f"thinking_start_time:{thinking_start_time}")
# print(f"nowtime:{time.time()}")
# print(f"nowtime - thinking_start_time:{time.time() - thinking_start_time}")
# print(f"{total_time}")
return total_time # 加上回车时间
@@ -535,39 +539,32 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
try:
# 获取开始时间之前最新的一条消息
start_message = db.messages.find_one(
{
"chat_id": stream_id,
"time": {"$lte": start_time}
},
sort=[("time", -1), ("_id", -1)] # 按时间倒序_id倒序最后插入的在前
{"chat_id": stream_id, "time": {"$lte": start_time}},
sort=[("time", -1), ("_id", -1)], # 按时间倒序_id倒序最后插入的在前
)
# 获取结束时间最近的一条消息
# 先找到结束时间点的所有消息
end_time_messages = list(db.messages.find(
{
"chat_id": stream_id,
"time": {"$lte": end_time}
},
sort=[("time", -1)] # 先按时间倒序
).limit(10)) # 限制查询数量,避免性能问题
end_time_messages = list(
db.messages.find(
{"chat_id": stream_id, "time": {"$lte": end_time}},
sort=[("time", -1)], # 先按时间倒序
).limit(10)
) # 限制查询数量,避免性能问题
if not end_time_messages:
logger.warning(f"未找到结束时间 {end_time} 之前的消息")
return 0, 0
# 找到最大时间
max_time = end_time_messages[0]["time"]
# 在最大时间的消息中找最后插入的_id最大的
end_message = max(
[msg for msg in end_time_messages if msg["time"] == max_time],
key=lambda x: x["_id"]
)
end_message = max([msg for msg in end_time_messages if msg["time"] == max_time], key=lambda x: x["_id"])
if not start_message:
logger.warning(f"未找到开始时间 {start_time} 之前的消息")
return 0, 0
# 调试输出
# print("\n=== 消息范围信息 ===")
# print("Start message:", {
@@ -587,20 +584,16 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
# 如果结束消息的时间等于开始时间返回0
if end_message["time"] == start_message["time"]:
return 0, 0
# 获取并打印这个时间范围内的所有消息
# print("\n=== 时间范围内的所有消息 ===")
all_messages = list(db.messages.find(
{
"chat_id": stream_id,
"time": {
"$gte": start_message["time"],
"$lte": end_message["time"]
}
},
sort=[("time", 1), ("_id", 1)] # 按时间正序_id正序
))
all_messages = list(
db.messages.find(
{"chat_id": stream_id, "time": {"$gte": start_message["time"], "$lte": end_message["time"]}},
sort=[("time", 1), ("_id", 1)], # 按时间正序_id正序
)
)
count = 0
total_length = 0
for msg in all_messages:
@@ -615,10 +608,10 @@ def count_messages_between(start_time: float, end_time: float, stream_id: str) -
# "text_length": text_length,
# "_id": str(msg.get("_id"))
# })
# 如果时间不同需要把end_message本身也计入
return count - 1, total_length
except Exception as e:
logger.error(f"计算消息数量时出错: {str(e)}")
return 0, 0

View File

@@ -239,13 +239,13 @@ class ImageManager:
# 解码base64
gif_data = base64.b64decode(gif_base64)
gif = Image.open(io.BytesIO(gif_data))
# 收集所有帧
frames = []
try:
while True:
gif.seek(len(frames))
frame = gif.convert('RGB')
frame = gif.convert("RGB")
frames.append(frame.copy())
except EOFError:
pass
@@ -264,18 +264,19 @@ class ImageManager:
# 获取单帧的尺寸
frame_width, frame_height = selected_frames[0].size
# 计算目标尺寸,保持宽高比
target_height = 200 # 固定高度
target_width = int((target_height / frame_height) * frame_width)
# 调整所有帧的大小
resized_frames = [frame.resize((target_width, target_height), Image.Resampling.LANCZOS)
for frame in selected_frames]
resized_frames = [
frame.resize((target_width, target_height), Image.Resampling.LANCZOS) for frame in selected_frames
]
# 创建拼接图像
total_width = target_width * len(resized_frames)
combined_image = Image.new('RGB', (total_width, target_height))
combined_image = Image.new("RGB", (total_width, target_height))
# 水平拼接图像
for idx, frame in enumerate(resized_frames):
@@ -283,11 +284,11 @@ class ImageManager:
# 转换为base64
buffer = io.BytesIO()
combined_image.save(buffer, format='JPEG', quality=85)
result_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
combined_image.save(buffer, format="JPEG", quality=85)
result_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
return result_base64
except Exception as e:
logger.error(f"GIF转换失败: {str(e)}")
return None

View File

@@ -7,12 +7,13 @@ from datetime import datetime
logger = get_module_logger("pfc_message_processor")
class MessageProcessor:
"""消息处理器,负责处理接收到的消息并存储"""
def __init__(self):
self.storage = MessageStorage()
def _check_ban_words(self, text: str, chat, userinfo) -> bool:
"""检查消息中是否包含过滤词"""
for word in global_config.ban_words:
@@ -34,10 +35,10 @@ class MessageProcessor:
logger.info(f"[正则表达式过滤]消息匹配到{pattern}filtered")
return True
return False
async def process_message(self, message: MessageRecv) -> None:
"""处理消息并存储
Args:
message: 消息对象
"""
@@ -55,12 +56,9 @@ class MessageProcessor:
# 存储消息
await self.storage.store_message(message, chat)
# 打印消息信息
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
# 将时间戳转换为datetime对象
current_time = datetime.fromtimestamp(message.message_info.time).strftime("%H:%M:%S")
logger.info(
f"[{current_time}][{mes_name}]"
f"{chat.user_info.user_nickname}: {message.processed_plain_text}"
)
logger.info(f"[{current_time}][{mes_name}]{chat.user_info.user_nickname}: {message.processed_plain_text}")

View File

@@ -27,6 +27,7 @@ chat_config = LogConfig(
logger = get_module_logger("reasoning_chat", config=chat_config)
class ReasoningChat:
def __init__(self):
self.storage = MessageStorage()
@@ -224,13 +225,13 @@ class ReasoningChat:
do_reply = False
if random() < reply_probability:
do_reply = True
# 创建思考消息
timer1 = time.time()
thinking_id = await self._create_thinking_message(message, chat, userinfo, messageinfo)
timer2 = time.time()
timing_results["创建思考消息"] = timer2 - timer1
# 生成回复
timer1 = time.time()
response_set = await self.gpt.generate_response(message)

View File

@@ -40,7 +40,7 @@ class ResponseGenerator:
async def generate_response(self, message: MessageThinking) -> Optional[Union[str, List[str]]]:
"""根据当前模型类型选择对应的生成函数"""
#从global_config中获取模型概率值并选择模型
# 从global_config中获取模型概率值并选择模型
if random.random() < global_config.MODEL_R1_PROBABILITY:
self.current_model_type = "深深地"
current_model = self.model_reasoning
@@ -51,7 +51,6 @@ class ResponseGenerator:
logger.info(
f"{self.current_model_type}思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}"
) # noqa: E501
model_response = await self._generate_response_with_model(message, current_model)
@@ -189,4 +188,4 @@ class ResponseGenerator:
# print(f"得到了处理后的llm返回{processed_response}")
return processed_response
return processed_response

View File

@@ -24,35 +24,32 @@ class PromptBuilder:
async def _build_prompt(
self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
) -> tuple[str, str]:
# 开始构建prompt
prompt_personality = ""
#person
# person
individuality = Individuality.get_instance()
personality_core = individuality.personality.personality_core
prompt_personality += personality_core
personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}"
identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}"
# 关系
who_chat_in_group = [(chat_stream.user_info.platform,
chat_stream.user_info.user_id,
chat_stream.user_info.user_nickname)]
who_chat_in_group = [
(chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname)
]
who_chat_in_group += get_recent_group_speaker(
stream_id,
(chat_stream.user_info.platform, chat_stream.user_info.user_id),
limit=global_config.MAX_CONTEXT_SIZE,
)
relation_prompt = ""
for person in who_chat_in_group:
relation_prompt += await relationship_manager.build_relationship_info(person)
@@ -67,7 +64,7 @@ class PromptBuilder:
mood_prompt = mood_manager.get_prompt()
# logger.info(f"心情prompt: {mood_prompt}")
# 调取记忆
memory_prompt = ""
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
@@ -84,7 +81,7 @@ class PromptBuilder:
# print(f"相关记忆:{related_memory_info}")
# 日程构建
schedule_prompt = f'''你现在正在做的事情是:{bot_schedule.get_current_num_task(num = 1,time_info = False)}'''
schedule_prompt = f"""你现在正在做的事情是:{bot_schedule.get_current_num_task(num=1, time_info=False)}"""
# 获取聊天上下文
chat_in_group = True
@@ -143,7 +140,7 @@ class PromptBuilder:
涉及政治敏感以及违法违规的内容请规避。"""
logger.info("开始构建prompt")
prompt = f"""
{relation_prompt_all}
{memory_prompt}
@@ -165,7 +162,7 @@ class PromptBuilder:
start_time = time.time()
related_info = ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 1. 先从LLM获取主题类似于记忆系统的做法
topics = []
# try:
@@ -173,7 +170,7 @@ class PromptBuilder:
# hippocampus = HippocampusManager.get_instance()._hippocampus
# topic_num = min(5, max(1, int(len(message) * 0.1)))
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
# # 提取关键词
# topics = re.findall(r"<([^>]+)>", topics_response[0])
# if not topics:
@@ -184,7 +181,7 @@ class PromptBuilder:
# for topic in ",".join(topics).replace("", ",").replace("、", ",").replace(" ", ",").split(",")
# if topic.strip()
# ]
# logger.info(f"从LLM提取的主题: {', '.join(topics)}")
# except Exception as e:
# logger.error(f"从LLM提取主题失败: {str(e)}")
@@ -192,7 +189,7 @@ class PromptBuilder:
# words = jieba.cut(message)
# topics = [word for word in words if len(word) > 1][:5]
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
# 如果无法提取到主题,直接使用整个消息
if not topics:
logger.info("未能提取到任何主题,使用整个消息进行查询")
@@ -200,26 +197,26 @@ class PromptBuilder:
if not embedding:
logger.error("获取消息嵌入向量失败")
return ""
related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}")
return related_info
# 2. 对每个主题进行知识库查询
logger.info(f"开始处理{len(topics)}个主题的知识库查询")
# 优化批量获取嵌入向量减少API调用
embeddings = {}
topics_batch = [topic for topic in topics if len(topic) > 0]
if message: # 确保消息非空
topics_batch.append(message)
# 批量获取嵌入向量
embed_start_time = time.time()
for text in topics_batch:
if not text or len(text.strip()) == 0:
continue
try:
embedding = await get_embedding(text, request_type="prompt_build")
if embedding:
@@ -228,17 +225,17 @@ class PromptBuilder:
logger.warning(f"获取'{text}'的嵌入向量失败")
except Exception as e:
logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}")
logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}")
if not embeddings:
logger.error("所有嵌入向量获取失败")
return ""
# 3. 对每个主题进行知识库查询
all_results = []
query_start_time = time.time()
# 首先添加原始消息的查询结果
if message in embeddings:
original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True)
@@ -247,12 +244,12 @@ class PromptBuilder:
result["topic"] = "原始消息"
all_results.extend(original_results)
logger.info(f"原始消息查询到{len(original_results)}条结果")
# 然后添加每个主题的查询结果
for topic in topics:
if not topic or topic not in embeddings:
continue
try:
topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True)
if topic_results:
@@ -263,9 +260,9 @@ class PromptBuilder:
logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果")
except Exception as e:
logger.error(f"查询主题'{topic}'时发生错误: {str(e)}")
logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果")
# 4. 去重和过滤
process_start_time = time.time()
unique_contents = set()
@@ -275,14 +272,16 @@ class PromptBuilder:
if content not in unique_contents:
unique_contents.add(content)
filtered_results.append(result)
# 5. 按相似度排序
filtered_results.sort(key=lambda x: x["similarity"], reverse=True)
# 6. 限制总数量最多10条
filtered_results = filtered_results[:10]
logger.info(f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果")
logger.info(
f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果"
)
# 7. 格式化输出
if filtered_results:
format_start_time = time.time()
@@ -292,7 +291,7 @@ class PromptBuilder:
if topic not in grouped_results:
grouped_results[topic] = []
grouped_results[topic].append(result)
# 按主题组织输出
for topic, results in grouped_results.items():
related_info += f"【主题: {topic}\n"
@@ -303,13 +302,15 @@ class PromptBuilder:
# related_info += f"{i}. [{similarity:.2f}] {content}\n"
related_info += f"{content}\n"
related_info += "\n"
logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}")
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}")
return related_info
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False) -> Union[str, list]:
def get_info_from_db(
self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
) -> Union[str, list]:
if not query_embedding:
return "" if not return_raw else []
# 使用余弦相似度计算

View File

@@ -28,6 +28,7 @@ chat_config = LogConfig(
logger = get_module_logger("think_flow_chat", config=chat_config)
class ThinkFlowChat:
def __init__(self):
self.storage = MessageStorage()
@@ -96,7 +97,7 @@ class ThinkFlowChat:
)
if not mark_head:
mark_head = True
# print(f"thinking_start_time:{bot_message.thinking_start_time}")
message_set.add_message(bot_message)
message_manager.add_message(message_set)
@@ -110,7 +111,7 @@ class ThinkFlowChat:
if emoji_raw:
emoji_path, description = emoji_raw
emoji_cq = image_path_to_base64(emoji_path)
# logger.info(emoji_cq)
thinking_time_point = round(message.message_info.time, 2)
@@ -130,7 +131,7 @@ class ThinkFlowChat:
is_head=False,
is_emoji=True,
)
# logger.info("22222222222222")
message_manager.add_message(bot_message)
@@ -180,7 +181,7 @@ class ThinkFlowChat:
await message.process()
logger.debug(f"消息处理成功{message.processed_plain_text}")
# 过滤词/正则表达式过滤
if self._check_ban_words(message.processed_plain_text, chat, userinfo) or self._check_ban_regex(
message.raw_message, chat, userinfo
@@ -190,7 +191,7 @@ class ThinkFlowChat:
await self.storage.store_message(message, chat)
logger.debug(f"存储成功{message.processed_plain_text}")
# 记忆激活
timer1 = time.time()
interested_rate = await HippocampusManager.get_instance().get_activate_from_text(
@@ -214,15 +215,13 @@ class ThinkFlowChat:
# 处理提及
is_mentioned, reply_probability = is_mentioned_bot_in_message(message)
# 计算回复意愿
current_willing_old = willing_manager.get_willing(chat_stream=chat)
# current_willing_new = (heartflow.get_subheartflow(chat.stream_id).current_state.willing - 5) / 4
# current_willing = (current_willing_old + current_willing_new) / 2
# current_willing = (current_willing_old + current_willing_new) / 2
# 有点bug
current_willing = current_willing_old
willing_manager.set_willing(chat.stream_id, current_willing)
# 意愿激活
@@ -258,7 +257,7 @@ class ThinkFlowChat:
if random() < reply_probability:
try:
do_reply = True
# 创建思考消息
try:
timer1 = time.time()
@@ -267,9 +266,9 @@ class ThinkFlowChat:
timing_results["创建思考消息"] = timer2 - timer1
except Exception as e:
logger.error(f"心流创建思考消息失败: {e}")
try:
# 观察
# 观察
timer1 = time.time()
await heartflow.get_subheartflow(chat.stream_id).do_observe()
timer2 = time.time()
@@ -280,12 +279,14 @@ class ThinkFlowChat:
# 思考前脑内状态
try:
timer1 = time.time()
await heartflow.get_subheartflow(chat.stream_id).do_thinking_before_reply(message.processed_plain_text)
await heartflow.get_subheartflow(chat.stream_id).do_thinking_before_reply(
message.processed_plain_text
)
timer2 = time.time()
timing_results["思考前脑内状态"] = timer2 - timer1
except Exception as e:
logger.error(f"心流思考前脑内状态失败: {e}")
# 生成回复
timer1 = time.time()
response_set = await self.gpt.generate_response(message)

View File

@@ -35,7 +35,6 @@ class ResponseGenerator:
async def generate_response(self, message: MessageThinking) -> Optional[Union[str, List[str]]]:
"""根据当前模型类型选择对应的生成函数"""
logger.info(
f"思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}"
)
@@ -178,4 +177,3 @@ class ResponseGenerator:
# print(f"得到了处理后的llm返回{processed_response}")
return processed_response

View File

@@ -21,22 +21,21 @@ class PromptBuilder:
async def _build_prompt(
self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
) -> tuple[str, str]:
current_mind_info = heartflow.get_subheartflow(stream_id).current_mind
individuality = Individuality.get_instance()
prompt_personality = individuality.get_prompt(type = "personality",x_person = 2,level = 1)
prompt_identity = individuality.get_prompt(type = "identity",x_person = 2,level = 1)
prompt_personality = individuality.get_prompt(type="personality", x_person=2, level=1)
prompt_identity = individuality.get_prompt(type="identity", x_person=2, level=1)
# 关系
who_chat_in_group = [(chat_stream.user_info.platform,
chat_stream.user_info.user_id,
chat_stream.user_info.user_nickname)]
who_chat_in_group = [
(chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname)
]
who_chat_in_group += get_recent_group_speaker(
stream_id,
(chat_stream.user_info.platform, chat_stream.user_info.user_id),
limit=global_config.MAX_CONTEXT_SIZE,
)
relation_prompt = ""
for person in who_chat_in_group:
relation_prompt += await relationship_manager.build_relationship_info(person)
@@ -100,7 +99,7 @@ class PromptBuilder:
涉及政治敏感以及违法违规的内容请规避。"""
logger.info("开始构建prompt")
prompt = f"""
{relation_prompt_all}\n
{chat_target}
@@ -114,7 +113,7 @@ class PromptBuilder:
请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,尽量不要说你说过的话
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
{moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。"""
return prompt

View File

@@ -3,6 +3,7 @@ import tomlkit
from pathlib import Path
from datetime import datetime
def update_config():
print("开始更新配置文件...")
# 获取根目录路径
@@ -25,11 +26,11 @@ def update_config():
print(f"发现旧配置文件: {old_config_path}")
with open(old_config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
# 生成带时间戳的新文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml"
# 移动旧配置文件到old目录
shutil.move(old_config_path, old_backup_path)
print(f"已备份旧配置文件到: {old_backup_path}")

View File

@@ -24,7 +24,7 @@ config_config = LogConfig(
# 配置主程序日志格式
logger = get_module_logger("config", config=config_config)
#考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
is_test = False
mai_version_main = "0.6.1"
mai_version_fix = ""
@@ -39,6 +39,7 @@ else:
else:
mai_version = mai_version_main
def update_config():
# 获取根目录路径
root_dir = Path(__file__).parent.parent.parent.parent
@@ -54,7 +55,7 @@ def update_config():
# 检查配置文件是否存在
if not old_config_path.exists():
logger.info("配置文件不存在,从模板创建新配置")
#创建文件夹
# 创建文件夹
old_config_dir.mkdir(parents=True, exist_ok=True)
shutil.copy2(template_path, old_config_path)
logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}")
@@ -84,7 +85,7 @@ def update_config():
# 生成带时间戳的新文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml"
# 移动旧配置文件到old目录
shutil.move(old_config_path, old_backup_path)
logger.info(f"已备份旧配置文件到: {old_backup_path}")
@@ -127,6 +128,7 @@ def update_config():
f.write(tomlkit.dumps(new_config))
logger.info("配置文件更新完成")
logger = get_module_logger("config")
@@ -148,17 +150,21 @@ class BotConfig:
ban_user_id = set()
# personality
personality_core = "用一句话或几句话描述人格的核心特点" # 建议20字以内谁再写3000字小作文敲谁脑袋
personality_sides: List[str] = field(default_factory=lambda: [
"用一句话或几句话描述人格的一些侧面",
"用一句话或几句话描述人格的一些侧面",
"用一句话或几句话描述人格的一些侧面"
])
personality_core = "用一句话或几句话描述人格的核心特点" # 建议20字以内谁再写3000字小作文敲谁脑袋
personality_sides: List[str] = field(
default_factory=lambda: [
"用一句话或几句话描述人格的一些侧面",
"用一句话或几句话描述人格的一些侧面",
"用一句话或几句话描述人格的一些侧面",
]
)
# identity
identity_detail: List[str] = field(default_factory=lambda: [
"身份特点",
"身份特点",
])
identity_detail: List[str] = field(
default_factory=lambda: [
"身份特点",
"身份特点",
]
)
height: int = 170 # 身高 单位厘米
weight: int = 50 # 体重 单位千克
age: int = 20 # 年龄 单位岁
@@ -181,22 +187,22 @@ class BotConfig:
ban_words = set()
ban_msgs_regex = set()
#heartflow
# heartflow
# enable_heartflow: bool = False # 是否启用心流
sub_heart_flow_update_interval: int = 60 # 子心流更新频率,间隔 单位秒
sub_heart_flow_freeze_time: int = 120 # 子心流冻结时间,超过这个时间没有回复,子心流会冻结,间隔 单位秒
sub_heart_flow_stop_time: int = 600 # 子心流停止时间,超过这个时间没有回复,子心流会停止,间隔 单位秒
heart_flow_update_interval: int = 300 # 心流更新频率,间隔 单位秒
# willing
willing_mode: str = "classical" # 意愿模式
response_willing_amplifier: float = 1.0 # 回复意愿放大系数
response_interested_rate_amplifier: float = 1.0 # 回复兴趣度放大系数
down_frequency_rate: float = 3 # 降低回复频率的群组回复意愿降低系数
emoji_response_penalty: float = 0.0 # 表情包回复惩罚
mentioned_bot_inevitable_reply: bool = False # 提及 bot 必然回复
at_bot_inevitable_reply: bool = False # @bot 必然回复
mentioned_bot_inevitable_reply: bool = False # 提及 bot 必然回复
at_bot_inevitable_reply: bool = False # @bot 必然回复
# response
response_mode: str = "heart_flow" # 回复策略
@@ -354,7 +360,6 @@ class BotConfig:
"""从TOML配置文件加载配置"""
config = cls()
def personality(parent: dict):
personality_config = parent["personality"]
if config.INNER_VERSION in SpecifierSet(">=1.2.4"):
@@ -418,13 +423,21 @@ class BotConfig:
config.max_response_length = response_config.get("max_response_length", config.max_response_length)
if config.INNER_VERSION in SpecifierSet(">=1.0.4"):
config.response_mode = response_config.get("response_mode", config.response_mode)
def heartflow(parent: dict):
heartflow_config = parent["heartflow"]
config.sub_heart_flow_update_interval = heartflow_config.get("sub_heart_flow_update_interval", config.sub_heart_flow_update_interval)
config.sub_heart_flow_freeze_time = heartflow_config.get("sub_heart_flow_freeze_time", config.sub_heart_flow_freeze_time)
config.sub_heart_flow_stop_time = heartflow_config.get("sub_heart_flow_stop_time", config.sub_heart_flow_stop_time)
config.heart_flow_update_interval = heartflow_config.get("heart_flow_update_interval", config.heart_flow_update_interval)
config.sub_heart_flow_update_interval = heartflow_config.get(
"sub_heart_flow_update_interval", config.sub_heart_flow_update_interval
)
config.sub_heart_flow_freeze_time = heartflow_config.get(
"sub_heart_flow_freeze_time", config.sub_heart_flow_freeze_time
)
config.sub_heart_flow_stop_time = heartflow_config.get(
"sub_heart_flow_stop_time", config.sub_heart_flow_stop_time
)
config.heart_flow_update_interval = heartflow_config.get(
"heart_flow_update_interval", config.heart_flow_update_interval
)
def willing(parent: dict):
willing_config = parent["willing"]

View File

@@ -14,6 +14,7 @@ from src.common.logger import get_module_logger, LogConfig, MEMORY_STYLE_CONFIG
from src.plugins.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
from .memory_config import MemoryConfig
def get_closest_chat_from_db(length: int, timestamp: str):
# print(f"获取最接近指定时间戳的聊天记录,长度: {length}, 时间戳: {timestamp}")
# print(f"当前时间: {timestamp},转换后时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(timestamp))}")

View File

@@ -179,7 +179,6 @@ class LLM_request:
# logger.debug(f"{logger_msg}发送请求到URL: {api_url}")
# logger.info(f"使用模型: {self.model_name}")
# 构建请求体
if image_base64:
payload = await self._build_payload(prompt, image_base64, image_format)
@@ -205,13 +204,17 @@ class LLM_request:
# 处理需要重试的状态码
if response.status in policy["retry_codes"]:
wait_time = policy["base_wait"] * (2**retry)
logger.warning(f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试")
logger.warning(
f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试"
)
if response.status == 413:
logger.warning("请求体过大,尝试压缩...")
image_base64 = compress_base64_image_by_scale(image_base64)
payload = await self._build_payload(prompt, image_base64, image_format)
elif response.status in [500, 503]:
logger.error(f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}")
logger.error(
f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
)
raise RuntimeError("服务器负载过高模型恢复失败QAQ")
else:
logger.warning(f"模型 {self.model_name} 请求限制(429),等待{wait_time}秒后重试...")
@@ -219,7 +222,9 @@ class LLM_request:
await asyncio.sleep(wait_time)
continue
elif response.status in policy["abort_codes"]:
logger.error(f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}")
logger.error(
f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
)
# 尝试获取并记录服务器返回的详细错误信息
try:
error_json = await response.json()
@@ -257,7 +262,9 @@ class LLM_request:
):
old_model_name = self.model_name
self.model_name = self.model_name[4:] # 移除"Pro/"前缀
logger.warning(f"检测到403错误模型从 {old_model_name} 降级为 {self.model_name}")
logger.warning(
f"检测到403错误模型从 {old_model_name} 降级为 {self.model_name}"
)
# 对全局配置进行更新
if global_config.llm_normal.get("name") == old_model_name:
@@ -266,7 +273,9 @@ class LLM_request:
if global_config.llm_reasoning.get("name") == old_model_name:
global_config.llm_reasoning["name"] = self.model_name
logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}")
logger.warning(
f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}"
)
# 更新payload中的模型名
if payload and "model" in payload:
@@ -328,7 +337,14 @@ class LLM_request:
await response.release()
# 返回已经累积的内容
result = {
"choices": [{"message": {"content": accumulated_content, "reasoning_content": reasoning_content}}],
"choices": [
{
"message": {
"content": accumulated_content,
"reasoning_content": reasoning_content,
}
}
],
"usage": usage,
}
return (
@@ -345,7 +361,14 @@ class LLM_request:
logger.error(f"清理资源时发生错误: {cleanup_error}")
# 返回已经累积的内容
result = {
"choices": [{"message": {"content": accumulated_content, "reasoning_content": reasoning_content}}],
"choices": [
{
"message": {
"content": accumulated_content,
"reasoning_content": reasoning_content,
}
}
],
"usage": usage,
}
return (
@@ -360,7 +383,9 @@ class LLM_request:
content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
# 构造一个伪result以便调用自定义响应处理器或默认处理器
result = {
"choices": [{"message": {"content": content, "reasoning_content": reasoning_content}}],
"choices": [
{"message": {"content": content, "reasoning_content": reasoning_content}}
],
"usage": usage,
}
return (
@@ -394,7 +419,9 @@ class LLM_request:
# 处理aiohttp抛出的响应错误
if retry < policy["max_retries"] - 1:
wait_time = policy["base_wait"] * (2**retry)
logger.error(f"模型 {self.model_name} HTTP响应错误等待{wait_time}秒后重试... 状态码: {e.status}, 错误: {e.message}")
logger.error(
f"模型 {self.model_name} HTTP响应错误等待{wait_time}秒后重试... 状态码: {e.status}, 错误: {e.message}"
)
try:
if hasattr(e, "response") and e.response and hasattr(e.response, "text"):
error_text = await e.response.text()
@@ -419,13 +446,17 @@ class LLM_request:
else:
logger.error(f"模型 {self.model_name} 服务器错误响应: {error_json}")
except (json.JSONDecodeError, TypeError) as json_err:
logger.warning(f"模型 {self.model_name} 响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}")
logger.warning(
f"模型 {self.model_name} 响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}"
)
except (AttributeError, TypeError, ValueError) as parse_err:
logger.warning(f"模型 {self.model_name} 无法解析响应错误内容: {str(parse_err)}")
await asyncio.sleep(wait_time)
else:
logger.critical(f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {e.status}, 错误: {e.message}")
logger.critical(
f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {e.status}, 错误: {e.message}"
)
# 安全地检查和记录请求详情
if (
image_base64

View File

@@ -139,7 +139,7 @@ class MoodManager:
# 神经质:影响情绪变化速度
neuroticism_factor = 1 + (personality.neuroticism - 0.5) * 0.5
agreeableness_factor = 1 + (personality.agreeableness - 0.5) * 0.5
# 宜人性:影响情绪基准线
if personality.agreeableness < 0.2:
agreeableness_bias = (personality.agreeableness - 0.2) * 2
@@ -151,7 +151,7 @@ class MoodManager:
# 分别计算正向和负向的衰减率
if self.current_mood.valence >= 0:
# 正向情绪衰减
decay_rate_positive = self.decay_rate_valence * (1/agreeableness_factor)
decay_rate_positive = self.decay_rate_valence * (1 / agreeableness_factor)
valence_target = 0 + agreeableness_bias
self.current_mood.valence = valence_target + (self.current_mood.valence - valence_target) * math.exp(
-decay_rate_positive * time_diff * neuroticism_factor
@@ -279,8 +279,9 @@ class MoodManager:
# 限制范围
self.current_mood.valence = max(-1.0, min(1.0, self.current_mood.valence))
self.current_mood.arousal = max(0.0, min(1.0, self.current_mood.arousal))
self._update_mood_text()
logger.info(f"[情绪变化] {emotion}(强度:{intensity:.2f}) | 愉悦度:{old_valence:.2f}->{self.current_mood.valence:.2f}, 唤醒度:{old_arousal:.2f}->{self.current_mood.arousal:.2f} | 心情:{old_mood}->{self.current_mood.text}")
logger.info(
f"[情绪变化] {emotion}(强度:{intensity:.2f}) | 愉悦度:{old_valence:.2f}->{self.current_mood.valence:.2f}, 唤醒度:{old_arousal:.2f}->{self.current_mood.arousal:.2f} | 心情:{old_mood}->{self.current_mood.text}"
)

View File

@@ -8,7 +8,8 @@ import asyncio
import numpy as np
import matplotlib
matplotlib.use('Agg')
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
@@ -30,38 +31,39 @@ PersonInfoManager 类方法功能摘要:
logger = get_module_logger("person_info")
person_info_default = {
"person_id" : None,
"platform" : None,
"user_id" : None,
"nickname" : None,
"person_id": None,
"platform": None,
"user_id": None,
"nickname": None,
# "age" : 0,
"relationship_value" : 0,
"relationship_value": 0,
# "saved" : True,
# "impression" : None,
# "gender" : Unkown,
"konw_time" : 0,
"konw_time": 0,
"msg_interval": 3000,
"msg_interval_list": []
"msg_interval_list": [],
} # 个人信息的各项与默认值在此定义,以下处理会自动创建/补全每一项
class PersonInfoManager:
def __init__(self):
if "person_info" not in db.list_collection_names():
db.create_collection("person_info")
db.person_info.create_index("person_id", unique=True)
def get_person_id(self, platform:str, user_id:int):
def get_person_id(self, platform: str, user_id: int):
"""获取唯一id"""
components = [platform, str(user_id)]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
async def create_person_info(self, person_id:str, data:dict = None):
async def create_person_info(self, person_id: str, data: dict = None):
"""创建一个项"""
if not person_id:
logger.debug("创建失败personid不存在")
return
_person_info_default = copy.deepcopy(person_info_default)
_person_info_default["person_id"] = person_id
@@ -72,19 +74,16 @@ class PersonInfoManager:
db.person_info.insert_one(_person_info_default)
async def update_one_field(self, person_id:str, field_name:str, value, Data:dict = None):
async def update_one_field(self, person_id: str, field_name: str, value, Data: dict = None):
"""更新某一个字段,会补全"""
if field_name not in person_info_default.keys():
logger.debug(f"更新'{field_name}'失败,未定义的字段")
return
document = db.person_info.find_one({"person_id": person_id})
if document:
db.person_info.update_one(
{"person_id": person_id},
{"$set": {field_name: value}}
)
db.person_info.update_one({"person_id": person_id}, {"$set": {field_name: value}})
else:
Data[field_name] = value
logger.debug(f"更新时{person_id}不存在,已新建")
@@ -107,23 +106,20 @@ class PersonInfoManager:
if not person_id:
logger.debug("get_value获取失败person_id不能为空")
return None
if field_name not in person_info_default:
logger.debug(f"get_value获取失败字段'{field_name}'未定义")
return None
document = db.person_info.find_one(
{"person_id": person_id},
{field_name: 1}
)
document = db.person_info.find_one({"person_id": person_id}, {field_name: 1})
if document and field_name in document:
return document[field_name]
else:
default_value = copy.deepcopy(person_info_default[field_name])
logger.debug(f"获取{person_id}{field_name}失败,已返回默认值{default_value}")
return default_value
async def get_values(self, person_id: str, field_names: list) -> dict:
"""获取指定person_id文档的多个字段值若不存在该字段则返回该字段的全局默认值"""
if not person_id:
@@ -139,62 +135,57 @@ class PersonInfoManager:
# 构建查询投影(所有字段都有效才会执行到这里)
projection = {field: 1 for field in field_names}
document = db.person_info.find_one(
{"person_id": person_id},
projection
)
document = db.person_info.find_one({"person_id": person_id}, projection)
result = {}
for field in field_names:
result[field] = copy.deepcopy(
document.get(field, person_info_default[field])
if document else person_info_default[field]
document.get(field, person_info_default[field]) if document else person_info_default[field]
)
return result
async def del_all_undefined_field(self):
"""删除所有项里的未定义字段"""
# 获取所有已定义的字段名
defined_fields = set(person_info_default.keys())
try:
# 遍历集合中的所有文档
for document in db.person_info.find({}):
# 找出文档中未定义的字段
undefined_fields = set(document.keys()) - defined_fields - {'_id'}
undefined_fields = set(document.keys()) - defined_fields - {"_id"}
if undefined_fields:
# 构建更新操作,使用$unset删除未定义字段
update_result = db.person_info.update_one(
{'_id': document['_id']},
{'$unset': {field: 1 for field in undefined_fields}}
{"_id": document["_id"]}, {"$unset": {field: 1 for field in undefined_fields}}
)
if update_result.modified_count > 0:
logger.debug(f"已清理文档 {document['_id']} 的未定义字段: {undefined_fields}")
return
except Exception as e:
logger.error(f"清理未定义字段时出错: {e}")
return
async def get_specific_value_list(
self,
field_name: str,
way: Callable[[Any], bool], # 接受任意类型值
) ->Dict[str, Any]:
self,
field_name: str,
way: Callable[[Any], bool], # 接受任意类型值
) -> Dict[str, Any]:
"""
获取满足条件的字段值字典
Args:
field_name: 目标字段名
way: 判断函数 (value: Any) -> bool
Returns:
{person_id: value} | {}
Example:
# 查找所有nickname包含"admin"的用户
result = manager.specific_value_list(
@@ -208,10 +199,7 @@ class PersonInfoManager:
try:
result = {}
for doc in db.person_info.find(
{field_name: {"$exists": True}},
{"person_id": 1, field_name: 1, "_id": 0}
):
for doc in db.person_info.find({field_name: {"$exists": True}}, {"person_id": 1, field_name: 1, "_id": 0}):
try:
value = doc[field_name]
if way(value):
@@ -225,11 +213,11 @@ class PersonInfoManager:
except Exception as e:
logger.error(f"数据库查询失败: {str(e)}", exc_info=True)
return {}
async def personal_habit_deduction(self):
"""启动个人信息推断,每天根据一定条件推断一次"""
try:
while(1):
while 1:
await asyncio.sleep(60)
current_time = datetime.datetime.now()
logger.info(f"个人信息推断启动: {current_time.strftime('%Y-%m-%d %H:%M:%S')}")
@@ -237,8 +225,7 @@ class PersonInfoManager:
# "msg_interval"推断
msg_interval_map = False
msg_interval_lists = await self.get_specific_value_list(
"msg_interval_list",
lambda x: isinstance(x, list) and len(x) >= 100
"msg_interval_list", lambda x: isinstance(x, list) and len(x) >= 100
)
for person_id, msg_interval_list_ in msg_interval_lists.items():
try:
@@ -258,23 +245,23 @@ class PersonInfoManager:
log_dir.mkdir(parents=True, exist_ok=True)
plt.figure(figsize=(10, 6))
time_series = pd.Series(time_interval)
plt.hist(time_series, bins=50, density=True, alpha=0.4, color='pink', label='Histogram')
time_series.plot(kind='kde', color='mediumpurple', linewidth=1, label='Density')
plt.hist(time_series, bins=50, density=True, alpha=0.4, color="pink", label="Histogram")
time_series.plot(kind="kde", color="mediumpurple", linewidth=1, label="Density")
plt.grid(True, alpha=0.2)
plt.xlim(0, 8000)
plt.title(f"Message Interval Distribution (User: {person_id[:8]}...)")
plt.xlabel("Interval (ms)")
plt.ylabel("Density")
plt.legend(framealpha=0.9, facecolor='white')
plt.legend(framealpha=0.9, facecolor="white")
img_path = log_dir / f"interval_distribution_{person_id[:8]}.png"
plt.savefig(img_path)
plt.close()
# 画图
q25, q75 = np.percentile(time_interval, [25, 75])
iqr = q75 - q25
filtered = [x for x in time_interval if (q25 - 1.5*iqr) <= x <= (q75 + 1.5*iqr)]
filtered = [x for x in time_interval if (q25 - 1.5 * iqr) <= x <= (q75 + 1.5 * iqr)]
msg_interval = int(round(np.percentile(filtered, 80)))
await self.update_one_field(person_id, "msg_interval", msg_interval)
logger.debug(f"用户{person_id}的msg_interval已经被更新为{msg_interval}")

View File

@@ -12,6 +12,7 @@ relationship_config = LogConfig(
)
logger = get_module_logger("rel_manager", config=relationship_config)
class RelationshipManager:
def __init__(self):
self.positive_feedback_value = 0 # 正反馈系统
@@ -22,6 +23,7 @@ class RelationshipManager:
def mood_manager(self):
if self._mood_manager is None:
from ..moods.moods import MoodManager # 延迟导入
self._mood_manager = MoodManager.get_instance()
return self._mood_manager
@@ -51,27 +53,27 @@ class RelationshipManager:
self.positive_feedback_value -= 1
elif self.positive_feedback_value > 0:
self.positive_feedback_value = 0
if abs(self.positive_feedback_value) > 1:
logger.info(f"触发mood变更增益当前增益系数{self.gain_coefficient[abs(self.positive_feedback_value)]}")
def mood_feedback(self, value):
"""情绪反馈"""
mood_manager = self.mood_manager
mood_gain = (mood_manager.get_current_mood().valence) ** 2 \
* math.copysign(1, value * mood_manager.get_current_mood().valence)
mood_gain = (mood_manager.get_current_mood().valence) ** 2 * math.copysign(
1, value * mood_manager.get_current_mood().valence
)
value += value * mood_gain
logger.info(f"当前relationship增益系数{mood_gain:.3f}")
return value
def feedback_to_mood(self, mood_value):
"""对情绪的反馈"""
coefficient = self.gain_coefficient[abs(self.positive_feedback_value)]
if (mood_value > 0 and self.positive_feedback_value > 0
or mood_value < 0 and self.positive_feedback_value < 0):
return mood_value*coefficient
if mood_value > 0 and self.positive_feedback_value > 0 or mood_value < 0 and self.positive_feedback_value < 0:
return mood_value * coefficient
else:
return mood_value/coefficient
return mood_value / coefficient
async def calculate_update_relationship_value(self, chat_stream: ChatStream, label: str, stance: str) -> None:
"""计算并变更关系值
@@ -88,7 +90,7 @@ class RelationshipManager:
"中立": 1,
"反对": 2,
}
valuedict = {
"开心": 1.5,
"愤怒": -2.0,
@@ -103,10 +105,10 @@ class RelationshipManager:
person_id = person_info_manager.get_person_id(chat_stream.user_info.platform, chat_stream.user_info.user_id)
data = {
"platform" : chat_stream.user_info.platform,
"user_id" : chat_stream.user_info.user_id,
"nickname" : chat_stream.user_info.user_nickname,
"konw_time" : int(time.time())
"platform": chat_stream.user_info.platform,
"user_id": chat_stream.user_info.user_id,
"nickname": chat_stream.user_info.user_nickname,
"konw_time": int(time.time()),
}
old_value = await person_info_manager.get_value(person_id, "relationship_value")
old_value = self.ensure_float(old_value, person_id)
@@ -200,4 +202,5 @@ class RelationshipManager:
logger.warning(f"[关系管理] {person_id}值转换失败(原始值:{value}已重置为0")
return 0.0
relationship_manager = RelationshipManager()

View File

@@ -14,7 +14,7 @@ from src.common.logger import get_module_logger, SCHEDULE_STYLE_CONFIG, LogConfi
from src.plugins.models.utils_model import LLM_request # noqa: E402
from src.plugins.config.config import global_config # noqa: E402
TIME_ZONE = tz.gettz(global_config.TIME_ZONE) # 设置时区
TIME_ZONE = tz.gettz(global_config.TIME_ZONE) # 设置时区
schedule_config = LogConfig(
@@ -31,10 +31,16 @@ class ScheduleGenerator:
def __init__(self):
# 使用离线LLM模型
self.llm_scheduler_all = LLM_request(
model=global_config.llm_reasoning, temperature=global_config.SCHEDULE_TEMPERATURE, max_tokens=7000, request_type="schedule"
model=global_config.llm_reasoning,
temperature=global_config.SCHEDULE_TEMPERATURE,
max_tokens=7000,
request_type="schedule",
)
self.llm_scheduler_doing = LLM_request(
model=global_config.llm_normal, temperature=global_config.SCHEDULE_TEMPERATURE, max_tokens=2048, request_type="schedule"
model=global_config.llm_normal,
temperature=global_config.SCHEDULE_TEMPERATURE,
max_tokens=2048,
request_type="schedule",
)
self.today_schedule_text = ""

View File

@@ -53,18 +53,18 @@ class KnowledgeLibrary:
# 按空行分割内容
paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
chunks = []
for para in paragraphs:
para_length = len(para)
# 如果段落长度小于等于最大长度,直接添加
if para_length <= max_length:
chunks.append(para)
else:
# 如果段落超过最大长度,则按最大长度切分
for i in range(0, para_length, max_length):
chunks.append(para[i:i + max_length])
chunks.append(para[i : i + max_length])
return chunks
def get_embedding(self, text: str) -> list:

File diff suppressed because it is too large Load Diff