Merge pull request #155 from KX76/fix/20250310-logger-optimize

Fix/20250310 logger optimize
This commit is contained in:
HYY
2025-03-10 11:51:48 +08:00
committed by GitHub
23 changed files with 834 additions and 798 deletions

3
bot.py
View File

@@ -100,7 +100,7 @@ def load_logger():
"#777777>|</> <cyan>{name:.<8}</cyan>:<cyan>{function:.<8}</cyan>:<cyan>{line: >4}</cyan> <fg " "#777777>|</> <cyan>{name:.<8}</cyan>:<cyan>{function:.<8}</cyan>:<cyan>{line: >4}</cyan> <fg "
"#777777>-</> <level>{message}</level>", "#777777>-</> <level>{message}</level>",
colorize=True, colorize=True,
level=os.getenv("LOG_LEVEL", "INFO") # 根据环境设置日志级别默认为INFO level=os.getenv("LOG_LEVEL", "DEBUG") # 根据环境设置日志级别默认为INFO
) )
@@ -149,6 +149,7 @@ if __name__ == "__main__":
init_config() init_config()
init_env() init_env()
load_env() load_env()
load_logger()
env_config = {key: os.getenv(key) for key in os.environ} env_config = {key: os.getenv(key) for key in os.environ}
scan_provider(env_config) scan_provider(env_config)

View File

@@ -5,6 +5,9 @@ import threading
import time import time
from datetime import datetime from datetime import datetime
from typing import Dict, List from typing import Dict, List
from loguru import logger
from typing import Optional
from pymongo import MongoClient
import customtkinter as ctk import customtkinter as ctk
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -17,23 +20,20 @@ root_dir = os.path.abspath(os.path.join(current_dir, '..', '..'))
# 加载环境变量 # 加载环境变量
if os.path.exists(os.path.join(root_dir, '.env.dev')): if os.path.exists(os.path.join(root_dir, '.env.dev')):
load_dotenv(os.path.join(root_dir, '.env.dev')) load_dotenv(os.path.join(root_dir, '.env.dev'))
print("成功加载开发环境配置") logger.info("成功加载开发环境配置")
elif os.path.exists(os.path.join(root_dir, '.env.prod')): elif os.path.exists(os.path.join(root_dir, '.env.prod')):
load_dotenv(os.path.join(root_dir, '.env.prod')) load_dotenv(os.path.join(root_dir, '.env.prod'))
print("成功加载生产环境配置") logger.info("成功加载生产环境配置")
else: else:
print("未找到环境配置文件") logger.error("未找到环境配置文件")
sys.exit(1) sys.exit(1)
from typing import Optional
from pymongo import MongoClient
class Database: class Database:
_instance: Optional["Database"] = None _instance: Optional["Database"] = None
def __init__(self, host: str, port: int, db_name: str, username: str = None, password: str = None, auth_source: str = None): def __init__(self, host: str, port: int, db_name: str, username: str = None, password: str = None,
auth_source: str = None):
if username and password: if username and password:
self.client = MongoClient( self.client = MongoClient(
host=host, host=host,
@@ -45,96 +45,96 @@ class Database:
else: else:
self.client = MongoClient(host, port) self.client = MongoClient(host, port)
self.db = self.client[db_name] self.db = self.client[db_name]
@classmethod @classmethod
def initialize(cls, host: str, port: int, db_name: str, username: str = None, password: str = None, auth_source: str = None) -> "Database": def initialize(cls, host: str, port: int, db_name: str, username: str = None, password: str = None,
auth_source: str = None) -> "Database":
if cls._instance is None: if cls._instance is None:
cls._instance = cls(host, port, db_name, username, password, auth_source) cls._instance = cls(host, port, db_name, username, password, auth_source)
return cls._instance return cls._instance
@classmethod @classmethod
def get_instance(cls) -> "Database": def get_instance(cls) -> "Database":
if cls._instance is None: if cls._instance is None:
raise RuntimeError("Database not initialized") raise RuntimeError("Database not initialized")
return cls._instance return cls._instance
class ReasoningGUI: class ReasoningGUI:
def __init__(self): def __init__(self):
# 记录启动时间戳转换为Unix时间戳 # 记录启动时间戳转换为Unix时间戳
self.start_timestamp = datetime.now().timestamp() self.start_timestamp = datetime.now().timestamp()
print(f"程序启动时间戳: {self.start_timestamp}") logger.info(f"程序启动时间戳: {self.start_timestamp}")
# 设置主题 # 设置主题
ctk.set_appearance_mode("dark") ctk.set_appearance_mode("dark")
ctk.set_default_color_theme("blue") ctk.set_default_color_theme("blue")
# 创建主窗口 # 创建主窗口
self.root = ctk.CTk() self.root = ctk.CTk()
self.root.title('麦麦推理') self.root.title('麦麦推理')
self.root.geometry('800x600') self.root.geometry('800x600')
self.root.protocol("WM_DELETE_WINDOW", self._on_closing) self.root.protocol("WM_DELETE_WINDOW", self._on_closing)
# 初始化数据库连接 # 初始化数据库连接
try: try:
self.db = Database.get_instance().db self.db = Database.get_instance().db
print("数据库连接成功") logger.success("数据库连接成功")
except RuntimeError: except RuntimeError:
print("数据库未初始化,正在尝试初始化...") logger.warning("数据库未初始化,正在尝试初始化...")
try: try:
Database.initialize("127.0.0.1", 27017, "maimai_bot") Database.initialize("127.0.0.1", 27017, "maimai_bot")
self.db = Database.get_instance().db self.db = Database.get_instance().db
print("数据库初始化成功") logger.success("数据库初始化成功")
except Exception as e: except Exception:
print(f"数据库初始化失败: {e}") logger.exception(f"数据库初始化失败")
sys.exit(1) sys.exit(1)
# 存储群组数据 # 存储群组数据
self.group_data: Dict[str, List[dict]] = {} self.group_data: Dict[str, List[dict]] = {}
# 创建更新队列 # 创建更新队列
self.update_queue = queue.Queue() self.update_queue = queue.Queue()
# 创建主框架 # 创建主框架
self.frame = ctk.CTkFrame(self.root) self.frame = ctk.CTkFrame(self.root)
self.frame.pack(pady=20, padx=20, fill="both", expand=True) self.frame.pack(pady=20, padx=20, fill="both", expand=True)
# 添加标题 # 添加标题
self.title = ctk.CTkLabel(self.frame, text="麦麦的脑内所想", font=("Arial", 24)) self.title = ctk.CTkLabel(self.frame, text="麦麦的脑内所想", font=("Arial", 24))
self.title.pack(pady=10, padx=10) self.title.pack(pady=10, padx=10)
# 创建左右分栏 # 创建左右分栏
self.paned = ctk.CTkFrame(self.frame) self.paned = ctk.CTkFrame(self.frame)
self.paned.pack(fill="both", expand=True, padx=10, pady=10) self.paned.pack(fill="both", expand=True, padx=10, pady=10)
# 左侧群组列表 # 左侧群组列表
self.left_frame = ctk.CTkFrame(self.paned, width=200) self.left_frame = ctk.CTkFrame(self.paned, width=200)
self.left_frame.pack(side="left", fill="y", padx=5, pady=5) self.left_frame.pack(side="left", fill="y", padx=5, pady=5)
self.group_label = ctk.CTkLabel(self.left_frame, text="群组列表", font=("Arial", 16)) self.group_label = ctk.CTkLabel(self.left_frame, text="群组列表", font=("Arial", 16))
self.group_label.pack(pady=5) self.group_label.pack(pady=5)
# 创建可滚动框架来容纳群组按钮 # 创建可滚动框架来容纳群组按钮
self.group_scroll_frame = ctk.CTkScrollableFrame(self.left_frame, width=180, height=400) self.group_scroll_frame = ctk.CTkScrollableFrame(self.left_frame, width=180, height=400)
self.group_scroll_frame.pack(pady=5, padx=5, fill="both", expand=True) self.group_scroll_frame.pack(pady=5, padx=5, fill="both", expand=True)
# 存储群组按钮的字典 # 存储群组按钮的字典
self.group_buttons: Dict[str, ctk.CTkButton] = {} self.group_buttons: Dict[str, ctk.CTkButton] = {}
# 当前选中的群组ID # 当前选中的群组ID
self.selected_group_id: Optional[str] = None self.selected_group_id: Optional[str] = None
# 右侧内容显示 # 右侧内容显示
self.right_frame = ctk.CTkFrame(self.paned) self.right_frame = ctk.CTkFrame(self.paned)
self.right_frame.pack(side="right", fill="both", expand=True, padx=5, pady=5) self.right_frame.pack(side="right", fill="both", expand=True, padx=5, pady=5)
self.content_label = ctk.CTkLabel(self.right_frame, text="推理内容", font=("Arial", 16)) self.content_label = ctk.CTkLabel(self.right_frame, text="推理内容", font=("Arial", 16))
self.content_label.pack(pady=5) self.content_label.pack(pady=5)
# 创建富文本显示框 # 创建富文本显示框
self.content_text = ctk.CTkTextbox(self.right_frame, width=500, height=400) self.content_text = ctk.CTkTextbox(self.right_frame, width=500, height=400)
self.content_text.pack(pady=5, padx=5, fill="both", expand=True) self.content_text.pack(pady=5, padx=5, fill="both", expand=True)
# 配置文本标签 - 只使用颜色 # 配置文本标签 - 只使用颜色
self.content_text.tag_config("timestamp", foreground="#888888") # 时间戳使用灰色 self.content_text.tag_config("timestamp", foreground="#888888") # 时间戳使用灰色
self.content_text.tag_config("user", foreground="#4CAF50") # 用户名使用绿色 self.content_text.tag_config("user", foreground="#4CAF50") # 用户名使用绿色
@@ -144,11 +144,11 @@ class ReasoningGUI:
self.content_text.tag_config("reasoning", foreground="#FF9800") # 推理过程使用橙色 self.content_text.tag_config("reasoning", foreground="#FF9800") # 推理过程使用橙色
self.content_text.tag_config("response", foreground="#E91E63") # 回复使用粉色 self.content_text.tag_config("response", foreground="#E91E63") # 回复使用粉色
self.content_text.tag_config("separator", foreground="#666666") # 分隔符使用深灰色 self.content_text.tag_config("separator", foreground="#666666") # 分隔符使用深灰色
# 底部控制栏 # 底部控制栏
self.control_frame = ctk.CTkFrame(self.frame) self.control_frame = ctk.CTkFrame(self.frame)
self.control_frame.pack(fill="x", padx=10, pady=5) self.control_frame.pack(fill="x", padx=10, pady=5)
self.clear_button = ctk.CTkButton( self.clear_button = ctk.CTkButton(
self.control_frame, self.control_frame,
text="清除显示", text="清除显示",
@@ -156,19 +156,19 @@ class ReasoningGUI:
width=120 width=120
) )
self.clear_button.pack(side="left", padx=5) self.clear_button.pack(side="left", padx=5)
# 启动自动更新线程 # 启动自动更新线程
self.update_thread = threading.Thread(target=self._auto_update, daemon=True) self.update_thread = threading.Thread(target=self._auto_update, daemon=True)
self.update_thread.start() self.update_thread.start()
# 启动GUI更新检查 # 启动GUI更新检查
self.root.after(100, self._process_queue) self.root.after(100, self._process_queue)
def _on_closing(self): def _on_closing(self):
"""处理窗口关闭事件""" """处理窗口关闭事件"""
self.root.quit() self.root.quit()
sys.exit(0) sys.exit(0)
def _process_queue(self): def _process_queue(self):
"""处理更新队列中的任务""" """处理更新队列中的任务"""
try: try:
@@ -183,14 +183,14 @@ class ReasoningGUI:
finally: finally:
# 继续检查队列 # 继续检查队列
self.root.after(100, self._process_queue) self.root.after(100, self._process_queue)
def _update_group_list_gui(self): def _update_group_list_gui(self):
"""在主线程中更新群组列表""" """在主线程中更新群组列表"""
# 清除现有按钮 # 清除现有按钮
for button in self.group_buttons.values(): for button in self.group_buttons.values():
button.destroy() button.destroy()
self.group_buttons.clear() self.group_buttons.clear()
# 创建新的群组按钮 # 创建新的群组按钮
for group_id in self.group_data.keys(): for group_id in self.group_data.keys():
button = ctk.CTkButton( button = ctk.CTkButton(
@@ -203,16 +203,16 @@ class ReasoningGUI:
) )
button.pack(pady=2, padx=5) button.pack(pady=2, padx=5)
self.group_buttons[group_id] = button self.group_buttons[group_id] = button
# 如果有选中的群组,保持其高亮状态 # 如果有选中的群组,保持其高亮状态
if self.selected_group_id and self.selected_group_id in self.group_buttons: if self.selected_group_id and self.selected_group_id in self.group_buttons:
self._highlight_selected_group(self.selected_group_id) self._highlight_selected_group(self.selected_group_id)
def _on_group_select(self, group_id: str): def _on_group_select(self, group_id: str):
"""处理群组选择事件""" """处理群组选择事件"""
self._highlight_selected_group(group_id) self._highlight_selected_group(group_id)
self._update_display_gui(group_id) self._update_display_gui(group_id)
def _highlight_selected_group(self, group_id: str): def _highlight_selected_group(self, group_id: str):
"""高亮显示选中的群组按钮""" """高亮显示选中的群组按钮"""
# 重置所有按钮的颜色 # 重置所有按钮的颜色
@@ -223,9 +223,9 @@ class ReasoningGUI:
else: else:
# 恢复其他按钮的默认颜色 # 恢复其他按钮的默认颜色
button.configure(fg_color="#2B2B2B", hover_color="#404040") button.configure(fg_color="#2B2B2B", hover_color="#404040")
self.selected_group_id = group_id self.selected_group_id = group_id
def _update_display_gui(self, group_id: str): def _update_display_gui(self, group_id: str):
"""在主线程中更新显示内容""" """在主线程中更新显示内容"""
if group_id in self.group_data: if group_id in self.group_data:
@@ -234,19 +234,19 @@ class ReasoningGUI:
# 时间戳 # 时间戳
time_str = item['time'].strftime("%Y-%m-%d %H:%M:%S") time_str = item['time'].strftime("%Y-%m-%d %H:%M:%S")
self.content_text.insert("end", f"[{time_str}]\n", "timestamp") self.content_text.insert("end", f"[{time_str}]\n", "timestamp")
# 用户信息 # 用户信息
self.content_text.insert("end", "用户: ", "timestamp") self.content_text.insert("end", "用户: ", "timestamp")
self.content_text.insert("end", f"{item.get('user', '未知')}\n", "user") self.content_text.insert("end", f"{item.get('user', '未知')}\n", "user")
# 消息内容 # 消息内容
self.content_text.insert("end", "消息: ", "timestamp") self.content_text.insert("end", "消息: ", "timestamp")
self.content_text.insert("end", f"{item.get('message', '')}\n", "message") self.content_text.insert("end", f"{item.get('message', '')}\n", "message")
# 模型信息 # 模型信息
self.content_text.insert("end", "模型: ", "timestamp") self.content_text.insert("end", "模型: ", "timestamp")
self.content_text.insert("end", f"{item.get('model', '')}\n", "model") self.content_text.insert("end", f"{item.get('model', '')}\n", "model")
# Prompt内容 # Prompt内容
self.content_text.insert("end", "Prompt内容:\n", "timestamp") self.content_text.insert("end", "Prompt内容:\n", "timestamp")
prompt_text = item.get('prompt', '') prompt_text = item.get('prompt', '')
@@ -257,7 +257,7 @@ class ReasoningGUI:
self.content_text.insert("end", " " + line + "\n", "prompt") self.content_text.insert("end", " " + line + "\n", "prompt")
else: else:
self.content_text.insert("end", " 无Prompt内容\n", "prompt") self.content_text.insert("end", " 无Prompt内容\n", "prompt")
# 推理过程 # 推理过程
self.content_text.insert("end", "推理过程:\n", "timestamp") self.content_text.insert("end", "推理过程:\n", "timestamp")
reasoning_text = item.get('reasoning', '') reasoning_text = item.get('reasoning', '')
@@ -268,53 +268,53 @@ class ReasoningGUI:
self.content_text.insert("end", " " + line + "\n", "reasoning") self.content_text.insert("end", " " + line + "\n", "reasoning")
else: else:
self.content_text.insert("end", " 无推理过程\n", "reasoning") self.content_text.insert("end", " 无推理过程\n", "reasoning")
# 回复内容 # 回复内容
self.content_text.insert("end", "回复: ", "timestamp") self.content_text.insert("end", "回复: ", "timestamp")
self.content_text.insert("end", f"{item.get('response', '')}\n", "response") self.content_text.insert("end", f"{item.get('response', '')}\n", "response")
# 分隔符 # 分隔符
self.content_text.insert("end", f"\n{'='*50}\n\n", "separator") self.content_text.insert("end", f"\n{'=' * 50}\n\n", "separator")
# 滚动到顶部 # 滚动到顶部
self.content_text.see("1.0") self.content_text.see("1.0")
def _auto_update(self): def _auto_update(self):
"""自动更新函数""" """自动更新函数"""
while True: while True:
try: try:
# 从数据库获取最新数据,只获取启动时间之后的记录 # 从数据库获取最新数据,只获取启动时间之后的记录
query = {"time": {"$gt": self.start_timestamp}} query = {"time": {"$gt": self.start_timestamp}}
print(f"查询条件: {query}") logger.debug(f"查询条件: {query}")
# 先获取一条记录检查时间格式 # 先获取一条记录检查时间格式
sample = self.db.reasoning_logs.find_one() sample = self.db.reasoning_logs.find_one()
if sample: if sample:
print(f"样本记录时间格式: {type(sample['time'])} 值: {sample['time']}") logger.debug(f"样本记录时间格式: {type(sample['time'])} 值: {sample['time']}")
cursor = self.db.reasoning_logs.find(query).sort("time", -1) cursor = self.db.reasoning_logs.find(query).sort("time", -1)
new_data = {} new_data = {}
total_count = 0 total_count = 0
for item in cursor: for item in cursor:
# 调试输出 # 调试输出
if total_count == 0: if total_count == 0:
print(f"记录时间: {item['time']}, 类型: {type(item['time'])}") logger.debug(f"记录时间: {item['time']}, 类型: {type(item['time'])}")
total_count += 1 total_count += 1
group_id = str(item.get('group_id', 'unknown')) group_id = str(item.get('group_id', 'unknown'))
if group_id not in new_data: if group_id not in new_data:
new_data[group_id] = [] new_data[group_id] = []
# 转换时间戳为datetime对象 # 转换时间戳为datetime对象
if isinstance(item['time'], (int, float)): if isinstance(item['time'], (int, float)):
time_obj = datetime.fromtimestamp(item['time']) time_obj = datetime.fromtimestamp(item['time'])
elif isinstance(item['time'], datetime): elif isinstance(item['time'], datetime):
time_obj = item['time'] time_obj = item['time']
else: else:
print(f"未知的时间格式: {type(item['time'])}") logger.warning(f"未知的时间格式: {type(item['time'])}")
time_obj = datetime.now() # 使用当前时间作为后备 time_obj = datetime.now() # 使用当前时间作为后备
new_data[group_id].append({ new_data[group_id].append({
'time': time_obj, 'time': time_obj,
'user': item.get('user', '未知'), 'user': item.get('user', '未知'),
@@ -324,13 +324,13 @@ class ReasoningGUI:
'response': item.get('response', ''), 'response': item.get('response', ''),
'prompt': item.get('prompt', '') # 添加prompt字段 'prompt': item.get('prompt', '') # 添加prompt字段
}) })
print(f"从数据库加载了 {total_count} 条记录,分布在 {len(new_data)} 个群组中") logger.info(f"从数据库加载了 {total_count} 条记录,分布在 {len(new_data)} 个群组中")
# 更新数据 # 更新数据
if new_data != self.group_data: if new_data != self.group_data:
self.group_data = new_data self.group_data = new_data
print("数据已更新,正在刷新显示...") logger.info("数据已更新,正在刷新显示...")
# 将更新任务添加到队列 # 将更新任务添加到队列
self.update_queue.put({'type': 'update_group_list'}) self.update_queue.put({'type': 'update_group_list'})
if self.group_data: if self.group_data:
@@ -341,16 +341,16 @@ class ReasoningGUI:
'type': 'update_display', 'type': 'update_display',
'group_id': self.selected_group_id 'group_id': self.selected_group_id
}) })
except Exception as e: except Exception:
print(f"自动更新出错: {e}") logger.exception(f"自动更新出错")
# 每5秒更新一次 # 每5秒更新一次
time.sleep(5) time.sleep(5)
def clear_display(self): def clear_display(self):
"""清除显示内容""" """清除显示内容"""
self.content_text.delete("1.0", "end") self.content_text.delete("1.0", "end")
def run(self): def run(self):
"""运行GUI""" """运行GUI"""
self.root.mainloop() self.root.mainloop()
@@ -359,18 +359,17 @@ class ReasoningGUI:
def main(): def main():
"""主函数""" """主函数"""
Database.initialize( Database.initialize(
host= os.getenv("MONGODB_HOST"), host=os.getenv("MONGODB_HOST"),
port= int(os.getenv("MONGODB_PORT")), port=int(os.getenv("MONGODB_PORT")),
db_name= os.getenv("DATABASE_NAME"), db_name=os.getenv("DATABASE_NAME"),
username= os.getenv("MONGODB_USERNAME"), username=os.getenv("MONGODB_USERNAME"),
password= os.getenv("MONGODB_PASSWORD"), password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE") auth_source=os.getenv("MONGODB_AUTH_SOURCE")
) )
app = ReasoningGUI() app = ReasoningGUI()
app.run() app.run()
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -2,9 +2,8 @@ import asyncio
import time import time
from loguru import logger from loguru import logger
from nonebot import get_driver, on_command, on_message, require from nonebot import get_driver, on_message, require
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment
from nonebot.rule import to_me
from nonebot.typing import T_State from nonebot.typing import T_State
from ...common.database import Database from ...common.database import Database
@@ -16,6 +15,10 @@ from .config import global_config
from .emoji_manager import emoji_manager from .emoji_manager import emoji_manager
from .relationship_manager import relationship_manager from .relationship_manager import relationship_manager
from .willing_manager import willing_manager from .willing_manager import willing_manager
from ..memory_system.memory import hippocampus, memory_graph
from .bot import ChatBot
from .message_sender import message_manager, message_sender
# 创建LLM统计实例 # 创建LLM统计实例
llm_stats = LLMStatistics("llm_statistics.txt") llm_stats = LLMStatistics("llm_statistics.txt")
@@ -35,19 +38,13 @@ Database.initialize(
password=config.MONGODB_PASSWORD, password=config.MONGODB_PASSWORD,
auth_source=config.MONGODB_AUTH_SOURCE auth_source=config.MONGODB_AUTH_SOURCE
) )
print("\033[1;32m[初始化数据库完成]\033[0m") logger.success("初始化数据库成功")
# 导入其他模块
from ..memory_system.memory import hippocampus, memory_graph
from .bot import ChatBot
# from .message_send_control import message_sender
from .message_sender import message_manager, message_sender
# 初始化表情管理器 # 初始化表情管理器
emoji_manager.initialize() emoji_manager.initialize()
print(f"\033[1;32m正在唤醒{global_config.BOT_NICKNAME}......\033[0m") logger.debug(f"正在唤醒{global_config.BOT_NICKNAME}......")
# 创建机器人实例 # 创建机器人实例
chat_bot = ChatBot() chat_bot = ChatBot()
# 注册群消息处理器 # 注册群消息处理器
@@ -61,12 +58,12 @@ async def start_background_tasks():
"""启动后台任务""" """启动后台任务"""
# 启动LLM统计 # 启动LLM统计
llm_stats.start() llm_stats.start()
logger.success("[初始化]LLM统计功能启动") logger.success("LLM统计功能启动成功")
# 初始化并启动情绪管理器 # 初始化并启动情绪管理器
mood_manager = MoodManager.get_instance() mood_manager = MoodManager.get_instance()
mood_manager.start_mood_update(update_interval=global_config.mood_update_interval) mood_manager.start_mood_update(update_interval=global_config.mood_update_interval)
logger.success("[初始化]情绪管理器启动") logger.success("情绪管理器启动成功")
# 只启动表情包管理任务 # 只启动表情包管理任务
asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL)) asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL))
@@ -77,7 +74,7 @@ async def start_background_tasks():
@driver.on_startup @driver.on_startup
async def init_relationships(): async def init_relationships():
"""在 NoneBot2 启动时初始化关系管理器""" """在 NoneBot2 启动时初始化关系管理器"""
print("\033[1;32m[初始化]\033[0m 正在加载用户关系数据...") logger.debug("正在加载用户关系数据...")
await relationship_manager.load_all_relationships() await relationship_manager.load_all_relationships()
asyncio.create_task(relationship_manager._start_relationship_manager()) asyncio.create_task(relationship_manager._start_relationship_manager())
@@ -86,19 +83,19 @@ async def init_relationships():
async def _(bot: Bot): async def _(bot: Bot):
"""Bot连接成功时的处理""" """Bot连接成功时的处理"""
global _message_manager_started global _message_manager_started
print(f"\033[1;38;5;208m-----------{global_config.BOT_NICKNAME}成功连接!-----------\033[0m") logger.debug(f"-----------{global_config.BOT_NICKNAME}成功连接!-----------")
await willing_manager.ensure_started() await willing_manager.ensure_started()
message_sender.set_bot(bot) message_sender.set_bot(bot)
print("\033[1;38;5;208m-----------消息发送器已启动!-----------\033[0m") logger.success("-----------消息发送器已启动!-----------")
if not _message_manager_started: if not _message_manager_started:
asyncio.create_task(message_manager.start_processor()) asyncio.create_task(message_manager.start_processor())
_message_manager_started = True _message_manager_started = True
print("\033[1;38;5;208m-----------消息处理器已启动!-----------\033[0m") logger.success("-----------消息处理器已启动!-----------")
asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL)) asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL))
print("\033[1;38;5;208m-----------开始偷表情包!-----------\033[0m") logger.success("-----------开始偷表情包!-----------")
@group_msg.handle() @group_msg.handle()
@@ -110,13 +107,15 @@ async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory") @scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory")
async def build_memory_task(): async def build_memory_task():
"""每build_memory_interval秒执行一次记忆构建""" """每build_memory_interval秒执行一次记忆构建"""
print( logger.debug(
"\033[1;32m[记忆构建]\033[0m -------------------------------------------开始构建记忆-------------------------------------------") "[记忆构建]"
"------------------------------------开始构建记忆--------------------------------------")
start_time = time.time() start_time = time.time()
await hippocampus.operation_build_memory(chat_size=20) await hippocampus.operation_build_memory(chat_size=20)
end_time = time.time() end_time = time.time()
print( logger.success(
f"\033[1;32m[记忆构建]\033[0m -------------------------------------------记忆构建完成:耗时: {end_time - start_time:.2f} 秒-------------------------------------------") f"[记忆构建]--------------------------记忆构建完成:耗时: {end_time - start_time:.2f} "
"秒-------------------------------------------")
@scheduler.scheduled_job("interval", seconds=global_config.forget_memory_interval, id="forget_memory") @scheduler.scheduled_job("interval", seconds=global_config.forget_memory_interval, id="forget_memory")

