fix: remove duplicate message(CR comments)

This commit is contained in:
AL76
2025-03-10 11:46:59 +08:00
parent 2f2be5b3ad
commit a43f9495ea
8 changed files with 171 additions and 167 deletions

View File

@@ -135,7 +135,7 @@ class BotConfig:
try:
config_version: str = toml["inner"]["version"]
except KeyError as e:
logger.error(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件")
logger.error(f"配置文件中 inner 段 不存在, 这是错误的配置文件")
raise KeyError(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件")
else:
toml["inner"] = {"version": "0.0.0"}
@@ -246,11 +246,11 @@ class BotConfig:
try:
cfg_target[i] = cfg_item[i]
except KeyError as e:
logger.error(f"{item} 中的必要字段 {e} 不存在,请检查")
logger.error(f"{item} 中的必要字段不存在,请检查")
raise KeyError(f"{item} 中的必要字段 {e} 不存在,请检查")
provider = cfg_item.get("provider")
if provider == None:
if provider is None:
logger.error(f"provider 字段在模型配置 {item} 中不存在,请检查")
raise KeyError(f"provider 字段在模型配置 {item} 中不存在,请检查")

View File

@@ -93,8 +93,8 @@ class ResponseGenerator:
# 生成回复
try:
content, reasoning_content = await model.generate_response(prompt)
except Exception as e:
logger.exception(f"生成回复时出错: {e}")
except Exception:
logger.exception(f"生成回复时出错")
return None
# 保存到数据库
@@ -145,8 +145,8 @@ class ResponseGenerator:
else:
return ["neutral"]
except Exception as e:
logger.exception(f"获取情感标签时出错: {e}")
except Exception:
logger.exception(f"获取情感标签时出错")
return ["neutral"]
async def _process_response(self, content: str) -> Tuple[List[str], List[str]]:

View File

@@ -119,8 +119,8 @@ class MessageContainer:
self.messages.remove(message)
return True
return False
except Exception as e:
logger.exception(f"移除消息时发生错误: {e}")
except Exception:
logger.exception(f"移除消息时发生错误")
return False
def has_messages(self) -> bool:
@@ -213,8 +213,8 @@ class MessageManager:
# 安全地移除消息
if not container.remove_message(msg):
logger.warning("尝试删除不存在的消息")
except Exception as e:
logger.exception(f"处理超时消息时发生错误: {e}")
except Exception:
logger.exception(f"处理超时消息时发生错误")
continue
async def start_processor(self):

View File

@@ -2,12 +2,13 @@ from typing import Optional
from ...common.database import Database
from .message import Message
from loguru import logger
class MessageStorage:
def __init__(self):
self.db = Database.get_instance()
async def store_message(self, message: Message, topic: Optional[str] = None) -> None:
"""存储消息到数据库"""
try:
@@ -41,9 +42,9 @@ class MessageStorage:
"topic": topic,
"detailed_plain_text": message.detailed_plain_text,
}
self.db.db.messages.insert_one(message_data)
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 存储消息失败: {e}")
# 如果需要其他存储相关的函数,可以在这里添加
self.db.db.messages.insert_one(message_data)
except Exception:
logger.exception(f"存储消息失败")
# 如果需要其他存储相关的函数,可以在这里添加

View File

@@ -7,6 +7,7 @@ import jieba
import matplotlib.pyplot as plt
import networkx as nx
from dotenv import load_dotenv
from loguru import logger
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
from src.common.database import Database # 使用正确的导入语法
@@ -15,15 +16,15 @@ from src.common.database import Database # 使用正确的导入语法
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev')
load_dotenv(env_path)
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
self.db = Database.get_instance()
def connect_dot(self, concept1, concept2):
self.G.add_edge(concept1, concept2)
def add_dot(self, concept, memory):
if concept in self.G:
# 如果节点已存在,将新记忆添加到现有列表中
@@ -37,7 +38,7 @@ class Memory_graph:
else:
# 如果是新节点,创建新的记忆列表
self.G.add_node(concept, memory_items=[memory])
def get_dot(self, concept):
# 检查节点是否存在于图中
if concept in self.G:
@@ -45,20 +46,20 @@ class Memory_graph:
node_data = self.G.nodes[concept]
# print(node_data)
# 创建新的Memory_dot对象
return concept,node_data
return concept, node_data
return None
def get_related_item(self, topic, depth=1):
if topic not in self.G:
return [], []
first_layer_items = []
second_layer_items = []
# 获取相邻节点
neighbors = list(self.G.neighbors(topic))
# print(f"第一层: {topic}")
# 获取当前节点的记忆项
node_data = self.get_dot(topic)
if node_data:
@@ -69,7 +70,7 @@ class Memory_graph:
first_layer_items.extend(memory_items)
else:
first_layer_items.append(memory_items)
# 只在depth=2时获取第二层记忆
if depth >= 2:
# 获取相邻节点的记忆项
@@ -84,42 +85,44 @@ class Memory_graph:
second_layer_items.extend(memory_items)
else:
second_layer_items.append(memory_items)
return first_layer_items, second_layer_items
def store_memory(self):
for node in self.G.nodes():
dot_data = {
"concept": node
}
self.db.db.store_memory_dots.insert_one(dot_data)
@property
def dots(self):
# 返回所有节点对应的 Memory_dot 对象
return [self.get_dot(node) for node in self.G.nodes()]
def get_random_chat_from_db(self, length: int, timestamp: str):
# 从数据库中根据时间戳获取离其最近的聊天记录
chat_text = ''
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
logger.info(
f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
if closest_record:
closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
chat_record = list(
self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(
length))
for record in chat_record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
try:
displayname="[(%s)%s]%s" % (record["user_id"],record["user_nickname"],record["user_cardname"])
displayname = "[(%s)%s]%s" % (record["user_id"], record["user_nickname"], record["user_cardname"])
except:
displayname=record["user_nickname"] or "用户" + str(record["user_id"])
displayname = record["user_nickname"] or "用户" + str(record["user_id"])
chat_text += f'[{time_str}] {displayname}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
return chat_text
return [] # 如果没有找到记录,返回空列表
def save_graph_to_db(self):
@@ -166,53 +169,54 @@ def main():
password=os.getenv("MONGODB_PASSWORD", ""),
auth_source=os.getenv("MONGODB_AUTH_SOURCE", "")
)
memory_graph = Memory_graph()
memory_graph.load_graph_from_db()
# 只显示一次优化后的图形
visualize_graph_lite(memory_graph)
while True:
query = input("请输入新的查询概念(输入'退出'以结束):")
if query.lower() == '退出':
break
first_layer_items, second_layer_items = memory_graph.get_related_item(query)
if first_layer_items or second_layer_items:
print("\n第一层记忆:")
logger.debug("第一层记忆:")
for item in first_layer_items:
print(item)
print("\n第二层记忆:")
logger.debug(item)
logger.debug("第二层记忆:")
for item in second_layer_items:
print(item)
logger.debug(item)
else:
print("未找到相关记忆。")
logger.debug("未找到相关记忆。")
def segment_text(text):
seg_text = list(jieba.cut(text))
return seg_text
return seg_text
def find_topic(text, topic_num):
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
return prompt
def topic_what(text, topic):
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
return prompt
def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
G = memory_graph.G
# 创建一个新图用于可视化
H = G.copy()
# 移除只有一条记忆的节点和连接数少于3的节点
nodes_to_remove = []
for node in H.nodes():
@@ -221,14 +225,14 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
degree = H.degree(node)
if memory_count < 3 or degree < 2: # 改为小于2而不是小于等于2
nodes_to_remove.append(node)
H.remove_nodes_from(nodes_to_remove)
# 如果过滤后没有节点,则返回
if len(H.nodes()) == 0:
print("过滤后没有符合条件的节点可显示")
logger.debug("过滤后没有符合条件的节点可显示")
return
# 保存图到本地
# nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
@@ -236,7 +240,7 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
node_colors = []
node_sizes = []
nodes = list(H.nodes())
# 获取最大记忆数和最大度数用于归一化
max_memories = 1
max_degree = 1
@@ -246,7 +250,7 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
degree = H.degree(node)
max_memories = max(max_memories, memory_count)
max_degree = max(max_degree, degree)
# 计算每个节点的大小和颜色
for node in nodes:
# 计算节点大小(基于记忆数量)
@@ -254,9 +258,9 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
# 使用指数函数使变化更明显
ratio = memory_count / max_memories
size = 500 + 5000 * (ratio ) # 使用1.5次方函数使差异不那么明显
size = 500 + 5000 * (ratio) # 使用1.5次方函数使差异不那么明显
node_sizes.append(size)
# 计算节点颜色(基于连接数)
degree = H.degree(node)
# 红色分量随着度数增加而增加
@@ -267,26 +271,25 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
# blue = 1
color = (red, 0.1, blue)
node_colors.append(color)
# 绘制图形
plt.figure(figsize=(12, 8))
pos = nx.spring_layout(H, k=1, iterations=50) # 增加k值使节点分布更开
nx.draw(H, pos,
with_labels=True,
node_color=node_colors,
node_size=node_sizes,
font_size=10,
font_family='SimHei',
font_weight='bold',
edge_color='gray',
width=0.5,
alpha=0.9)
nx.draw(H, pos,
with_labels=True,
node_color=node_colors,
node_size=node_sizes,
font_size=10,
font_family='SimHei',
font_weight='bold',
edge_color='gray',
width=0.5,
alpha=0.9)
title = '记忆图谱可视化 - 节点大小表示记忆数量,颜色表示连接数'
plt.title(title, fontsize=16, fontfamily='SimHei')
plt.show()
if __name__ == "__main__":
main()
main()

View File

@@ -45,7 +45,7 @@ class LLM_request:
self.db.db.llm_usage.create_index([("user_id", 1)])
self.db.db.llm_usage.create_index([("request_type", 1)])
except Exception as e:
logger.error(f"创建数据库索引失败: {e}")
logger.error(f"创建数据库索引失败")
def _record_usage(self, prompt_tokens: int, completion_tokens: int, total_tokens: int,
user_id: str = "system", request_type: str = "chat",
@@ -79,8 +79,8 @@ class LLM_request:
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
f"总计: {total_tokens}"
)
except Exception as e:
logger.error(f"记录token使用情况失败: {e}")
except Exception:
logger.error(f"记录token使用情况失败")
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
"""计算API调用成本
@@ -226,8 +226,8 @@ class LLM_request:
if delta_content is None:
delta_content = ""
accumulated_content += delta_content
except Exception as e:
logger.error(f"解析流式输出错误: {e}")
except Exception:
logger.exception(f"解析流式输出错")
content = accumulated_content
reasoning_content = ""
think_match = re.search(r'<think>(.*?)</think>', content, re.DOTALL)

View File

@@ -3,6 +3,7 @@ import time
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any, Dict
from loguru import logger
from ...common.database import Database
@@ -153,8 +154,8 @@ class LLMStatistics:
try:
all_stats = self._collect_all_statistics()
self._save_statistics(all_stats)
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 统计数据处理失败: {e}")
except Exception:
logger.exception(f"统计数据处理失败")
# 等待1分钟
for _ in range(60):