Merge pull request #155 from KX76/fix/20250310-logger-optimize
Fix/20250310 logger optimize
This commit is contained in:
3
bot.py
3
bot.py
@@ -100,7 +100,7 @@ def load_logger():
|
|||||||
"#777777>|</> <cyan>{name:.<8}</cyan>:<cyan>{function:.<8}</cyan>:<cyan>{line: >4}</cyan> <fg "
|
"#777777>|</> <cyan>{name:.<8}</cyan>:<cyan>{function:.<8}</cyan>:<cyan>{line: >4}</cyan> <fg "
|
||||||
"#777777>-</> <level>{message}</level>",
|
"#777777>-</> <level>{message}</level>",
|
||||||
colorize=True,
|
colorize=True,
|
||||||
level=os.getenv("LOG_LEVEL", "INFO") # 根据环境设置日志级别,默认为INFO
|
level=os.getenv("LOG_LEVEL", "DEBUG") # 根据环境设置日志级别,默认为INFO
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -149,6 +149,7 @@ if __name__ == "__main__":
|
|||||||
init_config()
|
init_config()
|
||||||
init_env()
|
init_env()
|
||||||
load_env()
|
load_env()
|
||||||
|
load_logger()
|
||||||
|
|
||||||
env_config = {key: os.getenv(key) for key in os.environ}
|
env_config = {key: os.getenv(key) for key in os.environ}
|
||||||
scan_provider(env_config)
|
scan_provider(env_config)
|
||||||
|
|||||||
@@ -5,6 +5,9 @@ import threading
|
|||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
from loguru import logger
|
||||||
|
from typing import Optional
|
||||||
|
from pymongo import MongoClient
|
||||||
|
|
||||||
import customtkinter as ctk
|
import customtkinter as ctk
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
@@ -17,23 +20,20 @@ root_dir = os.path.abspath(os.path.join(current_dir, '..', '..'))
|
|||||||
# 加载环境变量
|
# 加载环境变量
|
||||||
if os.path.exists(os.path.join(root_dir, '.env.dev')):
|
if os.path.exists(os.path.join(root_dir, '.env.dev')):
|
||||||
load_dotenv(os.path.join(root_dir, '.env.dev'))
|
load_dotenv(os.path.join(root_dir, '.env.dev'))
|
||||||
print("成功加载开发环境配置")
|
logger.info("成功加载开发环境配置")
|
||||||
elif os.path.exists(os.path.join(root_dir, '.env.prod')):
|
elif os.path.exists(os.path.join(root_dir, '.env.prod')):
|
||||||
load_dotenv(os.path.join(root_dir, '.env.prod'))
|
load_dotenv(os.path.join(root_dir, '.env.prod'))
|
||||||
print("成功加载生产环境配置")
|
logger.info("成功加载生产环境配置")
|
||||||
else:
|
else:
|
||||||
print("未找到环境配置文件")
|
logger.error("未找到环境配置文件")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pymongo import MongoClient
|
|
||||||
|
|
||||||
|
|
||||||
class Database:
|
class Database:
|
||||||
_instance: Optional["Database"] = None
|
_instance: Optional["Database"] = None
|
||||||
|
|
||||||
def __init__(self, host: str, port: int, db_name: str, username: str = None, password: str = None, auth_source: str = None):
|
def __init__(self, host: str, port: int, db_name: str, username: str = None, password: str = None,
|
||||||
|
auth_source: str = None):
|
||||||
if username and password:
|
if username and password:
|
||||||
self.client = MongoClient(
|
self.client = MongoClient(
|
||||||
host=host,
|
host=host,
|
||||||
@@ -47,7 +47,8 @@ class Database:
|
|||||||
self.db = self.client[db_name]
|
self.db = self.client[db_name]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def initialize(cls, host: str, port: int, db_name: str, username: str = None, password: str = None, auth_source: str = None) -> "Database":
|
def initialize(cls, host: str, port: int, db_name: str, username: str = None, password: str = None,
|
||||||
|
auth_source: str = None) -> "Database":
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = cls(host, port, db_name, username, password, auth_source)
|
cls._instance = cls(host, port, db_name, username, password, auth_source)
|
||||||
return cls._instance
|
return cls._instance
|
||||||
@@ -59,12 +60,11 @@ class Database:
|
|||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ReasoningGUI:
|
class ReasoningGUI:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 记录启动时间戳,转换为Unix时间戳
|
# 记录启动时间戳,转换为Unix时间戳
|
||||||
self.start_timestamp = datetime.now().timestamp()
|
self.start_timestamp = datetime.now().timestamp()
|
||||||
print(f"程序启动时间戳: {self.start_timestamp}")
|
logger.info(f"程序启动时间戳: {self.start_timestamp}")
|
||||||
|
|
||||||
# 设置主题
|
# 设置主题
|
||||||
ctk.set_appearance_mode("dark")
|
ctk.set_appearance_mode("dark")
|
||||||
@@ -79,15 +79,15 @@ class ReasoningGUI:
|
|||||||
# 初始化数据库连接
|
# 初始化数据库连接
|
||||||
try:
|
try:
|
||||||
self.db = Database.get_instance().db
|
self.db = Database.get_instance().db
|
||||||
print("数据库连接成功")
|
logger.success("数据库连接成功")
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
print("数据库未初始化,正在尝试初始化...")
|
logger.warning("数据库未初始化,正在尝试初始化...")
|
||||||
try:
|
try:
|
||||||
Database.initialize("127.0.0.1", 27017, "maimai_bot")
|
Database.initialize("127.0.0.1", 27017, "maimai_bot")
|
||||||
self.db = Database.get_instance().db
|
self.db = Database.get_instance().db
|
||||||
print("数据库初始化成功")
|
logger.success("数据库初始化成功")
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print(f"数据库初始化失败: {e}")
|
logger.exception(f"数据库初始化失败")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# 存储群组数据
|
# 存储群组数据
|
||||||
@@ -285,12 +285,12 @@ class ReasoningGUI:
|
|||||||
try:
|
try:
|
||||||
# 从数据库获取最新数据,只获取启动时间之后的记录
|
# 从数据库获取最新数据,只获取启动时间之后的记录
|
||||||
query = {"time": {"$gt": self.start_timestamp}}
|
query = {"time": {"$gt": self.start_timestamp}}
|
||||||
print(f"查询条件: {query}")
|
logger.debug(f"查询条件: {query}")
|
||||||
|
|
||||||
# 先获取一条记录检查时间格式
|
# 先获取一条记录检查时间格式
|
||||||
sample = self.db.reasoning_logs.find_one()
|
sample = self.db.reasoning_logs.find_one()
|
||||||
if sample:
|
if sample:
|
||||||
print(f"样本记录时间格式: {type(sample['time'])} 值: {sample['time']}")
|
logger.debug(f"样本记录时间格式: {type(sample['time'])} 值: {sample['time']}")
|
||||||
|
|
||||||
cursor = self.db.reasoning_logs.find(query).sort("time", -1)
|
cursor = self.db.reasoning_logs.find(query).sort("time", -1)
|
||||||
new_data = {}
|
new_data = {}
|
||||||
@@ -299,7 +299,7 @@ class ReasoningGUI:
|
|||||||
for item in cursor:
|
for item in cursor:
|
||||||
# 调试输出
|
# 调试输出
|
||||||
if total_count == 0:
|
if total_count == 0:
|
||||||
print(f"记录时间: {item['time']}, 类型: {type(item['time'])}")
|
logger.debug(f"记录时间: {item['time']}, 类型: {type(item['time'])}")
|
||||||
|
|
||||||
total_count += 1
|
total_count += 1
|
||||||
group_id = str(item.get('group_id', 'unknown'))
|
group_id = str(item.get('group_id', 'unknown'))
|
||||||
@@ -312,7 +312,7 @@ class ReasoningGUI:
|
|||||||
elif isinstance(item['time'], datetime):
|
elif isinstance(item['time'], datetime):
|
||||||
time_obj = item['time']
|
time_obj = item['time']
|
||||||
else:
|
else:
|
||||||
print(f"未知的时间格式: {type(item['time'])}")
|
logger.warning(f"未知的时间格式: {type(item['time'])}")
|
||||||
time_obj = datetime.now() # 使用当前时间作为后备
|
time_obj = datetime.now() # 使用当前时间作为后备
|
||||||
|
|
||||||
new_data[group_id].append({
|
new_data[group_id].append({
|
||||||
@@ -325,12 +325,12 @@ class ReasoningGUI:
|
|||||||
'prompt': item.get('prompt', '') # 添加prompt字段
|
'prompt': item.get('prompt', '') # 添加prompt字段
|
||||||
})
|
})
|
||||||
|
|
||||||
print(f"从数据库加载了 {total_count} 条记录,分布在 {len(new_data)} 个群组中")
|
logger.info(f"从数据库加载了 {total_count} 条记录,分布在 {len(new_data)} 个群组中")
|
||||||
|
|
||||||
# 更新数据
|
# 更新数据
|
||||||
if new_data != self.group_data:
|
if new_data != self.group_data:
|
||||||
self.group_data = new_data
|
self.group_data = new_data
|
||||||
print("数据已更新,正在刷新显示...")
|
logger.info("数据已更新,正在刷新显示...")
|
||||||
# 将更新任务添加到队列
|
# 将更新任务添加到队列
|
||||||
self.update_queue.put({'type': 'update_group_list'})
|
self.update_queue.put({'type': 'update_group_list'})
|
||||||
if self.group_data:
|
if self.group_data:
|
||||||
@@ -341,8 +341,8 @@ class ReasoningGUI:
|
|||||||
'type': 'update_display',
|
'type': 'update_display',
|
||||||
'group_id': self.selected_group_id
|
'group_id': self.selected_group_id
|
||||||
})
|
})
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print(f"自动更新出错: {e}")
|
logger.exception(f"自动更新出错")
|
||||||
|
|
||||||
# 每5秒更新一次
|
# 每5秒更新一次
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
@@ -371,6 +371,5 @@ def main():
|
|||||||
app.run()
|
app.run()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -2,9 +2,8 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from nonebot import get_driver, on_command, on_message, require
|
from nonebot import get_driver, on_message, require
|
||||||
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment
|
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment
|
||||||
from nonebot.rule import to_me
|
|
||||||
from nonebot.typing import T_State
|
from nonebot.typing import T_State
|
||||||
|
|
||||||
from ...common.database import Database
|
from ...common.database import Database
|
||||||
@@ -16,6 +15,10 @@ from .config import global_config
|
|||||||
from .emoji_manager import emoji_manager
|
from .emoji_manager import emoji_manager
|
||||||
from .relationship_manager import relationship_manager
|
from .relationship_manager import relationship_manager
|
||||||
from .willing_manager import willing_manager
|
from .willing_manager import willing_manager
|
||||||
|
from ..memory_system.memory import hippocampus, memory_graph
|
||||||
|
from .bot import ChatBot
|
||||||
|
from .message_sender import message_manager, message_sender
|
||||||
|
|
||||||
|
|
||||||
# 创建LLM统计实例
|
# 创建LLM统计实例
|
||||||
llm_stats = LLMStatistics("llm_statistics.txt")
|
llm_stats = LLMStatistics("llm_statistics.txt")
|
||||||
@@ -35,19 +38,13 @@ Database.initialize(
|
|||||||
password=config.MONGODB_PASSWORD,
|
password=config.MONGODB_PASSWORD,
|
||||||
auth_source=config.MONGODB_AUTH_SOURCE
|
auth_source=config.MONGODB_AUTH_SOURCE
|
||||||
)
|
)
|
||||||
print("\033[1;32m[初始化数据库完成]\033[0m")
|
logger.success("初始化数据库成功")
|
||||||
|
|
||||||
# 导入其他模块
|
|
||||||
from ..memory_system.memory import hippocampus, memory_graph
|
|
||||||
from .bot import ChatBot
|
|
||||||
|
|
||||||
# from .message_send_control import message_sender
|
|
||||||
from .message_sender import message_manager, message_sender
|
|
||||||
|
|
||||||
# 初始化表情管理器
|
# 初始化表情管理器
|
||||||
emoji_manager.initialize()
|
emoji_manager.initialize()
|
||||||
|
|
||||||
print(f"\033[1;32m正在唤醒{global_config.BOT_NICKNAME}......\033[0m")
|
logger.debug(f"正在唤醒{global_config.BOT_NICKNAME}......")
|
||||||
# 创建机器人实例
|
# 创建机器人实例
|
||||||
chat_bot = ChatBot()
|
chat_bot = ChatBot()
|
||||||
# 注册群消息处理器
|
# 注册群消息处理器
|
||||||
@@ -61,12 +58,12 @@ async def start_background_tasks():
|
|||||||
"""启动后台任务"""
|
"""启动后台任务"""
|
||||||
# 启动LLM统计
|
# 启动LLM统计
|
||||||
llm_stats.start()
|
llm_stats.start()
|
||||||
logger.success("[初始化]LLM统计功能已启动")
|
logger.success("LLM统计功能启动成功")
|
||||||
|
|
||||||
# 初始化并启动情绪管理器
|
# 初始化并启动情绪管理器
|
||||||
mood_manager = MoodManager.get_instance()
|
mood_manager = MoodManager.get_instance()
|
||||||
mood_manager.start_mood_update(update_interval=global_config.mood_update_interval)
|
mood_manager.start_mood_update(update_interval=global_config.mood_update_interval)
|
||||||
logger.success("[初始化]情绪管理器已启动")
|
logger.success("情绪管理器启动成功")
|
||||||
|
|
||||||
# 只启动表情包管理任务
|
# 只启动表情包管理任务
|
||||||
asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL))
|
asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL))
|
||||||
@@ -77,7 +74,7 @@ async def start_background_tasks():
|
|||||||
@driver.on_startup
|
@driver.on_startup
|
||||||
async def init_relationships():
|
async def init_relationships():
|
||||||
"""在 NoneBot2 启动时初始化关系管理器"""
|
"""在 NoneBot2 启动时初始化关系管理器"""
|
||||||
print("\033[1;32m[初始化]\033[0m 正在加载用户关系数据...")
|
logger.debug("正在加载用户关系数据...")
|
||||||
await relationship_manager.load_all_relationships()
|
await relationship_manager.load_all_relationships()
|
||||||
asyncio.create_task(relationship_manager._start_relationship_manager())
|
asyncio.create_task(relationship_manager._start_relationship_manager())
|
||||||
|
|
||||||
@@ -86,19 +83,19 @@ async def init_relationships():
|
|||||||
async def _(bot: Bot):
|
async def _(bot: Bot):
|
||||||
"""Bot连接成功时的处理"""
|
"""Bot连接成功时的处理"""
|
||||||
global _message_manager_started
|
global _message_manager_started
|
||||||
print(f"\033[1;38;5;208m-----------{global_config.BOT_NICKNAME}成功连接!-----------\033[0m")
|
logger.debug(f"-----------{global_config.BOT_NICKNAME}成功连接!-----------")
|
||||||
await willing_manager.ensure_started()
|
await willing_manager.ensure_started()
|
||||||
|
|
||||||
message_sender.set_bot(bot)
|
message_sender.set_bot(bot)
|
||||||
print("\033[1;38;5;208m-----------消息发送器已启动!-----------\033[0m")
|
logger.success("-----------消息发送器已启动!-----------")
|
||||||
|
|
||||||
if not _message_manager_started:
|
if not _message_manager_started:
|
||||||
asyncio.create_task(message_manager.start_processor())
|
asyncio.create_task(message_manager.start_processor())
|
||||||
_message_manager_started = True
|
_message_manager_started = True
|
||||||
print("\033[1;38;5;208m-----------消息处理器已启动!-----------\033[0m")
|
logger.success("-----------消息处理器已启动!-----------")
|
||||||
|
|
||||||
asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL))
|
asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL))
|
||||||
print("\033[1;38;5;208m-----------开始偷表情包!-----------\033[0m")
|
logger.success("-----------开始偷表情包!-----------")
|
||||||
|
|
||||||
|
|
||||||
@group_msg.handle()
|
@group_msg.handle()
|
||||||
@@ -110,13 +107,15 @@ async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
|
|||||||
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory")
|
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory")
|
||||||
async def build_memory_task():
|
async def build_memory_task():
|
||||||
"""每build_memory_interval秒执行一次记忆构建"""
|
"""每build_memory_interval秒执行一次记忆构建"""
|
||||||
print(
|
logger.debug(
|
||||||
"\033[1;32m[记忆构建]\033[0m -------------------------------------------开始构建记忆-------------------------------------------")
|
"[记忆构建]"
|
||||||
|
"------------------------------------开始构建记忆--------------------------------------")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
await hippocampus.operation_build_memory(chat_size=20)
|
await hippocampus.operation_build_memory(chat_size=20)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
print(
|
logger.success(
|
||||||
f"\033[1;32m[记忆构建]\033[0m -------------------------------------------记忆构建完成:耗时: {end_time - start_time:.2f} 秒-------------------------------------------")
|
f"[记忆构建]--------------------------记忆构建完成:耗时: {end_time - start_time:.2f} "
|
||||||
|
"秒-------------------------------------------")
|
||||||
|
|
||||||
|
|
||||||
@scheduler.scheduled_job("interval", seconds=global_config.forget_memory_interval, id="forget_memory")
|
@scheduler.scheduled_job("interval", seconds=global_config.forget_memory_interval, id="forget_memory")
|
||||||
|
|||||||
@@ -70,19 +70,19 @@ class ChatBot:
|
|||||||
# 过滤词
|
# 过滤词
|
||||||
for word in global_config.ban_words:
|
for word in global_config.ban_words:
|
||||||
if word in message.detailed_plain_text:
|
if word in message.detailed_plain_text:
|
||||||
logger.info(f"\033[1;32m[{message.group_name}]{message.user_nickname}:\033[0m {message.processed_plain_text}")
|
logger.info(
|
||||||
logger.info(f"\033[1;32m[过滤词识别]\033[0m 消息中含有{word},filtered")
|
f"[{message.group_name}]{message.user_nickname}:{message.processed_plain_text}")
|
||||||
|
logger.info(f"[过滤词识别]消息中含有{word},filtered")
|
||||||
return
|
return
|
||||||
|
|
||||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
|
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# topic=await topic_identifier.identify_topic_llm(message.processed_plain_text)
|
# topic=await topic_identifier.identify_topic_llm(message.processed_plain_text)
|
||||||
topic = ''
|
topic = ''
|
||||||
interested_rate = 0
|
interested_rate = 0
|
||||||
interested_rate = await hippocampus.memory_activate_value(message.processed_plain_text) / 100
|
interested_rate = await hippocampus.memory_activate_value(message.processed_plain_text) / 100
|
||||||
print(f"\033[1;32m[记忆激活]\033[0m 对{message.processed_plain_text}的激活度:---------------------------------------{interested_rate}\n")
|
logger.debug(f"对{message.processed_plain_text}"
|
||||||
|
f"的激活度:{interested_rate}")
|
||||||
# logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}")
|
# logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}")
|
||||||
|
|
||||||
await self.storage.store_message(message, topic[0] if topic else None)
|
await self.storage.store_message(message, topic[0] if topic else None)
|
||||||
@@ -99,14 +99,13 @@ class ChatBot:
|
|||||||
)
|
)
|
||||||
current_willing = willing_manager.get_willing(event.group_id)
|
current_willing = willing_manager.get_willing(event.group_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
print(f"\033[1;32m[{current_time}][{message.group_name}]{message.user_nickname}:\033[0m {message.processed_plain_text}\033[1;36m[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]\033[0m")
|
f"[{current_time}][{message.group_name}]{message.user_nickname}:"
|
||||||
|
f"{message.processed_plain_text}[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]")
|
||||||
|
|
||||||
response = ""
|
response = ""
|
||||||
|
|
||||||
if random() < reply_probability:
|
if random() < reply_probability:
|
||||||
|
|
||||||
|
|
||||||
tinking_time_point = round(time.time(), 2)
|
tinking_time_point = round(time.time(), 2)
|
||||||
think_id = 'mt' + str(tinking_time_point)
|
think_id = 'mt' + str(tinking_time_point)
|
||||||
thinking_message = Message_Thinking(message=message, message_id=think_id)
|
thinking_message = Message_Thinking(message=message, message_id=think_id)
|
||||||
@@ -130,12 +129,13 @@ class ChatBot:
|
|||||||
|
|
||||||
# 如果找不到思考消息,直接返回
|
# 如果找不到思考消息,直接返回
|
||||||
if not thinking_message:
|
if not thinking_message:
|
||||||
print(f"\033[1;33m[警告]\033[0m 未找到对应的思考消息,可能已超时被移除")
|
logger.warning(f"未找到对应的思考消息,可能已超时被移除")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 记录开始思考的时间,避免从思考到回复的时间太久
|
# 记录开始思考的时间,避免从思考到回复的时间太久
|
||||||
thinking_start_time = thinking_message.thinking_start_time
|
thinking_start_time = thinking_message.thinking_start_time
|
||||||
message_set = MessageSet(event.group_id, global_config.BOT_QQ, think_id) # 发送消息的id和产生发送消息的message_thinking是一致的
|
message_set = MessageSet(event.group_id, global_config.BOT_QQ,
|
||||||
|
think_id) # 发送消息的id和产生发送消息的message_thinking是一致的
|
||||||
# 计算打字时间,1是为了模拟打字,2是避免多条回复乱序
|
# 计算打字时间,1是为了模拟打字,2是避免多条回复乱序
|
||||||
accu_typing_time = 0
|
accu_typing_time = 0
|
||||||
|
|
||||||
@@ -206,7 +206,7 @@ class ChatBot:
|
|||||||
await bot_message.initialize()
|
await bot_message.initialize()
|
||||||
message_manager.add_message(bot_message)
|
message_manager.add_message(bot_message)
|
||||||
emotion = await self.gpt._get_emotion_tags(raw_content)
|
emotion = await self.gpt._get_emotion_tags(raw_content)
|
||||||
print(f"为 '{response}' 获取到的情感标签为:{emotion}")
|
logger.debug(f"为 '{response}' 获取到的情感标签为:{emotion}")
|
||||||
valuedict = {
|
valuedict = {
|
||||||
'happy': 0.5,
|
'happy': 0.5,
|
||||||
'angry': -1,
|
'angry': -1,
|
||||||
@@ -216,11 +216,13 @@ class ChatBot:
|
|||||||
'fearful': -0.7,
|
'fearful': -0.7,
|
||||||
'neutral': 0.1
|
'neutral': 0.1
|
||||||
}
|
}
|
||||||
await relationship_manager.update_relationship_value(message.user_id, relationship_value=valuedict[emotion[0]])
|
await relationship_manager.update_relationship_value(message.user_id,
|
||||||
|
relationship_value=valuedict[emotion[0]])
|
||||||
# 使用情绪管理器更新情绪
|
# 使用情绪管理器更新情绪
|
||||||
self.mood_manager.update_mood_from_emotion(emotion[0], global_config.mood_intensity_factor)
|
self.mood_manager.update_mood_from_emotion(emotion[0], global_config.mood_intensity_factor)
|
||||||
|
|
||||||
# willing_manager.change_reply_willing_after_sent(event.group_id)
|
# willing_manager.change_reply_willing_after_sent(event.group_id)
|
||||||
|
|
||||||
|
|
||||||
# 创建全局ChatBot实例
|
# 创建全局ChatBot实例
|
||||||
chat_bot = ChatBot()
|
chat_bot = ChatBot()
|
||||||
@@ -135,7 +135,7 @@ class BotConfig:
|
|||||||
try:
|
try:
|
||||||
config_version: str = toml["inner"]["version"]
|
config_version: str = toml["inner"]["version"]
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
logger.error(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件")
|
logger.error(f"配置文件中 inner 段 不存在, 这是错误的配置文件")
|
||||||
raise KeyError(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件")
|
raise KeyError(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件")
|
||||||
else:
|
else:
|
||||||
toml["inner"] = {"version": "0.0.0"}
|
toml["inner"] = {"version": "0.0.0"}
|
||||||
@@ -162,7 +162,7 @@ class BotConfig:
|
|||||||
personality_config = parent['personality']
|
personality_config = parent['personality']
|
||||||
personality = personality_config.get('prompt_personality')
|
personality = personality_config.get('prompt_personality')
|
||||||
if len(personality) >= 2:
|
if len(personality) >= 2:
|
||||||
logger.info(f"载入自定义人格:{personality}")
|
logger.debug(f"载入自定义人格:{personality}")
|
||||||
config.PROMPT_PERSONALITY = personality_config.get('prompt_personality', config.PROMPT_PERSONALITY)
|
config.PROMPT_PERSONALITY = personality_config.get('prompt_personality', config.PROMPT_PERSONALITY)
|
||||||
logger.info(f"载入自定义日程prompt:{personality_config.get('prompt_schedule', config.PROMPT_SCHEDULE_GEN)}")
|
logger.info(f"载入自定义日程prompt:{personality_config.get('prompt_schedule', config.PROMPT_SCHEDULE_GEN)}")
|
||||||
config.PROMPT_SCHEDULE_GEN = personality_config.get('prompt_schedule', config.PROMPT_SCHEDULE_GEN)
|
config.PROMPT_SCHEDULE_GEN = personality_config.get('prompt_schedule', config.PROMPT_SCHEDULE_GEN)
|
||||||
@@ -246,11 +246,11 @@ class BotConfig:
|
|||||||
try:
|
try:
|
||||||
cfg_target[i] = cfg_item[i]
|
cfg_target[i] = cfg_item[i]
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
logger.error(f"{item} 中的必要字段 {e} 不存在,请检查")
|
logger.error(f"{item} 中的必要字段不存在,请检查")
|
||||||
raise KeyError(f"{item} 中的必要字段 {e} 不存在,请检查")
|
raise KeyError(f"{item} 中的必要字段 {e} 不存在,请检查")
|
||||||
|
|
||||||
provider = cfg_item.get("provider")
|
provider = cfg_item.get("provider")
|
||||||
if provider == None:
|
if provider is None:
|
||||||
logger.error(f"provider 字段在模型配置 {item} 中不存在,请检查")
|
logger.error(f"provider 字段在模型配置 {item} 中不存在,请检查")
|
||||||
raise KeyError(f"provider 字段在模型配置 {item} 中不存在,请检查")
|
raise KeyError(f"provider 字段在模型配置 {item} 中不存在,请检查")
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@@ -151,11 +152,11 @@ class CQCode:
|
|||||||
|
|
||||||
except (requests.exceptions.SSLError, requests.exceptions.HTTPError) as e:
|
except (requests.exceptions.SSLError, requests.exceptions.HTTPError) as e:
|
||||||
if retry == max_retries - 1:
|
if retry == max_retries - 1:
|
||||||
print(f"\033[1;31m[致命错误]\033[0m 最终请求失败: {str(e)}")
|
logger.error(f"最终请求失败: {str(e)}")
|
||||||
time.sleep(1.5 ** retry) # 指数退避
|
time.sleep(1.5 ** retry) # 指数退避
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\033[1;33m[未知错误]\033[0m {str(e)}")
|
logger.exception(f"[未知错误]")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@@ -194,7 +195,7 @@ class CQCode:
|
|||||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
|
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
|
||||||
return f"[表情包:{description}]"
|
return f"[表情包:{description}]"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}")
|
logger.exception(f"AI接口调用失败: {str(e)}")
|
||||||
return "[表情包]"
|
return "[表情包]"
|
||||||
|
|
||||||
async def get_image_description(self, image_base64: str) -> str:
|
async def get_image_description(self, image_base64: str) -> str:
|
||||||
@@ -205,7 +206,7 @@ class CQCode:
|
|||||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
|
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
|
||||||
return f"[图片:{description}]"
|
return f"[图片:{description}]"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}")
|
logger.exception(f"AI接口调用失败: {str(e)}")
|
||||||
return "[图片]"
|
return "[图片]"
|
||||||
|
|
||||||
async def translate_forward(self) -> str:
|
async def translate_forward(self) -> str:
|
||||||
@@ -222,7 +223,7 @@ class CQCode:
|
|||||||
try:
|
try:
|
||||||
messages = ast.literal_eval(content)
|
messages = ast.literal_eval(content)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
print(f"\033[1;31m[错误]\033[0m 解析转发消息内容失败: {str(e)}")
|
logger.error(f"解析转发消息内容失败: {str(e)}")
|
||||||
return '[转发消息]'
|
return '[转发消息]'
|
||||||
|
|
||||||
# 处理每条消息
|
# 处理每条消息
|
||||||
@@ -277,11 +278,11 @@ class CQCode:
|
|||||||
|
|
||||||
# 合并所有消息
|
# 合并所有消息
|
||||||
combined_messages = '\n'.join(formatted_messages)
|
combined_messages = '\n'.join(formatted_messages)
|
||||||
print(f"\033[1;34m[调试信息]\033[0m 合并后的转发消息: {combined_messages}")
|
logger.debug(f"合并后的转发消息: {combined_messages}")
|
||||||
return f"[转发消息:\n{combined_messages}]"
|
return f"[转发消息:\n{combined_messages}]"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\033[1;31m[错误]\033[0m 处理转发消息失败: {str(e)}")
|
logger.exception("处理转发消息失败")
|
||||||
return '[转发消息]'
|
return '[转发消息]'
|
||||||
|
|
||||||
async def translate_reply(self) -> str:
|
async def translate_reply(self) -> str:
|
||||||
@@ -307,7 +308,7 @@ class CQCode:
|
|||||||
return f"[回复 {self.reply_message.sender.nickname} 的消息: {message_obj.processed_plain_text}]"
|
return f"[回复 {self.reply_message.sender.nickname} 的消息: {message_obj.processed_plain_text}]"
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print("\033[1;31m[错误]\033[0m 回复消息的sender.user_id为空")
|
logger.error("回复消息的sender.user_id为空")
|
||||||
return '[回复某人消息]'
|
return '[回复某人消息]'
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -33,7 +33,9 @@ class EmojiManager:
|
|||||||
self.db = Database.get_instance()
|
self.db = Database.get_instance()
|
||||||
self._scan_task = None
|
self._scan_task = None
|
||||||
self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000)
|
self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000)
|
||||||
self.llm_emotion_judge = LLM_request(model=global_config.llm_emotion_judge, max_tokens=60,temperature=0.8) #更高的温度,更少的token(后续可以根据情绪来调整温度)
|
self.llm_emotion_judge = LLM_request(model=global_config.llm_normal_minor, max_tokens=60,
|
||||||
|
temperature=0.8) # 更高的温度,更少的token(后续可以根据情绪来调整温度)
|
||||||
|
|
||||||
|
|
||||||
def _ensure_emoji_dir(self):
|
def _ensure_emoji_dir(self):
|
||||||
"""确保表情存储目录存在"""
|
"""确保表情存储目录存在"""
|
||||||
@@ -50,7 +52,7 @@ class EmojiManager:
|
|||||||
# 启动时执行一次完整性检查
|
# 启动时执行一次完整性检查
|
||||||
self.check_emoji_file_integrity()
|
self.check_emoji_file_integrity()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"初始化表情管理器失败: {str(e)}")
|
logger.exception(f"初始化表情管理器失败")
|
||||||
|
|
||||||
def _ensure_db(self):
|
def _ensure_db(self):
|
||||||
"""确保数据库已初始化"""
|
"""确保数据库已初始化"""
|
||||||
@@ -86,7 +88,7 @@ class EmojiManager:
|
|||||||
{'$inc': {'usage_count': 1}}
|
{'$inc': {'usage_count': 1}}
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"记录表情使用失败: {str(e)}")
|
logger.exception(f"记录表情使用失败")
|
||||||
|
|
||||||
async def get_emoji_for_text(self, text: str) -> Optional[str]:
|
async def get_emoji_for_text(self, text: str) -> Optional[str]:
|
||||||
"""根据文本内容获取相关表情包
|
"""根据文本内容获取相关表情包
|
||||||
@@ -157,7 +159,8 @@ class EmojiManager:
|
|||||||
{'_id': selected_emoji['_id']},
|
{'_id': selected_emoji['_id']},
|
||||||
{'$inc': {'usage_count': 1}}
|
{'$inc': {'usage_count': 1}}
|
||||||
)
|
)
|
||||||
logger.success(f"找到匹配的表情包: {selected_emoji.get('discription', '无描述')} (相似度: {similarity:.4f})")
|
logger.success(
|
||||||
|
f"找到匹配的表情包: {selected_emoji.get('discription', '无描述')} (相似度: {similarity:.4f})")
|
||||||
# 稍微改一下文本描述,不然容易产生幻觉,描述已经包含 表情包 了
|
# 稍微改一下文本描述,不然容易产生幻觉,描述已经包含 表情包 了
|
||||||
return selected_emoji['path'], "[ %s ]" % selected_emoji.get('discription', '无描述')
|
return selected_emoji['path'], "[ %s ]" % selected_emoji.get('discription', '无描述')
|
||||||
|
|
||||||
@@ -215,7 +218,8 @@ class EmojiManager:
|
|||||||
os.makedirs(emoji_dir, exist_ok=True)
|
os.makedirs(emoji_dir, exist_ok=True)
|
||||||
|
|
||||||
# 获取所有支持的图片文件
|
# 获取所有支持的图片文件
|
||||||
files_to_process = [f for f in os.listdir(emoji_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.gif'))]
|
files_to_process = [f for f in os.listdir(emoji_dir) if
|
||||||
|
f.lower().endswith(('.jpg', '.jpeg', '.png', '.gif'))]
|
||||||
|
|
||||||
for filename in files_to_process:
|
for filename in files_to_process:
|
||||||
image_path = os.path.join(emoji_dir, filename)
|
image_path = os.path.join(emoji_dir, filename)
|
||||||
@@ -260,17 +264,15 @@ class EmojiManager:
|
|||||||
logger.warning(f"跳过表情包: {filename}")
|
logger.warning(f"跳过表情包: {filename}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"扫描表情包失败: {str(e)}")
|
logger.exception(f"扫描表情包失败")
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
|
|
||||||
async def _periodic_scan(self, interval_MINS: int = 10):
|
async def _periodic_scan(self, interval_MINS: int = 10):
|
||||||
"""定期扫描新表情包"""
|
"""定期扫描新表情包"""
|
||||||
while True:
|
while True:
|
||||||
print("\033[1;36m[表情包]\033[0m 开始扫描新表情包...")
|
logger.info("开始扫描新表情包...")
|
||||||
await self.scan_new_emojis()
|
await self.scan_new_emojis()
|
||||||
await asyncio.sleep(interval_MINS * 60) # 每600秒扫描一次
|
await asyncio.sleep(interval_MINS * 60) # 每600秒扫描一次
|
||||||
|
|
||||||
|
|
||||||
def check_emoji_file_integrity(self):
|
def check_emoji_file_integrity(self):
|
||||||
"""检查表情包文件完整性
|
"""检查表情包文件完整性
|
||||||
如果文件已被删除,则从数据库中移除对应记录
|
如果文件已被删除,则从数据库中移除对应记录
|
||||||
@@ -302,7 +304,7 @@ class EmojiManager:
|
|||||||
# 从数据库中删除记录
|
# 从数据库中删除记录
|
||||||
result = self.db.db.emoji.delete_one({'_id': emoji['_id']})
|
result = self.db.db.emoji.delete_one({'_id': emoji['_id']})
|
||||||
if result.deleted_count > 0:
|
if result.deleted_count > 0:
|
||||||
logger.success(f"成功删除数据库记录: {emoji['_id']}")
|
logger.debug(f"成功删除数据库记录: {emoji['_id']}")
|
||||||
removed_count += 1
|
removed_count += 1
|
||||||
else:
|
else:
|
||||||
logger.error(f"删除数据库记录失败: {emoji['_id']}")
|
logger.error(f"删除数据库记录失败: {emoji['_id']}")
|
||||||
@@ -328,6 +330,6 @@ class EmojiManager:
|
|||||||
await asyncio.sleep(interval_MINS * 60)
|
await asyncio.sleep(interval_MINS * 60)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 创建全局单例
|
# 创建全局单例
|
||||||
emoji_manager = EmojiManager()
|
emoji_manager = EmojiManager()
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import time
|
|||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
from nonebot import get_driver
|
from nonebot import get_driver
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from ...common.database import Database
|
from ...common.database import Database
|
||||||
from ..models.utils_model import LLM_request
|
from ..models.utils_model import LLM_request
|
||||||
@@ -39,13 +40,13 @@ class ResponseGenerator:
|
|||||||
self.current_model_type = 'r1_distill'
|
self.current_model_type = 'r1_distill'
|
||||||
current_model = self.model_r1_distill
|
current_model = self.model_r1_distill
|
||||||
|
|
||||||
print(f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++")
|
logger.info(f"{global_config.BOT_NICKNAME}{self.current_model_type}思考中")
|
||||||
|
|
||||||
model_response = await self._generate_response_with_model(message, current_model)
|
model_response = await self._generate_response_with_model(message, current_model)
|
||||||
raw_content=model_response
|
raw_content=model_response
|
||||||
|
|
||||||
if model_response:
|
if model_response:
|
||||||
print(f'{global_config.BOT_NICKNAME}的回复是:{model_response}')
|
logger.info(f'{global_config.BOT_NICKNAME}的回复是:{model_response}')
|
||||||
model_response = await self._process_response(model_response)
|
model_response = await self._process_response(model_response)
|
||||||
if model_response:
|
if model_response:
|
||||||
|
|
||||||
@@ -92,8 +93,8 @@ class ResponseGenerator:
|
|||||||
# 生成回复
|
# 生成回复
|
||||||
try:
|
try:
|
||||||
content, reasoning_content = await model.generate_response(prompt)
|
content, reasoning_content = await model.generate_response(prompt)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print(f"生成回复时出错: {e}")
|
logger.exception(f"生成回复时出错")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 保存到数据库
|
# 保存到数据库
|
||||||
@@ -144,8 +145,8 @@ class ResponseGenerator:
|
|||||||
else:
|
else:
|
||||||
return ["neutral"]
|
return ["neutral"]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print(f"获取情感标签时出错: {e}")
|
logger.exception(f"获取情感标签时出错")
|
||||||
return ["neutral"]
|
return ["neutral"]
|
||||||
|
|
||||||
async def _process_response(self, content: str) -> Tuple[List[str], List[str]]:
|
async def _process_response(self, content: str) -> Tuple[List[str], List[str]]:
|
||||||
@@ -172,7 +173,7 @@ class InitiativeMessageGenerate:
|
|||||||
prompt_builder._build_initiative_prompt_select(message.group_id)
|
prompt_builder._build_initiative_prompt_select(message.group_id)
|
||||||
)
|
)
|
||||||
content_select, reasoning = self.model_v3.generate_response(topic_select_prompt)
|
content_select, reasoning = self.model_v3.generate_response(topic_select_prompt)
|
||||||
print(f"[DEBUG] {content_select} {reasoning}")
|
logger.debug(f"{content_select} {reasoning}")
|
||||||
topics_list = [dot[0] for dot in dots_for_select]
|
topics_list = [dot[0] for dot in dots_for_select]
|
||||||
if content_select:
|
if content_select:
|
||||||
if content_select in topics_list:
|
if content_select in topics_list:
|
||||||
@@ -185,12 +186,12 @@ class InitiativeMessageGenerate:
|
|||||||
select_dot[1], prompt_template
|
select_dot[1], prompt_template
|
||||||
)
|
)
|
||||||
content_check, reasoning_check = self.model_v3.generate_response(prompt_check)
|
content_check, reasoning_check = self.model_v3.generate_response(prompt_check)
|
||||||
print(f"[DEBUG] {content_check} {reasoning_check}")
|
logger.info(f"{content_check} {reasoning_check}")
|
||||||
if "yes" not in content_check.lower():
|
if "yes" not in content_check.lower():
|
||||||
return None
|
return None
|
||||||
prompt = prompt_builder._build_initiative_prompt(
|
prompt = prompt_builder._build_initiative_prompt(
|
||||||
select_dot, prompt_template, memory
|
select_dot, prompt_template, memory
|
||||||
)
|
)
|
||||||
content, reasoning = self.model_r1.generate_response_async(prompt)
|
content, reasoning = self.model_r1.generate_response_async(prompt)
|
||||||
print(f"[DEBUG] {content} {reasoning}")
|
logger.debug(f"[DEBUG] {content} {reasoning}")
|
||||||
return content
|
return content
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
from nonebot.adapters.onebot.v11 import Bot
|
from nonebot.adapters.onebot.v11 import Bot
|
||||||
|
|
||||||
from .cq_code import cq_code_tool
|
from .cq_code import cq_code_tool
|
||||||
@@ -13,6 +14,7 @@ from .config import global_config
|
|||||||
|
|
||||||
class Message_Sender:
|
class Message_Sender:
|
||||||
"""发送器"""
|
"""发送器"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.message_interval = (0.5, 1) # 消息间隔时间范围(秒)
|
self.message_interval = (0.5, 1) # 消息间隔时间范围(秒)
|
||||||
self.last_send_time = 0
|
self.last_send_time = 0
|
||||||
@@ -46,7 +48,6 @@ class Message_Sender:
|
|||||||
# at_cq = cq_code_tool.create_at_cq(at_user_id)
|
# at_cq = cq_code_tool.create_at_cq(at_user_id)
|
||||||
# message = at_cq + " " + message
|
# message = at_cq + " " + message
|
||||||
|
|
||||||
|
|
||||||
typing_time = calculate_typing_time(message)
|
typing_time = calculate_typing_time(message)
|
||||||
if typing_time > 10:
|
if typing_time > 10:
|
||||||
typing_time = 10
|
typing_time = 10
|
||||||
@@ -59,14 +60,14 @@ class Message_Sender:
|
|||||||
message=message,
|
message=message,
|
||||||
auto_escape=auto_escape
|
auto_escape=auto_escape
|
||||||
)
|
)
|
||||||
print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功")
|
logger.debug(f"发送消息{message}成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"发生错误 {e}")
|
logger.exception(f"发送消息{message}失败")
|
||||||
print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败")
|
|
||||||
|
|
||||||
|
|
||||||
class MessageContainer:
|
class MessageContainer:
|
||||||
"""单个群的发送/思考消息容器"""
|
"""单个群的发送/思考消息容器"""
|
||||||
|
|
||||||
def __init__(self, group_id: int, max_size: int = 100):
|
def __init__(self, group_id: int, max_size: int = 100):
|
||||||
self.group_id = group_id
|
self.group_id = group_id
|
||||||
self.max_size = max_size
|
self.max_size = max_size
|
||||||
@@ -118,8 +119,8 @@ class MessageContainer:
|
|||||||
self.messages.remove(message)
|
self.messages.remove(message)
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print(f"\033[1;31m[错误]\033[0m 移除消息时发生错误: {e}")
|
logger.exception(f"移除消息时发生错误")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def has_messages(self) -> bool:
|
def has_messages(self) -> bool:
|
||||||
@@ -133,6 +134,7 @@ class MessageContainer:
|
|||||||
|
|
||||||
class MessageManager:
|
class MessageManager:
|
||||||
"""管理所有群的消息容器"""
|
"""管理所有群的消息容器"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.containers: Dict[int, MessageContainer] = {}
|
self.containers: Dict[int, MessageContainer] = {}
|
||||||
self.storage = MessageStorage()
|
self.storage = MessageStorage()
|
||||||
@@ -162,19 +164,22 @@ class MessageManager:
|
|||||||
# 优先等待这条消息
|
# 优先等待这条消息
|
||||||
message_earliest.update_thinking_time()
|
message_earliest.update_thinking_time()
|
||||||
thinking_time = message_earliest.thinking_time
|
thinking_time = message_earliest.thinking_time
|
||||||
print(f"\033[1;34m[调试]\033[0m 消息正在思考中,已思考{int(thinking_time)}秒\033[K\r", end='', flush=True)
|
print(f"消息正在思考中,已思考{int(thinking_time)}秒\r", end='', flush=True)
|
||||||
|
|
||||||
# 检查是否超时
|
# 检查是否超时
|
||||||
if thinking_time > global_config.thinking_timeout:
|
if thinking_time > global_config.thinking_timeout:
|
||||||
print(f"\033[1;33m[警告]\033[0m 消息思考超时({thinking_time}秒),移除该消息")
|
logger.warning(f"消息思考超时({thinking_time}秒),移除该消息")
|
||||||
container.remove_message(message_earliest)
|
container.remove_message(message_earliest)
|
||||||
else: # 如果不是message_thinking就只能是message_sending
|
else: # 如果不是message_thinking就只能是message_sending
|
||||||
print(f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中")
|
logger.debug(f"消息'{message_earliest.processed_plain_text}'正在发送中")
|
||||||
# 直接发,等什么呢
|
# 直接发,等什么呢
|
||||||
if message_earliest.is_head and message_earliest.update_thinking_time() > 30:
|
if message_earliest.is_head and message_earliest.update_thinking_time() > 30:
|
||||||
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False, reply_message_id=message_earliest.reply_message_id)
|
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text,
|
||||||
|
auto_escape=False,
|
||||||
|
reply_message_id=message_earliest.reply_message_id)
|
||||||
else:
|
else:
|
||||||
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False)
|
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text,
|
||||||
|
auto_escape=False)
|
||||||
# 移除消息
|
# 移除消息
|
||||||
if message_earliest.is_emoji:
|
if message_earliest.is_emoji:
|
||||||
message_earliest.processed_plain_text = "[表情包]"
|
message_earliest.processed_plain_text = "[表情包]"
|
||||||
@@ -185,7 +190,7 @@ class MessageManager:
|
|||||||
# 获取并处理超时消息
|
# 获取并处理超时消息
|
||||||
message_timeout = container.get_timeout_messages() # 也许是一堆message_sending
|
message_timeout = container.get_timeout_messages() # 也许是一堆message_sending
|
||||||
if message_timeout:
|
if message_timeout:
|
||||||
print(f"\033[1;34m[调试]\033[0m 发现{len(message_timeout)}条超时消息")
|
logger.warning(f"发现{len(message_timeout)}条超时消息")
|
||||||
for msg in message_timeout:
|
for msg in message_timeout:
|
||||||
if msg == message_earliest:
|
if msg == message_earliest:
|
||||||
continue # 跳过已经处理过的消息
|
continue # 跳过已经处理过的消息
|
||||||
@@ -193,10 +198,12 @@ class MessageManager:
|
|||||||
try:
|
try:
|
||||||
# 发送
|
# 发送
|
||||||
if msg.is_head and msg.update_thinking_time() > 30:
|
if msg.is_head and msg.update_thinking_time() > 30:
|
||||||
await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False, reply_message_id=msg.reply_message_id)
|
await message_sender.send_group_message(group_id, msg.processed_plain_text,
|
||||||
|
auto_escape=False,
|
||||||
|
reply_message_id=msg.reply_message_id)
|
||||||
else:
|
else:
|
||||||
await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False)
|
await message_sender.send_group_message(group_id, msg.processed_plain_text,
|
||||||
|
auto_escape=False)
|
||||||
|
|
||||||
# 如果是表情包,则替换为"[表情包]"
|
# 如果是表情包,则替换为"[表情包]"
|
||||||
if msg.is_emoji:
|
if msg.is_emoji:
|
||||||
@@ -205,9 +212,9 @@ class MessageManager:
|
|||||||
|
|
||||||
# 安全地移除消息
|
# 安全地移除消息
|
||||||
if not container.remove_message(msg):
|
if not container.remove_message(msg):
|
||||||
print("\033[1;33m[警告]\033[0m 尝试删除不存在的消息")
|
logger.warning("尝试删除不存在的消息")
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print(f"\033[1;31m[错误]\033[0m 处理超时消息时发生错误: {e}")
|
logger.exception(f"处理超时消息时发生错误")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
async def start_processor(self):
|
async def start_processor(self):
|
||||||
@@ -220,6 +227,7 @@ class MessageManager:
|
|||||||
|
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
|
||||||
# 创建全局消息管理器实例
|
# 创建全局消息管理器实例
|
||||||
message_manager = MessageManager()
|
message_manager = MessageManager()
|
||||||
# 创建全局发送器实例
|
# 创建全局发送器实例
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from ...common.database import Database
|
from ...common.database import Database
|
||||||
from ..memory_system.memory import hippocampus, memory_graph
|
from ..memory_system.memory import hippocampus, memory_graph
|
||||||
@@ -16,8 +17,6 @@ class PromptBuilder:
|
|||||||
self.activate_messages = ''
|
self.activate_messages = ''
|
||||||
self.db = Database.get_instance()
|
self.db = Database.get_instance()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def _build_prompt(self,
|
async def _build_prompt(self,
|
||||||
message_txt: str,
|
message_txt: str,
|
||||||
sender_name: str = "某人",
|
sender_name: str = "某人",
|
||||||
@@ -47,12 +46,10 @@ class PromptBuilder:
|
|||||||
|
|
||||||
# 开始构建prompt
|
# 开始构建prompt
|
||||||
|
|
||||||
|
|
||||||
# 心情
|
# 心情
|
||||||
mood_manager = MoodManager.get_instance()
|
mood_manager = MoodManager.get_instance()
|
||||||
mood_prompt = mood_manager.get_prompt()
|
mood_prompt = mood_manager.get_prompt()
|
||||||
|
|
||||||
|
|
||||||
# 日程构建
|
# 日程构建
|
||||||
current_date = time.strftime("%Y-%m-%d", time.localtime())
|
current_date = time.strftime("%Y-%m-%d", time.localtime())
|
||||||
current_time = time.strftime("%H:%M:%S", time.localtime())
|
current_time = time.strftime("%H:%M:%S", time.localtime())
|
||||||
@@ -66,20 +63,21 @@ class PromptBuilder:
|
|||||||
promt_info_prompt = ''
|
promt_info_prompt = ''
|
||||||
prompt_info = await self.get_prompt_info(message_txt, threshold=0.5)
|
prompt_info = await self.get_prompt_info(message_txt, threshold=0.5)
|
||||||
if prompt_info:
|
if prompt_info:
|
||||||
prompt_info = f'''\n----------------------------------------------------\n你有以下这些[知识]:\n{prompt_info}\n请你记住上面的[知识],之后可能会用到\n----------------------------------------------------\n'''
|
prompt_info = f'''你有以下这些[知识]:{prompt_info}请你记住上面的[
|
||||||
|
知识],之后可能会用到-'''
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
print(f"\033[1;32m[知识检索]\033[0m 耗时: {(end_time - start_time):.3f}秒")
|
logger.debug(f"知识检索耗时: {(end_time - start_time):.3f}秒")
|
||||||
|
|
||||||
# 获取聊天上下文
|
# 获取聊天上下文
|
||||||
chat_talking_prompt = ''
|
chat_talking_prompt = ''
|
||||||
if group_id:
|
if group_id:
|
||||||
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
|
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id,
|
||||||
|
limit=global_config.MAX_CONTEXT_SIZE,
|
||||||
|
combine=True)
|
||||||
|
|
||||||
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
|
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 使用新的记忆获取方法
|
# 使用新的记忆获取方法
|
||||||
memory_prompt = ''
|
memory_prompt = ''
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -101,14 +99,12 @@ class PromptBuilder:
|
|||||||
memory_prompt = "看到这些聊天,你想起来:\n" + "\n".join(memory_items) + "\n"
|
memory_prompt = "看到这些聊天,你想起来:\n" + "\n".join(memory_items) + "\n"
|
||||||
|
|
||||||
# 打印调试信息
|
# 打印调试信息
|
||||||
print("\n\033[1;32m[记忆检索]\033[0m 找到以下相关记忆:")
|
logger.debug("[记忆检索]找到以下相关记忆:")
|
||||||
for memory in relevant_memories:
|
for memory in relevant_memories:
|
||||||
print(f"- 主题「{memory['topic']}」[相似度: {memory['similarity']:.2f}]: {memory['content']}")
|
logger.debug(f"- 主题「{memory['topic']}」[相似度: {memory['similarity']:.2f}]: {memory['content']}")
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
print(f"\033[1;32m[回忆耗时]\033[0m 耗时: {(end_time - start_time):.3f}秒")
|
logger.info(f"回忆耗时: {(end_time - start_time):.3f}秒")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 激活prompt构建
|
# 激活prompt构建
|
||||||
activate_prompt = ''
|
activate_prompt = ''
|
||||||
@@ -127,10 +123,9 @@ class PromptBuilder:
|
|||||||
for rule in global_config.keywords_reaction_rules:
|
for rule in global_config.keywords_reaction_rules:
|
||||||
if rule.get("enable", False):
|
if rule.get("enable", False):
|
||||||
if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])):
|
if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])):
|
||||||
print(f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}")
|
logger.info(f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}")
|
||||||
keywords_reaction_prompt += rule.get("reaction", "") + ','
|
keywords_reaction_prompt += rule.get("reaction", "") + ','
|
||||||
|
|
||||||
|
|
||||||
# 人格选择
|
# 人格选择
|
||||||
personality = global_config.PROMPT_PERSONALITY
|
personality = global_config.PROMPT_PERSONALITY
|
||||||
probability_1 = global_config.PERSONALITY_1
|
probability_1 = global_config.PERSONALITY_1
|
||||||
@@ -187,7 +182,7 @@ class PromptBuilder:
|
|||||||
|
|
||||||
return prompt, prompt_check_if_response
|
return prompt, prompt_check_if_response
|
||||||
|
|
||||||
def _build_initiative_prompt_select(self,group_id):
|
def _build_initiative_prompt_select(self, group_id, probability_1=0.8, probability_2=0.1):
|
||||||
current_date = time.strftime("%Y-%m-%d", time.localtime())
|
current_date = time.strftime("%Y-%m-%d", time.localtime())
|
||||||
current_time = time.strftime("%H:%M:%S", time.localtime())
|
current_time = time.strftime("%H:%M:%S", time.localtime())
|
||||||
bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task()
|
bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task()
|
||||||
@@ -195,7 +190,9 @@ class PromptBuilder:
|
|||||||
|
|
||||||
chat_talking_prompt = ''
|
chat_talking_prompt = ''
|
||||||
if group_id:
|
if group_id:
|
||||||
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
|
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id,
|
||||||
|
limit=global_config.MAX_CONTEXT_SIZE,
|
||||||
|
combine=True)
|
||||||
|
|
||||||
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
|
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
|
||||||
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
|
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
|
||||||
@@ -238,10 +235,9 @@ class PromptBuilder:
|
|||||||
prompt_for_initiative = f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n,请在把握群里的聊天内容的基础上,综合群内的氛围,以日常且口语化的口吻,简短且随意一点进行发言,不要说的太有条理,可以有个性。记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等)"
|
prompt_for_initiative = f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n,请在把握群里的聊天内容的基础上,综合群内的氛围,以日常且口语化的口吻,简短且随意一点进行发言,不要说的太有条理,可以有个性。记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等)"
|
||||||
return prompt_for_initiative
|
return prompt_for_initiative
|
||||||
|
|
||||||
|
|
||||||
async def get_prompt_info(self, message: str, threshold: float):
|
async def get_prompt_info(self, message: str, threshold: float):
|
||||||
related_info = ''
|
related_info = ''
|
||||||
print(f"\033[1;34m[调试]\033[0m 获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||||
embedding = await get_embedding(message)
|
embedding = await get_embedding(message)
|
||||||
related_info += self.get_info_from_db(embedding, threshold=threshold)
|
related_info += self.get_info_from_db(embedding, threshold=threshold)
|
||||||
|
|
||||||
@@ -315,4 +311,5 @@ class PromptBuilder:
|
|||||||
# 返回所有找到的内容,用换行分隔
|
# 返回所有找到的内容,用换行分隔
|
||||||
return '\n'.join(str(result['content']) for result in results)
|
return '\n'.join(str(result['content']) for result in results)
|
||||||
|
|
||||||
|
|
||||||
prompt_builder = PromptBuilder()
|
prompt_builder = PromptBuilder()
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
from loguru import logger
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ...common.database import Database
|
from ...common.database import Database
|
||||||
@@ -11,6 +12,7 @@ class Impression:
|
|||||||
|
|
||||||
relationship_value: float = None
|
relationship_value: float = None
|
||||||
|
|
||||||
|
|
||||||
class Relationship:
|
class Relationship:
|
||||||
user_id: int = None
|
user_id: int = None
|
||||||
# impression: Impression = None
|
# impression: Impression = None
|
||||||
@@ -41,8 +43,6 @@ class Relationship:
|
|||||||
self.saved = kwargs.get('saved', False)
|
self.saved = kwargs.get('saved', False)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RelationshipManager:
|
class RelationshipManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.relationships: dict[int, Relationship] = {}
|
self.relationships: dict[int, Relationship] = {}
|
||||||
@@ -62,7 +62,8 @@ class RelationshipManager:
|
|||||||
setattr(relationship, key, value)
|
setattr(relationship, key, value)
|
||||||
else:
|
else:
|
||||||
# 如果不存在,创建新对象
|
# 如果不存在,创建新对象
|
||||||
relationship = Relationship(user_id, data=data) if isinstance(data, dict) else Relationship(user_id, **kwargs)
|
relationship = Relationship(user_id, data=data) if isinstance(data, dict) else Relationship(user_id,
|
||||||
|
**kwargs)
|
||||||
self.relationships[user_id] = relationship
|
self.relationships[user_id] = relationship
|
||||||
|
|
||||||
# 更新 id_name_nickname_table
|
# 更新 id_name_nickname_table
|
||||||
@@ -85,10 +86,9 @@ class RelationshipManager:
|
|||||||
relationship.saved = True
|
relationship.saved = True
|
||||||
return relationship
|
return relationship
|
||||||
else:
|
else:
|
||||||
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id} 不存在,无法更新")
|
logger.warning(f"用户 {user_id} 不存在,无法更新")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_relationship(self, user_id: int) -> Optional[Relationship]:
|
def get_relationship(self, user_id: int) -> Optional[Relationship]:
|
||||||
"""获取用户关系对象"""
|
"""获取用户关系对象"""
|
||||||
if user_id in self.relationships:
|
if user_id in self.relationships:
|
||||||
@@ -120,10 +120,10 @@ class RelationshipManager:
|
|||||||
user_id = data['user_id']
|
user_id = data['user_id']
|
||||||
relationship = await self.load_relationship(data)
|
relationship = await self.load_relationship(data)
|
||||||
self.relationships[user_id] = relationship
|
self.relationships[user_id] = relationship
|
||||||
print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录")
|
logger.debug(f"已加载 {len(self.relationships)} 条关系记录")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
print("\033[1;32m[关系管理]\033[0m 正在自动保存关系")
|
logger.debug("正在自动保存关系")
|
||||||
await asyncio.sleep(300) # 等待300秒(5分钟)
|
await asyncio.sleep(300) # 等待300秒(5分钟)
|
||||||
await self._save_all_relationships()
|
await self._save_all_relationships()
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from ...common.database import Database
|
from ...common.database import Database
|
||||||
from .message import Message
|
from .message import Message
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
class MessageStorage:
|
class MessageStorage:
|
||||||
@@ -43,7 +44,7 @@ class MessageStorage:
|
|||||||
}
|
}
|
||||||
|
|
||||||
self.db.db.messages.insert_one(message_data)
|
self.db.db.messages.insert_one(message_data)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print(f"\033[1;31m[错误]\033[0m 存储消息失败: {e}")
|
logger.exception(f"存储消息失败")
|
||||||
|
|
||||||
# 如果需要其他存储相关的函数,可以在这里添加
|
# 如果需要其他存储相关的函数,可以在这里添加
|
||||||
@@ -4,10 +4,12 @@ from nonebot import get_driver
|
|||||||
|
|
||||||
from ..models.utils_model import LLM_request
|
from ..models.utils_model import LLM_request
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
config = driver.config
|
config = driver.config
|
||||||
|
|
||||||
|
|
||||||
class TopicIdentifier:
|
class TopicIdentifier:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.llm_topic_judge = LLM_request(model=global_config.llm_topic_judge)
|
self.llm_topic_judge = LLM_request(model=global_config.llm_topic_judge)
|
||||||
@@ -25,7 +27,7 @@ class TopicIdentifier:
|
|||||||
topic, _ = await self.llm_topic_judge.generate_response(prompt)
|
topic, _ = await self.llm_topic_judge.generate_response(prompt)
|
||||||
|
|
||||||
if not topic:
|
if not topic:
|
||||||
print("\033[1;31m[错误]\033[0m LLM API 返回为空")
|
logger.error("LLM API 返回为空")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 直接在这里处理主题解析
|
# 直接在这里处理主题解析
|
||||||
@@ -35,7 +37,8 @@ class TopicIdentifier:
|
|||||||
# 解析主题字符串为列表
|
# 解析主题字符串为列表
|
||||||
topic_list = [t.strip() for t in topic.split(",") if t.strip()]
|
topic_list = [t.strip() for t in topic.split(",") if t.strip()]
|
||||||
|
|
||||||
print(f"\033[1;32m[主题识别]\033[0m 主题: {topic_list}")
|
logger.info(f"主题: {topic_list}")
|
||||||
return topic_list if topic_list else None
|
return topic_list if topic_list else None
|
||||||
|
|
||||||
|
|
||||||
topic_identifier = TopicIdentifier()
|
topic_identifier = TopicIdentifier()
|
||||||
@@ -7,6 +7,7 @@ from typing import Dict, List
|
|||||||
import jieba
|
import jieba
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from nonebot import get_driver
|
from nonebot import get_driver
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from ..models.utils_model import LLM_request
|
from ..models.utils_model import LLM_request
|
||||||
from ..utils.typo_generator import ChineseTypoGenerator
|
from ..utils.typo_generator import ChineseTypoGenerator
|
||||||
@@ -39,7 +40,7 @@ def combine_messages(messages: List[Message]) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def db_message_to_str(message_dict: Dict) -> str:
|
def db_message_to_str(message_dict: Dict) -> str:
|
||||||
print(f"message_dict: {message_dict}")
|
logger.debug(f"message_dict: {message_dict}")
|
||||||
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
|
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
|
||||||
try:
|
try:
|
||||||
name = "[(%s)%s]%s" % (
|
name = "[(%s)%s]%s" % (
|
||||||
@@ -48,7 +49,7 @@ def db_message_to_str(message_dict: Dict) -> str:
|
|||||||
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
|
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
|
||||||
content = message_dict.get("processed_plain_text", "")
|
content = message_dict.get("processed_plain_text", "")
|
||||||
result = f"[{time_str}] {name}: {content}\n"
|
result = f"[{time_str}] {name}: {content}\n"
|
||||||
print(f"result: {result}")
|
logger.debug(f"result: {result}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -182,7 +183,7 @@ async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
|||||||
await msg.initialize()
|
await msg.initialize()
|
||||||
message_objects.append(msg)
|
message_objects.append(msg)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print("[WARNING] 数据库中存在无效的消息")
|
logger.warning("数据库中存在无效的消息")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 按时间正序排列
|
# 按时间正序排列
|
||||||
@@ -298,11 +299,10 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]:
|
|||||||
sentence = sentence.replace(',', ' ').replace(',', ' ')
|
sentence = sentence.replace(',', ' ').replace(',', ' ')
|
||||||
sentences_done.append(sentence)
|
sentences_done.append(sentence)
|
||||||
|
|
||||||
print(f"处理后的句子: {sentences_done}")
|
logger.info(f"处理后的句子: {sentences_done}")
|
||||||
return sentences_done
|
return sentences_done
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def random_remove_punctuation(text: str) -> str:
|
def random_remove_punctuation(text: str) -> str:
|
||||||
"""随机处理标点符号,模拟人类打字习惯
|
"""随机处理标点符号,模拟人类打字习惯
|
||||||
|
|
||||||
@@ -330,11 +330,10 @@ def random_remove_punctuation(text: str) -> str:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def process_llm_response(text: str) -> List[str]:
|
def process_llm_response(text: str) -> List[str]:
|
||||||
# processed_response = process_text_with_typos(content)
|
# processed_response = process_text_with_typos(content)
|
||||||
if len(text) > 200:
|
if len(text) > 200:
|
||||||
print(f"回复过长 ({len(text)} 字符),返回默认回复")
|
logger.warning(f"回复过长 ({len(text)} 字符),返回默认回复")
|
||||||
return ['懒得说']
|
return ['懒得说']
|
||||||
# 处理长消息
|
# 处理长消息
|
||||||
typo_generator = ChineseTypoGenerator(
|
typo_generator = ChineseTypoGenerator(
|
||||||
@@ -356,7 +355,7 @@ def process_llm_response(text: str) -> List[str]:
|
|||||||
# 检查分割后的消息数量是否过多(超过3条)
|
# 检查分割后的消息数量是否过多(超过3条)
|
||||||
|
|
||||||
if len(sentences) > 5:
|
if len(sentences) > 5:
|
||||||
print(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
|
logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
|
||||||
return [f'{global_config.BOT_NICKNAME}不知道哦']
|
return [f'{global_config.BOT_NICKNAME}不知道哦']
|
||||||
|
|
||||||
return sentences
|
return sentences
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
class WillingManager:
|
class WillingManager:
|
||||||
@@ -30,16 +31,16 @@ class WillingManager:
|
|||||||
# print(f"初始意愿: {current_willing}")
|
# print(f"初始意愿: {current_willing}")
|
||||||
if is_mentioned_bot and current_willing < 1.0:
|
if is_mentioned_bot and current_willing < 1.0:
|
||||||
current_willing += 0.9
|
current_willing += 0.9
|
||||||
print(f"被提及, 当前意愿: {current_willing}")
|
logger.info(f"被提及, 当前意愿: {current_willing}")
|
||||||
elif is_mentioned_bot:
|
elif is_mentioned_bot:
|
||||||
current_willing += 0.05
|
current_willing += 0.05
|
||||||
print(f"被重复提及, 当前意愿: {current_willing}")
|
logger.info(f"被重复提及, 当前意愿: {current_willing}")
|
||||||
|
|
||||||
if is_emoji:
|
if is_emoji:
|
||||||
current_willing *= 0.1
|
current_willing *= 0.1
|
||||||
print(f"表情包, 当前意愿: {current_willing}")
|
logger.info(f"表情包, 当前意愿: {current_willing}")
|
||||||
|
|
||||||
print(f"放大系数_interested_rate: {global_config.response_interested_rate_amplifier}")
|
logger.debug(f"放大系数_interested_rate: {global_config.response_interested_rate_amplifier}")
|
||||||
interested_rate *= global_config.response_interested_rate_amplifier #放大回复兴趣度
|
interested_rate *= global_config.response_interested_rate_amplifier #放大回复兴趣度
|
||||||
if interested_rate > 0.4:
|
if interested_rate > 0.4:
|
||||||
# print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}")
|
# print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}")
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import jieba
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
|
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
|
||||||
from src.common.database import Database # 使用正确的导入语法
|
from src.common.database import Database # 使用正确的导入语法
|
||||||
@@ -99,18 +100,20 @@ class Memory_graph:
|
|||||||
# 返回所有节点对应的 Memory_dot 对象
|
# 返回所有节点对应的 Memory_dot 对象
|
||||||
return [self.get_dot(node) for node in self.G.nodes()]
|
return [self.get_dot(node) for node in self.G.nodes()]
|
||||||
|
|
||||||
|
|
||||||
def get_random_chat_from_db(self, length: int, timestamp: str):
|
def get_random_chat_from_db(self, length: int, timestamp: str):
|
||||||
# 从数据库中根据时间戳获取离其最近的聊天记录
|
# 从数据库中根据时间戳获取离其最近的聊天记录
|
||||||
chat_text = ''
|
chat_text = ''
|
||||||
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
|
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
|
||||||
print(f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
|
logger.info(
|
||||||
|
f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
|
||||||
|
|
||||||
if closest_record:
|
if closest_record:
|
||||||
closest_time = closest_record['time']
|
closest_time = closest_record['time']
|
||||||
group_id = closest_record['group_id'] # 获取groupid
|
group_id = closest_record['group_id'] # 获取groupid
|
||||||
# 获取该时间戳之后的length条消息,且groupid相同
|
# 获取该时间戳之后的length条消息,且groupid相同
|
||||||
chat_record = list(self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(length))
|
chat_record = list(
|
||||||
|
self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(
|
||||||
|
length))
|
||||||
for record in chat_record:
|
for record in chat_record:
|
||||||
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
||||||
try:
|
try:
|
||||||
@@ -179,30 +182,31 @@ def main():
|
|||||||
break
|
break
|
||||||
first_layer_items, second_layer_items = memory_graph.get_related_item(query)
|
first_layer_items, second_layer_items = memory_graph.get_related_item(query)
|
||||||
if first_layer_items or second_layer_items:
|
if first_layer_items or second_layer_items:
|
||||||
print("\n第一层记忆:")
|
logger.debug("第一层记忆:")
|
||||||
for item in first_layer_items:
|
for item in first_layer_items:
|
||||||
print(item)
|
logger.debug(item)
|
||||||
print("\n第二层记忆:")
|
logger.debug("第二层记忆:")
|
||||||
for item in second_layer_items:
|
for item in second_layer_items:
|
||||||
print(item)
|
logger.debug(item)
|
||||||
else:
|
else:
|
||||||
print("未找到相关记忆。")
|
logger.debug("未找到相关记忆。")
|
||||||
|
|
||||||
|
|
||||||
def segment_text(text):
|
def segment_text(text):
|
||||||
seg_text = list(jieba.cut(text))
|
seg_text = list(jieba.cut(text))
|
||||||
return seg_text
|
return seg_text
|
||||||
|
|
||||||
|
|
||||||
def find_topic(text, topic_num):
|
def find_topic(text, topic_num):
|
||||||
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
|
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个话题,帮我列出来,用逗号隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要告诉我其他内容。'
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def topic_what(text, topic):
|
def topic_what(text, topic):
|
||||||
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
|
prompt = f'这是一段文字:{text}。我想知道这记忆里有什么关于{topic}的话题,帮我总结成一句自然的话,可以包含时间和人物。只输出这句话就好'
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
|
def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = False):
|
||||||
# 设置中文字体
|
# 设置中文字体
|
||||||
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
|
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
|
||||||
@@ -226,7 +230,7 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
|
|||||||
|
|
||||||
# 如果过滤后没有节点,则返回
|
# 如果过滤后没有节点,则返回
|
||||||
if len(H.nodes()) == 0:
|
if len(H.nodes()) == 0:
|
||||||
print("过滤后没有符合条件的节点可显示")
|
logger.debug("过滤后没有符合条件的节点可显示")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 保存图到本地
|
# 保存图到本地
|
||||||
@@ -287,6 +291,5 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
|
|||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
@@ -7,6 +7,7 @@ import time
|
|||||||
import jieba
|
import jieba
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
from ...common.database import Database # 使用正确的导入语法
|
from ...common.database import Database # 使用正确的导入语法
|
||||||
from ..chat.config import global_config
|
from ..chat.config import global_config
|
||||||
from ..chat.utils import (
|
from ..chat.utils import (
|
||||||
@@ -223,17 +224,18 @@ class Hippocampus:
|
|||||||
for msg in messages:
|
for msg in messages:
|
||||||
input_text += f"{msg['text']}\n"
|
input_text += f"{msg['text']}\n"
|
||||||
|
|
||||||
print(input_text)
|
logger.debug(input_text)
|
||||||
|
|
||||||
topic_num = self.calculate_topic_num(input_text, compress_rate)
|
topic_num = self.calculate_topic_num(input_text, compress_rate)
|
||||||
topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(input_text, topic_num))
|
topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(input_text, topic_num))
|
||||||
|
|
||||||
# 过滤topics
|
# 过滤topics
|
||||||
filter_keywords = global_config.memory_ban_words
|
filter_keywords = global_config.memory_ban_words
|
||||||
topics = [topic.strip() for topic in topics_response[0].replace(",", ",").replace("、", ",").replace(" ", ",").split(",") if topic.strip()]
|
topics = [topic.strip() for topic in
|
||||||
|
topics_response[0].replace(",", ",").replace("、", ",").replace(" ", ",").split(",") if topic.strip()]
|
||||||
filtered_topics = [topic for topic in topics if not any(keyword in topic for keyword in filter_keywords)]
|
filtered_topics = [topic for topic in topics if not any(keyword in topic for keyword in filter_keywords)]
|
||||||
|
|
||||||
print(f"过滤后话题: {filtered_topics}")
|
logger.info(f"过滤后话题: {filtered_topics}")
|
||||||
|
|
||||||
# 创建所有话题的请求任务
|
# 创建所有话题的请求任务
|
||||||
tasks = []
|
tasks = []
|
||||||
@@ -257,7 +259,9 @@ class Hippocampus:
|
|||||||
topic_by_length = text.count('\n') * compress_rate
|
topic_by_length = text.count('\n') * compress_rate
|
||||||
topic_by_information_content = max(1, min(5, int((information_content - 3) * 2)))
|
topic_by_information_content = max(1, min(5, int((information_content - 3) * 2)))
|
||||||
topic_num = int((topic_by_length + topic_by_information_content) / 2)
|
topic_num = int((topic_by_length + topic_by_information_content) / 2)
|
||||||
print(f"topic_by_length: {topic_by_length}, topic_by_information_content: {topic_by_information_content}, topic_num: {topic_num}")
|
logger.debug(
|
||||||
|
f"topic_by_length: {topic_by_length}, topic_by_information_content: {topic_by_information_content}, "
|
||||||
|
f"topic_num: {topic_num}")
|
||||||
return topic_num
|
return topic_num
|
||||||
|
|
||||||
async def operation_build_memory(self, chat_size=20):
|
async def operation_build_memory(self, chat_size=20):
|
||||||
@@ -272,22 +276,22 @@ class Hippocampus:
|
|||||||
bar_length = 30
|
bar_length = 30
|
||||||
filled_length = int(bar_length * i // len(memory_sample))
|
filled_length = int(bar_length * i // len(memory_sample))
|
||||||
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
bar = '█' * filled_length + '-' * (bar_length - filled_length)
|
||||||
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
|
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
|
||||||
|
|
||||||
# 生成压缩后记忆 ,表现为 (话题,记忆) 的元组
|
# 生成压缩后记忆 ,表现为 (话题,记忆) 的元组
|
||||||
compressed_memory = set()
|
compressed_memory = set()
|
||||||
compress_rate = 0.1
|
compress_rate = 0.1
|
||||||
compressed_memory = await self.memory_compress(input_text, compress_rate)
|
compressed_memory = await self.memory_compress(input_text, compress_rate)
|
||||||
print(f"\033[1;33m压缩后记忆数量\033[0m: {len(compressed_memory)}")
|
logger.info(f"压缩后记忆数量: {len(compressed_memory)}")
|
||||||
|
|
||||||
# 将记忆加入到图谱中
|
# 将记忆加入到图谱中
|
||||||
for topic, memory in compressed_memory:
|
for topic, memory in compressed_memory:
|
||||||
print(f"\033[1;32m添加节点\033[0m: {topic}")
|
logger.info(f"添加节点: {topic}")
|
||||||
self.memory_graph.add_dot(topic, memory)
|
self.memory_graph.add_dot(topic, memory)
|
||||||
all_topics.append(topic) # 收集所有话题
|
all_topics.append(topic) # 收集所有话题
|
||||||
for i in range(len(all_topics)):
|
for i in range(len(all_topics)):
|
||||||
for j in range(i + 1, len(all_topics)):
|
for j in range(i + 1, len(all_topics)):
|
||||||
print(f"\033[1;32m连接节点\033[0m: {all_topics[i]} 和 {all_topics[j]}")
|
logger.info(f"连接节点: {all_topics[i]} 和 {all_topics[j]}")
|
||||||
self.memory_graph.connect_dot(all_topics[i], all_topics[j])
|
self.memory_graph.connect_dot(all_topics[i], all_topics[j])
|
||||||
|
|
||||||
self.sync_memory_to_db()
|
self.sync_memory_to_db()
|
||||||
@@ -448,14 +452,14 @@ class Hippocampus:
|
|||||||
removed_item = self.memory_graph.forget_topic(node)
|
removed_item = self.memory_graph.forget_topic(node)
|
||||||
if removed_item:
|
if removed_item:
|
||||||
forgotten_nodes.append((node, removed_item))
|
forgotten_nodes.append((node, removed_item))
|
||||||
print(f"遗忘节点 {node} 的记忆: {removed_item}")
|
logger.debug(f"遗忘节点 {node} 的记忆: {removed_item}")
|
||||||
|
|
||||||
# 同步到数据库
|
# 同步到数据库
|
||||||
if forgotten_nodes:
|
if forgotten_nodes:
|
||||||
self.sync_memory_to_db()
|
self.sync_memory_to_db()
|
||||||
print(f"完成遗忘操作,共遗忘 {len(forgotten_nodes)} 个节点的记忆")
|
logger.debug(f"完成遗忘操作,共遗忘 {len(forgotten_nodes)} 个节点的记忆")
|
||||||
else:
|
else:
|
||||||
print("本次检查没有节点满足遗忘条件")
|
logger.debug("本次检查没有节点满足遗忘条件")
|
||||||
|
|
||||||
async def merge_memory(self, topic):
|
async def merge_memory(self, topic):
|
||||||
"""
|
"""
|
||||||
@@ -478,8 +482,8 @@ class Hippocampus:
|
|||||||
|
|
||||||
# 拼接成文本
|
# 拼接成文本
|
||||||
merged_text = "\n".join(selected_memories)
|
merged_text = "\n".join(selected_memories)
|
||||||
print(f"\n[合并记忆] 话题: {topic}")
|
logger.debug(f"\n[合并记忆] 话题: {topic}")
|
||||||
print(f"选择的记忆:\n{merged_text}")
|
logger.debug(f"选择的记忆:\n{merged_text}")
|
||||||
|
|
||||||
# 使用memory_compress生成新的压缩记忆
|
# 使用memory_compress生成新的压缩记忆
|
||||||
compressed_memories = await self.memory_compress(selected_memories, 0.1)
|
compressed_memories = await self.memory_compress(selected_memories, 0.1)
|
||||||
@@ -491,11 +495,11 @@ class Hippocampus:
|
|||||||
# 添加新的压缩记忆
|
# 添加新的压缩记忆
|
||||||
for _, compressed_memory in compressed_memories:
|
for _, compressed_memory in compressed_memories:
|
||||||
memory_items.append(compressed_memory)
|
memory_items.append(compressed_memory)
|
||||||
print(f"添加压缩记忆: {compressed_memory}")
|
logger.info(f"添加压缩记忆: {compressed_memory}")
|
||||||
|
|
||||||
# 更新节点的记忆项
|
# 更新节点的记忆项
|
||||||
self.memory_graph.G.nodes[topic]['memory_items'] = memory_items
|
self.memory_graph.G.nodes[topic]['memory_items'] = memory_items
|
||||||
print(f"完成记忆合并,当前记忆数量: {len(memory_items)}")
|
logger.debug(f"完成记忆合并,当前记忆数量: {len(memory_items)}")
|
||||||
|
|
||||||
async def operation_merge_memory(self, percentage=0.1):
|
async def operation_merge_memory(self, percentage=0.1):
|
||||||
"""
|
"""
|
||||||
@@ -521,16 +525,16 @@ class Hippocampus:
|
|||||||
|
|
||||||
# 如果内容数量超过100,进行合并
|
# 如果内容数量超过100,进行合并
|
||||||
if content_count > 100:
|
if content_count > 100:
|
||||||
print(f"\n检查节点: {node}, 当前记忆数量: {content_count}")
|
logger.debug(f"检查节点: {node}, 当前记忆数量: {content_count}")
|
||||||
await self.merge_memory(node)
|
await self.merge_memory(node)
|
||||||
merged_nodes.append(node)
|
merged_nodes.append(node)
|
||||||
|
|
||||||
# 同步到数据库
|
# 同步到数据库
|
||||||
if merged_nodes:
|
if merged_nodes:
|
||||||
self.sync_memory_to_db()
|
self.sync_memory_to_db()
|
||||||
print(f"\n完成记忆合并操作,共处理 {len(merged_nodes)} 个节点")
|
logger.debug(f"完成记忆合并操作,共处理 {len(merged_nodes)} 个节点")
|
||||||
else:
|
else:
|
||||||
print("\n本次检查没有需要合并的节点")
|
logger.debug("本次检查没有需要合并的节点")
|
||||||
|
|
||||||
def find_topic_llm(self, text, topic_num):
|
def find_topic_llm(self, text, topic_num):
|
||||||
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,用逗号,隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要有序号,不要告诉我其他内容。'
|
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,用逗号,隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要有序号,不要告诉我其他内容。'
|
||||||
@@ -551,7 +555,8 @@ class Hippocampus:
|
|||||||
"""
|
"""
|
||||||
topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(text, 5))
|
topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(text, 5))
|
||||||
# print(f"话题: {topics_response[0]}")
|
# print(f"话题: {topics_response[0]}")
|
||||||
topics = [topic.strip() for topic in topics_response[0].replace(",", ",").replace("、", ",").replace(" ", ",").split(",") if topic.strip()]
|
topics = [topic.strip() for topic in
|
||||||
|
topics_response[0].replace(",", ",").replace("、", ",").replace(" ", ",").split(",") if topic.strip()]
|
||||||
# print(f"话题: {topics}")
|
# print(f"话题: {topics}")
|
||||||
|
|
||||||
return topics
|
return topics
|
||||||
@@ -624,7 +629,7 @@ class Hippocampus:
|
|||||||
|
|
||||||
async def memory_activate_value(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.3) -> int:
|
async def memory_activate_value(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.3) -> int:
|
||||||
"""计算输入文本对记忆的激活程度"""
|
"""计算输入文本对记忆的激活程度"""
|
||||||
print(f"\033[1;32m[记忆激活]\033[0m 识别主题: {await self._identify_topics(text)}")
|
logger.info(f"识别主题: {await self._identify_topics(text)}")
|
||||||
|
|
||||||
# 识别主题
|
# 识别主题
|
||||||
identified_topics = await self._identify_topics(text)
|
identified_topics = await self._identify_topics(text)
|
||||||
@@ -655,7 +660,8 @@ class Hippocampus:
|
|||||||
penalty = 1.0 / (1 + math.log(content_count + 1))
|
penalty = 1.0 / (1 + math.log(content_count + 1))
|
||||||
|
|
||||||
activation = int(score * 50 * penalty)
|
activation = int(score * 50 * penalty)
|
||||||
print(f"\033[1;32m[记忆激活]\033[0m 单主题「{topic}」- 相似度: {score:.3f}, 内容数: {content_count}, 激活值: {activation}")
|
logger.info(
|
||||||
|
f"[记忆激活]单主题「{topic}」- 相似度: {score:.3f}, 内容数: {content_count}, 激活值: {activation}")
|
||||||
return activation
|
return activation
|
||||||
|
|
||||||
# 计算关键词匹配率,同时考虑内容数量
|
# 计算关键词匹配率,同时考虑内容数量
|
||||||
@@ -682,7 +688,8 @@ class Hippocampus:
|
|||||||
matched_topics.add(input_topic)
|
matched_topics.add(input_topic)
|
||||||
adjusted_sim = sim * penalty
|
adjusted_sim = sim * penalty
|
||||||
topic_similarities[input_topic] = max(topic_similarities.get(input_topic, 0), adjusted_sim)
|
topic_similarities[input_topic] = max(topic_similarities.get(input_topic, 0), adjusted_sim)
|
||||||
print(f"\033[1;32m[记忆激活]\033[0m 主题「{input_topic}」-> 「{memory_topic}」(内容数: {content_count}, 相似度: {adjusted_sim:.3f})")
|
logger.info(
|
||||||
|
f"[记忆激活]主题「{input_topic}」-> 「{memory_topic}」(内容数: {content_count}, 相似度: {adjusted_sim:.3f})")
|
||||||
|
|
||||||
# 计算主题匹配率和平均相似度
|
# 计算主题匹配率和平均相似度
|
||||||
topic_match = len(matched_topics) / len(identified_topics)
|
topic_match = len(matched_topics) / len(identified_topics)
|
||||||
@@ -690,11 +697,13 @@ class Hippocampus:
|
|||||||
|
|
||||||
# 计算最终激活值
|
# 计算最终激活值
|
||||||
activation = int((topic_match + average_similarities) / 2 * 100)
|
activation = int((topic_match + average_similarities) / 2 * 100)
|
||||||
print(f"\033[1;32m[记忆激活]\033[0m 匹配率: {topic_match:.3f}, 平均相似度: {average_similarities:.3f}, 激活值: {activation}")
|
logger.info(
|
||||||
|
f"[记忆激活]匹配率: {topic_match:.3f}, 平均相似度: {average_similarities:.3f}, 激活值: {activation}")
|
||||||
|
|
||||||
return activation
|
return activation
|
||||||
|
|
||||||
async def get_relevant_memories(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.4, max_memory_num: int = 5) -> list:
|
async def get_relevant_memories(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.4,
|
||||||
|
max_memory_num: int = 5) -> list:
|
||||||
"""根据输入文本获取相关的记忆内容"""
|
"""根据输入文本获取相关的记忆内容"""
|
||||||
# 识别主题
|
# 识别主题
|
||||||
identified_topics = await self._identify_topics(text)
|
identified_topics = await self._identify_topics(text)
|
||||||
@@ -764,4 +773,4 @@ hippocampus = Hippocampus(memory_graph)
|
|||||||
hippocampus.sync_memory_from_db()
|
hippocampus.sync_memory_from_db()
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
print(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
|
logger.success(f"加载海马体耗时: {end_time - start_time:.2f} 秒")
|
||||||
|
|||||||
@@ -743,7 +743,7 @@ class Hippocampus:
|
|||||||
|
|
||||||
async def memory_activate_value(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.3) -> int:
|
async def memory_activate_value(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.3) -> int:
|
||||||
"""计算输入文本对记忆的激活程度"""
|
"""计算输入文本对记忆的激活程度"""
|
||||||
print(f"\033[1;32m[记忆激活]\033[0m 识别主题: {await self._identify_topics(text)}")
|
logger.info(f"[记忆激活]识别主题: {await self._identify_topics(text)}")
|
||||||
|
|
||||||
identified_topics = await self._identify_topics(text)
|
identified_topics = await self._identify_topics(text)
|
||||||
if not identified_topics:
|
if not identified_topics:
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ class LLM_request:
|
|||||||
self.db.db.llm_usage.create_index([("user_id", 1)])
|
self.db.db.llm_usage.create_index([("user_id", 1)])
|
||||||
self.db.db.llm_usage.create_index([("request_type", 1)])
|
self.db.db.llm_usage.create_index([("request_type", 1)])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"创建数据库索引失败: {e}")
|
logger.error(f"创建数据库索引失败")
|
||||||
|
|
||||||
def _record_usage(self, prompt_tokens: int, completion_tokens: int, total_tokens: int,
|
def _record_usage(self, prompt_tokens: int, completion_tokens: int, total_tokens: int,
|
||||||
user_id: str = "system", request_type: str = "chat",
|
user_id: str = "system", request_type: str = "chat",
|
||||||
@@ -79,8 +79,8 @@ class LLM_request:
|
|||||||
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
|
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
|
||||||
f"总计: {total_tokens}"
|
f"总计: {total_tokens}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error(f"记录token使用情况失败: {e}")
|
logger.error(f"记录token使用情况失败")
|
||||||
|
|
||||||
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
|
def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
|
||||||
"""计算API调用成本
|
"""计算API调用成本
|
||||||
@@ -143,9 +143,9 @@ class LLM_request:
|
|||||||
# 判断是否为流式
|
# 判断是否为流式
|
||||||
stream_mode = self.params.get("stream", False)
|
stream_mode = self.params.get("stream", False)
|
||||||
if self.params.get("stream", False) is True:
|
if self.params.get("stream", False) is True:
|
||||||
logger.info(f"进入流式输出模式,发送请求到URL: {api_url}")
|
logger.debug(f"进入流式输出模式,发送请求到URL: {api_url}")
|
||||||
else:
|
else:
|
||||||
logger.info(f"发送请求到URL: {api_url}")
|
logger.debug(f"发送请求到URL: {api_url}")
|
||||||
logger.info(f"使用模型: {self.model_name}")
|
logger.info(f"使用模型: {self.model_name}")
|
||||||
|
|
||||||
# 构建请求体
|
# 构建请求体
|
||||||
@@ -184,13 +184,15 @@ class LLM_request:
|
|||||||
logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}")
|
logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}")
|
||||||
if response.status == 403:
|
if response.status == 403:
|
||||||
# 尝试降级Pro模型
|
# 尝试降级Pro模型
|
||||||
if self.model_name.startswith("Pro/") and self.base_url == "https://api.siliconflow.cn/v1/":
|
if self.model_name.startswith(
|
||||||
|
"Pro/") and self.base_url == "https://api.siliconflow.cn/v1/":
|
||||||
old_model_name = self.model_name
|
old_model_name = self.model_name
|
||||||
self.model_name = self.model_name[4:] # 移除"Pro/"前缀
|
self.model_name = self.model_name[4:] # 移除"Pro/"前缀
|
||||||
logger.warning(f"检测到403错误,模型从 {old_model_name} 降级为 {self.model_name}")
|
logger.warning(f"检测到403错误,模型从 {old_model_name} 降级为 {self.model_name}")
|
||||||
|
|
||||||
# 对全局配置进行更新
|
# 对全局配置进行更新
|
||||||
if hasattr(global_config, 'llm_normal') and global_config.llm_normal.get('name') == old_model_name:
|
if hasattr(global_config, 'llm_normal') and global_config.llm_normal.get(
|
||||||
|
'name') == old_model_name:
|
||||||
global_config.llm_normal['name'] = self.model_name
|
global_config.llm_normal['name'] = self.model_name
|
||||||
logger.warning(f"已将全局配置中的 llm_normal 模型降级")
|
logger.warning(f"已将全局配置中的 llm_normal 模型降级")
|
||||||
|
|
||||||
@@ -224,8 +226,8 @@ class LLM_request:
|
|||||||
if delta_content is None:
|
if delta_content is None:
|
||||||
delta_content = ""
|
delta_content = ""
|
||||||
accumulated_content += delta_content
|
accumulated_content += delta_content
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error(f"解析流式输出错误: {e}")
|
logger.exception(f"解析流式输出错")
|
||||||
content = accumulated_content
|
content = accumulated_content
|
||||||
reasoning_content = ""
|
reasoning_content = ""
|
||||||
think_match = re.search(r'<think>(.*?)</think>', content, re.DOTALL)
|
think_match = re.search(r'<think>(.*?)</think>', content, re.DOTALL)
|
||||||
@@ -233,12 +235,15 @@ class LLM_request:
|
|||||||
reasoning_content = think_match.group(1).strip()
|
reasoning_content = think_match.group(1).strip()
|
||||||
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
|
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
|
||||||
# 构造一个伪result以便调用自定义响应处理器或默认处理器
|
# 构造一个伪result以便调用自定义响应处理器或默认处理器
|
||||||
result = {"choices": [{"message": {"content": content, "reasoning_content": reasoning_content}}]}
|
result = {
|
||||||
return response_handler(result) if response_handler else self._default_response_handler(result, user_id, request_type, endpoint)
|
"choices": [{"message": {"content": content, "reasoning_content": reasoning_content}}]}
|
||||||
|
return response_handler(result) if response_handler else self._default_response_handler(
|
||||||
|
result, user_id, request_type, endpoint)
|
||||||
else:
|
else:
|
||||||
result = await response.json()
|
result = await response.json()
|
||||||
# 使用自定义处理器或默认处理
|
# 使用自定义处理器或默认处理
|
||||||
return response_handler(result) if response_handler else self._default_response_handler(result, user_id, request_type, endpoint)
|
return response_handler(result) if response_handler else self._default_response_handler(
|
||||||
|
result, user_id, request_type, endpoint)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if retry < policy["max_retries"] - 1:
|
if retry < policy["max_retries"] - 1:
|
||||||
@@ -262,7 +267,8 @@ class LLM_request:
|
|||||||
# 复制一份参数,避免直接修改原始数据
|
# 复制一份参数,避免直接修改原始数据
|
||||||
new_params = dict(params)
|
new_params = dict(params)
|
||||||
# 定义需要转换的模型列表
|
# 定义需要转换的模型列表
|
||||||
models_needing_transformation = ["o3-mini", "o1-mini", "o1-preview", "o1-2024-12-17", "o1-preview-2024-09-12", "o3-mini-2025-01-31", "o1-mini-2024-09-12"]
|
models_needing_transformation = ["o3-mini", "o1-mini", "o1-preview", "o1-2024-12-17", "o1-preview-2024-09-12",
|
||||||
|
"o3-mini-2025-01-31", "o1-mini-2024-09-12"]
|
||||||
if self.model_name.lower() in models_needing_transformation:
|
if self.model_name.lower() in models_needing_transformation:
|
||||||
# 删除 'temprature' 参数(如果存在)
|
# 删除 'temprature' 参数(如果存在)
|
||||||
new_params.pop("temperature", None)
|
new_params.pop("temperature", None)
|
||||||
@@ -298,11 +304,11 @@ class LLM_request:
|
|||||||
**params_copy
|
**params_copy
|
||||||
}
|
}
|
||||||
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
|
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
|
||||||
if self.model_name.lower() in ["o3-mini", "o1-mini", "o1-preview", "o1-2024-12-17", "o1-preview-2024-09-12", "o3-mini-2025-01-31", "o1-mini-2024-09-12"] and "max_tokens" in payload:
|
if self.model_name.lower() in ["o3-mini", "o1-mini", "o1-preview", "o1-2024-12-17", "o1-preview-2024-09-12",
|
||||||
|
"o3-mini-2025-01-31", "o1-mini-2024-09-12"] and "max_tokens" in payload:
|
||||||
payload["max_completion_tokens"] = payload.pop("max_tokens")
|
payload["max_completion_tokens"] = payload.pop("max_tokens")
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
|
||||||
def _default_response_handler(self, result: dict, user_id: str = "system",
|
def _default_response_handler(self, result: dict, user_id: str = "system",
|
||||||
request_type: str = "chat", endpoint: str = "/chat/completions") -> Tuple:
|
request_type: str = "chat", endpoint: str = "/chat/completions") -> Tuple:
|
||||||
"""默认响应解析"""
|
"""默认响应解析"""
|
||||||
@@ -404,6 +410,7 @@ class LLM_request:
|
|||||||
Returns:
|
Returns:
|
||||||
list: embedding向量,如果失败则返回None
|
list: embedding向量,如果失败则返回None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def embedding_handler(result):
|
def embedding_handler(result):
|
||||||
"""处理响应"""
|
"""处理响应"""
|
||||||
if "data" in result and len(result["data"]) > 0:
|
if "data" in result and len(result["data"]) > 0:
|
||||||
@@ -425,4 +432,3 @@ class LLM_request:
|
|||||||
response_handler=embedding_handler
|
response_handler=embedding_handler
|
||||||
)
|
)
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import time
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from ..chat.config import global_config
|
from ..chat.config import global_config
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MoodState:
|
class MoodState:
|
||||||
@@ -210,7 +210,7 @@ class MoodManager:
|
|||||||
|
|
||||||
def print_mood_status(self) -> None:
|
def print_mood_status(self) -> None:
|
||||||
"""打印当前情绪状态"""
|
"""打印当前情绪状态"""
|
||||||
print(f"\033[1;35m[情绪状态]\033[0m 愉悦度: {self.current_mood.valence:.2f}, "
|
logger.info(f"[情绪状态]愉悦度: {self.current_mood.valence:.2f}, "
|
||||||
f"唤醒度: {self.current_mood.arousal:.2f}, "
|
f"唤醒度: {self.current_mood.arousal:.2f}, "
|
||||||
f"心情: {self.current_mood.text}")
|
f"心情: {self.current_mood.text}")
|
||||||
|
|
||||||
|
|||||||
@@ -57,12 +57,12 @@ class ScheduleGenerator:
|
|||||||
|
|
||||||
existing_schedule = self.db.db.schedule.find_one({"date": date_str})
|
existing_schedule = self.db.db.schedule.find_one({"date": date_str})
|
||||||
if existing_schedule:
|
if existing_schedule:
|
||||||
logger.info(f"{date_str}的日程已存在:")
|
logger.debug(f"{date_str}的日程已存在:")
|
||||||
schedule_text = existing_schedule["schedule"]
|
schedule_text = existing_schedule["schedule"]
|
||||||
# print(self.schedule_text)
|
# print(self.schedule_text)
|
||||||
|
|
||||||
elif read_only == False:
|
elif not read_only:
|
||||||
logger.info(f"{date_str}的日程不存在,准备生成新的日程。")
|
logger.debug(f"{date_str}的日程不存在,准备生成新的日程。")
|
||||||
prompt = f"""我是{global_config.BOT_NICKNAME},{global_config.PROMPT_SCHEDULE_GEN},请为我生成{date_str}({weekday})的日程安排,包括:""" + \
|
prompt = f"""我是{global_config.BOT_NICKNAME},{global_config.PROMPT_SCHEDULE_GEN},请为我生成{date_str}({weekday})的日程安排,包括:""" + \
|
||||||
"""
|
"""
|
||||||
1. 早上的学习和工作安排
|
1. 早上的学习和工作安排
|
||||||
@@ -78,7 +78,7 @@ class ScheduleGenerator:
|
|||||||
schedule_text = "生成日程时出错了"
|
schedule_text = "生成日程时出错了"
|
||||||
# print(self.schedule_text)
|
# print(self.schedule_text)
|
||||||
else:
|
else:
|
||||||
logger.info(f"{date_str}的日程不存在。")
|
logger.debug(f"{date_str}的日程不存在。")
|
||||||
schedule_text = "忘了"
|
schedule_text = "忘了"
|
||||||
|
|
||||||
return schedule_text, None
|
return schedule_text, None
|
||||||
@@ -154,10 +154,10 @@ class ScheduleGenerator:
|
|||||||
logger.warning("今日日程有误,将在下次运行时重新生成")
|
logger.warning("今日日程有误,将在下次运行时重新生成")
|
||||||
self.db.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")})
|
self.db.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")})
|
||||||
else:
|
else:
|
||||||
logger.info("\n=== 今日日程安排 ===")
|
logger.info("=== 今日日程安排 ===")
|
||||||
for time_str, activity in self.today_schedule.items():
|
for time_str, activity in self.today_schedule.items():
|
||||||
logger.info(f"时间[{time_str}]: 活动[{activity}]")
|
logger.info(f"时间[{time_str}]: 活动[{activity}]")
|
||||||
logger.info("==================\n")
|
logger.info("==================")
|
||||||
|
|
||||||
|
|
||||||
# def main():
|
# def main():
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import time
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from ...common.database import Database
|
from ...common.database import Database
|
||||||
|
|
||||||
@@ -153,8 +154,8 @@ class LLMStatistics:
|
|||||||
try:
|
try:
|
||||||
all_stats = self._collect_all_statistics()
|
all_stats = self._collect_all_statistics()
|
||||||
self._save_statistics(all_stats)
|
self._save_statistics(all_stats)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print(f"\033[1;31m[错误]\033[0m 统计数据处理失败: {e}")
|
logger.exception(f"统计数据处理失败")
|
||||||
|
|
||||||
# 等待1分钟
|
# 等待1分钟
|
||||||
for _ in range(60):
|
for _ in range(60):
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ from pathlib import Path
|
|||||||
import random
|
import random
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
class ChineseTypoGenerator:
|
class ChineseTypoGenerator:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@@ -36,7 +38,7 @@ class ChineseTypoGenerator:
|
|||||||
self.max_freq_diff = max_freq_diff
|
self.max_freq_diff = max_freq_diff
|
||||||
|
|
||||||
# 加载数据
|
# 加载数据
|
||||||
print("正在加载汉字数据库,请稍候...")
|
logger.debug("正在加载汉字数据库,请稍候...")
|
||||||
self.pinyin_dict = self._create_pinyin_dict()
|
self.pinyin_dict = self._create_pinyin_dict()
|
||||||
self.char_frequency = self._load_or_create_char_frequency()
|
self.char_frequency = self._load_or_create_char_frequency()
|
||||||
|
|
||||||
@@ -399,9 +401,10 @@ class ChineseTypoGenerator:
|
|||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
if hasattr(self, key):
|
if hasattr(self, key):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
print(f"参数 {key} 已设置为 {value}")
|
logger.debug(f"参数 {key} 已设置为 {value}")
|
||||||
else:
|
else:
|
||||||
print(f"警告: 参数 {key} 不存在")
|
logger.warning(f"警告: 参数 {key} 不存在")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# 创建错别字生成器实例
|
# 创建错别字生成器实例
|
||||||
@@ -420,18 +423,18 @@ def main():
|
|||||||
typo_sentence, typo_info = typo_generator.create_typo_sentence(sentence)
|
typo_sentence, typo_info = typo_generator.create_typo_sentence(sentence)
|
||||||
|
|
||||||
# 打印结果
|
# 打印结果
|
||||||
print("\n原句:", sentence)
|
logger.debug("原句:", sentence)
|
||||||
print("错字版:", typo_sentence)
|
logger.debug("错字版:", typo_sentence)
|
||||||
|
|
||||||
# 打印错别字信息
|
# 打印错别字信息
|
||||||
if typo_info:
|
if typo_info:
|
||||||
print("\n错别字信息:")
|
logger.debug(f"错别字信息:{typo_generator.format_typo_info(typo_info)})")
|
||||||
print(typo_generator.format_typo_info(typo_info))
|
|
||||||
|
|
||||||
# 计算并打印总耗时
|
# 计算并打印总耗时
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
total_time = end_time - start_time
|
total_time = end_time - start_time
|
||||||
print(f"\n总耗时:{total_time:.2f}秒")
|
logger.debug(f"总耗时:{total_time:.2f}秒")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user