View File

@@ -31,10 +31,10 @@ class ChatBot:
self._started = False self._started = False
self.mood_manager = MoodManager.get_instance() # 获取情绪管理器单例 self.mood_manager = MoodManager.get_instance() # 获取情绪管理器单例
self.mood_manager.start_mood_update() # 启动情绪更新 self.mood_manager.start_mood_update() # 启动情绪更新
self.emoji_chance = 0.2 # 发送表情包的基础概率 self.emoji_chance = 0.2 # 发送表情包的基础概率
# self.message_streams = MessageStreamContainer() # self.message_streams = MessageStreamContainer()
async def _ensure_started(self): async def _ensure_started(self):
"""确保所有任务已启动""" """确保所有任务已启动"""
if not self._started: if not self._started:
@@ -42,26 +42,26 @@ class ChatBot:
async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None: async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None:
"""处理收到的群消息""" """处理收到的群消息"""
if event.group_id not in global_config.talk_allowed_groups: if event.group_id not in global_config.talk_allowed_groups:
return return
self.bot = bot # 更新 bot 实例 self.bot = bot # 更新 bot 实例
if event.user_id in global_config.ban_user_id: if event.user_id in global_config.ban_user_id:
return return
group_info = await bot.get_group_info(group_id=event.group_id) group_info = await bot.get_group_info(group_id=event.group_id)
sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True) sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info) await relationship_manager.update_relationship(user_id=event.user_id, data=sender_info)
await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5) await relationship_manager.update_relationship_value(user_id=event.user_id, relationship_value=0.5)
message = Message( message = Message(
group_id=event.group_id, group_id=event.group_id,
user_id=event.user_id, user_id=event.user_id,
message_id=event.message_id, message_id=event.message_id,
user_cardname=sender_info['card'], user_cardname=sender_info['card'],
raw_message=str(event.original_message), raw_message=str(event.original_message),
plain_text=event.get_plaintext(), plain_text=event.get_plaintext(),
reply_message=event.reply, reply_message=event.reply,
) )
@@ -70,26 +70,26 @@ class ChatBot:
# 过滤词 # 过滤词
for word in global_config.ban_words: for word in global_config.ban_words:
if word in message.detailed_plain_text: if word in message.detailed_plain_text:
logger.info(f"\033[1;32m[{message.group_name}]{message.user_nickname}:\033[0m {message.processed_plain_text}") logger.info(
logger.info(f"\033[1;32m[过滤词识别]\033[0m 消息中含有{word}filtered") f"[{message.group_name}]{message.user_nickname}:{message.processed_plain_text}")
logger.info(f"[过滤词识别]消息中含有{word}filtered")
return return
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time)) current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
# topic=await topic_identifier.identify_topic_llm(message.processed_plain_text) # topic=await topic_identifier.identify_topic_llm(message.processed_plain_text)
topic = '' topic = ''
interested_rate = 0 interested_rate = 0
interested_rate = await hippocampus.memory_activate_value(message.processed_plain_text)/100 interested_rate = await hippocampus.memory_activate_value(message.processed_plain_text) / 100
print(f"\033[1;32m[记忆激活]\033[0m 对{message.processed_plain_text}的激活度:---------------------------------------{interested_rate}\n") logger.debug(f"{message.processed_plain_text}"
f"的激活度:{interested_rate}")
# logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}") # logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}")
await self.storage.store_message(message, topic[0] if topic else None) await self.storage.store_message(message, topic[0] if topic else None)
is_mentioned = is_mentioned_bot_in_txt(message.processed_plain_text) is_mentioned = is_mentioned_bot_in_txt(message.processed_plain_text)
reply_probability = willing_manager.change_reply_willing_received( reply_probability = willing_manager.change_reply_willing_received(
event.group_id, event.group_id,
topic[0] if topic else None, topic[0] if topic else None,
is_mentioned, is_mentioned,
global_config, global_config,
@@ -98,25 +98,24 @@ class ChatBot:
interested_rate interested_rate
) )
current_willing = willing_manager.get_willing(event.group_id) current_willing = willing_manager.get_willing(event.group_id)
logger.info(
print(f"\033[1;32m[{current_time}][{message.group_name}]{message.user_nickname}:\033[0m {message.processed_plain_text}\033[1;36m[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]\033[0m") f"[{current_time}][{message.group_name}]{message.user_nickname}:"
f"{message.processed_plain_text}[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]")
response = "" response = ""
if random() < reply_probability: if random() < reply_probability:
tinking_time_point = round(time.time(), 2) tinking_time_point = round(time.time(), 2)
think_id = 'mt' + str(tinking_time_point) think_id = 'mt' + str(tinking_time_point)
thinking_message = Message_Thinking(message=message,message_id=think_id) thinking_message = Message_Thinking(message=message, message_id=think_id)
message_manager.add_message(thinking_message) message_manager.add_message(thinking_message)
willing_manager.change_reply_willing_sent(thinking_message.group_id) willing_manager.change_reply_willing_sent(thinking_message.group_id)
response,raw_content = await self.gpt.generate_response(message) response, raw_content = await self.gpt.generate_response(message)
if response: if response:
container = message_manager.get_container(event.group_id) container = message_manager.get_container(event.group_id)
thinking_message = None thinking_message = None
@@ -127,27 +126,28 @@ class ChatBot:
container.messages.remove(msg) container.messages.remove(msg)
# print(f"\033[1;32m[思考消息删除]\033[0m 已找到思考消息对象,开始删除") # print(f"\033[1;32m[思考消息删除]\033[0m 已找到思考消息对象,开始删除")
break break
# 如果找不到思考消息,直接返回 # 如果找不到思考消息,直接返回
if not thinking_message: if not thinking_message:
print(f"\033[1;33m[警告]\033[0m 未找到对应的思考消息,可能已超时被移除") logger.warning(f"未找到对应的思考消息,可能已超时被移除")
return return
#记录开始思考的时间,避免从思考到回复的时间太久 # 记录开始思考的时间,避免从思考到回复的时间太久
thinking_start_time = thinking_message.thinking_start_time thinking_start_time = thinking_message.thinking_start_time
message_set = MessageSet(event.group_id, global_config.BOT_QQ, think_id) # 发送消息的id和产生发送消息的message_thinking是一致的 message_set = MessageSet(event.group_id, global_config.BOT_QQ,
#计算打字时间1是为了模拟打字2是避免多条回复乱序 think_id) # 发送消息的id和产生发送消息的message_thinking是一致的
# 计算打字时间1是为了模拟打字2是避免多条回复乱序
accu_typing_time = 0 accu_typing_time = 0
# print(f"\033[1;32m[开始回复]\033[0m 开始将回复1载入发送容器") # print(f"\033[1;32m[开始回复]\033[0m 开始将回复1载入发送容器")
mark_head = False mark_head = False
for msg in response: for msg in response:
# print(f"\033[1;32m[回复内容]\033[0m {msg}") # print(f"\033[1;32m[回复内容]\033[0m {msg}")
#通过时间改变时间戳 # 通过时间改变时间戳
typing_time = calculate_typing_time(msg) typing_time = calculate_typing_time(msg)
accu_typing_time += typing_time accu_typing_time += typing_time
timepoint = tinking_time_point + accu_typing_time timepoint = tinking_time_point + accu_typing_time
bot_message = Message_Sending( bot_message = Message_Sending(
group_id=event.group_id, group_id=event.group_id,
user_id=global_config.BOT_QQ, user_id=global_config.BOT_QQ,
@@ -157,8 +157,8 @@ class ChatBot:
processed_plain_text=msg, processed_plain_text=msg,
user_nickname=global_config.BOT_NICKNAME, user_nickname=global_config.BOT_NICKNAME,
group_name=message.group_name, group_name=message.group_name,
time=timepoint, #记录了回复生成的时间 time=timepoint, # 记录了回复生成的时间
thinking_start_time=thinking_start_time, #记录了思考开始的时间 thinking_start_time=thinking_start_time, # 记录了思考开始的时间
reply_message_id=message.message_id reply_message_id=message.message_id
) )
await bot_message.initialize() await bot_message.initialize()
@@ -166,27 +166,27 @@ class ChatBot:
bot_message.is_head = True bot_message.is_head = True
mark_head = True mark_head = True
message_set.add_message(bot_message) message_set.add_message(bot_message)
#message_set 可以直接加入 message_manager # message_set 可以直接加入 message_manager
# print(f"\033[1;32m[回复]\033[0m 将回复载入发送容器") # print(f"\033[1;32m[回复]\033[0m 将回复载入发送容器")
message_manager.add_message(message_set) message_manager.add_message(message_set)
bot_response_time = tinking_time_point bot_response_time = tinking_time_point
if random() < global_config.emoji_chance: if random() < global_config.emoji_chance:
emoji_raw = await emoji_manager.get_emoji_for_text(response) emoji_raw = await emoji_manager.get_emoji_for_text(response)
# 检查是否 <没有找到> emoji # 检查是否 <没有找到> emoji
if emoji_raw != None: if emoji_raw != None:
emoji_path,discription = emoji_raw emoji_path, discription = emoji_raw
emoji_cq = CQCode.create_emoji_cq(emoji_path) emoji_cq = CQCode.create_emoji_cq(emoji_path)
if random() < 0.5: if random() < 0.5:
bot_response_time = tinking_time_point - 1 bot_response_time = tinking_time_point - 1
else: else:
bot_response_time = bot_response_time + 1 bot_response_time = bot_response_time + 1
bot_message = Message_Sending( bot_message = Message_Sending(
group_id=event.group_id, group_id=event.group_id,
user_id=global_config.BOT_QQ, user_id=global_config.BOT_QQ,
@@ -206,8 +206,8 @@ class ChatBot:
await bot_message.initialize() await bot_message.initialize()
message_manager.add_message(bot_message) message_manager.add_message(bot_message)
emotion = await self.gpt._get_emotion_tags(raw_content) emotion = await self.gpt._get_emotion_tags(raw_content)
print(f"'{response}' 获取到的情感标签为:{emotion}") logger.debug(f"'{response}' 获取到的情感标签为:{emotion}")
valuedict={ valuedict = {
'happy': 0.5, 'happy': 0.5,
'angry': -1, 'angry': -1,
'sad': -0.5, 'sad': -0.5,
@@ -216,11 +216,13 @@ class ChatBot:
'fearful': -0.7, 'fearful': -0.7,
'neutral': 0.1 'neutral': 0.1
} }
await relationship_manager.update_relationship_value(message.user_id, relationship_value=valuedict[emotion[0]]) await relationship_manager.update_relationship_value(message.user_id,
relationship_value=valuedict[emotion[0]])
# 使用情绪管理器更新情绪 # 使用情绪管理器更新情绪
self.mood_manager.update_mood_from_emotion(emotion[0], global_config.mood_intensity_factor) self.mood_manager.update_mood_from_emotion(emotion[0], global_config.mood_intensity_factor)
# willing_manager.change_reply_willing_after_sent(event.group_id) # willing_manager.change_reply_willing_after_sent(event.group_id)
# 创建全局ChatBot实例 # 创建全局ChatBot实例
chat_bot = ChatBot() chat_bot = ChatBot()

View File

@@ -135,7 +135,7 @@ class BotConfig:
try: try:
config_version: str = toml["inner"]["version"] config_version: str = toml["inner"]["version"]
except KeyError as e: except KeyError as e:
logger.error(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件") logger.error(f"配置文件中 inner 段 不存在, 这是错误的配置文件")
raise KeyError(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件") raise KeyError(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件")
else: else:
toml["inner"] = {"version": "0.0.0"} toml["inner"] = {"version": "0.0.0"}
@@ -162,7 +162,7 @@ class BotConfig:
personality_config = parent['personality'] personality_config = parent['personality']
personality = personality_config.get('prompt_personality') personality = personality_config.get('prompt_personality')
if len(personality) >= 2: if len(personality) >= 2:
logger.info(f"载入自定义人格:{personality}") logger.debug(f"载入自定义人格:{personality}")
config.PROMPT_PERSONALITY = personality_config.get('prompt_personality', config.PROMPT_PERSONALITY) config.PROMPT_PERSONALITY = personality_config.get('prompt_personality', config.PROMPT_PERSONALITY)
logger.info(f"载入自定义日程prompt:{personality_config.get('prompt_schedule', config.PROMPT_SCHEDULE_GEN)}") logger.info(f"载入自定义日程prompt:{personality_config.get('prompt_schedule', config.PROMPT_SCHEDULE_GEN)}")
config.PROMPT_SCHEDULE_GEN = personality_config.get('prompt_schedule', config.PROMPT_SCHEDULE_GEN) config.PROMPT_SCHEDULE_GEN = personality_config.get('prompt_schedule', config.PROMPT_SCHEDULE_GEN)
@@ -246,11 +246,11 @@ class BotConfig:
try: try:
cfg_target[i] = cfg_item[i] cfg_target[i] = cfg_item[i]
except KeyError as e: except KeyError as e:
logger.error(f"{item} 中的必要字段 {e} 不存在,请检查") logger.error(f"{item} 中的必要字段不存在,请检查")
raise KeyError(f"{item} 中的必要字段 {e} 不存在,请检查") raise KeyError(f"{item} 中的必要字段 {e} 不存在,请检查")
provider = cfg_item.get("provider") provider = cfg_item.get("provider")
if provider == None: if provider is None:
logger.error(f"provider 字段在模型配置 {item} 中不存在,请检查") logger.error(f"provider 字段在模型配置 {item} 中不存在,请检查")
raise KeyError(f"provider 字段在模型配置 {item} 中不存在,请检查") raise KeyError(f"provider 字段在模型配置 {item} 中不存在,请检查")

View File

@@ -4,6 +4,7 @@ import os
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional from typing import Dict, Optional
from loguru import logger
import requests import requests
@@ -151,11 +152,11 @@ class CQCode:
except (requests.exceptions.SSLError, requests.exceptions.HTTPError) as e: except (requests.exceptions.SSLError, requests.exceptions.HTTPError) as e:
if retry == max_retries - 1: if retry == max_retries - 1:
print(f"\033[1;31m[致命错误]\033[0m 最终请求失败: {str(e)}") logger.error(f"最终请求失败: {str(e)}")
time.sleep(1.5 ** retry) # 指数退避 time.sleep(1.5 ** retry) # 指数退避
except Exception as e: except Exception as e:
print(f"\033[1;33m[未知错误]\033[0m {str(e)}") logger.exception(f"[未知错误]")
return None return None
return None return None
@@ -194,7 +195,7 @@ class CQCode:
description, _ = await self._llm.generate_response_for_image(prompt, image_base64) description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
return f"[表情包:{description}]" return f"[表情包:{description}]"
except Exception as e: except Exception as e:
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}") logger.exception(f"AI接口调用失败: {str(e)}")
return "[表情包]" return "[表情包]"
async def get_image_description(self, image_base64: str) -> str: async def get_image_description(self, image_base64: str) -> str:
@@ -205,7 +206,7 @@ class CQCode:
description, _ = await self._llm.generate_response_for_image(prompt, image_base64) description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
return f"[图片:{description}]" return f"[图片:{description}]"
except Exception as e: except Exception as e:
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}") logger.exception(f"AI接口调用失败: {str(e)}")
return "[图片]" return "[图片]"
async def translate_forward(self) -> str: async def translate_forward(self) -> str:
@@ -222,7 +223,7 @@ class CQCode:
try: try:
messages = ast.literal_eval(content) messages = ast.literal_eval(content)
except ValueError as e: except ValueError as e:
print(f"\033[1;31m[错误]\033[0m 解析转发消息内容失败: {str(e)}") logger.error(f"解析转发消息内容失败: {str(e)}")
return '[转发消息]' return '[转发消息]'
# 处理每条消息 # 处理每条消息
@@ -277,11 +278,11 @@ class CQCode:
# 合并所有消息 # 合并所有消息
combined_messages = '\n'.join(formatted_messages) combined_messages = '\n'.join(formatted_messages)
print(f"\033[1;34m[调试信息]\033[0m 合并后的转发消息: {combined_messages}") logger.debug(f"合并后的转发消息: {combined_messages}")
return f"[转发消息:\n{combined_messages}]" return f"[转发消息:\n{combined_messages}]"
except Exception as e: except Exception as e:
print(f"\033[1;31m[错误]\033[0m 处理转发消息失败: {str(e)}") logger.exception("处理转发消息失败")
return '[转发消息]' return '[转发消息]'
async def translate_reply(self) -> str: async def translate_reply(self) -> str:
@@ -307,7 +308,7 @@ class CQCode:
return f"[回复 {self.reply_message.sender.nickname} 的消息: {message_obj.processed_plain_text}]" return f"[回复 {self.reply_message.sender.nickname} 的消息: {message_obj.processed_plain_text}]"
else: else:
print("\033[1;31m[错误]\033[0m 回复消息的sender.user_id为空") logger.error("回复消息的sender.user_id为空")
return '[回复某人消息]' return '[回复某人消息]'
@staticmethod @staticmethod

View File

@@ -21,24 +21,26 @@ config = driver.config
class EmojiManager: class EmojiManager:
_instance = None _instance = None
EMOJI_DIR = "data/emoji" # 表情包存储目录 EMOJI_DIR = "data/emoji" # 表情包存储目录
def __new__(cls): def __new__(cls):
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
cls._instance.db = None cls._instance.db = None
cls._instance._initialized = False cls._instance._initialized = False
return cls._instance return cls._instance
def __init__(self): def __init__(self):
self.db = Database.get_instance() self.db = Database.get_instance()
self._scan_task = None self._scan_task = None
self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000) self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000)
self.llm_emotion_judge = LLM_request(model=global_config.llm_emotion_judge, max_tokens=60,temperature=0.8) #更高的温度更少的token后续可以根据情绪来调整温度 self.llm_emotion_judge = LLM_request(model=global_config.llm_normal_minor, max_tokens=60,
temperature=0.8) # 更高的温度更少的token后续可以根据情绪来调整温度
def _ensure_emoji_dir(self): def _ensure_emoji_dir(self):
"""确保表情存储目录存在""" """确保表情存储目录存在"""
os.makedirs(self.EMOJI_DIR, exist_ok=True) os.makedirs(self.EMOJI_DIR, exist_ok=True)
def initialize(self): def initialize(self):
"""初始化数据库连接和表情目录""" """初始化数据库连接和表情目录"""
if not self._initialized: if not self._initialized:
@@ -50,15 +52,15 @@ class EmojiManager:
# 启动时执行一次完整性检查 # 启动时执行一次完整性检查
self.check_emoji_file_integrity() self.check_emoji_file_integrity()
except Exception as e: except Exception as e:
logger.error(f"初始化表情管理器失败: {str(e)}") logger.exception(f"初始化表情管理器失败")
def _ensure_db(self): def _ensure_db(self):
"""确保数据库已初始化""" """确保数据库已初始化"""
if not self._initialized: if not self._initialized:
self.initialize() self.initialize()
if not self._initialized: if not self._initialized:
raise RuntimeError("EmojiManager not initialized") raise RuntimeError("EmojiManager not initialized")
def _ensure_emoji_collection(self): def _ensure_emoji_collection(self):
"""确保emoji集合存在并创建索引 """确保emoji集合存在并创建索引
@@ -76,7 +78,7 @@ class EmojiManager:
self.db.db.emoji.create_index([('embedding', '2dsphere')]) self.db.db.emoji.create_index([('embedding', '2dsphere')])
self.db.db.emoji.create_index([('tags', 1)]) self.db.db.emoji.create_index([('tags', 1)])
self.db.db.emoji.create_index([('filename', 1)], unique=True) self.db.db.emoji.create_index([('filename', 1)], unique=True)
def record_usage(self, emoji_id: str): def record_usage(self, emoji_id: str):
"""记录表情使用次数""" """记录表情使用次数"""
try: try:
@@ -86,8 +88,8 @@ class EmojiManager:
{'$inc': {'usage_count': 1}} {'$inc': {'usage_count': 1}}
) )
except Exception as e: except Exception as e:
logger.error(f"记录表情使用失败: {str(e)}") logger.exception(f"记录表情使用失败")
async def get_emoji_for_text(self, text: str) -> Optional[str]: async def get_emoji_for_text(self, text: str) -> Optional[str]:
"""根据文本内容获取相关表情包 """根据文本内容获取相关表情包
Args: Args:
@@ -102,9 +104,9 @@ class EmojiManager:
""" """
try: try:
self._ensure_db() self._ensure_db()
# 获取文本的embedding # 获取文本的embedding
text_for_search= await self._get_kimoji_for_text(text) text_for_search = await self._get_kimoji_for_text(text)
if not text_for_search: if not text_for_search:
logger.error("无法获取文本的情绪") logger.error("无法获取文本的情绪")
return None return None
@@ -112,15 +114,15 @@ class EmojiManager:
if not text_embedding: if not text_embedding:
logger.error("无法获取文本的embedding") logger.error("无法获取文本的embedding")
return None return None
try: try:
# 获取所有表情包 # 获取所有表情包
all_emojis = list(self.db.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'discription': 1})) all_emojis = list(self.db.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'discription': 1}))
if not all_emojis: if not all_emojis:
logger.warning("数据库中没有任何表情包") logger.warning("数据库中没有任何表情包")
return None return None
# 计算余弦相似度并排序 # 计算余弦相似度并排序
def cosine_similarity(v1, v2): def cosine_similarity(v1, v2):
if not v1 or not v2: if not v1 or not v2:
@@ -131,42 +133,43 @@ class EmojiManager:
if norm_v1 == 0 or norm_v2 == 0: if norm_v1 == 0 or norm_v2 == 0:
return 0 return 0
return dot_product / (norm_v1 * norm_v2) return dot_product / (norm_v1 * norm_v2)
# 计算所有表情包与输入文本的相似度 # 计算所有表情包与输入文本的相似度
emoji_similarities = [ emoji_similarities = [
(emoji, cosine_similarity(text_embedding, emoji.get('embedding', []))) (emoji, cosine_similarity(text_embedding, emoji.get('embedding', [])))
for emoji in all_emojis for emoji in all_emojis
] ]
# 按相似度降序排序 # 按相似度降序排序
emoji_similarities.sort(key=lambda x: x[1], reverse=True) emoji_similarities.sort(key=lambda x: x[1], reverse=True)
# 获取前3个最相似的表情包 # 获取前3个最相似的表情包
top_3_emojis = emoji_similarities[:3] top_3_emojis = emoji_similarities[:3]
if not top_3_emojis: if not top_3_emojis:
logger.warning("未找到匹配的表情包") logger.warning("未找到匹配的表情包")
return None return None
# 从前3个中随机选择一个 # 从前3个中随机选择一个
selected_emoji, similarity = random.choice(top_3_emojis) selected_emoji, similarity = random.choice(top_3_emojis)
if selected_emoji and 'path' in selected_emoji: if selected_emoji and 'path' in selected_emoji:
# 更新使用次数 # 更新使用次数
self.db.db.emoji.update_one( self.db.db.emoji.update_one(
{'_id': selected_emoji['_id']}, {'_id': selected_emoji['_id']},
{'$inc': {'usage_count': 1}} {'$inc': {'usage_count': 1}}
) )
logger.success(f"找到匹配的表情包: {selected_emoji.get('discription', '无描述')} (相似度: {similarity:.4f})") logger.success(
f"找到匹配的表情包: {selected_emoji.get('discription', '无描述')} (相似度: {similarity:.4f})")
# 稍微改一下文本描述,不然容易产生幻觉,描述已经包含 表情包 了 # 稍微改一下文本描述,不然容易产生幻觉,描述已经包含 表情包 了
return selected_emoji['path'],"[ %s ]" % selected_emoji.get('discription', '无描述') return selected_emoji['path'], "[ %s ]" % selected_emoji.get('discription', '无描述')
except Exception as search_error: except Exception as search_error:
logger.error(f"搜索表情包失败: {str(search_error)}") logger.error(f"搜索表情包失败: {str(search_error)}")
return None return None
return None return None
except Exception as e: except Exception as e:
logger.error(f"获取表情包失败: {str(e)}") logger.error(f"获取表情包失败: {str(e)}")
return None return None
@@ -175,39 +178,39 @@ class EmojiManager:
"""获取表情包的标签""" """获取表情包的标签"""
try: try:
prompt = '这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感' prompt = '这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感'
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64) content, _ = await self.vlm.generate_response_for_image(prompt, image_base64)
logger.debug(f"输出描述: {content}") logger.debug(f"输出描述: {content}")
return content return content
except Exception as e: except Exception as e:
logger.error(f"获取标签失败: {str(e)}") logger.error(f"获取标签失败: {str(e)}")
return None return None
async def _check_emoji(self, image_base64: str) -> str: async def _check_emoji(self, image_base64: str) -> str:
try: try:
prompt = f'这是一个表情包,请回答这个表情包是否满足\"{global_config.EMOJI_CHECK_PROMPT}\"的要求,是则回答是,否则回答否,不要出现任何其他内容' prompt = f'这是一个表情包,请回答这个表情包是否满足\"{global_config.EMOJI_CHECK_PROMPT}\"的要求,是则回答是,否则回答否,不要出现任何其他内容'
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64) content, _ = await self.vlm.generate_response_for_image(prompt, image_base64)
logger.debug(f"输出描述: {content}") logger.debug(f"输出描述: {content}")
return content return content
except Exception as e: except Exception as e:
logger.error(f"获取标签失败: {str(e)}") logger.error(f"获取标签失败: {str(e)}")
return None return None
async def _get_kimoji_for_text(self, text:str): async def _get_kimoji_for_text(self, text: str):
try: try:
prompt = f'这是{global_config.BOT_NICKNAME}将要发送的消息内容:\n{text}\n若要为其配上表情包,请你输出这个表情包应该表达怎样的情感,应该给人什么样的感觉,不要太简洁也不要太长,注意不要输出任何对消息内容的分析内容,只输出\"一种什么样的感觉\"中间的形容词部分。' prompt = f'这是{global_config.BOT_NICKNAME}将要发送的消息内容:\n{text}\n若要为其配上表情包,请你输出这个表情包应该表达怎样的情感,应该给人什么样的感觉,不要太简洁也不要太长,注意不要输出任何对消息内容的分析内容,只输出\"一种什么样的感觉\"中间的形容词部分。'
content, _ = await self.llm_emotion_judge.generate_response_async(prompt) content, _ = await self.llm_emotion_judge.generate_response_async(prompt)
logger.info(f"输出描述: {content}") logger.info(f"输出描述: {content}")
return content return content
except Exception as e: except Exception as e:
logger.error(f"获取标签失败: {str(e)}") logger.error(f"获取标签失败: {str(e)}")
return None return None
async def scan_new_emojis(self): async def scan_new_emojis(self):
"""扫描新的表情包""" """扫描新的表情包"""
try: try:
@@ -215,22 +218,23 @@ class EmojiManager:
os.makedirs(emoji_dir, exist_ok=True) os.makedirs(emoji_dir, exist_ok=True)
# 获取所有支持的图片文件 # 获取所有支持的图片文件
files_to_process = [f for f in os.listdir(emoji_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.gif'))] files_to_process = [f for f in os.listdir(emoji_dir) if
f.lower().endswith(('.jpg', '.jpeg', '.png', '.gif'))]
for filename in files_to_process: for filename in files_to_process:
image_path = os.path.join(emoji_dir, filename) image_path = os.path.join(emoji_dir, filename)
# 检查是否已经注册过 # 检查是否已经注册过
existing_emoji = self.db.db['emoji'].find_one({'filename': filename}) existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
if existing_emoji: if existing_emoji:
continue continue
# 压缩图片并获取base64编码 # 压缩图片并获取base64编码
image_base64 = image_path_to_base64(image_path) image_base64 = image_path_to_base64(image_path)
if image_base64 is None: if image_base64 is None:
os.remove(image_path) os.remove(image_path)
continue continue
# 获取表情包的描述 # 获取表情包的描述
discription = await self._get_emoji_discription(image_base64) discription = await self._get_emoji_discription(image_base64)
if global_config.EMOJI_CHECK: if global_config.EMOJI_CHECK:
@@ -247,30 +251,28 @@ class EmojiManager:
emoji_record = { emoji_record = {
'filename': filename, 'filename': filename,
'path': image_path, 'path': image_path,
'embedding':embedding, 'embedding': embedding,
'discription': discription, 'discription': discription,
'timestamp': int(time.time()) 'timestamp': int(time.time())
} }
# 保存到数据库 # 保存到数据库
self.db.db['emoji'].insert_one(emoji_record) self.db.db['emoji'].insert_one(emoji_record)
logger.success(f"注册新表情包: {filename}") logger.success(f"注册新表情包: {filename}")
logger.info(f"描述: {discription}") logger.info(f"描述: {discription}")
else: else:
logger.warning(f"跳过表情包: {filename}") logger.warning(f"跳过表情包: {filename}")
except Exception as e: except Exception as e:
logger.error(f"扫描表情包失败: {str(e)}") logger.exception(f"扫描表情包失败")
logger.error(traceback.format_exc())
async def _periodic_scan(self, interval_MINS: int = 10): async def _periodic_scan(self, interval_MINS: int = 10):
"""定期扫描新表情包""" """定期扫描新表情包"""
while True: while True:
print("\033[1;36m[表情包]\033[0m 开始扫描新表情包...") logger.info("开始扫描新表情包...")
await self.scan_new_emojis() await self.scan_new_emojis()
await asyncio.sleep(interval_MINS * 60) # 每600秒扫描一次 await asyncio.sleep(interval_MINS * 60) # 每600秒扫描一次
def check_emoji_file_integrity(self): def check_emoji_file_integrity(self):
"""检查表情包文件完整性 """检查表情包文件完整性
如果文件已被删除,则从数据库中移除对应记录 如果文件已被删除,则从数据库中移除对应记录
@@ -281,7 +283,7 @@ class EmojiManager:
all_emojis = list(self.db.db.emoji.find()) all_emojis = list(self.db.db.emoji.find())
removed_count = 0 removed_count = 0
total_count = len(all_emojis) total_count = len(all_emojis)
for emoji in all_emojis: for emoji in all_emojis:
try: try:
if 'path' not in emoji: if 'path' not in emoji:
@@ -289,27 +291,27 @@ class EmojiManager:
self.db.db.emoji.delete_one({'_id': emoji['_id']}) self.db.db.emoji.delete_one({'_id': emoji['_id']})
removed_count += 1 removed_count += 1
continue continue
if 'embedding' not in emoji: if 'embedding' not in emoji:
logger.warning(f"发现过时记录缺少embedding字段ID: {emoji.get('_id', 'unknown')}") logger.warning(f"发现过时记录缺少embedding字段ID: {emoji.get('_id', 'unknown')}")
self.db.db.emoji.delete_one({'_id': emoji['_id']}) self.db.db.emoji.delete_one({'_id': emoji['_id']})
removed_count += 1 removed_count += 1
continue continue
# 检查文件是否存在 # 检查文件是否存在
if not os.path.exists(emoji['path']): if not os.path.exists(emoji['path']):
logger.warning(f"表情包文件已被删除: {emoji['path']}") logger.warning(f"表情包文件已被删除: {emoji['path']}")
# 从数据库中删除记录 # 从数据库中删除记录
result = self.db.db.emoji.delete_one({'_id': emoji['_id']}) result = self.db.db.emoji.delete_one({'_id': emoji['_id']})
if result.deleted_count > 0: if result.deleted_count > 0:
logger.success(f"成功删除数据库记录: {emoji['_id']}") logger.debug(f"成功删除数据库记录: {emoji['_id']}")
removed_count += 1 removed_count += 1
else: else:
logger.error(f"删除数据库记录失败: {emoji['_id']}") logger.error(f"删除数据库记录失败: {emoji['_id']}")
except Exception as item_error: except Exception as item_error:
logger.error(f"处理表情包记录时出错: {str(item_error)}") logger.error(f"处理表情包记录时出错: {str(item_error)}")
continue continue
# 验证清理结果 # 验证清理结果
remaining_count = self.db.db.emoji.count_documents({}) remaining_count = self.db.db.emoji.count_documents({})
if removed_count > 0: if removed_count > 0:
@@ -317,7 +319,7 @@ class EmojiManager:
logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}") logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}")
else: else:
logger.info(f"已检查 {total_count} 个表情包记录") logger.info(f"已检查 {total_count} 个表情包记录")
except Exception as e: except Exception as e:
logger.error(f"检查表情包完整性失败: {str(e)}") logger.error(f"检查表情包完整性失败: {str(e)}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
@@ -328,6 +330,6 @@ class EmojiManager:
await asyncio.sleep(interval_MINS * 60) await asyncio.sleep(interval_MINS * 60)
# 创建全局单例 # 创建全局单例
emoji_manager = EmojiManager() emoji_manager = EmojiManager()

View File

@@ -3,6 +3,7 @@ import time
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from nonebot import get_driver from nonebot import get_driver
from loguru import logger
from ...common.database import Database from ...common.database import Database
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
@@ -39,13 +40,13 @@ class ResponseGenerator:
self.current_model_type = 'r1_distill' self.current_model_type = 'r1_distill'
current_model = self.model_r1_distill current_model = self.model_r1_distill
print(f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++") logger.info(f"{global_config.BOT_NICKNAME}{self.current_model_type}思考中")
model_response = await self._generate_response_with_model(message, current_model) model_response = await self._generate_response_with_model(message, current_model)
raw_content=model_response raw_content=model_response
if model_response: if model_response:
print(f'{global_config.BOT_NICKNAME}的回复是:{model_response}') logger.info(f'{global_config.BOT_NICKNAME}的回复是:{model_response}')
model_response = await self._process_response(model_response) model_response = await self._process_response(model_response)
if model_response: if model_response:
@@ -92,8 +93,8 @@ class ResponseGenerator:
# 生成回复 # 生成回复
try: try:
content, reasoning_content = await model.generate_response(prompt) content, reasoning_content = await model.generate_response(prompt)
except Exception as e: except Exception:
print(f"生成回复时出错: {e}") logger.exception(f"生成回复时出错")
return None return None
# 保存到数据库 # 保存到数据库
@@ -144,8 +145,8 @@ class ResponseGenerator:
else: else:
return ["neutral"] return ["neutral"]
except Exception as e: except Exception:
print(f"获取情感标签时出错: {e}") logger.exception(f"获取情感标签时出错")
return ["neutral"] return ["neutral"]
async def _process_response(self, content: str) -> Tuple[List[str], List[str]]: async def _process_response(self, content: str) -> Tuple[List[str], List[str]]:
@@ -172,7 +173,7 @@ class InitiativeMessageGenerate:
prompt_builder._build_initiative_prompt_select(message.group_id) prompt_builder._build_initiative_prompt_select(message.group_id)
) )
content_select, reasoning = self.model_v3.generate_response(topic_select_prompt) content_select, reasoning = self.model_v3.generate_response(topic_select_prompt)
print(f"[DEBUG] {content_select} {reasoning}") logger.debug(f"{content_select} {reasoning}")
topics_list = [dot[0] for dot in dots_for_select] topics_list = [dot[0] for dot in dots_for_select]
if content_select: if content_select:
if content_select in topics_list: if content_select in topics_list:
@@ -185,12 +186,12 @@ class InitiativeMessageGenerate:
select_dot[1], prompt_template select_dot[1], prompt_template
) )
content_check, reasoning_check = self.model_v3.generate_response(prompt_check) content_check, reasoning_check = self.model_v3.generate_response(prompt_check)
print(f"[DEBUG] {content_check} {reasoning_check}") logger.info(f"{content_check} {reasoning_check}")
if "yes" not in content_check.lower(): if "yes" not in content_check.lower():
return None return None
prompt = prompt_builder._build_initiative_prompt( prompt = prompt_builder._build_initiative_prompt(
select_dot, prompt_template, memory select_dot, prompt_template, memory
) )
content, reasoning = self.model_r1.generate_response_async(prompt) content, reasoning = self.model_r1.generate_response_async(prompt)
print(f"[DEBUG] {content} {reasoning}") logger.debug(f"[DEBUG] {content} {reasoning}")
return content return content

