fix: remove duplicate message(CR comments)
This commit is contained in:
@@ -5,6 +5,9 @@ import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, List
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
from pymongo import MongoClient
|
||||
|
||||
import customtkinter as ctk
|
||||
from dotenv import load_dotenv
|
||||
@@ -17,23 +20,20 @@ root_dir = os.path.abspath(os.path.join(current_dir, '..', '..'))
|
||||
# 加载环境变量
|
||||
if os.path.exists(os.path.join(root_dir, '.env.dev')):
|
||||
load_dotenv(os.path.join(root_dir, '.env.dev'))
|
||||
print("成功加载开发环境配置")
|
||||
logger.info("成功加载开发环境配置")
|
||||
elif os.path.exists(os.path.join(root_dir, '.env.prod')):
|
||||
load_dotenv(os.path.join(root_dir, '.env.prod'))
|
||||
print("成功加载生产环境配置")
|
||||
logger.info("成功加载生产环境配置")
|
||||
else:
|
||||
print("未找到环境配置文件")
|
||||
logger.error("未找到环境配置文件")
|
||||
sys.exit(1)
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pymongo import MongoClient
|
||||
|
||||
|
||||
class Database:
|
||||
_instance: Optional["Database"] = None
|
||||
|
||||
def __init__(self, host: str, port: int, db_name: str, username: str = None, password: str = None, auth_source: str = None):
|
||||
def __init__(self, host: str, port: int, db_name: str, username: str = None, password: str = None,
|
||||
auth_source: str = None):
|
||||
if username and password:
|
||||
self.client = MongoClient(
|
||||
host=host,
|
||||
@@ -47,7 +47,8 @@ class Database:
|
||||
self.db = self.client[db_name]
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, host: str, port: int, db_name: str, username: str = None, password: str = None, auth_source: str = None) -> "Database":
|
||||
def initialize(cls, host: str, port: int, db_name: str, username: str = None, password: str = None,
|
||||
auth_source: str = None) -> "Database":
|
||||
if cls._instance is None:
|
||||
cls._instance = cls(host, port, db_name, username, password, auth_source)
|
||||
return cls._instance
|
||||
@@ -59,12 +60,11 @@ class Database:
|
||||
return cls._instance
|
||||
|
||||
|
||||
|
||||
class ReasoningGUI:
|
||||
def __init__(self):
|
||||
# 记录启动时间戳,转换为Unix时间戳
|
||||
self.start_timestamp = datetime.now().timestamp()
|
||||
print(f"程序启动时间戳: {self.start_timestamp}")
|
||||
logger.info(f"程序启动时间戳: {self.start_timestamp}")
|
||||
|
||||
# 设置主题
|
||||
ctk.set_appearance_mode("dark")
|
||||
@@ -79,15 +79,15 @@ class ReasoningGUI:
|
||||
# 初始化数据库连接
|
||||
try:
|
||||
self.db = Database.get_instance().db
|
||||
print("数据库连接成功")
|
||||
logger.success("数据库连接成功")
|
||||
except RuntimeError:
|
||||
print("数据库未初始化,正在尝试初始化...")
|
||||
logger.warning("数据库未初始化,正在尝试初始化...")
|
||||
try:
|
||||
Database.initialize("127.0.0.1", 27017, "maimai_bot")
|
||||
self.db = Database.get_instance().db
|
||||
print("数据库初始化成功")
|
||||
except Exception as e:
|
||||
print(f"数据库初始化失败: {e}")
|
||||
logger.success("数据库初始化成功")
|
||||
except Exception:
|
||||
logger.exception(f"数据库初始化失败")
|
||||
sys.exit(1)
|
||||
|
||||
# 存储群组数据
|
||||
@@ -274,7 +274,7 @@ class ReasoningGUI:
|
||||
self.content_text.insert("end", f"{item.get('response', '')}\n", "response")
|
||||
|
||||
# 分隔符
|
||||
self.content_text.insert("end", f"\n{'='*50}\n\n", "separator")
|
||||
self.content_text.insert("end", f"\n{'=' * 50}\n\n", "separator")
|
||||
|
||||
# 滚动到顶部
|
||||
self.content_text.see("1.0")
|
||||
@@ -285,12 +285,12 @@ class ReasoningGUI:
|
||||
try:
|
||||
# 从数据库获取最新数据,只获取启动时间之后的记录
|
||||
query = {"time": {"$gt": self.start_timestamp}}
|
||||
print(f"查询条件: {query}")
|
||||
logger.debug(f"查询条件: {query}")
|
||||
|
||||
# 先获取一条记录检查时间格式
|
||||
sample = self.db.reasoning_logs.find_one()
|
||||
if sample:
|
||||
print(f"样本记录时间格式: {type(sample['time'])} 值: {sample['time']}")
|
||||
logger.debug(f"样本记录时间格式: {type(sample['time'])} 值: {sample['time']}")
|
||||
|
||||
cursor = self.db.reasoning_logs.find(query).sort("time", -1)
|
||||
new_data = {}
|
||||
@@ -299,7 +299,7 @@ class ReasoningGUI:
|
||||
for item in cursor:
|
||||
# 调试输出
|
||||
if total_count == 0:
|
||||
print(f"记录时间: {item['time']}, 类型: {type(item['time'])}")
|
||||
logger.debug(f"记录时间: {item['time']}, 类型: {type(item['time'])}")
|
||||
|
||||
total_count += 1
|
||||
group_id = str(item.get('group_id', 'unknown'))
|
||||
@@ -312,7 +312,7 @@ class ReasoningGUI:
|
||||
elif isinstance(item['time'], datetime):
|
||||
time_obj = item['time']
|
||||
else:
|
||||
print(f"未知的时间格式: {type(item['time'])}")
|
||||
logger.warning(f"未知的时间格式: {type(item['time'])}")
|
||||
time_obj = datetime.now() # 使用当前时间作为后备
|
||||
|
||||
new_data[group_id].append({
|
||||
@@ -325,12 +325,12 @@ class ReasoningGUI:
|
||||
'prompt': item.get('prompt', '') # 添加prompt字段
|
||||
})
|
||||
|
||||
print(f"从数据库加载了 {total_count} 条记录,分布在 {len(new_data)} 个群组中")
|
||||
logger.info(f"从数据库加载了 {total_count} 条记录,分布在 {len(new_data)} 个群组中")
|
||||
|
||||
# 更新数据
|
||||
if new_data != self.group_data:
|
||||
self.group_data = new_data
|
||||
print("数据已更新,正在刷新显示...")
|
||||
logger.info("数据已更新,正在刷新显示...")
|
||||
# 将更新任务添加到队列
|
||||
self.update_queue.put({'type': 'update_group_list'})
|
||||
if self.group_data:
|
||||
@@ -341,8 +341,8 @@ class ReasoningGUI:
|
||||
'type': 'update_display',
|
||||
'group_id': self.selected_group_id
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"自动更新出错: {e}")
|
||||
except Exception:
|
||||
logger.exception(f"自动更新出错")
|
||||
|
||||
# 每5秒更新一次
|
||||
time.sleep(5)
|
||||
@@ -359,11 +359,11 @@ class ReasoningGUI:
|
||||
def main():
|
||||
"""主函数"""
|
||||
Database.initialize(
|
||||
host= os.getenv("MONGODB_HOST"),
|
||||
port= int(os.getenv("MONGODB_PORT")),
|
||||
db_name= os.getenv("DATABASE_NAME"),
|
||||
username= os.getenv("MONGODB_USERNAME"),
|
||||
password= os.getenv("MONGODB_PASSWORD"),
|
||||
host=os.getenv("MONGODB_HOST"),
|
||||
port=int(os.getenv("MONGODB_PORT")),
|
||||
db_name=os.getenv("DATABASE_NAME"),
|
||||
username=os.getenv("MONGODB_USERNAME"),
|
||||
password=os.getenv("MONGODB_PASSWORD"),
|
||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
|
||||
)
|
||||
|
||||
@@ -371,6 +371,5 @@ def main():
|
||||
app.run()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -135,7 +135,7 @@ class BotConfig:
|
||||
try:
|
||||
config_version: str = toml["inner"]["version"]
|
||||
except KeyError as e:
|
||||
logger.error(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件")
|
||||
logger.error(f"配置文件中 inner 段 不存在, 这是错误的配置文件")
|
||||
raise KeyError(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件")
|
||||
else:
|
||||
toml["inner"] = {"version": "0.0.0"}
|
||||
@@ -246,11 +246,11 @@ class BotConfig:
|
||||
try:
|
||||
cfg_target[i] = cfg_item[i]
|
||||
except KeyError as e:
|
||||
logger.error(f"{item} 中的必要字段 {e} 不存在,请检查")
|
||||
logger.error(f"{item} 中的必要字段不存在,请检查")
|
||||
raise KeyError(f"{item} 中的必要字段 {e} 不存在,请检查")
|
||||
|
||||
provider = cfg_item.get("provider")
|
||||
if provider == None:
|
||||
if provider is None:
|
||||
logger.error(f"provider 字段在模型配置 {item} 中不存在,请检查")
|
||||
raise KeyError(f"provider 字段在模型配置 {item} 中不存在,请检查")
|
||||
|
||||
|
||||
@@ -93,8 +93,8 @@ class ResponseGenerator:
|
||||
# 生成回复
|
||||
try:
|
||||
content, reasoning_content = await model.generate_response(prompt)
|
||||
except Exception as e:
|
||||
logger.exception(f"生成回复时出错: {e}")
|
||||
except Exception:
|
||||
logger.exception(f"生成回复时出错")
|
||||
return None
|
||||
|
||||
# 保存到数据库
|
||||
@@ -145,8 +145,8 @@ class ResponseGenerator:
|
||||
else:
|
||||
return ["neutral"]
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"获取情感标签时出错: {e}")
|
||||
except Exception:
|
||||
logger.exception(f"获取情感标签时出错")
|
||||
return ["neutral"]
|
||||
|
||||
async def _process_response(self, content: str) -> Tuple[List[str], List[str]]:
|
||||
|
||||
@@ -119,8 +119,8 @@ class MessageContainer:
|
||||
self.messages.remove(message)
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.exception(f"移除消息时发生错误: {e}")
|
||||
except Exception:
|
||||
logger.exception(f"移除消息时发生错误")
|
||||
return False
|
||||
|
||||
def has_messages(self) -> bool:
|
||||
@@ -213,8 +213,8 @@ class MessageManager:
|
||||
# 安全地移除消息
|
||||
if not container.remove_message(msg):
|
||||
logger.warning("尝试删除不存在的消息")
|
||||
except Exception as e:
|
||||
logger.exception(f"处理超时消息时发生错误: {e}")
|
||||
except Exception:
|
||||
logger.exception(f"处理超时消息时发生错误")
|
||||
continue
|
||||
|
||||
async def start_processor(self):
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Optional
|
||||
|
||||
from ...common.database import Database
|
||||
from .message import Message
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class MessageStorage:
|
||||
@@ -43,7 +44,7 @@ class MessageStorage:
|
||||
}
|
||||
|
||||
self.db.db.messages.insert_one(message_data)
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 存储消息失败: {e}")
|
||||
except Exception:
|
||||
logger.exception(f"存储消息失败")
|
||||
|
||||
# 如果需要其他存储相关的函数,可以在这里添加
|
||||
@@ -7,6 +7,7 @@ import jieba
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
|
||||
from src.common.database import Database # 使用正确的导入语法
|
||||
@@ -45,7 +46,7 @@ class Memory_graph:
|
||||
node_data = self.G.nodes[concept]
|
||||
# print(node_data)
|
||||
# 创建新的Memory_dot对象
|
||||
return concept,node_data
|
||||
return concept, node_data
|
||||
return None
|
||||
|
||||
def get_related_item(self, topic, depth=1):
|
||||
@@ -99,24 +100,26 @@ class Memory_graph:
|
||||
# 返回所有节点对应的 Memory_dot 对象
|
||||
return [self.get_dot(node) for node in self.G.nodes()]
|
||||
|
||||
|
||||
def get_random_chat_from_db(self, length: int, timestamp: str):
|
||||
# 从数据库中根据时间戳获取离其最近的聊天记录
|
||||
chat_text = ''
|
||||
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
|
||||
print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
|
||||
logger.info(
|
||||
f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
|
||||
|
||||
if closest_record:
|
||||
closest_time = closest_record['time']
|
||||
group_id = closest_record['group_id'] # 获取groupid
|
||||
# 获取该时间戳之后的length条消息,且groupid相同
|
||||
chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
|
||||
chat_record = list(
|
||||
self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(
|
||||
length))
|
||||
for record in chat_record:
|
||||
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
||||
try:
|
||||
displayname="[(%s)%s]%s" % (record["user_id"],record["user_nickname"],record["user_cardname"])
|
||||
displayname = "[(%s)%s]%s" % (record["user_id"], record["user_nickname"], record["user_cardname"])
|
||||
except:
|
||||
displayname=record["user_nickname"] or "用户" + str(record["user_id"])
|
||||
displayname = record["user_nickname"] or "用户" + str(record["user_id"])
|
||||
chat_text += f'[{time_str}] {displayname}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
|
||||
return chat_text
|
||||
|
||||
@@ -179,30 +182,31 @@ def main():
|
||||
break
|
||||
first_layer_items, second_layer_items = memory_graph.get_related_item(query)
|
||||
if first_layer_items or second_layer_items:
|
||||
print("\n第一层记忆:")
|
||||
logger.debug("第一层记忆:")
|
||||
for item in first_layer_items:
|
||||
print(item)
|
||||
print("\n第二层记忆:")
|
||||
logger.debug(item)
|
||||
logger.debug("第二层记忆:")
|
||||
for item in second_layer_items:
|
||||
print(item)
|
||||
logger.debug(item)
|
||||
else:
|
||||
print("未找到相关记忆。")
|
||||
logger.debug("未找到相关记忆。")
|
||||
|
||||
|
||||
def segment_text(text):
|
||||
seg_text = list(jieba.cut(text))
|
||||
return seg_text
|
||||
|
||||
|
||||
def find_topic(text, topic_num):
|
||||
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
|
||||
return prompt
|
||||
|
||||
|
||||
def topic_what(text, topic):
|
||||
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
|
||||
return prompt
|
||||
|
||||
|
||||
|
||||
def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
|
||||
# 设置中文字体
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
|
||||
@@ -226,7 +230,7 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
|
||||
|
||||
# 如果过滤后没有节点,则返回
|
||||
if len(H.nodes()) == 0:
|
||||
print("过滤后没有符合条件的节点可显示")
|
||||
logger.debug("过滤后没有符合条件的节点可显示")
|
||||
return
|
||||
|
||||
# 保存图到本地
|
||||
@@ -254,7 +258,7 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
|
||||
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
|
||||
# 使用指数函数使变化更明显
|
||||
ratio = memory_count / max_memories
|
||||
size = 500 + 5000 * (ratio ) # 使用1.5次方函数使差异不那么明显
|
||||
size = 500 + 5000 * (ratio) # 使用1.5次方函数使差异不那么明显
|
||||
node_sizes.append(size)
|
||||
|
||||
# 计算节点颜色(基于连接数)
|
||||
@@ -272,21 +276,20 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
|
||||
plt.figure(figsize=(12, 8))
|
||||
pos = nx.spring_layout(H, k=1, iterations=50) # 增加k值使节点分布更开
|
||||
nx.draw(H, pos,
|
||||
with_labels=True,
|
||||
node_color=node_colors,
|
||||
node_size=node_sizes,
|
||||
font_size=10,
|
||||
font_family='SimHei',
|
||||
font_weight='bold',
|
||||
edge_color='gray',
|
||||
width=0.5,
|
||||
alpha=0.9)
|
||||
with_labels=True,
|
||||
node_color=node_colors,
|
||||
node_size=node_sizes,
|
||||
font_size=10,
|
||||
font_family='SimHei',
|
||||
font_weight='bold',
|
||||
edge_color='gray',
|
||||
width=0.5,
|
||||
alpha=0.9)
|
||||
|
||||
title = '记忆图谱可视化 - 节点大小表示记忆数量,颜色表示连接数'
|
||||
plt.title(title, fontsize=16, fontfamily='SimHei')
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -45,7 +45,7 @@ class LLM_request:
|
||||
self.db.db.llm_usage.create_index([("user_id", 1)])
|
||||
self.db.db.llm_usage.create_index([("request_type", 1)])
|
||||
except Exception as e:
|
||||
logger.error(f"创建数据库索引失败: {e}")
|
||||
logger.error(f"创建数据库索引失败")
|
||||
|
||||
def _record_usage(self, prompt_tokens: int, completion_tokens: int, total_tokens: int,
|
||||
user_id: str = "system", request_type: str = "chat",
|
||||
@@ -79,8 +79,8 @@ class LLM_request:
|
||||
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
|
||||
f"总计: {total_tokens}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"记录token使用情况失败: {e}")
|
||||
except Exception:
|
||||
logger.error(f"记录token使用情况失败")
|
||||
|
||||
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
|
||||
"""计算API调用成本
|
||||
@@ -226,8 +226,8 @@ class LLM_request:
|
||||
if delta_content is None:
|
||||
delta_content = ""
|
||||
accumulated_content += delta_content
|
||||
except Exception as e:
|
||||
logger.error(f"解析流式输出错误: {e}")
|
||||
except Exception:
|
||||
logger.exception(f"解析流式输出错")
|
||||
content = accumulated_content
|
||||
reasoning_content = ""
|
||||
think_match = re.search(r'<think>(.*?)</think>', content, re.DOTALL)
|
||||
|
||||
@@ -3,6 +3,7 @@ import time
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict
|
||||
from loguru import logger
|
||||
|
||||
from ...common.database import Database
|
||||
|
||||
@@ -153,8 +154,8 @@ class LLMStatistics:
|
||||
try:
|
||||
all_stats = self._collect_all_statistics()
|
||||
self._save_statistics(all_stats)
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 统计数据处理失败: {e}")
|
||||
except Exception:
|
||||
logger.exception(f"统计数据处理失败")
|
||||
|
||||
# 等待1分钟
|
||||
for _ in range(60):
|
||||
|
||||
Reference in New Issue
Block a user