refactor: 全部代码格式化

This commit is contained in:
Rikki
2025-03-30 04:56:46 +08:00
parent 7adaa2f5a8
commit b2fc824afd
21 changed files with 491 additions and 514 deletions

View File

@@ -81,7 +81,9 @@ MEMORY_STYLE_CONFIG = {
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 海马体 | {message}"), "file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 海马体 | {message}"),
}, },
"simple": { "simple": {
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-yellow>海马体</light-yellow> | <light-yellow>{message}</light-yellow>"), "console_format": (
"<green>{time:MM-DD HH:mm}</green> | <light-yellow>海马体</light-yellow> | <light-yellow>{message}</light-yellow>"
),
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 海马体 | {message}"), "file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 海马体 | {message}"),
}, },
} }
@@ -152,7 +154,9 @@ HEARTFLOW_STYLE_CONFIG = {
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦大脑袋 | {message}"), "file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦大脑袋 | {message}"),
}, },
"simple": { "simple": {
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-green>麦麦大脑袋</light-green> | <light-green>{message}</light-green>"), # noqa: E501 "console_format": (
"<green>{time:MM-DD HH:mm}</green> | <light-green>麦麦大脑袋</light-green> | <light-green>{message}</light-green>"
), # noqa: E501
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦大脑袋 | {message}"), "file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦大脑袋 | {message}"),
}, },
} }
@@ -223,7 +227,9 @@ CHAT_STYLE_CONFIG = {
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}"), "file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}"),
}, },
"simple": { "simple": {
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-blue>见闻</light-blue> | <green>{message}</green>"), # noqa: E501 "console_format": (
"<green>{time:MM-DD HH:mm}</green> | <light-blue>见闻</light-blue> | <green>{message}</green>"
), # noqa: E501
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}"), "file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}"),
}, },
} }
@@ -240,7 +246,9 @@ SUB_HEARTFLOW_STYLE_CONFIG = {
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦小脑袋 | {message}"), "file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦小脑袋 | {message}"),
}, },
"simple": { "simple": {
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-blue>麦麦小脑袋</light-blue> | <light-blue>{message}</light-blue>"), # noqa: E501 "console_format": (
"<green>{time:MM-DD HH:mm}</green> | <light-blue>麦麦小脑袋</light-blue> | <light-blue>{message}</light-blue>"
), # noqa: E501
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦小脑袋 | {message}"), "file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦小脑袋 | {message}"),
}, },
} }
@@ -257,14 +265,14 @@ WILLING_STYLE_CONFIG = {
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"), "file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"),
}, },
"simple": { "simple": {
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-blue>意愿</light-blue> | <light-blue>{message}</light-blue>"), # noqa: E501 "console_format": (
"<green>{time:MM-DD HH:mm}</green> | <light-blue>意愿</light-blue> | <light-blue>{message}</light-blue>"
), # noqa: E501
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"), "file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"),
}, },
} }
# 根据SIMPLE_OUTPUT选择配置 # 根据SIMPLE_OUTPUT选择配置
MEMORY_STYLE_CONFIG = MEMORY_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else MEMORY_STYLE_CONFIG["advanced"] MEMORY_STYLE_CONFIG = MEMORY_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else MEMORY_STYLE_CONFIG["advanced"]
TOPIC_STYLE_CONFIG = TOPIC_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else TOPIC_STYLE_CONFIG["advanced"] TOPIC_STYLE_CONFIG = TOPIC_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else TOPIC_STYLE_CONFIG["advanced"]
@@ -275,7 +283,9 @@ MOOD_STYLE_CONFIG = MOOD_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else MOOD_STYLE
RELATION_STYLE_CONFIG = RELATION_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else RELATION_STYLE_CONFIG["advanced"] RELATION_STYLE_CONFIG = RELATION_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else RELATION_STYLE_CONFIG["advanced"]
SCHEDULE_STYLE_CONFIG = SCHEDULE_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else SCHEDULE_STYLE_CONFIG["advanced"] SCHEDULE_STYLE_CONFIG = SCHEDULE_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else SCHEDULE_STYLE_CONFIG["advanced"]
HEARTFLOW_STYLE_CONFIG = HEARTFLOW_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else HEARTFLOW_STYLE_CONFIG["advanced"] HEARTFLOW_STYLE_CONFIG = HEARTFLOW_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else HEARTFLOW_STYLE_CONFIG["advanced"]
SUB_HEARTFLOW_STYLE_CONFIG = SUB_HEARTFLOW_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else SUB_HEARTFLOW_STYLE_CONFIG["advanced"] # noqa: E501 SUB_HEARTFLOW_STYLE_CONFIG = (
SUB_HEARTFLOW_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else SUB_HEARTFLOW_STYLE_CONFIG["advanced"]
) # noqa: E501
WILLING_STYLE_CONFIG = WILLING_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else WILLING_STYLE_CONFIG["advanced"] WILLING_STYLE_CONFIG = WILLING_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else WILLING_STYLE_CONFIG["advanced"]

View File

@@ -6,6 +6,7 @@ import time
from datetime import datetime from datetime import datetime
from typing import Dict, List from typing import Dict, List
from typing import Optional from typing import Optional
sys.path.insert(0, sys.path[0] + "/../") sys.path.insert(0, sys.path[0] + "/../")
sys.path.insert(0, sys.path[0] + "/../") sys.path.insert(0, sys.path[0] + "/../")
from src.common.logger import get_module_logger from src.common.logger import get_module_logger

View File

@@ -81,16 +81,13 @@ class ChatBot:
logger.debug(f"2消息处理时间: {timer2 - timer1}") logger.debug(f"2消息处理时间: {timer2 - timer1}")
# 过滤词/正则表达式过滤 # 过滤词/正则表达式过滤
if ( if self._check_ban_words(message.processed_plain_text, chat, userinfo) or self._check_ban_regex(
self._check_ban_words(message.processed_plain_text, chat, userinfo) message.raw_message, chat, userinfo
or self._check_ban_regex(message.raw_message, chat, userinfo)
): ):
return return
await self.storage.store_message(message, chat) await self.storage.store_message(message, chat)
timer1 = time.time() timer1 = time.time()
interested_rate = 0 interested_rate = 0
interested_rate = await HippocampusManager.get_instance().get_activate_from_text( interested_rate = await HippocampusManager.get_instance().get_activate_from_text(
@@ -99,7 +96,6 @@ class ChatBot:
timer2 = time.time() timer2 = time.time()
logger.debug(f"3记忆激活时间: {timer2 - timer1}") logger.debug(f"3记忆激活时间: {timer2 - timer1}")
is_mentioned = is_mentioned_bot_in_message(message) is_mentioned = is_mentioned_bot_in_message(message)
if global_config.enable_think_flow: if global_config.enable_think_flow:
@@ -128,11 +124,11 @@ class ChatBot:
if chat.group_info: if chat.group_info:
if chat.group_info.group_name: if chat.group_info.group_name:
mes_name_dict = chat.group_info.group_name mes_name_dict = chat.group_info.group_name
mes_name = mes_name_dict.get('group_name', '无名群聊') mes_name = mes_name_dict.get("group_name", "无名群聊")
else: else:
mes_name = '群聊' mes_name = "群聊"
else: else:
mes_name = '私聊' mes_name = "私聊"
# 打印收到的信息的信息 # 打印收到的信息的信息
current_time = time.strftime("%H:%M:%S", time.localtime(messageinfo.time)) current_time = time.strftime("%H:%M:%S", time.localtime(messageinfo.time))
@@ -146,7 +142,6 @@ class ChatBot:
if "maimcore_reply_probability_gain" in message.message_info.additional_config.keys(): if "maimcore_reply_probability_gain" in message.message_info.additional_config.keys():
reply_probability += message.message_info.additional_config["maimcore_reply_probability_gain"] reply_probability += message.message_info.additional_config["maimcore_reply_probability_gain"]
# 开始组织语言 # 开始组织语言
if random() < reply_probability: if random() < reply_probability:
timer1 = time.time() timer1 = time.time()
@@ -224,7 +219,6 @@ class ChatBot:
heartflow.get_subheartflow(stream_id).do_after_reply(response_set, chat_talking_prompt) heartflow.get_subheartflow(stream_id).do_after_reply(response_set, chat_talking_prompt)
async def _send_response_messages(self, message, chat, response_set, thinking_id): async def _send_response_messages(self, message, chat, response_set, thinking_id):
container = message_manager.get_container(chat.stream_id) container = message_manager.get_container(chat.stream_id)
thinking_message = None thinking_message = None
@@ -313,9 +307,7 @@ class ChatBot:
""" """
stance, emotion = await self.gpt._get_emotion_tags(raw_content, message.processed_plain_text) stance, emotion = await self.gpt._get_emotion_tags(raw_content, message.processed_plain_text)
logger.debug(f"'{response}' 立场为:{stance} 获取到的情感标签为:{emotion}") logger.debug(f"'{response}' 立场为:{stance} 获取到的情感标签为:{emotion}")
await relationship_manager.calculate_update_relationship_value( await relationship_manager.calculate_update_relationship_value(chat_stream=chat, label=emotion, stance=stance)
chat_stream=chat, label=emotion, stance=stance
)
self.mood_manager.update_mood_from_emotion(emotion, global_config.mood_intensity_factor) self.mood_manager.update_mood_from_emotion(emotion, global_config.mood_intensity_factor)
def _check_ban_words(self, text: str, chat, userinfo) -> bool: def _check_ban_words(self, text: str, chat, userinfo) -> bool:
@@ -332,8 +324,7 @@ class ChatBot:
for word in global_config.ban_words: for word in global_config.ban_words:
if word in text: if word in text:
logger.info( logger.info(
f"[{chat.group_info.group_name if chat.group_info else '私聊'}]" f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}"
f"{userinfo.user_nickname}:{text}"
) )
logger.info(f"[过滤词识别]消息中含有{word}filtered") logger.info(f"[过滤词识别]消息中含有{word}filtered")
return True return True
@@ -353,12 +344,12 @@ class ChatBot:
for pattern in global_config.ban_msgs_regex: for pattern in global_config.ban_msgs_regex:
if re.search(pattern, text): if re.search(pattern, text):
logger.info( logger.info(
f"[{chat.group_info.group_name if chat.group_info else '私聊'}]" f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}"
f"{userinfo.user_nickname}:{text}"
) )
logger.info(f"[正则表达式过滤]消息匹配到{pattern}filtered") logger.info(f"[正则表达式过滤]消息匹配到{pattern}filtered")
return True return True
return False return False
# 创建全局ChatBot实例 # 创建全局ChatBot实例
chat_bot = ChatBot() chat_bot = ChatBot()