View File

@@ -2,6 +2,7 @@ import asyncio
import time import time
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from loguru import logger
from nonebot.adapters.onebot.v11 import Bot from nonebot.adapters.onebot.v11 import Bot
from .cq_code import cq_code_tool from .cq_code import cq_code_tool
@@ -13,45 +14,45 @@ from .config import global_config
class Message_Sender: class Message_Sender:
"""发送器""" """发送器"""
def __init__(self): def __init__(self):
self.message_interval = (0.5, 1) # 消息间隔时间范围(秒) self.message_interval = (0.5, 1) # 消息间隔时间范围(秒)
self.last_send_time = 0 self.last_send_time = 0
self._current_bot = None self._current_bot = None
def set_bot(self, bot: Bot): def set_bot(self, bot: Bot):
"""设置当前bot实例""" """设置当前bot实例"""
self._current_bot = bot self._current_bot = bot
async def send_group_message( async def send_group_message(
self, self,
group_id: int, group_id: int,
send_text: str, send_text: str,
auto_escape: bool = False, auto_escape: bool = False,
reply_message_id: int = None, reply_message_id: int = None,
at_user_id: int = None at_user_id: int = None
) -> None: ) -> None:
if not self._current_bot: if not self._current_bot:
raise RuntimeError("Bot未设置请先调用set_bot方法设置bot实例") raise RuntimeError("Bot未设置请先调用set_bot方法设置bot实例")
message = send_text message = send_text
# 如果需要回复 # 如果需要回复
if reply_message_id: if reply_message_id:
reply_cq = cq_code_tool.create_reply_cq(reply_message_id) reply_cq = cq_code_tool.create_reply_cq(reply_message_id)
message = reply_cq + message message = reply_cq + message
# 如果需要at # 如果需要at
# if at_user_id: # if at_user_id:
# at_cq = cq_code_tool.create_at_cq(at_user_id) # at_cq = cq_code_tool.create_at_cq(at_user_id)
# message = at_cq + " " + message # message = at_cq + " " + message
typing_time = calculate_typing_time(message) typing_time = calculate_typing_time(message)
if typing_time > 10: if typing_time > 10:
typing_time = 10 typing_time = 10
await asyncio.sleep(typing_time) await asyncio.sleep(typing_time)
# 发送消息 # 发送消息
try: try:
await self._current_bot.send_group_msg( await self._current_bot.send_group_msg(
@@ -59,49 +60,49 @@ class Message_Sender:
message=message, message=message,
auto_escape=auto_escape auto_escape=auto_escape
) )
print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功") logger.debug(f"发送消息{message}成功")
except Exception as e: except Exception as e:
print(f"生错误 {e}") logger.exception(f"送消息{message}失败")
print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败")
class MessageContainer: class MessageContainer:
"""单个群的发送/思考消息容器""" """单个群的发送/思考消息容器"""
def __init__(self, group_id: int, max_size: int = 100): def __init__(self, group_id: int, max_size: int = 100):
self.group_id = group_id self.group_id = group_id
self.max_size = max_size self.max_size = max_size
self.messages = [] self.messages = []
self.last_send_time = 0 self.last_send_time = 0
self.thinking_timeout = 20 # 思考超时时间(秒) self.thinking_timeout = 20 # 思考超时时间(秒)
def get_timeout_messages(self) -> List[Message_Sending]: def get_timeout_messages(self) -> List[Message_Sending]:
"""获取所有超时的Message_Sending对象思考时间超过30秒按thinking_start_time排序""" """获取所有超时的Message_Sending对象思考时间超过30秒按thinking_start_time排序"""
current_time = time.time() current_time = time.time()
timeout_messages = [] timeout_messages = []
for msg in self.messages: for msg in self.messages:
if isinstance(msg, Message_Sending): if isinstance(msg, Message_Sending):
if current_time - msg.thinking_start_time > self.thinking_timeout: if current_time - msg.thinking_start_time > self.thinking_timeout:
timeout_messages.append(msg) timeout_messages.append(msg)
# 按thinking_start_time排序时间早的在前面 # 按thinking_start_time排序时间早的在前面
timeout_messages.sort(key=lambda x: x.thinking_start_time) timeout_messages.sort(key=lambda x: x.thinking_start_time)
return timeout_messages return timeout_messages
def get_earliest_message(self) -> Optional[Union[Message_Thinking, Message_Sending]]: def get_earliest_message(self) -> Optional[Union[Message_Thinking, Message_Sending]]:
"""获取thinking_start_time最早的消息对象""" """获取thinking_start_time最早的消息对象"""
if not self.messages: if not self.messages:
return None return None
earliest_time = float('inf') earliest_time = float('inf')
earliest_message = None earliest_message = None
for msg in self.messages: for msg in self.messages:
msg_time = msg.thinking_start_time msg_time = msg.thinking_start_time
if msg_time < earliest_time: if msg_time < earliest_time:
earliest_time = msg_time earliest_time = msg_time
earliest_message = msg earliest_message = msg
return earliest_message return earliest_message
def add_message(self, message: Union[Message_Thinking, Message_Sending]) -> None: def add_message(self, message: Union[Message_Thinking, Message_Sending]) -> None:
"""添加消息到队列""" """添加消息到队列"""
# print(f"\033[1;32m[添加消息]\033[0m 添加消息到对应群") # print(f"\033[1;32m[添加消息]\033[0m 添加消息到对应群")
@@ -110,7 +111,7 @@ class MessageContainer:
self.messages.append(single_message) self.messages.append(single_message)
else: else:
self.messages.append(message) self.messages.append(message)
def remove_message(self, message: Union[Message_Thinking, Message_Sending]) -> bool: def remove_message(self, message: Union[Message_Thinking, Message_Sending]) -> bool:
"""移除消息如果消息存在则返回True否则返回False""" """移除消息如果消息存在则返回True否则返回False"""
try: try:
@@ -118,98 +119,104 @@ class MessageContainer:
self.messages.remove(message) self.messages.remove(message)
return True return True
return False return False
except Exception as e: except Exception:
print(f"\033[1;31m[错误]\033[0m 移除消息时发生错误: {e}") logger.exception(f"移除消息时发生错误")
return False return False
def has_messages(self) -> bool: def has_messages(self) -> bool:
"""检查是否有待发送的消息""" """检查是否有待发送的消息"""
return bool(self.messages) return bool(self.messages)
def get_all_messages(self) -> List[Union[Message, Message_Thinking]]: def get_all_messages(self) -> List[Union[Message, Message_Thinking]]:
"""获取所有消息""" """获取所有消息"""
return list(self.messages) return list(self.messages)
class MessageManager: class MessageManager:
"""管理所有群的消息容器""" """管理所有群的消息容器"""
def __init__(self): def __init__(self):
self.containers: Dict[int, MessageContainer] = {} self.containers: Dict[int, MessageContainer] = {}
self.storage = MessageStorage() self.storage = MessageStorage()
self._running = True self._running = True
def get_container(self, group_id: int) -> MessageContainer: def get_container(self, group_id: int) -> MessageContainer:
"""获取或创建群的消息容器""" """获取或创建群的消息容器"""
if group_id not in self.containers: if group_id not in self.containers:
self.containers[group_id] = MessageContainer(group_id) self.containers[group_id] = MessageContainer(group_id)
return self.containers[group_id] return self.containers[group_id]
def add_message(self, message: Union[Message_Thinking, Message_Sending, MessageSet]) -> None: def add_message(self, message: Union[Message_Thinking, Message_Sending, MessageSet]) -> None:
container = self.get_container(message.group_id) container = self.get_container(message.group_id)
container.add_message(message) container.add_message(message)
async def process_group_messages(self, group_id: int): async def process_group_messages(self, group_id: int):
"""处理群消息""" """处理群消息"""
# if int(time.time() / 3) == time.time() / 3: # if int(time.time() / 3) == time.time() / 3:
# print(f"\033[1;34m[调试]\033[0m 开始处理群{group_id}的消息") # print(f"\033[1;34m[调试]\033[0m 开始处理群{group_id}的消息")
container = self.get_container(group_id) container = self.get_container(group_id)
if container.has_messages(): if container.has_messages():
#最早的对象,可能是思考消息,也可能是发送消息 # 最早的对象,可能是思考消息,也可能是发送消息
message_earliest = container.get_earliest_message() #一个message_thinking or message_sending message_earliest = container.get_earliest_message() # 一个message_thinking or message_sending
#如果是思考消息 # 如果是思考消息
if isinstance(message_earliest, Message_Thinking): if isinstance(message_earliest, Message_Thinking):
#优先等待这条消息 # 优先等待这条消息
message_earliest.update_thinking_time() message_earliest.update_thinking_time()
thinking_time = message_earliest.thinking_time thinking_time = message_earliest.thinking_time
print(f"\033[1;34m[调试]\033[0m 消息正在思考中,已思考{int(thinking_time)}\033[K\r", end='', flush=True) print(f"消息正在思考中,已思考{int(thinking_time)}\r", end='', flush=True)
# 检查是否超时 # 检查是否超时
if thinking_time > global_config.thinking_timeout: if thinking_time > global_config.thinking_timeout:
print(f"\033[1;33m[警告]\033[0m 消息思考超时({thinking_time}秒),移除该消息") logger.warning(f"消息思考超时({thinking_time}秒),移除该消息")
container.remove_message(message_earliest) container.remove_message(message_earliest)
else:# 如果不是message_thinking就只能是message_sending else: # 如果不是message_thinking就只能是message_sending
print(f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中") logger.debug(f"消息'{message_earliest.processed_plain_text}'正在发送中")
#直接发,等什么呢 # 直接发,等什么呢
if message_earliest.is_head and message_earliest.update_thinking_time() >30: if message_earliest.is_head and message_earliest.update_thinking_time() > 30:
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False, reply_message_id=message_earliest.reply_message_id) await message_sender.send_group_message(group_id, message_earliest.processed_plain_text,
auto_escape=False,
reply_message_id=message_earliest.reply_message_id)
else: else:
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False) await message_sender.send_group_message(group_id, message_earliest.processed_plain_text,
#移除消息 auto_escape=False)
# 移除消息
if message_earliest.is_emoji: if message_earliest.is_emoji:
message_earliest.processed_plain_text = "[表情包]" message_earliest.processed_plain_text = "[表情包]"
await self.storage.store_message(message_earliest, None) await self.storage.store_message(message_earliest, None)
container.remove_message(message_earliest) container.remove_message(message_earliest)
#获取并处理超时消息 # 获取并处理超时消息
message_timeout = container.get_timeout_messages() #也许是一堆message_sending message_timeout = container.get_timeout_messages() # 也许是一堆message_sending
if message_timeout: if message_timeout:
print(f"\033[1;34m[调试]\033[0m 发现{len(message_timeout)}条超时消息") logger.warning(f"发现{len(message_timeout)}条超时消息")
for msg in message_timeout: for msg in message_timeout:
if msg == message_earliest: if msg == message_earliest:
continue # 跳过已经处理过的消息 continue # 跳过已经处理过的消息
try: try:
#发送 # 发送
if msg.is_head and msg.update_thinking_time() >30: if msg.is_head and msg.update_thinking_time() > 30:
await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False, reply_message_id=msg.reply_message_id) await message_sender.send_group_message(group_id, msg.processed_plain_text,
auto_escape=False,
reply_message_id=msg.reply_message_id)
else: else:
await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False) await message_sender.send_group_message(group_id, msg.processed_plain_text,
auto_escape=False)
#如果是表情包,则替换为"[表情包]" # 如果是表情包,则替换为"[表情包]"
if msg.is_emoji: if msg.is_emoji:
msg.processed_plain_text = "[表情包]" msg.processed_plain_text = "[表情包]"
await self.storage.store_message(msg, None) await self.storage.store_message(msg, None)
# 安全地移除消息 # 安全地移除消息
if not container.remove_message(msg): if not container.remove_message(msg):
print("\033[1;33m[警告]\033[0m 尝试删除不存在的消息") logger.warning("尝试删除不存在的消息")
except Exception as e: except Exception:
print(f"\033[1;31m[错误]\033[0m 处理超时消息时发生错误: {e}") logger.exception(f"处理超时消息时发生错误")
continue continue
async def start_processor(self): async def start_processor(self):
"""启动消息处理器""" """启动消息处理器"""
while self._running: while self._running:
@@ -217,9 +224,10 @@ class MessageManager:
tasks = [] tasks = []
for group_id in self.containers.keys(): for group_id in self.containers.keys():
tasks.append(self.process_group_messages(group_id)) tasks.append(self.process_group_messages(group_id))
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
# 创建全局消息管理器实例 # 创建全局消息管理器实例
message_manager = MessageManager() message_manager = MessageManager()
# 创建全局发送器实例 # 创建全局发送器实例

