fix: remove duplicate message(CR comments)
This commit is contained in:
@@ -5,6 +5,9 @@ import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, List
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
from pymongo import MongoClient
|
||||
|
||||
import customtkinter as ctk
|
||||
from dotenv import load_dotenv
|
||||
@@ -17,23 +20,20 @@ root_dir = os.path.abspath(os.path.join(current_dir, '..', '..'))
|
||||
# 加载环境变量
|
||||
if os.path.exists(os.path.join(root_dir, '.env.dev')):
|
||||
load_dotenv(os.path.join(root_dir, '.env.dev'))
|
||||
print("成功加载开发环境配置")
|
||||
logger.info("成功加载开发环境配置")
|
||||
elif os.path.exists(os.path.join(root_dir, '.env.prod')):
|
||||
load_dotenv(os.path.join(root_dir, '.env.prod'))
|
||||
print("成功加载生产环境配置")
|
||||
logger.info("成功加载生产环境配置")
|
||||
else:
|
||||
print("未找到环境配置文件")
|
||||
logger.error("未找到环境配置文件")
|
||||
sys.exit(1)
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pymongo import MongoClient
|
||||
|
||||
|
||||
class Database:
|
||||
_instance: Optional["Database"] = None
|
||||
|
||||
def __init__(self, host: str, port: int, db_name: str, username: str = None, password: str = None, auth_source: str = None):
|
||||
|
||||
def __init__(self, host: str, port: int, db_name: str, username: str = None, password: str = None,
|
||||
auth_source: str = None):
|
||||
if username and password:
|
||||
self.client = MongoClient(
|
||||
host=host,
|
||||
@@ -45,96 +45,96 @@ class Database:
|
||||
else:
|
||||
self.client = MongoClient(host, port)
|
||||
self.db = self.client[db_name]
|
||||
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, host: str, port: int, db_name: str, username: str = None, password: str = None, auth_source: str = None) -> "Database":
|
||||
def initialize(cls, host: str, port: int, db_name: str, username: str = None, password: str = None,
|
||||
auth_source: str = None) -> "Database":
|
||||
if cls._instance is None:
|
||||
cls._instance = cls(host, port, db_name, username, password, auth_source)
|
||||
return cls._instance
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "Database":
|
||||
if cls._instance is None:
|
||||
raise RuntimeError("Database not initialized")
|
||||
return cls._instance
|
||||
|
||||
return cls._instance
|
||||
|
||||
|
||||
class ReasoningGUI:
|
||||
def __init__(self):
|
||||
# 记录启动时间戳,转换为Unix时间戳
|
||||
self.start_timestamp = datetime.now().timestamp()
|
||||
print(f"程序启动时间戳: {self.start_timestamp}")
|
||||
|
||||
logger.info(f"程序启动时间戳: {self.start_timestamp}")
|
||||
|
||||
# 设置主题
|
||||
ctk.set_appearance_mode("dark")
|
||||
ctk.set_default_color_theme("blue")
|
||||
|
||||
|
||||
# 创建主窗口
|
||||
self.root = ctk.CTk()
|
||||
self.root.title('麦麦推理')
|
||||
self.root.geometry('800x600')
|
||||
self.root.protocol("WM_DELETE_WINDOW", self._on_closing)
|
||||
|
||||
|
||||
# 初始化数据库连接
|
||||
try:
|
||||
self.db = Database.get_instance().db
|
||||
print("数据库连接成功")
|
||||
logger.success("数据库连接成功")
|
||||
except RuntimeError:
|
||||
print("数据库未初始化,正在尝试初始化...")
|
||||
logger.warning("数据库未初始化,正在尝试初始化...")
|
||||
try:
|
||||
Database.initialize("127.0.0.1", 27017, "maimai_bot")
|
||||
self.db = Database.get_instance().db
|
||||
print("数据库初始化成功")
|
||||
except Exception as e:
|
||||
print(f"数据库初始化失败: {e}")
|
||||
logger.success("数据库初始化成功")
|
||||
except Exception:
|
||||
logger.exception(f"数据库初始化失败")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# 存储群组数据
|
||||
self.group_data: Dict[str, List[dict]] = {}
|
||||
|
||||
|
||||
# 创建更新队列
|
||||
self.update_queue = queue.Queue()
|
||||
|
||||
|
||||
# 创建主框架
|
||||
self.frame = ctk.CTkFrame(self.root)
|
||||
self.frame.pack(pady=20, padx=20, fill="both", expand=True)
|
||||
|
||||
|
||||
# 添加标题
|
||||
self.title = ctk.CTkLabel(self.frame, text="麦麦的脑内所想", font=("Arial", 24))
|
||||
self.title.pack(pady=10, padx=10)
|
||||
|
||||
|
||||
# 创建左右分栏
|
||||
self.paned = ctk.CTkFrame(self.frame)
|
||||
self.paned.pack(fill="both", expand=True, padx=10, pady=10)
|
||||
|
||||
|
||||
# 左侧群组列表
|
||||
self.left_frame = ctk.CTkFrame(self.paned, width=200)
|
||||
self.left_frame.pack(side="left", fill="y", padx=5, pady=5)
|
||||
|
||||
|
||||
self.group_label = ctk.CTkLabel(self.left_frame, text="群组列表", font=("Arial", 16))
|
||||
self.group_label.pack(pady=5)
|
||||
|
||||
|
||||
# 创建可滚动框架来容纳群组按钮
|
||||
self.group_scroll_frame = ctk.CTkScrollableFrame(self.left_frame, width=180, height=400)
|
||||
self.group_scroll_frame.pack(pady=5, padx=5, fill="both", expand=True)
|
||||
|
||||
|
||||
# 存储群组按钮的字典
|
||||
self.group_buttons: Dict[str, ctk.CTkButton] = {}
|
||||
# 当前选中的群组ID
|
||||
self.selected_group_id: Optional[str] = None
|
||||
|
||||
|
||||
# 右侧内容显示
|
||||
self.right_frame = ctk.CTkFrame(self.paned)
|
||||
self.right_frame.pack(side="right", fill="both", expand=True, padx=5, pady=5)
|
||||
|
||||
|
||||
self.content_label = ctk.CTkLabel(self.right_frame, text="推理内容", font=("Arial", 16))
|
||||
self.content_label.pack(pady=5)
|
||||
|
||||
|
||||
# 创建富文本显示框
|
||||
self.content_text = ctk.CTkTextbox(self.right_frame, width=500, height=400)
|
||||
self.content_text.pack(pady=5, padx=5, fill="both", expand=True)
|
||||
|
||||
|
||||
# 配置文本标签 - 只使用颜色
|
||||
self.content_text.tag_config("timestamp", foreground="#888888") # 时间戳使用灰色
|
||||
self.content_text.tag_config("user", foreground="#4CAF50") # 用户名使用绿色
|
||||
@@ -144,11 +144,11 @@ class ReasoningGUI:
|
||||
self.content_text.tag_config("reasoning", foreground="#FF9800") # 推理过程使用橙色
|
||||
self.content_text.tag_config("response", foreground="#E91E63") # 回复使用粉色
|
||||
self.content_text.tag_config("separator", foreground="#666666") # 分隔符使用深灰色
|
||||
|
||||
|
||||
# 底部控制栏
|
||||
self.control_frame = ctk.CTkFrame(self.frame)
|
||||
self.control_frame.pack(fill="x", padx=10, pady=5)
|
||||
|
||||
|
||||
self.clear_button = ctk.CTkButton(
|
||||
self.control_frame,
|
||||
text="清除显示",
|
||||
@@ -156,19 +156,19 @@ class ReasoningGUI:
|
||||
width=120
|
||||
)
|
||||
self.clear_button.pack(side="left", padx=5)
|
||||
|
||||
|
||||
# 启动自动更新线程
|
||||
self.update_thread = threading.Thread(target=self._auto_update, daemon=True)
|
||||
self.update_thread.start()
|
||||
|
||||
|
||||
# 启动GUI更新检查
|
||||
self.root.after(100, self._process_queue)
|
||||
|
||||
|
||||
def _on_closing(self):
|
||||
"""处理窗口关闭事件"""
|
||||
self.root.quit()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def _process_queue(self):
|
||||
"""处理更新队列中的任务"""
|
||||
try:
|
||||
@@ -183,14 +183,14 @@ class ReasoningGUI:
|
||||
finally:
|
||||
# 继续检查队列
|
||||
self.root.after(100, self._process_queue)
|
||||
|
||||
|
||||
def _update_group_list_gui(self):
|
||||
"""在主线程中更新群组列表"""
|
||||
# 清除现有按钮
|
||||
for button in self.group_buttons.values():
|
||||
button.destroy()
|
||||
self.group_buttons.clear()
|
||||
|
||||
|
||||
# 创建新的群组按钮
|
||||
for group_id in self.group_data.keys():
|
||||
button = ctk.CTkButton(
|
||||
@@ -203,16 +203,16 @@ class ReasoningGUI:
|
||||
)
|
||||
button.pack(pady=2, padx=5)
|
||||
self.group_buttons[group_id] = button
|
||||
|
||||
|
||||
# 如果有选中的群组,保持其高亮状态
|
||||
if self.selected_group_id and self.selected_group_id in self.group_buttons:
|
||||
self._highlight_selected_group(self.selected_group_id)
|
||||
|
||||
|
||||
def _on_group_select(self, group_id: str):
|
||||
"""处理群组选择事件"""
|
||||
self._highlight_selected_group(group_id)
|
||||
self._update_display_gui(group_id)
|
||||
|
||||
|
||||
def _highlight_selected_group(self, group_id: str):
|
||||
"""高亮显示选中的群组按钮"""
|
||||
# 重置所有按钮的颜色
|
||||
@@ -223,9 +223,9 @@ class ReasoningGUI:
|
||||
else:
|
||||
# 恢复其他按钮的默认颜色
|
||||
button.configure(fg_color="#2B2B2B", hover_color="#404040")
|
||||
|
||||
|
||||
self.selected_group_id = group_id
|
||||
|
||||
|
||||
def _update_display_gui(self, group_id: str):
|
||||
"""在主线程中更新显示内容"""
|
||||
if group_id in self.group_data:
|
||||
@@ -234,19 +234,19 @@ class ReasoningGUI:
|
||||
# 时间戳
|
||||
time_str = item['time'].strftime("%Y-%m-%d %H:%M:%S")
|
||||
self.content_text.insert("end", f"[{time_str}]\n", "timestamp")
|
||||
|
||||
|
||||
# 用户信息
|
||||
self.content_text.insert("end", "用户: ", "timestamp")
|
||||
self.content_text.insert("end", f"{item.get('user', '未知')}\n", "user")
|
||||
|
||||
|
||||
# 消息内容
|
||||
self.content_text.insert("end", "消息: ", "timestamp")
|
||||
self.content_text.insert("end", f"{item.get('message', '')}\n", "message")
|
||||
|
||||
|
||||
# 模型信息
|
||||
self.content_text.insert("end", "模型: ", "timestamp")
|
||||
self.content_text.insert("end", f"{item.get('model', '')}\n", "model")
|
||||
|
||||
|
||||
# Prompt内容
|
||||
self.content_text.insert("end", "Prompt内容:\n", "timestamp")
|
||||
prompt_text = item.get('prompt', '')
|
||||
@@ -257,7 +257,7 @@ class ReasoningGUI:
|
||||
self.content_text.insert("end", " " + line + "\n", "prompt")
|
||||
else:
|
||||
self.content_text.insert("end", " 无Prompt内容\n", "prompt")
|
||||
|
||||
|
||||
# 推理过程
|
||||
self.content_text.insert("end", "推理过程:\n", "timestamp")
|
||||
reasoning_text = item.get('reasoning', '')
|
||||
@@ -268,53 +268,53 @@ class ReasoningGUI:
|
||||
self.content_text.insert("end", " " + line + "\n", "reasoning")
|
||||
else:
|
||||
self.content_text.insert("end", " 无推理过程\n", "reasoning")
|
||||
|
||||
|
||||
# 回复内容
|
||||
self.content_text.insert("end", "回复: ", "timestamp")
|
||||
self.content_text.insert("end", f"{item.get('response', '')}\n", "response")
|
||||
|
||||
|
||||
# 分隔符
|
||||
self.content_text.insert("end", f"\n{'='*50}\n\n", "separator")
|
||||
|
||||
self.content_text.insert("end", f"\n{'=' * 50}\n\n", "separator")
|
||||
|
||||
# 滚动到顶部
|
||||
self.content_text.see("1.0")
|
||||
|
||||
|
||||
def _auto_update(self):
|
||||
"""自动更新函数"""
|
||||
while True:
|
||||
try:
|
||||
# 从数据库获取最新数据,只获取启动时间之后的记录
|
||||
query = {"time": {"$gt": self.start_timestamp}}
|
||||
print(f"查询条件: {query}")
|
||||
|
||||
logger.debug(f"查询条件: {query}")
|
||||
|
||||
# 先获取一条记录检查时间格式
|
||||
sample = self.db.reasoning_logs.find_one()
|
||||
if sample:
|
||||
print(f"样本记录时间格式: {type(sample['time'])} 值: {sample['time']}")
|
||||
|
||||
logger.debug(f"样本记录时间格式: {type(sample['time'])} 值: {sample['time']}")
|
||||
|
||||
cursor = self.db.reasoning_logs.find(query).sort("time", -1)
|
||||
new_data = {}
|
||||
total_count = 0
|
||||
|
||||
|
||||
for item in cursor:
|
||||
# 调试输出
|
||||
if total_count == 0:
|
||||
print(f"记录时间: {item['time']}, 类型: {type(item['time'])}")
|
||||
|
||||
logger.debug(f"记录时间: {item['time']}, 类型: {type(item['time'])}")
|
||||
|
||||
total_count += 1
|
||||
group_id = str(item.get('group_id', 'unknown'))
|
||||
if group_id not in new_data:
|
||||
new_data[group_id] = []
|
||||
|
||||
|
||||
# 转换时间戳为datetime对象
|
||||
if isinstance(item['time'], (int, float)):
|
||||
time_obj = datetime.fromtimestamp(item['time'])
|
||||
elif isinstance(item['time'], datetime):
|
||||
time_obj = item['time']
|
||||
else:
|
||||
print(f"未知的时间格式: {type(item['time'])}")
|
||||
logger.warning(f"未知的时间格式: {type(item['time'])}")
|
||||
time_obj = datetime.now() # 使用当前时间作为后备
|
||||
|
||||
|
||||
new_data[group_id].append({
|
||||
'time': time_obj,
|
||||
'user': item.get('user', '未知'),
|
||||
@@ -324,13 +324,13 @@ class ReasoningGUI:
|
||||
'response': item.get('response', ''),
|
||||
'prompt': item.get('prompt', '') # 添加prompt字段
|
||||
})
|
||||
|
||||
print(f"从数据库加载了 {total_count} 条记录,分布在 {len(new_data)} 个群组中")
|
||||
|
||||
|
||||
logger.info(f"从数据库加载了 {total_count} 条记录,分布在 {len(new_data)} 个群组中")
|
||||
|
||||
# 更新数据
|
||||
if new_data != self.group_data:
|
||||
self.group_data = new_data
|
||||
print("数据已更新,正在刷新显示...")
|
||||
logger.info("数据已更新,正在刷新显示...")
|
||||
# 将更新任务添加到队列
|
||||
self.update_queue.put({'type': 'update_group_list'})
|
||||
if self.group_data:
|
||||
@@ -341,16 +341,16 @@ class ReasoningGUI:
|
||||
'type': 'update_display',
|
||||
'group_id': self.selected_group_id
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"自动更新出错: {e}")
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"自动更新出错")
|
||||
|
||||
# 每5秒更新一次
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
def clear_display(self):
|
||||
"""清除显示内容"""
|
||||
self.content_text.delete("1.0", "end")
|
||||
|
||||
|
||||
def run(self):
|
||||
"""运行GUI"""
|
||||
self.root.mainloop()
|
||||
@@ -359,18 +359,17 @@ class ReasoningGUI:
|
||||
def main():
|
||||
"""主函数"""
|
||||
Database.initialize(
|
||||
host= os.getenv("MONGODB_HOST"),
|
||||
port= int(os.getenv("MONGODB_PORT")),
|
||||
db_name= os.getenv("DATABASE_NAME"),
|
||||
username= os.getenv("MONGODB_USERNAME"),
|
||||
password= os.getenv("MONGODB_PASSWORD"),
|
||||
host=os.getenv("MONGODB_HOST"),
|
||||
port=int(os.getenv("MONGODB_PORT")),
|
||||
db_name=os.getenv("DATABASE_NAME"),
|
||||
username=os.getenv("MONGODB_USERNAME"),
|
||||
password=os.getenv("MONGODB_PASSWORD"),
|
||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
||||
)
|
||||
|
||||
|
||||
app = ReasoningGUI()
|
||||
app.run()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -135,7 +135,7 @@ class BotConfig:
|
||||
try:
|
||||
config_version: str = toml["inner"]["version"]
|
||||
except KeyError as e:
|
||||
logger.error(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件")
|
||||
logger.error(f"配置文件中 inner 段 不存在, 这是错误的配置文件")
|
||||
raise KeyError(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件")
|
||||
else:
|
||||
toml["inner"] = {"version": "0.0.0"}
|
||||
@@ -246,11 +246,11 @@ class BotConfig:
|
||||
try:
|
||||
cfg_target[i] = cfg_item[i]
|
||||
except KeyError as e:
|
||||
logger.error(f"{item} 中的必要字段 {e} 不存在,请检查")
|
||||
logger.error(f"{item} 中的必要字段不存在,请检查")
|
||||
raise KeyError(f"{item} 中的必要字段 {e} 不存在,请检查")
|
||||
|
||||
provider = cfg_item.get("provider")
|
||||
if provider == None:
|
||||
if provider is None:
|
||||
logger.error(f"provider 字段在模型配置 {item} 中不存在,请检查")
|
||||
raise KeyError(f"provider 字段在模型配置 {item} 中不存在,请检查")
|
||||
|
||||
|
||||
@@ -93,8 +93,8 @@ class ResponseGenerator:
|
||||
# 生成回复
|
||||
try:
|
||||
content, reasoning_content = await model.generate_response(prompt)
|
||||
except Exception as e:
|
||||
logger.exception(f"生成回复时出错: {e}")
|
||||
except Exception:
|
||||
logger.exception(f"生成回复时出错")
|
||||
return None
|
||||
|
||||
# 保存到数据库
|
||||
@@ -145,8 +145,8 @@ class ResponseGenerator:
|
||||
else:
|
||||
return ["neutral"]
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"获取情感标签时出错: {e}")
|
||||
except Exception:
|
||||
logger.exception(f"获取情感标签时出错")
|
||||
return ["neutral"]
|
||||
|
||||
async def _process_response(self, content: str) -> Tuple[List[str], List[str]]:
|
||||
|
||||
@@ -119,8 +119,8 @@ class MessageContainer:
|
||||
self.messages.remove(message)
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.exception(f"移除消息时发生错误: {e}")
|
||||
except Exception:
|
||||
logger.exception(f"移除消息时发生错误")
|
||||
return False
|
||||
|
||||
def has_messages(self) -> bool:
|
||||
@@ -213,8 +213,8 @@ class MessageManager:
|
||||
# 安全地移除消息
|
||||
if not container.remove_message(msg):
|
||||
logger.warning("尝试删除不存在的消息")
|
||||
except Exception as e:
|
||||
logger.exception(f"处理超时消息时发生错误: {e}")
|
||||
except Exception:
|
||||
logger.exception(f"处理超时消息时发生错误")
|
||||
continue
|
||||
|
||||
async def start_processor(self):
|
||||
|
||||
@@ -2,12 +2,13 @@ from typing import Optional
|
||||
|
||||
from ...common.database import Database
|
||||
from .message import Message
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class MessageStorage:
|
||||
def __init__(self):
|
||||
self.db = Database.get_instance()
|
||||
|
||||
|
||||
async def store_message(self, message: Message, topic: Optional[str] = None) -> None:
|
||||
"""存储消息到数据库"""
|
||||
try:
|
||||
@@ -41,9 +42,9 @@ class MessageStorage:
|
||||
"topic": topic,
|
||||
"detailed_plain_text": message.detailed_plain_text,
|
||||
}
|
||||
|
||||
self.db.db.messages.insert_one(message_data)
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 存储消息失败: {e}")
|
||||
|
||||
# 如果需要其他存储相关的函数,可以在这里添加
|
||||
self.db.db.messages.insert_one(message_data)
|
||||
except Exception:
|
||||
logger.exception(f"存储消息失败")
|
||||
|
||||
# 如果需要其他存储相关的函数,可以在这里添加
|
||||
|
||||
@@ -7,6 +7,7 @@ import jieba
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
|
||||
from src.common.database import Database # 使用正确的导入语法
|
||||
@@ -15,15 +16,15 @@ from src.common.database import Database # 使用正确的导入语法
|
||||
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev')
|
||||
load_dotenv(env_path)
|
||||
|
||||
|
||||
|
||||
class Memory_graph:
|
||||
def __init__(self):
|
||||
self.G = nx.Graph() # 使用 networkx 的图结构
|
||||
self.db = Database.get_instance()
|
||||
|
||||
|
||||
def connect_dot(self, concept1, concept2):
|
||||
self.G.add_edge(concept1, concept2)
|
||||
|
||||
|
||||
def add_dot(self, concept, memory):
|
||||
if concept in self.G:
|
||||
# 如果节点已存在,将新记忆添加到现有列表中
|
||||
@@ -37,7 +38,7 @@ class Memory_graph:
|
||||
else:
|
||||
# 如果是新节点,创建新的记忆列表
|
||||
self.G.add_node(concept, memory_items=[memory])
|
||||
|
||||
|
||||
def get_dot(self, concept):
|
||||
# 检查节点是否存在于图中
|
||||
if concept in self.G:
|
||||
@@ -45,20 +46,20 @@ class Memory_graph:
|
||||
node_data = self.G.nodes[concept]
|
||||
# print(node_data)
|
||||
# 创建新的Memory_dot对象
|
||||
return concept,node_data
|
||||
return concept, node_data
|
||||
return None
|
||||
|
||||
def get_related_item(self, topic, depth=1):
|
||||
if topic not in self.G:
|
||||
return [], []
|
||||
|
||||
|
||||
first_layer_items = []
|
||||
second_layer_items = []
|
||||
|
||||
|
||||
# 获取相邻节点
|
||||
neighbors = list(self.G.neighbors(topic))
|
||||
# print(f"第一层: {topic}")
|
||||
|
||||
|
||||
# 获取当前节点的记忆项
|
||||
node_data = self.get_dot(topic)
|
||||
if node_data:
|
||||
@@ -69,7 +70,7 @@ class Memory_graph:
|
||||
first_layer_items.extend(memory_items)
|
||||
else:
|
||||
first_layer_items.append(memory_items)
|
||||
|
||||
|
||||
# 只在depth=2时获取第二层记忆
|
||||
if depth >= 2:
|
||||
# 获取相邻节点的记忆项
|
||||
@@ -84,42 +85,44 @@ class Memory_graph:
|
||||
second_layer_items.extend(memory_items)
|
||||
else:
|
||||
second_layer_items.append(memory_items)
|
||||
|
||||
|
||||
return first_layer_items, second_layer_items
|
||||
|
||||
|
||||
def store_memory(self):
|
||||
for node in self.G.nodes():
|
||||
dot_data = {
|
||||
"concept": node
|
||||
}
|
||||
self.db.db.store_memory_dots.insert_one(dot_data)
|
||||
|
||||
|
||||
@property
|
||||
def dots(self):
|
||||
# 返回所有节点对应的 Memory_dot 对象
|
||||
return [self.get_dot(node) for node in self.G.nodes()]
|
||||
|
||||
|
||||
|
||||
def get_random_chat_from_db(self, length: int, timestamp: str):
|
||||
# 从数据库中根据时间戳获取离其最近的聊天记录
|
||||
chat_text = ''
|
||||
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
|
||||
print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
|
||||
|
||||
logger.info(
|
||||
f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
|
||||
|
||||
if closest_record:
|
||||
closest_time = closest_record['time']
|
||||
group_id = closest_record['group_id'] # 获取groupid
|
||||
# 获取该时间戳之后的length条消息,且groupid相同
|
||||
chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
|
||||
chat_record = list(
|
||||
self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(
|
||||
length))
|
||||
for record in chat_record:
|
||||
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
||||
try:
|
||||
displayname="[(%s)%s]%s" % (record["user_id"],record["user_nickname"],record["user_cardname"])
|
||||
displayname = "[(%s)%s]%s" % (record["user_id"], record["user_nickname"], record["user_cardname"])
|
||||
except:
|
||||
displayname=record["user_nickname"] or "用户" + str(record["user_id"])
|
||||
displayname = record["user_nickname"] or "用户" + str(record["user_id"])
|
||||
chat_text += f'[{time_str}] {displayname}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
|
||||
return chat_text
|
||||
|
||||
|
||||
return [] # 如果没有找到记录,返回空列表
|
||||
|
||||
def save_graph_to_db(self):
|
||||
@@ -166,53 +169,54 @@ def main():
|
||||
password=os.getenv("MONGODB_PASSWORD", ""),
|
||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE", "")
|
||||
)
|
||||
|
||||
|
||||
memory_graph = Memory_graph()
|
||||
memory_graph.load_graph_from_db()
|
||||
|
||||
|
||||
# 只显示一次优化后的图形
|
||||
visualize_graph_lite(memory_graph)
|
||||
|
||||
|
||||
while True:
|
||||
query = input("请输入新的查询概念(输入'退出'以结束):")
|
||||
if query.lower() == '退出':
|
||||
break
|
||||
first_layer_items, second_layer_items = memory_graph.get_related_item(query)
|
||||
if first_layer_items or second_layer_items:
|
||||
print("\n第一层记忆:")
|
||||
logger.debug("第一层记忆:")
|
||||
for item in first_layer_items:
|
||||
print(item)
|
||||
print("\n第二层记忆:")
|
||||
logger.debug(item)
|
||||
logger.debug("第二层记忆:")
|
||||
for item in second_layer_items:
|
||||
print(item)
|
||||
logger.debug(item)
|
||||
else:
|
||||
print("未找到相关记忆。")
|
||||
|
||||
logger.debug("未找到相关记忆。")
|
||||
|
||||
|
||||
def segment_text(text):
|
||||
seg_text = list(jieba.cut(text))
|
||||
return seg_text
|
||||
return seg_text
|
||||
|
||||
|
||||
def find_topic(text, topic_num):
|
||||
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
|
||||
return prompt
|
||||
|
||||
|
||||
def topic_what(text, topic):
|
||||
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
|
||||
return prompt
|
||||
|
||||
|
||||
|
||||
def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
|
||||
# 设置中文字体
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
|
||||
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
|
||||
|
||||
|
||||
G = memory_graph.G
|
||||
|
||||
|
||||
# 创建一个新图用于可视化
|
||||
H = G.copy()
|
||||
|
||||
|
||||
# 移除只有一条记忆的节点和连接数少于3的节点
|
||||
nodes_to_remove = []
|
||||
for node in H.nodes():
|
||||
@@ -221,14 +225,14 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
|
||||
degree = H.degree(node)
|
||||
if memory_count < 3 or degree < 2: # 改为小于2而不是小于等于2
|
||||
nodes_to_remove.append(node)
|
||||
|
||||
|
||||
H.remove_nodes_from(nodes_to_remove)
|
||||
|
||||
|
||||
# 如果过滤后没有节点,则返回
|
||||
if len(H.nodes()) == 0:
|
||||
print("过滤后没有符合条件的节点可显示")
|
||||
logger.debug("过滤后没有符合条件的节点可显示")
|
||||
return
|
||||
|
||||
|
||||
# 保存图到本地
|
||||
# nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
|
||||
|
||||
@@ -236,7 +240,7 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
|
||||
node_colors = []
|
||||
node_sizes = []
|
||||
nodes = list(H.nodes())
|
||||
|
||||
|
||||
# 获取最大记忆数和最大度数用于归一化
|
||||
max_memories = 1
|
||||
max_degree = 1
|
||||
@@ -246,7 +250,7 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
|
||||
degree = H.degree(node)
|
||||
max_memories = max(max_memories, memory_count)
|
||||
max_degree = max(max_degree, degree)
|
||||
|
||||
|
||||
# 计算每个节点的大小和颜色
|
||||
for node in nodes:
|
||||
# 计算节点大小(基于记忆数量)
|
||||
@@ -254,9 +258,9 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
|
||||
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
||||
# 使用指数函数使变化更明显
|
||||
ratio = memory_count / max_memories
|
||||
size = 500 + 5000 * (ratio ) # 使用1.5次方函数使差异不那么明显
|
||||
size = 500 + 5000 * (ratio) # 使用1.5次方函数使差异不那么明显
|
||||
node_sizes.append(size)
|
||||
|
||||
|
||||
# 计算节点颜色(基于连接数)
|
||||
degree = H.degree(node)
|
||||
# 红色分量随着度数增加而增加
|
||||
@@ -267,26 +271,25 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
|
||||
# blue = 1
|
||||
color = (red, 0.1, blue)
|
||||
node_colors.append(color)
|
||||
|
||||
|
||||
# 绘制图形
|
||||
plt.figure(figsize=(12, 8))
|
||||
pos = nx.spring_layout(H, k=1, iterations=50) # 增加k值使节点分布更开
|
||||
nx.draw(H, pos,
|
||||
with_labels=True,
|
||||
node_color=node_colors,
|
||||
node_size=node_sizes,
|
||||
font_size=10,
|
||||
font_family='SimHei',
|
||||
font_weight='bold',
|
||||
edge_color='gray',
|
||||
width=0.5,
|
||||
alpha=0.9)
|
||||
|
||||
nx.draw(H, pos,
|
||||
with_labels=True,
|
||||
node_color=node_colors,
|
||||
node_size=node_sizes,
|
||||
font_size=10,
|
||||
font_family='SimHei',
|
||||
font_weight='bold',
|
||||
edge_color='gray',
|
||||
width=0.5,
|
||||
alpha=0.9)
|
||||
|
||||
title = '记忆图谱可视化 - 节点大小表示记忆数量,颜色表示连接数'
|
||||
plt.title(title, fontsize=16, fontfamily='SimHei')
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@@ -45,7 +45,7 @@ class LLM_request:
|
||||
self.db.db.llm_usage.create_index([("user_id", 1)])
|
||||
self.db.db.llm_usage.create_index([("request_type", 1)])
|
||||
except Exception as e:
|
||||
logger.error(f"创建数据库索引失败: {e}")
|
||||
logger.error(f"创建数据库索引失败")
|
||||
|
||||
def _record_usage(self, prompt_tokens: int, completion_tokens: int, total_tokens: int,
|
||||
user_id: str = "system", request_type: str = "chat",
|
||||
@@ -79,8 +79,8 @@ class LLM_request:
|
||||
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
|
||||
f"总计: {total_tokens}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"记录token使用情况失败: {e}")
|
||||
except Exception:
|
||||
logger.error(f"记录token使用情况失败")
|
||||
|
||||
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
|
||||
"""计算API调用成本
|
||||
@@ -226,8 +226,8 @@ class LLM_request:
|
||||
if delta_content is None:
|
||||
delta_content = ""
|
||||
accumulated_content += delta_content
|
||||
except Exception as e:
|
||||
logger.error(f"解析流式输出错误: {e}")
|
||||
except Exception:
|
||||
logger.exception(f"解析流式输出错")
|
||||
content = accumulated_content
|
||||
reasoning_content = ""
|
||||
think_match = re.search(r'<think>(.*?)</think>', content, re.DOTALL)
|
||||
|
||||
@@ -3,6 +3,7 @@ import time
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict
|
||||
from loguru import logger
|
||||
|
||||
from ...common.database import Database
|
||||
|
||||
@@ -153,8 +154,8 @@ class LLMStatistics:
|
||||
try:
|
||||
all_stats = self._collect_all_statistics()
|
||||
self._save_statistics(all_stats)
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 统计数据处理失败: {e}")
|
||||
except Exception:
|
||||
logger.exception(f"统计数据处理失败")
|
||||
|
||||
# 等待1分钟
|
||||
for _ in range(60):
|
||||
|
||||
Reference in New Issue
Block a user