View File

@@ -31,10 +31,7 @@ class ResponseGenerator:
request_type="response", request_type="response",
) )
self.model_normal = LLM_request( self.model_normal = LLM_request(
model=global_config.llm_normal, model=global_config.llm_normal, temperature=0.7, max_tokens=3000, request_type="response"
temperature=0.7,
max_tokens=3000,
request_type="response"
) )
self.model_sum = LLM_request( self.model_sum = LLM_request(
@@ -53,8 +50,9 @@ class ResponseGenerator:
self.current_model_type = "浅浅的" self.current_model_type = "浅浅的"
current_model = self.model_normal current_model = self.model_normal
logger.info(f"{self.current_model_type}思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}") # noqa: E501 logger.info(
f"{self.current_model_type}思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}"
) # noqa: E501
model_response = await self._generate_response_with_model(message, current_model) model_response = await self._generate_response_with_model(message, current_model)
@@ -64,7 +62,6 @@ class ResponseGenerator:
logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response}") logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response}")
model_response = await self._process_response(model_response) model_response = await self._process_response(model_response)
return model_response return model_response
else: else:
logger.info(f"{self.current_model_type}思考,失败") logger.info(f"{self.current_model_type}思考,失败")

View File

@@ -37,7 +37,6 @@ class PromptBuilder:
current_mind_info = heartflow.get_subheartflow(stream_id).current_mind current_mind_info = heartflow.get_subheartflow(stream_id).current_mind
# relation_prompt = "" # relation_prompt = ""
# for person in who_chat_in_group: # for person in who_chat_in_group:
# relation_prompt += relationship_manager.build_relationship_info(person) # relation_prompt += relationship_manager.build_relationship_info(person)
@@ -73,7 +72,6 @@ class PromptBuilder:
chat_talking_prompt = chat_talking_prompt chat_talking_prompt = chat_talking_prompt
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}") # print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
# 使用新的记忆获取方法 # 使用新的记忆获取方法
memory_prompt = "" memory_prompt = ""
start_time = time.time() start_time = time.time()
@@ -165,11 +163,8 @@ class PromptBuilder:
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。 请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
{moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。""" {moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )。"""
return prompt return prompt
def _build_initiative_prompt_select(self, group_id, probability_1=0.8, probability_2=0.1): def _build_initiative_prompt_select(self, group_id, probability_1=0.8, probability_2=0.1):
current_date = time.strftime("%Y-%m-%d", time.localtime()) current_date = time.strftime("%Y-%m-%d", time.localtime())
current_time = time.strftime("%H:%M:%S", time.localtime()) current_time = time.strftime("%H:%M:%S", time.localtime())

View File

@@ -9,9 +9,7 @@ logger = get_module_logger("message_storage")
class MessageStorage: class MessageStorage:
async def store_message( async def store_message(self, message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
self, message: Union[MessageSending, MessageRecv], chat_stream: ChatStream
) -> None:
"""存储消息到数据库""" """存储消息到数据库"""
try: try:
message_data = { message_data = {

View File

@@ -56,6 +56,7 @@ def get_closest_chat_from_db(length: int, timestamp: str):
return [] return []
def calculate_information_content(text): def calculate_information_content(text):
"""计算文本的信息量(熵)""" """计算文本的信息量(熵)"""
char_count = Counter(text) char_count = Counter(text)
@@ -68,6 +69,7 @@ def calculate_information_content(text):
return entropy return entropy
def cosine_similarity(v1, v2): def cosine_similarity(v1, v2):
"""计算余弦相似度""" """计算余弦相似度"""
dot_product = np.dot(v1, v2) dot_product = np.dot(v1, v2)
@@ -223,6 +225,7 @@ class Memory_graph:
return None return None
# 负责海马体与其他部分的交互 # 负责海马体与其他部分的交互
class EntorhinalCortex: class EntorhinalCortex:
def __init__(self, hippocampus): def __init__(self, hippocampus):
@@ -243,7 +246,7 @@ class EntorhinalCortex:
n_hours2=self.config.memory_build_distribution[3], n_hours2=self.config.memory_build_distribution[3],
std_hours2=self.config.memory_build_distribution[4], std_hours2=self.config.memory_build_distribution[4],
weight2=self.config.memory_build_distribution[5], weight2=self.config.memory_build_distribution[5],
total_samples=self.config.build_memory_sample_num total_samples=self.config.build_memory_sample_num,
) )
timestamps = sample_scheduler.get_timestamp_array() timestamps = sample_scheduler.get_timestamp_array()
@@ -251,9 +254,7 @@ class EntorhinalCortex:
chat_samples = [] chat_samples = []
for timestamp in timestamps: for timestamp in timestamps:
messages = self.random_get_msg_snippet( messages = self.random_get_msg_snippet(
timestamp, timestamp, self.config.build_memory_sample_length, max_memorized_time_per_msg
self.config.build_memory_sample_length,
max_memorized_time_per_msg
) )
if messages: if messages:
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600 time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
@@ -504,6 +505,7 @@ class EntorhinalCortex:
logger.success(f"[数据库] 重新同步完成,总耗时: {end_time - start_time:.2f}") logger.success(f"[数据库] 重新同步完成,总耗时: {end_time - start_time:.2f}")
logger.success(f"[数据库] 同步了 {len(memory_nodes)} 个节点和 {len(memory_edges)} 条边") logger.success(f"[数据库] 同步了 {len(memory_nodes)} 个节点和 {len(memory_edges)} 条边")
# 负责整合,遗忘,合并记忆 # 负责整合,遗忘,合并记忆
class ParahippocampalGyrus: class ParahippocampalGyrus:
def __init__(self, hippocampus): def __init__(self, hippocampus):
@@ -567,26 +569,26 @@ class ParahippocampalGyrus:
topic_num = self.hippocampus.calculate_topic_num(input_text, compress_rate) topic_num = self.hippocampus.calculate_topic_num(input_text, compress_rate)
topics_response = await self.hippocampus.llm_topic_judge.generate_response( topics_response = await self.hippocampus.llm_topic_judge.generate_response(
self.hippocampus.find_topic_llm(input_text, topic_num)) self.hippocampus.find_topic_llm(input_text, topic_num)
)
# 使用正则表达式提取<>中的内容 # 使用正则表达式提取<>中的内容
topics = re.findall(r'<([^>]+)>', topics_response[0]) topics = re.findall(r"<([^>]+)>", topics_response[0])
# 如果没有找到<>包裹的内容,返回['none'] # 如果没有找到<>包裹的内容,返回['none']
if not topics: if not topics:
topics = ['none'] topics = ["none"]
else: else:
# 处理提取出的话题 # 处理提取出的话题
topics = [ topics = [
topic.strip() topic.strip()
for topic in ','.join(topics).replace("", ",").replace("", ",").replace(" ", ",").split(",") for topic in ",".join(topics).replace("", ",").replace("", ",").replace(" ", ",").split(",")
if topic.strip() if topic.strip()
] ]
# 过滤掉包含禁用关键词的topic # 过滤掉包含禁用关键词的topic
filtered_topics = [ filtered_topics = [
topic for topic in topics topic for topic in topics if not any(keyword in topic for keyword in self.config.memory_ban_words)
if not any(keyword in topic for keyword in self.config.memory_ban_words)
] ]
logger.debug(f"过滤后话题: {filtered_topics}") logger.debug(f"过滤后话题: {filtered_topics}")
@@ -689,10 +691,7 @@ class ParahippocampalGyrus:
await self.hippocampus.entorhinal_cortex.sync_memory_to_db() await self.hippocampus.entorhinal_cortex.sync_memory_to_db()
end_time = time.time() end_time = time.time()
logger.success( logger.success(f"---------------------记忆构建耗时: {end_time - start_time:.2f} 秒---------------------")
f"---------------------记忆构建耗时: {end_time - start_time:.2f} "
"秒---------------------"
)
async def operation_forget_topic(self, percentage=0.005): async def operation_forget_topic(self, percentage=0.005):
start_time = time.time() start_time = time.time()
@@ -714,11 +713,11 @@ class ParahippocampalGyrus:
# 使用列表存储变化信息 # 使用列表存储变化信息
edge_changes = { edge_changes = {
"weakened": [], # 存储减弱的边 "weakened": [], # 存储减弱的边
"removed": [] # 存储移除的边 "removed": [], # 存储移除的边
} }
node_changes = { node_changes = {
"reduced": [], # 存储减少记忆的节点 "reduced": [], # 存储减少记忆的节点
"removed": [] # 存储移除的节点 "removed": [], # 存储移除的节点
} }
current_time = datetime.datetime.now().timestamp() current_time = datetime.datetime.now().timestamp()
@@ -781,25 +780,30 @@ class ParahippocampalGyrus:
logger.info("[遗忘] 遗忘操作统计:") logger.info("[遗忘] 遗忘操作统计:")
if edge_changes["weakened"]: if edge_changes["weakened"]:
logger.info( logger.info(
f"[遗忘] 减弱的连接 ({len(edge_changes['weakened'])}个): {', '.join(edge_changes['weakened'])}") f"[遗忘] 减弱的连接 ({len(edge_changes['weakened'])}个): {', '.join(edge_changes['weakened'])}"
)
if edge_changes["removed"]: if edge_changes["removed"]:
logger.info( logger.info(
f"[遗忘] 移除的连接 ({len(edge_changes['removed'])}个): {', '.join(edge_changes['removed'])}") f"[遗忘] 移除的连接 ({len(edge_changes['removed'])}个): {', '.join(edge_changes['removed'])}"
)
if node_changes["reduced"]: if node_changes["reduced"]:
logger.info( logger.info(
f"[遗忘] 减少记忆的节点 ({len(node_changes['reduced'])}个): {', '.join(node_changes['reduced'])}") f"[遗忘] 减少记忆的节点 ({len(node_changes['reduced'])}个): {', '.join(node_changes['reduced'])}"
)
if node_changes["removed"]: if node_changes["removed"]:
logger.info( logger.info(
f"[遗忘] 移除的节点 ({len(node_changes['removed'])}个): {', '.join(node_changes['removed'])}") f"[遗忘] 移除的节点 ({len(node_changes['removed'])}个): {', '.join(node_changes['removed'])}"
)
else: else:
logger.info("[遗忘] 本次检查没有节点或连接满足遗忘条件") logger.info("[遗忘] 本次检查没有节点或连接满足遗忘条件")
end_time = time.time() end_time = time.time()
logger.info(f"[遗忘] 总耗时: {end_time - start_time:.2f}") logger.info(f"[遗忘] 总耗时: {end_time - start_time:.2f}")
# 海马体 # 海马体
class Hippocampus: class Hippocampus:
def __init__(self): def __init__(self):
@@ -908,9 +912,14 @@ class Hippocampus:
memories.sort(key=lambda x: x[2], reverse=True) memories.sort(key=lambda x: x[2], reverse=True)
return memories return memories
async def get_memory_from_text(self, text: str, max_memory_num: int = 3, max_memory_length: int = 2, async def get_memory_from_text(
self,
text: str,
max_memory_num: int = 3,
max_memory_length: int = 2,
max_depth: int = 3, max_depth: int = 3,
fast_retrieval: bool = False) -> list: fast_retrieval: bool = False,
) -> list:
"""从文本中提取关键词并获取相关记忆。 """从文本中提取关键词并获取相关记忆。
Args: Args:
@@ -943,18 +952,16 @@ class Hippocampus:
# 使用LLM提取关键词 # 使用LLM提取关键词
topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量 topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量
# logger.info(f"提取关键词数量: {topic_num}") # logger.info(f"提取关键词数量: {topic_num}")
topics_response = await self.llm_topic_judge.generate_response( topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(text, topic_num))
self.find_topic_llm(text, topic_num)
)
# 提取关键词 # 提取关键词
keywords = re.findall(r'<([^>]+)>', topics_response[0]) keywords = re.findall(r"<([^>]+)>", topics_response[0])
if not keywords: if not keywords:
keywords = [] keywords = []
else: else:
keywords = [ keywords = [
keyword.strip() keyword.strip()
for keyword in ','.join(keywords).replace("", ",").replace("", ",").replace(" ", ",").split(",") for keyword in ",".join(keywords).replace("", ",").replace("", ",").replace(" ", ",").split(",")
if keyword.strip() if keyword.strip()
] ]
@@ -1008,7 +1015,8 @@ class Hippocampus:
visited_nodes.add(neighbor) visited_nodes.add(neighbor)
nodes_to_process.append((neighbor, new_activation, current_depth + 1)) nodes_to_process.append((neighbor, new_activation, current_depth + 1))
logger.debug( logger.debug(
f"节点 '{neighbor}' 被激活,激活值: {new_activation:.2f} (通过 '{current_node}' 连接,强度: {strength}, 深度: {current_depth + 1})") # noqa: E501 f"节点 '{neighbor}' 被激活,激活值: {new_activation:.2f} (通过 '{current_node}' 连接,强度: {strength}, 深度: {current_depth + 1})"
) # noqa: E501
# 更新激活映射 # 更新激活映射
for node, activation_value in activation_values.items(): for node, activation_value in activation_values.items():
@@ -1032,22 +1040,18 @@ class Hippocampus:
if total_squared_activation > 0: if total_squared_activation > 0:
# 计算归一化的激活值 # 计算归一化的激活值
normalized_activations = { normalized_activations = {
node: (activation ** 2) / total_squared_activation node: (activation**2) / total_squared_activation for node, activation in activate_map.items()
for node, activation in activate_map.items()
} }
# 按归一化激活值排序并选择前max_memory_num个 # 按归一化激活值排序并选择前max_memory_num个
sorted_nodes = sorted( sorted_nodes = sorted(normalized_activations.items(), key=lambda x: x[1], reverse=True)[:max_memory_num]
normalized_activations.items(),
key=lambda x: x[1],
reverse=True
)[:max_memory_num]
# 将选中的节点添加到remember_map # 将选中的节点添加到remember_map
for node, normalized_activation in sorted_nodes: for node, normalized_activation in sorted_nodes:
remember_map[node] = activate_map[node] # 使用原始激活值 remember_map[node] = activate_map[node] # 使用原始激活值
logger.debug( logger.debug(
f"节点 '{node}' (归一化激活值: {normalized_activation:.2f}, 激活值: {activate_map[node]:.2f})") f"节点 '{node}' (归一化激活值: {normalized_activation:.2f}, 激活值: {activate_map[node]:.2f})"
)
else: else:
logger.info("没有有效的激活值") logger.info("没有有效的激活值")
@@ -1109,8 +1113,7 @@ class Hippocampus:
return result return result
async def get_activate_from_text(self, text: str, max_depth: int = 3, async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> float:
fast_retrieval: bool = False) -> float:
"""从文本中提取关键词并获取相关记忆。 """从文本中提取关键词并获取相关记忆。
Args: Args:
@@ -1140,18 +1143,16 @@ class Hippocampus:
# 使用LLM提取关键词 # 使用LLM提取关键词
topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量 topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量
# logger.info(f"提取关键词数量: {topic_num}") # logger.info(f"提取关键词数量: {topic_num}")
topics_response = await self.llm_topic_judge.generate_response( topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(text, topic_num))
self.find_topic_llm(text, topic_num)
)
# 提取关键词 # 提取关键词
keywords = re.findall(r'<([^>]+)>', topics_response[0]) keywords = re.findall(r"<([^>]+)>", topics_response[0])
if not keywords: if not keywords:
keywords = [] keywords = []
else: else:
keywords = [ keywords = [
keyword.strip() keyword.strip()
for keyword in ','.join(keywords).replace("", ",").replace("", ",").replace(" ", ",").split(",") for keyword in ",".join(keywords).replace("", ",").replace("", ",").replace(" ", ",").split(",")
if keyword.strip() if keyword.strip()
] ]
@@ -1230,6 +1231,7 @@ class Hippocampus:
return activation_ratio return activation_ratio
class HippocampusManager: class HippocampusManager:
_instance = None _instance = None
_hippocampus = None _hippocampus = None
@@ -1266,14 +1268,13 @@ class HippocampusManager:
node_count = len(memory_graph.nodes()) node_count = len(memory_graph.nodes())
edge_count = len(memory_graph.edges()) edge_count = len(memory_graph.edges())
logger.success(f'''-------------------------------- logger.success(f"""--------------------------------
记忆系统参数配置: 记忆系统参数配置:
构建间隔: {global_config.build_memory_interval}秒|样本数: {config.build_memory_sample_num},长度: {config.build_memory_sample_length}|压缩率: {config.memory_compress_rate} 构建间隔: {global_config.build_memory_interval}秒|样本数: {config.build_memory_sample_num},长度: {config.build_memory_sample_length}|压缩率: {config.memory_compress_rate}
记忆构建分布: {config.memory_build_distribution} 记忆构建分布: {config.memory_build_distribution}
遗忘间隔: {global_config.forget_memory_interval}秒|遗忘比例: {global_config.memory_forget_percentage}|遗忘: {config.memory_forget_time}小时之后 遗忘间隔: {global_config.forget_memory_interval}秒|遗忘比例: {global_config.memory_forget_percentage}|遗忘: {config.memory_forget_time}小时之后
记忆图统计信息: 节点数量: {node_count}, 连接数量: {edge_count} 记忆图统计信息: 节点数量: {node_count}, 连接数量: {edge_count}
--------------------------------''') #noqa: E501 --------------------------------""") # noqa: E501
return self._hippocampus return self._hippocampus
@@ -1289,17 +1290,22 @@ class HippocampusManager:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
return await self._hippocampus.parahippocampal_gyrus.operation_forget_topic(percentage) return await self._hippocampus.parahippocampal_gyrus.operation_forget_topic(percentage)
async def get_memory_from_text(self, text: str, max_memory_num: int = 3, async def get_memory_from_text(
max_memory_length: int = 2, max_depth: int = 3, self,
fast_retrieval: bool = False) -> list: text: str,
max_memory_num: int = 3,
max_memory_length: int = 2,
max_depth: int = 3,
fast_retrieval: bool = False,
) -> list:
"""从文本中获取相关记忆的公共接口""" """从文本中获取相关记忆的公共接口"""
if not self._initialized: if not self._initialized:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
return await self._hippocampus.get_memory_from_text( return await self._hippocampus.get_memory_from_text(
text, max_memory_num, max_memory_length, max_depth, fast_retrieval) text, max_memory_num, max_memory_length, max_depth, fast_retrieval
)
async def get_activate_from_text(self, text: str, max_depth: int = 3, async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> float:
fast_retrieval: bool = False) -> float:
"""从文本中获取激活值的公共接口""" """从文本中获取激活值的公共接口"""
if not self._initialized: if not self._initialized:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
@@ -1316,5 +1322,3 @@ class HippocampusManager:
if not self._initialized: if not self._initialized:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
return self._hippocampus.get_all_node_names() return self._hippocampus.get_all_node_names()