View File

@@ -1,6 +1,7 @@
import random import random
import time import time
from typing import Optional from typing import Optional
from loguru import logger
from ...common.database import Database from ...common.database import Database
from ..memory_system.memory import hippocampus, memory_graph from ..memory_system.memory import hippocampus, memory_graph
@@ -16,13 +17,11 @@ class PromptBuilder:
self.activate_messages = '' self.activate_messages = ''
self.db = Database.get_instance() self.db = Database.get_instance()
async def _build_prompt(self,
message_txt: str,
async def _build_prompt(self, sender_name: str = "某人",
message_txt: str, relationship_value: float = 0.0,
sender_name: str = "某人", group_id: Optional[int] = None) -> tuple[str, str]:
relationship_value: float = 0.0,
group_id: Optional[int] = None) -> tuple[str, str]:
"""构建prompt """构建prompt
Args: Args:
@@ -33,57 +32,56 @@ class PromptBuilder:
Returns: Returns:
str: 构建好的prompt str: 构建好的prompt
""" """
#先禁用关系 # 先禁用关系
if 0 > 30: if 0 > 30:
relation_prompt = "关系特别特别好,你很喜欢喜欢他" relation_prompt = "关系特别特别好,你很喜欢喜欢他"
relation_prompt_2 = "热情发言或者回复" relation_prompt_2 = "热情发言或者回复"
elif 0 <-20: elif 0 < -20:
relation_prompt = "关系很差,你很讨厌他" relation_prompt = "关系很差,你很讨厌他"
relation_prompt_2 = "骂他" relation_prompt_2 = "骂他"
else: else:
relation_prompt = "关系一般" relation_prompt = "关系一般"
relation_prompt_2 = "发言或者回复" relation_prompt_2 = "发言或者回复"
#开始构建prompt # 开始构建prompt
# 心情
#心情
mood_manager = MoodManager.get_instance() mood_manager = MoodManager.get_instance()
mood_prompt = mood_manager.get_prompt() mood_prompt = mood_manager.get_prompt()
# 日程构建
#日程构建
current_date = time.strftime("%Y-%m-%d", time.localtime()) current_date = time.strftime("%Y-%m-%d", time.localtime())
current_time = time.strftime("%H:%M:%S", time.localtime()) current_time = time.strftime("%H:%M:%S", time.localtime())
bot_schedule_now_time,bot_schedule_now_activity = bot_schedule.get_current_task() bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task()
prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n''' prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n'''
#知识构建 # 知识构建
start_time = time.time() start_time = time.time()
prompt_info = '' prompt_info = ''
promt_info_prompt = '' promt_info_prompt = ''
prompt_info = await self.get_prompt_info(message_txt,threshold=0.5) prompt_info = await self.get_prompt_info(message_txt, threshold=0.5)
if prompt_info: if prompt_info:
prompt_info = f'''\n----------------------------------------------------\n你有以下这些[知识]\n{prompt_info}\n请你记住上面的[知识],之后可能会用到\n----------------------------------------------------\n''' prompt_info = f'''你有以下这些[知识]{prompt_info}请你记住上面的[
知识],之后可能会用到-'''
end_time = time.time() end_time = time.time()
print(f"\033[1;32m[知识检索]\033[0m 耗时: {(end_time - start_time):.3f}") logger.debug(f"知识检索耗时: {(end_time - start_time):.3f}")
# 获取聊天上下文 # 获取聊天上下文
chat_talking_prompt = '' chat_talking_prompt = ''
if group_id: if group_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True) chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id,
limit=global_config.MAX_CONTEXT_SIZE,
combine=True)
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}" chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
# 使用新的记忆获取方法 # 使用新的记忆获取方法
memory_prompt = '' memory_prompt = ''
start_time = time.time() start_time = time.time()
# 调用 hippocampus 的 get_relevant_memories 方法 # 调用 hippocampus 的 get_relevant_memories 方法
relevant_memories = await hippocampus.get_relevant_memories( relevant_memories = await hippocampus.get_relevant_memories(
text=message_txt, text=message_txt,
@@ -91,30 +89,28 @@ class PromptBuilder:
similarity_threshold=0.4, similarity_threshold=0.4,
max_memory_num=5 max_memory_num=5
) )
if relevant_memories: if relevant_memories:
# 格式化记忆内容 # 格式化记忆内容
memory_items = [] memory_items = []
for memory in relevant_memories: for memory in relevant_memories:
memory_items.append(f"关于「{memory['topic']}」的记忆:{memory['content']}") memory_items.append(f"关于「{memory['topic']}」的记忆:{memory['content']}")
memory_prompt = "看到这些聊天,你想起来:\n" + "\n".join(memory_items) + "\n" memory_prompt = "看到这些聊天,你想起来:\n" + "\n".join(memory_items) + "\n"
# 打印调试信息 # 打印调试信息
print("\n\033[1;32m[记忆检索]\033[0m 找到以下相关记忆:") logger.debug("[记忆检索]找到以下相关记忆:")
for memory in relevant_memories: for memory in relevant_memories:
print(f"- 主题「{memory['topic']}」[相似度: {memory['similarity']:.2f}]: {memory['content']}") logger.debug(f"- 主题「{memory['topic']}」[相似度: {memory['similarity']:.2f}]: {memory['content']}")
end_time = time.time() end_time = time.time()
print(f"\033[1;32m[回忆耗时]\033[0m 耗时: {(end_time - start_time):.3f}") logger.info(f"回忆耗时: {(end_time - start_time):.3f}")
# 激活prompt构建
#激活prompt构建
activate_prompt = '' activate_prompt = ''
activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},{mood_prompt},你想要{relation_prompt_2}" activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},{mood_prompt},你想要{relation_prompt_2}"
#检测机器人相关词汇,改为关键词检测与反应功能了,提取到全局配置中 # 检测机器人相关词汇,改为关键词检测与反应功能了,提取到全局配置中
# bot_keywords = ['人机', 'bot', '机器', '入机', 'robot', '机器人'] # bot_keywords = ['人机', 'bot', '机器', '入机', 'robot', '机器人']
# is_bot = any(keyword in message_txt.lower() for keyword in bot_keywords) # is_bot = any(keyword in message_txt.lower() for keyword in bot_keywords)
# if is_bot: # if is_bot:
@@ -127,12 +123,11 @@ class PromptBuilder:
for rule in global_config.keywords_reaction_rules: for rule in global_config.keywords_reaction_rules:
if rule.get("enable", False): if rule.get("enable", False):
if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])): if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])):
print(f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}") logger.info(f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}")
keywords_reaction_prompt += rule.get("reaction", "") + '' keywords_reaction_prompt += rule.get("reaction", "") + ''
# 人格选择
#人格选择 personality = global_config.PROMPT_PERSONALITY
personality=global_config.PROMPT_PERSONALITY
probability_1 = global_config.PERSONALITY_1 probability_1 = global_config.PERSONALITY_1
probability_2 = global_config.PERSONALITY_2 probability_2 = global_config.PERSONALITY_2
probability_3 = global_config.PERSONALITY_3 probability_3 = global_config.PERSONALITY_3
@@ -150,8 +145,8 @@ class PromptBuilder:
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[2]}, 你正在浏览qq群{promt_info_prompt}, prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[2]}, 你正在浏览qq群{promt_info_prompt},
现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{keywords_reaction_prompt} 现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{keywords_reaction_prompt}
请你表达自己的见解和观点。可以有个性。''' 请你表达自己的见解和观点。可以有个性。'''
#中文高手(新加的好玩功能) # 中文高手(新加的好玩功能)
prompt_ger = '' prompt_ger = ''
if random.random() < 0.04: if random.random() < 0.04:
prompt_ger += '你喜欢用倒装句' prompt_ger += '你喜欢用倒装句'
@@ -159,23 +154,23 @@ class PromptBuilder:
prompt_ger += '你喜欢用反问句' prompt_ger += '你喜欢用反问句'
if random.random() < 0.01: if random.random() < 0.01:
prompt_ger += '你喜欢用文言文' prompt_ger += '你喜欢用文言文'
#额外信息要求 # 额外信息要求
extra_info = '''但是记得回复平淡一些,简短一些,尤其注意在没明确提到时不要过多提及自身的背景, 不要直接回复别人发的表情包,记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容''' extra_info = '''但是记得回复平淡一些,简短一些,尤其注意在没明确提到时不要过多提及自身的背景, 不要直接回复别人发的表情包,记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容'''
#合并prompt # 合并prompt
prompt = "" prompt = ""
prompt += f"{prompt_info}\n" prompt += f"{prompt_info}\n"
prompt += f"{prompt_date}\n" prompt += f"{prompt_date}\n"
prompt += f"{chat_talking_prompt}\n" prompt += f"{chat_talking_prompt}\n"
prompt += f"{prompt_personality}\n" prompt += f"{prompt_personality}\n"
prompt += f"{prompt_ger}\n" prompt += f"{prompt_ger}\n"
prompt += f"{extra_info}\n" prompt += f"{extra_info}\n"
'''读空气prompt处理''' '''读空气prompt处理'''
activate_prompt_check=f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。" activate_prompt_check = f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。"
prompt_personality_check = '' prompt_personality_check = ''
extra_check_info=f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息一定要回复如果自己正在和别人聊天一定要回复其他话题如果合适搭话也可以回复如果认为应该回复请输出yes否则输出no请注意是决定是否需要回复而不是编写回复内容除了yes和no不要输出任何回复内容。" extra_check_info = f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息一定要回复如果自己正在和别人聊天一定要回复其他话题如果合适搭话也可以回复如果认为应该回复请输出yes否则输出no请注意是决定是否需要回复而不是编写回复内容除了yes和no不要输出任何回复内容。"
if personality_choice < probability_1: # 第一种人格 if personality_choice < probability_1: # 第一种人格
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME}{personality[0]}, 你正在浏览qq群{promt_info_prompt} {activate_prompt_check} {extra_check_info}''' prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME}{personality[0]}, 你正在浏览qq群{promt_info_prompt} {activate_prompt_check} {extra_check_info}'''
elif personality_choice < probability_1 + probability_2: # 第二种人格 elif personality_choice < probability_1 + probability_2: # 第二种人格
@@ -183,34 +178,36 @@ class PromptBuilder:
else: # 第三种人格 else: # 第三种人格
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME}{personality[2]}, 你正在浏览qq群{promt_info_prompt} {activate_prompt_check} {extra_check_info}''' prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME}{personality[2]}, 你正在浏览qq群{promt_info_prompt} {activate_prompt_check} {extra_check_info}'''
prompt_check_if_response=f"{prompt_info}\n{prompt_date}\n{chat_talking_prompt}\n{prompt_personality_check}" prompt_check_if_response = f"{prompt_info}\n{prompt_date}\n{chat_talking_prompt}\n{prompt_personality_check}"
return prompt,prompt_check_if_response return prompt, prompt_check_if_response
def _build_initiative_prompt_select(self,group_id): def _build_initiative_prompt_select(self, group_id, probability_1=0.8, probability_2=0.1):
current_date = time.strftime("%Y-%m-%d", time.localtime()) current_date = time.strftime("%Y-%m-%d", time.localtime())
current_time = time.strftime("%H:%M:%S", time.localtime()) current_time = time.strftime("%H:%M:%S", time.localtime())
bot_schedule_now_time,bot_schedule_now_activity = bot_schedule.get_current_task() bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task()
prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n''' prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n'''
chat_talking_prompt = '' chat_talking_prompt = ''
if group_id: if group_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True) chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id,
limit=global_config.MAX_CONTEXT_SIZE,
combine=True)
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}" chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}") # print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
# 获取主动发言的话题 # 获取主动发言的话题
all_nodes=memory_graph.dots all_nodes = memory_graph.dots
all_nodes=filter(lambda dot:len(dot[1]['memory_items'])>3,all_nodes) all_nodes = filter(lambda dot: len(dot[1]['memory_items']) > 3, all_nodes)
nodes_for_select=random.sample(all_nodes,5) nodes_for_select = random.sample(all_nodes, 5)
topics=[info[0] for info in nodes_for_select] topics = [info[0] for info in nodes_for_select]
infos=[info[1] for info in nodes_for_select] infos = [info[1] for info in nodes_for_select]
#激活prompt构建 # 激活prompt构建
activate_prompt = '' activate_prompt = ''
activate_prompt = "以上是群里正在进行的聊天。" activate_prompt = "以上是群里正在进行的聊天。"
personality=global_config.PROMPT_PERSONALITY personality = global_config.PROMPT_PERSONALITY
prompt_personality = '' prompt_personality = ''
personality_choice = random.random() personality_choice = random.random()
if personality_choice < probability_1: # 第一种人格 if personality_choice < probability_1: # 第一种人格
@@ -219,32 +216,31 @@ class PromptBuilder:
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[1]}''' prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[1]}'''
else: # 第三种人格 else: # 第三种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[2]}''' prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[2]}'''
topics_str=','.join(f"\"{topics}\"")
prompt_for_select=f"你现在想在群里发言,回忆了一下,想到几个话题,分别是{topics_str},综合当前状态以及群内气氛,请你在其中选择一个合适的话题,注意只需要输出话题,除了话题什么也不要输出(双引号也不要输出)"
prompt_initiative_select=f"{prompt_date}\n{prompt_personality}\n{prompt_for_select}"
prompt_regular=f"{prompt_date}\n{prompt_personality}"
return prompt_initiative_select,nodes_for_select,prompt_regular topics_str = ','.join(f"\"{topics}\"")
prompt_for_select = f"你现在想在群里发言,回忆了一下,想到几个话题,分别是{topics_str},综合当前状态以及群内气氛,请你在其中选择一个合适的话题,注意只需要输出话题,除了话题什么也不要输出(双引号也不要输出)"
def _build_initiative_prompt_check(self,selected_node,prompt_regular):
memory=random.sample(selected_node['memory_items'],3) prompt_initiative_select = f"{prompt_date}\n{prompt_personality}\n{prompt_for_select}"
memory='\n'.join(memory) prompt_regular = f"{prompt_date}\n{prompt_personality}"
prompt_for_check=f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n以这个作为主题发言合适吗请在把握群里的聊天内容的基础上综合群内的氛围如果认为应该发言请输出yes否则输出no请注意是决定是否需要发言而不是编写回复内容除了yes和no不要输出任何回复内容。"
return prompt_for_check,memory return prompt_initiative_select, nodes_for_select, prompt_regular
def _build_initiative_prompt(self,selected_node,prompt_regular,memory): def _build_initiative_prompt_check(self, selected_node, prompt_regular):
prompt_for_initiative=f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n,请在把握群里的聊天内容的基础上,综合群内的氛围,以日常且口语化的口吻,简短且随意一点进行发言,不要说的太有条理,可以有个性。记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等)" memory = random.sample(selected_node['memory_items'], 3)
memory = '\n'.join(memory)
prompt_for_check = f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n以这个作为主题发言合适吗请在把握群里的聊天内容的基础上综合群内的氛围如果认为应该发言请输出yes否则输出no请注意是决定是否需要发言而不是编写回复内容除了yes和no不要输出任何回复内容。"
return prompt_for_check, memory
def _build_initiative_prompt(self, selected_node, prompt_regular, memory):
prompt_for_initiative = f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n,请在把握群里的聊天内容的基础上,综合群内的氛围,以日常且口语化的口吻,简短且随意一点进行发言,不要说的太有条理,可以有个性。记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等)"
return prompt_for_initiative return prompt_for_initiative
async def get_prompt_info(self,message:str,threshold:float): async def get_prompt_info(self, message: str, threshold: float):
related_info = '' related_info = ''
print(f"\033[1;34m[调试]\033[0m 获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
embedding = await get_embedding(message) embedding = await get_embedding(message)
related_info += self.get_info_from_db(embedding,threshold=threshold) related_info += self.get_info_from_db(embedding, threshold=threshold)
return related_info return related_info
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5) -> str: def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5) -> str:
@@ -305,14 +301,15 @@ class PromptBuilder:
{"$limit": limit}, {"$limit": limit},
{"$project": {"content": 1, "similarity": 1}} {"$project": {"content": 1, "similarity": 1}}
] ]
results = list(self.db.db.knowledges.aggregate(pipeline)) results = list(self.db.db.knowledges.aggregate(pipeline))
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}") # print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
if not results: if not results:
return '' return ''
# 返回所有找到的内容,用换行分隔 # 返回所有找到的内容,用换行分隔
return '\n'.join(str(result['content']) for result in results) return '\n'.join(str(result['content']) for result in results)
prompt_builder = PromptBuilder()
prompt_builder = PromptBuilder()

View File

