feat:保存关键词到message数据库

This commit is contained in:
SengokuCola
2025-08-10 21:12:49 +08:00
parent 22a625ce46
commit 69a855df8d
8 changed files with 150 additions and 109 deletions

View File

@@ -4,6 +4,7 @@ import traceback
import random import random
from typing import List, Optional, Dict, Any, Tuple from typing import List, Optional, Dict, Any, Tuple
from rich.traceback import install from rich.traceback import install
from collections import deque
from src.config.config import global_config from src.config.config import global_config
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -121,6 +122,8 @@ class HeartFChatting:
self.focus_energy = 1 self.focus_energy = 1
self.no_reply_consecutive = 0 self.no_reply_consecutive = 0
# 最近三次no_reply的新消息兴趣度记录
self.recent_interest_records: deque = deque(maxlen=3)
async def start(self): async def start(self):
"""检查是否需要启动主循环,如果未激活则启动。""" """检查是否需要启动主循环,如果未激活则启动。"""
@@ -210,13 +213,10 @@ class HeartFChatting:
self.focus_energy = 1 self.focus_energy = 1
else: else:
# 计算最近三次记录的兴趣度总和 # 计算最近三次记录的兴趣度总和
total_recent_interest = sum(NoReplyAction._recent_interest_records) total_recent_interest = sum(self.recent_interest_records)
# 获取当前聊天频率和意愿系数
talk_frequency = global_config.chat.get_current_talk_frequency(self.stream_id)
# 计算调整后的阈值 # 计算调整后的阈值
adjusted_threshold = 3 / talk_frequency adjusted_threshold = 3 / global_config.chat.get_current_talk_frequency(self.stream_id)
logger.info(f"{self.log_prefix} 最近三次兴趣度总和: {total_recent_interest:.2f}, 调整后阈值: {adjusted_threshold:.2f}") logger.info(f"{self.log_prefix} 最近三次兴趣度总和: {total_recent_interest:.2f}, 调整后阈值: {adjusted_threshold:.2f}")
@@ -228,57 +228,74 @@ class HeartFChatting:
logger.info(f"{self.log_prefix} 兴趣度充足") logger.info(f"{self.log_prefix} 兴趣度充足")
self.focus_energy = 1 self.focus_energy = 1
async def _execute_no_reply(self, new_message:List[Dict[str, Any]]) -> Tuple[bool, str]: async def _should_process_messages(self, new_message: List[Dict[str, Any]], mode: ChatMode) -> bool:
"""执行breaking形式的no_reply原有逻辑""" """
new_message_count = len(new_message) 判断是否应该处理消息
# 检查消息数量是否达到阈值
talk_frequency = global_config.chat.get_current_talk_frequency(self.stream_id)
modified_exit_count_threshold = self.focus_energy / talk_frequency
if new_message_count >= modified_exit_count_threshold: Args:
# 记录兴趣度到列表 new_message: 新消息列表
total_interest = 0.0 mode: 当前聊天模式
for msg_dict in new_message:
interest_value = msg_dict.get("interest_value", 0.0)
if msg_dict.get("processed_plain_text", ""):
total_interest += interest_value
NoReplyAction._recent_interest_records.append(total_interest) Returns:
bool: 是否应该处理消息
"""
new_message_count = len(new_message)
if mode == ChatMode.NORMAL:
# Normal模式简单的消息数量判断
return new_message_count >= self.focus_energy
logger.info( elif mode == ChatMode.FOCUS:
f"{self.log_prefix} 累计消息数量达到{new_message_count}条(>{modified_exit_count_threshold}),结束等待" # Focus模式原有的breaking形式no_reply逻辑
) talk_frequency = global_config.chat.get_current_talk_frequency(self.stream_id)
modified_exit_count_threshold = self.focus_energy / talk_frequency
return True
# 检查累计兴趣值
if new_message_count > 0:
accumulated_interest = 0.0
for msg_dict in new_message:
text = msg_dict.get("processed_plain_text", "")
interest_value = msg_dict.get("interest_value", 0.0)
if text:
accumulated_interest += interest_value
# 只在兴趣值变化时输出log if new_message_count >= modified_exit_count_threshold:
if not hasattr(self, "_last_accumulated_interest") or accumulated_interest != self._last_accumulated_interest:
logger.info(f"{self.log_prefix} breaking形式当前累计兴趣值: {accumulated_interest:.2f}, 当前聊天频率: {talk_frequency:.2f}")
self._last_accumulated_interest = accumulated_interest
if accumulated_interest >= 3 / talk_frequency:
# 记录兴趣度到列表 # 记录兴趣度到列表
NoReplyAction._recent_interest_records.append(accumulated_interest) total_interest = 0.0
for msg_dict in new_message:
interest_value = msg_dict.get("interest_value", 0.0)
if msg_dict.get("processed_plain_text", ""):
total_interest += interest_value
self.recent_interest_records.append(total_interest)
logger.info( logger.info(
f"{self.log_prefix} 累计兴趣值达到{accumulated_interest:.2f}(>{5 / talk_frequency}),结束等待" f"{self.log_prefix} 累计消息数量达到{new_message_count}(>{modified_exit_count_threshold}),结束等待"
) )
return True return True
# 每10秒输出一次等待状态 # 检查累计兴趣值
if int(time.time() - self.last_read_time) > 0 and int(time.time() - self.last_read_time) % 10 == 0: if new_message_count > 0:
logger.info( accumulated_interest = 0.0
f"{self.log_prefix} 已等待{time.time() - self.last_read_time:.0f}秒,累计{new_message_count}条消息,继续等待..." for msg_dict in new_message:
) text = msg_dict.get("processed_plain_text", "")
interest_value = msg_dict.get("interest_value", 0.0)
if text:
accumulated_interest += interest_value
# 只在兴趣值变化时输出log
if not hasattr(self, "_last_accumulated_interest") or accumulated_interest != self._last_accumulated_interest:
logger.info(f"{self.log_prefix} breaking形式当前累计兴趣值: {accumulated_interest:.2f}, 当前聊天频率: {talk_frequency:.2f}")
self._last_accumulated_interest = accumulated_interest
if accumulated_interest >= 3 / talk_frequency:
# 记录兴趣度到列表
self.recent_interest_records.append(accumulated_interest)
logger.info(
f"{self.log_prefix} 累计兴趣值达到{accumulated_interest:.2f}(>{5 / talk_frequency}),结束等待"
)
return True
# 每10秒输出一次等待状态
if int(time.time() - self.last_read_time) > 0 and int(time.time() - self.last_read_time) % 10 == 0:
logger.info(
f"{self.log_prefix} 已等待{time.time() - self.last_read_time:.0f}秒,累计{new_message_count}条消息,继续等待..."
)
await asyncio.sleep(0.5)
return False
async def _loopbody(self): async def _loopbody(self):
@@ -291,51 +308,50 @@ class HeartFChatting:
filter_mai=True, filter_mai=True,
filter_command=True, filter_command=True,
) )
new_message_count = len(recent_messages_dict)
# 先进行focus判定
if self.loop_mode == ChatMode.FOCUS: if self.loop_mode == ChatMode.FOCUS:
if self.last_action == "no_reply":
if not await self._execute_no_reply(recent_messages_dict):
self.energy_value -= 0.3 / global_config.chat.focus_value
logger.info(f"{self.log_prefix} 能量值减少,当前能量值:{self.energy_value:.1f}")
await asyncio.sleep(0.5)
return True
self.last_read_time = time.time()
if await self._observe():
self.energy_value += 1 / global_config.chat.focus_value
logger.info(f"{self.log_prefix} 能量值增加,当前能量值:{self.energy_value:.1f}")
if self.energy_value <= 1: if self.energy_value <= 1:
logger.info(f"{self.log_prefix} 能量值过低进入normal模式")
self.energy_value = 1 self.energy_value = 1
self.loop_mode = ChatMode.NORMAL self.loop_mode = ChatMode.NORMAL
return True return True
return True
elif self.loop_mode == ChatMode.NORMAL: elif self.loop_mode == ChatMode.NORMAL:
if global_config.chat.focus_value != 0: if global_config.chat.focus_value != 0 and self.energy_value >= 30:
if new_message_count > 3 / pow(global_config.chat.focus_value, 0.5): self.loop_mode = ChatMode.FOCUS
self.loop_mode = ChatMode.FOCUS return True
self.energy_value = (
10 + (new_message_count / (3 / pow(global_config.chat.focus_value, 0.5))) * 10 # 统一的消息处理逻辑
) should_process = await self._should_process_messages(recent_messages_dict, self.loop_mode)
return True
if self.loop_mode == ChatMode.FOCUS:
if self.energy_value >= 30: # Focus模式处理
self.loop_mode = ChatMode.FOCUS if self.last_action == "no_reply" and not should_process:
return True # 需要继续等待
self.energy_value -= 0.3 / global_config.chat.focus_value
if new_message_count >= self.focus_energy: logger.info(f"{self.log_prefix} 能量值减少,当前能量值:{self.energy_value:.1f}")
earliest_messages_data = recent_messages_dict[0] await asyncio.sleep(0.5)
self.last_read_time = earliest_messages_data.get("time") return True
if_think = await self.normal_response(earliest_messages_data) if should_process:
# Focus模式设置last_read_time并执行observe
self.last_read_time = time.time()
if await self._observe():
self.energy_value += 1 / global_config.chat.focus_value
logger.info(f"{self.log_prefix} 能量值增加,当前能量值:{self.energy_value:.1f}")
return True
elif self.loop_mode == ChatMode.NORMAL:
# Normal模式处理
if should_process:
# Normal模式设置last_read_time为最早消息的时间并调用normal_response
earliest_message_data = recent_messages_dict[0]
self.last_read_time = earliest_message_data.get("time")
if_think = await self.normal_response(earliest_message_data)
if if_think: if if_think:
factor = max(global_config.chat.focus_value, 0.1) factor = max(global_config.chat.focus_value, 0.1)
self.energy_value *= 1.1 * factor self.energy_value *= 1.1 * pow(factor, 0.5)
logger.info(f"{self.log_prefix} 进行了思考,能量值按倍数增加,当前能量值:{self.energy_value:.1f}") logger.info(f"{self.log_prefix} 进行了思考,能量值按倍数增加,当前能量值:{self.energy_value:.1f}")
else: else:
self.energy_value += 0.1 * global_config.chat.focus_value self.energy_value += 0.1 * global_config.chat.focus_value
@@ -343,10 +359,12 @@ class HeartFChatting:
logger.debug(f"{self.log_prefix} 当前能量值:{self.energy_value:.1f}") logger.debug(f"{self.log_prefix} 当前能量值:{self.energy_value:.1f}")
return True return True
else:
# Normal模式消息数量不足等待
await asyncio.sleep(0.5)
return True
await asyncio.sleep(0.5) return True
return True
async def build_reply_to_str(self, message_data: dict): async def build_reply_to_str(self, message_data: dict):
person_info_manager = get_person_info_manager() person_info_manager = get_person_info_manager()
@@ -705,13 +723,13 @@ class HeartFChatting:
# 管理no_reply计数器当执行了非no_reply动作时重置计数器 # 管理no_reply计数器当执行了非no_reply动作时重置计数器
if action_type != "no_reply" and action_type != "no_action": if action_type != "no_reply" and action_type != "no_action":
# 导入NoReplyAction并重置计数器 # 导入NoReplyAction并重置计数器
NoReplyAction.reset_consecutive_count() self.recent_interest_records.clear()
self.no_reply_consecutive = 0 self.no_reply_consecutive = 0
logger.info(f"{self.log_prefix} 执行了{action_type}动作重置no_reply计数器") logger.info(f"{self.log_prefix} 执行了{action_type}动作重置no_reply计数器")
return True return True
elif action_type == "no_action": elif action_type == "no_action":
# 当执行回复动作时也重置no_reply计数 # 当执行回复动作时也重置no_reply计数
NoReplyAction.reset_consecutive_count() self.recent_interest_records.clear()
self.no_reply_consecutive = 0 self.no_reply_consecutive = 0
logger.info(f"{self.log_prefix} 执行了回复动作重置no_reply计数器") logger.info(f"{self.log_prefix} 执行了回复动作重置no_reply计数器")