View File

@@ -3,11 +3,13 @@ import asyncio
import time import time
import sys import sys
import os import os
# 添加项目根目录到系统路径 # 添加项目根目录到系统路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))) sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
from src.plugins.memory_system.Hippocampus import HippocampusManager from src.plugins.memory_system.Hippocampus import HippocampusManager
from src.plugins.config.config import global_config from src.plugins.config.config import global_config
async def test_memory_system(): async def test_memory_system():
"""测试记忆系统的主要功能""" """测试记忆系统的主要功能"""
try: try:
@@ -24,7 +26,7 @@ async def test_memory_system():
# 测试记忆检索 # 测试记忆检索
test_text = "千石可乐在群里聊天" test_text = "千石可乐在群里聊天"
test_text = '''[03-24 10:39:37] 麦麦(ta的id:2814567326): 早说散步结果下雨改成室内运动啊 test_text = """[03-24 10:39:37] 麦麦(ta的id:2814567326): 早说散步结果下雨改成室内运动啊
[03-24 10:39:37] 麦麦(ta的id:2814567326): [回复:变量] 变量就像今天计划总变 [03-24 10:39:37] 麦麦(ta的id:2814567326): [回复:变量] 变量就像今天计划总变
[03-24 10:39:44] 状态异常(ta的id:535554838): 要把本地文件改成弹出来的路径吗 [03-24 10:39:44] 状态异常(ta的id:535554838): 要把本地文件改成弹出来的路径吗
[03-24 10:40:35] 状态异常(ta的id:535554838): [图片这张图片显示的是Windows系统的环境变量设置界面。界面左侧列出了多个环境变量的值包括Intel Dev Redist、Windows、Windows PowerShell、OpenSSH、NVIDIA Corporation的目录等。右侧有新建、编辑、浏览、删除、上移、下移和编辑文本等操作按钮。图片下方有一个错误提示框显示"Windows找不到文件'mongodb\\bin\\mongod.exe'。请确定文件名是否正确后,再试一次。"这意味着用户试图运行MongoDB的mongod.exe程序时系统找不到该文件。这可能是因为MongoDB的安装路径未正确添加到系统环境变量中或者文件路径有误。 [03-24 10:40:35] 状态异常(ta的id:535554838): [图片这张图片显示的是Windows系统的环境变量设置界面。界面左侧列出了多个环境变量的值包括Intel Dev Redist、Windows、Windows PowerShell、OpenSSH、NVIDIA Corporation的目录等。右侧有新建、编辑、浏览、删除、上移、下移和编辑文本等操作按钮。图片下方有一个错误提示框显示"Windows找不到文件'mongodb\\bin\\mongod.exe'。请确定文件名是否正确后,再试一次。"这意味着用户试图运行MongoDB的mongod.exe程序时系统找不到该文件。这可能是因为MongoDB的安装路径未正确添加到系统环境变量中或者文件路径有误。
@@ -39,17 +41,12 @@ async def test_memory_system():
[03-24 10:46:12] (ta的id:3229291803): [表情包:这张表情包显示了一只手正在做"点赞"的动作,通常表示赞同、喜欢或支持。这个表情包所表达的情感是积极的、赞同的或支持的。] [03-24 10:46:12] (ta的id:3229291803): [表情包:这张表情包显示了一只手正在做"点赞"的动作,通常表示赞同、喜欢或支持。这个表情包所表达的情感是积极的、赞同的或支持的。]
[03-24 10:46:37] 星野風禾(ta的id:2890165435): 还能思考高达 [03-24 10:46:37] 星野風禾(ta的id:2890165435): 还能思考高达
[03-24 10:46:39] 星野風禾(ta的id:2890165435): 什么知识库 [03-24 10:46:39] 星野風禾(ta的id:2890165435): 什么知识库
[03-24 10:46:49] ❦幻凌慌てない(ta的id:2459587037): 为什么改了回复系数麦麦还是不怎么回复?大佬们''' # noqa: E501 [03-24 10:46:49] ❦幻凌慌てない(ta的id:2459587037): 为什么改了回复系数麦麦还是不怎么回复?大佬们""" # noqa: E501
# test_text = '''千石可乐分不清AI的陪伴和人类的陪伴,是这样吗?''' # test_text = '''千石可乐分不清AI的陪伴和人类的陪伴,是这样吗?'''
print(f"开始测试记忆检索,测试文本: {test_text}\n") print(f"开始测试记忆检索,测试文本: {test_text}\n")
memories = await hippocampus_manager.get_memory_from_text( memories = await hippocampus_manager.get_memory_from_text(
text=test_text, text=test_text, max_memory_num=3, max_memory_length=2, max_depth=3, fast_retrieval=False
max_memory_num=3,
max_memory_length=2,
max_depth=3,
fast_retrieval=False
) )
await asyncio.sleep(1) await asyncio.sleep(1)
@@ -59,8 +56,6 @@ async def test_memory_system():
print(f"主题: {topic}") print(f"主题: {topic}")
print(f"- {memory_items}") print(f"- {memory_items}")
# 测试记忆遗忘 # 测试记忆遗忘
# forget_start_time = time.time() # forget_start_time = time.time()
# # print("开始测试记忆遗忘...") # # print("开始测试记忆遗忘...")
@@ -80,6 +75,7 @@ async def test_memory_system():
print(f"测试过程中出现错误: {e}") print(f"测试过程中出现错误: {e}")
raise raise
async def main(): async def main():
"""主函数""" """主函数"""
try: try:
@@ -91,5 +87,6 @@ async def main():
print(f"程序执行出错: {e}") print(f"程序执行出错: {e}")
raise raise
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,9 +1,11 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import List
@dataclass @dataclass
class MemoryConfig: class MemoryConfig:
"""记忆系统配置类""" """记忆系统配置类"""
# 记忆构建相关配置 # 记忆构建相关配置
memory_build_distribution: List[float] # 记忆构建的时间分布参数 memory_build_distribution: List[float] # 记忆构建的时间分布参数
build_memory_sample_num: int # 每次构建记忆的样本数量 build_memory_sample_num: int # 每次构建记忆的样本数量
@@ -30,5 +32,5 @@ class MemoryConfig:
memory_forget_time=global_config.memory_forget_time, memory_forget_time=global_config.memory_forget_time,
memory_ban_words=global_config.memory_ban_words, memory_ban_words=global_config.memory_ban_words,
llm_topic_judge=global_config.llm_topic_judge, llm_topic_judge=global_config.llm_topic_judge,
llm_summary_by_topic=global_config.llm_summary_by_topic llm_summary_by_topic=global_config.llm_summary_by_topic,
) )