@@ -1,4 +1,5 @@
import asyncio import asyncio
from loguru import logger
from typing import Optional from typing import Optional
from ...common.database import Database from ...common.database import Database
@@ -8,9 +9,10 @@ class Impression:
traits: str = None traits: str = None
called: str = None called: str = None
know_time: float = None know_time: float = None
relationship_value: float = None relationship_value: float = None
class Relationship: class Relationship:
user_id: int = None user_id: int = None
# impression: Impression = None # impression: Impression = None
@@ -21,7 +23,7 @@ class Relationship:
nickname: str = None nickname: str = None
relationship_value: float = None relationship_value: float = None
saved = False saved = False
def __init__(self, user_id: int, data=None, **kwargs): def __init__(self, user_id: int, data=None, **kwargs):
if isinstance(data, dict): if isinstance(data, dict):
# 如果输入是字典,使用字典解析 # 如果输入是字典,使用字典解析
@@ -39,14 +41,12 @@ class Relationship:
self.nickname = kwargs.get('nickname') self.nickname = kwargs.get('nickname')
self.relationship_value = kwargs.get('relationship_value', 0.0) self.relationship_value = kwargs.get('relationship_value', 0.0)
self.saved = kwargs.get('saved', False) self.saved = kwargs.get('saved', False)
class RelationshipManager: class RelationshipManager:
def __init__(self): def __init__(self):
self.relationships: dict[int, Relationship] = {} self.relationships: dict[int, Relationship] = {}
async def update_relationship(self, user_id: int, data=None, **kwargs): async def update_relationship(self, user_id: int, data=None, **kwargs):
# 检查是否在内存中已存在 # 检查是否在内存中已存在
relationship = self.relationships.get(user_id) relationship = self.relationships.get(user_id)
@@ -62,7 +62,8 @@ class RelationshipManager:
setattr(relationship, key, value) setattr(relationship, key, value)
else: else:
# 如果不存在,创建新对象 # 如果不存在,创建新对象
relationship = Relationship(user_id, data=data) if isinstance(data, dict) else Relationship(user_id, **kwargs) relationship = Relationship(user_id, data=data) if isinstance(data, dict) else Relationship(user_id,
**kwargs)
self.relationships[user_id] = relationship self.relationships[user_id] = relationship
# 更新 id_name_nickname_table # 更新 id_name_nickname_table
@@ -71,9 +72,9 @@ class RelationshipManager:
# 保存到数据库 # 保存到数据库
await self.storage_relationship(relationship) await self.storage_relationship(relationship)
relationship.saved = True relationship.saved = True
return relationship return relationship
async def update_relationship_value(self, user_id: int, **kwargs): async def update_relationship_value(self, user_id: int, **kwargs):
# 检查是否在内存中已存在 # 检查是否在内存中已存在
relationship = self.relationships.get(user_id) relationship = self.relationships.get(user_id)
@@ -85,31 +86,30 @@ class RelationshipManager:
relationship.saved = True relationship.saved = True
return relationship return relationship
else: else:
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id} 不存在,无法更新") logger.warning(f"用户 {user_id} 不存在,无法更新")
return None return None
def get_relationship(self, user_id: int) -> Optional[Relationship]: def get_relationship(self, user_id: int) -> Optional[Relationship]:
"""获取用户关系对象""" """获取用户关系对象"""
if user_id in self.relationships: if user_id in self.relationships:
return self.relationships[user_id] return self.relationships[user_id]
else: else:
return 0 return 0
async def load_relationship(self, data: dict) -> Relationship: async def load_relationship(self, data: dict) -> Relationship:
"""从数据库加载或创建新的关系对象""" """从数据库加载或创建新的关系对象"""
rela = Relationship(user_id=data['user_id'], data=data) rela = Relationship(user_id=data['user_id'], data=data)
rela.saved = True rela.saved = True
self.relationships[rela.user_id] = rela self.relationships[rela.user_id] = rela
return rela return rela
async def load_all_relationships(self): async def load_all_relationships(self):
"""加载所有关系对象""" """加载所有关系对象"""
db = Database.get_instance() db = Database.get_instance()
all_relationships = db.db.relationships.find({}) all_relationships = db.db.relationships.find({})
for data in all_relationships: for data in all_relationships:
await self.load_relationship(data) await self.load_relationship(data)
async def _start_relationship_manager(self): async def _start_relationship_manager(self):
"""每5分钟自动保存一次关系数据""" """每5分钟自动保存一次关系数据"""
db = Database.get_instance() db = Database.get_instance()
@@ -119,23 +119,23 @@ class RelationshipManager:
for data in all_relationships: for data in all_relationships:
user_id = data['user_id'] user_id = data['user_id']
relationship = await self.load_relationship(data) relationship = await self.load_relationship(data)
self.relationships[user_id] = relationship self.relationships[user_id] = relationship
print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录") logger.debug(f"已加载 {len(self.relationships)} 条关系记录")
while True: while True:
print("\033[1;32m[关系管理]\033[0m 正在自动保存关系") logger.debug("正在自动保存关系")
await asyncio.sleep(300) # 等待300秒(5分钟) await asyncio.sleep(300) # 等待300秒(5分钟)
await self._save_all_relationships() await self._save_all_relationships()
async def _save_all_relationships(self): async def _save_all_relationships(self):
"""将所有关系数据保存到数据库""" """将所有关系数据保存到数据库"""
# 保存所有关系数据 # 保存所有关系数据
for userid, relationship in self.relationships.items(): for userid, relationship in self.relationships.items():
if not relationship.saved: if not relationship.saved:
relationship.saved = True relationship.saved = True
await self.storage_relationship(relationship) await self.storage_relationship(relationship)
async def storage_relationship(self,relationship: Relationship): async def storage_relationship(self, relationship: Relationship):
""" """
将关系记录存储到数据库中 将关系记录存储到数据库中
""" """
@@ -145,7 +145,7 @@ class RelationshipManager:
gender = relationship.gender gender = relationship.gender
age = relationship.age age = relationship.age
saved = relationship.saved saved = relationship.saved
db = Database.get_instance() db = Database.get_instance()
db.db.relationships.update_one( db.db.relationships.update_one(
{'user_id': user_id}, {'user_id': user_id},
@@ -158,7 +158,7 @@ class RelationshipManager:
}}, }},
upsert=True upsert=True
) )
def get_name(self, user_id: int) -> str: def get_name(self, user_id: int) -> str:
# 确保user_id是整数类型 # 确保user_id是整数类型
user_id = int(user_id) user_id = int(user_id)
@@ -169,4 +169,4 @@ class RelationshipManager:
return "某人" return "某人"
relationship_manager = RelationshipManager() relationship_manager = RelationshipManager()

View File

@@ -2,12 +2,13 @@ from typing import Optional
from ...common.database import Database from ...common.database import Database
from .message import Message from .message import Message
from loguru import logger
class MessageStorage: class MessageStorage:
def __init__(self): def __init__(self):
self.db = Database.get_instance() self.db = Database.get_instance()
async def store_message(self, message: Message, topic: Optional[str] = None) -> None: async def store_message(self, message: Message, topic: Optional[str] = None) -> None:
"""存储消息到数据库""" """存储消息到数据库"""
try: try:
@@ -41,9 +42,9 @@ class MessageStorage:
"topic": topic, "topic": topic,
"detailed_plain_text": message.detailed_plain_text, "detailed_plain_text": message.detailed_plain_text,
} }
self.db.db.messages.insert_one(message_data)
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 存储消息失败: {e}")
# 如果需要其他存储相关的函数,可以在这里添加 self.db.db.messages.insert_one(message_data)
except Exception:
logger.exception(f"存储消息失败")
# 如果需要其他存储相关的函数,可以在这里添加

View File

@@ -4,9 +4,11 @@ from nonebot import get_driver
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from .config import global_config from .config import global_config
from loguru import logger
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
class TopicIdentifier: class TopicIdentifier:
def __init__(self): def __init__(self):
@@ -23,19 +25,20 @@ class TopicIdentifier:
# 使用 LLM_request 类进行请求 # 使用 LLM_request 类进行请求
topic, _ = await self.llm_topic_judge.generate_response(prompt) topic, _ = await self.llm_topic_judge.generate_response(prompt)
if not topic: if not topic:
print("\033[1;31m[错误]\033[0m LLM API 返回为空") logger.error("LLM API 返回为空")
return None return None
# 直接在这里处理主题解析 # 直接在这里处理主题解析
if not topic or topic == "无主题": if not topic or topic == "无主题":
return None return None
# 解析主题字符串为列表 # 解析主题字符串为列表
topic_list = [t.strip() for t in topic.split(",") if t.strip()] topic_list = [t.strip() for t in topic.split(",") if t.strip()]
print(f"\033[1;32m[主题识别]\033[0m 主题: {topic_list}") logger.info(f"主题: {topic_list}")
return topic_list if topic_list else None return topic_list if topic_list else None
topic_identifier = TopicIdentifier()
topic_identifier = TopicIdentifier()

View File