View File

@@ -57,9 +57,11 @@ async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool, list[s
with Timer("记忆激活"): with Timer("记忆激活"):
interested_rate, keywords = await hippocampus_manager.get_activate_from_text( interested_rate, keywords = await hippocampus_manager.get_activate_from_text(
message.processed_plain_text, message.processed_plain_text,
max_depth= 5, max_depth= 4,
fast_retrieval=False, fast_retrieval=False,
) )
message.key_words = keywords
message.key_words_lite = keywords
logger.debug(f"记忆激活率: {interested_rate:.2f}, 关键词: {keywords}") logger.debug(f"记忆激活率: {interested_rate:.2f}, 关键词: {keywords}")
text_len = len(message.processed_plain_text) text_len = len(message.processed_plain_text)

View File

@@ -322,14 +322,14 @@ class Hippocampus:
# 使用LLM提取关键词 - 根据详细文本长度分布优化topic_num计算 # 使用LLM提取关键词 - 根据详细文本长度分布优化topic_num计算
text_length = len(text) text_length = len(text)
topic_num: int | list[int] = 0 topic_num: int | list[int] = 0
if text_length <= 5: if text_length <= 6:
words = jieba.cut(text) words = jieba.cut(text)
keywords = [word for word in words if len(word) > 1] keywords = [word for word in words if len(word) > 1]
keywords = list(set(keywords))[:3] # 限制最多3个关键词 keywords = list(set(keywords))[:3] # 限制最多3个关键词
if keywords: if keywords:
logger.debug(f"提取关键词: {keywords}") logger.debug(f"提取关键词: {keywords}")
return keywords return keywords
elif text_length <= 10: elif text_length <= 12:
topic_num = [1, 3] # 6-10字符: 1个关键词 (27.18%的文本) topic_num = [1, 3] # 6-10字符: 1个关键词 (27.18%的文本)
elif text_length <= 20: elif text_length <= 20:
topic_num = [2, 4] # 11-20字符: 2个关键词 (22.76%的文本) topic_num = [2, 4] # 11-20字符: 2个关键词 (22.76%的文本)
@@ -776,7 +776,7 @@ class Hippocampus:
total_nodes = len(self.memory_graph.G.nodes()) total_nodes = len(self.memory_graph.G.nodes())
# activated_nodes = len(activate_map) # activated_nodes = len(activate_map)
activation_ratio = total_activation / total_nodes if total_nodes > 0 else 0 activation_ratio = total_activation / total_nodes if total_nodes > 0 else 0
activation_ratio = activation_ratio * 60 activation_ratio = activation_ratio * 50
logger.debug(f"总激活值: {total_activation:.2f}, 总节点数: {total_nodes}, 激活: {activation_ratio}") logger.debug(f"总激活值: {total_activation:.2f}, 总节点数: {total_nodes}, 激活: {activation_ratio}")
return activation_ratio, keywords return activation_ratio, keywords

View File

@@ -116,6 +116,9 @@ class MessageRecv(Message):
self.priority_mode = "interest" self.priority_mode = "interest"
self.priority_info = None self.priority_info = None
self.interest_value: float = None # type: ignore self.interest_value: float = None # type: ignore
self.key_words = []
self.key_words_lite = []
def update_chat_stream(self, chat_stream: "ChatStream"): def update_chat_stream(self, chat_stream: "ChatStream"):
self.chat_stream = chat_stream self.chat_stream = chat_stream

View File

@@ -1,4 +1,5 @@
import re import re
import json
import traceback import traceback
from typing import Union from typing import Union
@@ -11,6 +12,23 @@ logger = get_logger("message_storage")
class MessageStorage: class MessageStorage:
@staticmethod
def _serialize_keywords(keywords) -> str:
"""将关键词列表序列化为JSON字符串"""
if isinstance(keywords, list):
return json.dumps(keywords, ensure_ascii=False)
return "[]"
@staticmethod
def _deserialize_keywords(keywords_str: str) -> list:
"""将JSON字符串反序列化为关键词列表"""
if not keywords_str:
return []
try:
return json.loads(keywords_str)
except (json.JSONDecodeError, TypeError):
return []
@staticmethod @staticmethod
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None: async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
"""存储消息到数据库""" """存储消息到数据库"""
@@ -45,6 +63,8 @@ class MessageStorage:
is_picid = False is_picid = False
is_notify = False is_notify = False
is_command = False is_command = False
key_words = ""
key_words_lite = ""
else: else:
filtered_display_message = "" filtered_display_message = ""
interest_value = message.interest_value interest_value = message.interest_value
@@ -56,7 +76,10 @@ class MessageStorage:
is_picid = message.is_picid is_picid = message.is_picid
is_notify = message.is_notify is_notify = message.is_notify
is_command = message.is_command is_command = message.is_command
# 序列化关键词列表为JSON字符串
key_words = MessageStorage._serialize_keywords(message.key_words)
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
chat_info_dict = chat_stream.to_dict() chat_info_dict = chat_stream.to_dict()
user_info_dict = message.message_info.user_info.to_dict() # type: ignore user_info_dict = message.message_info.user_info.to_dict() # type: ignore
@@ -102,6 +125,8 @@ class MessageStorage:
is_picid=is_picid, is_picid=is_picid,
is_notify=is_notify, is_notify=is_notify,
is_command=is_command, is_command=is_command,
key_words=key_words,
key_words_lite=key_words_lite,
) )
except Exception: except Exception:
logger.exception("存储消息失败") logger.exception("存储消息失败")