View File

@@ -2,6 +2,7 @@ import numpy as np
from scipy import stats from scipy import stats
from datetime import datetime, timedelta from datetime import datetime, timedelta
class DistributionVisualizer: class DistributionVisualizer:
def __init__(self, mean=0, std=1, skewness=0, sample_size=10): def __init__(self, mean=0, std=1, skewness=0, sample_size=10):
""" """
@@ -26,10 +27,7 @@ class DistributionVisualizer:
self.samples = np.random.normal(loc=self.mean, scale=self.std, size=self.sample_size) self.samples = np.random.normal(loc=self.mean, scale=self.std, size=self.sample_size)
else: else:
# 使用 scipy.stats 生成具有偏度的分布 # 使用 scipy.stats 生成具有偏度的分布
self.samples = stats.skewnorm.rvs(a=self.skewness, self.samples = stats.skewnorm.rvs(a=self.skewness, loc=self.mean, scale=self.std, size=self.sample_size)
loc=self.mean,
scale=self.std,
size=self.sample_size)
def get_weighted_samples(self): def get_weighted_samples(self):
"""获取加权后的样本数列""" """获取加权后的样本数列"""
@@ -43,17 +41,11 @@ class DistributionVisualizer:
if self.samples is None: if self.samples is None:
self.generate_samples() self.generate_samples()
return { return {"均值": np.mean(self.samples), "标准差": np.std(self.samples), "实际偏度": stats.skew(self.samples)}
"均值": np.mean(self.samples),
"标准差": np.std(self.samples),
"实际偏度": stats.skew(self.samples)
}
class MemoryBuildScheduler: class MemoryBuildScheduler:
def __init__(self, def __init__(self, n_hours1, std_hours1, weight1, n_hours2, std_hours2, weight2, total_samples=50):
n_hours1, std_hours1, weight1,
n_hours2, std_hours2, weight2,
total_samples=50):
""" """
初始化记忆构建调度器 初始化记忆构建调度器
@@ -85,17 +77,9 @@ class MemoryBuildScheduler:
samples2 = self.total_samples - samples1 samples2 = self.total_samples - samples1
# 生成两个正态分布的小时偏移 # 生成两个正态分布的小时偏移
hours_offset1 = np.random.normal( hours_offset1 = np.random.normal(loc=self.n_hours1, scale=self.std_hours1, size=samples1)
loc=self.n_hours1,
scale=self.std_hours1,
size=samples1
)
hours_offset2 = np.random.normal( hours_offset2 = np.random.normal(loc=self.n_hours2, scale=self.std_hours2, size=samples2)
loc=self.n_hours2,
scale=self.std_hours2,
size=samples2
)
# 合并两个分布的偏移 # 合并两个分布的偏移
hours_offset = np.concatenate([hours_offset1, hours_offset2]) hours_offset = np.concatenate([hours_offset1, hours_offset2])
@@ -111,6 +95,7 @@ class MemoryBuildScheduler:
timestamps = self.generate_time_samples() timestamps = self.generate_time_samples()
return [int(t.timestamp()) for t in timestamps] return [int(t.timestamp()) for t in timestamps]
def print_time_samples(timestamps, show_distribution=True): def print_time_samples(timestamps, show_distribution=True):
"""打印时间样本和分布信息""" """打印时间样本和分布信息"""
print(f"\n生成的{len(timestamps)}个时间点分布:") print(f"\n生成的{len(timestamps)}个时间点分布:")
@@ -140,6 +125,7 @@ def print_time_samples(timestamps, show_distribution=True):
if hist[i] > 0: if hist[i] > 0:
print(f"{bins[i]:6.1f}-{bins[i + 1]:6.1f}小时: {'*' * int(hist[i])}") print(f"{bins[i]:6.1f}-{bins[i + 1]:6.1f}小时: {'*' * int(hist[i])}")
# 使用示例 # 使用示例
if __name__ == "__main__": if __name__ == "__main__":
# 创建一个双峰分布的记忆调度器 # 创建一个双峰分布的记忆调度器
@@ -150,7 +136,7 @@ if __name__ == "__main__":
n_hours2=36, # 第二个分布均值36小时前 n_hours2=36, # 第二个分布均值36小时前
std_hours2=24, # 第二个分布标准差 std_hours2=24, # 第二个分布标准差
weight2=0.3, # 第二个分布权重 30% weight2=0.3, # 第二个分布权重 30%
total_samples=50 # 总共生成50个时间点 total_samples=50, # 总共生成50个时间点
) )
# 生成时间分布 # 生成时间分布