@@ -7,6 +7,7 @@ from typing import Dict, List
import jieba import jieba
import numpy as np import numpy as np
from nonebot import get_driver from nonebot import get_driver
from loguru import logger
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from ..utils.typo_generator import ChineseTypoGenerator from ..utils.typo_generator import ChineseTypoGenerator
@@ -39,16 +40,16 @@ def combine_messages(messages: List[Message]) -> str:
def db_message_to_str(message_dict: Dict) -> str: def db_message_to_str(message_dict: Dict) -> str:
print(f"message_dict: {message_dict}") logger.debug(f"message_dict: {message_dict}")
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"])) time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
try: try:
name = "[(%s)%s]%s" % ( name = "[(%s)%s]%s" % (
message_dict['user_id'], message_dict.get("user_nickname", ""), message_dict.get("user_cardname", "")) message_dict['user_id'], message_dict.get("user_nickname", ""), message_dict.get("user_cardname", ""))
except: except:
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}" name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
content = message_dict.get("processed_plain_text", "") content = message_dict.get("processed_plain_text", "")
result = f"[{time_str}] {name}: {content}\n" result = f"[{time_str}] {name}: {content}\n"
print(f"result: {result}") logger.debug(f"result: {result}")
return result return result
@@ -182,7 +183,7 @@ async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
await msg.initialize() await msg.initialize()
message_objects.append(msg) message_objects.append(msg)
except KeyError: except KeyError:
print("[WARNING] 数据库中存在无效的消息") logger.warning("数据库中存在无效的消息")
continue continue
# 按时间正序排列 # 按时间正序排列
@@ -298,11 +299,10 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]:
sentence = sentence.replace('', ' ').replace(',', ' ') sentence = sentence.replace('', ' ').replace(',', ' ')
sentences_done.append(sentence) sentences_done.append(sentence)
print(f"处理后的句子: {sentences_done}") logger.info(f"处理后的句子: {sentences_done}")
return sentences_done return sentences_done
def random_remove_punctuation(text: str) -> str: def random_remove_punctuation(text: str) -> str:
"""随机处理标点符号,模拟人类打字习惯 """随机处理标点符号,模拟人类打字习惯
@@ -330,11 +330,10 @@ def random_remove_punctuation(text: str) -> str:
return result return result
def process_llm_response(text: str) -> List[str]: def process_llm_response(text: str) -> List[str]:
# processed_response = process_text_with_typos(content) # processed_response = process_text_with_typos(content)
if len(text) > 200: if len(text) > 200:
print(f"回复过长 ({len(text)} 字符),返回默认回复") logger.warning(f"回复过长 ({len(text)} 字符),返回默认回复")
return ['懒得说'] return ['懒得说']
# 处理长消息 # 处理长消息
typo_generator = ChineseTypoGenerator( typo_generator = ChineseTypoGenerator(
@@ -354,9 +353,9 @@ def process_llm_response(text: str) -> List[str]:
else: else:
sentences.append(sentence) sentences.append(sentence)
# 检查分割后的消息数量是否过多超过3条 # 检查分割后的消息数量是否过多超过3条
if len(sentences) > 5: if len(sentences) > 5:
print(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复") logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
return [f'{global_config.BOT_NICKNAME}不知道哦'] return [f'{global_config.BOT_NICKNAME}不知道哦']
return sentences return sentences
@@ -378,15 +377,15 @@ def calculate_typing_time(input_string: str, chinese_time: float = 0.4, english_
mood_arousal = mood_manager.current_mood.arousal mood_arousal = mood_manager.current_mood.arousal
# 映射到0.5到2倍的速度系数 # 映射到0.5到2倍的速度系数
typing_speed_multiplier = 1.5 ** mood_arousal # 唤醒度为1时速度翻倍,为-1时速度减半 typing_speed_multiplier = 1.5 ** mood_arousal # 唤醒度为1时速度翻倍,为-1时速度减半
chinese_time *= 1/typing_speed_multiplier chinese_time *= 1 / typing_speed_multiplier
english_time *= 1/typing_speed_multiplier english_time *= 1 / typing_speed_multiplier
# 计算中文字符数 # 计算中文字符数
chinese_chars = sum(1 for char in input_string if '\u4e00' <= char <= '\u9fff') chinese_chars = sum(1 for char in input_string if '\u4e00' <= char <= '\u9fff')
# 如果只有一个中文字符使用3倍时间 # 如果只有一个中文字符使用3倍时间
if chinese_chars == 1 and len(input_string.strip()) == 1: if chinese_chars == 1 and len(input_string.strip()) == 1:
return chinese_time * 3 + 0.3 # 加上回车时间 return chinese_time * 3 + 0.3 # 加上回车时间
# 正常计算所有字符的输入时间 # 正常计算所有字符的输入时间
total_time = 0.0 total_time = 0.0
for char in input_string: for char in input_string:

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
from .config import global_config from .config import global_config
from loguru import logger
class WillingManager: class WillingManager:
@@ -30,16 +31,16 @@ class WillingManager:
# print(f"初始意愿: {current_willing}") # print(f"初始意愿: {current_willing}")
if is_mentioned_bot and current_willing < 1.0: if is_mentioned_bot and current_willing < 1.0:
current_willing += 0.9 current_willing += 0.9
print(f"被提及, 当前意愿: {current_willing}") logger.info(f"被提及, 当前意愿: {current_willing}")
elif is_mentioned_bot: elif is_mentioned_bot:
current_willing += 0.05 current_willing += 0.05
print(f"被重复提及, 当前意愿: {current_willing}") logger.info(f"被重复提及, 当前意愿: {current_willing}")
if is_emoji: if is_emoji:
current_willing *= 0.1 current_willing *= 0.1
print(f"表情包, 当前意愿: {current_willing}") logger.info(f"表情包, 当前意愿: {current_willing}")
print(f"放大系数_interested_rate: {global_config.response_interested_rate_amplifier}") logger.debug(f"放大系数_interested_rate: {global_config.response_interested_rate_amplifier}")
interested_rate *= global_config.response_interested_rate_amplifier #放大回复兴趣度 interested_rate *= global_config.response_interested_rate_amplifier #放大回复兴趣度
if interested_rate > 0.4: if interested_rate > 0.4:
# print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}") # print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}")

View File

@@ -7,6 +7,7 @@ import jieba
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import networkx as nx import networkx as nx
from dotenv import load_dotenv from dotenv import load_dotenv
from loguru import logger
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径 sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
from src.common.database import Database # 使用正确的导入语法 from src.common.database import Database # 使用正确的导入语法
@@ -15,15 +16,15 @@ from src.common.database import Database # 使用正确的导入语法
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev') env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev')
load_dotenv(env_path) load_dotenv(env_path)
class Memory_graph: class Memory_graph:
def __init__(self): def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构 self.G = nx.Graph() # 使用 networkx 的图结构
self.db = Database.get_instance() self.db = Database.get_instance()
def connect_dot(self, concept1, concept2): def connect_dot(self, concept1, concept2):
self.G.add_edge(concept1, concept2) self.G.add_edge(concept1, concept2)
def add_dot(self, concept, memory): def add_dot(self, concept, memory):
if concept in self.G: if concept in self.G:
# 如果节点已存在,将新记忆添加到现有列表中 # 如果节点已存在,将新记忆添加到现有列表中
@@ -37,7 +38,7 @@ class Memory_graph:
else: else:
# 如果是新节点,创建新的记忆列表 # 如果是新节点,创建新的记忆列表
self.G.add_node(concept, memory_items=[memory]) self.G.add_node(concept, memory_items=[memory])
def get_dot(self, concept): def get_dot(self, concept):
# 检查节点是否存在于图中 # 检查节点是否存在于图中
if concept in self.G: if concept in self.G:
@@ -45,20 +46,20 @@ class Memory_graph:
node_data = self.G.nodes[concept] node_data = self.G.nodes[concept]
# print(node_data) # print(node_data)
# 创建新的Memory_dot对象 # 创建新的Memory_dot对象
return concept,node_data return concept, node_data
return None return None
def get_related_item(self, topic, depth=1): def get_related_item(self, topic, depth=1):
if topic not in self.G: if topic not in self.G:
return [], [] return [], []
first_layer_items = [] first_layer_items = []
second_layer_items = [] second_layer_items = []
# 获取相邻节点 # 获取相邻节点
neighbors = list(self.G.neighbors(topic)) neighbors = list(self.G.neighbors(topic))
# print(f"第一层: {topic}") # print(f"第一层: {topic}")
# 获取当前节点的记忆项 # 获取当前节点的记忆项
node_data = self.get_dot(topic) node_data = self.get_dot(topic)
if node_data: if node_data:
@@ -69,7 +70,7 @@ class Memory_graph:
first_layer_items.extend(memory_items) first_layer_items.extend(memory_items)
else: else:
first_layer_items.append(memory_items) first_layer_items.append(memory_items)
# 只在depth=2时获取第二层记忆 # 只在depth=2时获取第二层记忆
if depth >= 2: if depth >= 2:
# 获取相邻节点的记忆项 # 获取相邻节点的记忆项
@@ -84,42 +85,44 @@ class Memory_graph:
second_layer_items.extend(memory_items) second_layer_items.extend(memory_items)
else: else:
second_layer_items.append(memory_items) second_layer_items.append(memory_items)
return first_layer_items, second_layer_items return first_layer_items, second_layer_items
def store_memory(self): def store_memory(self):
for node in self.G.nodes(): for node in self.G.nodes():
dot_data = { dot_data = {
"concept": node "concept": node
} }
self.db.db.store_memory_dots.insert_one(dot_data) self.db.db.store_memory_dots.insert_one(dot_data)
@property @property
def dots(self): def dots(self):
# 返回所有节点对应的 Memory_dot 对象 # 返回所有节点对应的 Memory_dot 对象
return [self.get_dot(node) for node in self.G.nodes()] return [self.get_dot(node) for node in self.G.nodes()]
def get_random_chat_from_db(self, length: int, timestamp: str): def get_random_chat_from_db(self, length: int, timestamp: str):
# 从数据库中根据时间戳获取离其最近的聊天记录 # 从数据库中根据时间戳获取离其最近的聊天记录
chat_text = '' chat_text = ''
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出 closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}") logger.info(
f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
if closest_record: if closest_record:
closest_time = closest_record['time'] closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同 # 获取该时间戳之后的length条消息且groupid相同
chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length)) chat_record = list(
self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(
length))
for record in chat_record: for record in chat_record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time']))) time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
try: try:
displayname="[(%s)%s]%s" % (record["user_id"],record["user_nickname"],record["user_cardname"]) displayname = "[(%s)%s]%s" % (record["user_id"], record["user_nickname"], record["user_cardname"])
except: except:
displayname=record["user_nickname"] or "用户" + str(record["user_id"]) displayname = record["user_nickname"] or "用户" + str(record["user_id"])
chat_text += f'[{time_str}] {displayname}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息 chat_text += f'[{time_str}] {displayname}: {record["processed_plain_text"]}\n' # 添加发送者和时间信息
return chat_text return chat_text
return [] # 如果没有找到记录,返回空列表 return [] # 如果没有找到记录,返回空列表
def save_graph_to_db(self): def save_graph_to_db(self):
@@ -166,53 +169,54 @@ def main():
password=os.getenv("MONGODB_PASSWORD", ""), password=os.getenv("MONGODB_PASSWORD", ""),
auth_source=os.getenv("MONGODB_AUTH_SOURCE", "") auth_source=os.getenv("MONGODB_AUTH_SOURCE", "")
) )
memory_graph = Memory_graph() memory_graph = Memory_graph()
memory_graph.load_graph_from_db() memory_graph.load_graph_from_db()
# 只显示一次优化后的图形 # 只显示一次优化后的图形
visualize_graph_lite(memory_graph) visualize_graph_lite(memory_graph)
while True: while True:
query = input("请输入新的查询概念(输入'退出'以结束):") query = input("请输入新的查询概念(输入'退出'以结束):")
if query.lower() == '退出': if query.lower() == '退出':
break break
first_layer_items, second_layer_items = memory_graph.get_related_item(query) first_layer_items, second_layer_items = memory_graph.get_related_item(query)
if first_layer_items or second_layer_items: if first_layer_items or second_layer_items:
print("\n第一层记忆:") logger.debug("第一层记忆:")
for item in first_layer_items: for item in first_layer_items:
print(item) logger.debug(item)
print("\n第二层记忆:") logger.debug("第二层记忆:")
for item in second_layer_items: for item in second_layer_items:
print(item) logger.debug(item)
else: else:
print("未找到相关记忆。") logger.debug("未找到相关记忆。")
def segment_text(text): def segment_text(text):
seg_text = list(jieba.cut(text)) seg_text = list(jieba.cut(text))
return seg_text return seg_text
def find_topic(text, topic_num): def find_topic(text, topic_num):
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。' prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
return prompt return prompt
def topic_what(text, topic): def topic_what(text, topic):
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好' prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
return prompt return prompt
def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False): def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
# 设置中文字体 # 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
G = memory_graph.G G = memory_graph.G
# 创建一个新图用于可视化 # 创建一个新图用于可视化
H = G.copy() H = G.copy()
# 移除只有一条记忆的节点和连接数少于3的节点 # 移除只有一条记忆的节点和连接数少于3的节点
nodes_to_remove = [] nodes_to_remove = []
for node in H.nodes(): for node in H.nodes():
@@ -221,14 +225,14 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
degree = H.degree(node) degree = H.degree(node)
if memory_count < 3 or degree < 2: # 改为小于2而不是小于等于2 if memory_count < 3 or degree < 2: # 改为小于2而不是小于等于2
nodes_to_remove.append(node) nodes_to_remove.append(node)
H.remove_nodes_from(nodes_to_remove) H.remove_nodes_from(nodes_to_remove)
# 如果过滤后没有节点,则返回 # 如果过滤后没有节点,则返回
if len(H.nodes()) == 0: if len(H.nodes()) == 0:
print("过滤后没有符合条件的节点可显示") logger.debug("过滤后没有符合条件的节点可显示")
return return
# 保存图到本地 # 保存图到本地
# nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式 # nx.write_gml(H, "memory_graph.gml") # 保存为 GML 格式
@@ -236,7 +240,7 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
node_colors = [] node_colors = []
node_sizes = [] node_sizes = []
nodes = list(H.nodes()) nodes = list(H.nodes())
# 获取最大记忆数和最大度数用于归一化 # 获取最大记忆数和最大度数用于归一化
max_memories = 1 max_memories = 1
max_degree = 1 max_degree = 1
@@ -246,7 +250,7 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
degree = H.degree(node) degree = H.degree(node)
max_memories = max(max_memories, memory_count) max_memories = max(max_memories, memory_count)
max_degree = max(max_degree, degree) max_degree = max(max_degree, degree)
# 计算每个节点的大小和颜色 # 计算每个节点的大小和颜色
for node in nodes: for node in nodes:
# 计算节点大小(基于记忆数量) # 计算节点大小(基于记忆数量)
@@ -254,9 +258,9 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0) memory_count = len(memory_items) if isinstance(memory_items, list) else (1 if memory_items else 0)
# 使用指数函数使变化更明显 # 使用指数函数使变化更明显
ratio = memory_count / max_memories ratio = memory_count / max_memories
size = 500 + 5000 * (ratio ) # 使用1.5次方函数使差异不那么明显 size = 500 + 5000 * (ratio) # 使用1.5次方函数使差异不那么明显
node_sizes.append(size) node_sizes.append(size)
# 计算节点颜色(基于连接数) # 计算节点颜色(基于连接数)
degree = H.degree(node) degree = H.degree(node)
# 红色分量随着度数增加而增加 # 红色分量随着度数增加而增加
@@ -267,26 +271,25 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
# blue = 1 # blue = 1
color = (red, 0.1, blue) color = (red, 0.1, blue)
node_colors.append(color) node_colors.append(color)
# 绘制图形 # 绘制图形
plt.figure(figsize=(12, 8)) plt.figure(figsize=(12, 8))
pos = nx.spring_layout(H, k=1, iterations=50) # 增加k值使节点分布更开 pos = nx.spring_layout(H, k=1, iterations=50) # 增加k值使节点分布更开
nx.draw(H, pos, nx.draw(H, pos,
with_labels=True, with_labels=True,
node_color=node_colors, node_color=node_colors,
node_size=node_sizes, node_size=node_sizes,
font_size=10, font_size=10,
font_family='SimHei', font_family='SimHei',
font_weight='bold', font_weight='bold',
edge_color='gray', edge_color='gray',
width=0.5, width=0.5,
alpha=0.9) alpha=0.9)
title = '记忆图谱可视化 - 节点大小表示记忆数量,颜色表示连接数' title = '记忆图谱可视化 - 节点大小表示记忆数量,颜色表示连接数'
plt.title(title, fontsize=16, fontfamily='SimHei') plt.title(title, fontsize=16, fontfamily='SimHei')
plt.show() plt.show()
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -7,6 +7,7 @@ import time
import jieba import jieba
import networkx as nx import networkx as nx
from loguru import logger
from ...common.database import Database # 使用正确的导入语法 from ...common.database import Database # 使用正确的导入语法
from ..chat.config import global_config from ..chat.config import global_config
from ..chat.utils import ( from ..chat.utils import (
@@ -22,7 +23,7 @@ class Memory_graph:
def __init__(self): def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构 self.G = nx.Graph() # 使用 networkx 的图结构
self.db = Database.get_instance() self.db = Database.get_instance()
def connect_dot(self, concept1, concept2): def connect_dot(self, concept1, concept2):
# 如果边已存在,增加 strength # 如果边已存在,增加 strength
if self.G.has_edge(concept1, concept2): if self.G.has_edge(concept1, concept2):
@@ -30,7 +31,7 @@ class Memory_graph:
else: else:
# 如果是新边,初始化 strength 为 1 # 如果是新边,初始化 strength 为 1
self.G.add_edge(concept1, concept2, strength=1) self.G.add_edge(concept1, concept2, strength=1)
def add_dot(self, concept, memory): def add_dot(self, concept, memory):
if concept in self.G: if concept in self.G:
# 如果节点已存在,将新记忆添加到现有列表中 # 如果节点已存在,将新记忆添加到现有列表中
@@ -44,7 +45,7 @@ class Memory_graph:
else: else:
# 如果是新节点,创建新的记忆列表 # 如果是新节点,创建新的记忆列表
self.G.add_node(concept, memory_items=[memory]) self.G.add_node(concept, memory_items=[memory])
def get_dot(self, concept): def get_dot(self, concept):
# 检查节点是否存在于图中 # 检查节点是否存在于图中
if concept in self.G: if concept in self.G:
@@ -56,13 +57,13 @@ class Memory_graph:
def get_related_item(self, topic, depth=1): def get_related_item(self, topic, depth=1):
if topic not in self.G: if topic not in self.G:
return [], [] return [], []
first_layer_items = [] first_layer_items = []
second_layer_items = [] second_layer_items = []
# 获取相邻节点 # 获取相邻节点
neighbors = list(self.G.neighbors(topic)) neighbors = list(self.G.neighbors(topic))
# 获取当前节点的记忆项 # 获取当前节点的记忆项
node_data = self.get_dot(topic) node_data = self.get_dot(topic)
if node_data: if node_data:
@@ -73,7 +74,7 @@ class Memory_graph:
first_layer_items.extend(memory_items) first_layer_items.extend(memory_items)
else: else:
first_layer_items.append(memory_items) first_layer_items.append(memory_items)
# 只在depth=2时获取第二层记忆 # 只在depth=2时获取第二层记忆
if depth >= 2: if depth >= 2:
# 获取相邻节点的记忆项 # 获取相邻节点的记忆项
@@ -87,9 +88,9 @@ class Memory_graph:
second_layer_items.extend(memory_items) second_layer_items.extend(memory_items)
else: else:
second_layer_items.append(memory_items) second_layer_items.append(memory_items)
return first_layer_items, second_layer_items return first_layer_items, second_layer_items
@property @property
def dots(self): def dots(self):
# 返回所有节点对应的 Memory_dot 对象 # 返回所有节点对应的 Memory_dot 对象
@@ -99,43 +100,43 @@ class Memory_graph:
"""随机删除指定话题中的一条记忆,如果话题没有记忆则移除该话题节点""" """随机删除指定话题中的一条记忆,如果话题没有记忆则移除该话题节点"""
if topic not in self.G: if topic not in self.G:
return None return None
# 获取话题节点数据 # 获取话题节点数据
node_data = self.G.nodes[topic] node_data = self.G.nodes[topic]
# 如果节点存在memory_items # 如果节点存在memory_items
if 'memory_items' in node_data: if 'memory_items' in node_data:
memory_items = node_data['memory_items'] memory_items = node_data['memory_items']
# 确保memory_items是列表 # 确保memory_items是列表
if not isinstance(memory_items, list): if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else [] memory_items = [memory_items] if memory_items else []
# 如果有记忆项可以删除 # 如果有记忆项可以删除
if memory_items: if memory_items:
# 随机选择一个记忆项删除 # 随机选择一个记忆项删除
removed_item = random.choice(memory_items) removed_item = random.choice(memory_items)
memory_items.remove(removed_item) memory_items.remove(removed_item)
# 更新节点的记忆项 # 更新节点的记忆项
if memory_items: if memory_items:
self.G.nodes[topic]['memory_items'] = memory_items self.G.nodes[topic]['memory_items'] = memory_items
else: else:
# 如果没有记忆项了,删除整个节点 # 如果没有记忆项了,删除整个节点
self.G.remove_node(topic) self.G.remove_node(topic)
return removed_item return removed_item
return None return None
# 海马体 # 海马体
class Hippocampus: class Hippocampus:
def __init__(self,memory_graph:Memory_graph): def __init__(self, memory_graph: Memory_graph):
self.memory_graph = memory_graph self.memory_graph = memory_graph
self.llm_topic_judge = LLM_request(model = global_config.llm_topic_judge,temperature=0.5) self.llm_topic_judge = LLM_request(model=global_config.llm_topic_judge, temperature=0.5)
self.llm_summary_by_topic = LLM_request(model = global_config.llm_summary_by_topic,temperature=0.5) self.llm_summary_by_topic = LLM_request(model=global_config.llm_summary_by_topic, temperature=0.5)
def get_all_node_names(self) -> list: def get_all_node_names(self) -> list:
"""获取记忆图中所有节点的名字列表 """获取记忆图中所有节点的名字列表
@@ -156,8 +157,8 @@ class Hippocampus:
"""计算边的特征值""" """计算边的特征值"""
nodes = sorted([source, target]) nodes = sorted([source, target])
return hash(f"{nodes[0]}:{nodes[1]}") return hash(f"{nodes[0]}:{nodes[1]}")
def get_memory_sample(self, chat_size=20, time_frequency:dict={'near':2,'mid':4,'far':3}): def get_memory_sample(self, chat_size=20, time_frequency: dict = {'near': 2, 'mid': 4, 'far': 3}):
"""获取记忆样本 """获取记忆样本
Returns: Returns:
@@ -165,26 +166,26 @@ class Hippocampus:
""" """
current_timestamp = datetime.datetime.now().timestamp() current_timestamp = datetime.datetime.now().timestamp()
chat_samples = [] chat_samples = []
# 短期1h 中期4h 长期24h # 短期1h 中期4h 长期24h
for _ in range(time_frequency.get('near')): for _ in range(time_frequency.get('near')):
random_time = current_timestamp - random.randint(1, 3600) random_time = current_timestamp - random.randint(1, 3600)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
if messages: if messages:
chat_samples.append(messages) chat_samples.append(messages)
for _ in range(time_frequency.get('mid')): for _ in range(time_frequency.get('mid')):
random_time = current_timestamp - random.randint(3600, 3600*4) random_time = current_timestamp - random.randint(3600, 3600 * 4)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
if messages: if messages:
chat_samples.append(messages) chat_samples.append(messages)
for _ in range(time_frequency.get('far')): for _ in range(time_frequency.get('far')):
random_time = current_timestamp - random.randint(3600*4, 3600*24) random_time = current_timestamp - random.randint(3600 * 4, 3600 * 24)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
if messages: if messages:
chat_samples.append(messages) chat_samples.append(messages)
return chat_samples return chat_samples
async def memory_compress(self, messages: list, compress_rate=0.1): async def memory_compress(self, messages: list, compress_rate=0.1):
@@ -199,17 +200,17 @@ class Hippocampus:
""" """
if not messages: if not messages:
return set() return set()
# 合并消息文本,同时保留时间信息 # 合并消息文本,同时保留时间信息
input_text = "" input_text = ""
time_info = "" time_info = ""
# 计算最早和最晚时间 # 计算最早和最晚时间
earliest_time = min(msg['time'] for msg in messages) earliest_time = min(msg['time'] for msg in messages)
latest_time = max(msg['time'] for msg in messages) latest_time = max(msg['time'] for msg in messages)
earliest_dt = datetime.datetime.fromtimestamp(earliest_time) earliest_dt = datetime.datetime.fromtimestamp(earliest_time)
latest_dt = datetime.datetime.fromtimestamp(latest_time) latest_dt = datetime.datetime.fromtimestamp(latest_time)
# 如果是同一年 # 如果是同一年
if earliest_dt.year == latest_dt.year: if earliest_dt.year == latest_dt.year:
earliest_str = earliest_dt.strftime("%m-%d %H:%M:%S") earliest_str = earliest_dt.strftime("%m-%d %H:%M:%S")
@@ -217,54 +218,57 @@ class Hippocampus:
time_info += f"是在{earliest_dt.year}年,{earliest_str}{latest_str} 的对话:\n" time_info += f"是在{earliest_dt.year}年,{earliest_str}{latest_str} 的对话:\n"
else: else:
earliest_str = earliest_dt.strftime("%Y-%m-%d %H:%M:%S") earliest_str = earliest_dt.strftime("%Y-%m-%d %H:%M:%S")
latest_str = latest_dt.strftime("%Y-%m-%d %H:%M:%S") latest_str = latest_dt.strftime("%Y-%m-%d %H:%M:%S")
time_info += f"是从 {earliest_str}{latest_str} 的对话:\n" time_info += f"是从 {earliest_str}{latest_str} 的对话:\n"
for msg in messages: for msg in messages:
input_text += f"{msg['text']}\n" input_text += f"{msg['text']}\n"
print(input_text) logger.debug(input_text)
topic_num = self.calculate_topic_num(input_text, compress_rate) topic_num = self.calculate_topic_num(input_text, compress_rate)
topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(input_text, topic_num)) topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(input_text, topic_num))
# 过滤topics # 过滤topics
filter_keywords = global_config.memory_ban_words filter_keywords = global_config.memory_ban_words
topics = [topic.strip() for topic in topics_response[0].replace("", ",").replace("", ",").replace(" ", ",").split(",") if topic.strip()] topics = [topic.strip() for topic in
topics_response[0].replace("", ",").replace("", ",").replace(" ", ",").split(",") if topic.strip()]
filtered_topics = [topic for topic in topics if not any(keyword in topic for keyword in filter_keywords)] filtered_topics = [topic for topic in topics if not any(keyword in topic for keyword in filter_keywords)]
print(f"过滤后话题: {filtered_topics}") logger.info(f"过滤后话题: {filtered_topics}")
# 创建所有话题的请求任务 # 创建所有话题的请求任务
tasks = [] tasks = []
for topic in filtered_topics: for topic in filtered_topics:
topic_what_prompt = self.topic_what(input_text, topic, time_info) topic_what_prompt = self.topic_what(input_text, topic, time_info)
task = self.llm_summary_by_topic.generate_response_async(topic_what_prompt) task = self.llm_summary_by_topic.generate_response_async(topic_what_prompt)
tasks.append((topic.strip(), task)) tasks.append((topic.strip(), task))
# 等待所有任务完成 # 等待所有任务完成
compressed_memory = set() compressed_memory = set()
for topic, task in tasks: for topic, task in tasks:
response = await task response = await task
if response: if response:
compressed_memory.add((topic, response[0])) compressed_memory.add((topic, response[0]))
return compressed_memory return compressed_memory
def calculate_topic_num(self,text, compress_rate): def calculate_topic_num(self, text, compress_rate):
"""计算文本的话题数量""" """计算文本的话题数量"""
information_content = calculate_information_content(text) information_content = calculate_information_content(text)
topic_by_length = text.count('\n')*compress_rate topic_by_length = text.count('\n') * compress_rate
topic_by_information_content = max(1, min(5, int((information_content-3) * 2))) topic_by_information_content = max(1, min(5, int((information_content - 3) * 2)))
topic_num = int((topic_by_length + topic_by_information_content)/2) topic_num = int((topic_by_length + topic_by_information_content) / 2)
print(f"topic_by_length: {topic_by_length}, topic_by_information_content: {topic_by_information_content}, topic_num: {topic_num}") logger.debug(
f"topic_by_length: {topic_by_length}, topic_by_information_content: {topic_by_information_content}, "
f"topic_num: {topic_num}")
return topic_num return topic_num
async def operation_build_memory(self,chat_size=20): async def operation_build_memory(self, chat_size=20):
# 最近消息获取频率 # 最近消息获取频率
time_frequency = {'near':2,'mid':4,'far':2} time_frequency = {'near': 2, 'mid': 4, 'far': 2}
memory_sample = self.get_memory_sample(chat_size,time_frequency) memory_sample = self.get_memory_sample(chat_size, time_frequency)
for i, input_text in enumerate(memory_sample, 1): for i, input_text in enumerate(memory_sample, 1):
# 加载进度可视化 # 加载进度可视化
all_topics = [] all_topics = []
@@ -272,24 +276,24 @@ class Hippocampus:
bar_length = 30 bar_length = 30
filled_length = int(bar_length * i // len(memory_sample)) filled_length = int(bar_length * i // len(memory_sample))
bar = '' * filled_length + '-' * (bar_length - filled_length) bar = '' * filled_length + '-' * (bar_length - filled_length)
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})") logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
# 生成压缩后记忆 ,表现为 (话题,记忆) 的元组 # 生成压缩后记忆 ,表现为 (话题,记忆) 的元组
compressed_memory = set() compressed_memory = set()
compress_rate = 0.1 compress_rate = 0.1
compressed_memory = await self.memory_compress(input_text, compress_rate) compressed_memory = await self.memory_compress(input_text, compress_rate)
print(f"\033[1;33m压缩后记忆数量\033[0m: {len(compressed_memory)}") logger.info(f"压缩后记忆数量: {len(compressed_memory)}")
# 将记忆加入到图谱中 # 将记忆加入到图谱中
for topic, memory in compressed_memory: for topic, memory in compressed_memory:
print(f"\033[1;32m添加节点\033[0m: {topic}") logger.info(f"添加节点: {topic}")
self.memory_graph.add_dot(topic, memory) self.memory_graph.add_dot(topic, memory)
all_topics.append(topic) # 收集所有话题 all_topics.append(topic) # 收集所有话题
for i in range(len(all_topics)): for i in range(len(all_topics)):
for j in range(i + 1, len(all_topics)): for j in range(i + 1, len(all_topics)):
print(f"\033[1;32m连接节点\033[0m: {all_topics[i]}{all_topics[j]}") logger.info(f"连接节点: {all_topics[i]}{all_topics[j]}")
self.memory_graph.connect_dot(all_topics[i], all_topics[j]) self.memory_graph.connect_dot(all_topics[i], all_topics[j])
self.sync_memory_to_db() self.sync_memory_to_db()
def sync_memory_to_db(self): def sync_memory_to_db(self):
@@ -297,19 +301,19 @@ class Hippocampus:
# 获取数据库中所有节点和内存中所有节点 # 获取数据库中所有节点和内存中所有节点
db_nodes = list(self.memory_graph.db.db.graph_data.nodes.find()) db_nodes = list(self.memory_graph.db.db.graph_data.nodes.find())
memory_nodes = list(self.memory_graph.G.nodes(data=True)) memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 转换数据库节点为字典格式,方便查找 # 转换数据库节点为字典格式,方便查找
db_nodes_dict = {node['concept']: node for node in db_nodes} db_nodes_dict = {node['concept']: node for node in db_nodes}
# 检查并更新节点 # 检查并更新节点
for concept, data in memory_nodes: for concept, data in memory_nodes:
memory_items = data.get('memory_items', []) memory_items = data.get('memory_items', [])
if not isinstance(memory_items, list): if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else [] memory_items = [memory_items] if memory_items else []
# 计算内存中节点的特征值 # 计算内存中节点的特征值
memory_hash = self.calculate_node_hash(concept, memory_items) memory_hash = self.calculate_node_hash(concept, memory_items)
if concept not in db_nodes_dict: if concept not in db_nodes_dict:
# 数据库中缺少的节点,添加 # 数据库中缺少的节点,添加
node_data = { node_data = {
@@ -322,7 +326,7 @@ class Hippocampus:
# 获取数据库中节点的特征值 # 获取数据库中节点的特征值
db_node = db_nodes_dict[concept] db_node = db_nodes_dict[concept]
db_hash = db_node.get('hash', None) db_hash = db_node.get('hash', None)
# 如果特征值不同,则更新节点 # 如果特征值不同,则更新节点
if db_hash != memory_hash: if db_hash != memory_hash:
self.memory_graph.db.db.graph_data.nodes.update_one( self.memory_graph.db.db.graph_data.nodes.update_one(
@@ -332,17 +336,17 @@ class Hippocampus:
'hash': memory_hash 'hash': memory_hash
}} }}
) )
# 检查并删除数据库中多余的节点 # 检查并删除数据库中多余的节点
memory_concepts = set(node[0] for node in memory_nodes) memory_concepts = set(node[0] for node in memory_nodes)
for db_node in db_nodes: for db_node in db_nodes:
if db_node['concept'] not in memory_concepts: if db_node['concept'] not in memory_concepts:
self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': db_node['concept']}) self.memory_graph.db.db.graph_data.nodes.delete_one({'concept': db_node['concept']})
# 处理边的信息 # 处理边的信息
db_edges = list(self.memory_graph.db.db.graph_data.edges.find()) db_edges = list(self.memory_graph.db.db.graph_data.edges.find())
memory_edges = list(self.memory_graph.G.edges()) memory_edges = list(self.memory_graph.G.edges())
# 创建边的哈希值字典 # 创建边的哈希值字典
db_edge_dict = {} db_edge_dict = {}
for edge in db_edges: for edge in db_edges:
@@ -351,13 +355,13 @@ class Hippocampus:
'hash': edge_hash, 'hash': edge_hash,
'strength': edge.get('strength', 1) 'strength': edge.get('strength', 1)
} }
# 检查并更新边 # 检查并更新边
for source, target in memory_edges: for source, target in memory_edges:
edge_hash = self.calculate_edge_hash(source, target) edge_hash = self.calculate_edge_hash(source, target)
edge_key = (source, target) edge_key = (source, target)
strength = self.memory_graph.G[source][target].get('strength', 1) strength = self.memory_graph.G[source][target].get('strength', 1)
if edge_key not in db_edge_dict: if edge_key not in db_edge_dict:
# 添加新边 # 添加新边
edge_data = { edge_data = {
@@ -377,7 +381,7 @@ class Hippocampus:
'strength': strength 'strength': strength
}} }}
) )
# 删除多余的边 # 删除多余的边
memory_edge_set = set(memory_edges) memory_edge_set = set(memory_edges)
for edge_key in db_edge_dict: for edge_key in db_edge_dict:
@@ -392,7 +396,7 @@ class Hippocampus:
"""从数据库同步数据到内存中的图结构""" """从数据库同步数据到内存中的图结构"""
# 清空当前图 # 清空当前图
self.memory_graph.G.clear() self.memory_graph.G.clear()
# 从数据库加载所有节点 # 从数据库加载所有节点
nodes = self.memory_graph.db.db.graph_data.nodes.find() nodes = self.memory_graph.db.db.graph_data.nodes.find()
for node in nodes: for node in nodes:
@@ -403,7 +407,7 @@ class Hippocampus:
memory_items = [memory_items] if memory_items else [] memory_items = [memory_items] if memory_items else []
# 添加节点到图中 # 添加节点到图中
self.memory_graph.G.add_node(concept, memory_items=memory_items) self.memory_graph.G.add_node(concept, memory_items=memory_items)
# 从数据库加载所有边 # 从数据库加载所有边
edges = self.memory_graph.db.db.graph_data.edges.find() edges = self.memory_graph.db.db.graph_data.edges.find()
for edge in edges: for edge in edges:
@@ -413,7 +417,7 @@ class Hippocampus:
# 只有当源节点和目标节点都存在时才添加边 # 只有当源节点和目标节点都存在时才添加边
if source in self.memory_graph.G and target in self.memory_graph.G: if source in self.memory_graph.G and target in self.memory_graph.G:
self.memory_graph.G.add_edge(source, target, strength=strength) self.memory_graph.G.add_edge(source, target, strength=strength)
async def operation_forget_topic(self, percentage=0.1): async def operation_forget_topic(self, percentage=0.1):
"""随机选择图中一定比例的节点进行检查,根据条件决定是否遗忘""" """随机选择图中一定比例的节点进行检查,根据条件决定是否遗忘"""
# 获取所有节点 # 获取所有节点
@@ -422,18 +426,18 @@ class Hippocampus:
check_count = max(1, int(len(all_nodes) * percentage)) check_count = max(1, int(len(all_nodes) * percentage))
# 随机选择节点 # 随机选择节点
nodes_to_check = random.sample(all_nodes, check_count) nodes_to_check = random.sample(all_nodes, check_count)
forgotten_nodes = [] forgotten_nodes = []
for node in nodes_to_check: for node in nodes_to_check:
# 获取节点的连接数 # 获取节点的连接数
connections = self.memory_graph.G.degree(node) connections = self.memory_graph.G.degree(node)
# 获取节点的内容条数 # 获取节点的内容条数
memory_items = self.memory_graph.G.nodes[node].get('memory_items', []) memory_items = self.memory_graph.G.nodes[node].get('memory_items', [])
if not isinstance(memory_items, list): if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else [] memory_items = [memory_items] if memory_items else []
content_count = len(memory_items) content_count = len(memory_items)
# 检查连接强度 # 检查连接强度
weak_connections = True weak_connections = True
if connections > 1: # 只有当连接数大于1时才检查强度 if connections > 1: # 只有当连接数大于1时才检查强度
@@ -442,20 +446,20 @@ class Hippocampus:
if strength > 2: if strength > 2:
weak_connections = False weak_connections = False
break break
# 如果满足遗忘条件 # 如果满足遗忘条件
if (connections <= 1 and weak_connections) or content_count <= 2: if (connections <= 1 and weak_connections) or content_count <= 2:
removed_item = self.memory_graph.forget_topic(node) removed_item = self.memory_graph.forget_topic(node)
if removed_item: if removed_item:
forgotten_nodes.append((node, removed_item)) forgotten_nodes.append((node, removed_item))
print(f"遗忘节点 {node} 的记忆: {removed_item}") logger.debug(f"遗忘节点 {node} 的记忆: {removed_item}")
# 同步到数据库 # 同步到数据库
if forgotten_nodes: if forgotten_nodes:
self.sync_memory_to_db() self.sync_memory_to_db()
print(f"完成遗忘操作,共遗忘 {len(forgotten_nodes)} 个节点的记忆") logger.debug(f"完成遗忘操作,共遗忘 {len(forgotten_nodes)} 个节点的记忆")
else: else:
print("本次检查没有节点满足遗忘条件") logger.debug("本次检查没有节点满足遗忘条件")
async def merge_memory(self, topic): async def merge_memory(self, topic):
""" """
@@ -468,35 +472,35 @@ class Hippocampus:
memory_items = self.memory_graph.G.nodes[topic].get('memory_items', []) memory_items = self.memory_graph.G.nodes[topic].get('memory_items', [])
if not isinstance(memory_items, list): if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else [] memory_items = [memory_items] if memory_items else []
# 如果记忆项不足,直接返回 # 如果记忆项不足,直接返回
if len(memory_items) < 10: if len(memory_items) < 10:
return return
# 随机选择10条记忆 # 随机选择10条记忆
selected_memories = random.sample(memory_items, 10) selected_memories = random.sample(memory_items, 10)
# 拼接成文本 # 拼接成文本
merged_text = "\n".join(selected_memories) merged_text = "\n".join(selected_memories)
print(f"\n[合并记忆] 话题: {topic}") logger.debug(f"\n[合并记忆] 话题: {topic}")
print(f"选择的记忆:\n{merged_text}") logger.debug(f"选择的记忆:\n{merged_text}")
# 使用memory_compress生成新的压缩记忆 # 使用memory_compress生成新的压缩记忆
compressed_memories = await self.memory_compress(selected_memories, 0.1) compressed_memories = await self.memory_compress(selected_memories, 0.1)
# 从原记忆列表中移除被选中的记忆 # 从原记忆列表中移除被选中的记忆
for memory in selected_memories: for memory in selected_memories:
memory_items.remove(memory) memory_items.remove(memory)
# 添加新的压缩记忆 # 添加新的压缩记忆
for _, compressed_memory in compressed_memories: for _, compressed_memory in compressed_memories:
memory_items.append(compressed_memory) memory_items.append(compressed_memory)
print(f"添加压缩记忆: {compressed_memory}") logger.info(f"添加压缩记忆: {compressed_memory}")
# 更新节点的记忆项 # 更新节点的记忆项
self.memory_graph.G.nodes[topic]['memory_items'] = memory_items self.memory_graph.G.nodes[topic]['memory_items'] = memory_items
print(f"完成记忆合并,当前记忆数量: {len(memory_items)}") logger.debug(f"完成记忆合并,当前记忆数量: {len(memory_items)}")
async def operation_merge_memory(self, percentage=0.1): async def operation_merge_memory(self, percentage=0.1):
""" """
随机检查一定比例的节点对内容数量超过100的节点进行记忆合并 随机检查一定比例的节点对内容数量超过100的节点进行记忆合并
@@ -510,7 +514,7 @@ class Hippocampus:
check_count = max(1, int(len(all_nodes) * percentage)) check_count = max(1, int(len(all_nodes) * percentage))
# 随机选择节点 # 随机选择节点
nodes_to_check = random.sample(all_nodes, check_count) nodes_to_check = random.sample(all_nodes, check_count)
merged_nodes = [] merged_nodes = []
for node in nodes_to_check: for node in nodes_to_check:
# 获取节点的内容条数 # 获取节点的内容条数
@@ -518,25 +522,25 @@ class Hippocampus:
if not isinstance(memory_items, list): if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else [] memory_items = [memory_items] if memory_items else []
content_count = len(memory_items) content_count = len(memory_items)
# 如果内容数量超过100进行合并 # 如果内容数量超过100进行合并
if content_count > 100: if content_count > 100:
print(f"\n检查节点: {node}, 当前记忆数量: {content_count}") logger.debug(f"检查节点: {node}, 当前记忆数量: {content_count}")
await self.merge_memory(node) await self.merge_memory(node)
merged_nodes.append(node) merged_nodes.append(node)
# 同步到数据库 # 同步到数据库
if merged_nodes: if merged_nodes:
self.sync_memory_to_db() self.sync_memory_to_db()
print(f"\n完成记忆合并操作,共处理 {len(merged_nodes)} 个节点") logger.debug(f"完成记忆合并操作,共处理 {len(merged_nodes)} 个节点")
else: else:
print("\n本次检查没有需要合并的节点") logger.debug("本次检查没有需要合并的节点")
def find_topic_llm(self,text, topic_num): def find_topic_llm(self, text, topic_num):
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,用逗号,隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要有序号,不要告诉我其他内容。' prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,用逗号,隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要有序号,不要告诉我其他内容。'
return prompt return prompt
def topic_what(self,text, topic, time_info): def topic_what(self, text, topic, time_info):
prompt = f'这是一段文字,{time_info}{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,可以包含时间和人物,以及具体的观点。只输出这句话就好' prompt = f'这是一段文字,{time_info}{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,可以包含时间和人物,以及具体的观点。只输出这句话就好'
return prompt return prompt
@@ -551,11 +555,12 @@ class Hippocampus:
""" """
topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(text, 5)) topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(text, 5))
# print(f"话题: {topics_response[0]}") # print(f"话题: {topics_response[0]}")
topics = [topic.strip() for topic in topics_response[0].replace("", ",").replace("", ",").replace(" ", ",").split(",") if topic.strip()] topics = [topic.strip() for topic in
topics_response[0].replace("", ",").replace("", ",").replace(" ", ",").split(",") if topic.strip()]
# print(f"话题: {topics}") # print(f"话题: {topics}")
return topics return topics
def _find_similar_topics(self, topics: list, similarity_threshold: float = 0.4, debug_info: str = "") -> list: def _find_similar_topics(self, topics: list, similarity_threshold: float = 0.4, debug_info: str = "") -> list:
"""查找与给定主题相似的记忆主题 """查找与给定主题相似的记忆主题
@@ -569,16 +574,16 @@ class Hippocampus:
""" """
all_memory_topics = self.get_all_node_names() all_memory_topics = self.get_all_node_names()
all_similar_topics = [] all_similar_topics = []
# 计算每个识别出的主题与记忆主题的相似度 # 计算每个识别出的主题与记忆主题的相似度
for topic in topics: for topic in topics:
if debug_info: if debug_info:
# print(f"\033[1;32m[{debug_info}]\033[0m 正在思考有没有见过: {topic}") # print(f"\033[1;32m[{debug_info}]\033[0m 正在思考有没有见过: {topic}")
pass pass
topic_vector = text_to_vector(topic) topic_vector = text_to_vector(topic)
has_similar_topic = False has_similar_topic = False
for memory_topic in all_memory_topics: for memory_topic in all_memory_topics:
memory_vector = text_to_vector(memory_topic) memory_vector = text_to_vector(memory_topic)
# 获取所有唯一词 # 获取所有唯一词
@@ -588,20 +593,20 @@ class Hippocampus:
v2 = [memory_vector.get(word, 0) for word in all_words] v2 = [memory_vector.get(word, 0) for word in all_words]
# 计算相似度 # 计算相似度
similarity = cosine_similarity(v1, v2) similarity = cosine_similarity(v1, v2)
if similarity >= similarity_threshold: if similarity >= similarity_threshold:
has_similar_topic = True has_similar_topic = True
if debug_info: if debug_info:
# print(f"\033[1;32m[{debug_info}]\033[0m 找到相似主题: {topic} -> {memory_topic} (相似度: {similarity:.2f})") # print(f"\033[1;32m[{debug_info}]\033[0m 找到相似主题: {topic} -> {memory_topic} (相似度: {similarity:.2f})")
pass pass
all_similar_topics.append((memory_topic, similarity)) all_similar_topics.append((memory_topic, similarity))
if not has_similar_topic and debug_info: if not has_similar_topic and debug_info:
# print(f"\033[1;31m[{debug_info}]\033[0m 没有见过: {topic} ,呃呃") # print(f"\033[1;31m[{debug_info}]\033[0m 没有见过: {topic} ,呃呃")
pass pass
return all_similar_topics return all_similar_topics
def _get_top_topics(self, similar_topics: list, max_topics: int = 5) -> list: def _get_top_topics(self, similar_topics: list, max_topics: int = 5) -> list:
"""获取相似度最高的主题 """获取相似度最高的主题
@@ -614,36 +619,36 @@ class Hippocampus:
""" """
seen_topics = set() seen_topics = set()
top_topics = [] top_topics = []
for topic, score in sorted(similar_topics, key=lambda x: x[1], reverse=True): for topic, score in sorted(similar_topics, key=lambda x: x[1], reverse=True):
if topic not in seen_topics and len(top_topics) < max_topics: if topic not in seen_topics and len(top_topics) < max_topics:
seen_topics.add(topic) seen_topics.add(topic)
top_topics.append((topic, score)) top_topics.append((topic, score))
return top_topics return top_topics
async def memory_activate_value(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.3) -> int: async def memory_activate_value(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.3) -> int:
"""计算输入文本对记忆的激活程度""" """计算输入文本对记忆的激活程度"""
print(f"\033[1;32m[记忆激活]\033[0m 识别主题: {await self._identify_topics(text)}") logger.info(f"识别主题: {await self._identify_topics(text)}")
# 识别主题 # 识别主题
identified_topics = await self._identify_topics(text) identified_topics = await self._identify_topics(text)
if not identified_topics: if not identified_topics:
return 0 return 0
# 查找相似主题 # 查找相似主题
all_similar_topics = self._find_similar_topics( all_similar_topics = self._find_similar_topics(
identified_topics, identified_topics,
similarity_threshold=similarity_threshold, similarity_threshold=similarity_threshold,
debug_info="记忆激活" debug_info="记忆激活"
) )
if not all_similar_topics: if not all_similar_topics:
return 0 return 0
# 获取最相关的主题 # 获取最相关的主题
top_topics = self._get_top_topics(all_similar_topics, max_topics) top_topics = self._get_top_topics(all_similar_topics, max_topics)
# 如果只找到一个主题,进行惩罚 # 如果只找到一个主题,进行惩罚
if len(top_topics) == 1: if len(top_topics) == 1:
topic, score = top_topics[0] topic, score = top_topics[0]
@@ -653,15 +658,16 @@ class Hippocampus:
memory_items = [memory_items] if memory_items else [] memory_items = [memory_items] if memory_items else []
content_count = len(memory_items) content_count = len(memory_items)
penalty = 1.0 / (1 + math.log(content_count + 1)) penalty = 1.0 / (1 + math.log(content_count + 1))
activation = int(score * 50 * penalty) activation = int(score * 50 * penalty)
print(f"\033[1;32m[记忆激活]\033[0m 单主题「{topic}」- 相似度: {score:.3f}, 内容数: {content_count}, 激活值: {activation}") logger.info(
f"[记忆激活]单主题「{topic}」- 相似度: {score:.3f}, 内容数: {content_count}, 激活值: {activation}")
return activation return activation
# 计算关键词匹配率,同时考虑内容数量 # 计算关键词匹配率,同时考虑内容数量
matched_topics = set() matched_topics = set()
topic_similarities = {} topic_similarities = {}
for memory_topic, similarity in top_topics: for memory_topic, similarity in top_topics:
# 计算内容数量惩罚 # 计算内容数量惩罚
memory_items = self.memory_graph.G.nodes[memory_topic].get('memory_items', []) memory_items = self.memory_graph.G.nodes[memory_topic].get('memory_items', [])
@@ -669,7 +675,7 @@ class Hippocampus:
memory_items = [memory_items] if memory_items else [] memory_items = [memory_items] if memory_items else []
content_count = len(memory_items) content_count = len(memory_items)
penalty = 1.0 / (1 + math.log(content_count + 1)) penalty = 1.0 / (1 + math.log(content_count + 1))
# 对每个记忆主题,检查它与哪些输入主题相似 # 对每个记忆主题,检查它与哪些输入主题相似
for input_topic in identified_topics: for input_topic in identified_topics:
topic_vector = text_to_vector(input_topic) topic_vector = text_to_vector(input_topic)
@@ -682,33 +688,36 @@ class Hippocampus:
matched_topics.add(input_topic) matched_topics.add(input_topic)
adjusted_sim = sim * penalty adjusted_sim = sim * penalty
topic_similarities[input_topic] = max(topic_similarities.get(input_topic, 0), adjusted_sim) topic_similarities[input_topic] = max(topic_similarities.get(input_topic, 0), adjusted_sim)
print(f"\033[1;32m[记忆激活]\033[0m 主题「{input_topic}」-> 「{memory_topic}」(内容数: {content_count}, 相似度: {adjusted_sim:.3f})") logger.info(
f"[记忆激活]主题「{input_topic}」-> 「{memory_topic}」(内容数: {content_count}, 相似度: {adjusted_sim:.3f})")
# 计算主题匹配率和平均相似度 # 计算主题匹配率和平均相似度
topic_match = len(matched_topics) / len(identified_topics) topic_match = len(matched_topics) / len(identified_topics)
average_similarities = sum(topic_similarities.values()) / len(topic_similarities) if topic_similarities else 0 average_similarities = sum(topic_similarities.values()) / len(topic_similarities) if topic_similarities else 0
# 计算最终激活值 # 计算最终激活值
activation = int((topic_match + average_similarities) / 2 * 100) activation = int((topic_match + average_similarities) / 2 * 100)
print(f"\033[1;32m[记忆激活]\033[0m 匹配率: {topic_match:.3f}, 平均相似度: {average_similarities:.3f}, 激活值: {activation}") logger.info(
f"[记忆激活]匹配率: {topic_match:.3f}, 平均相似度: {average_similarities:.3f}, 激活值: {activation}")
return activation return activation
async def get_relevant_memories(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.4, max_memory_num: int = 5) -> list: async def get_relevant_memories(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.4,
max_memory_num: int = 5) -> list:
"""根据输入文本获取相关的记忆内容""" """根据输入文本获取相关的记忆内容"""
# 识别主题 # 识别主题
identified_topics = await self._identify_topics(text) identified_topics = await self._identify_topics(text)
# 查找相似主题 # 查找相似主题
all_similar_topics = self._find_similar_topics( all_similar_topics = self._find_similar_topics(
identified_topics, identified_topics,
similarity_threshold=similarity_threshold, similarity_threshold=similarity_threshold,
debug_info="记忆检索" debug_info="记忆检索"
) )
# 获取最相关的主题 # 获取最相关的主题
relevant_topics = self._get_top_topics(all_similar_topics, max_topics) relevant_topics = self._get_top_topics(all_similar_topics, max_topics)
# 获取相关记忆内容 # 获取相关记忆内容
relevant_memories = [] relevant_memories = []
for topic, score in relevant_topics: for topic, score in relevant_topics:
@@ -716,8 +725,8 @@ class Hippocampus:
first_layer, _ = self.memory_graph.get_related_item(topic, depth=1) first_layer, _ = self.memory_graph.get_related_item(topic, depth=1)
if first_layer: if first_layer:
# 如果记忆条数超过限制,随机选择指定数量的记忆 # 如果记忆条数超过限制,随机选择指定数量的记忆
if len(first_layer) > max_memory_num/2: if len(first_layer) > max_memory_num / 2:
first_layer = random.sample(first_layer, max_memory_num//2) first_layer = random.sample(first_layer, max_memory_num // 2)
# 为每条记忆添加来源主题和相似度信息 # 为每条记忆添加来源主题和相似度信息
for memory in first_layer: for memory in first_layer:
relevant_memories.append({ relevant_memories.append({
@@ -725,20 +734,20 @@ class Hippocampus:
'similarity': score, 'similarity': score,
'content': memory 'content': memory
}) })
# 如果记忆数量超过5个,随机选择5个 # 如果记忆数量超过5个,随机选择5个
# 按相似度排序 # 按相似度排序
relevant_memories.sort(key=lambda x: x['similarity'], reverse=True) relevant_memories.sort(key=lambda x: x['similarity'], reverse=True)
if len(relevant_memories) > max_memory_num: if len(relevant_memories) > max_memory_num:
relevant_memories = random.sample(relevant_memories, max_memory_num) relevant_memories = random.sample(relevant_memories, max_memory_num)
return relevant_memories return relevant_memories
def segment_text(text): def segment_text(text):
seg_text = list(jieba.cut(text)) seg_text = list(jieba.cut(text))
return seg_text return seg_text
from nonebot import get_driver from nonebot import get_driver
@@ -749,19 +758,19 @@ config = driver.config
start_time = time.time() start_time = time.time()
Database.initialize( Database.initialize(
host= config.MONGODB_HOST, host=config.MONGODB_HOST,
port= config.MONGODB_PORT, port=config.MONGODB_PORT,
db_name= config.DATABASE_NAME, db_name=config.DATABASE_NAME,
username= config.MONGODB_USERNAME, username=config.MONGODB_USERNAME,
password= config.MONGODB_PASSWORD, password=config.MONGODB_PASSWORD,
auth_source=config.MONGODB_AUTH_SOURCE auth_source=config.MONGODB_AUTH_SOURCE
) )
#创建记忆图 # 创建记忆图
memory_graph = Memory_graph() memory_graph = Memory_graph()
#创建海马体 # 创建海马体
hippocampus = Hippocampus(memory_graph) hippocampus = Hippocampus(memory_graph)
#从数据库加载记忆图 # 从数据库加载记忆图
hippocampus.sync_memory_from_db() hippocampus.sync_memory_from_db()
end_time = time.time() end_time = time.time()
print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f}]\033[0m") logger.success(f"加载海马体耗时: {end_time - start_time:.2f}")

View File

@@ -743,7 +743,7 @@ class Hippocampus:
async def memory_activate_value(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.3) -> int: async def memory_activate_value(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.3) -> int:
"""计算输入文本对记忆的激活程度""" """计算输入文本对记忆的激活程度"""
print(f"\033[1;32m[记忆激活]\033[0m 识别主题: {await self._identify_topics(text)}") logger.info(f"[记忆激活]识别主题: {await self._identify_topics(text)}")
identified_topics = await self._identify_topics(text) identified_topics = await self._identify_topics(text)
if not identified_topics: if not identified_topics:

View File

@@ -28,10 +28,10 @@ class LLM_request:
raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e
self.model_name = model["name"] self.model_name = model["name"]
self.params = kwargs self.params = kwargs
self.pri_in = model.get("pri_in", 0) self.pri_in = model.get("pri_in", 0)
self.pri_out = model.get("pri_out", 0) self.pri_out = model.get("pri_out", 0)
# 获取数据库实例 # 获取数据库实例
self.db = Database.get_instance() self.db = Database.get_instance()
self._init_database() self._init_database()
@@ -45,11 +45,11 @@ class LLM_request:
self.db.db.llm_usage.create_index([("user_id", 1)]) self.db.db.llm_usage.create_index([("user_id", 1)])
self.db.db.llm_usage.create_index([("request_type", 1)]) self.db.db.llm_usage.create_index([("request_type", 1)])
except Exception as e: except Exception as e:
logger.error(f"创建数据库索引失败: {e}") logger.error(f"创建数据库索引失败")
def _record_usage(self, prompt_tokens: int, completion_tokens: int, total_tokens: int, def _record_usage(self, prompt_tokens: int, completion_tokens: int, total_tokens: int,
user_id: str = "system", request_type: str = "chat", user_id: str = "system", request_type: str = "chat",
endpoint: str = "/chat/completions"): endpoint: str = "/chat/completions"):
"""记录模型使用情况到数据库 """记录模型使用情况到数据库
Args: Args:
prompt_tokens: 输入token数 prompt_tokens: 输入token数
@@ -79,8 +79,8 @@ class LLM_request:
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, " f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
f"总计: {total_tokens}" f"总计: {total_tokens}"
) )
except Exception as e: except Exception:
logger.error(f"记录token使用情况失败: {e}") logger.error(f"记录token使用情况失败")
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float: def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
"""计算API调用成本 """计算API调用成本
@@ -140,12 +140,12 @@ class LLM_request:
} }
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}" api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
#判断是否为流式 # 判断是否为流式
stream_mode = self.params.get("stream", False) stream_mode = self.params.get("stream", False)
if self.params.get("stream", False) is True: if self.params.get("stream", False) is True:
logger.info(f"进入流式输出模式发送请求到URL: {api_url}") logger.debug(f"进入流式输出模式发送请求到URL: {api_url}")
else: else:
logger.info(f"发送请求到URL: {api_url}") logger.debug(f"发送请求到URL: {api_url}")
logger.info(f"使用模型: {self.model_name}") logger.info(f"使用模型: {self.model_name}")
# 构建请求体 # 构建请求体
@@ -158,7 +158,7 @@ class LLM_request:
try: try:
# 使用上下文管理器处理会话 # 使用上下文管理器处理会话
headers = await self._build_headers() headers = await self._build_headers()
#似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响 # 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
if stream_mode: if stream_mode:
headers["Accept"] = "text/event-stream" headers["Accept"] = "text/event-stream"
@@ -184,29 +184,31 @@ class LLM_request:
logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}") logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}")
if response.status == 403: if response.status == 403:
# 尝试降级Pro模型 # 尝试降级Pro模型
if self.model_name.startswith("Pro/") and self.base_url == "https://api.siliconflow.cn/v1/": if self.model_name.startswith(
"Pro/") and self.base_url == "https://api.siliconflow.cn/v1/":
old_model_name = self.model_name old_model_name = self.model_name
self.model_name = self.model_name[4:] # 移除"Pro/"前缀 self.model_name = self.model_name[4:] # 移除"Pro/"前缀
logger.warning(f"检测到403错误模型从 {old_model_name} 降级为 {self.model_name}") logger.warning(f"检测到403错误模型从 {old_model_name} 降级为 {self.model_name}")
# 对全局配置进行更新 # 对全局配置进行更新
if hasattr(global_config, 'llm_normal') and global_config.llm_normal.get('name') == old_model_name: if hasattr(global_config, 'llm_normal') and global_config.llm_normal.get(
'name') == old_model_name:
global_config.llm_normal['name'] = self.model_name global_config.llm_normal['name'] = self.model_name
logger.warning(f"已将全局配置中的 llm_normal 模型降级") logger.warning(f"已将全局配置中的 llm_normal 模型降级")
# 更新payload中的模型名 # 更新payload中的模型名
if payload and 'model' in payload: if payload and 'model' in payload:
payload['model'] = self.model_name payload['model'] = self.model_name
# 重新尝试请求 # 重新尝试请求
retry -= 1 # 不计入重试次数 retry -= 1 # 不计入重试次数
continue continue
raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}") raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}")
response.raise_for_status() response.raise_for_status()
#将流式输出转化为非流式输出 # 将流式输出转化为非流式输出
if stream_mode: if stream_mode:
accumulated_content = "" accumulated_content = ""
async for line_bytes in response.content: async for line_bytes in response.content:
@@ -224,8 +226,8 @@ class LLM_request:
if delta_content is None: if delta_content is None:
delta_content = "" delta_content = ""
accumulated_content += delta_content accumulated_content += delta_content
except Exception as e: except Exception:
logger.error(f"解析流式输出错误: {e}") logger.exception(f"解析流式输出错")
content = accumulated_content content = accumulated_content
reasoning_content = "" reasoning_content = ""
think_match = re.search(r'<think>(.*?)</think>', content, re.DOTALL) think_match = re.search(r'<think>(.*?)</think>', content, re.DOTALL)
@@ -233,12 +235,15 @@ class LLM_request:
reasoning_content = think_match.group(1).strip() reasoning_content = think_match.group(1).strip()
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip() content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
# 构造一个伪result以便调用自定义响应处理器或默认处理器 # 构造一个伪result以便调用自定义响应处理器或默认处理器
result = {"choices": [{"message": {"content": content, "reasoning_content": reasoning_content}}]} result = {
return response_handler(result) if response_handler else self._default_response_handler(result, user_id, request_type, endpoint) "choices": [{"message": {"content": content, "reasoning_content": reasoning_content}}]}
return response_handler(result) if response_handler else self._default_response_handler(
result, user_id, request_type, endpoint)
else: else:
result = await response.json() result = await response.json()
# 使用自定义处理器或默认处理 # 使用自定义处理器或默认处理
return response_handler(result) if response_handler else self._default_response_handler(result, user_id, request_type, endpoint) return response_handler(result) if response_handler else self._default_response_handler(
result, user_id, request_type, endpoint)
except Exception as e: except Exception as e:
if retry < policy["max_retries"] - 1: if retry < policy["max_retries"] - 1:
@@ -252,8 +257,8 @@ class LLM_request:
logger.error("达到最大重试次数,请求仍然失败") logger.error("达到最大重试次数,请求仍然失败")
raise RuntimeError("达到最大重试次数API请求仍然失败") raise RuntimeError("达到最大重试次数API请求仍然失败")
async def _transform_parameters(self, params: dict) ->dict: async def _transform_parameters(self, params: dict) -> dict:
""" """
根据模型名称转换参数: 根据模型名称转换参数:
- 对于需要转换的OpenAI CoT系列模型例如 "o3-mini"),删除 'temprature' 参数, - 对于需要转换的OpenAI CoT系列模型例如 "o3-mini"),删除 'temprature' 参数,
@@ -262,7 +267,8 @@ class LLM_request:
# 复制一份参数,避免直接修改原始数据 # 复制一份参数,避免直接修改原始数据
new_params = dict(params) new_params = dict(params)
# 定义需要转换的模型列表 # 定义需要转换的模型列表
models_needing_transformation = ["o3-mini", "o1-mini", "o1-preview", "o1-2024-12-17", "o1-preview-2024-09-12", "o3-mini-2025-01-31", "o1-mini-2024-09-12"] models_needing_transformation = ["o3-mini", "o1-mini", "o1-preview", "o1-2024-12-17", "o1-preview-2024-09-12",
"o3-mini-2025-01-31", "o1-mini-2024-09-12"]
if self.model_name.lower() in models_needing_transformation: if self.model_name.lower() in models_needing_transformation:
# 删除 'temprature' 参数(如果存在) # 删除 'temprature' 参数(如果存在)
new_params.pop("temperature", None) new_params.pop("temperature", None)
@@ -298,13 +304,13 @@ class LLM_request:
**params_copy **params_copy
} }
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查 # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
if self.model_name.lower() in ["o3-mini", "o1-mini", "o1-preview", "o1-2024-12-17", "o1-preview-2024-09-12", "o3-mini-2025-01-31", "o1-mini-2024-09-12"] and "max_tokens" in payload: if self.model_name.lower() in ["o3-mini", "o1-mini", "o1-preview", "o1-2024-12-17", "o1-preview-2024-09-12",
"o3-mini-2025-01-31", "o1-mini-2024-09-12"] and "max_tokens" in payload:
payload["max_completion_tokens"] = payload.pop("max_tokens") payload["max_completion_tokens"] = payload.pop("max_tokens")
return payload return payload
def _default_response_handler(self, result: dict, user_id: str = "system", def _default_response_handler(self, result: dict, user_id: str = "system",
request_type: str = "chat", endpoint: str = "/chat/completions") -> Tuple: request_type: str = "chat", endpoint: str = "/chat/completions") -> Tuple:
"""默认响应解析""" """默认响应解析"""
if "choices" in result and result["choices"]: if "choices" in result and result["choices"]:
message = result["choices"][0]["message"] message = result["choices"][0]["message"]
@@ -356,8 +362,8 @@ class LLM_request:
return { return {
"Authorization": f"Bearer {self.api_key}", "Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json" "Content-Type": "application/json"
} }
# 防止小朋友们截图自己的key # 防止小朋友们截图自己的key
async def generate_response(self, prompt: str) -> Tuple[str, str]: async def generate_response(self, prompt: str) -> Tuple[str, str]:
"""根据输入的提示生成模型的异步响应""" """根据输入的提示生成模型的异步响应"""
@@ -404,6 +410,7 @@ class LLM_request:
Returns: Returns:
list: embedding向量如果失败则返回None list: embedding向量如果失败则返回None
""" """
def embedding_handler(result): def embedding_handler(result):
"""处理响应""" """处理响应"""
if "data" in result and len(result["data"]) > 0: if "data" in result and len(result["data"]) > 0:
@@ -425,4 +432,3 @@ class LLM_request:
response_handler=embedding_handler response_handler=embedding_handler
) )
return embedding return embedding