View File

@@ -79,10 +79,13 @@ def init_prompt():
{identity} {identity}
{action_descriptions} {action_descriptions}
你现在的主要任务是和 {sender_name} 聊天。同时,也有其他用户会参与你们的聊天,你可以参考他们的回复内容,但是你主要还是关注你和{sender_name}的聊天内容。 你现在的主要任务是和 {sender_name} 聊天。同时,也有其他用户会参与聊天,你可以参考他们的回复内容,但是你主要还是关注你和{sender_name}的聊天内容。
{time_block}
这是所有聊天内容:
{background_dialogue_prompt} {background_dialogue_prompt}
-------------------------------- --------------------------------
{time_block} {time_block}
这是你和{sender_name}的对话,你们正在交流中: 这是你和{sender_name}的对话,你们正在交流中:
@@ -585,8 +588,8 @@ class DefaultReplyer:
# 构建背景对话 prompt # 构建背景对话 prompt
background_dialogue_prompt = "" background_dialogue_prompt = ""
if background_dialogue_list: if message_list_before_now:
latest_25_msgs = background_dialogue_list[-int(global_config.chat.max_context_size * 0.5) :] latest_25_msgs = message_list_before_now[-int(global_config.chat.max_context_size * 0.5) :]
background_dialogue_prompt_str = build_readable_messages( background_dialogue_prompt_str = build_readable_messages(
latest_25_msgs, latest_25_msgs,
replace_bot_name=True, replace_bot_name=True,

View File

@@ -130,6 +130,9 @@ class Messages(BaseModel):
reply_to = TextField(null=True) reply_to = TextField(null=True)
interest_value = DoubleField(null=True) interest_value = DoubleField(null=True)
key_words = TextField(null=True)
key_words_lite = TextField(null=True)
is_mentioned = BooleanField(null=True) is_mentioned = BooleanField(null=True)
# 从 chat_info 扁平化而来的字段 # 从 chat_info 扁平化而来的字段

View File

@@ -32,8 +32,7 @@ class NoReplyAction(BaseAction):
action_name = "no_reply" action_name = "no_reply"
action_description = "暂时不回复消息" action_description = "暂时不回复消息"
# 最近三次no_reply的新消息兴趣度记录
_recent_interest_records: deque = deque(maxlen=3)
# 兴趣值退出阈值 # 兴趣值退出阈值
_interest_exit_threshold = 3.0 _interest_exit_threshold = 3.0
@@ -75,15 +74,3 @@ class NoReplyAction(BaseAction):
action_done=True, action_done=True,
) )
return False, f"不回复动作执行失败: {e}" return False, f"不回复动作执行失败: {e}"
@classmethod
def reset_consecutive_count(cls):
"""重置连续计数器和兴趣度记录"""
cls._recent_interest_records.clear()
logger.debug("NoReplyAction连续计数器和兴趣度记录已重置")
@classmethod
def get_recent_interest_records(cls) -> List[float]:
"""获取最近的兴趣度记录"""
return list(cls._recent_interest_records)