View File

@@ -54,9 +54,7 @@ class TestLiveAPI(unittest.IsolatedAsyncioTestCase):
# 准备测试消息 # 准备测试消息
user_info = UserInfo(user_id=12345678, user_nickname="测试用户", platform="qq") user_info = UserInfo(user_id=12345678, user_nickname="测试用户", platform="qq")
group_info = GroupInfo(group_id=12345678, group_name="测试群", platform="qq") group_info = GroupInfo(group_id=12345678, group_name="测试群", platform="qq")
format_info = FormatInfo( format_info = FormatInfo(content_format=["text"], accept_format=["text", "emoji", "reply"])
content_format=["text"], accept_format=["text", "emoji", "reply"]
)
template_info = None template_info = None
message_info = BaseMessageInfo( message_info = BaseMessageInfo(
platform="qq", platform="qq",

View File

@@ -35,6 +35,7 @@ else:
print(f"未找到环境变量文件: {env_path}") print(f"未找到环境变量文件: {env_path}")
print("将使用默认配置") print("将使用默认配置")
class ChatBasedPersonalityEvaluator: class ChatBasedPersonalityEvaluator:
def __init__(self): def __init__(self):
self.personality_traits = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0} self.personality_traits = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
@@ -55,11 +56,9 @@ class ChatBasedPersonalityEvaluator:
scene = scenes[scene_key] scene = scenes[scene_key]
other_traits = [t for t in PERSONALITY_SCENES if t != trait] other_traits = [t for t in PERSONALITY_SCENES if t != trait]
secondary_trait = random.choice(other_traits) secondary_trait = random.choice(other_traits)
self.scenarios.append({ self.scenarios.append(
"场景": scene["scenario"], {"场景": scene["scenario"], "评估维度": [trait, secondary_trait], "场景编号": scene_key}
"评估维度": [trait, secondary_trait], )
"场景编号": scene_key
})
def analyze_chat_context(self, messages: List[Dict]) -> str: def analyze_chat_context(self, messages: List[Dict]) -> str:
""" """
@@ -67,14 +66,15 @@ class ChatBasedPersonalityEvaluator:
""" """
context = "" context = ""
for msg in messages: for msg in messages:
nickname = msg.get('user_info', {}).get('user_nickname', '未知用户') nickname = msg.get("user_info", {}).get("user_nickname", "未知用户")
content = msg.get('processed_plain_text', msg.get('detailed_plain_text', '')) content = msg.get("processed_plain_text", msg.get("detailed_plain_text", ""))
if content: if content:
context += f"{nickname}: {content}\n" context += f"{nickname}: {content}\n"
return context return context
def evaluate_chat_response( def evaluate_chat_response(
self, user_nickname: str, chat_context: str, dimensions: List[str] = None) -> Dict[str, float]: self, user_nickname: str, chat_context: str, dimensions: List[str] = None
) -> Dict[str, float]:
""" """
评估聊天内容在各个人格维度上的得分 评估聊天内容在各个人格维度上的得分
""" """
@@ -147,7 +147,8 @@ class ChatBasedPersonalityEvaluator:
""" """
# 获取用户的随机消息及其上下文 # 获取用户的随机消息及其上下文
chat_contexts, user_nickname = self.message_analyzer.get_user_random_contexts( chat_contexts, user_nickname = self.message_analyzer.get_user_random_contexts(
qq_id, num_messages=num_samples, context_length=context_length) qq_id, num_messages=num_samples, context_length=context_length
)
if not chat_contexts: if not chat_contexts:
return {"error": f"没有找到QQ号 {qq_id} 的消息记录"} return {"error": f"没有找到QQ号 {qq_id} 的消息记录"}
@@ -165,11 +166,9 @@ class ChatBasedPersonalityEvaluator:
scores = self.evaluate_chat_response(user_nickname, chat_context) scores = self.evaluate_chat_response(user_nickname, chat_context)
# 记录样本 # 记录样本
chat_samples.append({ chat_samples.append(
"聊天内容": chat_context, {"聊天内容": chat_context, "评估维度": list(self.personality_traits.keys()), "评分": scores}
"评估维度": list(self.personality_traits.keys()), )
"评分": scores
})
# 更新总分和历史记录 # 更新总分和历史记录
for dimension, score in scores.items(): for dimension, score in scores.items():
@@ -196,7 +195,7 @@ class ChatBasedPersonalityEvaluator:
"人格特征评分": average_scores, "人格特征评分": average_scores,
"维度评估次数": dict(dimension_counts), "维度评估次数": dict(dimension_counts),
"详细样本": chat_samples, "详细样本": chat_samples,
"特质得分历史": {k: v for k, v in self.trait_scores_history.items()} "特质得分历史": {k: v for k, v in self.trait_scores_history.items()},
} }
# 保存结果 # 保存结果
@@ -215,32 +214,33 @@ class ChatBasedPersonalityEvaluator:
chinese_fonts = [] chinese_fonts = []
for f in fm.fontManager.ttflist: for f in fm.fontManager.ttflist:
try: try:
if '' in f.name or 'SC' in f.name or '' in f.name or '' in f.name or '微软' in f.name: if "" in f.name or "SC" in f.name or "" in f.name or "" in f.name or "微软" in f.name:
chinese_fonts.append(f.name) chinese_fonts.append(f.name)
except Exception: except Exception:
continue continue
if chinese_fonts: if chinese_fonts:
plt.rcParams['font.sans-serif'] = chinese_fonts + ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS'] plt.rcParams["font.sans-serif"] = chinese_fonts + ["SimHei", "Microsoft YaHei", "Arial Unicode MS"]
else: else:
# 如果没有找到中文字体,使用默认字体,并将中文昵称转换为拼音或英文 # 如果没有找到中文字体,使用默认字体,并将中文昵称转换为拼音或英文
try: try:
from pypinyin import lazy_pinyin from pypinyin import lazy_pinyin
user_nickname = ''.join(lazy_pinyin(user_nickname))
user_nickname = "".join(lazy_pinyin(user_nickname))
except ImportError: except ImportError:
user_nickname = "User" # 如果无法转换为拼音,使用默认英文 user_nickname = "User" # 如果无法转换为拼音,使用默认英文
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题 plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题
plt.figure(figsize=(12, 6)) plt.figure(figsize=(12, 6))
plt.style.use('bmh') # 使用内置的bmh样式它有类似seaborn的美观效果 plt.style.use("bmh") # 使用内置的bmh样式它有类似seaborn的美观效果
colors = { colors = {
"开放性": "#FF9999", "开放性": "#FF9999",
"严谨性": "#66B2FF", "严谨性": "#66B2FF",
"外向性": "#99FF99", "外向性": "#99FF99",
"宜人性": "#FFCC99", "宜人性": "#FFCC99",
"神经质": "#FF99CC" "神经质": "#FF99CC",
} }
# 计算每个维度在每个时间点的累计平均分 # 计算每个维度在每个时间点的累计平均分
@@ -271,18 +271,18 @@ class ChatBasedPersonalityEvaluator:
# 绘制每个维度的累计平均分变化趋势 # 绘制每个维度的累计平均分变化趋势
for trait, averages in cumulative_averages.items(): for trait, averages in cumulative_averages.items():
x = range(1, len(averages) + 1) x = range(1, len(averages) + 1)
plt.plot(x, averages, 'o-', label=trait, color=colors.get(trait), linewidth=2, markersize=8) plt.plot(x, averages, "o-", label=trait, color=colors.get(trait), linewidth=2, markersize=8)
# 添加趋势线 # 添加趋势线
z = np.polyfit(x, averages, 1) z = np.polyfit(x, averages, 1)
p = np.poly1d(z) p = np.poly1d(z)
plt.plot(x, p(x), '--', color=colors.get(trait), alpha=0.5) plt.plot(x, p(x), "--", color=colors.get(trait), alpha=0.5)
plt.title(f"{user_nickname} 的人格特质累计平均分变化趋势", fontsize=14, pad=20) plt.title(f"{user_nickname} 的人格特质累计平均分变化趋势", fontsize=14, pad=20)
plt.xlabel("评估次数", fontsize=12) plt.xlabel("评估次数", fontsize=12)
plt.ylabel("累计平均分", fontsize=12) plt.ylabel("累计平均分", fontsize=12)
plt.grid(True, linestyle='--', alpha=0.7) plt.grid(True, linestyle="--", alpha=0.7)
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5)) plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
plt.ylim(0, 7) plt.ylim(0, 7)
plt.tight_layout() plt.tight_layout()
@@ -290,9 +290,10 @@ class ChatBasedPersonalityEvaluator:
os.makedirs("results/plots", exist_ok=True) os.makedirs("results/plots", exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
plot_file = f"results/plots/personality_trend_{qq_id}_{timestamp}.png" plot_file = f"results/plots/personality_trend_{qq_id}_{timestamp}.png"
plt.savefig(plot_file, dpi=300, bbox_inches='tight') plt.savefig(plot_file, dpi=300, bbox_inches="tight")
plt.close() plt.close()
def analyze_user_personality(qq_id: str, num_samples: int = 10, context_length: int = 5) -> str: def analyze_user_personality(qq_id: str, num_samples: int = 10, context_length: int = 5) -> str:
""" """
分析用户人格特征的便捷函数 分析用户人格特征的便捷函数
@@ -341,6 +342,7 @@ def analyze_user_personality(qq_id: str, num_samples: int = 10, context_length:
return output return output
if __name__ == "__main__": if __name__ == "__main__":
# 测试代码 # 测试代码
# test_qq = "" # 替换为要测试的QQ号 # test_qq = "" # 替换为要测试的QQ号

View File

@@ -82,7 +82,6 @@ class PersonalityEvaluator_direct:
dimensions_text = "\n".join(dimension_descriptions) dimensions_text = "\n".join(dimension_descriptions)
prompt = f"""请根据以下场景和用户描述评估用户在大五人格模型中的相关维度得分1-6分 prompt = f"""请根据以下场景和用户描述评估用户在大五人格模型中的相关维度得分1-6分
场景描述: 场景描述:

View File

@@ -14,6 +14,7 @@ sys.path.append(root_path)
from src.common.database import db # noqa: E402 from src.common.database import db # noqa: E402
class MessageAnalyzer: class MessageAnalyzer:
def __init__(self): def __init__(self):
self.messages_collection = db["messages"] self.messages_collection = db["messages"]
@@ -35,19 +36,17 @@ class MessageAnalyzer:
return None return None
# 获取该消息的stream_id # 获取该消息的stream_id
stream_id = target_message.get('chat_info', {}).get('stream_id') stream_id = target_message.get("chat_info", {}).get("stream_id")
if not stream_id: if not stream_id:
return None return None
# 获取同一stream_id的所有消息 # 获取同一stream_id的所有消息
stream_messages = list(self.messages_collection.find({ stream_messages = list(self.messages_collection.find({"chat_info.stream_id": stream_id}).sort("time", 1))
"chat_info.stream_id": stream_id
}).sort("time", 1))
# 找到目标消息在列表中的位置 # 找到目标消息在列表中的位置
target_index = None target_index = None
for i, msg in enumerate(stream_messages): for i, msg in enumerate(stream_messages):
if msg['message_id'] == message_id: if msg["message_id"] == message_id:
target_index = i target_index = i
break break
@@ -77,24 +76,25 @@ class MessageAnalyzer:
reply = "" reply = ""
for msg in messages: for msg in messages:
# 消息时间 # 消息时间
msg_time = datetime.datetime.fromtimestamp(int(msg['time'])).strftime("%Y-%m-%d %H:%M:%S") msg_time = datetime.datetime.fromtimestamp(int(msg["time"])).strftime("%Y-%m-%d %H:%M:%S")
# 获取消息内容 # 获取消息内容
message_text = msg.get('processed_plain_text', msg.get('detailed_plain_text', '无消息内容')) message_text = msg.get("processed_plain_text", msg.get("detailed_plain_text", "无消息内容"))
nickname = msg.get('user_info', {}).get('user_nickname', '未知用户') nickname = msg.get("user_info", {}).get("user_nickname", "未知用户")
# 标记当前消息 # 标记当前消息
is_target = "" if target_message_id and msg['message_id'] == target_message_id else " " is_target = "" if target_message_id and msg["message_id"] == target_message_id else " "
reply += f"{is_target}[{msg_time}] {nickname}: {message_text}\n" reply += f"{is_target}[{msg_time}] {nickname}: {message_text}\n"
if target_message_id and msg['message_id'] == target_message_id: if target_message_id and msg["message_id"] == target_message_id:
reply += " " + "-" * 50 + "\n" reply += " " + "-" * 50 + "\n"
return reply return reply
def get_user_random_contexts( def get_user_random_contexts(
self, qq_id: str, num_messages: int = 10, context_length: int = 5) -> tuple[List[str], str]: # noqa: E501 self, qq_id: str, num_messages: int = 10, context_length: int = 5
) -> tuple[List[str], str]: # noqa: E501
""" """
获取用户的随机消息及其上下文 获取用户的随机消息及其上下文
@@ -115,19 +115,19 @@ class MessageAnalyzer:
return [], "" return [], ""
# 获取用户昵称 # 获取用户昵称
user_nickname = all_messages[0].get('chat_info', {}).get('user_info', {}).get('user_nickname', '未知用户') user_nickname = all_messages[0].get("chat_info", {}).get("user_info", {}).get("user_nickname", "未知用户")
# 随机选择指定数量的消息 # 随机选择指定数量的消息
selected_messages = random.sample(all_messages, min(num_messages, len(all_messages))) selected_messages = random.sample(all_messages, min(num_messages, len(all_messages)))
# 按时间排序 # 按时间排序
selected_messages.sort(key=lambda x: int(x['time'])) selected_messages.sort(key=lambda x: int(x["time"]))
# 存储所有上下文消息 # 存储所有上下文消息
context_list = [] context_list = []
# 获取每条消息的上下文 # 获取每条消息的上下文
for msg in selected_messages: for msg in selected_messages:
message_id = msg['message_id'] message_id = msg["message_id"]
# 获取消息上下文 # 获取消息上下文
context_messages = self.get_message_context(message_id, context_length) context_messages = self.get_message_context(message_id, context_length)
@@ -137,6 +137,7 @@ class MessageAnalyzer:
return context_list, user_nickname return context_list, user_nickname
if __name__ == "__main__": if __name__ == "__main__":
# 测试代码 # 测试代码
analyzer = MessageAnalyzer() analyzer = MessageAnalyzer()

View File

@@ -46,17 +46,15 @@ class LLMStatistics:
"""记录在线时间""" """记录在线时间"""
current_time = datetime.now() current_time = datetime.now()
# 检查5分钟内是否已有记录 # 检查5分钟内是否已有记录
recent_record = db.online_time.find_one({ recent_record = db.online_time.find_one({"timestamp": {"$gte": current_time - timedelta(minutes=5)}})
"timestamp": {
"$gte": current_time - timedelta(minutes=5)
}
})
if not recent_record: if not recent_record:
db.online_time.insert_one({ db.online_time.insert_one(
{
"timestamp": current_time, "timestamp": current_time,
"duration": 5 # 5分钟 "duration": 5, # 5分钟
}) }
)
def _collect_statistics_for_period(self, start_time: datetime) -> Dict[str, Any]: def _collect_statistics_for_period(self, start_time: datetime) -> Dict[str, Any]:
"""收集指定时间段的LLM请求统计数据 """收集指定时间段的LLM请求统计数据

View File

@@ -41,7 +41,6 @@ class WillingManager:
interested_rate = interested_rate * config.response_interested_rate_amplifier interested_rate = interested_rate * config.response_interested_rate_amplifier
if interested_rate > 0.4: if interested_rate > 0.4:
current_willing += interested_rate - 0.3 current_willing += interested_rate - 0.3

View File

@@ -15,6 +15,7 @@ heartflow_config = LogConfig(
) )
logger = get_module_logger("heartflow", config=heartflow_config) logger = get_module_logger("heartflow", config=heartflow_config)
class CuttentState: class CuttentState:
def __init__(self): def __init__(self):
self.willing = 0 self.willing = 0
@@ -26,13 +27,15 @@ class CuttentState:
def update_current_state_info(self): def update_current_state_info(self):
self.current_state_info = self.mood_manager.get_current_mood() self.current_state_info = self.mood_manager.get_current_mood()
class Heartflow: class Heartflow:
def __init__(self): def __init__(self):
self.current_mind = "你什么也没想" self.current_mind = "你什么也没想"
self.past_mind = [] self.past_mind = []
self.current_state: CuttentState = CuttentState() self.current_state: CuttentState = CuttentState()
self.llm_model = LLM_request( self.llm_model = LLM_request(
model=global_config.llm_heartflow, temperature=0.6, max_tokens=1000, request_type="heart_flow") model=global_config.llm_heartflow, temperature=0.6, max_tokens=1000, request_type="heart_flow"
)
self._subheartflows = {} self._subheartflows = {}
self.active_subheartflows_nums = 0 self.active_subheartflows_nums = 0
@@ -79,7 +82,7 @@ class Heartflow:
personality_info = self.personality_info personality_info = self.personality_info
current_thinking_info = self.current_mind current_thinking_info = self.current_mind
mood_info = self.current_state.mood mood_info = self.current_state.mood
related_memory_info = 'memory' related_memory_info = "memory"
sub_flows_info = await self.get_all_subheartflows_minds() sub_flows_info = await self.get_all_subheartflows_minds()
schedule_info = bot_schedule.get_current_num_task(num=4, time_info=True) schedule_info = bot_schedule.get_current_num_task(num=4, time_info=True)
@@ -103,7 +106,6 @@ class Heartflow:
# logger.info("麦麦想了想,当前活动:") # logger.info("麦麦想了想,当前活动:")
await bot_schedule.move_doing(self.current_mind) await bot_schedule.move_doing(self.current_mind)
for _, subheartflow in self._subheartflows.items(): for _, subheartflow in self._subheartflows.items():
subheartflow.main_heartflow_info = reponse subheartflow.main_heartflow_info = reponse
@@ -111,8 +113,6 @@ class Heartflow:
self.past_mind.append(self.current_mind) self.past_mind.append(self.current_mind)
self.current_mind = reponse self.current_mind = reponse
async def get_all_subheartflows_minds(self): async def get_all_subheartflows_minds(self):
sub_minds = "" sub_minds = ""
for _, subheartflow in self._subheartflows.items(): for _, subheartflow in self._subheartflows.items():
@@ -129,8 +129,8 @@ class Heartflow:
prompt += f"现在{global_config.BOT_NICKNAME}的想法是:{self.current_mind}\n" prompt += f"现在{global_config.BOT_NICKNAME}的想法是:{self.current_mind}\n"
prompt += f"现在{global_config.BOT_NICKNAME}在qq群里进行聊天聊天的话题如下{minds_str}\n" prompt += f"现在{global_config.BOT_NICKNAME}在qq群里进行聊天聊天的话题如下{minds_str}\n"
prompt += f"你现在{mood_info}\n" prompt += f"你现在{mood_info}\n"
prompt += '''现在请你总结这些聊天内容,注意关注聊天内容对原有的想法的影响,输出连贯的内心独白 prompt += """现在请你总结这些聊天内容,注意关注聊天内容对原有的想法的影响,输出连贯的内心独白
不要太长,但是记得结合上述的消息,要记得你的人设,关注新内容:''' 不要太长,但是记得结合上述的消息,要记得你的人设,关注新内容:"""
reponse, reasoning_content = await self.llm_model.generate_response_async(prompt) reponse, reasoning_content = await self.llm_model.generate_response_async(prompt)

View File

@@ -6,6 +6,7 @@ from src.plugins.models.utils_model import LLM_request
from src.plugins.config.config import global_config from src.plugins.config.config import global_config
from src.common.database import db from src.common.database import db
# 所有观察的基类 # 所有观察的基类
class Observation: class Observation:
def __init__(self, observe_type, observe_id): def __init__(self, observe_type, observe_id):
@@ -14,6 +15,7 @@ class Observation:
self.observe_id = observe_id self.observe_id = observe_id
self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间 self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间
# 聊天观察 # 聊天观察
class ChattingObservation(Observation): class ChattingObservation(Observation):
def __init__(self, chat_id): def __init__(self, chat_id):
@@ -32,15 +34,17 @@ class ChattingObservation(Observation):
self.sub_observe = None self.sub_observe = None
self.llm_summary = LLM_request( self.llm_summary = LLM_request(
model=global_config.llm_outer_world, temperature=0.7, max_tokens=300, request_type="outer_world") model=global_config.llm_outer_world, temperature=0.7, max_tokens=300, request_type="outer_world"
)
# 进行一次观察 返回观察结果observe_info # 进行一次观察 返回观察结果observe_info
async def observe(self): async def observe(self):
# 查找新消息限制最多30条 # 查找新消息限制最多30条
new_messages = list(db.messages.find({ new_messages = list(
"chat_id": self.chat_id, db.messages.find({"chat_id": self.chat_id, "time": {"$gt": self.last_observe_time}})
"time": {"$gt": self.last_observe_time} .sort("time", 1)
}).sort("time", 1).limit(20)) # 按时间正序排列最多20条 .limit(20)
) # 按时间正序排列最多20条
if not new_messages: if not new_messages:
return self.observe_info # 没有新消息,返回上次观察结果 return self.observe_info # 没有新消息,返回上次观察结果
@@ -75,10 +79,11 @@ class ChattingObservation(Observation):
async def carefully_observe(self): async def carefully_observe(self):
# 查找新消息限制最多40条 # 查找新消息限制最多40条
new_messages = list(db.messages.find({ new_messages = list(
"chat_id": self.chat_id, db.messages.find({"chat_id": self.chat_id, "time": {"$gt": self.last_observe_time}})
"time": {"$gt": self.last_observe_time} .sort("time", 1)
}).sort("time", 1).limit(30)) # 按时间正序排列最多30条 .limit(30)
) # 按时间正序排列最多30条
if not new_messages: if not new_messages:
return self.observe_info # 没有新消息,返回上次观察结果 return self.observe_info # 没有新消息,返回上次观察结果
@@ -102,15 +107,14 @@ class ChattingObservation(Observation):
await self.update_talking_summary(new_messages_str) await self.update_talking_summary(new_messages_str)
return self.observe_info return self.observe_info
async def update_talking_summary(self, new_messages_str): async def update_talking_summary(self, new_messages_str):
# 基于已经有的talking_summary和新的talking_message生成一个summary # 基于已经有的talking_summary和新的talking_message生成一个summary
# print(f"更新聊天总结:{self.talking_summary}") # print(f"更新聊天总结:{self.talking_summary}")
prompt = "" prompt = ""
prompt = f"你正在参与一个qq群聊的讨论这个群之前在聊的内容是{self.observe_info}\n" prompt = f"你正在参与一个qq群聊的讨论这个群之前在聊的内容是{self.observe_info}\n"
prompt += f"现在群里的群友们产生了新的讨论,有了新的发言,具体内容如下:{new_messages_str}\n" prompt += f"现在群里的群友们产生了新的讨论,有了新的发言,具体内容如下:{new_messages_str}\n"
prompt += '''以上是群里在进行的聊天,请你对这个聊天内容进行总结,总结内容要包含聊天的大致内容, prompt += """以上是群里在进行的聊天,请你对这个聊天内容进行总结,总结内容要包含聊天的大致内容,
以及聊天中的一些重要信息,记得不要分点,不要太长,精简的概括成一段文本\n''' 以及聊天中的一些重要信息,记得不要分点,不要太长,精简的概括成一段文本\n"""
prompt += "总结概括:" prompt += "总结概括:"
self.observe_info, reasoning_content = await self.llm_summary.generate_response_async(prompt) self.observe_info, reasoning_content = await self.llm_summary.generate_response_async(prompt)

View File

@@ -37,7 +37,8 @@ class SubHeartflow:
self.past_mind = [] self.past_mind = []
self.current_state: CuttentState = CuttentState() self.current_state: CuttentState = CuttentState()
self.llm_model = LLM_request( self.llm_model = LLM_request(
model=global_config.llm_sub_heartflow, temperature=0.7, max_tokens=600, request_type="sub_heart_flow") model=global_config.llm_sub_heartflow, temperature=0.7, max_tokens=600, request_type="sub_heart_flow"
)
self.main_heartflow_info = "" self.main_heartflow_info = ""
@@ -101,7 +102,6 @@ class SubHeartflow:
break # 退出循环,销毁自己 break # 退出循环,销毁自己
async def do_a_thinking(self): async def do_a_thinking(self):
current_thinking_info = self.current_mind current_thinking_info = self.current_mind
mood_info = self.current_state.mood mood_info = self.current_state.mood
@@ -111,11 +111,7 @@ class SubHeartflow:
# 调取记忆 # 调取记忆
related_memory = await HippocampusManager.get_instance().get_memory_from_text( related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=chat_observe_info, text=chat_observe_info, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
max_memory_num=2,
max_memory_length=2,
max_depth=3,
fast_retrieval=False
) )
if related_memory: if related_memory:
@@ -123,7 +119,7 @@ class SubHeartflow:
for memory in related_memory: for memory in related_memory:
related_memory_info += memory[1] related_memory_info += memory[1]
else: else:
related_memory_info = '' related_memory_info = ""
# print(f"相关记忆:{related_memory_info}") # print(f"相关记忆:{related_memory_info}")
@@ -196,7 +192,7 @@ class SubHeartflow:
response, reasoning_content = await self.llm_model.generate_response_async(prompt) response, reasoning_content = await self.llm_model.generate_response_async(prompt)
# 解析willing值 # 解析willing值
willing_match = re.search(r'<(\d+)>', response) willing_match = re.search(r"<(\d+)>", response)
if willing_match: if willing_match:
self.current_state.willing = int(willing_match.group(1)) self.current_state.willing = int(willing_match.group(1))
else: else:
@@ -210,4 +206,3 @@ class SubHeartflow:
# subheartflow = SubHeartflow() # subheartflow = SubHeartflow()