View File

@@ -4,7 +4,7 @@ import time
from dataclasses import dataclass from dataclasses import dataclass
from ..chat.config import global_config from ..chat.config import global_config
from loguru import logger
@dataclass @dataclass
class MoodState: class MoodState:
@@ -210,7 +210,7 @@ class MoodManager:
def print_mood_status(self) -> None: def print_mood_status(self) -> None:
"""打印当前情绪状态""" """打印当前情绪状态"""
print(f"\033[1;35m[情绪状态]\033[0m 愉悦度: {self.current_mood.valence:.2f}, " logger.info(f"[情绪状态]愉悦度: {self.current_mood.valence:.2f}, "
f"唤醒度: {self.current_mood.arousal:.2f}, " f"唤醒度: {self.current_mood.arousal:.2f}, "
f"心情: {self.current_mood.text}") f"心情: {self.current_mood.text}")

View File

@@ -57,12 +57,12 @@ class ScheduleGenerator:
existing_schedule = self.db.db.schedule.find_one({"date": date_str}) existing_schedule = self.db.db.schedule.find_one({"date": date_str})
if existing_schedule: if existing_schedule:
logger.info(f"{date_str}的日程已存在:") logger.debug(f"{date_str}的日程已存在:")
schedule_text = existing_schedule["schedule"] schedule_text = existing_schedule["schedule"]
# print(self.schedule_text) # print(self.schedule_text)
elif read_only == False: elif not read_only:
logger.info(f"{date_str}的日程不存在,准备生成新的日程。") logger.debug(f"{date_str}的日程不存在,准备生成新的日程。")
prompt = f"""我是{global_config.BOT_NICKNAME}{global_config.PROMPT_SCHEDULE_GEN},请为我生成{date_str}{weekday})的日程安排,包括:""" + \ prompt = f"""我是{global_config.BOT_NICKNAME}{global_config.PROMPT_SCHEDULE_GEN},请为我生成{date_str}{weekday})的日程安排,包括:""" + \
""" """
1. 早上的学习和工作安排 1. 早上的学习和工作安排
@@ -78,7 +78,7 @@ class ScheduleGenerator:
schedule_text = "生成日程时出错了" schedule_text = "生成日程时出错了"
# print(self.schedule_text) # print(self.schedule_text)
else: else:
logger.info(f"{date_str}的日程不存在。") logger.debug(f"{date_str}的日程不存在。")
schedule_text = "忘了" schedule_text = "忘了"
return schedule_text, None return schedule_text, None
@@ -154,10 +154,10 @@ class ScheduleGenerator:
logger.warning("今日日程有误,将在下次运行时重新生成") logger.warning("今日日程有误,将在下次运行时重新生成")
self.db.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")}) self.db.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")})
else: else:
logger.info("\n=== 今日日程安排 ===") logger.info("=== 今日日程安排 ===")
for time_str, activity in self.today_schedule.items(): for time_str, activity in self.today_schedule.items():
logger.info(f"时间[{time_str}]: 活动[{activity}]") logger.info(f"时间[{time_str}]: 活动[{activity}]")
logger.info("==================\n") logger.info("==================")
# def main(): # def main():

