合并
This commit is contained in:
@@ -2,7 +2,6 @@ import asyncio
|
||||
import time
|
||||
import os
|
||||
|
||||
from loguru import logger
|
||||
from nonebot import get_driver, on_message, on_notice, require
|
||||
from nonebot.rule import to_me
|
||||
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment, MessageEvent, NoticeEvent
|
||||
@@ -21,6 +20,9 @@ from ..memory_system.memory import hippocampus, memory_graph
|
||||
from .bot import ChatBot
|
||||
from .message_sender import message_manager, message_sender
|
||||
from .storage import MessageStorage
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger("chat_init")
|
||||
|
||||
# 创建LLM统计实例
|
||||
llm_stats = LLMStatistics("llm_statistics.txt")
|
||||
|
||||
@@ -12,6 +12,7 @@ from nonebot.adapters.onebot.v11 import (
|
||||
FriendRecallNoticeEvent,
|
||||
)
|
||||
|
||||
from src.common.logger import get_module_logger
|
||||
from ..memory_system.memory import hippocampus
|
||||
from ..moods.moods import MoodManager # 导入情绪管理器
|
||||
from .config import global_config
|
||||
@@ -31,11 +32,8 @@ from .utils_image import image_path_to_base64
|
||||
from .utils_user import get_user_nickname, get_user_cardname, get_groupname
|
||||
from ..willing.willing_manager import willing_manager # 导入意愿管理器
|
||||
from .message_base import UserInfo, GroupInfo, Seg
|
||||
from ..utils.logger_config import LogClassification, LogModule
|
||||
|
||||
# 配置日志
|
||||
log_module = LogModule()
|
||||
logger = log_module.setup_logger(LogClassification.CHAT)
|
||||
logger = get_module_logger("chat_bot")
|
||||
|
||||
|
||||
class ChatBot:
|
||||
@@ -336,9 +334,12 @@ class ChatBot:
|
||||
platform="qq",
|
||||
)
|
||||
|
||||
group_info = GroupInfo(
|
||||
group_id=event.group_id, group_name=None, platform="qq"
|
||||
)
|
||||
if isinstance(event, GroupRecallNoticeEvent):
|
||||
group_info = GroupInfo(
|
||||
group_id=event.group_id, group_name=None, platform="qq"
|
||||
)
|
||||
else:
|
||||
group_info = None
|
||||
|
||||
chat = await chat_manager.get_or_create_stream(
|
||||
platform=user_info.platform, user_info=user_info, group_info=group_info
|
||||
|
||||
@@ -4,11 +4,14 @@ import time
|
||||
import copy
|
||||
from typing import Dict, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from ...common.database import db
|
||||
from .message_base import GroupInfo, UserInfo
|
||||
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger("chat_stream")
|
||||
|
||||
|
||||
class ChatStream:
|
||||
"""聊天流对象,存储一个完整的聊天上下文"""
|
||||
|
||||
@@ -4,11 +4,14 @@ from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import tomli
|
||||
from loguru import logger
|
||||
from packaging import version
|
||||
from packaging.version import Version, InvalidVersion
|
||||
from packaging.specifiers import SpecifierSet, InvalidSpecifier
|
||||
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger("config")
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotConfig:
|
||||
@@ -440,10 +443,3 @@ else:
|
||||
|
||||
global_config = BotConfig.load_config(config_path=bot_config_path)
|
||||
|
||||
if not global_config.enable_advance_output:
|
||||
logger.remove()
|
||||
|
||||
# 调试输出功能
|
||||
if global_config.enable_debug_output:
|
||||
logger.remove()
|
||||
logger.add(sys.stdout, level="DEBUG")
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Dict, List, Optional, Union
|
||||
import ssl
|
||||
import os
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from src.common.logger import get_module_logger
|
||||
from nonebot import get_driver
|
||||
|
||||
from ..models.utils_model import LLM_request
|
||||
@@ -24,6 +24,7 @@ config = driver.config
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.set_ciphers("AES128-GCM-SHA256")
|
||||
|
||||
logger = get_module_logger("cq_code")
|
||||
|
||||
@dataclass
|
||||
class CQCode:
|
||||
@@ -249,6 +250,13 @@ class CQCode:
|
||||
if self.reply_message is None:
|
||||
return None
|
||||
|
||||
if hasattr(self.reply_message, "group_id"):
|
||||
group_info = GroupInfo(
|
||||
platform="qq", group_id=self.reply_message.group_id, group_name=""
|
||||
)
|
||||
else:
|
||||
group_info = None
|
||||
|
||||
if self.reply_message.sender.user_id:
|
||||
message_obj = MessageRecvCQ(
|
||||
user_info=UserInfo(
|
||||
@@ -256,7 +264,7 @@ class CQCode:
|
||||
),
|
||||
message_id=self.reply_message.message_id,
|
||||
raw_message=str(self.reply_message.message),
|
||||
group_info=GroupInfo(group_id=self.reply_message.group_id),
|
||||
group_info=group_info,
|
||||
)
|
||||
await message_obj.initialize()
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Optional, Tuple
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
from loguru import logger
|
||||
from nonebot import get_driver
|
||||
|
||||
from ...common.database import db
|
||||
@@ -17,12 +16,10 @@ from ..chat.config import global_config
|
||||
from ..chat.utils import get_embedding
|
||||
from ..chat.utils_image import ImageManager, image_path_to_base64
|
||||
from ..models.utils_model import LLM_request
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
from ..utils.logger_config import LogClassification, LogModule
|
||||
logger = get_module_logger("emoji")
|
||||
|
||||
# 配置日志
|
||||
log_module = LogModule()
|
||||
logger = log_module.setup_logger(LogClassification.EMOJI)
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
@@ -3,7 +3,6 @@ import time
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from nonebot import get_driver
|
||||
from loguru import logger
|
||||
|
||||
from ...common.database import db
|
||||
from ..models.utils_model import LLM_request
|
||||
@@ -12,6 +11,9 @@ from .message import MessageRecv, MessageThinking, Message
|
||||
from .prompt_builder import prompt_builder
|
||||
from .relationship_manager import relationship_manager
|
||||
from .utils import process_llm_response
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger("response_gen")
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
@@ -6,12 +6,14 @@ from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import urllib3
|
||||
from loguru import logger
|
||||
|
||||
from .utils_image import image_manager
|
||||
|
||||
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
|
||||
from .chat_stream import ChatStream, chat_manager
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger("chat_message")
|
||||
|
||||
# 禁用SSL警告
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from loguru import logger
|
||||
from src.common.logger import get_module_logger
|
||||
from nonebot.adapters.onebot.v11 import Bot
|
||||
from ...common.database import db
|
||||
from .message_cq import MessageSendCQ
|
||||
@@ -12,6 +12,7 @@ from .storage import MessageStorage
|
||||
from .config import global_config
|
||||
from .utils import truncate_message
|
||||
|
||||
logger = get_module_logger("msg_sender")
|
||||
|
||||
class Message_Sender:
|
||||
"""发送器"""
|
||||
@@ -50,7 +51,6 @@ class Message_Sender:
|
||||
if not is_recalled:
|
||||
message_json = message.to_dict()
|
||||
message_send = MessageSendCQ(data=message_json)
|
||||
# logger.debug(message_send.message_info,message_send.raw_message)
|
||||
message_preview = truncate_message(message.processed_plain_text)
|
||||
if message_send.message_info.group_info and message_send.message_info.group_info.group_id:
|
||||
try:
|
||||
@@ -188,16 +188,17 @@ class MessageManager:
|
||||
else:
|
||||
if (
|
||||
message_earliest.is_head
|
||||
and message_earliest.update_thinking_time() > 30
|
||||
and message_earliest.update_thinking_time() > 10
|
||||
and not message_earliest.is_private_message() # 避免在私聊时插入reply
|
||||
):
|
||||
message_earliest.set_reply()
|
||||
await message_sender.send_message(message_earliest)
|
||||
|
||||
await message_earliest.process()
|
||||
|
||||
await message_sender.send_message(message_earliest)
|
||||
|
||||
|
||||
|
||||
print(
|
||||
f"\033[1;34m[调试]\033[0m 消息“{truncate_message(message_earliest.processed_plain_text)}”正在发送中"
|
||||
)
|
||||
|
||||
await self.storage.store_message(message_earliest, message_earliest.chat_stream, None)
|
||||
|
||||
@@ -217,11 +218,11 @@ class MessageManager:
|
||||
and not message_earliest.is_private_message() # 避免在私聊时插入reply
|
||||
):
|
||||
msg.set_reply()
|
||||
|
||||
await msg.process()
|
||||
|
||||
await message_sender.send_message(msg)
|
||||
|
||||
# if msg.is_emoji:
|
||||
# msg.processed_plain_text = "[表情包]"
|
||||
await msg.process()
|
||||
|
||||
await self.storage.store_message(msg, msg.chat_stream, None)
|
||||
|
||||
if not container.remove_message(msg):
|
||||
|
||||
@@ -9,12 +9,13 @@ from ..schedule.schedule_generator import bot_schedule
|
||||
from .config import global_config
|
||||
from .utils import get_embedding, get_recent_group_detailed_plain_text, get_recent_group_speaker
|
||||
from .chat_stream import chat_manager
|
||||
<<<<<<< HEAD
|
||||
from .relationship_manager import relationship_manager
|
||||
=======
|
||||
from src.common.logger import get_module_logger
|
||||
>>>>>>> main-fix
|
||||
|
||||
from ..utils.logger_config import LogClassification, LogModule
|
||||
|
||||
log_module = LogModule()
|
||||
logger = log_module.setup_logger(LogClassification.PBUILDER)
|
||||
logger = get_module_logger("prompt")
|
||||
|
||||
logger.info("初始化Prompt系统")
|
||||
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from loguru import logger
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
from ...common.database import db
|
||||
from .message_base import UserInfo
|
||||
from .chat_stream import ChatStream
|
||||
import math
|
||||
|
||||
logger = get_module_logger("rel_manager")
|
||||
|
||||
class Impression:
|
||||
traits: str = None
|
||||
called: str = None
|
||||
|
||||
@@ -3,7 +3,9 @@ from typing import Optional, Union
|
||||
from ...common.database import db
|
||||
from .message import MessageSending, MessageRecv
|
||||
from .chat_stream import ChatStream
|
||||
from loguru import logger
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger("message_storage")
|
||||
|
||||
|
||||
class MessageStorage:
|
||||
|
||||
@@ -4,7 +4,9 @@ from nonebot import get_driver
|
||||
|
||||
from ..models.utils_model import LLM_request
|
||||
from .config import global_config
|
||||
from loguru import logger
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger("topic_identifier")
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Dict, List
|
||||
import jieba
|
||||
import numpy as np
|
||||
from nonebot import get_driver
|
||||
from loguru import logger
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
from ..models.utils_model import LLM_request
|
||||
from ..utils.typo_generator import ChineseTypoGenerator
|
||||
@@ -21,6 +21,8 @@ from ...common.database import db
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
logger = get_module_logger("chat_utils")
|
||||
|
||||
|
||||
|
||||
def db_message_to_str(message_dict: Dict) -> str:
|
||||
|
||||
@@ -7,13 +7,16 @@ from typing import Optional, Union
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
from loguru import logger
|
||||
from nonebot import get_driver
|
||||
|
||||
from ...common.database import db
|
||||
from ..chat.config import global_config
|
||||
from ..models.utils_model import LLM_request
|
||||
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger("chat_image")
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from nonebot import get_app
|
||||
from .api import router
|
||||
from loguru import logger
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
# 获取主应用实例并挂载路由
|
||||
app = get_app()
|
||||
app.include_router(router, prefix="/api")
|
||||
|
||||
# 打印日志,方便确认API已注册
|
||||
logger = get_module_logger("cfg_reload")
|
||||
logger.success("配置重载API已注册,可通过 /api/reload-config 访问")
|
||||
@@ -7,7 +7,9 @@ import jieba
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger("draw_memory")
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
|
||||
@@ -3,7 +3,6 @@ import datetime
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
import os
|
||||
|
||||
import jieba
|
||||
import networkx as nx
|
||||
@@ -18,14 +17,10 @@ from ..chat.utils import (
|
||||
text_to_vector,
|
||||
)
|
||||
from ..models.utils_model import LLM_request
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
from ..utils.logger_config import LogClassification, LogModule
|
||||
logger = get_module_logger("memory_sys")
|
||||
|
||||
# 配置日志
|
||||
log_module = LogModule()
|
||||
logger = log_module.setup_logger(LogClassification.MEMORY)
|
||||
|
||||
logger.info("初始化记忆系统")
|
||||
|
||||
class Memory_graph:
|
||||
def __init__(self):
|
||||
@@ -35,9 +30,9 @@ class Memory_graph:
|
||||
# 避免自连接
|
||||
if concept1 == concept2:
|
||||
return
|
||||
|
||||
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
|
||||
|
||||
# 如果边已存在,增加 strength
|
||||
if self.G.has_edge(concept1, concept2):
|
||||
self.G[concept1][concept2]['strength'] = self.G[concept1][concept2].get('strength', 1) + 1
|
||||
@@ -45,14 +40,14 @@ class Memory_graph:
|
||||
self.G[concept1][concept2]['last_modified'] = current_time
|
||||
else:
|
||||
# 如果是新边,初始化 strength 为 1
|
||||
self.G.add_edge(concept1, concept2,
|
||||
strength=1,
|
||||
created_time=current_time, # 添加创建时间
|
||||
last_modified=current_time) # 添加最后修改时间
|
||||
self.G.add_edge(concept1, concept2,
|
||||
strength=1,
|
||||
created_time=current_time, # 添加创建时间
|
||||
last_modified=current_time) # 添加最后修改时间
|
||||
|
||||
def add_dot(self, concept, memory):
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
|
||||
|
||||
if concept in self.G:
|
||||
if 'memory_items' in self.G.nodes[concept]:
|
||||
if not isinstance(self.G.nodes[concept]['memory_items'], list):
|
||||
@@ -68,10 +63,10 @@ class Memory_graph:
|
||||
self.G.nodes[concept]['last_modified'] = current_time
|
||||
else:
|
||||
# 如果是新节点,创建新的记忆列表
|
||||
self.G.add_node(concept,
|
||||
memory_items=[memory],
|
||||
created_time=current_time, # 添加创建时间
|
||||
last_modified=current_time) # 添加最后修改时间
|
||||
self.G.add_node(concept,
|
||||
memory_items=[memory],
|
||||
created_time=current_time, # 添加创建时间
|
||||
last_modified=current_time) # 添加最后修改时间
|
||||
|
||||
def get_dot(self, concept):
|
||||
# 检查节点是否存在于图中
|
||||
@@ -210,12 +205,13 @@ class Hippocampus:
|
||||
# 成功抽取短期消息样本
|
||||
# 数据写回:增加记忆次数
|
||||
for message in messages:
|
||||
db.messages.update_one({"_id": message["_id"]}, {"$set": {"memorized_times": message["memorized_times"] + 1}})
|
||||
db.messages.update_one({"_id": message["_id"]},
|
||||
{"$set": {"memorized_times": message["memorized_times"] + 1}})
|
||||
return messages
|
||||
try_count += 1
|
||||
# 三次尝试均失败
|
||||
return None
|
||||
|
||||
|
||||
def get_memory_sample(self, chat_size=20, time_frequency: dict = {'near': 2, 'mid': 4, 'far': 3}):
|
||||
"""获取记忆样本
|
||||
|
||||
@@ -225,7 +221,7 @@ class Hippocampus:
|
||||
# 硬编码:每条消息最大记忆次数
|
||||
# 如有需求可写入global_config
|
||||
max_memorized_time_per_msg = 3
|
||||
|
||||
|
||||
current_timestamp = datetime.datetime.now().timestamp()
|
||||
chat_samples = []
|
||||
|
||||
@@ -324,20 +320,20 @@ class Hippocampus:
|
||||
# 为每个话题查找相似的已存在主题
|
||||
existing_topics = list(self.memory_graph.G.nodes())
|
||||
similar_topics = []
|
||||
|
||||
|
||||
for existing_topic in existing_topics:
|
||||
topic_words = set(jieba.cut(topic))
|
||||
existing_words = set(jieba.cut(existing_topic))
|
||||
|
||||
|
||||
all_words = topic_words | existing_words
|
||||
v1 = [1 if word in topic_words else 0 for word in all_words]
|
||||
v2 = [1 if word in existing_words else 0 for word in all_words]
|
||||
|
||||
|
||||
similarity = cosine_similarity(v1, v2)
|
||||
|
||||
|
||||
if similarity >= 0.6:
|
||||
similar_topics.append((existing_topic, similarity))
|
||||
|
||||
|
||||
similar_topics.sort(key=lambda x: x[1], reverse=True)
|
||||
similar_topics = similar_topics[:5]
|
||||
similar_topics_dict[topic] = similar_topics
|
||||
@@ -358,7 +354,7 @@ class Hippocampus:
|
||||
async def operation_build_memory(self, chat_size=20):
|
||||
time_frequency = {'near': 1, 'mid': 4, 'far': 4}
|
||||
memory_samples = self.get_memory_sample(chat_size, time_frequency)
|
||||
|
||||
|
||||
for i, messages in enumerate(memory_samples, 1):
|
||||
all_topics = []
|
||||
# 加载进度可视化
|
||||
@@ -371,14 +367,14 @@ class Hippocampus:
|
||||
compress_rate = global_config.memory_compress_rate
|
||||
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate)
|
||||
logger.info(f"压缩后记忆数量: {len(compressed_memory)},似曾相识的话题: {len(similar_topics_dict)}")
|
||||
|
||||
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
|
||||
|
||||
for topic, memory in compressed_memory:
|
||||
logger.info(f"添加节点: {topic}")
|
||||
self.memory_graph.add_dot(topic, memory)
|
||||
all_topics.append(topic)
|
||||
|
||||
|
||||
# 连接相似的已存在主题
|
||||
if topic in similar_topics_dict:
|
||||
similar_topics = similar_topics_dict[topic]
|
||||
@@ -386,11 +382,11 @@ class Hippocampus:
|
||||
if topic != similar_topic:
|
||||
strength = int(similarity * 10)
|
||||
logger.info(f"连接相似节点: {topic} 和 {similar_topic} (强度: {strength})")
|
||||
self.memory_graph.G.add_edge(topic, similar_topic,
|
||||
strength=strength,
|
||||
created_time=current_time,
|
||||
last_modified=current_time)
|
||||
|
||||
self.memory_graph.G.add_edge(topic, similar_topic,
|
||||
strength=strength,
|
||||
created_time=current_time,
|
||||
last_modified=current_time)
|
||||
|
||||
# 连接同批次的相关话题
|
||||
for i in range(len(all_topics)):
|
||||
for j in range(i + 1, len(all_topics)):
|
||||
@@ -416,7 +412,7 @@ class Hippocampus:
|
||||
|
||||
# 计算内存中节点的特征值
|
||||
memory_hash = self.calculate_node_hash(concept, memory_items)
|
||||
|
||||
|
||||
# 获取时间信息
|
||||
created_time = data.get('created_time', datetime.datetime.now().timestamp())
|
||||
last_modified = data.get('last_modified', datetime.datetime.now().timestamp())
|
||||
@@ -466,7 +462,7 @@ class Hippocampus:
|
||||
edge_hash = self.calculate_edge_hash(source, target)
|
||||
edge_key = (source, target)
|
||||
strength = data.get('strength', 1)
|
||||
|
||||
|
||||
# 获取边的时间信息
|
||||
created_time = data.get('created_time', datetime.datetime.now().timestamp())
|
||||
last_modified = data.get('last_modified', datetime.datetime.now().timestamp())
|
||||
@@ -499,7 +495,7 @@ class Hippocampus:
|
||||
"""从数据库同步数据到内存中的图结构"""
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
need_update = False
|
||||
|
||||
|
||||
# 清空当前图
|
||||
self.memory_graph.G.clear()
|
||||
|
||||
@@ -510,7 +506,7 @@ class Hippocampus:
|
||||
memory_items = node.get('memory_items', [])
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
|
||||
|
||||
# 检查时间字段是否存在
|
||||
if 'created_time' not in node or 'last_modified' not in node:
|
||||
need_update = True
|
||||
@@ -520,22 +516,22 @@ class Hippocampus:
|
||||
update_data['created_time'] = current_time
|
||||
if 'last_modified' not in node:
|
||||
update_data['last_modified'] = current_time
|
||||
|
||||
|
||||
db.graph_data.nodes.update_one(
|
||||
{'concept': concept},
|
||||
{'$set': update_data}
|
||||
)
|
||||
logger.info(f"[时间更新] 节点 {concept} 添加缺失的时间字段")
|
||||
|
||||
|
||||
# 获取时间信息(如果不存在则使用当前时间)
|
||||
created_time = node.get('created_time', current_time)
|
||||
last_modified = node.get('last_modified', current_time)
|
||||
|
||||
|
||||
# 添加节点到图中
|
||||
self.memory_graph.G.add_node(concept,
|
||||
memory_items=memory_items,
|
||||
created_time=created_time,
|
||||
last_modified=last_modified)
|
||||
self.memory_graph.G.add_node(concept,
|
||||
memory_items=memory_items,
|
||||
created_time=created_time,
|
||||
last_modified=last_modified)
|
||||
|
||||
# 从数据库加载所有边
|
||||
edges = list(db.graph_data.edges.find())
|
||||
@@ -543,7 +539,7 @@ class Hippocampus:
|
||||
source = edge['source']
|
||||
target = edge['target']
|
||||
strength = edge.get('strength', 1)
|
||||
|
||||
|
||||
# 检查时间字段是否存在
|
||||
if 'created_time' not in edge or 'last_modified' not in edge:
|
||||
need_update = True
|
||||
@@ -553,24 +549,24 @@ class Hippocampus:
|
||||
update_data['created_time'] = current_time
|
||||
if 'last_modified' not in edge:
|
||||
update_data['last_modified'] = current_time
|
||||
|
||||
|
||||
db.graph_data.edges.update_one(
|
||||
{'source': source, 'target': target},
|
||||
{'$set': update_data}
|
||||
)
|
||||
logger.info(f"[时间更新] 边 {source} - {target} 添加缺失的时间字段")
|
||||
|
||||
|
||||
# 获取时间信息(如果不存在则使用当前时间)
|
||||
created_time = edge.get('created_time', current_time)
|
||||
last_modified = edge.get('last_modified', current_time)
|
||||
|
||||
|
||||
# 只有当源节点和目标节点都存在时才添加边
|
||||
if source in self.memory_graph.G and target in self.memory_graph.G:
|
||||
self.memory_graph.G.add_edge(source, target,
|
||||
strength=strength,
|
||||
created_time=created_time,
|
||||
last_modified=last_modified)
|
||||
|
||||
self.memory_graph.G.add_edge(source, target,
|
||||
strength=strength,
|
||||
created_time=created_time,
|
||||
last_modified=last_modified)
|
||||
|
||||
if need_update:
|
||||
logger.success("[数据库] 已为缺失的时间字段进行补充")
|
||||
|
||||
@@ -578,44 +574,44 @@ class Hippocampus:
|
||||
"""随机选择图中一定比例的节点和边进行检查,根据时间条件决定是否遗忘"""
|
||||
# 检查数据库是否为空
|
||||
# logger.remove()
|
||||
|
||||
|
||||
logger.info(f"[遗忘] 开始检查数据库... 当前Logger信息:")
|
||||
# logger.info(f"- Logger名称: {logger.name}")
|
||||
logger.info(f"- Logger等级: {logger.level}")
|
||||
# logger.info(f"- Logger处理器: {[handler.__class__.__name__ for handler in logger.handlers]}")
|
||||
|
||||
|
||||
# logger2 = setup_logger(LogModule.MEMORY)
|
||||
# logger2.info(f"[遗忘] 开始检查数据库... 当前Logger信息:")
|
||||
# logger.info(f"[遗忘] 开始检查数据库... 当前Logger信息:")
|
||||
|
||||
|
||||
all_nodes = list(self.memory_graph.G.nodes())
|
||||
all_edges = list(self.memory_graph.G.edges())
|
||||
|
||||
|
||||
if not all_nodes and not all_edges:
|
||||
logger.info("[遗忘] 记忆图为空,无需进行遗忘操作")
|
||||
return
|
||||
|
||||
|
||||
check_nodes_count = max(1, int(len(all_nodes) * percentage))
|
||||
check_edges_count = max(1, int(len(all_edges) * percentage))
|
||||
|
||||
|
||||
nodes_to_check = random.sample(all_nodes, check_nodes_count)
|
||||
edges_to_check = random.sample(all_edges, check_edges_count)
|
||||
|
||||
|
||||
edge_changes = {'weakened': 0, 'removed': 0}
|
||||
node_changes = {'reduced': 0, 'removed': 0}
|
||||
|
||||
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
|
||||
|
||||
# 检查并遗忘连接
|
||||
logger.info("[遗忘] 开始检查连接...")
|
||||
for source, target in edges_to_check:
|
||||
edge_data = self.memory_graph.G[source][target]
|
||||
last_modified = edge_data.get('last_modified')
|
||||
|
||||
if current_time - last_modified > 3600*global_config.memory_forget_time:
|
||||
|
||||
if current_time - last_modified > 3600 * global_config.memory_forget_time:
|
||||
current_strength = edge_data.get('strength', 1)
|
||||
new_strength = current_strength - 1
|
||||
|
||||
|
||||
if new_strength <= 0:
|
||||
self.memory_graph.G.remove_edge(source, target)
|
||||
edge_changes['removed'] += 1
|
||||
@@ -625,23 +621,23 @@ class Hippocampus:
|
||||
edge_data['last_modified'] = current_time
|
||||
edge_changes['weakened'] += 1
|
||||
logger.info(f"[遗忘] 连接减弱: {source} -> {target} (强度: {current_strength} -> {new_strength})")
|
||||
|
||||
|
||||
# 检查并遗忘话题
|
||||
logger.info("[遗忘] 开始检查节点...")
|
||||
for node in nodes_to_check:
|
||||
node_data = self.memory_graph.G.nodes[node]
|
||||
last_modified = node_data.get('last_modified', current_time)
|
||||
|
||||
if current_time - last_modified > 3600*24:
|
||||
|
||||
if current_time - last_modified > 3600 * 24:
|
||||
memory_items = node_data.get('memory_items', [])
|
||||
if not isinstance(memory_items, list):
|
||||
memory_items = [memory_items] if memory_items else []
|
||||
|
||||
|
||||
if memory_items:
|
||||
current_count = len(memory_items)
|
||||
removed_item = random.choice(memory_items)
|
||||
memory_items.remove(removed_item)
|
||||
|
||||
|
||||
if memory_items:
|
||||
self.memory_graph.G.nodes[node]['memory_items'] = memory_items
|
||||
self.memory_graph.G.nodes[node]['last_modified'] = current_time
|
||||
@@ -651,7 +647,7 @@ class Hippocampus:
|
||||
self.memory_graph.G.remove_node(node)
|
||||
node_changes['removed'] += 1
|
||||
logger.info(f"[遗忘] 节点移除: {node}")
|
||||
|
||||
|
||||
if any(count > 0 for count in edge_changes.values()) or any(count > 0 for count in node_changes.values()):
|
||||
self.sync_memory_to_db()
|
||||
logger.info("[遗忘] 统计信息:")
|
||||
@@ -882,8 +878,8 @@ class Hippocampus:
|
||||
matched_topics.add(input_topic)
|
||||
adjusted_sim = sim * penalty
|
||||
topic_similarities[input_topic] = max(topic_similarities.get(input_topic, 0), adjusted_sim)
|
||||
logger.debug(
|
||||
f"[激活] 主题「{input_topic}」-> 「{memory_topic}」(内容数: {content_count}, 相似度: {adjusted_sim:.3f})")
|
||||
# logger.debug(
|
||||
# f"[激活] 主题「{input_topic}」-> 「{memory_topic}」(内容数: {content_count}, 相似度: {adjusted_sim:.3f})")
|
||||
|
||||
# 计算主题匹配率和平均相似度
|
||||
topic_match = len(matched_topics) / len(identified_topics)
|
||||
@@ -943,6 +939,7 @@ def segment_text(text):
|
||||
seg_text = list(jieba.cut(text))
|
||||
return seg_text
|
||||
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from src.common.logger import get_module_logger
|
||||
import jieba
|
||||
|
||||
# from chat.config import global_config
|
||||
@@ -29,6 +29,8 @@ project_root = current_dir.parent.parent.parent
|
||||
# env.dev文件路径
|
||||
env_path = project_root / ".env.dev"
|
||||
|
||||
logger = get_module_logger("mem_manual_bd")
|
||||
|
||||
# 加载环境变量
|
||||
if env_path.exists():
|
||||
logger.info(f"从 {env_path} 加载环境变量")
|
||||
|
||||
@@ -12,9 +12,11 @@ import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
import pymongo
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from src.common.logger import get_module_logger
|
||||
import jieba
|
||||
|
||||
logger = get_module_logger("mem_test")
|
||||
|
||||
'''
|
||||
该理论认为,当两个或多个事物在形态上具有相似性时,
|
||||
它们在记忆中会形成关联。
|
||||
|
||||
@@ -5,8 +5,9 @@ from typing import Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from loguru import logger
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger("offline_llm")
|
||||
|
||||
class LLMModel:
|
||||
def __init__(self, model_name="deepseek-ai/DeepSeek-V3", **kwargs):
|
||||
|
||||
@@ -5,7 +5,7 @@ from datetime import datetime
|
||||
from typing import Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from src.common.logger import get_module_logger
|
||||
from nonebot import get_driver
|
||||
import base64
|
||||
from PIL import Image
|
||||
@@ -16,6 +16,8 @@ from ..chat.config import global_config
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
logger = get_module_logger("model_utils")
|
||||
|
||||
|
||||
class LLM_request:
|
||||
# 定义需要转换的模型列表,作为类变量避免重复
|
||||
@@ -165,8 +167,8 @@ class LLM_request:
|
||||
# 判断是否为流式
|
||||
stream_mode = self.params.get("stream", False)
|
||||
logger_msg = "进入流式输出模式," if stream_mode else ""
|
||||
logger.debug(f"{logger_msg}发送请求到URL: {api_url}")
|
||||
logger.info(f"使用模型: {self.model_name}")
|
||||
# logger.debug(f"{logger_msg}发送请求到URL: {api_url}")
|
||||
# logger.info(f"使用模型: {self.model_name}")
|
||||
|
||||
# 构建请求体
|
||||
if image_base64:
|
||||
|
||||
@@ -4,7 +4,9 @@ import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..chat.config import global_config
|
||||
from loguru import logger
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger("mood_manager")
|
||||
|
||||
@dataclass
|
||||
class MoodState:
|
||||
|
||||
@@ -5,7 +5,9 @@ import platform
|
||||
import os
|
||||
import json
|
||||
import threading
|
||||
from loguru import logger
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger("remote")
|
||||
|
||||
# UUID文件路径
|
||||
UUID_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "client_uuid.json")
|
||||
@@ -30,9 +32,9 @@ def get_unique_id():
|
||||
try:
|
||||
with open(UUID_FILE, "w") as f:
|
||||
json.dump({"client_id": client_id}, f)
|
||||
print("已保存新生成的客户端ID到本地文件")
|
||||
logger.info("已保存新生成的客户端ID到本地文件")
|
||||
except IOError as e:
|
||||
print(f"保存UUID时出错: {e}")
|
||||
logger.error(f"保存UUID时出错: {e}")
|
||||
|
||||
return client_id
|
||||
|
||||
|
||||
@@ -3,13 +3,15 @@ import json
|
||||
import re
|
||||
from typing import Dict, Union
|
||||
|
||||
from loguru import logger
|
||||
from nonebot import get_driver
|
||||
|
||||
from src.plugins.chat.config import global_config
|
||||
|
||||
from ...common.database import db # 使用正确的导入语法
|
||||
from ..models.utils_model import LLM_request
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger("scheduler")
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
@@ -3,10 +3,11 @@ import time
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict
|
||||
from loguru import logger
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
from ...common.database import db
|
||||
|
||||
logger = get_module_logger("llm_statistics")
|
||||
|
||||
class LLMStatistics:
|
||||
def __init__(self, output_file: str = "llm_statistics.txt"):
|
||||
|
||||
@@ -13,8 +13,9 @@ from pathlib import Path
|
||||
import jieba
|
||||
from pypinyin import Style, pinyin
|
||||
|
||||
from loguru import logger
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger("typo_gen")
|
||||
|
||||
class ChineseTypoGenerator:
|
||||
def __init__(self,
|
||||
|
||||
@@ -2,7 +2,9 @@ import asyncio
|
||||
import random
|
||||
import time
|
||||
from typing import Dict
|
||||
from loguru import logger
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger("mode_dynamic")
|
||||
|
||||
|
||||
from ..chat.config import global_config
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from typing import Optional
|
||||
from loguru import logger
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
from ..chat.config import global_config
|
||||
from .mode_classical import WillingManager as ClassicalWillingManager
|
||||
from .mode_dynamic import WillingManager as DynamicWillingManager
|
||||
from .mode_custom import WillingManager as CustomWillingManager
|
||||
|
||||
logger = get_module_logger("willing")
|
||||
|
||||
def init_willing_manager() -> Optional[object]:
|
||||
"""
|
||||
根据配置初始化并返回对应的WillingManager实例
|
||||
|
||||
Reference in New Issue
Block a user