refactor: 全部代码格式化
This commit is contained in:
@@ -81,13 +81,15 @@ MEMORY_STYLE_CONFIG = {
|
|||||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 海马体 | {message}"),
|
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 海马体 | {message}"),
|
||||||
},
|
},
|
||||||
"simple": {
|
"simple": {
|
||||||
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-yellow>海马体</light-yellow> | <light-yellow>{message}</light-yellow>"),
|
"console_format": (
|
||||||
|
"<green>{time:MM-DD HH:mm}</green> | <light-yellow>海马体</light-yellow> | <light-yellow>{message}</light-yellow>"
|
||||||
|
),
|
||||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 海马体 | {message}"),
|
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 海马体 | {message}"),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#MOOD
|
# MOOD
|
||||||
MOOD_STYLE_CONFIG = {
|
MOOD_STYLE_CONFIG = {
|
||||||
"advanced": {
|
"advanced": {
|
||||||
"console_format": (
|
"console_format": (
|
||||||
@@ -152,7 +154,9 @@ HEARTFLOW_STYLE_CONFIG = {
|
|||||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦大脑袋 | {message}"),
|
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦大脑袋 | {message}"),
|
||||||
},
|
},
|
||||||
"simple": {
|
"simple": {
|
||||||
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-green>麦麦大脑袋</light-green> | <light-green>{message}</light-green>"), # noqa: E501
|
"console_format": (
|
||||||
|
"<green>{time:MM-DD HH:mm}</green> | <light-green>麦麦大脑袋</light-green> | <light-green>{message}</light-green>"
|
||||||
|
), # noqa: E501
|
||||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦大脑袋 | {message}"),
|
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦大脑袋 | {message}"),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -223,7 +227,9 @@ CHAT_STYLE_CONFIG = {
|
|||||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}"),
|
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}"),
|
||||||
},
|
},
|
||||||
"simple": {
|
"simple": {
|
||||||
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-blue>见闻</light-blue> | <green>{message}</green>"), # noqa: E501
|
"console_format": (
|
||||||
|
"<green>{time:MM-DD HH:mm}</green> | <light-blue>见闻</light-blue> | <green>{message}</green>"
|
||||||
|
), # noqa: E501
|
||||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}"),
|
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 见闻 | {message}"),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -240,7 +246,9 @@ SUB_HEARTFLOW_STYLE_CONFIG = {
|
|||||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦小脑袋 | {message}"),
|
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦小脑袋 | {message}"),
|
||||||
},
|
},
|
||||||
"simple": {
|
"simple": {
|
||||||
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-blue>麦麦小脑袋</light-blue> | <light-blue>{message}</light-blue>"), # noqa: E501
|
"console_format": (
|
||||||
|
"<green>{time:MM-DD HH:mm}</green> | <light-blue>麦麦小脑袋</light-blue> | <light-blue>{message}</light-blue>"
|
||||||
|
), # noqa: E501
|
||||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦小脑袋 | {message}"),
|
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 麦麦小脑袋 | {message}"),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -257,17 +265,17 @@ WILLING_STYLE_CONFIG = {
|
|||||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"),
|
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"),
|
||||||
},
|
},
|
||||||
"simple": {
|
"simple": {
|
||||||
"console_format": ("<green>{time:MM-DD HH:mm}</green> | <light-blue>意愿</light-blue> | <light-blue>{message}</light-blue>"), # noqa: E501
|
"console_format": (
|
||||||
|
"<green>{time:MM-DD HH:mm}</green> | <light-blue>意愿</light-blue> | <light-blue>{message}</light-blue>"
|
||||||
|
), # noqa: E501
|
||||||
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"),
|
"file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 意愿 | {message}"),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 根据SIMPLE_OUTPUT选择配置
|
# 根据SIMPLE_OUTPUT选择配置
|
||||||
MEMORY_STYLE_CONFIG = MEMORY_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else MEMORY_STYLE_CONFIG["advanced"]
|
MEMORY_STYLE_CONFIG = MEMORY_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else MEMORY_STYLE_CONFIG["advanced"]
|
||||||
TOPIC_STYLE_CONFIG = TOPIC_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else TOPIC_STYLE_CONFIG["advanced"]
|
TOPIC_STYLE_CONFIG = TOPIC_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else TOPIC_STYLE_CONFIG["advanced"]
|
||||||
SENDER_STYLE_CONFIG = SENDER_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else SENDER_STYLE_CONFIG["advanced"]
|
SENDER_STYLE_CONFIG = SENDER_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else SENDER_STYLE_CONFIG["advanced"]
|
||||||
LLM_STYLE_CONFIG = LLM_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else LLM_STYLE_CONFIG["advanced"]
|
LLM_STYLE_CONFIG = LLM_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else LLM_STYLE_CONFIG["advanced"]
|
||||||
CHAT_STYLE_CONFIG = CHAT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else CHAT_STYLE_CONFIG["advanced"]
|
CHAT_STYLE_CONFIG = CHAT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else CHAT_STYLE_CONFIG["advanced"]
|
||||||
@@ -275,7 +283,9 @@ MOOD_STYLE_CONFIG = MOOD_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else MOOD_STYLE
|
|||||||
RELATION_STYLE_CONFIG = RELATION_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else RELATION_STYLE_CONFIG["advanced"]
|
RELATION_STYLE_CONFIG = RELATION_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else RELATION_STYLE_CONFIG["advanced"]
|
||||||
SCHEDULE_STYLE_CONFIG = SCHEDULE_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else SCHEDULE_STYLE_CONFIG["advanced"]
|
SCHEDULE_STYLE_CONFIG = SCHEDULE_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else SCHEDULE_STYLE_CONFIG["advanced"]
|
||||||
HEARTFLOW_STYLE_CONFIG = HEARTFLOW_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else HEARTFLOW_STYLE_CONFIG["advanced"]
|
HEARTFLOW_STYLE_CONFIG = HEARTFLOW_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else HEARTFLOW_STYLE_CONFIG["advanced"]
|
||||||
SUB_HEARTFLOW_STYLE_CONFIG = SUB_HEARTFLOW_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else SUB_HEARTFLOW_STYLE_CONFIG["advanced"] # noqa: E501
|
SUB_HEARTFLOW_STYLE_CONFIG = (
|
||||||
|
SUB_HEARTFLOW_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else SUB_HEARTFLOW_STYLE_CONFIG["advanced"]
|
||||||
|
) # noqa: E501
|
||||||
WILLING_STYLE_CONFIG = WILLING_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else WILLING_STYLE_CONFIG["advanced"]
|
WILLING_STYLE_CONFIG = WILLING_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else WILLING_STYLE_CONFIG["advanced"]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,8 +6,9 @@ import time
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
sys.path.insert(0, sys.path[0]+"/../")
|
|
||||||
sys.path.insert(0, sys.path[0]+"/../")
|
sys.path.insert(0, sys.path[0] + "/../")
|
||||||
|
sys.path.insert(0, sys.path[0] + "/../")
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
|
|
||||||
import customtkinter as ctk
|
import customtkinter as ctk
|
||||||
|
|||||||
@@ -90,8 +90,8 @@ class MainSystem:
|
|||||||
# 启动心流系统
|
# 启动心流系统
|
||||||
asyncio.create_task(heartflow.heartflow_start_working())
|
asyncio.create_task(heartflow.heartflow_start_working())
|
||||||
logger.success("心流系统启动成功")
|
logger.success("心流系统启动成功")
|
||||||
|
|
||||||
init_time = int(1000*(time.time()- init_start_time))
|
init_time = int(1000 * (time.time() - init_start_time))
|
||||||
logger.success(f"初始化完成,神经元放电{init_time}次")
|
logger.success(f"初始化完成,神经元放电{init_time}次")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"启动大脑和外部世界失败: {e}")
|
logger.error(f"启动大脑和外部世界失败: {e}")
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ class ChatBot:
|
|||||||
5. 更新关系
|
5. 更新关系
|
||||||
6. 更新情绪
|
6. 更新情绪
|
||||||
"""
|
"""
|
||||||
|
|
||||||
message = MessageRecv(message_data)
|
message = MessageRecv(message_data)
|
||||||
groupinfo = message.message_info.group_info
|
groupinfo = message.message_info.group_info
|
||||||
userinfo = message.message_info.user_info
|
userinfo = message.message_info.user_info
|
||||||
@@ -68,7 +68,7 @@ class ChatBot:
|
|||||||
chat = await chat_manager.get_or_create_stream(
|
chat = await chat_manager.get_or_create_stream(
|
||||||
platform=messageinfo.platform,
|
platform=messageinfo.platform,
|
||||||
user_info=userinfo,
|
user_info=userinfo,
|
||||||
group_info=groupinfo,
|
group_info=groupinfo,
|
||||||
)
|
)
|
||||||
message.update_chat_stream(chat)
|
message.update_chat_stream(chat)
|
||||||
|
|
||||||
@@ -81,15 +81,12 @@ class ChatBot:
|
|||||||
logger.debug(f"2消息处理时间: {timer2 - timer1}秒")
|
logger.debug(f"2消息处理时间: {timer2 - timer1}秒")
|
||||||
|
|
||||||
# 过滤词/正则表达式过滤
|
# 过滤词/正则表达式过滤
|
||||||
if (
|
if self._check_ban_words(message.processed_plain_text, chat, userinfo) or self._check_ban_regex(
|
||||||
self._check_ban_words(message.processed_plain_text, chat, userinfo)
|
message.raw_message, chat, userinfo
|
||||||
or self._check_ban_regex(message.raw_message, chat, userinfo)
|
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
await self.storage.store_message(message, chat)
|
await self.storage.store_message(message, chat)
|
||||||
|
|
||||||
|
|
||||||
timer1 = time.time()
|
timer1 = time.time()
|
||||||
interested_rate = 0
|
interested_rate = 0
|
||||||
@@ -99,7 +96,6 @@ class ChatBot:
|
|||||||
timer2 = time.time()
|
timer2 = time.time()
|
||||||
logger.debug(f"3记忆激活时间: {timer2 - timer1}秒")
|
logger.debug(f"3记忆激活时间: {timer2 - timer1}秒")
|
||||||
|
|
||||||
|
|
||||||
is_mentioned = is_mentioned_bot_in_message(message)
|
is_mentioned = is_mentioned_bot_in_message(message)
|
||||||
|
|
||||||
if global_config.enable_think_flow:
|
if global_config.enable_think_flow:
|
||||||
@@ -124,17 +120,17 @@ class ChatBot:
|
|||||||
timer2 = time.time()
|
timer2 = time.time()
|
||||||
logger.debug(f"4计算意愿激活时间: {timer2 - timer1}秒")
|
logger.debug(f"4计算意愿激活时间: {timer2 - timer1}秒")
|
||||||
|
|
||||||
#神秘的消息流数据结构处理
|
# 神秘的消息流数据结构处理
|
||||||
if chat.group_info:
|
if chat.group_info:
|
||||||
if chat.group_info.group_name:
|
if chat.group_info.group_name:
|
||||||
mes_name_dict = chat.group_info.group_name
|
mes_name_dict = chat.group_info.group_name
|
||||||
mes_name = mes_name_dict.get('group_name', '无名群聊')
|
mes_name = mes_name_dict.get("group_name", "无名群聊")
|
||||||
else:
|
else:
|
||||||
mes_name = '群聊'
|
mes_name = "群聊"
|
||||||
else:
|
else:
|
||||||
mes_name = '私聊'
|
mes_name = "私聊"
|
||||||
|
|
||||||
#打印收到的信息的信息
|
# 打印收到的信息的信息
|
||||||
current_time = time.strftime("%H:%M:%S", time.localtime(messageinfo.time))
|
current_time = time.strftime("%H:%M:%S", time.localtime(messageinfo.time))
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{current_time}][{mes_name}]"
|
f"[{current_time}][{mes_name}]"
|
||||||
@@ -145,48 +141,47 @@ class ChatBot:
|
|||||||
if message.message_info.additional_config:
|
if message.message_info.additional_config:
|
||||||
if "maimcore_reply_probability_gain" in message.message_info.additional_config.keys():
|
if "maimcore_reply_probability_gain" in message.message_info.additional_config.keys():
|
||||||
reply_probability += message.message_info.additional_config["maimcore_reply_probability_gain"]
|
reply_probability += message.message_info.additional_config["maimcore_reply_probability_gain"]
|
||||||
|
|
||||||
|
|
||||||
# 开始组织语言
|
# 开始组织语言
|
||||||
if random() < reply_probability:
|
if random() < reply_probability:
|
||||||
timer1 = time.time()
|
timer1 = time.time()
|
||||||
response_set, thinking_id = await self._generate_response_from_message(message, chat, userinfo, messageinfo)
|
response_set, thinking_id = await self._generate_response_from_message(message, chat, userinfo, messageinfo)
|
||||||
timer2 = time.time()
|
timer2 = time.time()
|
||||||
logger.info(f"5生成回复时间: {timer2 - timer1}秒")
|
logger.info(f"5生成回复时间: {timer2 - timer1}秒")
|
||||||
|
|
||||||
if not response_set:
|
if not response_set:
|
||||||
logger.info("为什么生成回复失败?")
|
logger.info("为什么生成回复失败?")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 发送消息
|
# 发送消息
|
||||||
timer1 = time.time()
|
timer1 = time.time()
|
||||||
await self._send_response_messages(message, chat, response_set, thinking_id)
|
await self._send_response_messages(message, chat, response_set, thinking_id)
|
||||||
timer2 = time.time()
|
timer2 = time.time()
|
||||||
logger.info(f"7发送消息时间: {timer2 - timer1}秒")
|
logger.info(f"7发送消息时间: {timer2 - timer1}秒")
|
||||||
|
|
||||||
# 处理表情包
|
# 处理表情包
|
||||||
timer1 = time.time()
|
timer1 = time.time()
|
||||||
await self._handle_emoji(message, chat, response_set)
|
await self._handle_emoji(message, chat, response_set)
|
||||||
timer2 = time.time()
|
timer2 = time.time()
|
||||||
logger.debug(f"8处理表情包时间: {timer2 - timer1}秒")
|
logger.debug(f"8处理表情包时间: {timer2 - timer1}秒")
|
||||||
|
|
||||||
timer1 = time.time()
|
timer1 = time.time()
|
||||||
await self._update_using_response(message, chat, response_set)
|
await self._update_using_response(message, chat, response_set)
|
||||||
timer2 = time.time()
|
timer2 = time.time()
|
||||||
logger.info(f"6更新htfl时间: {timer2 - timer1}秒")
|
logger.info(f"6更新htfl时间: {timer2 - timer1}秒")
|
||||||
|
|
||||||
# 更新情绪和关系
|
# 更新情绪和关系
|
||||||
# await self._update_emotion_and_relationship(message, chat, response_set)
|
# await self._update_emotion_and_relationship(message, chat, response_set)
|
||||||
|
|
||||||
async def _generate_response_from_message(self, message, chat, userinfo, messageinfo):
|
async def _generate_response_from_message(self, message, chat, userinfo, messageinfo):
|
||||||
"""生成回复内容
|
"""生成回复内容
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: 接收到的消息
|
message: 接收到的消息
|
||||||
chat: 聊天流对象
|
chat: 聊天流对象
|
||||||
userinfo: 用户信息对象
|
userinfo: 用户信息对象
|
||||||
messageinfo: 消息信息对象
|
messageinfo: 消息信息对象
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (response, raw_content) 回复内容和原始内容
|
tuple: (response, raw_content) 回复内容和原始内容
|
||||||
"""
|
"""
|
||||||
@@ -195,7 +190,7 @@ class ChatBot:
|
|||||||
user_nickname=global_config.BOT_NICKNAME,
|
user_nickname=global_config.BOT_NICKNAME,
|
||||||
platform=messageinfo.platform,
|
platform=messageinfo.platform,
|
||||||
)
|
)
|
||||||
|
|
||||||
thinking_time_point = round(time.time(), 2)
|
thinking_time_point = round(time.time(), 2)
|
||||||
thinking_id = "mt" + str(thinking_time_point)
|
thinking_id = "mt" + str(thinking_time_point)
|
||||||
thinking_message = MessageThinking(
|
thinking_message = MessageThinking(
|
||||||
@@ -208,9 +203,9 @@ class ChatBot:
|
|||||||
|
|
||||||
message_manager.add_message(thinking_message)
|
message_manager.add_message(thinking_message)
|
||||||
willing_manager.change_reply_willing_sent(chat)
|
willing_manager.change_reply_willing_sent(chat)
|
||||||
|
|
||||||
response_set = await self.gpt.generate_response(message)
|
response_set = await self.gpt.generate_response(message)
|
||||||
|
|
||||||
return response_set, thinking_id
|
return response_set, thinking_id
|
||||||
|
|
||||||
async def _update_using_response(self, message, response_set):
|
async def _update_using_response(self, message, response_set):
|
||||||
@@ -221,14 +216,13 @@ class ChatBot:
|
|||||||
chat_talking_prompt = get_recent_group_detailed_plain_text(
|
chat_talking_prompt = get_recent_group_detailed_plain_text(
|
||||||
stream_id, limit=global_config.MAX_CONTEXT_SIZE, combine=True
|
stream_id, limit=global_config.MAX_CONTEXT_SIZE, combine=True
|
||||||
)
|
)
|
||||||
|
|
||||||
heartflow.get_subheartflow(stream_id).do_after_reply(response_set, chat_talking_prompt)
|
|
||||||
|
|
||||||
|
heartflow.get_subheartflow(stream_id).do_after_reply(response_set, chat_talking_prompt)
|
||||||
|
|
||||||
async def _send_response_messages(self, message, chat, response_set, thinking_id):
|
async def _send_response_messages(self, message, chat, response_set, thinking_id):
|
||||||
container = message_manager.get_container(chat.stream_id)
|
container = message_manager.get_container(chat.stream_id)
|
||||||
thinking_message = None
|
thinking_message = None
|
||||||
|
|
||||||
# logger.info(f"开始发送消息准备")
|
# logger.info(f"开始发送消息准备")
|
||||||
for msg in container.messages:
|
for msg in container.messages:
|
||||||
if isinstance(msg, MessageThinking) and msg.message_info.message_id == thinking_id:
|
if isinstance(msg, MessageThinking) and msg.message_info.message_id == thinking_id:
|
||||||
@@ -243,7 +237,7 @@ class ChatBot:
|
|||||||
# logger.info(f"开始发送消息")
|
# logger.info(f"开始发送消息")
|
||||||
thinking_start_time = thinking_message.thinking_start_time
|
thinking_start_time = thinking_message.thinking_start_time
|
||||||
message_set = MessageSet(chat, thinking_id)
|
message_set = MessageSet(chat, thinking_id)
|
||||||
|
|
||||||
mark_head = False
|
mark_head = False
|
||||||
for msg in response_set:
|
for msg in response_set:
|
||||||
message_segment = Seg(type="text", data=msg)
|
message_segment = Seg(type="text", data=msg)
|
||||||
@@ -270,7 +264,7 @@ class ChatBot:
|
|||||||
|
|
||||||
async def _handle_emoji(self, message, chat, response):
|
async def _handle_emoji(self, message, chat, response):
|
||||||
"""处理表情包
|
"""处理表情包
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: 接收到的消息
|
message: 接收到的消息
|
||||||
chat: 聊天流对象
|
chat: 聊天流对象
|
||||||
@@ -281,10 +275,10 @@ class ChatBot:
|
|||||||
if emoji_raw:
|
if emoji_raw:
|
||||||
emoji_path, description = emoji_raw
|
emoji_path, description = emoji_raw
|
||||||
emoji_cq = image_path_to_base64(emoji_path)
|
emoji_cq = image_path_to_base64(emoji_path)
|
||||||
|
|
||||||
thinking_time_point = round(message.message_info.time, 2)
|
thinking_time_point = round(message.message_info.time, 2)
|
||||||
bot_response_time = thinking_time_point + (1 if random() < 0.5 else -1)
|
bot_response_time = thinking_time_point + (1 if random() < 0.5 else -1)
|
||||||
|
|
||||||
message_segment = Seg(type="emoji", data=emoji_cq)
|
message_segment = Seg(type="emoji", data=emoji_cq)
|
||||||
bot_message = MessageSending(
|
bot_message = MessageSending(
|
||||||
message_id="mt" + str(thinking_time_point),
|
message_id="mt" + str(thinking_time_point),
|
||||||
@@ -304,7 +298,7 @@ class ChatBot:
|
|||||||
|
|
||||||
async def _update_emotion_and_relationship(self, message, chat, response, raw_content):
|
async def _update_emotion_and_relationship(self, message, chat, response, raw_content):
|
||||||
"""更新情绪和关系
|
"""更新情绪和关系
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: 接收到的消息
|
message: 接收到的消息
|
||||||
chat: 聊天流对象
|
chat: 聊天流对象
|
||||||
@@ -313,27 +307,24 @@ class ChatBot:
|
|||||||
"""
|
"""
|
||||||
stance, emotion = await self.gpt._get_emotion_tags(raw_content, message.processed_plain_text)
|
stance, emotion = await self.gpt._get_emotion_tags(raw_content, message.processed_plain_text)
|
||||||
logger.debug(f"为 '{response}' 立场为:{stance} 获取到的情感标签为:{emotion}")
|
logger.debug(f"为 '{response}' 立场为:{stance} 获取到的情感标签为:{emotion}")
|
||||||
await relationship_manager.calculate_update_relationship_value(
|
await relationship_manager.calculate_update_relationship_value(chat_stream=chat, label=emotion, stance=stance)
|
||||||
chat_stream=chat, label=emotion, stance=stance
|
|
||||||
)
|
|
||||||
self.mood_manager.update_mood_from_emotion(emotion, global_config.mood_intensity_factor)
|
self.mood_manager.update_mood_from_emotion(emotion, global_config.mood_intensity_factor)
|
||||||
|
|
||||||
def _check_ban_words(self, text: str, chat, userinfo) -> bool:
|
def _check_ban_words(self, text: str, chat, userinfo) -> bool:
|
||||||
"""检查消息中是否包含过滤词
|
"""检查消息中是否包含过滤词
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: 要检查的文本
|
text: 要检查的文本
|
||||||
chat: 聊天流对象
|
chat: 聊天流对象
|
||||||
userinfo: 用户信息对象
|
userinfo: 用户信息对象
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: 如果包含过滤词返回True,否则返回False
|
bool: 如果包含过滤词返回True,否则返回False
|
||||||
"""
|
"""
|
||||||
for word in global_config.ban_words:
|
for word in global_config.ban_words:
|
||||||
if word in text:
|
if word in text:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{chat.group_info.group_name if chat.group_info else '私聊'}]"
|
f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}"
|
||||||
f"{userinfo.user_nickname}:{text}"
|
|
||||||
)
|
)
|
||||||
logger.info(f"[过滤词识别]消息中含有{word},filtered")
|
logger.info(f"[过滤词识别]消息中含有{word},filtered")
|
||||||
return True
|
return True
|
||||||
@@ -341,24 +332,24 @@ class ChatBot:
|
|||||||
|
|
||||||
def _check_ban_regex(self, text: str, chat, userinfo) -> bool:
|
def _check_ban_regex(self, text: str, chat, userinfo) -> bool:
|
||||||
"""检查消息是否匹配过滤正则表达式
|
"""检查消息是否匹配过滤正则表达式
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: 要检查的文本
|
text: 要检查的文本
|
||||||
chat: 聊天流对象
|
chat: 聊天流对象
|
||||||
userinfo: 用户信息对象
|
userinfo: 用户信息对象
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: 如果匹配过滤正则返回True,否则返回False
|
bool: 如果匹配过滤正则返回True,否则返回False
|
||||||
"""
|
"""
|
||||||
for pattern in global_config.ban_msgs_regex:
|
for pattern in global_config.ban_msgs_regex:
|
||||||
if re.search(pattern, text):
|
if re.search(pattern, text):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{chat.group_info.group_name if chat.group_info else '私聊'}]"
|
f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{text}"
|
||||||
f"{userinfo.user_nickname}:{text}"
|
|
||||||
)
|
)
|
||||||
logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered")
|
logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered")
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
# 创建全局ChatBot实例
|
# 创建全局ChatBot实例
|
||||||
chat_bot = ChatBot()
|
chat_bot = ChatBot()
|
||||||
|
|||||||
@@ -343,7 +343,7 @@ class EmojiManager:
|
|||||||
while True:
|
while True:
|
||||||
logger.info("[扫描] 开始扫描新表情包...")
|
logger.info("[扫描] 开始扫描新表情包...")
|
||||||
await self.scan_new_emojis()
|
await self.scan_new_emojis()
|
||||||
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
|
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
|
||||||
|
|
||||||
def check_emoji_file_integrity(self):
|
def check_emoji_file_integrity(self):
|
||||||
"""检查表情包文件完整性
|
"""检查表情包文件完整性
|
||||||
|
|||||||
@@ -31,12 +31,9 @@ class ResponseGenerator:
|
|||||||
request_type="response",
|
request_type="response",
|
||||||
)
|
)
|
||||||
self.model_normal = LLM_request(
|
self.model_normal = LLM_request(
|
||||||
model=global_config.llm_normal,
|
model=global_config.llm_normal, temperature=0.7, max_tokens=3000, request_type="response"
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=3000,
|
|
||||||
request_type="response"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_sum = LLM_request(
|
self.model_sum = LLM_request(
|
||||||
model=global_config.llm_summary_by_topic, temperature=0.7, max_tokens=3000, request_type="relation"
|
model=global_config.llm_summary_by_topic, temperature=0.7, max_tokens=3000, request_type="relation"
|
||||||
)
|
)
|
||||||
@@ -53,8 +50,9 @@ class ResponseGenerator:
|
|||||||
self.current_model_type = "浅浅的"
|
self.current_model_type = "浅浅的"
|
||||||
current_model = self.model_normal
|
current_model = self.model_normal
|
||||||
|
|
||||||
logger.info(f"{self.current_model_type}思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}") # noqa: E501
|
logger.info(
|
||||||
|
f"{self.current_model_type}思考:{message.processed_plain_text[:30] + '...' if len(message.processed_plain_text) > 30 else message.processed_plain_text}"
|
||||||
|
) # noqa: E501
|
||||||
|
|
||||||
model_response = await self._generate_response_with_model(message, current_model)
|
model_response = await self._generate_response_with_model(message, current_model)
|
||||||
|
|
||||||
@@ -64,7 +62,6 @@ class ResponseGenerator:
|
|||||||
logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response}")
|
logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response}")
|
||||||
model_response = await self._process_response(model_response)
|
model_response = await self._process_response(model_response)
|
||||||
|
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
else:
|
else:
|
||||||
logger.info(f"{self.current_model_type}思考,失败")
|
logger.info(f"{self.current_model_type}思考,失败")
|
||||||
@@ -93,7 +90,7 @@ class ResponseGenerator:
|
|||||||
)
|
)
|
||||||
timer2 = time.time()
|
timer2 = time.time()
|
||||||
logger.info(f"构建prompt时间: {timer2 - timer1}秒")
|
logger.info(f"构建prompt时间: {timer2 - timer1}秒")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
content, reasoning_content, self.current_model_name = await model.generate_response(prompt)
|
content, reasoning_content, self.current_model_name = await model.generate_response(prompt)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ class PromptBuilder:
|
|||||||
|
|
||||||
current_mind_info = heartflow.get_subheartflow(stream_id).current_mind
|
current_mind_info = heartflow.get_subheartflow(stream_id).current_mind
|
||||||
|
|
||||||
|
|
||||||
# relation_prompt = ""
|
# relation_prompt = ""
|
||||||
# for person in who_chat_in_group:
|
# for person in who_chat_in_group:
|
||||||
# relation_prompt += relationship_manager.build_relationship_info(person)
|
# relation_prompt += relationship_manager.build_relationship_info(person)
|
||||||
@@ -52,7 +51,7 @@ class PromptBuilder:
|
|||||||
# 心情
|
# 心情
|
||||||
mood_manager = MoodManager.get_instance()
|
mood_manager = MoodManager.get_instance()
|
||||||
mood_prompt = mood_manager.get_prompt()
|
mood_prompt = mood_manager.get_prompt()
|
||||||
|
|
||||||
logger.info(f"心情prompt: {mood_prompt}")
|
logger.info(f"心情prompt: {mood_prompt}")
|
||||||
|
|
||||||
# 日程构建
|
# 日程构建
|
||||||
@@ -72,13 +71,12 @@ class PromptBuilder:
|
|||||||
chat_in_group = False
|
chat_in_group = False
|
||||||
chat_talking_prompt = chat_talking_prompt
|
chat_talking_prompt = chat_talking_prompt
|
||||||
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
|
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
|
||||||
|
|
||||||
|
|
||||||
# 使用新的记忆获取方法
|
# 使用新的记忆获取方法
|
||||||
memory_prompt = ""
|
memory_prompt = ""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
#调用 hippocampus 的 get_relevant_memories 方法
|
# 调用 hippocampus 的 get_relevant_memories 方法
|
||||||
relevant_memories = await HippocampusManager.get_instance().get_memory_from_text(
|
relevant_memories = await HippocampusManager.get_instance().get_memory_from_text(
|
||||||
text=message_txt, max_memory_num=3, max_memory_length=2, max_depth=2, fast_retrieval=False
|
text=message_txt, max_memory_num=3, max_memory_length=2, max_depth=2, fast_retrieval=False
|
||||||
)
|
)
|
||||||
@@ -165,11 +163,8 @@ class PromptBuilder:
|
|||||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。
|
||||||
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。"""
|
{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。"""
|
||||||
|
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _build_initiative_prompt_select(self, group_id, probability_1=0.8, probability_2=0.1):
|
def _build_initiative_prompt_select(self, group_id, probability_1=0.8, probability_2=0.1):
|
||||||
current_date = time.strftime("%Y-%m-%d", time.localtime())
|
current_date = time.strftime("%Y-%m-%d", time.localtime())
|
||||||
current_time = time.strftime("%H:%M:%S", time.localtime())
|
current_time = time.strftime("%H:%M:%S", time.localtime())
|
||||||
|
|||||||
@@ -9,9 +9,7 @@ logger = get_module_logger("message_storage")
|
|||||||
|
|
||||||
|
|
||||||
class MessageStorage:
|
class MessageStorage:
|
||||||
async def store_message(
|
async def store_message(self, message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
|
||||||
self, message: Union[MessageSending, MessageRecv], chat_stream: ChatStream
|
|
||||||
) -> None:
|
|
||||||
"""存储消息到数据库"""
|
"""存储消息到数据库"""
|
||||||
try:
|
try:
|
||||||
message_data = {
|
message_data = {
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from collections import Counter
|
|||||||
from ...common.database import db
|
from ...common.database import db
|
||||||
from ...plugins.models.utils_model import LLM_request
|
from ...plugins.models.utils_model import LLM_request
|
||||||
from src.common.logger import get_module_logger, LogConfig, MEMORY_STYLE_CONFIG
|
from src.common.logger import get_module_logger, LogConfig, MEMORY_STYLE_CONFIG
|
||||||
from src.plugins.memory_system.sample_distribution import MemoryBuildScheduler #分布生成器
|
from src.plugins.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
|
||||||
from .memory_config import MemoryConfig
|
from .memory_config import MemoryConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -56,6 +56,7 @@ def get_closest_chat_from_db(length: int, timestamp: str):
|
|||||||
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def calculate_information_content(text):
|
def calculate_information_content(text):
|
||||||
"""计算文本的信息量(熵)"""
|
"""计算文本的信息量(熵)"""
|
||||||
char_count = Counter(text)
|
char_count = Counter(text)
|
||||||
@@ -68,6 +69,7 @@ def calculate_information_content(text):
|
|||||||
|
|
||||||
return entropy
|
return entropy
|
||||||
|
|
||||||
|
|
||||||
def cosine_similarity(v1, v2):
|
def cosine_similarity(v1, v2):
|
||||||
"""计算余弦相似度"""
|
"""计算余弦相似度"""
|
||||||
dot_product = np.dot(v1, v2)
|
dot_product = np.dot(v1, v2)
|
||||||
@@ -223,7 +225,8 @@ class Memory_graph:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
#负责海马体与其他部分的交互
|
|
||||||
|
# 负责海马体与其他部分的交互
|
||||||
class EntorhinalCortex:
|
class EntorhinalCortex:
|
||||||
def __init__(self, hippocampus):
|
def __init__(self, hippocampus):
|
||||||
self.hippocampus = hippocampus
|
self.hippocampus = hippocampus
|
||||||
@@ -243,7 +246,7 @@ class EntorhinalCortex:
|
|||||||
n_hours2=self.config.memory_build_distribution[3],
|
n_hours2=self.config.memory_build_distribution[3],
|
||||||
std_hours2=self.config.memory_build_distribution[4],
|
std_hours2=self.config.memory_build_distribution[4],
|
||||||
weight2=self.config.memory_build_distribution[5],
|
weight2=self.config.memory_build_distribution[5],
|
||||||
total_samples=self.config.build_memory_sample_num
|
total_samples=self.config.build_memory_sample_num,
|
||||||
)
|
)
|
||||||
|
|
||||||
timestamps = sample_scheduler.get_timestamp_array()
|
timestamps = sample_scheduler.get_timestamp_array()
|
||||||
@@ -251,9 +254,7 @@ class EntorhinalCortex:
|
|||||||
chat_samples = []
|
chat_samples = []
|
||||||
for timestamp in timestamps:
|
for timestamp in timestamps:
|
||||||
messages = self.random_get_msg_snippet(
|
messages = self.random_get_msg_snippet(
|
||||||
timestamp,
|
timestamp, self.config.build_memory_sample_length, max_memorized_time_per_msg
|
||||||
self.config.build_memory_sample_length,
|
|
||||||
max_memorized_time_per_msg
|
|
||||||
)
|
)
|
||||||
if messages:
|
if messages:
|
||||||
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
|
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
|
||||||
@@ -455,25 +456,25 @@ class EntorhinalCortex:
|
|||||||
"""清空数据库并重新同步所有记忆数据"""
|
"""清空数据库并重新同步所有记忆数据"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
logger.info("[数据库] 开始重新同步所有记忆数据...")
|
logger.info("[数据库] 开始重新同步所有记忆数据...")
|
||||||
|
|
||||||
# 清空数据库
|
# 清空数据库
|
||||||
clear_start = time.time()
|
clear_start = time.time()
|
||||||
db.graph_data.nodes.delete_many({})
|
db.graph_data.nodes.delete_many({})
|
||||||
db.graph_data.edges.delete_many({})
|
db.graph_data.edges.delete_many({})
|
||||||
clear_end = time.time()
|
clear_end = time.time()
|
||||||
logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}秒")
|
logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}秒")
|
||||||
|
|
||||||
# 获取所有节点和边
|
# 获取所有节点和边
|
||||||
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||||
memory_edges = list(self.memory_graph.G.edges(data=True))
|
memory_edges = list(self.memory_graph.G.edges(data=True))
|
||||||
|
|
||||||
# 重新写入节点
|
# 重新写入节点
|
||||||
node_start = time.time()
|
node_start = time.time()
|
||||||
for concept, data in memory_nodes:
|
for concept, data in memory_nodes:
|
||||||
memory_items = data.get("memory_items", [])
|
memory_items = data.get("memory_items", [])
|
||||||
if not isinstance(memory_items, list):
|
if not isinstance(memory_items, list):
|
||||||
memory_items = [memory_items] if memory_items else []
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
|
||||||
node_data = {
|
node_data = {
|
||||||
"concept": concept,
|
"concept": concept,
|
||||||
"memory_items": memory_items,
|
"memory_items": memory_items,
|
||||||
@@ -484,7 +485,7 @@ class EntorhinalCortex:
|
|||||||
db.graph_data.nodes.insert_one(node_data)
|
db.graph_data.nodes.insert_one(node_data)
|
||||||
node_end = time.time()
|
node_end = time.time()
|
||||||
logger.info(f"[数据库] 写入 {len(memory_nodes)} 个节点耗时: {node_end - node_start:.2f}秒")
|
logger.info(f"[数据库] 写入 {len(memory_nodes)} 个节点耗时: {node_end - node_start:.2f}秒")
|
||||||
|
|
||||||
# 重新写入边
|
# 重新写入边
|
||||||
edge_start = time.time()
|
edge_start = time.time()
|
||||||
for source, target, data in memory_edges:
|
for source, target, data in memory_edges:
|
||||||
@@ -499,12 +500,13 @@ class EntorhinalCortex:
|
|||||||
db.graph_data.edges.insert_one(edge_data)
|
db.graph_data.edges.insert_one(edge_data)
|
||||||
edge_end = time.time()
|
edge_end = time.time()
|
||||||
logger.info(f"[数据库] 写入 {len(memory_edges)} 条边耗时: {edge_end - edge_start:.2f}秒")
|
logger.info(f"[数据库] 写入 {len(memory_edges)} 条边耗时: {edge_end - edge_start:.2f}秒")
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
logger.success(f"[数据库] 重新同步完成,总耗时: {end_time - start_time:.2f}秒")
|
logger.success(f"[数据库] 重新同步完成,总耗时: {end_time - start_time:.2f}秒")
|
||||||
logger.success(f"[数据库] 同步了 {len(memory_nodes)} 个节点和 {len(memory_edges)} 条边")
|
logger.success(f"[数据库] 同步了 {len(memory_nodes)} 个节点和 {len(memory_edges)} 条边")
|
||||||
|
|
||||||
#负责整合,遗忘,合并记忆
|
|
||||||
|
# 负责整合,遗忘,合并记忆
|
||||||
class ParahippocampalGyrus:
|
class ParahippocampalGyrus:
|
||||||
def __init__(self, hippocampus):
|
def __init__(self, hippocampus):
|
||||||
self.hippocampus = hippocampus
|
self.hippocampus = hippocampus
|
||||||
@@ -567,26 +569,26 @@ class ParahippocampalGyrus:
|
|||||||
|
|
||||||
topic_num = self.hippocampus.calculate_topic_num(input_text, compress_rate)
|
topic_num = self.hippocampus.calculate_topic_num(input_text, compress_rate)
|
||||||
topics_response = await self.hippocampus.llm_topic_judge.generate_response(
|
topics_response = await self.hippocampus.llm_topic_judge.generate_response(
|
||||||
self.hippocampus.find_topic_llm(input_text, topic_num))
|
self.hippocampus.find_topic_llm(input_text, topic_num)
|
||||||
|
)
|
||||||
|
|
||||||
# 使用正则表达式提取<>中的内容
|
# 使用正则表达式提取<>中的内容
|
||||||
topics = re.findall(r'<([^>]+)>', topics_response[0])
|
topics = re.findall(r"<([^>]+)>", topics_response[0])
|
||||||
|
|
||||||
# 如果没有找到<>包裹的内容,返回['none']
|
# 如果没有找到<>包裹的内容,返回['none']
|
||||||
if not topics:
|
if not topics:
|
||||||
topics = ['none']
|
topics = ["none"]
|
||||||
else:
|
else:
|
||||||
# 处理提取出的话题
|
# 处理提取出的话题
|
||||||
topics = [
|
topics = [
|
||||||
topic.strip()
|
topic.strip()
|
||||||
for topic in ','.join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
||||||
if topic.strip()
|
if topic.strip()
|
||||||
]
|
]
|
||||||
|
|
||||||
# 过滤掉包含禁用关键词的topic
|
# 过滤掉包含禁用关键词的topic
|
||||||
filtered_topics = [
|
filtered_topics = [
|
||||||
topic for topic in topics
|
topic for topic in topics if not any(keyword in topic for keyword in self.config.memory_ban_words)
|
||||||
if not any(keyword in topic for keyword in self.config.memory_ban_words)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
logger.debug(f"过滤后话题: {filtered_topics}")
|
logger.debug(f"过滤后话题: {filtered_topics}")
|
||||||
@@ -601,12 +603,12 @@ class ParahippocampalGyrus:
|
|||||||
# 等待所有任务完成
|
# 等待所有任务完成
|
||||||
compressed_memory = set()
|
compressed_memory = set()
|
||||||
similar_topics_dict = {}
|
similar_topics_dict = {}
|
||||||
|
|
||||||
for topic, task in tasks:
|
for topic, task in tasks:
|
||||||
response = await task
|
response = await task
|
||||||
if response:
|
if response:
|
||||||
compressed_memory.add((topic, response[0]))
|
compressed_memory.add((topic, response[0]))
|
||||||
|
|
||||||
existing_topics = list(self.memory_graph.G.nodes())
|
existing_topics = list(self.memory_graph.G.nodes())
|
||||||
similar_topics = []
|
similar_topics = []
|
||||||
|
|
||||||
@@ -651,7 +653,7 @@ class ParahippocampalGyrus:
|
|||||||
current_time = datetime.datetime.now().timestamp()
|
current_time = datetime.datetime.now().timestamp()
|
||||||
logger.debug(f"添加节点: {', '.join(topic for topic, _ in compressed_memory)}")
|
logger.debug(f"添加节点: {', '.join(topic for topic, _ in compressed_memory)}")
|
||||||
all_added_nodes.extend(topic for topic, _ in compressed_memory)
|
all_added_nodes.extend(topic for topic, _ in compressed_memory)
|
||||||
|
|
||||||
for topic, memory in compressed_memory:
|
for topic, memory in compressed_memory:
|
||||||
self.memory_graph.add_dot(topic, memory)
|
self.memory_graph.add_dot(topic, memory)
|
||||||
all_topics.append(topic)
|
all_topics.append(topic)
|
||||||
@@ -661,13 +663,13 @@ class ParahippocampalGyrus:
|
|||||||
for similar_topic, similarity in similar_topics:
|
for similar_topic, similarity in similar_topics:
|
||||||
if topic != similar_topic:
|
if topic != similar_topic:
|
||||||
strength = int(similarity * 10)
|
strength = int(similarity * 10)
|
||||||
|
|
||||||
logger.debug(f"连接相似节点: {topic} 和 {similar_topic} (强度: {strength})")
|
logger.debug(f"连接相似节点: {topic} 和 {similar_topic} (强度: {strength})")
|
||||||
all_added_edges.append(f"{topic}-{similar_topic}")
|
all_added_edges.append(f"{topic}-{similar_topic}")
|
||||||
|
|
||||||
all_connected_nodes.append(topic)
|
all_connected_nodes.append(topic)
|
||||||
all_connected_nodes.append(similar_topic)
|
all_connected_nodes.append(similar_topic)
|
||||||
|
|
||||||
self.memory_graph.G.add_edge(
|
self.memory_graph.G.add_edge(
|
||||||
topic,
|
topic,
|
||||||
similar_topic,
|
similar_topic,
|
||||||
@@ -685,14 +687,11 @@ class ParahippocampalGyrus:
|
|||||||
logger.success(f"更新记忆: {', '.join(all_added_nodes)}")
|
logger.success(f"更新记忆: {', '.join(all_added_nodes)}")
|
||||||
logger.debug(f"强化连接: {', '.join(all_added_edges)}")
|
logger.debug(f"强化连接: {', '.join(all_added_edges)}")
|
||||||
logger.info(f"强化连接节点: {', '.join(all_connected_nodes)}")
|
logger.info(f"强化连接节点: {', '.join(all_connected_nodes)}")
|
||||||
|
|
||||||
await self.hippocampus.entorhinal_cortex.sync_memory_to_db()
|
await self.hippocampus.entorhinal_cortex.sync_memory_to_db()
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
logger.success(
|
logger.success(f"---------------------记忆构建耗时: {end_time - start_time:.2f} 秒---------------------")
|
||||||
f"---------------------记忆构建耗时: {end_time - start_time:.2f} "
|
|
||||||
"秒---------------------"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def operation_forget_topic(self, percentage=0.005):
|
async def operation_forget_topic(self, percentage=0.005):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -714,11 +713,11 @@ class ParahippocampalGyrus:
|
|||||||
# 使用列表存储变化信息
|
# 使用列表存储变化信息
|
||||||
edge_changes = {
|
edge_changes = {
|
||||||
"weakened": [], # 存储减弱的边
|
"weakened": [], # 存储减弱的边
|
||||||
"removed": [] # 存储移除的边
|
"removed": [], # 存储移除的边
|
||||||
}
|
}
|
||||||
node_changes = {
|
node_changes = {
|
||||||
"reduced": [], # 存储减少记忆的节点
|
"reduced": [], # 存储减少记忆的节点
|
||||||
"removed": [] # 存储移除的节点
|
"removed": [], # 存储移除的节点
|
||||||
}
|
}
|
||||||
|
|
||||||
current_time = datetime.datetime.now().timestamp()
|
current_time = datetime.datetime.now().timestamp()
|
||||||
@@ -771,35 +770,40 @@ class ParahippocampalGyrus:
|
|||||||
|
|
||||||
if any(edge_changes.values()) or any(node_changes.values()):
|
if any(edge_changes.values()) or any(node_changes.values()):
|
||||||
sync_start = time.time()
|
sync_start = time.time()
|
||||||
|
|
||||||
await self.hippocampus.entorhinal_cortex.resync_memory_to_db()
|
await self.hippocampus.entorhinal_cortex.resync_memory_to_db()
|
||||||
|
|
||||||
sync_end = time.time()
|
sync_end = time.time()
|
||||||
logger.info(f"[遗忘] 数据库同步耗时: {sync_end - sync_start:.2f}秒")
|
logger.info(f"[遗忘] 数据库同步耗时: {sync_end - sync_start:.2f}秒")
|
||||||
|
|
||||||
# 汇总输出所有变化
|
# 汇总输出所有变化
|
||||||
logger.info("[遗忘] 遗忘操作统计:")
|
logger.info("[遗忘] 遗忘操作统计:")
|
||||||
if edge_changes["weakened"]:
|
if edge_changes["weakened"]:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[遗忘] 减弱的连接 ({len(edge_changes['weakened'])}个): {', '.join(edge_changes['weakened'])}")
|
f"[遗忘] 减弱的连接 ({len(edge_changes['weakened'])}个): {', '.join(edge_changes['weakened'])}"
|
||||||
|
)
|
||||||
|
|
||||||
if edge_changes["removed"]:
|
if edge_changes["removed"]:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[遗忘] 移除的连接 ({len(edge_changes['removed'])}个): {', '.join(edge_changes['removed'])}")
|
f"[遗忘] 移除的连接 ({len(edge_changes['removed'])}个): {', '.join(edge_changes['removed'])}"
|
||||||
|
)
|
||||||
|
|
||||||
if node_changes["reduced"]:
|
if node_changes["reduced"]:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[遗忘] 减少记忆的节点 ({len(node_changes['reduced'])}个): {', '.join(node_changes['reduced'])}")
|
f"[遗忘] 减少记忆的节点 ({len(node_changes['reduced'])}个): {', '.join(node_changes['reduced'])}"
|
||||||
|
)
|
||||||
|
|
||||||
if node_changes["removed"]:
|
if node_changes["removed"]:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[遗忘] 移除的节点 ({len(node_changes['removed'])}个): {', '.join(node_changes['removed'])}")
|
f"[遗忘] 移除的节点 ({len(node_changes['removed'])}个): {', '.join(node_changes['removed'])}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("[遗忘] 本次检查没有节点或连接满足遗忘条件")
|
logger.info("[遗忘] 本次检查没有节点或连接满足遗忘条件")
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
logger.info(f"[遗忘] 总耗时: {end_time - start_time:.2f}秒")
|
logger.info(f"[遗忘] 总耗时: {end_time - start_time:.2f}秒")
|
||||||
|
|
||||||
|
|
||||||
# 海马体
|
# 海马体
|
||||||
class Hippocampus:
|
class Hippocampus:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -817,8 +821,8 @@ class Hippocampus:
|
|||||||
self.parahippocampal_gyrus = ParahippocampalGyrus(self)
|
self.parahippocampal_gyrus = ParahippocampalGyrus(self)
|
||||||
# 从数据库加载记忆图
|
# 从数据库加载记忆图
|
||||||
self.entorhinal_cortex.sync_memory_from_db()
|
self.entorhinal_cortex.sync_memory_from_db()
|
||||||
self.llm_topic_judge = LLM_request(self.config.llm_topic_judge,request_type="memory")
|
self.llm_topic_judge = LLM_request(self.config.llm_topic_judge, request_type="memory")
|
||||||
self.llm_summary_by_topic = LLM_request(self.config.llm_summary_by_topic,request_type="memory")
|
self.llm_summary_by_topic = LLM_request(self.config.llm_summary_by_topic, request_type="memory")
|
||||||
|
|
||||||
def get_all_node_names(self) -> list:
|
def get_all_node_names(self) -> list:
|
||||||
"""获取记忆图中所有节点的名字列表"""
|
"""获取记忆图中所有节点的名字列表"""
|
||||||
@@ -901,16 +905,21 @@ class Hippocampus:
|
|||||||
memory_items = node_data.get("memory_items", [])
|
memory_items = node_data.get("memory_items", [])
|
||||||
if not isinstance(memory_items, list):
|
if not isinstance(memory_items, list):
|
||||||
memory_items = [memory_items] if memory_items else []
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
|
||||||
memories.append((node, memory_items, similarity))
|
memories.append((node, memory_items, similarity))
|
||||||
|
|
||||||
# 按相似度降序排序
|
# 按相似度降序排序
|
||||||
memories.sort(key=lambda x: x[2], reverse=True)
|
memories.sort(key=lambda x: x[2], reverse=True)
|
||||||
return memories
|
return memories
|
||||||
|
|
||||||
async def get_memory_from_text(self, text: str, max_memory_num: int = 3, max_memory_length: int = 2,
|
async def get_memory_from_text(
|
||||||
max_depth: int = 3,
|
self,
|
||||||
fast_retrieval: bool = False) -> list:
|
text: str,
|
||||||
|
max_memory_num: int = 3,
|
||||||
|
max_memory_length: int = 2,
|
||||||
|
max_depth: int = 3,
|
||||||
|
fast_retrieval: bool = False,
|
||||||
|
) -> list:
|
||||||
"""从文本中提取关键词并获取相关记忆。
|
"""从文本中提取关键词并获取相关记忆。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -943,18 +952,16 @@ class Hippocampus:
|
|||||||
# 使用LLM提取关键词
|
# 使用LLM提取关键词
|
||||||
topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量
|
topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量
|
||||||
# logger.info(f"提取关键词数量: {topic_num}")
|
# logger.info(f"提取关键词数量: {topic_num}")
|
||||||
topics_response = await self.llm_topic_judge.generate_response(
|
topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(text, topic_num))
|
||||||
self.find_topic_llm(text, topic_num)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 提取关键词
|
# 提取关键词
|
||||||
keywords = re.findall(r'<([^>]+)>', topics_response[0])
|
keywords = re.findall(r"<([^>]+)>", topics_response[0])
|
||||||
if not keywords:
|
if not keywords:
|
||||||
keywords = []
|
keywords = []
|
||||||
else:
|
else:
|
||||||
keywords = [
|
keywords = [
|
||||||
keyword.strip()
|
keyword.strip()
|
||||||
for keyword in ','.join(keywords).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
for keyword in ",".join(keywords).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
||||||
if keyword.strip()
|
if keyword.strip()
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -965,7 +972,7 @@ class Hippocampus:
|
|||||||
if not valid_keywords:
|
if not valid_keywords:
|
||||||
logger.info("没有找到有效的关键词节点")
|
logger.info("没有找到有效的关键词节点")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
logger.info(f"有效的关键词: {', '.join(valid_keywords)}")
|
logger.info(f"有效的关键词: {', '.join(valid_keywords)}")
|
||||||
|
|
||||||
# 从每个关键词获取记忆
|
# 从每个关键词获取记忆
|
||||||
@@ -981,35 +988,36 @@ class Hippocampus:
|
|||||||
visited_nodes = {keyword}
|
visited_nodes = {keyword}
|
||||||
# 待处理的节点队列,每个元素是(节点, 激活值, 当前深度)
|
# 待处理的节点队列,每个元素是(节点, 激活值, 当前深度)
|
||||||
nodes_to_process = [(keyword, 1.0, 0)]
|
nodes_to_process = [(keyword, 1.0, 0)]
|
||||||
|
|
||||||
while nodes_to_process:
|
while nodes_to_process:
|
||||||
current_node, current_activation, current_depth = nodes_to_process.pop(0)
|
current_node, current_activation, current_depth = nodes_to_process.pop(0)
|
||||||
|
|
||||||
# 如果激活值小于0或超过最大深度,停止扩散
|
# 如果激活值小于0或超过最大深度,停止扩散
|
||||||
if current_activation <= 0 or current_depth >= max_depth:
|
if current_activation <= 0 or current_depth >= max_depth:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 获取当前节点的所有邻居
|
# 获取当前节点的所有邻居
|
||||||
neighbors = list(self.memory_graph.G.neighbors(current_node))
|
neighbors = list(self.memory_graph.G.neighbors(current_node))
|
||||||
|
|
||||||
for neighbor in neighbors:
|
for neighbor in neighbors:
|
||||||
if neighbor in visited_nodes:
|
if neighbor in visited_nodes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 获取连接强度
|
# 获取连接强度
|
||||||
edge_data = self.memory_graph.G[current_node][neighbor]
|
edge_data = self.memory_graph.G[current_node][neighbor]
|
||||||
strength = edge_data.get("strength", 1)
|
strength = edge_data.get("strength", 1)
|
||||||
|
|
||||||
# 计算新的激活值
|
# 计算新的激活值
|
||||||
new_activation = current_activation - (1 / strength)
|
new_activation = current_activation - (1 / strength)
|
||||||
|
|
||||||
if new_activation > 0:
|
if new_activation > 0:
|
||||||
activation_values[neighbor] = new_activation
|
activation_values[neighbor] = new_activation
|
||||||
visited_nodes.add(neighbor)
|
visited_nodes.add(neighbor)
|
||||||
nodes_to_process.append((neighbor, new_activation, current_depth + 1))
|
nodes_to_process.append((neighbor, new_activation, current_depth + 1))
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"节点 '{neighbor}' 被激活,激活值: {new_activation:.2f} (通过 '{current_node}' 连接,强度: {strength}, 深度: {current_depth + 1})") # noqa: E501
|
f"节点 '{neighbor}' 被激活,激活值: {new_activation:.2f} (通过 '{current_node}' 连接,强度: {strength}, 深度: {current_depth + 1})"
|
||||||
|
) # noqa: E501
|
||||||
|
|
||||||
# 更新激活映射
|
# 更新激活映射
|
||||||
for node, activation_value in activation_values.items():
|
for node, activation_value in activation_values.items():
|
||||||
if activation_value > 0:
|
if activation_value > 0:
|
||||||
@@ -1017,7 +1025,7 @@ class Hippocampus:
|
|||||||
activate_map[node] += activation_value
|
activate_map[node] += activation_value
|
||||||
else:
|
else:
|
||||||
activate_map[node] = activation_value
|
activate_map[node] = activation_value
|
||||||
|
|
||||||
# 输出激活映射
|
# 输出激活映射
|
||||||
# logger.info("激活映射统计:")
|
# logger.info("激活映射统计:")
|
||||||
# for node, total_activation in sorted(activate_map.items(), key=lambda x: x[1], reverse=True):
|
# for node, total_activation in sorted(activate_map.items(), key=lambda x: x[1], reverse=True):
|
||||||
@@ -1026,28 +1034,24 @@ class Hippocampus:
|
|||||||
# 基于激活值平方的独立概率选择
|
# 基于激活值平方的独立概率选择
|
||||||
remember_map = {}
|
remember_map = {}
|
||||||
# logger.info("基于激活值平方的归一化选择:")
|
# logger.info("基于激活值平方的归一化选择:")
|
||||||
|
|
||||||
# 计算所有激活值的平方和
|
# 计算所有激活值的平方和
|
||||||
total_squared_activation = sum(activation ** 2 for activation in activate_map.values())
|
total_squared_activation = sum(activation**2 for activation in activate_map.values())
|
||||||
if total_squared_activation > 0:
|
if total_squared_activation > 0:
|
||||||
# 计算归一化的激活值
|
# 计算归一化的激活值
|
||||||
normalized_activations = {
|
normalized_activations = {
|
||||||
node: (activation ** 2) / total_squared_activation
|
node: (activation**2) / total_squared_activation for node, activation in activate_map.items()
|
||||||
for node, activation in activate_map.items()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# 按归一化激活值排序并选择前max_memory_num个
|
# 按归一化激活值排序并选择前max_memory_num个
|
||||||
sorted_nodes = sorted(
|
sorted_nodes = sorted(normalized_activations.items(), key=lambda x: x[1], reverse=True)[:max_memory_num]
|
||||||
normalized_activations.items(),
|
|
||||||
key=lambda x: x[1],
|
|
||||||
reverse=True
|
|
||||||
)[:max_memory_num]
|
|
||||||
|
|
||||||
# 将选中的节点添加到remember_map
|
# 将选中的节点添加到remember_map
|
||||||
for node, normalized_activation in sorted_nodes:
|
for node, normalized_activation in sorted_nodes:
|
||||||
remember_map[node] = activate_map[node] # 使用原始激活值
|
remember_map[node] = activate_map[node] # 使用原始激活值
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"节点 '{node}' (归一化激活值: {normalized_activation:.2f}, 激活值: {activate_map[node]:.2f})")
|
f"节点 '{node}' (归一化激活值: {normalized_activation:.2f}, 激活值: {activate_map[node]:.2f})"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("没有有效的激活值")
|
logger.info("没有有效的激活值")
|
||||||
|
|
||||||
@@ -1060,7 +1064,7 @@ class Hippocampus:
|
|||||||
memory_items = node_data.get("memory_items", [])
|
memory_items = node_data.get("memory_items", [])
|
||||||
if not isinstance(memory_items, list):
|
if not isinstance(memory_items, list):
|
||||||
memory_items = [memory_items] if memory_items else []
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
|
||||||
if memory_items:
|
if memory_items:
|
||||||
logger.debug(f"节点包含 {len(memory_items)} 条记忆")
|
logger.debug(f"节点包含 {len(memory_items)} 条记忆")
|
||||||
# 计算每条记忆与输入文本的相似度
|
# 计算每条记忆与输入文本的相似度
|
||||||
@@ -1079,7 +1083,7 @@ class Hippocampus:
|
|||||||
memory_similarities.sort(key=lambda x: x[1], reverse=True)
|
memory_similarities.sort(key=lambda x: x[1], reverse=True)
|
||||||
# 获取最匹配的记忆
|
# 获取最匹配的记忆
|
||||||
top_memories = memory_similarities[:max_memory_length]
|
top_memories = memory_similarities[:max_memory_length]
|
||||||
|
|
||||||
# 添加到结果中
|
# 添加到结果中
|
||||||
for memory, similarity in top_memories:
|
for memory, similarity in top_memories:
|
||||||
all_memories.append((node, [memory], similarity))
|
all_memories.append((node, [memory], similarity))
|
||||||
@@ -1106,11 +1110,10 @@ class Hippocampus:
|
|||||||
memory = memory_items[0] # 因为每个topic只有一条记忆
|
memory = memory_items[0] # 因为每个topic只有一条记忆
|
||||||
result.append((topic, memory))
|
result.append((topic, memory))
|
||||||
logger.info(f"选中记忆: {memory} (来自节点: {topic})")
|
logger.info(f"选中记忆: {memory} (来自节点: {topic})")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def get_activate_from_text(self, text: str, max_depth: int = 3,
|
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> float:
|
||||||
fast_retrieval: bool = False) -> float:
|
|
||||||
"""从文本中提取关键词并获取相关记忆。
|
"""从文本中提取关键词并获取相关记忆。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -1140,18 +1143,16 @@ class Hippocampus:
|
|||||||
# 使用LLM提取关键词
|
# 使用LLM提取关键词
|
||||||
topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量
|
topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量
|
||||||
# logger.info(f"提取关键词数量: {topic_num}")
|
# logger.info(f"提取关键词数量: {topic_num}")
|
||||||
topics_response = await self.llm_topic_judge.generate_response(
|
topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(text, topic_num))
|
||||||
self.find_topic_llm(text, topic_num)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 提取关键词
|
# 提取关键词
|
||||||
keywords = re.findall(r'<([^>]+)>', topics_response[0])
|
keywords = re.findall(r"<([^>]+)>", topics_response[0])
|
||||||
if not keywords:
|
if not keywords:
|
||||||
keywords = []
|
keywords = []
|
||||||
else:
|
else:
|
||||||
keywords = [
|
keywords = [
|
||||||
keyword.strip()
|
keyword.strip()
|
||||||
for keyword in ','.join(keywords).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
for keyword in ",".join(keywords).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
||||||
if keyword.strip()
|
if keyword.strip()
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -1162,7 +1163,7 @@ class Hippocampus:
|
|||||||
if not valid_keywords:
|
if not valid_keywords:
|
||||||
logger.info("没有找到有效的关键词节点")
|
logger.info("没有找到有效的关键词节点")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
logger.info(f"有效的关键词: {', '.join(valid_keywords)}")
|
logger.info(f"有效的关键词: {', '.join(valid_keywords)}")
|
||||||
|
|
||||||
# 从每个关键词获取记忆
|
# 从每个关键词获取记忆
|
||||||
@@ -1177,35 +1178,35 @@ class Hippocampus:
|
|||||||
visited_nodes = {keyword}
|
visited_nodes = {keyword}
|
||||||
# 待处理的节点队列,每个元素是(节点, 激活值, 当前深度)
|
# 待处理的节点队列,每个元素是(节点, 激活值, 当前深度)
|
||||||
nodes_to_process = [(keyword, 1.0, 0)]
|
nodes_to_process = [(keyword, 1.0, 0)]
|
||||||
|
|
||||||
while nodes_to_process:
|
while nodes_to_process:
|
||||||
current_node, current_activation, current_depth = nodes_to_process.pop(0)
|
current_node, current_activation, current_depth = nodes_to_process.pop(0)
|
||||||
|
|
||||||
# 如果激活值小于0或超过最大深度,停止扩散
|
# 如果激活值小于0或超过最大深度,停止扩散
|
||||||
if current_activation <= 0 or current_depth >= max_depth:
|
if current_activation <= 0 or current_depth >= max_depth:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 获取当前节点的所有邻居
|
# 获取当前节点的所有邻居
|
||||||
neighbors = list(self.memory_graph.G.neighbors(current_node))
|
neighbors = list(self.memory_graph.G.neighbors(current_node))
|
||||||
|
|
||||||
for neighbor in neighbors:
|
for neighbor in neighbors:
|
||||||
if neighbor in visited_nodes:
|
if neighbor in visited_nodes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 获取连接强度
|
# 获取连接强度
|
||||||
edge_data = self.memory_graph.G[current_node][neighbor]
|
edge_data = self.memory_graph.G[current_node][neighbor]
|
||||||
strength = edge_data.get("strength", 1)
|
strength = edge_data.get("strength", 1)
|
||||||
|
|
||||||
# 计算新的激活值
|
# 计算新的激活值
|
||||||
new_activation = current_activation - (1 / strength)
|
new_activation = current_activation - (1 / strength)
|
||||||
|
|
||||||
if new_activation > 0:
|
if new_activation > 0:
|
||||||
activation_values[neighbor] = new_activation
|
activation_values[neighbor] = new_activation
|
||||||
visited_nodes.add(neighbor)
|
visited_nodes.add(neighbor)
|
||||||
nodes_to_process.append((neighbor, new_activation, current_depth + 1))
|
nodes_to_process.append((neighbor, new_activation, current_depth + 1))
|
||||||
# logger.debug(
|
# logger.debug(
|
||||||
# f"节点 '{neighbor}' 被激活,激活值: {new_activation:.2f} (通过 '{current_node}' 连接,强度: {strength}, 深度: {current_depth + 1})") # noqa: E501
|
# f"节点 '{neighbor}' 被激活,激活值: {new_activation:.2f} (通过 '{current_node}' 连接,强度: {strength}, 深度: {current_depth + 1})") # noqa: E501
|
||||||
|
|
||||||
# 更新激活映射
|
# 更新激活映射
|
||||||
for node, activation_value in activation_values.items():
|
for node, activation_value in activation_values.items():
|
||||||
if activation_value > 0:
|
if activation_value > 0:
|
||||||
@@ -1213,23 +1214,24 @@ class Hippocampus:
|
|||||||
activate_map[node] += activation_value
|
activate_map[node] += activation_value
|
||||||
else:
|
else:
|
||||||
activate_map[node] = activation_value
|
activate_map[node] = activation_value
|
||||||
|
|
||||||
# 输出激活映射
|
# 输出激活映射
|
||||||
# logger.info("激活映射统计:")
|
# logger.info("激活映射统计:")
|
||||||
# for node, total_activation in sorted(activate_map.items(), key=lambda x: x[1], reverse=True):
|
# for node, total_activation in sorted(activate_map.items(), key=lambda x: x[1], reverse=True):
|
||||||
# logger.info(f"节点 '{node}': 累计激活值 = {total_activation:.2f}")
|
# logger.info(f"节点 '{node}': 累计激活值 = {total_activation:.2f}")
|
||||||
|
|
||||||
# 计算激活节点数与总节点数的比值
|
# 计算激活节点数与总节点数的比值
|
||||||
total_activation = sum(activate_map.values())
|
total_activation = sum(activate_map.values())
|
||||||
logger.info(f"总激活值: {total_activation:.2f}")
|
logger.info(f"总激活值: {total_activation:.2f}")
|
||||||
total_nodes = len(self.memory_graph.G.nodes())
|
total_nodes = len(self.memory_graph.G.nodes())
|
||||||
# activated_nodes = len(activate_map)
|
# activated_nodes = len(activate_map)
|
||||||
activation_ratio = total_activation / total_nodes if total_nodes > 0 else 0
|
activation_ratio = total_activation / total_nodes if total_nodes > 0 else 0
|
||||||
activation_ratio = activation_ratio*60
|
activation_ratio = activation_ratio * 60
|
||||||
logger.info(f"总激活值: {total_activation:.2f}, 总节点数: {total_nodes}, 激活: {activation_ratio}")
|
logger.info(f"总激活值: {total_activation:.2f}, 总节点数: {total_nodes}, 激活: {activation_ratio}")
|
||||||
|
|
||||||
return activation_ratio
|
return activation_ratio
|
||||||
|
|
||||||
|
|
||||||
class HippocampusManager:
|
class HippocampusManager:
|
||||||
_instance = None
|
_instance = None
|
||||||
_hippocampus = None
|
_hippocampus = None
|
||||||
@@ -1252,12 +1254,12 @@ class HippocampusManager:
|
|||||||
"""初始化海马体实例"""
|
"""初始化海马体实例"""
|
||||||
if self._initialized:
|
if self._initialized:
|
||||||
return self._hippocampus
|
return self._hippocampus
|
||||||
|
|
||||||
self._global_config = global_config
|
self._global_config = global_config
|
||||||
self._hippocampus = Hippocampus()
|
self._hippocampus = Hippocampus()
|
||||||
self._hippocampus.initialize(global_config)
|
self._hippocampus.initialize(global_config)
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
# 输出记忆系统参数信息
|
# 输出记忆系统参数信息
|
||||||
config = self._hippocampus.config
|
config = self._hippocampus.config
|
||||||
|
|
||||||
@@ -1265,16 +1267,15 @@ class HippocampusManager:
|
|||||||
memory_graph = self._hippocampus.memory_graph.G
|
memory_graph = self._hippocampus.memory_graph.G
|
||||||
node_count = len(memory_graph.nodes())
|
node_count = len(memory_graph.nodes())
|
||||||
edge_count = len(memory_graph.edges())
|
edge_count = len(memory_graph.edges())
|
||||||
|
|
||||||
logger.success(f'''--------------------------------
|
logger.success(f"""--------------------------------
|
||||||
记忆系统参数配置:
|
记忆系统参数配置:
|
||||||
构建间隔: {global_config.build_memory_interval}秒|样本数: {config.build_memory_sample_num},长度: {config.build_memory_sample_length}|压缩率: {config.memory_compress_rate}
|
构建间隔: {global_config.build_memory_interval}秒|样本数: {config.build_memory_sample_num},长度: {config.build_memory_sample_length}|压缩率: {config.memory_compress_rate}
|
||||||
记忆构建分布: {config.memory_build_distribution}
|
记忆构建分布: {config.memory_build_distribution}
|
||||||
遗忘间隔: {global_config.forget_memory_interval}秒|遗忘比例: {global_config.memory_forget_percentage}|遗忘: {config.memory_forget_time}小时之后
|
遗忘间隔: {global_config.forget_memory_interval}秒|遗忘比例: {global_config.memory_forget_percentage}|遗忘: {config.memory_forget_time}小时之后
|
||||||
记忆图统计信息: 节点数量: {node_count}, 连接数量: {edge_count}
|
记忆图统计信息: 节点数量: {node_count}, 连接数量: {edge_count}
|
||||||
--------------------------------''') #noqa: E501
|
--------------------------------""") # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
return self._hippocampus
|
return self._hippocampus
|
||||||
|
|
||||||
async def build_memory(self):
|
async def build_memory(self):
|
||||||
@@ -1289,17 +1290,22 @@ class HippocampusManager:
|
|||||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||||
return await self._hippocampus.parahippocampal_gyrus.operation_forget_topic(percentage)
|
return await self._hippocampus.parahippocampal_gyrus.operation_forget_topic(percentage)
|
||||||
|
|
||||||
async def get_memory_from_text(self, text: str, max_memory_num: int = 3,
|
async def get_memory_from_text(
|
||||||
max_memory_length: int = 2, max_depth: int = 3,
|
self,
|
||||||
fast_retrieval: bool = False) -> list:
|
text: str,
|
||||||
|
max_memory_num: int = 3,
|
||||||
|
max_memory_length: int = 2,
|
||||||
|
max_depth: int = 3,
|
||||||
|
fast_retrieval: bool = False,
|
||||||
|
) -> list:
|
||||||
"""从文本中获取相关记忆的公共接口"""
|
"""从文本中获取相关记忆的公共接口"""
|
||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||||
return await self._hippocampus.get_memory_from_text(
|
return await self._hippocampus.get_memory_from_text(
|
||||||
text, max_memory_num, max_memory_length, max_depth, fast_retrieval)
|
text, max_memory_num, max_memory_length, max_depth, fast_retrieval
|
||||||
|
)
|
||||||
|
|
||||||
async def get_activate_from_text(self, text: str, max_depth: int = 3,
|
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> float:
|
||||||
fast_retrieval: bool = False) -> float:
|
|
||||||
"""从文本中获取激活值的公共接口"""
|
"""从文本中获取激活值的公共接口"""
|
||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||||
@@ -1316,5 +1322,3 @@ class HippocampusManager:
|
|||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||||
return self._hippocampus.get_all_node_names()
|
return self._hippocampus.get_all_node_names()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,11 +3,13 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# 添加项目根目录到系统路径
|
# 添加项目根目录到系统路径
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
|
||||||
from src.plugins.memory_system.Hippocampus import HippocampusManager
|
from src.plugins.memory_system.Hippocampus import HippocampusManager
|
||||||
from src.plugins.config.config import global_config
|
from src.plugins.config.config import global_config
|
||||||
|
|
||||||
|
|
||||||
async def test_memory_system():
|
async def test_memory_system():
|
||||||
"""测试记忆系统的主要功能"""
|
"""测试记忆系统的主要功能"""
|
||||||
try:
|
try:
|
||||||
@@ -24,7 +26,7 @@ async def test_memory_system():
|
|||||||
|
|
||||||
# 测试记忆检索
|
# 测试记忆检索
|
||||||
test_text = "千石可乐在群里聊天"
|
test_text = "千石可乐在群里聊天"
|
||||||
test_text = '''[03-24 10:39:37] 麦麦(ta的id:2814567326): 早说散步结果下雨改成室内运动啊
|
test_text = """[03-24 10:39:37] 麦麦(ta的id:2814567326): 早说散步结果下雨改成室内运动啊
|
||||||
[03-24 10:39:37] 麦麦(ta的id:2814567326): [回复:变量] 变量就像今天计划总变
|
[03-24 10:39:37] 麦麦(ta的id:2814567326): [回复:变量] 变量就像今天计划总变
|
||||||
[03-24 10:39:44] 状态异常(ta的id:535554838): 要把本地文件改成弹出来的路径吗
|
[03-24 10:39:44] 状态异常(ta的id:535554838): 要把本地文件改成弹出来的路径吗
|
||||||
[03-24 10:40:35] 状态异常(ta的id:535554838): [图片:这张图片显示的是Windows系统的环境变量设置界面。界面左侧列出了多个环境变量的值,包括Intel Dev Redist、Windows、Windows PowerShell、OpenSSH、NVIDIA Corporation的目录等。右侧有新建、编辑、浏览、删除、上移、下移和编辑文本等操作按钮。图片下方有一个错误提示框,显示"Windows找不到文件'mongodb\\bin\\mongod.exe'。请确定文件名是否正确后,再试一次。"这意味着用户试图运行MongoDB的mongod.exe程序时,系统找不到该文件。这可能是因为MongoDB的安装路径未正确添加到系统环境变量中,或者文件路径有误。
|
[03-24 10:40:35] 状态异常(ta的id:535554838): [图片:这张图片显示的是Windows系统的环境变量设置界面。界面左侧列出了多个环境变量的值,包括Intel Dev Redist、Windows、Windows PowerShell、OpenSSH、NVIDIA Corporation的目录等。右侧有新建、编辑、浏览、删除、上移、下移和编辑文本等操作按钮。图片下方有一个错误提示框,显示"Windows找不到文件'mongodb\\bin\\mongod.exe'。请确定文件名是否正确后,再试一次。"这意味着用户试图运行MongoDB的mongod.exe程序时,系统找不到该文件。这可能是因为MongoDB的安装路径未正确添加到系统环境变量中,或者文件路径有误。
|
||||||
@@ -39,28 +41,21 @@ async def test_memory_system():
|
|||||||
[03-24 10:46:12] (ta的id:3229291803): [表情包:这张表情包显示了一只手正在做"点赞"的动作,通常表示赞同、喜欢或支持。这个表情包所表达的情感是积极的、赞同的或支持的。]
|
[03-24 10:46:12] (ta的id:3229291803): [表情包:这张表情包显示了一只手正在做"点赞"的动作,通常表示赞同、喜欢或支持。这个表情包所表达的情感是积极的、赞同的或支持的。]
|
||||||
[03-24 10:46:37] 星野風禾(ta的id:2890165435): 还能思考高达
|
[03-24 10:46:37] 星野風禾(ta的id:2890165435): 还能思考高达
|
||||||
[03-24 10:46:39] 星野風禾(ta的id:2890165435): 什么知识库
|
[03-24 10:46:39] 星野風禾(ta的id:2890165435): 什么知识库
|
||||||
[03-24 10:46:49] ❦幻凌慌てない(ta的id:2459587037): 为什么改了回复系数麦麦还是不怎么回复?大佬们''' # noqa: E501
|
[03-24 10:46:49] ❦幻凌慌てない(ta的id:2459587037): 为什么改了回复系数麦麦还是不怎么回复?大佬们""" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
# test_text = '''千石可乐:分不清AI的陪伴和人类的陪伴,是这样吗?'''
|
# test_text = '''千石可乐:分不清AI的陪伴和人类的陪伴,是这样吗?'''
|
||||||
print(f"开始测试记忆检索,测试文本: {test_text}\n")
|
print(f"开始测试记忆检索,测试文本: {test_text}\n")
|
||||||
memories = await hippocampus_manager.get_memory_from_text(
|
memories = await hippocampus_manager.get_memory_from_text(
|
||||||
text=test_text,
|
text=test_text, max_memory_num=3, max_memory_length=2, max_depth=3, fast_retrieval=False
|
||||||
max_memory_num=3,
|
|
||||||
max_memory_length=2,
|
|
||||||
max_depth=3,
|
|
||||||
fast_retrieval=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
print("检索到的记忆:")
|
print("检索到的记忆:")
|
||||||
for topic, memory_items in memories:
|
for topic, memory_items in memories:
|
||||||
print(f"主题: {topic}")
|
print(f"主题: {topic}")
|
||||||
print(f"- {memory_items}")
|
print(f"- {memory_items}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 测试记忆遗忘
|
# 测试记忆遗忘
|
||||||
# forget_start_time = time.time()
|
# forget_start_time = time.time()
|
||||||
# # print("开始测试记忆遗忘...")
|
# # print("开始测试记忆遗忘...")
|
||||||
@@ -80,6 +75,7 @@ async def test_memory_system():
|
|||||||
print(f"测试过程中出现错误: {e}")
|
print(f"测试过程中出现错误: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
"""主函数"""
|
"""主函数"""
|
||||||
try:
|
try:
|
||||||
@@ -91,5 +87,6 @@ async def main():
|
|||||||
print(f"程序执行出错: {e}")
|
print(f"程序执行出错: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|||||||
@@ -1,24 +1,26 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MemoryConfig:
|
class MemoryConfig:
|
||||||
"""记忆系统配置类"""
|
"""记忆系统配置类"""
|
||||||
|
|
||||||
# 记忆构建相关配置
|
# 记忆构建相关配置
|
||||||
memory_build_distribution: List[float] # 记忆构建的时间分布参数
|
memory_build_distribution: List[float] # 记忆构建的时间分布参数
|
||||||
build_memory_sample_num: int # 每次构建记忆的样本数量
|
build_memory_sample_num: int # 每次构建记忆的样本数量
|
||||||
build_memory_sample_length: int # 每个样本的消息长度
|
build_memory_sample_length: int # 每个样本的消息长度
|
||||||
memory_compress_rate: float # 记忆压缩率
|
memory_compress_rate: float # 记忆压缩率
|
||||||
|
|
||||||
# 记忆遗忘相关配置
|
# 记忆遗忘相关配置
|
||||||
memory_forget_time: int # 记忆遗忘时间(小时)
|
memory_forget_time: int # 记忆遗忘时间(小时)
|
||||||
|
|
||||||
# 记忆过滤相关配置
|
# 记忆过滤相关配置
|
||||||
memory_ban_words: List[str] # 记忆过滤词列表
|
memory_ban_words: List[str] # 记忆过滤词列表
|
||||||
|
|
||||||
llm_topic_judge: str # 话题判断模型
|
llm_topic_judge: str # 话题判断模型
|
||||||
llm_summary_by_topic: str # 话题总结模型
|
llm_summary_by_topic: str # 话题总结模型
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_global_config(cls, global_config):
|
def from_global_config(cls, global_config):
|
||||||
"""从全局配置创建记忆系统配置"""
|
"""从全局配置创建记忆系统配置"""
|
||||||
@@ -30,5 +32,5 @@ class MemoryConfig:
|
|||||||
memory_forget_time=global_config.memory_forget_time,
|
memory_forget_time=global_config.memory_forget_time,
|
||||||
memory_ban_words=global_config.memory_ban_words,
|
memory_ban_words=global_config.memory_ban_words,
|
||||||
llm_topic_judge=global_config.llm_topic_judge,
|
llm_topic_judge=global_config.llm_topic_judge,
|
||||||
llm_summary_by_topic=global_config.llm_summary_by_topic
|
llm_summary_by_topic=global_config.llm_summary_by_topic,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,11 +2,12 @@ import numpy as np
|
|||||||
from scipy import stats
|
from scipy import stats
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
|
||||||
class DistributionVisualizer:
|
class DistributionVisualizer:
|
||||||
def __init__(self, mean=0, std=1, skewness=0, sample_size=10):
|
def __init__(self, mean=0, std=1, skewness=0, sample_size=10):
|
||||||
"""
|
"""
|
||||||
初始化分布可视化器
|
初始化分布可视化器
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
mean (float): 期望均值
|
mean (float): 期望均值
|
||||||
std (float): 标准差
|
std (float): 标准差
|
||||||
@@ -18,7 +19,7 @@ class DistributionVisualizer:
|
|||||||
self.skewness = skewness
|
self.skewness = skewness
|
||||||
self.sample_size = sample_size
|
self.sample_size = sample_size
|
||||||
self.samples = None
|
self.samples = None
|
||||||
|
|
||||||
def generate_samples(self):
|
def generate_samples(self):
|
||||||
"""生成具有指定参数的样本"""
|
"""生成具有指定参数的样本"""
|
||||||
if self.skewness == 0:
|
if self.skewness == 0:
|
||||||
@@ -26,37 +27,28 @@ class DistributionVisualizer:
|
|||||||
self.samples = np.random.normal(loc=self.mean, scale=self.std, size=self.sample_size)
|
self.samples = np.random.normal(loc=self.mean, scale=self.std, size=self.sample_size)
|
||||||
else:
|
else:
|
||||||
# 使用 scipy.stats 生成具有偏度的分布
|
# 使用 scipy.stats 生成具有偏度的分布
|
||||||
self.samples = stats.skewnorm.rvs(a=self.skewness,
|
self.samples = stats.skewnorm.rvs(a=self.skewness, loc=self.mean, scale=self.std, size=self.sample_size)
|
||||||
loc=self.mean,
|
|
||||||
scale=self.std,
|
|
||||||
size=self.sample_size)
|
|
||||||
|
|
||||||
def get_weighted_samples(self):
|
def get_weighted_samples(self):
|
||||||
"""获取加权后的样本数列"""
|
"""获取加权后的样本数列"""
|
||||||
if self.samples is None:
|
if self.samples is None:
|
||||||
self.generate_samples()
|
self.generate_samples()
|
||||||
# 将样本值乘以样本大小
|
# 将样本值乘以样本大小
|
||||||
return self.samples * self.sample_size
|
return self.samples * self.sample_size
|
||||||
|
|
||||||
def get_statistics(self):
|
def get_statistics(self):
|
||||||
"""获取分布的统计信息"""
|
"""获取分布的统计信息"""
|
||||||
if self.samples is None:
|
if self.samples is None:
|
||||||
self.generate_samples()
|
self.generate_samples()
|
||||||
|
|
||||||
return {
|
return {"均值": np.mean(self.samples), "标准差": np.std(self.samples), "实际偏度": stats.skew(self.samples)}
|
||||||
"均值": np.mean(self.samples),
|
|
||||||
"标准差": np.std(self.samples),
|
|
||||||
"实际偏度": stats.skew(self.samples)
|
|
||||||
}
|
|
||||||
|
|
||||||
class MemoryBuildScheduler:
|
class MemoryBuildScheduler:
|
||||||
def __init__(self,
|
def __init__(self, n_hours1, std_hours1, weight1, n_hours2, std_hours2, weight2, total_samples=50):
|
||||||
n_hours1, std_hours1, weight1,
|
|
||||||
n_hours2, std_hours2, weight2,
|
|
||||||
total_samples=50):
|
|
||||||
"""
|
"""
|
||||||
初始化记忆构建调度器
|
初始化记忆构建调度器
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
n_hours1 (float): 第一个分布的均值(距离现在的小时数)
|
n_hours1 (float): 第一个分布的均值(距离现在的小时数)
|
||||||
std_hours1 (float): 第一个分布的标准差(小时)
|
std_hours1 (float): 第一个分布的标准差(小时)
|
||||||
@@ -70,39 +62,31 @@ class MemoryBuildScheduler:
|
|||||||
total_weight = weight1 + weight2
|
total_weight = weight1 + weight2
|
||||||
self.weight1 = weight1 / total_weight
|
self.weight1 = weight1 / total_weight
|
||||||
self.weight2 = weight2 / total_weight
|
self.weight2 = weight2 / total_weight
|
||||||
|
|
||||||
self.n_hours1 = n_hours1
|
self.n_hours1 = n_hours1
|
||||||
self.std_hours1 = std_hours1
|
self.std_hours1 = std_hours1
|
||||||
self.n_hours2 = n_hours2
|
self.n_hours2 = n_hours2
|
||||||
self.std_hours2 = std_hours2
|
self.std_hours2 = std_hours2
|
||||||
self.total_samples = total_samples
|
self.total_samples = total_samples
|
||||||
self.base_time = datetime.now()
|
self.base_time = datetime.now()
|
||||||
|
|
||||||
def generate_time_samples(self):
|
def generate_time_samples(self):
|
||||||
"""生成混合分布的时间采样点"""
|
"""生成混合分布的时间采样点"""
|
||||||
# 根据权重计算每个分布的样本数
|
# 根据权重计算每个分布的样本数
|
||||||
samples1 = int(self.total_samples * self.weight1)
|
samples1 = int(self.total_samples * self.weight1)
|
||||||
samples2 = self.total_samples - samples1
|
samples2 = self.total_samples - samples1
|
||||||
|
|
||||||
# 生成两个正态分布的小时偏移
|
# 生成两个正态分布的小时偏移
|
||||||
hours_offset1 = np.random.normal(
|
hours_offset1 = np.random.normal(loc=self.n_hours1, scale=self.std_hours1, size=samples1)
|
||||||
loc=self.n_hours1,
|
|
||||||
scale=self.std_hours1,
|
hours_offset2 = np.random.normal(loc=self.n_hours2, scale=self.std_hours2, size=samples2)
|
||||||
size=samples1
|
|
||||||
)
|
|
||||||
|
|
||||||
hours_offset2 = np.random.normal(
|
|
||||||
loc=self.n_hours2,
|
|
||||||
scale=self.std_hours2,
|
|
||||||
size=samples2
|
|
||||||
)
|
|
||||||
|
|
||||||
# 合并两个分布的偏移
|
# 合并两个分布的偏移
|
||||||
hours_offset = np.concatenate([hours_offset1, hours_offset2])
|
hours_offset = np.concatenate([hours_offset1, hours_offset2])
|
||||||
|
|
||||||
# 将偏移转换为实际时间戳(使用绝对值确保时间点在过去)
|
# 将偏移转换为实际时间戳(使用绝对值确保时间点在过去)
|
||||||
timestamps = [self.base_time - timedelta(hours=abs(offset)) for offset in hours_offset]
|
timestamps = [self.base_time - timedelta(hours=abs(offset)) for offset in hours_offset]
|
||||||
|
|
||||||
# 按时间排序(从最早到最近)
|
# 按时间排序(从最早到最近)
|
||||||
return sorted(timestamps)
|
return sorted(timestamps)
|
||||||
|
|
||||||
@@ -111,54 +95,56 @@ class MemoryBuildScheduler:
|
|||||||
timestamps = self.generate_time_samples()
|
timestamps = self.generate_time_samples()
|
||||||
return [int(t.timestamp()) for t in timestamps]
|
return [int(t.timestamp()) for t in timestamps]
|
||||||
|
|
||||||
|
|
||||||
def print_time_samples(timestamps, show_distribution=True):
|
def print_time_samples(timestamps, show_distribution=True):
|
||||||
"""打印时间样本和分布信息"""
|
"""打印时间样本和分布信息"""
|
||||||
print(f"\n生成的{len(timestamps)}个时间点分布:")
|
print(f"\n生成的{len(timestamps)}个时间点分布:")
|
||||||
print("序号".ljust(5), "时间戳".ljust(25), "距现在(小时)")
|
print("序号".ljust(5), "时间戳".ljust(25), "距现在(小时)")
|
||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
|
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
time_diffs = []
|
time_diffs = []
|
||||||
|
|
||||||
for i, timestamp in enumerate(timestamps, 1):
|
for i, timestamp in enumerate(timestamps, 1):
|
||||||
hours_diff = (now - timestamp).total_seconds() / 3600
|
hours_diff = (now - timestamp).total_seconds() / 3600
|
||||||
time_diffs.append(hours_diff)
|
time_diffs.append(hours_diff)
|
||||||
print(f"{str(i).ljust(5)} {timestamp.strftime('%Y-%m-%d %H:%M:%S').ljust(25)} {hours_diff:.2f}")
|
print(f"{str(i).ljust(5)} {timestamp.strftime('%Y-%m-%d %H:%M:%S').ljust(25)} {hours_diff:.2f}")
|
||||||
|
|
||||||
# 打印统计信息
|
# 打印统计信息
|
||||||
print("\n统计信息:")
|
print("\n统计信息:")
|
||||||
print(f"平均时间偏移:{np.mean(time_diffs):.2f}小时")
|
print(f"平均时间偏移:{np.mean(time_diffs):.2f}小时")
|
||||||
print(f"标准差:{np.std(time_diffs):.2f}小时")
|
print(f"标准差:{np.std(time_diffs):.2f}小时")
|
||||||
print(f"最早时间:{min(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({max(time_diffs):.2f}小时前)")
|
print(f"最早时间:{min(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({max(time_diffs):.2f}小时前)")
|
||||||
print(f"最近时间:{max(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({min(time_diffs):.2f}小时前)")
|
print(f"最近时间:{max(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({min(time_diffs):.2f}小时前)")
|
||||||
|
|
||||||
if show_distribution:
|
if show_distribution:
|
||||||
# 计算时间分布的直方图
|
# 计算时间分布的直方图
|
||||||
hist, bins = np.histogram(time_diffs, bins=40)
|
hist, bins = np.histogram(time_diffs, bins=40)
|
||||||
print("\n时间分布(每个*代表一个时间点):")
|
print("\n时间分布(每个*代表一个时间点):")
|
||||||
for i in range(len(hist)):
|
for i in range(len(hist)):
|
||||||
if hist[i] > 0:
|
if hist[i] > 0:
|
||||||
print(f"{bins[i]:6.1f}-{bins[i+1]:6.1f}小时: {'*' * int(hist[i])}")
|
print(f"{bins[i]:6.1f}-{bins[i + 1]:6.1f}小时: {'*' * int(hist[i])}")
|
||||||
|
|
||||||
|
|
||||||
# 使用示例
|
# 使用示例
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 创建一个双峰分布的记忆调度器
|
# 创建一个双峰分布的记忆调度器
|
||||||
scheduler = MemoryBuildScheduler(
|
scheduler = MemoryBuildScheduler(
|
||||||
n_hours1=12, # 第一个分布均值(12小时前)
|
n_hours1=12, # 第一个分布均值(12小时前)
|
||||||
std_hours1=8, # 第一个分布标准差
|
std_hours1=8, # 第一个分布标准差
|
||||||
weight1=0.7, # 第一个分布权重 70%
|
weight1=0.7, # 第一个分布权重 70%
|
||||||
n_hours2=36, # 第二个分布均值(36小时前)
|
n_hours2=36, # 第二个分布均值(36小时前)
|
||||||
std_hours2=24, # 第二个分布标准差
|
std_hours2=24, # 第二个分布标准差
|
||||||
weight2=0.3, # 第二个分布权重 30%
|
weight2=0.3, # 第二个分布权重 30%
|
||||||
total_samples=50 # 总共生成50个时间点
|
total_samples=50, # 总共生成50个时间点
|
||||||
)
|
)
|
||||||
|
|
||||||
# 生成时间分布
|
# 生成时间分布
|
||||||
timestamps = scheduler.generate_time_samples()
|
timestamps = scheduler.generate_time_samples()
|
||||||
|
|
||||||
# 打印结果,包含分布可视化
|
# 打印结果,包含分布可视化
|
||||||
print_time_samples(timestamps, show_distribution=True)
|
print_time_samples(timestamps, show_distribution=True)
|
||||||
|
|
||||||
# 打印时间戳数组
|
# 打印时间戳数组
|
||||||
timestamp_array = scheduler.get_timestamp_array()
|
timestamp_array = scheduler.get_timestamp_array()
|
||||||
print("\n时间戳数组(Unix时间戳):")
|
print("\n时间戳数组(Unix时间戳):")
|
||||||
@@ -167,4 +153,4 @@ if __name__ == "__main__":
|
|||||||
if i > 0:
|
if i > 0:
|
||||||
print(", ", end="")
|
print(", ", end="")
|
||||||
print(ts, end="")
|
print(ts, end="")
|
||||||
print("]")
|
print("]")
|
||||||
|
|||||||
@@ -54,9 +54,7 @@ class TestLiveAPI(unittest.IsolatedAsyncioTestCase):
|
|||||||
# 准备测试消息
|
# 准备测试消息
|
||||||
user_info = UserInfo(user_id=12345678, user_nickname="测试用户", platform="qq")
|
user_info = UserInfo(user_id=12345678, user_nickname="测试用户", platform="qq")
|
||||||
group_info = GroupInfo(group_id=12345678, group_name="测试群", platform="qq")
|
group_info = GroupInfo(group_id=12345678, group_name="测试群", platform="qq")
|
||||||
format_info = FormatInfo(
|
format_info = FormatInfo(content_format=["text"], accept_format=["text", "emoji", "reply"])
|
||||||
content_format=["text"], accept_format=["text", "emoji", "reply"]
|
|
||||||
)
|
|
||||||
template_info = None
|
template_info = None
|
||||||
message_info = BaseMessageInfo(
|
message_info = BaseMessageInfo(
|
||||||
platform="qq",
|
platform="qq",
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ else:
|
|||||||
print(f"未找到环境变量文件: {env_path}")
|
print(f"未找到环境变量文件: {env_path}")
|
||||||
print("将使用默认配置")
|
print("将使用默认配置")
|
||||||
|
|
||||||
|
|
||||||
class ChatBasedPersonalityEvaluator:
|
class ChatBasedPersonalityEvaluator:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.personality_traits = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
|
self.personality_traits = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
|
||||||
@@ -50,16 +51,14 @@ class ChatBasedPersonalityEvaluator:
|
|||||||
continue
|
continue
|
||||||
scene_keys = list(scenes.keys())
|
scene_keys = list(scenes.keys())
|
||||||
selected_scenes = random.sample(scene_keys, min(3, len(scene_keys)))
|
selected_scenes = random.sample(scene_keys, min(3, len(scene_keys)))
|
||||||
|
|
||||||
for scene_key in selected_scenes:
|
for scene_key in selected_scenes:
|
||||||
scene = scenes[scene_key]
|
scene = scenes[scene_key]
|
||||||
other_traits = [t for t in PERSONALITY_SCENES if t != trait]
|
other_traits = [t for t in PERSONALITY_SCENES if t != trait]
|
||||||
secondary_trait = random.choice(other_traits)
|
secondary_trait = random.choice(other_traits)
|
||||||
self.scenarios.append({
|
self.scenarios.append(
|
||||||
"场景": scene["scenario"],
|
{"场景": scene["scenario"], "评估维度": [trait, secondary_trait], "场景编号": scene_key}
|
||||||
"评估维度": [trait, secondary_trait],
|
)
|
||||||
"场景编号": scene_key
|
|
||||||
})
|
|
||||||
|
|
||||||
def analyze_chat_context(self, messages: List[Dict]) -> str:
|
def analyze_chat_context(self, messages: List[Dict]) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -67,20 +66,21 @@ class ChatBasedPersonalityEvaluator:
|
|||||||
"""
|
"""
|
||||||
context = ""
|
context = ""
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
nickname = msg.get('user_info', {}).get('user_nickname', '未知用户')
|
nickname = msg.get("user_info", {}).get("user_nickname", "未知用户")
|
||||||
content = msg.get('processed_plain_text', msg.get('detailed_plain_text', ''))
|
content = msg.get("processed_plain_text", msg.get("detailed_plain_text", ""))
|
||||||
if content:
|
if content:
|
||||||
context += f"{nickname}: {content}\n"
|
context += f"{nickname}: {content}\n"
|
||||||
return context
|
return context
|
||||||
|
|
||||||
def evaluate_chat_response(
|
def evaluate_chat_response(
|
||||||
self, user_nickname: str, chat_context: str, dimensions: List[str] = None) -> Dict[str, float]:
|
self, user_nickname: str, chat_context: str, dimensions: List[str] = None
|
||||||
|
) -> Dict[str, float]:
|
||||||
"""
|
"""
|
||||||
评估聊天内容在各个人格维度上的得分
|
评估聊天内容在各个人格维度上的得分
|
||||||
"""
|
"""
|
||||||
# 使用所有维度进行评估
|
# 使用所有维度进行评估
|
||||||
dimensions = list(self.personality_traits.keys())
|
dimensions = list(self.personality_traits.keys())
|
||||||
|
|
||||||
dimension_descriptions = []
|
dimension_descriptions = []
|
||||||
for dim in dimensions:
|
for dim in dimensions:
|
||||||
desc = FACTOR_DESCRIPTIONS.get(dim, "")
|
desc = FACTOR_DESCRIPTIONS.get(dim, "")
|
||||||
@@ -136,18 +136,19 @@ class ChatBasedPersonalityEvaluator:
|
|||||||
def evaluate_user_personality(self, qq_id: str, num_samples: int = 10, context_length: int = 5) -> Dict:
|
def evaluate_user_personality(self, qq_id: str, num_samples: int = 10, context_length: int = 5) -> Dict:
|
||||||
"""
|
"""
|
||||||
基于用户的聊天记录评估人格特征
|
基于用户的聊天记录评估人格特征
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
qq_id (str): 用户QQ号
|
qq_id (str): 用户QQ号
|
||||||
num_samples (int): 要分析的聊天片段数量
|
num_samples (int): 要分析的聊天片段数量
|
||||||
context_length (int): 每个聊天片段的上下文长度
|
context_length (int): 每个聊天片段的上下文长度
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 评估结果
|
Dict: 评估结果
|
||||||
"""
|
"""
|
||||||
# 获取用户的随机消息及其上下文
|
# 获取用户的随机消息及其上下文
|
||||||
chat_contexts, user_nickname = self.message_analyzer.get_user_random_contexts(
|
chat_contexts, user_nickname = self.message_analyzer.get_user_random_contexts(
|
||||||
qq_id, num_messages=num_samples, context_length=context_length)
|
qq_id, num_messages=num_samples, context_length=context_length
|
||||||
|
)
|
||||||
if not chat_contexts:
|
if not chat_contexts:
|
||||||
return {"error": f"没有找到QQ号 {qq_id} 的消息记录"}
|
return {"error": f"没有找到QQ号 {qq_id} 的消息记录"}
|
||||||
|
|
||||||
@@ -155,7 +156,7 @@ class ChatBasedPersonalityEvaluator:
|
|||||||
final_scores = defaultdict(float)
|
final_scores = defaultdict(float)
|
||||||
dimension_counts = defaultdict(int)
|
dimension_counts = defaultdict(int)
|
||||||
chat_samples = []
|
chat_samples = []
|
||||||
|
|
||||||
# 清空历史记录
|
# 清空历史记录
|
||||||
self.trait_scores_history.clear()
|
self.trait_scores_history.clear()
|
||||||
|
|
||||||
@@ -163,13 +164,11 @@ class ChatBasedPersonalityEvaluator:
|
|||||||
for chat_context in chat_contexts:
|
for chat_context in chat_contexts:
|
||||||
# 评估这段聊天内容的所有维度
|
# 评估这段聊天内容的所有维度
|
||||||
scores = self.evaluate_chat_response(user_nickname, chat_context)
|
scores = self.evaluate_chat_response(user_nickname, chat_context)
|
||||||
|
|
||||||
# 记录样本
|
# 记录样本
|
||||||
chat_samples.append({
|
chat_samples.append(
|
||||||
"聊天内容": chat_context,
|
{"聊天内容": chat_context, "评估维度": list(self.personality_traits.keys()), "评分": scores}
|
||||||
"评估维度": list(self.personality_traits.keys()),
|
)
|
||||||
"评分": scores
|
|
||||||
})
|
|
||||||
|
|
||||||
# 更新总分和历史记录
|
# 更新总分和历史记录
|
||||||
for dimension, score in scores.items():
|
for dimension, score in scores.items():
|
||||||
@@ -196,7 +195,7 @@ class ChatBasedPersonalityEvaluator:
|
|||||||
"人格特征评分": average_scores,
|
"人格特征评分": average_scores,
|
||||||
"维度评估次数": dict(dimension_counts),
|
"维度评估次数": dict(dimension_counts),
|
||||||
"详细样本": chat_samples,
|
"详细样本": chat_samples,
|
||||||
"特质得分历史": {k: v for k, v in self.trait_scores_history.items()}
|
"特质得分历史": {k: v for k, v in self.trait_scores_history.items()},
|
||||||
}
|
}
|
||||||
|
|
||||||
# 保存结果
|
# 保存结果
|
||||||
@@ -215,40 +214,41 @@ class ChatBasedPersonalityEvaluator:
|
|||||||
chinese_fonts = []
|
chinese_fonts = []
|
||||||
for f in fm.fontManager.ttflist:
|
for f in fm.fontManager.ttflist:
|
||||||
try:
|
try:
|
||||||
if '简' in f.name or 'SC' in f.name or '黑' in f.name or '宋' in f.name or '微软' in f.name:
|
if "简" in f.name or "SC" in f.name or "黑" in f.name or "宋" in f.name or "微软" in f.name:
|
||||||
chinese_fonts.append(f.name)
|
chinese_fonts.append(f.name)
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if chinese_fonts:
|
if chinese_fonts:
|
||||||
plt.rcParams['font.sans-serif'] = chinese_fonts + ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS']
|
plt.rcParams["font.sans-serif"] = chinese_fonts + ["SimHei", "Microsoft YaHei", "Arial Unicode MS"]
|
||||||
else:
|
else:
|
||||||
# 如果没有找到中文字体,使用默认字体,并将中文昵称转换为拼音或英文
|
# 如果没有找到中文字体,使用默认字体,并将中文昵称转换为拼音或英文
|
||||||
try:
|
try:
|
||||||
from pypinyin import lazy_pinyin
|
from pypinyin import lazy_pinyin
|
||||||
user_nickname = ''.join(lazy_pinyin(user_nickname))
|
|
||||||
|
user_nickname = "".join(lazy_pinyin(user_nickname))
|
||||||
except ImportError:
|
except ImportError:
|
||||||
user_nickname = "User" # 如果无法转换为拼音,使用默认英文
|
user_nickname = "User" # 如果无法转换为拼音,使用默认英文
|
||||||
|
|
||||||
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
|
plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题
|
||||||
|
|
||||||
plt.figure(figsize=(12, 6))
|
plt.figure(figsize=(12, 6))
|
||||||
plt.style.use('bmh') # 使用内置的bmh样式,它有类似seaborn的美观效果
|
plt.style.use("bmh") # 使用内置的bmh样式,它有类似seaborn的美观效果
|
||||||
|
|
||||||
colors = {
|
colors = {
|
||||||
"开放性": "#FF9999",
|
"开放性": "#FF9999",
|
||||||
"严谨性": "#66B2FF",
|
"严谨性": "#66B2FF",
|
||||||
"外向性": "#99FF99",
|
"外向性": "#99FF99",
|
||||||
"宜人性": "#FFCC99",
|
"宜人性": "#FFCC99",
|
||||||
"神经质": "#FF99CC"
|
"神经质": "#FF99CC",
|
||||||
}
|
}
|
||||||
|
|
||||||
# 计算每个维度在每个时间点的累计平均分
|
# 计算每个维度在每个时间点的累计平均分
|
||||||
cumulative_averages = {}
|
cumulative_averages = {}
|
||||||
for trait, scores in self.trait_scores_history.items():
|
for trait, scores in self.trait_scores_history.items():
|
||||||
if not scores:
|
if not scores:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
averages = []
|
averages = []
|
||||||
total = 0
|
total = 0
|
||||||
valid_count = 0
|
valid_count = 0
|
||||||
@@ -264,25 +264,25 @@ class ChatBasedPersonalityEvaluator:
|
|||||||
averages.append(averages[-1])
|
averages.append(averages[-1])
|
||||||
else:
|
else:
|
||||||
continue # 跳过无效分数
|
continue # 跳过无效分数
|
||||||
|
|
||||||
if averages: # 只有在有有效分数的情况下才添加到累计平均中
|
if averages: # 只有在有有效分数的情况下才添加到累计平均中
|
||||||
cumulative_averages[trait] = averages
|
cumulative_averages[trait] = averages
|
||||||
|
|
||||||
# 绘制每个维度的累计平均分变化趋势
|
# 绘制每个维度的累计平均分变化趋势
|
||||||
for trait, averages in cumulative_averages.items():
|
for trait, averages in cumulative_averages.items():
|
||||||
x = range(1, len(averages) + 1)
|
x = range(1, len(averages) + 1)
|
||||||
plt.plot(x, averages, 'o-', label=trait, color=colors.get(trait), linewidth=2, markersize=8)
|
plt.plot(x, averages, "o-", label=trait, color=colors.get(trait), linewidth=2, markersize=8)
|
||||||
|
|
||||||
# 添加趋势线
|
# 添加趋势线
|
||||||
z = np.polyfit(x, averages, 1)
|
z = np.polyfit(x, averages, 1)
|
||||||
p = np.poly1d(z)
|
p = np.poly1d(z)
|
||||||
plt.plot(x, p(x), '--', color=colors.get(trait), alpha=0.5)
|
plt.plot(x, p(x), "--", color=colors.get(trait), alpha=0.5)
|
||||||
|
|
||||||
plt.title(f"{user_nickname} 的人格特质累计平均分变化趋势", fontsize=14, pad=20)
|
plt.title(f"{user_nickname} 的人格特质累计平均分变化趋势", fontsize=14, pad=20)
|
||||||
plt.xlabel("评估次数", fontsize=12)
|
plt.xlabel("评估次数", fontsize=12)
|
||||||
plt.ylabel("累计平均分", fontsize=12)
|
plt.ylabel("累计平均分", fontsize=12)
|
||||||
plt.grid(True, linestyle='--', alpha=0.7)
|
plt.grid(True, linestyle="--", alpha=0.7)
|
||||||
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
|
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
||||||
plt.ylim(0, 7)
|
plt.ylim(0, 7)
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
|
|
||||||
@@ -290,38 +290,39 @@ class ChatBasedPersonalityEvaluator:
|
|||||||
os.makedirs("results/plots", exist_ok=True)
|
os.makedirs("results/plots", exist_ok=True)
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
plot_file = f"results/plots/personality_trend_{qq_id}_{timestamp}.png"
|
plot_file = f"results/plots/personality_trend_{qq_id}_{timestamp}.png"
|
||||||
plt.savefig(plot_file, dpi=300, bbox_inches='tight')
|
plt.savefig(plot_file, dpi=300, bbox_inches="tight")
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
def analyze_user_personality(qq_id: str, num_samples: int = 10, context_length: int = 5) -> str:
|
def analyze_user_personality(qq_id: str, num_samples: int = 10, context_length: int = 5) -> str:
|
||||||
"""
|
"""
|
||||||
分析用户人格特征的便捷函数
|
分析用户人格特征的便捷函数
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
qq_id (str): 用户QQ号
|
qq_id (str): 用户QQ号
|
||||||
num_samples (int): 要分析的聊天片段数量
|
num_samples (int): 要分析的聊天片段数量
|
||||||
context_length (int): 每个聊天片段的上下文长度
|
context_length (int): 每个聊天片段的上下文长度
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 格式化的分析结果
|
str: 格式化的分析结果
|
||||||
"""
|
"""
|
||||||
evaluator = ChatBasedPersonalityEvaluator()
|
evaluator = ChatBasedPersonalityEvaluator()
|
||||||
result = evaluator.evaluate_user_personality(qq_id, num_samples, context_length)
|
result = evaluator.evaluate_user_personality(qq_id, num_samples, context_length)
|
||||||
|
|
||||||
if "error" in result:
|
if "error" in result:
|
||||||
return result["error"]
|
return result["error"]
|
||||||
|
|
||||||
# 格式化输出
|
# 格式化输出
|
||||||
output = f"QQ号 {qq_id} ({result['用户昵称']}) 的人格特征分析结果:\n"
|
output = f"QQ号 {qq_id} ({result['用户昵称']}) 的人格特征分析结果:\n"
|
||||||
output += "=" * 50 + "\n\n"
|
output += "=" * 50 + "\n\n"
|
||||||
|
|
||||||
output += "人格特征评分:\n"
|
output += "人格特征评分:\n"
|
||||||
for trait, score in result["人格特征评分"].items():
|
for trait, score in result["人格特征评分"].items():
|
||||||
if score == 0:
|
if score == 0:
|
||||||
output += f"{trait}: 数据不足,无法判断 (评估次数: {result['维度评估次数'].get(trait, 0)})\n"
|
output += f"{trait}: 数据不足,无法判断 (评估次数: {result['维度评估次数'].get(trait, 0)})\n"
|
||||||
else:
|
else:
|
||||||
output += f"{trait}: {score}/6 (评估次数: {result['维度评估次数'].get(trait, 0)})\n"
|
output += f"{trait}: {score}/6 (评估次数: {result['维度评估次数'].get(trait, 0)})\n"
|
||||||
|
|
||||||
# 添加变化趋势描述
|
# 添加变化趋势描述
|
||||||
if trait in result["特质得分历史"] and len(result["特质得分历史"][trait]) > 1:
|
if trait in result["特质得分历史"] and len(result["特质得分历史"][trait]) > 1:
|
||||||
scores = [s for s in result["特质得分历史"][trait] if s != 0] # 过滤掉无效分数
|
scores = [s for s in result["特质得分历史"][trait] if s != 0] # 过滤掉无效分数
|
||||||
@@ -334,13 +335,14 @@ def analyze_user_personality(qq_id: str, num_samples: int = 10, context_length:
|
|||||||
else:
|
else:
|
||||||
trend_desc = "呈下降趋势"
|
trend_desc = "呈下降趋势"
|
||||||
output += f" 变化趋势: {trend_desc} (斜率: {trend:.2f})\n"
|
output += f" 变化趋势: {trend_desc} (斜率: {trend:.2f})\n"
|
||||||
|
|
||||||
output += f"\n分析样本数量:{result['样本数量']}\n"
|
output += f"\n分析样本数量:{result['样本数量']}\n"
|
||||||
output += f"结果已保存至:results/personality_result_{qq_id}.json\n"
|
output += f"结果已保存至:results/personality_result_{qq_id}.json\n"
|
||||||
output += "变化趋势图已保存至:results/plots/目录\n"
|
output += "变化趋势图已保存至:results/plots/目录\n"
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 测试代码
|
# 测试代码
|
||||||
# test_qq = "" # 替换为要测试的QQ号
|
# test_qq = "" # 替换为要测试的QQ号
|
||||||
|
|||||||
@@ -82,7 +82,6 @@ class PersonalityEvaluator_direct:
|
|||||||
|
|
||||||
dimensions_text = "\n".join(dimension_descriptions)
|
dimensions_text = "\n".join(dimension_descriptions)
|
||||||
|
|
||||||
|
|
||||||
prompt = f"""请根据以下场景和用户描述,评估用户在大五人格模型中的相关维度得分(1-6分)。
|
prompt = f"""请根据以下场景和用户描述,评估用户在大五人格模型中的相关维度得分(1-6分)。
|
||||||
|
|
||||||
场景描述:
|
场景描述:
|
||||||
|
|||||||
@@ -14,18 +14,19 @@ sys.path.append(root_path)
|
|||||||
|
|
||||||
from src.common.database import db # noqa: E402
|
from src.common.database import db # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
class MessageAnalyzer:
|
class MessageAnalyzer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.messages_collection = db["messages"]
|
self.messages_collection = db["messages"]
|
||||||
|
|
||||||
def get_message_context(self, message_id: int, context_length: int = 5) -> Optional[List[Dict]]:
|
def get_message_context(self, message_id: int, context_length: int = 5) -> Optional[List[Dict]]:
|
||||||
"""
|
"""
|
||||||
获取指定消息ID的上下文消息列表
|
获取指定消息ID的上下文消息列表
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message_id (int): 消息ID
|
message_id (int): 消息ID
|
||||||
context_length (int): 上下文长度(单侧,总长度为 2*context_length + 1)
|
context_length (int): 上下文长度(单侧,总长度为 2*context_length + 1)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[List[Dict]]: 消息列表,如果未找到则返回None
|
Optional[List[Dict]]: 消息列表,如果未找到则返回None
|
||||||
"""
|
"""
|
||||||
@@ -33,110 +34,110 @@ class MessageAnalyzer:
|
|||||||
target_message = self.messages_collection.find_one({"message_id": message_id})
|
target_message = self.messages_collection.find_one({"message_id": message_id})
|
||||||
if not target_message:
|
if not target_message:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 获取该消息的stream_id
|
# 获取该消息的stream_id
|
||||||
stream_id = target_message.get('chat_info', {}).get('stream_id')
|
stream_id = target_message.get("chat_info", {}).get("stream_id")
|
||||||
if not stream_id:
|
if not stream_id:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 获取同一stream_id的所有消息
|
# 获取同一stream_id的所有消息
|
||||||
stream_messages = list(self.messages_collection.find({
|
stream_messages = list(self.messages_collection.find({"chat_info.stream_id": stream_id}).sort("time", 1))
|
||||||
"chat_info.stream_id": stream_id
|
|
||||||
}).sort("time", 1))
|
|
||||||
|
|
||||||
# 找到目标消息在列表中的位置
|
# 找到目标消息在列表中的位置
|
||||||
target_index = None
|
target_index = None
|
||||||
for i, msg in enumerate(stream_messages):
|
for i, msg in enumerate(stream_messages):
|
||||||
if msg['message_id'] == message_id:
|
if msg["message_id"] == message_id:
|
||||||
target_index = i
|
target_index = i
|
||||||
break
|
break
|
||||||
|
|
||||||
if target_index is None:
|
if target_index is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 获取目标消息前后的消息
|
# 获取目标消息前后的消息
|
||||||
start_index = max(0, target_index - context_length)
|
start_index = max(0, target_index - context_length)
|
||||||
end_index = min(len(stream_messages), target_index + context_length + 1)
|
end_index = min(len(stream_messages), target_index + context_length + 1)
|
||||||
|
|
||||||
return stream_messages[start_index:end_index]
|
return stream_messages[start_index:end_index]
|
||||||
|
|
||||||
def format_messages(self, messages: List[Dict], target_message_id: Optional[int] = None) -> str:
|
def format_messages(self, messages: List[Dict], target_message_id: Optional[int] = None) -> str:
|
||||||
"""
|
"""
|
||||||
格式化消息列表为可读字符串
|
格式化消息列表为可读字符串
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages (List[Dict]): 消息列表
|
messages (List[Dict]): 消息列表
|
||||||
target_message_id (Optional[int]): 目标消息ID,用于标记
|
target_message_id (Optional[int]): 目标消息ID,用于标记
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 格式化的消息字符串
|
str: 格式化的消息字符串
|
||||||
"""
|
"""
|
||||||
if not messages:
|
if not messages:
|
||||||
return "没有消息记录"
|
return "没有消息记录"
|
||||||
|
|
||||||
reply = ""
|
reply = ""
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
# 消息时间
|
# 消息时间
|
||||||
msg_time = datetime.datetime.fromtimestamp(int(msg['time'])).strftime("%Y-%m-%d %H:%M:%S")
|
msg_time = datetime.datetime.fromtimestamp(int(msg["time"])).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
# 获取消息内容
|
# 获取消息内容
|
||||||
message_text = msg.get('processed_plain_text', msg.get('detailed_plain_text', '无消息内容'))
|
message_text = msg.get("processed_plain_text", msg.get("detailed_plain_text", "无消息内容"))
|
||||||
nickname = msg.get('user_info', {}).get('user_nickname', '未知用户')
|
nickname = msg.get("user_info", {}).get("user_nickname", "未知用户")
|
||||||
|
|
||||||
# 标记当前消息
|
# 标记当前消息
|
||||||
is_target = "→ " if target_message_id and msg['message_id'] == target_message_id else " "
|
is_target = "→ " if target_message_id and msg["message_id"] == target_message_id else " "
|
||||||
|
|
||||||
reply += f"{is_target}[{msg_time}] {nickname}: {message_text}\n"
|
reply += f"{is_target}[{msg_time}] {nickname}: {message_text}\n"
|
||||||
|
|
||||||
if target_message_id and msg['message_id'] == target_message_id:
|
if target_message_id and msg["message_id"] == target_message_id:
|
||||||
reply += " " + "-" * 50 + "\n"
|
reply += " " + "-" * 50 + "\n"
|
||||||
|
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
def get_user_random_contexts(
|
def get_user_random_contexts(
|
||||||
self, qq_id: str, num_messages: int = 10, context_length: int = 5) -> tuple[List[str], str]: # noqa: E501
|
self, qq_id: str, num_messages: int = 10, context_length: int = 5
|
||||||
|
) -> tuple[List[str], str]: # noqa: E501
|
||||||
"""
|
"""
|
||||||
获取用户的随机消息及其上下文
|
获取用户的随机消息及其上下文
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
qq_id (str): QQ号
|
qq_id (str): QQ号
|
||||||
num_messages (int): 要获取的随机消息数量
|
num_messages (int): 要获取的随机消息数量
|
||||||
context_length (int): 每条消息的上下文长度(单侧)
|
context_length (int): 每条消息的上下文长度(单侧)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[List[str], str]: (每个消息上下文的格式化字符串列表, 用户昵称)
|
tuple[List[str], str]: (每个消息上下文的格式化字符串列表, 用户昵称)
|
||||||
"""
|
"""
|
||||||
if not qq_id:
|
if not qq_id:
|
||||||
return [], ""
|
return [], ""
|
||||||
|
|
||||||
# 获取用户所有消息
|
# 获取用户所有消息
|
||||||
all_messages = list(self.messages_collection.find({"user_info.user_id": int(qq_id)}))
|
all_messages = list(self.messages_collection.find({"user_info.user_id": int(qq_id)}))
|
||||||
if not all_messages:
|
if not all_messages:
|
||||||
return [], ""
|
return [], ""
|
||||||
|
|
||||||
# 获取用户昵称
|
# 获取用户昵称
|
||||||
user_nickname = all_messages[0].get('chat_info', {}).get('user_info', {}).get('user_nickname', '未知用户')
|
user_nickname = all_messages[0].get("chat_info", {}).get("user_info", {}).get("user_nickname", "未知用户")
|
||||||
|
|
||||||
# 随机选择指定数量的消息
|
# 随机选择指定数量的消息
|
||||||
selected_messages = random.sample(all_messages, min(num_messages, len(all_messages)))
|
selected_messages = random.sample(all_messages, min(num_messages, len(all_messages)))
|
||||||
# 按时间排序
|
# 按时间排序
|
||||||
selected_messages.sort(key=lambda x: int(x['time']))
|
selected_messages.sort(key=lambda x: int(x["time"]))
|
||||||
|
|
||||||
# 存储所有上下文消息
|
# 存储所有上下文消息
|
||||||
context_list = []
|
context_list = []
|
||||||
|
|
||||||
# 获取每条消息的上下文
|
# 获取每条消息的上下文
|
||||||
for msg in selected_messages:
|
for msg in selected_messages:
|
||||||
message_id = msg['message_id']
|
message_id = msg["message_id"]
|
||||||
|
|
||||||
# 获取消息上下文
|
# 获取消息上下文
|
||||||
context_messages = self.get_message_context(message_id, context_length)
|
context_messages = self.get_message_context(message_id, context_length)
|
||||||
if context_messages:
|
if context_messages:
|
||||||
formatted_context = self.format_messages(context_messages, message_id)
|
formatted_context = self.format_messages(context_messages, message_id)
|
||||||
context_list.append(formatted_context)
|
context_list.append(formatted_context)
|
||||||
|
|
||||||
return context_list, user_nickname
|
return context_list, user_nickname
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 测试代码
|
# 测试代码
|
||||||
analyzer = MessageAnalyzer()
|
analyzer = MessageAnalyzer()
|
||||||
@@ -145,7 +146,7 @@ if __name__ == "__main__":
|
|||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
# 获取5条消息,每条消息前后各3条上下文
|
# 获取5条消息,每条消息前后各3条上下文
|
||||||
contexts, nickname = analyzer.get_user_random_contexts(test_qq, num_messages=5, context_length=3)
|
contexts, nickname = analyzer.get_user_random_contexts(test_qq, num_messages=5, context_length=3)
|
||||||
|
|
||||||
print(f"用户昵称: {nickname}\n")
|
print(f"用户昵称: {nickname}\n")
|
||||||
# 打印每个上下文
|
# 打印每个上下文
|
||||||
for i, context in enumerate(contexts, 1):
|
for i, context in enumerate(contexts, 1):
|
||||||
|
|||||||
@@ -46,17 +46,15 @@ class LLMStatistics:
|
|||||||
"""记录在线时间"""
|
"""记录在线时间"""
|
||||||
current_time = datetime.now()
|
current_time = datetime.now()
|
||||||
# 检查5分钟内是否已有记录
|
# 检查5分钟内是否已有记录
|
||||||
recent_record = db.online_time.find_one({
|
recent_record = db.online_time.find_one({"timestamp": {"$gte": current_time - timedelta(minutes=5)}})
|
||||||
"timestamp": {
|
|
||||||
"$gte": current_time - timedelta(minutes=5)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
if not recent_record:
|
if not recent_record:
|
||||||
db.online_time.insert_one({
|
db.online_time.insert_one(
|
||||||
"timestamp": current_time,
|
{
|
||||||
"duration": 5 # 5分钟
|
"timestamp": current_time,
|
||||||
})
|
"duration": 5, # 5分钟
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def _collect_statistics_for_period(self, start_time: datetime) -> Dict[str, Any]:
|
def _collect_statistics_for_period(self, start_time: datetime) -> Dict[str, Any]:
|
||||||
"""收集指定时间段的LLM请求统计数据
|
"""收集指定时间段的LLM请求统计数据
|
||||||
|
|||||||
@@ -41,10 +41,9 @@ class WillingManager:
|
|||||||
|
|
||||||
interested_rate = interested_rate * config.response_interested_rate_amplifier
|
interested_rate = interested_rate * config.response_interested_rate_amplifier
|
||||||
|
|
||||||
|
|
||||||
if interested_rate > 0.4:
|
if interested_rate > 0.4:
|
||||||
current_willing += interested_rate - 0.3
|
current_willing += interested_rate - 0.3
|
||||||
|
|
||||||
if is_mentioned_bot and current_willing < 1.0:
|
if is_mentioned_bot and current_willing < 1.0:
|
||||||
current_willing += 1
|
current_willing += 1
|
||||||
elif is_mentioned_bot:
|
elif is_mentioned_bot:
|
||||||
|
|||||||
@@ -5,38 +5,41 @@ from src.plugins.models.utils_model import LLM_request
|
|||||||
from src.plugins.config.config import global_config
|
from src.plugins.config.config import global_config
|
||||||
from src.plugins.schedule.schedule_generator import bot_schedule
|
from src.plugins.schedule.schedule_generator import bot_schedule
|
||||||
import asyncio
|
import asyncio
|
||||||
from src.common.logger import get_module_logger, LogConfig, HEARTFLOW_STYLE_CONFIG # noqa: E402
|
from src.common.logger import get_module_logger, LogConfig, HEARTFLOW_STYLE_CONFIG # noqa: E402
|
||||||
import time
|
import time
|
||||||
|
|
||||||
heartflow_config = LogConfig(
|
heartflow_config = LogConfig(
|
||||||
# 使用海马体专用样式
|
# 使用海马体专用样式
|
||||||
console_format=HEARTFLOW_STYLE_CONFIG["console_format"],
|
console_format=HEARTFLOW_STYLE_CONFIG["console_format"],
|
||||||
file_format=HEARTFLOW_STYLE_CONFIG["file_format"],
|
file_format=HEARTFLOW_STYLE_CONFIG["file_format"],
|
||||||
)
|
)
|
||||||
logger = get_module_logger("heartflow", config=heartflow_config)
|
logger = get_module_logger("heartflow", config=heartflow_config)
|
||||||
|
|
||||||
|
|
||||||
class CuttentState:
|
class CuttentState:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.willing = 0
|
self.willing = 0
|
||||||
self.current_state_info = ""
|
self.current_state_info = ""
|
||||||
|
|
||||||
self.mood_manager = MoodManager()
|
self.mood_manager = MoodManager()
|
||||||
self.mood = self.mood_manager.get_prompt()
|
self.mood = self.mood_manager.get_prompt()
|
||||||
|
|
||||||
def update_current_state_info(self):
|
def update_current_state_info(self):
|
||||||
self.current_state_info = self.mood_manager.get_current_mood()
|
self.current_state_info = self.mood_manager.get_current_mood()
|
||||||
|
|
||||||
|
|
||||||
class Heartflow:
|
class Heartflow:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.current_mind = "你什么也没想"
|
self.current_mind = "你什么也没想"
|
||||||
self.past_mind = []
|
self.past_mind = []
|
||||||
self.current_state : CuttentState = CuttentState()
|
self.current_state: CuttentState = CuttentState()
|
||||||
self.llm_model = LLM_request(
|
self.llm_model = LLM_request(
|
||||||
model=global_config.llm_heartflow, temperature=0.6, max_tokens=1000, request_type="heart_flow")
|
model=global_config.llm_heartflow, temperature=0.6, max_tokens=1000, request_type="heart_flow"
|
||||||
|
)
|
||||||
|
|
||||||
self._subheartflows = {}
|
self._subheartflows = {}
|
||||||
self.active_subheartflows_nums = 0
|
self.active_subheartflows_nums = 0
|
||||||
|
|
||||||
self.personality_info = " ".join(global_config.PROMPT_PERSONALITY)
|
self.personality_info = " ".join(global_config.PROMPT_PERSONALITY)
|
||||||
|
|
||||||
async def _cleanup_inactive_subheartflows(self):
|
async def _cleanup_inactive_subheartflows(self):
|
||||||
@@ -44,46 +47,46 @@ class Heartflow:
|
|||||||
while True:
|
while True:
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
inactive_subheartflows = []
|
inactive_subheartflows = []
|
||||||
|
|
||||||
# 检查所有子心流
|
# 检查所有子心流
|
||||||
for subheartflow_id, subheartflow in self._subheartflows.items():
|
for subheartflow_id, subheartflow in self._subheartflows.items():
|
||||||
if current_time - subheartflow.last_active_time > 600: # 10分钟 = 600秒
|
if current_time - subheartflow.last_active_time > 600: # 10分钟 = 600秒
|
||||||
inactive_subheartflows.append(subheartflow_id)
|
inactive_subheartflows.append(subheartflow_id)
|
||||||
logger.info(f"发现不活跃的子心流: {subheartflow_id}")
|
logger.info(f"发现不活跃的子心流: {subheartflow_id}")
|
||||||
|
|
||||||
# 清理不活跃的子心流
|
# 清理不活跃的子心流
|
||||||
for subheartflow_id in inactive_subheartflows:
|
for subheartflow_id in inactive_subheartflows:
|
||||||
del self._subheartflows[subheartflow_id]
|
del self._subheartflows[subheartflow_id]
|
||||||
logger.info(f"已清理不活跃的子心流: {subheartflow_id}")
|
logger.info(f"已清理不活跃的子心流: {subheartflow_id}")
|
||||||
|
|
||||||
await asyncio.sleep(30) # 每分钟检查一次
|
await asyncio.sleep(30) # 每分钟检查一次
|
||||||
|
|
||||||
async def heartflow_start_working(self):
|
async def heartflow_start_working(self):
|
||||||
# 启动清理任务
|
# 启动清理任务
|
||||||
asyncio.create_task(self._cleanup_inactive_subheartflows())
|
asyncio.create_task(self._cleanup_inactive_subheartflows())
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# 检查是否存在子心流
|
# 检查是否存在子心流
|
||||||
if not self._subheartflows:
|
if not self._subheartflows:
|
||||||
logger.info("当前没有子心流,等待新的子心流创建...")
|
logger.info("当前没有子心流,等待新的子心流创建...")
|
||||||
await asyncio.sleep(60) # 每分钟检查一次是否有新的子心流
|
await asyncio.sleep(60) # 每分钟检查一次是否有新的子心流
|
||||||
continue
|
continue
|
||||||
|
|
||||||
await self.do_a_thinking()
|
await self.do_a_thinking()
|
||||||
await asyncio.sleep(300) # 5分钟思考一次
|
await asyncio.sleep(300) # 5分钟思考一次
|
||||||
|
|
||||||
async def do_a_thinking(self):
|
async def do_a_thinking(self):
|
||||||
logger.debug("麦麦大脑袋转起来了")
|
logger.debug("麦麦大脑袋转起来了")
|
||||||
self.current_state.update_current_state_info()
|
self.current_state.update_current_state_info()
|
||||||
|
|
||||||
personality_info = self.personality_info
|
personality_info = self.personality_info
|
||||||
current_thinking_info = self.current_mind
|
current_thinking_info = self.current_mind
|
||||||
mood_info = self.current_state.mood
|
mood_info = self.current_state.mood
|
||||||
related_memory_info = 'memory'
|
related_memory_info = "memory"
|
||||||
sub_flows_info = await self.get_all_subheartflows_minds()
|
sub_flows_info = await self.get_all_subheartflows_minds()
|
||||||
|
|
||||||
schedule_info = bot_schedule.get_current_num_task(num = 4,time_info = True)
|
schedule_info = bot_schedule.get_current_num_task(num=4, time_info=True)
|
||||||
|
|
||||||
prompt = ""
|
prompt = ""
|
||||||
prompt += f"你刚刚在做的事情是:{schedule_info}\n"
|
prompt += f"你刚刚在做的事情是:{schedule_info}\n"
|
||||||
prompt += f"{personality_info}\n"
|
prompt += f"{personality_info}\n"
|
||||||
@@ -93,49 +96,46 @@ class Heartflow:
|
|||||||
prompt += f"你现在{mood_info}。"
|
prompt += f"你现在{mood_info}。"
|
||||||
prompt += "现在你接下去继续思考,产生新的想法,但是要基于原有的主要想法,不要分点输出,"
|
prompt += "现在你接下去继续思考,产生新的想法,但是要基于原有的主要想法,不要分点输出,"
|
||||||
prompt += "输出连贯的内心独白,不要太长,但是记得结合上述的消息,关注新内容:"
|
prompt += "输出连贯的内心独白,不要太长,但是记得结合上述的消息,关注新内容:"
|
||||||
|
|
||||||
reponse, reasoning_content = await self.llm_model.generate_response_async(prompt)
|
reponse, reasoning_content = await self.llm_model.generate_response_async(prompt)
|
||||||
|
|
||||||
self.update_current_mind(reponse)
|
self.update_current_mind(reponse)
|
||||||
|
|
||||||
self.current_mind = reponse
|
self.current_mind = reponse
|
||||||
logger.info(f"麦麦的总体脑内状态:{self.current_mind}")
|
logger.info(f"麦麦的总体脑内状态:{self.current_mind}")
|
||||||
# logger.info("麦麦想了想,当前活动:")
|
# logger.info("麦麦想了想,当前活动:")
|
||||||
await bot_schedule.move_doing(self.current_mind)
|
await bot_schedule.move_doing(self.current_mind)
|
||||||
|
|
||||||
|
|
||||||
for _, subheartflow in self._subheartflows.items():
|
for _, subheartflow in self._subheartflows.items():
|
||||||
subheartflow.main_heartflow_info = reponse
|
subheartflow.main_heartflow_info = reponse
|
||||||
|
|
||||||
def update_current_mind(self,reponse):
|
def update_current_mind(self, reponse):
|
||||||
self.past_mind.append(self.current_mind)
|
self.past_mind.append(self.current_mind)
|
||||||
self.current_mind = reponse
|
self.current_mind = reponse
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def get_all_subheartflows_minds(self):
|
async def get_all_subheartflows_minds(self):
|
||||||
sub_minds = ""
|
sub_minds = ""
|
||||||
for _, subheartflow in self._subheartflows.items():
|
for _, subheartflow in self._subheartflows.items():
|
||||||
sub_minds += subheartflow.current_mind
|
sub_minds += subheartflow.current_mind
|
||||||
|
|
||||||
return await self.minds_summary(sub_minds)
|
return await self.minds_summary(sub_minds)
|
||||||
|
|
||||||
async def minds_summary(self,minds_str):
|
async def minds_summary(self, minds_str):
|
||||||
personality_info = self.personality_info
|
personality_info = self.personality_info
|
||||||
mood_info = self.current_state.mood
|
mood_info = self.current_state.mood
|
||||||
|
|
||||||
prompt = ""
|
prompt = ""
|
||||||
prompt += f"{personality_info}\n"
|
prompt += f"{personality_info}\n"
|
||||||
prompt += f"现在{global_config.BOT_NICKNAME}的想法是:{self.current_mind}\n"
|
prompt += f"现在{global_config.BOT_NICKNAME}的想法是:{self.current_mind}\n"
|
||||||
prompt += f"现在{global_config.BOT_NICKNAME}在qq群里进行聊天,聊天的话题如下:{minds_str}\n"
|
prompt += f"现在{global_config.BOT_NICKNAME}在qq群里进行聊天,聊天的话题如下:{minds_str}\n"
|
||||||
prompt += f"你现在{mood_info}\n"
|
prompt += f"你现在{mood_info}\n"
|
||||||
prompt += '''现在请你总结这些聊天内容,注意关注聊天内容对原有的想法的影响,输出连贯的内心独白
|
prompt += """现在请你总结这些聊天内容,注意关注聊天内容对原有的想法的影响,输出连贯的内心独白
|
||||||
不要太长,但是记得结合上述的消息,要记得你的人设,关注新内容:'''
|
不要太长,但是记得结合上述的消息,要记得你的人设,关注新内容:"""
|
||||||
|
|
||||||
reponse, reasoning_content = await self.llm_model.generate_response_async(prompt)
|
reponse, reasoning_content = await self.llm_model.generate_response_async(prompt)
|
||||||
|
|
||||||
return reponse
|
return reponse
|
||||||
|
|
||||||
def create_subheartflow(self, subheartflow_id):
|
def create_subheartflow(self, subheartflow_id):
|
||||||
"""
|
"""
|
||||||
创建一个新的SubHeartflow实例
|
创建一个新的SubHeartflow实例
|
||||||
@@ -145,10 +145,10 @@ class Heartflow:
|
|||||||
if subheartflow_id not in self._subheartflows:
|
if subheartflow_id not in self._subheartflows:
|
||||||
logger.debug(f"创建 subheartflow: {subheartflow_id}")
|
logger.debug(f"创建 subheartflow: {subheartflow_id}")
|
||||||
subheartflow = SubHeartflow(subheartflow_id)
|
subheartflow = SubHeartflow(subheartflow_id)
|
||||||
#创建一个观察对象,目前只可以用chat_id创建观察对象
|
# 创建一个观察对象,目前只可以用chat_id创建观察对象
|
||||||
logger.debug(f"创建 observation: {subheartflow_id}")
|
logger.debug(f"创建 observation: {subheartflow_id}")
|
||||||
observation = ChattingObservation(subheartflow_id)
|
observation = ChattingObservation(subheartflow_id)
|
||||||
|
|
||||||
logger.debug(f"添加 observation ")
|
logger.debug(f"添加 observation ")
|
||||||
subheartflow.add_observation(observation)
|
subheartflow.add_observation(observation)
|
||||||
logger.debug(f"添加 observation 成功")
|
logger.debug(f"添加 observation 成功")
|
||||||
@@ -159,11 +159,11 @@ class Heartflow:
|
|||||||
self._subheartflows[subheartflow_id] = subheartflow
|
self._subheartflows[subheartflow_id] = subheartflow
|
||||||
logger.info(f"添加 subheartflow 成功")
|
logger.info(f"添加 subheartflow 成功")
|
||||||
return self._subheartflows[subheartflow_id]
|
return self._subheartflows[subheartflow_id]
|
||||||
|
|
||||||
def get_subheartflow(self, observe_chat_id):
|
def get_subheartflow(self, observe_chat_id):
|
||||||
"""获取指定ID的SubHeartflow实例"""
|
"""获取指定ID的SubHeartflow实例"""
|
||||||
return self._subheartflows.get(observe_chat_id)
|
return self._subheartflows.get(observe_chat_id)
|
||||||
|
|
||||||
|
|
||||||
# 创建一个全局的管理器实例
|
# 创建一个全局的管理器实例
|
||||||
heartflow = Heartflow()
|
heartflow = Heartflow()
|
||||||
|
|||||||
@@ -1,119 +1,123 @@
|
|||||||
#定义了来自外部世界的信息
|
# 定义了来自外部世界的信息
|
||||||
#外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体
|
# 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from src.plugins.models.utils_model import LLM_request
|
from src.plugins.models.utils_model import LLM_request
|
||||||
from src.plugins.config.config import global_config
|
from src.plugins.config.config import global_config
|
||||||
from src.common.database import db
|
from src.common.database import db
|
||||||
|
|
||||||
|
|
||||||
# 所有观察的基类
|
# 所有观察的基类
|
||||||
class Observation:
|
class Observation:
|
||||||
def __init__(self,observe_type,observe_id):
|
def __init__(self, observe_type, observe_id):
|
||||||
self.observe_info = ""
|
self.observe_info = ""
|
||||||
self.observe_type = observe_type
|
self.observe_type = observe_type
|
||||||
self.observe_id = observe_id
|
self.observe_id = observe_id
|
||||||
self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间
|
self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间
|
||||||
|
|
||||||
|
|
||||||
# 聊天观察
|
# 聊天观察
|
||||||
class ChattingObservation(Observation):
|
class ChattingObservation(Observation):
|
||||||
def __init__(self,chat_id):
|
def __init__(self, chat_id):
|
||||||
super().__init__("chat",chat_id)
|
super().__init__("chat", chat_id)
|
||||||
self.chat_id = chat_id
|
self.chat_id = chat_id
|
||||||
|
|
||||||
self.talking_message = []
|
self.talking_message = []
|
||||||
self.talking_message_str = ""
|
self.talking_message_str = ""
|
||||||
|
|
||||||
self.observe_times = 0
|
self.observe_times = 0
|
||||||
|
|
||||||
self.summary_count = 0 # 30秒内的更新次数
|
self.summary_count = 0 # 30秒内的更新次数
|
||||||
self.max_update_in_30s = 2 #30秒内最多更新2次
|
self.max_update_in_30s = 2 # 30秒内最多更新2次
|
||||||
self.last_summary_time = 0 #上次更新summary的时间
|
self.last_summary_time = 0 # 上次更新summary的时间
|
||||||
|
|
||||||
self.sub_observe = None
|
self.sub_observe = None
|
||||||
|
|
||||||
self.llm_summary = LLM_request(
|
self.llm_summary = LLM_request(
|
||||||
model=global_config.llm_outer_world, temperature=0.7, max_tokens=300, request_type="outer_world")
|
model=global_config.llm_outer_world, temperature=0.7, max_tokens=300, request_type="outer_world"
|
||||||
|
)
|
||||||
|
|
||||||
# 进行一次观察 返回观察结果observe_info
|
# 进行一次观察 返回观察结果observe_info
|
||||||
async def observe(self):
|
async def observe(self):
|
||||||
# 查找新消息,限制最多30条
|
# 查找新消息,限制最多30条
|
||||||
new_messages = list(db.messages.find({
|
new_messages = list(
|
||||||
"chat_id": self.chat_id,
|
db.messages.find({"chat_id": self.chat_id, "time": {"$gt": self.last_observe_time}})
|
||||||
"time": {"$gt": self.last_observe_time}
|
.sort("time", 1)
|
||||||
}).sort("time", 1).limit(20)) # 按时间正序排列,最多20条
|
.limit(20)
|
||||||
|
) # 按时间正序排列,最多20条
|
||||||
|
|
||||||
if not new_messages:
|
if not new_messages:
|
||||||
return self.observe_info #没有新消息,返回上次观察结果
|
return self.observe_info # 没有新消息,返回上次观察结果
|
||||||
|
|
||||||
# 将新消息转换为字符串格式
|
# 将新消息转换为字符串格式
|
||||||
new_messages_str = ""
|
new_messages_str = ""
|
||||||
for msg in new_messages:
|
for msg in new_messages:
|
||||||
if "sender_name" in msg and "content" in msg:
|
if "sender_name" in msg and "content" in msg:
|
||||||
new_messages_str += f"{msg['sender_name']}: {msg['content']}\n"
|
new_messages_str += f"{msg['sender_name']}: {msg['content']}\n"
|
||||||
|
|
||||||
# 将新消息添加到talking_message,同时保持列表长度不超过20条
|
# 将新消息添加到talking_message,同时保持列表长度不超过20条
|
||||||
self.talking_message.extend(new_messages)
|
self.talking_message.extend(new_messages)
|
||||||
if len(self.talking_message) > 20:
|
if len(self.talking_message) > 20:
|
||||||
self.talking_message = self.talking_message[-20:] # 只保留最新的20条
|
self.talking_message = self.talking_message[-20:] # 只保留最新的20条
|
||||||
self.translate_message_list_to_str()
|
self.translate_message_list_to_str()
|
||||||
|
|
||||||
# 更新观察次数
|
# 更新观察次数
|
||||||
self.observe_times += 1
|
self.observe_times += 1
|
||||||
self.last_observe_time = new_messages[-1]["time"]
|
self.last_observe_time = new_messages[-1]["time"]
|
||||||
|
|
||||||
# 检查是否需要更新summary
|
# 检查是否需要更新summary
|
||||||
current_time = int(datetime.now().timestamp())
|
current_time = int(datetime.now().timestamp())
|
||||||
if current_time - self.last_summary_time >= 30: # 如果超过30秒,重置计数
|
if current_time - self.last_summary_time >= 30: # 如果超过30秒,重置计数
|
||||||
self.summary_count = 0
|
self.summary_count = 0
|
||||||
self.last_summary_time = current_time
|
self.last_summary_time = current_time
|
||||||
|
|
||||||
if self.summary_count < self.max_update_in_30s: # 如果30秒内更新次数小于2次
|
if self.summary_count < self.max_update_in_30s: # 如果30秒内更新次数小于2次
|
||||||
await self.update_talking_summary(new_messages_str)
|
await self.update_talking_summary(new_messages_str)
|
||||||
self.summary_count += 1
|
self.summary_count += 1
|
||||||
|
|
||||||
return self.observe_info
|
return self.observe_info
|
||||||
|
|
||||||
async def carefully_observe(self):
|
async def carefully_observe(self):
|
||||||
# 查找新消息,限制最多40条
|
# 查找新消息,限制最多40条
|
||||||
new_messages = list(db.messages.find({
|
new_messages = list(
|
||||||
"chat_id": self.chat_id,
|
db.messages.find({"chat_id": self.chat_id, "time": {"$gt": self.last_observe_time}})
|
||||||
"time": {"$gt": self.last_observe_time}
|
.sort("time", 1)
|
||||||
}).sort("time", 1).limit(30)) # 按时间正序排列,最多30条
|
.limit(30)
|
||||||
|
) # 按时间正序排列,最多30条
|
||||||
|
|
||||||
if not new_messages:
|
if not new_messages:
|
||||||
return self.observe_info #没有新消息,返回上次观察结果
|
return self.observe_info # 没有新消息,返回上次观察结果
|
||||||
|
|
||||||
# 将新消息转换为字符串格式
|
# 将新消息转换为字符串格式
|
||||||
new_messages_str = ""
|
new_messages_str = ""
|
||||||
for msg in new_messages:
|
for msg in new_messages:
|
||||||
if "sender_name" in msg and "content" in msg:
|
if "sender_name" in msg and "content" in msg:
|
||||||
new_messages_str += f"{msg['sender_name']}: {msg['content']}\n"
|
new_messages_str += f"{msg['sender_name']}: {msg['content']}\n"
|
||||||
|
|
||||||
# 将新消息添加到talking_message,同时保持列表长度不超过30条
|
# 将新消息添加到talking_message,同时保持列表长度不超过30条
|
||||||
self.talking_message.extend(new_messages)
|
self.talking_message.extend(new_messages)
|
||||||
if len(self.talking_message) > 30:
|
if len(self.talking_message) > 30:
|
||||||
self.talking_message = self.talking_message[-30:] # 只保留最新的30条
|
self.talking_message = self.talking_message[-30:] # 只保留最新的30条
|
||||||
self.translate_message_list_to_str()
|
self.translate_message_list_to_str()
|
||||||
|
|
||||||
# 更新观察次数
|
# 更新观察次数
|
||||||
self.observe_times += 1
|
self.observe_times += 1
|
||||||
self.last_observe_time = new_messages[-1]["time"]
|
self.last_observe_time = new_messages[-1]["time"]
|
||||||
|
|
||||||
await self.update_talking_summary(new_messages_str)
|
await self.update_talking_summary(new_messages_str)
|
||||||
return self.observe_info
|
return self.observe_info
|
||||||
|
|
||||||
|
async def update_talking_summary(self, new_messages_str):
|
||||||
async def update_talking_summary(self,new_messages_str):
|
# 基于已经有的talking_summary,和新的talking_message,生成一个summary
|
||||||
#基于已经有的talking_summary,和新的talking_message,生成一个summary
|
|
||||||
# print(f"更新聊天总结:{self.talking_summary}")
|
# print(f"更新聊天总结:{self.talking_summary}")
|
||||||
prompt = ""
|
prompt = ""
|
||||||
prompt = f"你正在参与一个qq群聊的讨论,这个群之前在聊的内容是:{self.observe_info}\n"
|
prompt = f"你正在参与一个qq群聊的讨论,这个群之前在聊的内容是:{self.observe_info}\n"
|
||||||
prompt += f"现在群里的群友们产生了新的讨论,有了新的发言,具体内容如下:{new_messages_str}\n"
|
prompt += f"现在群里的群友们产生了新的讨论,有了新的发言,具体内容如下:{new_messages_str}\n"
|
||||||
prompt += '''以上是群里在进行的聊天,请你对这个聊天内容进行总结,总结内容要包含聊天的大致内容,
|
prompt += """以上是群里在进行的聊天,请你对这个聊天内容进行总结,总结内容要包含聊天的大致内容,
|
||||||
以及聊天中的一些重要信息,记得不要分点,不要太长,精简的概括成一段文本\n'''
|
以及聊天中的一些重要信息,记得不要分点,不要太长,精简的概括成一段文本\n"""
|
||||||
prompt += "总结概括:"
|
prompt += "总结概括:"
|
||||||
self.observe_info, reasoning_content = await self.llm_summary.generate_response_async(prompt)
|
self.observe_info, reasoning_content = await self.llm_summary.generate_response_async(prompt)
|
||||||
|
|
||||||
def translate_message_list_to_str(self):
|
def translate_message_list_to_str(self):
|
||||||
self.talking_message_str = ""
|
self.talking_message_str = ""
|
||||||
for message in self.talking_message:
|
for message in self.talking_message:
|
||||||
|
|||||||
@@ -7,13 +7,13 @@ import re
|
|||||||
import time
|
import time
|
||||||
from src.plugins.schedule.schedule_generator import bot_schedule
|
from src.plugins.schedule.schedule_generator import bot_schedule
|
||||||
from src.plugins.memory_system.Hippocampus import HippocampusManager
|
from src.plugins.memory_system.Hippocampus import HippocampusManager
|
||||||
from src.common.logger import get_module_logger, LogConfig, SUB_HEARTFLOW_STYLE_CONFIG # noqa: E402
|
from src.common.logger import get_module_logger, LogConfig, SUB_HEARTFLOW_STYLE_CONFIG # noqa: E402
|
||||||
|
|
||||||
subheartflow_config = LogConfig(
|
subheartflow_config = LogConfig(
|
||||||
# 使用海马体专用样式
|
# 使用海马体专用样式
|
||||||
console_format=SUB_HEARTFLOW_STYLE_CONFIG["console_format"],
|
console_format=SUB_HEARTFLOW_STYLE_CONFIG["console_format"],
|
||||||
file_format=SUB_HEARTFLOW_STYLE_CONFIG["file_format"],
|
file_format=SUB_HEARTFLOW_STYLE_CONFIG["file_format"],
|
||||||
)
|
)
|
||||||
logger = get_module_logger("subheartflow", config=subheartflow_config)
|
logger = get_module_logger("subheartflow", config=subheartflow_config)
|
||||||
|
|
||||||
|
|
||||||
@@ -21,38 +21,39 @@ class CuttentState:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.willing = 0
|
self.willing = 0
|
||||||
self.current_state_info = ""
|
self.current_state_info = ""
|
||||||
|
|
||||||
self.mood_manager = MoodManager()
|
self.mood_manager = MoodManager()
|
||||||
self.mood = self.mood_manager.get_prompt()
|
self.mood = self.mood_manager.get_prompt()
|
||||||
|
|
||||||
def update_current_state_info(self):
|
def update_current_state_info(self):
|
||||||
self.current_state_info = self.mood_manager.get_current_mood()
|
self.current_state_info = self.mood_manager.get_current_mood()
|
||||||
|
|
||||||
|
|
||||||
class SubHeartflow:
|
class SubHeartflow:
|
||||||
def __init__(self,subheartflow_id):
|
def __init__(self, subheartflow_id):
|
||||||
self.subheartflow_id = subheartflow_id
|
self.subheartflow_id = subheartflow_id
|
||||||
|
|
||||||
self.current_mind = ""
|
self.current_mind = ""
|
||||||
self.past_mind = []
|
self.past_mind = []
|
||||||
self.current_state : CuttentState = CuttentState()
|
self.current_state: CuttentState = CuttentState()
|
||||||
self.llm_model = LLM_request(
|
self.llm_model = LLM_request(
|
||||||
model=global_config.llm_sub_heartflow, temperature=0.7, max_tokens=600, request_type="sub_heart_flow")
|
model=global_config.llm_sub_heartflow, temperature=0.7, max_tokens=600, request_type="sub_heart_flow"
|
||||||
|
)
|
||||||
|
|
||||||
self.main_heartflow_info = ""
|
self.main_heartflow_info = ""
|
||||||
|
|
||||||
self.last_reply_time = time.time()
|
self.last_reply_time = time.time()
|
||||||
self.last_active_time = time.time() # 添加最后激活时间
|
self.last_active_time = time.time() # 添加最后激活时间
|
||||||
|
|
||||||
if not self.current_mind:
|
if not self.current_mind:
|
||||||
self.current_mind = "你什么也没想"
|
self.current_mind = "你什么也没想"
|
||||||
|
|
||||||
self.personality_info = " ".join(global_config.PROMPT_PERSONALITY)
|
self.personality_info = " ".join(global_config.PROMPT_PERSONALITY)
|
||||||
|
|
||||||
self.is_active = False
|
self.is_active = False
|
||||||
|
|
||||||
self.observations : list[Observation] = []
|
self.observations: list[Observation] = []
|
||||||
|
|
||||||
def add_observation(self, observation: Observation):
|
def add_observation(self, observation: Observation):
|
||||||
"""添加一个新的observation对象到列表中,如果已存在相同id的observation则不添加"""
|
"""添加一个新的observation对象到列表中,如果已存在相同id的observation则不添加"""
|
||||||
# 查找是否存在相同id的observation
|
# 查找是否存在相同id的observation
|
||||||
@@ -62,16 +63,16 @@ class SubHeartflow:
|
|||||||
return
|
return
|
||||||
# 如果没有找到相同id的observation,则添加新的
|
# 如果没有找到相同id的observation,则添加新的
|
||||||
self.observations.append(observation)
|
self.observations.append(observation)
|
||||||
|
|
||||||
def remove_observation(self, observation: Observation):
|
def remove_observation(self, observation: Observation):
|
||||||
"""从列表中移除一个observation对象"""
|
"""从列表中移除一个observation对象"""
|
||||||
if observation in self.observations:
|
if observation in self.observations:
|
||||||
self.observations.remove(observation)
|
self.observations.remove(observation)
|
||||||
|
|
||||||
def get_all_observations(self) -> list[Observation]:
|
def get_all_observations(self) -> list[Observation]:
|
||||||
"""获取所有observation对象"""
|
"""获取所有observation对象"""
|
||||||
return self.observations
|
return self.observations
|
||||||
|
|
||||||
def clear_observations(self):
|
def clear_observations(self):
|
||||||
"""清空所有observation对象"""
|
"""清空所有observation对象"""
|
||||||
self.observations.clear()
|
self.observations.clear()
|
||||||
@@ -85,50 +86,45 @@ class SubHeartflow:
|
|||||||
else:
|
else:
|
||||||
self.is_active = True
|
self.is_active = True
|
||||||
self.last_active_time = current_time # 更新最后激活时间
|
self.last_active_time = current_time # 更新最后激活时间
|
||||||
|
|
||||||
observation = self.observations[0]
|
observation = self.observations[0]
|
||||||
await observation.observe()
|
await observation.observe()
|
||||||
|
|
||||||
self.current_state.update_current_state_info()
|
self.current_state.update_current_state_info()
|
||||||
|
|
||||||
await self.do_a_thinking()
|
await self.do_a_thinking()
|
||||||
await self.judge_willing()
|
await self.judge_willing()
|
||||||
await asyncio.sleep(60)
|
await asyncio.sleep(60)
|
||||||
|
|
||||||
# 检查是否超过10分钟没有激活
|
# 检查是否超过10分钟没有激活
|
||||||
if current_time - self.last_active_time > 600: # 5分钟无回复/不在场,销毁
|
if current_time - self.last_active_time > 600: # 5分钟无回复/不在场,销毁
|
||||||
logger.info(f"子心流 {self.subheartflow_id} 已经5分钟没有激活,正在销毁...")
|
logger.info(f"子心流 {self.subheartflow_id} 已经5分钟没有激活,正在销毁...")
|
||||||
break # 退出循环,销毁自己
|
break # 退出循环,销毁自己
|
||||||
|
|
||||||
async def do_a_thinking(self):
|
async def do_a_thinking(self):
|
||||||
|
|
||||||
current_thinking_info = self.current_mind
|
current_thinking_info = self.current_mind
|
||||||
mood_info = self.current_state.mood
|
mood_info = self.current_state.mood
|
||||||
|
|
||||||
observation = self.observations[0]
|
observation = self.observations[0]
|
||||||
chat_observe_info = observation.observe_info
|
chat_observe_info = observation.observe_info
|
||||||
print(f"chat_observe_info:{chat_observe_info}")
|
print(f"chat_observe_info:{chat_observe_info}")
|
||||||
|
|
||||||
# 调取记忆
|
# 调取记忆
|
||||||
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||||
text=chat_observe_info,
|
text=chat_observe_info, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
|
||||||
max_memory_num=2,
|
|
||||||
max_memory_length=2,
|
|
||||||
max_depth=3,
|
|
||||||
fast_retrieval=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if related_memory:
|
if related_memory:
|
||||||
related_memory_info = ""
|
related_memory_info = ""
|
||||||
for memory in related_memory:
|
for memory in related_memory:
|
||||||
related_memory_info += memory[1]
|
related_memory_info += memory[1]
|
||||||
else:
|
else:
|
||||||
related_memory_info = ''
|
related_memory_info = ""
|
||||||
|
|
||||||
# print(f"相关记忆:{related_memory_info}")
|
# print(f"相关记忆:{related_memory_info}")
|
||||||
|
|
||||||
schedule_info = bot_schedule.get_current_num_task(num = 1,time_info = False)
|
schedule_info = bot_schedule.get_current_num_task(num=1, time_info=False)
|
||||||
|
|
||||||
prompt = ""
|
prompt = ""
|
||||||
prompt += f"你刚刚在做的事情是:{schedule_info}\n"
|
prompt += f"你刚刚在做的事情是:{schedule_info}\n"
|
||||||
# prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n"
|
# prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n"
|
||||||
@@ -142,25 +138,25 @@ class SubHeartflow:
|
|||||||
prompt += "现在你接下去继续思考,产生新的想法,不要分点输出,输出连贯的内心独白,不要太长,"
|
prompt += "现在你接下去继续思考,产生新的想法,不要分点输出,输出连贯的内心独白,不要太长,"
|
||||||
prompt += "但是记得结合上述的消息,要记得维持住你的人设,关注聊天和新内容,不要思考太多:"
|
prompt += "但是记得结合上述的消息,要记得维持住你的人设,关注聊天和新内容,不要思考太多:"
|
||||||
reponse, reasoning_content = await self.llm_model.generate_response_async(prompt)
|
reponse, reasoning_content = await self.llm_model.generate_response_async(prompt)
|
||||||
|
|
||||||
self.update_current_mind(reponse)
|
self.update_current_mind(reponse)
|
||||||
|
|
||||||
self.current_mind = reponse
|
self.current_mind = reponse
|
||||||
logger.debug(f"prompt:\n{prompt}\n")
|
logger.debug(f"prompt:\n{prompt}\n")
|
||||||
logger.info(f"麦麦的脑内状态:{self.current_mind}")
|
logger.info(f"麦麦的脑内状态:{self.current_mind}")
|
||||||
|
|
||||||
async def do_after_reply(self,reply_content,chat_talking_prompt):
|
async def do_after_reply(self, reply_content, chat_talking_prompt):
|
||||||
# print("麦麦脑袋转起来了")
|
# print("麦麦脑袋转起来了")
|
||||||
current_thinking_info = self.current_mind
|
current_thinking_info = self.current_mind
|
||||||
mood_info = self.current_state.mood
|
mood_info = self.current_state.mood
|
||||||
|
|
||||||
observation = self.observations[0]
|
observation = self.observations[0]
|
||||||
chat_observe_info = observation.observe_info
|
chat_observe_info = observation.observe_info
|
||||||
|
|
||||||
message_new_info = chat_talking_prompt
|
message_new_info = chat_talking_prompt
|
||||||
reply_info = reply_content
|
reply_info = reply_content
|
||||||
schedule_info = bot_schedule.get_current_num_task(num = 1,time_info = False)
|
schedule_info = bot_schedule.get_current_num_task(num=1, time_info=False)
|
||||||
|
|
||||||
prompt = ""
|
prompt = ""
|
||||||
prompt += f"你现在正在做的事情是:{schedule_info}\n"
|
prompt += f"你现在正在做的事情是:{schedule_info}\n"
|
||||||
prompt += f"你{self.personality_info}\n"
|
prompt += f"你{self.personality_info}\n"
|
||||||
@@ -171,16 +167,16 @@ class SubHeartflow:
|
|||||||
prompt += f"你现在{mood_info}"
|
prompt += f"你现在{mood_info}"
|
||||||
prompt += "现在你接下去继续思考,产生新的想法,记得保留你刚刚的想法,不要分点输出,输出连贯的内心独白"
|
prompt += "现在你接下去继续思考,产生新的想法,记得保留你刚刚的想法,不要分点输出,输出连贯的内心独白"
|
||||||
prompt += "不要太长,但是记得结合上述的消息,要记得你的人设,关注聊天和新内容,关注你回复的内容,不要思考太多:"
|
prompt += "不要太长,但是记得结合上述的消息,要记得你的人设,关注聊天和新内容,关注你回复的内容,不要思考太多:"
|
||||||
|
|
||||||
reponse, reasoning_content = await self.llm_model.generate_response_async(prompt)
|
reponse, reasoning_content = await self.llm_model.generate_response_async(prompt)
|
||||||
|
|
||||||
self.update_current_mind(reponse)
|
self.update_current_mind(reponse)
|
||||||
|
|
||||||
self.current_mind = reponse
|
self.current_mind = reponse
|
||||||
logger.info(f"麦麦回复后的脑内状态:{self.current_mind}")
|
logger.info(f"麦麦回复后的脑内状态:{self.current_mind}")
|
||||||
|
|
||||||
self.last_reply_time = time.time()
|
self.last_reply_time = time.time()
|
||||||
|
|
||||||
async def judge_willing(self):
|
async def judge_willing(self):
|
||||||
# print("麦麦闹情绪了1")
|
# print("麦麦闹情绪了1")
|
||||||
current_thinking_info = self.current_mind
|
current_thinking_info = self.current_mind
|
||||||
@@ -193,21 +189,20 @@ class SubHeartflow:
|
|||||||
prompt += f"你现在{mood_info}。"
|
prompt += f"你现在{mood_info}。"
|
||||||
prompt += "现在请你思考,你想不想发言或者回复,请你输出一个数字,1-10,1表示非常不想,10表示非常想。"
|
prompt += "现在请你思考,你想不想发言或者回复,请你输出一个数字,1-10,1表示非常不想,10表示非常想。"
|
||||||
prompt += "请你用<>包裹你的回复意愿,输出<1>表示不想回复,输出<10>表示非常想回复。请你考虑,你完全可以不回复"
|
prompt += "请你用<>包裹你的回复意愿,输出<1>表示不想回复,输出<10>表示非常想回复。请你考虑,你完全可以不回复"
|
||||||
|
|
||||||
response, reasoning_content = await self.llm_model.generate_response_async(prompt)
|
response, reasoning_content = await self.llm_model.generate_response_async(prompt)
|
||||||
# 解析willing值
|
# 解析willing值
|
||||||
willing_match = re.search(r'<(\d+)>', response)
|
willing_match = re.search(r"<(\d+)>", response)
|
||||||
if willing_match:
|
if willing_match:
|
||||||
self.current_state.willing = int(willing_match.group(1))
|
self.current_state.willing = int(willing_match.group(1))
|
||||||
else:
|
else:
|
||||||
self.current_state.willing = 0
|
self.current_state.willing = 0
|
||||||
|
|
||||||
return self.current_state.willing
|
return self.current_state.willing
|
||||||
|
|
||||||
def update_current_mind(self,reponse):
|
def update_current_mind(self, reponse):
|
||||||
self.past_mind.append(self.current_mind)
|
self.past_mind.append(self.current_mind)
|
||||||
self.current_mind = reponse
|
self.current_mind = reponse
|
||||||
|
|
||||||
|
|
||||||
# subheartflow = SubHeartflow()
|
# subheartflow = SubHeartflow()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user