View File

@@ -3,6 +3,7 @@ import time
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict from typing import Any, Dict
from loguru import logger
from ...common.database import Database from ...common.database import Database
@@ -153,8 +154,8 @@ class LLMStatistics:
try: try:
all_stats = self._collect_all_statistics() all_stats = self._collect_all_statistics()
self._save_statistics(all_stats) self._save_statistics(all_stats)
except Exception as e: except Exception:
print(f"\033[1;31m[错误]\033[0m 统计数据处理失败: {e}") logger.exception(f"统计数据处理失败")
# 等待1分钟 # 等待1分钟
for _ in range(60): for _ in range(60):

View File

@@ -11,12 +11,14 @@ from pathlib import Path
import random import random
import math import math
import time import time
from loguru import logger
class ChineseTypoGenerator: class ChineseTypoGenerator:
def __init__(self, def __init__(self,
error_rate=0.3, error_rate=0.3,
min_freq=5, min_freq=5,
tone_error_rate=0.2, tone_error_rate=0.2,
word_replace_rate=0.3, word_replace_rate=0.3,
max_freq_diff=200): max_freq_diff=200):
""" """
@@ -34,27 +36,27 @@ class ChineseTypoGenerator:
self.tone_error_rate = tone_error_rate self.tone_error_rate = tone_error_rate
self.word_replace_rate = word_replace_rate self.word_replace_rate = word_replace_rate
self.max_freq_diff = max_freq_diff self.max_freq_diff = max_freq_diff
# 加载数据 # 加载数据
print("正在加载汉字数据库,请稍候...") logger.debug("正在加载汉字数据库,请稍候...")
self.pinyin_dict = self._create_pinyin_dict() self.pinyin_dict = self._create_pinyin_dict()
self.char_frequency = self._load_or_create_char_frequency() self.char_frequency = self._load_or_create_char_frequency()
def _load_or_create_char_frequency(self): def _load_or_create_char_frequency(self):
""" """
加载或创建汉字频率字典 加载或创建汉字频率字典
""" """
cache_file = Path("char_frequency.json") cache_file = Path("char_frequency.json")
# 如果缓存文件存在,直接加载 # 如果缓存文件存在,直接加载
if cache_file.exists(): if cache_file.exists():
with open(cache_file, 'r', encoding='utf-8') as f: with open(cache_file, 'r', encoding='utf-8') as f:
return json.load(f) return json.load(f)
# 使用内置的词频文件 # 使用内置的词频文件
char_freq = defaultdict(int) char_freq = defaultdict(int)
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt') dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
# 读取jieba的词典文件 # 读取jieba的词典文件
with open(dict_path, 'r', encoding='utf-8') as f: with open(dict_path, 'r', encoding='utf-8') as f:
for line in f: for line in f:
@@ -63,15 +65,15 @@ class ChineseTypoGenerator:
for char in word: for char in word:
if self._is_chinese_char(char): if self._is_chinese_char(char):
char_freq[char] += int(freq) char_freq[char] += int(freq)
# 归一化频率值 # 归一化频率值
max_freq = max(char_freq.values()) max_freq = max(char_freq.values())
normalized_freq = {char: freq/max_freq * 1000 for char, freq in char_freq.items()} normalized_freq = {char: freq / max_freq * 1000 for char, freq in char_freq.items()}
# 保存到缓存文件 # 保存到缓存文件
with open(cache_file, 'w', encoding='utf-8') as f: with open(cache_file, 'w', encoding='utf-8') as f:
json.dump(normalized_freq, f, ensure_ascii=False, indent=2) json.dump(normalized_freq, f, ensure_ascii=False, indent=2)
return normalized_freq return normalized_freq
def _create_pinyin_dict(self): def _create_pinyin_dict(self):
@@ -81,7 +83,7 @@ class ChineseTypoGenerator:
# 常用汉字范围 # 常用汉字范围
chars = [chr(i) for i in range(0x4e00, 0x9fff)] chars = [chr(i) for i in range(0x4e00, 0x9fff)]
pinyin_dict = defaultdict(list) pinyin_dict = defaultdict(list)
# 为每个汉字建立拼音映射 # 为每个汉字建立拼音映射
for char in chars: for char in chars:
try: try:
@@ -89,7 +91,7 @@ class ChineseTypoGenerator:
pinyin_dict[py].append(char) pinyin_dict[py].append(char)
except Exception: except Exception:
continue continue
return pinyin_dict return pinyin_dict
def _is_chinese_char(self, char): def _is_chinese_char(self, char):
@@ -107,7 +109,7 @@ class ChineseTypoGenerator:
""" """
# 将句子拆分成单个字符 # 将句子拆分成单个字符
characters = list(sentence) characters = list(sentence)
# 获取每个字符的拼音 # 获取每个字符的拼音
result = [] result = []
for char in characters: for char in characters:
@@ -117,7 +119,7 @@ class ChineseTypoGenerator:
# 获取拼音(数字声调) # 获取拼音(数字声调)
py = pinyin(char, style=Style.TONE3)[0][0] py = pinyin(char, style=Style.TONE3)[0][0]
result.append((char, py)) result.append((char, py))
return result return result
def _get_similar_tone_pinyin(self, py): def _get_similar_tone_pinyin(self, py):
@@ -127,19 +129,19 @@ class ChineseTypoGenerator:
# 检查拼音是否为空或无效 # 检查拼音是否为空或无效
if not py or len(py) < 1: if not py or len(py) < 1:
return py return py
# 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况 # 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况
if not py[-1].isdigit(): if not py[-1].isdigit():
# 为非数字结尾的拼音添加数字声调1 # 为非数字结尾的拼音添加数字声调1
return py + '1' return py + '1'
base = py[:-1] # 去掉声调 base = py[:-1] # 去掉声调
tone = int(py[-1]) # 获取声调 tone = int(py[-1]) # 获取声调
# 处理轻声通常用5表示或无效声调 # 处理轻声通常用5表示或无效声调
if tone not in [1, 2, 3, 4]: if tone not in [1, 2, 3, 4]:
return base + str(random.choice([1, 2, 3, 4])) return base + str(random.choice([1, 2, 3, 4]))
# 正常处理声调 # 正常处理声调
possible_tones = [1, 2, 3, 4] possible_tones = [1, 2, 3, 4]
possible_tones.remove(tone) # 移除原声调 possible_tones.remove(tone) # 移除原声调
@@ -152,11 +154,11 @@ class ChineseTypoGenerator:
""" """
if target_freq > orig_freq: if target_freq > orig_freq:
return 1.0 # 如果替换字频率更高,保持原有概率 return 1.0 # 如果替换字频率更高,保持原有概率
freq_diff = orig_freq - target_freq freq_diff = orig_freq - target_freq
if freq_diff > self.max_freq_diff: if freq_diff > self.max_freq_diff:
return 0.0 # 频率差太大,不替换 return 0.0 # 频率差太大,不替换
# 使用指数衰减函数计算概率 # 使用指数衰减函数计算概率
# 频率差为0时概率为1频率差为max_freq_diff时概率接近0 # 频率差为0时概率为1频率差为max_freq_diff时概率接近0
return math.exp(-3 * freq_diff / self.max_freq_diff) return math.exp(-3 * freq_diff / self.max_freq_diff)
@@ -166,42 +168,42 @@ class ChineseTypoGenerator:
获取与给定字频率相近的同音字,可能包含声调错误 获取与给定字频率相近的同音字,可能包含声调错误
""" """
homophones = [] homophones = []
# 有一定概率使用错误声调 # 有一定概率使用错误声调
if random.random() < self.tone_error_rate: if random.random() < self.tone_error_rate:
wrong_tone_py = self._get_similar_tone_pinyin(py) wrong_tone_py = self._get_similar_tone_pinyin(py)
homophones.extend(self.pinyin_dict[wrong_tone_py]) homophones.extend(self.pinyin_dict[wrong_tone_py])
# 添加正确声调的同音字 # 添加正确声调的同音字
homophones.extend(self.pinyin_dict[py]) homophones.extend(self.pinyin_dict[py])
if not homophones: if not homophones:
return None return None
# 获取原字的频率 # 获取原字的频率
orig_freq = self.char_frequency.get(char, 0) orig_freq = self.char_frequency.get(char, 0)
# 计算所有同音字与原字的频率差,并过滤掉低频字 # 计算所有同音字与原字的频率差,并过滤掉低频字
freq_diff = [(h, self.char_frequency.get(h, 0)) freq_diff = [(h, self.char_frequency.get(h, 0))
for h in homophones for h in homophones
if h != char and self.char_frequency.get(h, 0) >= self.min_freq] if h != char and self.char_frequency.get(h, 0) >= self.min_freq]
if not freq_diff: if not freq_diff:
return None return None
# 计算每个候选字的替换概率 # 计算每个候选字的替换概率
candidates_with_prob = [] candidates_with_prob = []
for h, freq in freq_diff: for h, freq in freq_diff:
prob = self._calculate_replacement_probability(orig_freq, freq) prob = self._calculate_replacement_probability(orig_freq, freq)
if prob > 0: # 只保留有效概率的候选字 if prob > 0: # 只保留有效概率的候选字
candidates_with_prob.append((h, prob)) candidates_with_prob.append((h, prob))
if not candidates_with_prob: if not candidates_with_prob:
return None return None
# 根据概率排序 # 根据概率排序
candidates_with_prob.sort(key=lambda x: x[1], reverse=True) candidates_with_prob.sort(key=lambda x: x[1], reverse=True)
# 返回概率最高的几个字 # 返回概率最高的几个字
return [char for char, _ in candidates_with_prob[:num_candidates]] return [char for char, _ in candidates_with_prob[:num_candidates]]
@@ -223,10 +225,10 @@ class ChineseTypoGenerator:
""" """
if len(word) == 1: if len(word) == 1:
return [] return []
# 获取词的拼音 # 获取词的拼音
word_pinyin = self._get_word_pinyin(word) word_pinyin = self._get_word_pinyin(word)
# 遍历所有可能的同音字组合 # 遍历所有可能的同音字组合
candidates = [] candidates = []
for py in word_pinyin: for py in word_pinyin:
@@ -234,11 +236,11 @@ class ChineseTypoGenerator:
if not chars: if not chars:
return [] return []
candidates.append(chars) candidates.append(chars)
# 生成所有可能的组合 # 生成所有可能的组合
import itertools import itertools
all_combinations = itertools.product(*candidates) all_combinations = itertools.product(*candidates)
# 获取jieba词典和词频信息 # 获取jieba词典和词频信息
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt') dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
valid_words = {} # 改用字典存储词语及其频率 valid_words = {} # 改用字典存储词语及其频率
@@ -249,11 +251,11 @@ class ChineseTypoGenerator:
word_text = parts[0] word_text = parts[0]
word_freq = float(parts[1]) # 获取词频 word_freq = float(parts[1]) # 获取词频
valid_words[word_text] = word_freq valid_words[word_text] = word_freq
# 获取原词的词频作为参考 # 获取原词的词频作为参考
original_word_freq = valid_words.get(word, 0) original_word_freq = valid_words.get(word, 0)
min_word_freq = original_word_freq * 0.1 # 设置最小词频为原词频的10% min_word_freq = original_word_freq * 0.1 # 设置最小词频为原词频的10%
# 过滤和计算频率 # 过滤和计算频率
homophones = [] homophones = []
for combo in all_combinations: for combo in all_combinations:
@@ -268,7 +270,7 @@ class ChineseTypoGenerator:
combined_score = (new_word_freq * 0.7 + char_avg_freq * 0.3) combined_score = (new_word_freq * 0.7 + char_avg_freq * 0.3)
if combined_score >= self.min_freq: if combined_score >= self.min_freq:
homophones.append((new_word, combined_score)) homophones.append((new_word, combined_score))
# 按综合分数排序并限制返回数量 # 按综合分数排序并限制返回数量
sorted_homophones = sorted(homophones, key=lambda x: x[1], reverse=True) sorted_homophones = sorted(homophones, key=lambda x: x[1], reverse=True)
return [word for word, _ in sorted_homophones[:5]] # 限制返回前5个结果 return [word for word, _ in sorted_homophones[:5]] # 限制返回前5个结果
@@ -286,19 +288,19 @@ class ChineseTypoGenerator:
""" """
result = [] result = []
typo_info = [] typo_info = []
# 分词 # 分词
words = self._segment_sentence(sentence) words = self._segment_sentence(sentence)
for word in words: for word in words:
# 如果是标点符号或空格,直接添加 # 如果是标点符号或空格,直接添加
if all(not self._is_chinese_char(c) for c in word): if all(not self._is_chinese_char(c) for c in word):
result.append(word) result.append(word)
continue continue
# 获取词语的拼音 # 获取词语的拼音
word_pinyin = self._get_word_pinyin(word) word_pinyin = self._get_word_pinyin(word)
# 尝试整词替换 # 尝试整词替换
if len(word) > 1 and random.random() < self.word_replace_rate: if len(word) > 1 and random.random() < self.word_replace_rate:
word_homophones = self._get_word_homophones(word) word_homophones = self._get_word_homophones(word)
@@ -307,15 +309,15 @@ class ChineseTypoGenerator:
# 计算词的平均频率 # 计算词的平均频率
orig_freq = sum(self.char_frequency.get(c, 0) for c in word) / len(word) orig_freq = sum(self.char_frequency.get(c, 0) for c in word) / len(word)
typo_freq = sum(self.char_frequency.get(c, 0) for c in typo_word) / len(typo_word) typo_freq = sum(self.char_frequency.get(c, 0) for c in typo_word) / len(typo_word)
# 添加到结果中 # 添加到结果中
result.append(typo_word) result.append(typo_word)
typo_info.append((word, typo_word, typo_info.append((word, typo_word,
' '.join(word_pinyin), ' '.join(word_pinyin),
' '.join(self._get_word_pinyin(typo_word)), ' '.join(self._get_word_pinyin(typo_word)),
orig_freq, typo_freq)) orig_freq, typo_freq))
continue continue
# 如果不进行整词替换,则进行单字替换 # 如果不进行整词替换,则进行单字替换
if len(word) == 1: if len(word) == 1:
char = word char = word
@@ -339,7 +341,7 @@ class ChineseTypoGenerator:
for i, (char, py) in enumerate(zip(word, word_pinyin)): for i, (char, py) in enumerate(zip(word, word_pinyin)):
# 词中的字替换概率降低 # 词中的字替换概率降低
word_error_rate = self.error_rate * (0.7 ** (len(word) - 1)) word_error_rate = self.error_rate * (0.7 ** (len(word) - 1))
if random.random() < word_error_rate: if random.random() < word_error_rate:
similar_chars = self._get_similar_frequency_chars(char, py) similar_chars = self._get_similar_frequency_chars(char, py)
if similar_chars: if similar_chars:
@@ -354,7 +356,7 @@ class ChineseTypoGenerator:
continue continue
word_result.append(char) word_result.append(char)
result.append(''.join(word_result)) result.append(''.join(word_result))
return ''.join(result), typo_info return ''.join(result), typo_info
def format_typo_info(self, typo_info): def format_typo_info(self, typo_info):
@@ -369,7 +371,7 @@ class ChineseTypoGenerator:
""" """
if not typo_info: if not typo_info:
return "未生成错别字" return "未生成错别字"
result = [] result = []
for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info: for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info:
# 判断是否为词语替换 # 判断是否为词语替换
@@ -379,12 +381,12 @@ class ChineseTypoGenerator:
else: else:
tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1] tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1]
error_type = "声调错误" if tone_error else "同音字替换" error_type = "声调错误" if tone_error else "同音字替换"
result.append(f"原文:{orig}({orig_py}) [频率:{orig_freq:.2f}] -> " result.append(f"原文:{orig}({orig_py}) [频率:{orig_freq:.2f}] -> "
f"替换:{typo}({typo_py}) [频率:{typo_freq:.2f}] [{error_type}]") f"替换:{typo}({typo_py}) [频率:{typo_freq:.2f}] [{error_type}]")
return "\n".join(result) return "\n".join(result)
def set_params(self, **kwargs): def set_params(self, **kwargs):
""" """
设置参数 设置参数
@@ -399,9 +401,10 @@ class ChineseTypoGenerator:
for key, value in kwargs.items(): for key, value in kwargs.items():
if hasattr(self, key): if hasattr(self, key):
setattr(self, key, value) setattr(self, key, value)
print(f"参数 {key} 已设置为 {value}") logger.debug(f"参数 {key} 已设置为 {value}")
else: else:
print(f"警告: 参数 {key} 不存在") logger.warning(f"警告: 参数 {key} 不存在")
def main(): def main():
# 创建错别字生成器实例 # 创建错别字生成器实例
@@ -411,27 +414,27 @@ def main():
tone_error_rate=0.02, tone_error_rate=0.02,
word_replace_rate=0.3 word_replace_rate=0.3
) )
# 获取用户输入 # 获取用户输入
sentence = input("请输入中文句子:") sentence = input("请输入中文句子:")
# 创建包含错别字的句子 # 创建包含错别字的句子
start_time = time.time() start_time = time.time()
typo_sentence, typo_info = typo_generator.create_typo_sentence(sentence) typo_sentence, typo_info = typo_generator.create_typo_sentence(sentence)
# 打印结果 # 打印结果
print("\n原句:", sentence) logger.debug("原句:", sentence)
print("错字版:", typo_sentence) logger.debug("错字版:", typo_sentence)
# 打印错别字信息 # 打印错别字信息
if typo_info: if typo_info:
print("\n错别字信息:") logger.debug(f"错别字信息:{typo_generator.format_typo_info(typo_info)})")
print(typo_generator.format_typo_info(typo_info))
# 计算并打印总耗时 # 计算并打印总耗时
end_time = time.time() end_time = time.time()
total_time = end_time - start_time total_time = end_time - start_time
print(f"\n总耗时:{total_time:.2f}") logger.debug(f"总耗时:{total_time:.2f}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()