fix typo, refactor memory_system
This commit is contained in:
@@ -207,7 +207,7 @@ class HeartFChatting:
|
|||||||
timestamp_end=time.time(),
|
timestamp_end=time.time(),
|
||||||
limit=10,
|
limit=10,
|
||||||
limit_mode="earliest",
|
limit_mode="earliest",
|
||||||
fliter_bot=True,
|
filter_bot=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(new_messages_data) > 4 * global_config.chat.auto_focus_threshold:
|
if len(new_messages_data) > 4 * global_config.chat.auto_focus_threshold:
|
||||||
@@ -296,13 +296,13 @@ class HeartFChatting:
|
|||||||
content = " ".join([item[1] for item in response_set if item[0] == "text"])
|
content = " ".join([item[1] for item in response_set if item[0] == "text"])
|
||||||
|
|
||||||
# 模型炸了,没有回复内容生成
|
# 模型炸了,没有回复内容生成
|
||||||
if not response_set or (action_type not in ["no_action"] and not is_parallel):
|
if not response_set:
|
||||||
if not response_set:
|
logger.warning(f"[{self.log_prefix}] 模型未生成回复内容")
|
||||||
logger.warning(f"[{self.log_prefix}] 模型未生成回复内容")
|
return False
|
||||||
elif action_type not in ["no_action"] and not is_parallel:
|
elif action_type not in ["no_action"] and not is_parallel:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{self.log_prefix}] {global_config.bot.nickname} 原本想要回复:{content},但选择执行{action_type},不发表回复"
|
f"[{self.log_prefix}] {global_config.bot.nickname} 原本想要回复:{content},但选择执行{action_type},不发表回复"
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
logger.info(f"[{self.log_prefix}] {global_config.bot.nickname} 决定的回复内容: {content}")
|
logger.info(f"[{self.log_prefix}] {global_config.bot.nickname} 决定的回复内容: {content}")
|
||||||
|
|||||||
@@ -832,10 +832,9 @@ class EntorhinalCortex:
|
|||||||
def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list | None:
|
def random_get_msg_snippet(target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list | None:
|
||||||
# sourcery skip: invert-any-all, use-any, use-named-expression, use-next
|
# sourcery skip: invert-any-all, use-any, use-named-expression, use-next
|
||||||
"""从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)"""
|
"""从数据库中随机获取指定时间戳附近的消息片段 (使用 chat_message_builder)"""
|
||||||
try_count = 0
|
|
||||||
time_window_seconds = random.randint(300, 1800) # 随机时间窗口,5到30分钟
|
time_window_seconds = random.randint(300, 1800) # 随机时间窗口,5到30分钟
|
||||||
|
|
||||||
while try_count < 3:
|
for _ in range(3):
|
||||||
# 定义时间范围:从目标时间戳开始,向后推移 time_window_seconds
|
# 定义时间范围:从目标时间戳开始,向后推移 time_window_seconds
|
||||||
timestamp_start = target_timestamp
|
timestamp_start = target_timestamp
|
||||||
timestamp_end = target_timestamp + time_window_seconds
|
timestamp_end = target_timestamp + time_window_seconds
|
||||||
@@ -874,8 +873,6 @@ class EntorhinalCortex:
|
|||||||
).execute()
|
).execute()
|
||||||
return messages # 直接返回原始的消息列表
|
return messages # 直接返回原始的消息列表
|
||||||
|
|
||||||
# 如果获取失败或消息无效,增加尝试次数
|
|
||||||
try_count += 1
|
|
||||||
target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试
|
target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试
|
||||||
|
|
||||||
# 三次尝试都失败,返回 None
|
# 三次尝试都失败,返回 None
|
||||||
@@ -1067,19 +1064,17 @@ class EntorhinalCortex:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
memory_items = [str(item) for item in memory_items]
|
memory_items = [str(item) for item in memory_items]
|
||||||
memory_items_json = json.dumps(memory_items, ensure_ascii=False)
|
if memory_items_json := json.dumps(memory_items, ensure_ascii=False):
|
||||||
if not memory_items_json:
|
nodes_data.append(
|
||||||
continue
|
{
|
||||||
|
"concept": concept,
|
||||||
|
"memory_items": memory_items_json,
|
||||||
|
"hash": self.hippocampus.calculate_node_hash(concept, memory_items),
|
||||||
|
"created_time": data.get("created_time", current_time),
|
||||||
|
"last_modified": data.get("last_modified", current_time),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
nodes_data.append(
|
|
||||||
{
|
|
||||||
"concept": concept,
|
|
||||||
"memory_items": memory_items_json,
|
|
||||||
"hash": self.hippocampus.calculate_node_hash(concept, memory_items),
|
|
||||||
"created_time": data.get("created_time", current_time),
|
|
||||||
"last_modified": data.get("last_modified", current_time),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"准备节点 {concept} 数据时发生错误: {e}")
|
logger.error(f"准备节点 {concept} 数据时发生错误: {e}")
|
||||||
continue
|
continue
|
||||||
@@ -1271,7 +1266,7 @@ class ParahippocampalGyrus:
|
|||||||
|
|
||||||
# 3. 过滤掉包含禁用关键词的topic
|
# 3. 过滤掉包含禁用关键词的topic
|
||||||
filtered_topics = [
|
filtered_topics = [
|
||||||
topic for topic in topics if not any(keyword in topic for keyword in global_config.memory.memory_ban_words)
|
topic for topic in topics if all(keyword not in topic for keyword in global_config.memory.memory_ban_words)
|
||||||
]
|
]
|
||||||
|
|
||||||
logger.debug(f"过滤后话题: {filtered_topics}")
|
logger.debug(f"过滤后话题: {filtered_topics}")
|
||||||
|
|||||||
@@ -66,61 +66,61 @@ class MemoryBuildScheduler:
|
|||||||
return [int(t.timestamp()) for t in timestamps]
|
return [int(t.timestamp()) for t in timestamps]
|
||||||
|
|
||||||
|
|
||||||
def print_time_samples(timestamps, show_distribution=True):
|
# def print_time_samples(timestamps, show_distribution=True):
|
||||||
"""打印时间样本和分布信息"""
|
# """打印时间样本和分布信息"""
|
||||||
print(f"\n生成的{len(timestamps)}个时间点分布:")
|
# print(f"\n生成的{len(timestamps)}个时间点分布:")
|
||||||
print("序号".ljust(5), "时间戳".ljust(25), "距现在(小时)")
|
# print("序号".ljust(5), "时间戳".ljust(25), "距现在(小时)")
|
||||||
print("-" * 50)
|
# print("-" * 50)
|
||||||
|
|
||||||
now = datetime.now()
|
# now = datetime.now()
|
||||||
time_diffs = []
|
# time_diffs = []
|
||||||
|
|
||||||
for i, timestamp in enumerate(timestamps, 1):
|
# for i, timestamp in enumerate(timestamps, 1):
|
||||||
hours_diff = (now - timestamp).total_seconds() / 3600
|
# hours_diff = (now - timestamp).total_seconds() / 3600
|
||||||
time_diffs.append(hours_diff)
|
# time_diffs.append(hours_diff)
|
||||||
print(f"{str(i).ljust(5)} {timestamp.strftime('%Y-%m-%d %H:%M:%S').ljust(25)} {hours_diff:.2f}")
|
# print(f"{str(i).ljust(5)} {timestamp.strftime('%Y-%m-%d %H:%M:%S').ljust(25)} {hours_diff:.2f}")
|
||||||
|
|
||||||
# 打印统计信息
|
# # 打印统计信息
|
||||||
print("\n统计信息:")
|
# print("\n统计信息:")
|
||||||
print(f"平均时间偏移:{np.mean(time_diffs):.2f}小时")
|
# print(f"平均时间偏移:{np.mean(time_diffs):.2f}小时")
|
||||||
print(f"标准差:{np.std(time_diffs):.2f}小时")
|
# print(f"标准差:{np.std(time_diffs):.2f}小时")
|
||||||
print(f"最早时间:{min(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({max(time_diffs):.2f}小时前)")
|
# print(f"最早时间:{min(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({max(time_diffs):.2f}小时前)")
|
||||||
print(f"最近时间:{max(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({min(time_diffs):.2f}小时前)")
|
# print(f"最近时间:{max(timestamps).strftime('%Y-%m-%d %H:%M:%S')} ({min(time_diffs):.2f}小时前)")
|
||||||
|
|
||||||
if show_distribution:
|
# if show_distribution:
|
||||||
# 计算时间分布的直方图
|
# # 计算时间分布的直方图
|
||||||
hist, bins = np.histogram(time_diffs, bins=40)
|
# hist, bins = np.histogram(time_diffs, bins=40)
|
||||||
print("\n时间分布(每个*代表一个时间点):")
|
# print("\n时间分布(每个*代表一个时间点):")
|
||||||
for i in range(len(hist)):
|
# for i in range(len(hist)):
|
||||||
if hist[i] > 0:
|
# if hist[i] > 0:
|
||||||
print(f"{bins[i]:6.1f}-{bins[i + 1]:6.1f}小时: {'*' * int(hist[i])}")
|
# print(f"{bins[i]:6.1f}-{bins[i + 1]:6.1f}小时: {'*' * int(hist[i])}")
|
||||||
|
|
||||||
|
|
||||||
# 使用示例
|
# # 使用示例
|
||||||
if __name__ == "__main__":
|
# if __name__ == "__main__":
|
||||||
# 创建一个双峰分布的记忆调度器
|
# # 创建一个双峰分布的记忆调度器
|
||||||
scheduler = MemoryBuildScheduler(
|
# scheduler = MemoryBuildScheduler(
|
||||||
n_hours1=12, # 第一个分布均值(12小时前)
|
# n_hours1=12, # 第一个分布均值(12小时前)
|
||||||
std_hours1=8, # 第一个分布标准差
|
# std_hours1=8, # 第一个分布标准差
|
||||||
weight1=0.7, # 第一个分布权重 70%
|
# weight1=0.7, # 第一个分布权重 70%
|
||||||
n_hours2=36, # 第二个分布均值(36小时前)
|
# n_hours2=36, # 第二个分布均值(36小时前)
|
||||||
std_hours2=24, # 第二个分布标准差
|
# std_hours2=24, # 第二个分布标准差
|
||||||
weight2=0.3, # 第二个分布权重 30%
|
# weight2=0.3, # 第二个分布权重 30%
|
||||||
total_samples=50, # 总共生成50个时间点
|
# total_samples=50, # 总共生成50个时间点
|
||||||
)
|
# )
|
||||||
|
|
||||||
# 生成时间分布
|
# # 生成时间分布
|
||||||
timestamps = scheduler.generate_time_samples()
|
# timestamps = scheduler.generate_time_samples()
|
||||||
|
|
||||||
# 打印结果,包含分布可视化
|
# # 打印结果,包含分布可视化
|
||||||
print_time_samples(timestamps, show_distribution=True)
|
# print_time_samples(timestamps, show_distribution=True)
|
||||||
|
|
||||||
# 打印时间戳数组
|
# # 打印时间戳数组
|
||||||
timestamp_array = scheduler.get_timestamp_array()
|
# timestamp_array = scheduler.get_timestamp_array()
|
||||||
print("\n时间戳数组(Unix时间戳):")
|
# print("\n时间戳数组(Unix时间戳):")
|
||||||
print("[", end="")
|
# print("[", end="")
|
||||||
for i, ts in enumerate(timestamp_array):
|
# for i, ts in enumerate(timestamp_array):
|
||||||
if i > 0:
|
# if i > 0:
|
||||||
print(", ", end="")
|
# print(", ", end="")
|
||||||
print(ts, end="")
|
# print(ts, end="")
|
||||||
print("]")
|
# print("]")
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ def get_raw_msg_by_timestamp_with_chat(
|
|||||||
timestamp_end: float,
|
timestamp_end: float,
|
||||||
limit: int = 0,
|
limit: int = 0,
|
||||||
limit_mode: str = "latest",
|
limit_mode: str = "latest",
|
||||||
fliter_bot=False,
|
filter_bot=False,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""获取在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
"""获取在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||||
limit: 限制返回的消息数量,0为不限制
|
limit: 限制返回的消息数量,0为不限制
|
||||||
@@ -46,7 +46,7 @@ def get_raw_msg_by_timestamp_with_chat(
|
|||||||
sort_order = [("time", 1)] if limit == 0 else None
|
sort_order = [("time", 1)] if limit == 0 else None
|
||||||
# 直接将 limit_mode 传递给 find_messages
|
# 直接将 limit_mode 传递给 find_messages
|
||||||
return find_messages(
|
return find_messages(
|
||||||
message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, fliter_bot=fliter_bot
|
message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, filter_bot=filter_bot
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -56,7 +56,7 @@ def get_raw_msg_by_timestamp_with_chat_inclusive(
|
|||||||
timestamp_end: float,
|
timestamp_end: float,
|
||||||
limit: int = 0,
|
limit: int = 0,
|
||||||
limit_mode: str = "latest",
|
limit_mode: str = "latest",
|
||||||
fliter_bot=False,
|
filter_bot=False,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""获取在特定聊天从指定时间戳到指定时间戳的消息(包含边界),按时间升序排序,返回消息列表
|
"""获取在特定聊天从指定时间戳到指定时间戳的消息(包含边界),按时间升序排序,返回消息列表
|
||||||
limit: 限制返回的消息数量,0为不限制
|
limit: 限制返回的消息数量,0为不限制
|
||||||
@@ -68,7 +68,7 @@ def get_raw_msg_by_timestamp_with_chat_inclusive(
|
|||||||
# 直接将 limit_mode 传递给 find_messages
|
# 直接将 limit_mode 传递给 find_messages
|
||||||
|
|
||||||
return find_messages(
|
return find_messages(
|
||||||
message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, fliter_bot=fliter_bot
|
message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, filter_bot=filter_bot
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ def find_messages(
|
|||||||
sort: Optional[List[tuple[str, int]]] = None,
|
sort: Optional[List[tuple[str, int]]] = None,
|
||||||
limit: int = 0,
|
limit: int = 0,
|
||||||
limit_mode: str = "latest",
|
limit_mode: str = "latest",
|
||||||
fliter_bot=False,
|
filter_bot=False,
|
||||||
) -> List[dict[str, Any]]:
|
) -> List[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
根据提供的过滤器、排序和限制条件查找消息。
|
根据提供的过滤器、排序和限制条件查找消息。
|
||||||
@@ -72,7 +72,7 @@ def find_messages(
|
|||||||
if conditions:
|
if conditions:
|
||||||
query = query.where(*conditions)
|
query = query.where(*conditions)
|
||||||
|
|
||||||
if fliter_bot:
|
if filter_bot:
|
||||||
query = query.where(Messages.user_id != global_config.bot.qq_account)
|
query = query.where(Messages.user_id != global_config.bot.qq_account)
|
||||||
|
|
||||||
if limit > 0:
|
if limit > 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user