Merge pull request #412 from KX76/main-fix

refactor:日志重构
This commit is contained in:
SengokuCola
2025-03-15 17:37:04 +08:00
committed by GitHub
34 changed files with 366 additions and 135 deletions

16
bot.py
View File

@@ -11,10 +11,11 @@ import uvicorn
from dotenv import load_dotenv from dotenv import load_dotenv
from nonebot.adapters.onebot.v11 import Adapter from nonebot.adapters.onebot.v11 import Adapter
import platform import platform
from src.plugins.utils.logger_config import LogModule, LogClassification from src.common.logger import get_module_logger
# 配置日志格式 # 配置主程序日志格式
logger = get_module_logger("main_bot")
# 获取没有加载env时的环境变量 # 获取没有加载env时的环境变量
env_mask = {key: os.getenv(key) for key in os.environ} env_mask = {key: os.getenv(key) for key in os.environ}
@@ -77,11 +78,11 @@ def init_env():
def load_env(): def load_env():
# 使用闭包实现对加载器的横向扩展,避免大量重复判断 # 使用闭包实现对加载器的横向扩展,避免大量重复判断
def prod(): def prod():
logger.success("加载生产环境变量配置") logger.success("成功加载生产环境变量配置")
load_dotenv(".env.prod", override=True) # override=True 允许覆盖已存在的环境变量 load_dotenv(".env.prod", override=True) # override=True 允许覆盖已存在的环境变量
def dev(): def dev():
logger.success("加载开发环境变量配置") logger.success("成功加载开发环境变量配置")
load_dotenv(".env.dev", override=True) # override=True 允许覆盖已存在的环境变量 load_dotenv(".env.dev", override=True) # override=True 允许覆盖已存在的环境变量
fn_map = {"prod": prod, "dev": dev} fn_map = {"prod": prod, "dev": dev}
@@ -101,11 +102,6 @@ def load_env():
RuntimeError(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在") RuntimeError(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在")
def load_logger():
global logger # 使得bot.py中其他函数也能调用
log_module = LogModule()
logger = log_module.setup_logger(LogClassification.BASE)
def scan_provider(env_config: dict): def scan_provider(env_config: dict):
provider = {} provider = {}
@@ -229,8 +225,6 @@ def raw_main():
if __name__ == "__main__": if __name__ == "__main__":
try: try:
# 配置日志使得主程序直接退出时候也能访问logger
load_logger()
raw_main() raw_main()
app = nonebot.get_asgi() app = nonebot.get_asgi()

198
src/common/logger.py Normal file
View File

@@ -0,0 +1,198 @@
from loguru import logger
from typing import Dict, Optional, Union, List
import sys
import os
from types import ModuleType
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
# 保存原生处理器ID
default_handler_id = None
for handler_id in logger._core.handlers:
default_handler_id = handler_id
break
# 移除默认处理器
if default_handler_id is not None:
logger.remove(default_handler_id)
# 类型别名
LoguruLogger = logger.__class__
# 全局注册表记录模块与处理器ID的映射
_handler_registry: Dict[str, List[int]] = {}
# 获取日志存储根地址
current_file_path = Path(__file__).resolve()
LOG_ROOT = "logs"
# 默认全局配置
DEFAULT_CONFIG = {
# 日志级别配置
"console_level": "INFO",
"file_level": "DEBUG",
# 格式配置
"console_format": (
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
"<level>{level: <8}</level> | "
"<cyan>{extra[module]: <12}</cyan> | "
"<level>{message}</level>"
),
"file_format": (
"{time:YYYY-MM-DD HH:mm:ss} | "
"{level: <8} | "
"{extra[module]: <15} | "
"{message}"
),
"log_dir": LOG_ROOT,
"rotation": "00:00",
"retention": "3 days",
"compression": "zip",
}
def is_registered_module(record: dict) -> bool:
"""检查是否为已注册的模块"""
return record["extra"].get("module") in _handler_registry
def is_unregistered_module(record: dict) -> bool:
"""检查是否为未注册的模块"""
return not is_registered_module(record)
def log_patcher(record: dict) -> None:
"""自动填充未设置模块名的日志记录,保留原生模块名称"""
if "module" not in record["extra"]:
# 尝试从name中提取模块名
module_name = record.get("name", "")
if module_name == "":
module_name = "root"
record["extra"]["module"] = module_name
# 应用全局修补器
logger.configure(patcher=log_patcher)
class LogConfig:
"""日志配置类"""
def __init__(self, **kwargs):
self.config = DEFAULT_CONFIG.copy()
self.config.update(kwargs)
def to_dict(self) -> dict:
return self.config.copy()
def update(self, **kwargs):
self.config.update(kwargs)
def get_module_logger(
module: Union[str, ModuleType],
*,
console_level: Optional[str] = None,
file_level: Optional[str] = None,
extra_handlers: Optional[List[dict]] = None,
config: Optional[LogConfig] = None
) -> LoguruLogger:
module_name = module if isinstance(module, str) else module.__name__
current_config = config.config if config else DEFAULT_CONFIG
# 清理旧处理器
if module_name in _handler_registry:
for handler_id in _handler_registry[module_name]:
logger.remove(handler_id)
del _handler_registry[module_name]
handler_ids = []
# 控制台处理器
console_id = logger.add(
sink=sys.stderr,
level=os.getenv("CONSOLE_LOG_LEVEL", console_level or current_config["console_level"]),
format=current_config["console_format"],
filter=lambda record: record["extra"].get("module") == module_name,
enqueue=True,
)
handler_ids.append(console_id)
# 文件处理器
log_dir = Path(current_config["log_dir"])
log_dir.mkdir(parents=True, exist_ok=True)
log_file = log_dir / module_name / f"{{time:YYYY-MM-DD}}.log"
log_file.parent.mkdir(parents=True, exist_ok=True)
file_id = logger.add(
sink=str(log_file),
level=os.getenv("FILE_LOG_LEVEL", file_level or current_config["file_level"]),
format=current_config["file_format"],
rotation=current_config["rotation"],
retention=current_config["retention"],
compression=current_config["compression"],
encoding="utf-8",
filter=lambda record: record["extra"].get("module") == module_name,
enqueue=True,
)
handler_ids.append(file_id)
# 额外处理器
if extra_handlers:
for handler in extra_handlers:
handler_id = logger.add(**handler)
handler_ids.append(handler_id)
# 更新注册表
_handler_registry[module_name] = handler_ids
return logger.bind(module=module_name)
def remove_module_logger(module_name: str) -> None:
"""清理指定模块的日志处理器"""
if module_name in _handler_registry:
for handler_id in _handler_registry[module_name]:
logger.remove(handler_id)
del _handler_registry[module_name]
# 添加全局默认处理器(只处理未注册模块的日志--->控制台)
DEFAULT_GLOBAL_HANDLER = logger.add(
sink=sys.stderr,
level=os.getenv("DEFAULT_CONSOLE_LOG_LEVEL", "SUCCESS"),
format=(
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
"<level>{level: <8}</level> | "
"<cyan>{name: <12}</cyan> | "
"<level>{message}</level>"
),
filter=is_unregistered_module, # 只处理未注册模块的日志
enqueue=True,
)
# 添加全局默认文件处理器(只处理未注册模块的日志--->logs文件夹
log_dir = Path(DEFAULT_CONFIG["log_dir"])
log_dir.mkdir(parents=True, exist_ok=True)
other_log_dir = log_dir / "other"
other_log_dir.mkdir(parents=True, exist_ok=True)
DEFAULT_FILE_HANDLER = logger.add(
sink=str(other_log_dir / f"{{time:YYYY-MM-DD}}.log"),
level=os.getenv("DEFAULT_FILE_LOG_LEVEL", "DEBUG"),
format=(
"{time:YYYY-MM-DD HH:mm:ss} | "
"{level: <8} | "
"{name: <15} | "
"{message}"
),
rotation=DEFAULT_CONFIG["rotation"],
retention=DEFAULT_CONFIG["retention"],
compression=DEFAULT_CONFIG["compression"],
encoding="utf-8",
filter=is_unregistered_module, # 只处理未注册模块的日志
enqueue=True,
)

View File

@@ -5,13 +5,14 @@ import threading
import time import time
from datetime import datetime from datetime import datetime
from typing import Dict, List from typing import Dict, List
from loguru import logger
from typing import Optional from typing import Optional
from src.common.logger import get_module_logger
import customtkinter as ctk import customtkinter as ctk
from dotenv import load_dotenv from dotenv import load_dotenv
logger = get_module_logger("gui")
# 获取当前文件的目录 # 获取当前文件的目录
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
# 获取项目根目录 # 获取项目根目录
@@ -30,6 +31,7 @@ else:
logger.error("未找到环境配置文件") logger.error("未找到环境配置文件")
sys.exit(1) sys.exit(1)
class ReasoningGUI: class ReasoningGUI:
def __init__(self): def __init__(self):
# 记录启动时间戳转换为Unix时间戳 # 记录启动时间戳转换为Unix时间戳

View File

@@ -2,7 +2,6 @@ import asyncio
import time import time
import os import os
from loguru import logger
from nonebot import get_driver, on_message, on_notice, require from nonebot import get_driver, on_message, on_notice, require
from nonebot.rule import to_me from nonebot.rule import to_me
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment, MessageEvent, NoticeEvent from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment, MessageEvent, NoticeEvent
@@ -21,6 +20,9 @@ from ..memory_system.memory import hippocampus, memory_graph
from .bot import ChatBot from .bot import ChatBot
from .message_sender import message_manager, message_sender from .message_sender import message_manager, message_sender
from .storage import MessageStorage from .storage import MessageStorage
from src.common.logger import get_module_logger
logger = get_module_logger("chat_init")
# 创建LLM统计实例 # 创建LLM统计实例
llm_stats = LLMStatistics("llm_statistics.txt") llm_stats = LLMStatistics("llm_statistics.txt")

View File

@@ -12,6 +12,7 @@ from nonebot.adapters.onebot.v11 import (
FriendRecallNoticeEvent, FriendRecallNoticeEvent,
) )
from src.common.logger import get_module_logger
from ..memory_system.memory import hippocampus from ..memory_system.memory import hippocampus
from ..moods.moods import MoodManager # 导入情绪管理器 from ..moods.moods import MoodManager # 导入情绪管理器
from .config import global_config from .config import global_config
@@ -31,11 +32,8 @@ from .utils_image import image_path_to_base64
from .utils_user import get_user_nickname, get_user_cardname, get_groupname from .utils_user import get_user_nickname, get_user_cardname, get_groupname
from ..willing.willing_manager import willing_manager # 导入意愿管理器 from ..willing.willing_manager import willing_manager # 导入意愿管理器
from .message_base import UserInfo, GroupInfo, Seg from .message_base import UserInfo, GroupInfo, Seg
from ..utils.logger_config import LogClassification, LogModule
# 配置日志 logger = get_module_logger("chat_bot")
log_module = LogModule()
logger = log_module.setup_logger(LogClassification.CHAT)
class ChatBot: class ChatBot:

View File

@@ -4,11 +4,14 @@ import time
import copy import copy
from typing import Dict, Optional from typing import Dict, Optional
from loguru import logger
from ...common.database import db from ...common.database import db
from .message_base import GroupInfo, UserInfo from .message_base import GroupInfo, UserInfo
from src.common.logger import get_module_logger
logger = get_module_logger("chat_stream")
class ChatStream: class ChatStream:
"""聊天流对象,存储一个完整的聊天上下文""" """聊天流对象,存储一个完整的聊天上下文"""

View File

@@ -4,11 +4,14 @@ from dataclasses import dataclass, field
from typing import Dict, List, Optional from typing import Dict, List, Optional
import tomli import tomli
from loguru import logger
from packaging import version from packaging import version
from packaging.version import Version, InvalidVersion from packaging.version import Version, InvalidVersion
from packaging.specifiers import SpecifierSet, InvalidSpecifier from packaging.specifiers import SpecifierSet, InvalidSpecifier
from src.common.logger import get_module_logger
logger = get_module_logger("config")
@dataclass @dataclass
class BotConfig: class BotConfig:
@@ -440,10 +443,3 @@ else:
global_config = BotConfig.load_config(config_path=bot_config_path) global_config = BotConfig.load_config(config_path=bot_config_path)
if not global_config.enable_advance_output:
logger.remove()
# 调试输出功能
if global_config.enable_debug_output:
logger.remove()
logger.add(sys.stdout, level="DEBUG")

View File

@@ -7,7 +7,7 @@ from typing import Dict, List, Optional, Union
import ssl import ssl
import os import os
import aiohttp import aiohttp
from loguru import logger from src.common.logger import get_module_logger
from nonebot import get_driver from nonebot import get_driver
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
@@ -24,6 +24,7 @@ config = driver.config
ssl_context = ssl.create_default_context() ssl_context = ssl.create_default_context()
ssl_context.set_ciphers("AES128-GCM-SHA256") ssl_context.set_ciphers("AES128-GCM-SHA256")
logger = get_module_logger("cq_code")
@dataclass @dataclass
class CQCode: class CQCode:

View File

@@ -9,7 +9,6 @@ from typing import Optional, Tuple
from PIL import Image from PIL import Image
import io import io
from loguru import logger
from nonebot import get_driver from nonebot import get_driver
from ...common.database import db from ...common.database import db
@@ -17,12 +16,10 @@ from ..chat.config import global_config
from ..chat.utils import get_embedding from ..chat.utils import get_embedding
from ..chat.utils_image import ImageManager, image_path_to_base64 from ..chat.utils_image import ImageManager, image_path_to_base64
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from src.common.logger import get_module_logger
from ..utils.logger_config import LogClassification, LogModule logger = get_module_logger("emoji")
# 配置日志
log_module = LogModule()
logger = log_module.setup_logger(LogClassification.EMOJI)
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config

View File

@@ -3,7 +3,6 @@ import time
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from nonebot import get_driver from nonebot import get_driver
from loguru import logger
from ...common.database import db from ...common.database import db
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
@@ -12,6 +11,9 @@ from .message import MessageRecv, MessageThinking, Message
from .prompt_builder import prompt_builder from .prompt_builder import prompt_builder
from .relationship_manager import relationship_manager from .relationship_manager import relationship_manager
from .utils import process_llm_response from .utils import process_llm_response
from src.common.logger import get_module_logger
logger = get_module_logger("response_gen")
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config

View File

@@ -6,12 +6,14 @@ from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional
import urllib3 import urllib3
from loguru import logger
from .utils_image import image_manager from .utils_image import image_manager
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
from .chat_stream import ChatStream, chat_manager from .chat_stream import ChatStream, chat_manager
from src.common.logger import get_module_logger
logger = get_module_logger("chat_message")
# 禁用SSL警告 # 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

View File

@@ -2,7 +2,7 @@ import asyncio
import time import time
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from loguru import logger from src.common.logger import get_module_logger
from nonebot.adapters.onebot.v11 import Bot from nonebot.adapters.onebot.v11 import Bot
from ...common.database import db from ...common.database import db
from .message_cq import MessageSendCQ from .message_cq import MessageSendCQ
@@ -12,6 +12,7 @@ from .storage import MessageStorage
from .config import global_config from .config import global_config
from .utils import truncate_message from .utils import truncate_message
logger = get_module_logger("msg_sender")
class Message_Sender: class Message_Sender:
"""发送器""" """发送器"""

View File

@@ -9,11 +9,9 @@ from ..schedule.schedule_generator import bot_schedule
from .config import global_config from .config import global_config
from .utils import get_embedding, get_recent_group_detailed_plain_text from .utils import get_embedding, get_recent_group_detailed_plain_text
from .chat_stream import chat_manager from .chat_stream import chat_manager
from src.common.logger import get_module_logger
from ..utils.logger_config import LogClassification, LogModule logger = get_module_logger("prompt")
log_module = LogModule()
logger = log_module.setup_logger(LogClassification.PBUILDER)
logger.info("初始化Prompt系统") logger.info("初始化Prompt系统")

View File

@@ -1,11 +1,13 @@
import asyncio import asyncio
from typing import Optional from typing import Optional
from loguru import logger from src.common.logger import get_module_logger
from ...common.database import db from ...common.database import db
from .message_base import UserInfo from .message_base import UserInfo
from .chat_stream import ChatStream from .chat_stream import ChatStream
logger = get_module_logger("rel_manager")
class Impression: class Impression:
traits: str = None traits: str = None
called: str = None called: str = None

View File

@@ -3,7 +3,9 @@ from typing import Optional, Union
from ...common.database import db from ...common.database import db
from .message import MessageSending, MessageRecv from .message import MessageSending, MessageRecv
from .chat_stream import ChatStream from .chat_stream import ChatStream
from loguru import logger from src.common.logger import get_module_logger
logger = get_module_logger("message_storage")
class MessageStorage: class MessageStorage:

View File

@@ -4,7 +4,9 @@ from nonebot import get_driver
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from .config import global_config from .config import global_config
from loguru import logger from src.common.logger import get_module_logger
logger = get_module_logger("topic_identifier")
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config

View File

@@ -7,7 +7,7 @@ from typing import Dict, List
import jieba import jieba
import numpy as np import numpy as np
from nonebot import get_driver from nonebot import get_driver
from loguru import logger from src.common.logger import get_module_logger
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from ..utils.typo_generator import ChineseTypoGenerator from ..utils.typo_generator import ChineseTypoGenerator
@@ -21,6 +21,8 @@ from ...common.database import db
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
logger = get_module_logger("chat_utils")
def db_message_to_str(message_dict: Dict) -> str: def db_message_to_str(message_dict: Dict) -> str:

View File

@@ -7,13 +7,16 @@ from typing import Optional, Union
from PIL import Image from PIL import Image
import io import io
from loguru import logger
from nonebot import get_driver from nonebot import get_driver
from ...common.database import db from ...common.database import db
from ..chat.config import global_config from ..chat.config import global_config
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from src.common.logger import get_module_logger
logger = get_module_logger("chat_image")
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config

View File

@@ -1,10 +1,11 @@
from nonebot import get_app from nonebot import get_app
from .api import router from .api import router
from loguru import logger from src.common.logger import get_module_logger
# 获取主应用实例并挂载路由 # 获取主应用实例并挂载路由
app = get_app() app = get_app()
app.include_router(router, prefix="/api") app.include_router(router, prefix="/api")
# 打印日志方便确认API已注册 # 打印日志方便确认API已注册
logger = get_module_logger("cfg_reload")
logger.success("配置重载API已注册可通过 /api/reload-config 访问") logger.success("配置重载API已注册可通过 /api/reload-config 访问")

View File

@@ -7,7 +7,9 @@ import jieba
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import networkx as nx import networkx as nx
from dotenv import load_dotenv from dotenv import load_dotenv
from loguru import logger from src.common.logger import get_module_logger
logger = get_module_logger("draw_memory")
# 添加项目根目录到 Python 路径 # 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))

View File

@@ -3,7 +3,6 @@ import datetime
import math import math
import random import random
import time import time
import os
import jieba import jieba
import networkx as nx import networkx as nx
@@ -18,14 +17,10 @@ from ..chat.utils import (
text_to_vector, text_to_vector,
) )
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from src.common.logger import get_module_logger
from ..utils.logger_config import LogClassification, LogModule logger = get_module_logger("memory_sys")
# 配置日志
log_module = LogModule()
logger = log_module.setup_logger(LogClassification.MEMORY)
logger.info("初始化记忆系统")
class Memory_graph: class Memory_graph:
def __init__(self): def __init__(self):
@@ -35,9 +30,9 @@ class Memory_graph:
# 避免自连接 # 避免自连接
if concept1 == concept2: if concept1 == concept2:
return return
current_time = datetime.datetime.now().timestamp() current_time = datetime.datetime.now().timestamp()
# 如果边已存在,增加 strength # 如果边已存在,增加 strength
if self.G.has_edge(concept1, concept2): if self.G.has_edge(concept1, concept2):
self.G[concept1][concept2]['strength'] = self.G[concept1][concept2].get('strength', 1) + 1 self.G[concept1][concept2]['strength'] = self.G[concept1][concept2].get('strength', 1) + 1
@@ -45,14 +40,14 @@ class Memory_graph:
self.G[concept1][concept2]['last_modified'] = current_time self.G[concept1][concept2]['last_modified'] = current_time
else: else:
# 如果是新边,初始化 strength 为 1 # 如果是新边,初始化 strength 为 1
self.G.add_edge(concept1, concept2, self.G.add_edge(concept1, concept2,
strength=1, strength=1,
created_time=current_time, # 添加创建时间 created_time=current_time, # 添加创建时间
last_modified=current_time) # 添加最后修改时间 last_modified=current_time) # 添加最后修改时间
def add_dot(self, concept, memory): def add_dot(self, concept, memory):
current_time = datetime.datetime.now().timestamp() current_time = datetime.datetime.now().timestamp()
if concept in self.G: if concept in self.G:
if 'memory_items' in self.G.nodes[concept]: if 'memory_items' in self.G.nodes[concept]:
if not isinstance(self.G.nodes[concept]['memory_items'], list): if not isinstance(self.G.nodes[concept]['memory_items'], list):
@@ -68,10 +63,10 @@ class Memory_graph:
self.G.nodes[concept]['last_modified'] = current_time self.G.nodes[concept]['last_modified'] = current_time
else: else:
# 如果是新节点,创建新的记忆列表 # 如果是新节点,创建新的记忆列表
self.G.add_node(concept, self.G.add_node(concept,
memory_items=[memory], memory_items=[memory],
created_time=current_time, # 添加创建时间 created_time=current_time, # 添加创建时间
last_modified=current_time) # 添加最后修改时间 last_modified=current_time) # 添加最后修改时间
def get_dot(self, concept): def get_dot(self, concept):
# 检查节点是否存在于图中 # 检查节点是否存在于图中
@@ -210,12 +205,13 @@ class Hippocampus:
# 成功抽取短期消息样本 # 成功抽取短期消息样本
# 数据写回:增加记忆次数 # 数据写回:增加记忆次数
for message in messages: for message in messages:
db.messages.update_one({"_id": message["_id"]}, {"$set": {"memorized_times": message["memorized_times"] + 1}}) db.messages.update_one({"_id": message["_id"]},
{"$set": {"memorized_times": message["memorized_times"] + 1}})
return messages return messages
try_count += 1 try_count += 1
# 三次尝试均失败 # 三次尝试均失败
return None return None
def get_memory_sample(self, chat_size=20, time_frequency: dict = {'near': 2, 'mid': 4, 'far': 3}): def get_memory_sample(self, chat_size=20, time_frequency: dict = {'near': 2, 'mid': 4, 'far': 3}):
"""获取记忆样本 """获取记忆样本
@@ -225,7 +221,7 @@ class Hippocampus:
# 硬编码:每条消息最大记忆次数 # 硬编码:每条消息最大记忆次数
# 如有需求可写入global_config # 如有需求可写入global_config
max_memorized_time_per_msg = 3 max_memorized_time_per_msg = 3
current_timestamp = datetime.datetime.now().timestamp() current_timestamp = datetime.datetime.now().timestamp()
chat_samples = [] chat_samples = []
@@ -324,20 +320,20 @@ class Hippocampus:
# 为每个话题查找相似的已存在主题 # 为每个话题查找相似的已存在主题
existing_topics = list(self.memory_graph.G.nodes()) existing_topics = list(self.memory_graph.G.nodes())
similar_topics = [] similar_topics = []
for existing_topic in existing_topics: for existing_topic in existing_topics:
topic_words = set(jieba.cut(topic)) topic_words = set(jieba.cut(topic))
existing_words = set(jieba.cut(existing_topic)) existing_words = set(jieba.cut(existing_topic))
all_words = topic_words | existing_words all_words = topic_words | existing_words
v1 = [1 if word in topic_words else 0 for word in all_words] v1 = [1 if word in topic_words else 0 for word in all_words]
v2 = [1 if word in existing_words else 0 for word in all_words] v2 = [1 if word in existing_words else 0 for word in all_words]
similarity = cosine_similarity(v1, v2) similarity = cosine_similarity(v1, v2)
if similarity >= 0.6: if similarity >= 0.6:
similar_topics.append((existing_topic, similarity)) similar_topics.append((existing_topic, similarity))
similar_topics.sort(key=lambda x: x[1], reverse=True) similar_topics.sort(key=lambda x: x[1], reverse=True)
similar_topics = similar_topics[:5] similar_topics = similar_topics[:5]
similar_topics_dict[topic] = similar_topics similar_topics_dict[topic] = similar_topics
@@ -358,7 +354,7 @@ class Hippocampus:
async def operation_build_memory(self, chat_size=20): async def operation_build_memory(self, chat_size=20):
time_frequency = {'near': 1, 'mid': 4, 'far': 4} time_frequency = {'near': 1, 'mid': 4, 'far': 4}
memory_samples = self.get_memory_sample(chat_size, time_frequency) memory_samples = self.get_memory_sample(chat_size, time_frequency)
for i, messages in enumerate(memory_samples, 1): for i, messages in enumerate(memory_samples, 1):
all_topics = [] all_topics = []
# 加载进度可视化 # 加载进度可视化
@@ -371,14 +367,14 @@ class Hippocampus:
compress_rate = global_config.memory_compress_rate compress_rate = global_config.memory_compress_rate
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate) compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate)
logger.info(f"压缩后记忆数量: {len(compressed_memory)},似曾相识的话题: {len(similar_topics_dict)}") logger.info(f"压缩后记忆数量: {len(compressed_memory)},似曾相识的话题: {len(similar_topics_dict)}")
current_time = datetime.datetime.now().timestamp() current_time = datetime.datetime.now().timestamp()
for topic, memory in compressed_memory: for topic, memory in compressed_memory:
logger.info(f"添加节点: {topic}") logger.info(f"添加节点: {topic}")
self.memory_graph.add_dot(topic, memory) self.memory_graph.add_dot(topic, memory)
all_topics.append(topic) all_topics.append(topic)
# 连接相似的已存在主题 # 连接相似的已存在主题
if topic in similar_topics_dict: if topic in similar_topics_dict:
similar_topics = similar_topics_dict[topic] similar_topics = similar_topics_dict[topic]
@@ -386,11 +382,11 @@ class Hippocampus:
if topic != similar_topic: if topic != similar_topic:
strength = int(similarity * 10) strength = int(similarity * 10)
logger.info(f"连接相似节点: {topic}{similar_topic} (强度: {strength})") logger.info(f"连接相似节点: {topic}{similar_topic} (强度: {strength})")
self.memory_graph.G.add_edge(topic, similar_topic, self.memory_graph.G.add_edge(topic, similar_topic,
strength=strength, strength=strength,
created_time=current_time, created_time=current_time,
last_modified=current_time) last_modified=current_time)
# 连接同批次的相关话题 # 连接同批次的相关话题
for i in range(len(all_topics)): for i in range(len(all_topics)):
for j in range(i + 1, len(all_topics)): for j in range(i + 1, len(all_topics)):
@@ -416,7 +412,7 @@ class Hippocampus:
# 计算内存中节点的特征值 # 计算内存中节点的特征值
memory_hash = self.calculate_node_hash(concept, memory_items) memory_hash = self.calculate_node_hash(concept, memory_items)
# 获取时间信息 # 获取时间信息
created_time = data.get('created_time', datetime.datetime.now().timestamp()) created_time = data.get('created_time', datetime.datetime.now().timestamp())
last_modified = data.get('last_modified', datetime.datetime.now().timestamp()) last_modified = data.get('last_modified', datetime.datetime.now().timestamp())
@@ -466,7 +462,7 @@ class Hippocampus:
edge_hash = self.calculate_edge_hash(source, target) edge_hash = self.calculate_edge_hash(source, target)
edge_key = (source, target) edge_key = (source, target)
strength = data.get('strength', 1) strength = data.get('strength', 1)
# 获取边的时间信息 # 获取边的时间信息
created_time = data.get('created_time', datetime.datetime.now().timestamp()) created_time = data.get('created_time', datetime.datetime.now().timestamp())
last_modified = data.get('last_modified', datetime.datetime.now().timestamp()) last_modified = data.get('last_modified', datetime.datetime.now().timestamp())
@@ -499,7 +495,7 @@ class Hippocampus:
"""从数据库同步数据到内存中的图结构""" """从数据库同步数据到内存中的图结构"""
current_time = datetime.datetime.now().timestamp() current_time = datetime.datetime.now().timestamp()
need_update = False need_update = False
# 清空当前图 # 清空当前图
self.memory_graph.G.clear() self.memory_graph.G.clear()
@@ -510,7 +506,7 @@ class Hippocampus:
memory_items = node.get('memory_items', []) memory_items = node.get('memory_items', [])
if not isinstance(memory_items, list): if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else [] memory_items = [memory_items] if memory_items else []
# 检查时间字段是否存在 # 检查时间字段是否存在
if 'created_time' not in node or 'last_modified' not in node: if 'created_time' not in node or 'last_modified' not in node:
need_update = True need_update = True
@@ -520,22 +516,22 @@ class Hippocampus:
update_data['created_time'] = current_time update_data['created_time'] = current_time
if 'last_modified' not in node: if 'last_modified' not in node:
update_data['last_modified'] = current_time update_data['last_modified'] = current_time
db.graph_data.nodes.update_one( db.graph_data.nodes.update_one(
{'concept': concept}, {'concept': concept},
{'$set': update_data} {'$set': update_data}
) )
logger.info(f"[时间更新] 节点 {concept} 添加缺失的时间字段") logger.info(f"[时间更新] 节点 {concept} 添加缺失的时间字段")
# 获取时间信息(如果不存在则使用当前时间) # 获取时间信息(如果不存在则使用当前时间)
created_time = node.get('created_time', current_time) created_time = node.get('created_time', current_time)
last_modified = node.get('last_modified', current_time) last_modified = node.get('last_modified', current_time)
# 添加节点到图中 # 添加节点到图中
self.memory_graph.G.add_node(concept, self.memory_graph.G.add_node(concept,
memory_items=memory_items, memory_items=memory_items,
created_time=created_time, created_time=created_time,
last_modified=last_modified) last_modified=last_modified)
# 从数据库加载所有边 # 从数据库加载所有边
edges = list(db.graph_data.edges.find()) edges = list(db.graph_data.edges.find())
@@ -543,7 +539,7 @@ class Hippocampus:
source = edge['source'] source = edge['source']
target = edge['target'] target = edge['target']
strength = edge.get('strength', 1) strength = edge.get('strength', 1)
# 检查时间字段是否存在 # 检查时间字段是否存在
if 'created_time' not in edge or 'last_modified' not in edge: if 'created_time' not in edge or 'last_modified' not in edge:
need_update = True need_update = True
@@ -553,24 +549,24 @@ class Hippocampus:
update_data['created_time'] = current_time update_data['created_time'] = current_time
if 'last_modified' not in edge: if 'last_modified' not in edge:
update_data['last_modified'] = current_time update_data['last_modified'] = current_time
db.graph_data.edges.update_one( db.graph_data.edges.update_one(
{'source': source, 'target': target}, {'source': source, 'target': target},
{'$set': update_data} {'$set': update_data}
) )
logger.info(f"[时间更新] 边 {source} - {target} 添加缺失的时间字段") logger.info(f"[时间更新] 边 {source} - {target} 添加缺失的时间字段")
# 获取时间信息(如果不存在则使用当前时间) # 获取时间信息(如果不存在则使用当前时间)
created_time = edge.get('created_time', current_time) created_time = edge.get('created_time', current_time)
last_modified = edge.get('last_modified', current_time) last_modified = edge.get('last_modified', current_time)
# 只有当源节点和目标节点都存在时才添加边 # 只有当源节点和目标节点都存在时才添加边
if source in self.memory_graph.G and target in self.memory_graph.G: if source in self.memory_graph.G and target in self.memory_graph.G:
self.memory_graph.G.add_edge(source, target, self.memory_graph.G.add_edge(source, target,
strength=strength, strength=strength,
created_time=created_time, created_time=created_time,
last_modified=last_modified) last_modified=last_modified)
if need_update: if need_update:
logger.success("[数据库] 已为缺失的时间字段进行补充") logger.success("[数据库] 已为缺失的时间字段进行补充")
@@ -578,44 +574,44 @@ class Hippocampus:
"""随机选择图中一定比例的节点和边进行检查,根据时间条件决定是否遗忘""" """随机选择图中一定比例的节点和边进行检查,根据时间条件决定是否遗忘"""
# 检查数据库是否为空 # 检查数据库是否为空
# logger.remove() # logger.remove()
logger.info(f"[遗忘] 开始检查数据库... 当前Logger信息:") logger.info(f"[遗忘] 开始检查数据库... 当前Logger信息:")
# logger.info(f"- Logger名称: {logger.name}") # logger.info(f"- Logger名称: {logger.name}")
logger.info(f"- Logger等级: {logger.level}") logger.info(f"- Logger等级: {logger.level}")
# logger.info(f"- Logger处理器: {[handler.__class__.__name__ for handler in logger.handlers]}") # logger.info(f"- Logger处理器: {[handler.__class__.__name__ for handler in logger.handlers]}")
# logger2 = setup_logger(LogModule.MEMORY) # logger2 = setup_logger(LogModule.MEMORY)
# logger2.info(f"[遗忘] 开始检查数据库... 当前Logger信息:") # logger2.info(f"[遗忘] 开始检查数据库... 当前Logger信息:")
# logger.info(f"[遗忘] 开始检查数据库... 当前Logger信息:") # logger.info(f"[遗忘] 开始检查数据库... 当前Logger信息:")
all_nodes = list(self.memory_graph.G.nodes()) all_nodes = list(self.memory_graph.G.nodes())
all_edges = list(self.memory_graph.G.edges()) all_edges = list(self.memory_graph.G.edges())
if not all_nodes and not all_edges: if not all_nodes and not all_edges:
logger.info("[遗忘] 记忆图为空,无需进行遗忘操作") logger.info("[遗忘] 记忆图为空,无需进行遗忘操作")
return return
check_nodes_count = max(1, int(len(all_nodes) * percentage)) check_nodes_count = max(1, int(len(all_nodes) * percentage))
check_edges_count = max(1, int(len(all_edges) * percentage)) check_edges_count = max(1, int(len(all_edges) * percentage))
nodes_to_check = random.sample(all_nodes, check_nodes_count) nodes_to_check = random.sample(all_nodes, check_nodes_count)
edges_to_check = random.sample(all_edges, check_edges_count) edges_to_check = random.sample(all_edges, check_edges_count)
edge_changes = {'weakened': 0, 'removed': 0} edge_changes = {'weakened': 0, 'removed': 0}
node_changes = {'reduced': 0, 'removed': 0} node_changes = {'reduced': 0, 'removed': 0}
current_time = datetime.datetime.now().timestamp() current_time = datetime.datetime.now().timestamp()
# 检查并遗忘连接 # 检查并遗忘连接
logger.info("[遗忘] 开始检查连接...") logger.info("[遗忘] 开始检查连接...")
for source, target in edges_to_check: for source, target in edges_to_check:
edge_data = self.memory_graph.G[source][target] edge_data = self.memory_graph.G[source][target]
last_modified = edge_data.get('last_modified') last_modified = edge_data.get('last_modified')
if current_time - last_modified > 3600*global_config.memory_forget_time: if current_time - last_modified > 3600 * global_config.memory_forget_time:
current_strength = edge_data.get('strength', 1) current_strength = edge_data.get('strength', 1)
new_strength = current_strength - 1 new_strength = current_strength - 1
if new_strength <= 0: if new_strength <= 0:
self.memory_graph.G.remove_edge(source, target) self.memory_graph.G.remove_edge(source, target)
edge_changes['removed'] += 1 edge_changes['removed'] += 1
@@ -625,23 +621,23 @@ class Hippocampus:
edge_data['last_modified'] = current_time edge_data['last_modified'] = current_time
edge_changes['weakened'] += 1 edge_changes['weakened'] += 1
logger.info(f"[遗忘] 连接减弱: {source} -> {target} (强度: {current_strength} -> {new_strength})") logger.info(f"[遗忘] 连接减弱: {source} -> {target} (强度: {current_strength} -> {new_strength})")
# 检查并遗忘话题 # 检查并遗忘话题
logger.info("[遗忘] 开始检查节点...") logger.info("[遗忘] 开始检查节点...")
for node in nodes_to_check: for node in nodes_to_check:
node_data = self.memory_graph.G.nodes[node] node_data = self.memory_graph.G.nodes[node]
last_modified = node_data.get('last_modified', current_time) last_modified = node_data.get('last_modified', current_time)
if current_time - last_modified > 3600*24: if current_time - last_modified > 3600 * 24:
memory_items = node_data.get('memory_items', []) memory_items = node_data.get('memory_items', [])
if not isinstance(memory_items, list): if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else [] memory_items = [memory_items] if memory_items else []
if memory_items: if memory_items:
current_count = len(memory_items) current_count = len(memory_items)
removed_item = random.choice(memory_items) removed_item = random.choice(memory_items)
memory_items.remove(removed_item) memory_items.remove(removed_item)
if memory_items: if memory_items:
self.memory_graph.G.nodes[node]['memory_items'] = memory_items self.memory_graph.G.nodes[node]['memory_items'] = memory_items
self.memory_graph.G.nodes[node]['last_modified'] = current_time self.memory_graph.G.nodes[node]['last_modified'] = current_time
@@ -651,7 +647,7 @@ class Hippocampus:
self.memory_graph.G.remove_node(node) self.memory_graph.G.remove_node(node)
node_changes['removed'] += 1 node_changes['removed'] += 1
logger.info(f"[遗忘] 节点移除: {node}") logger.info(f"[遗忘] 节点移除: {node}")
if any(count > 0 for count in edge_changes.values()) or any(count > 0 for count in node_changes.values()): if any(count > 0 for count in edge_changes.values()) or any(count > 0 for count in node_changes.values()):
self.sync_memory_to_db() self.sync_memory_to_db()
logger.info("[遗忘] 统计信息:") logger.info("[遗忘] 统计信息:")
@@ -943,6 +939,7 @@ def segment_text(text):
seg_text = list(jieba.cut(text)) seg_text = list(jieba.cut(text))
return seg_text return seg_text
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config

View File

@@ -11,7 +11,7 @@ from pathlib import Path
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import networkx as nx import networkx as nx
from dotenv import load_dotenv from dotenv import load_dotenv
from loguru import logger from src.common.logger import get_module_logger
import jieba import jieba
# from chat.config import global_config # from chat.config import global_config
@@ -29,6 +29,8 @@ project_root = current_dir.parent.parent.parent
# env.dev文件路径 # env.dev文件路径
env_path = project_root / ".env.dev" env_path = project_root / ".env.dev"
logger = get_module_logger("mem_manual_bd")
# 加载环境变量 # 加载环境变量
if env_path.exists(): if env_path.exists():
logger.info(f"{env_path} 加载环境变量") logger.info(f"{env_path} 加载环境变量")

View File

@@ -12,9 +12,11 @@ import matplotlib.pyplot as plt
import networkx as nx import networkx as nx
import pymongo import pymongo
from dotenv import load_dotenv from dotenv import load_dotenv
from loguru import logger from src.common.logger import get_module_logger
import jieba import jieba
logger = get_module_logger("mem_test")
''' '''
该理论认为,当两个或多个事物在形态上具有相似性时, 该理论认为,当两个或多个事物在形态上具有相似性时,
它们在记忆中会形成关联。 它们在记忆中会形成关联。

View File

@@ -5,8 +5,9 @@ from typing import Tuple, Union
import aiohttp import aiohttp
import requests import requests
from loguru import logger from src.common.logger import get_module_logger
logger = get_module_logger("offline_llm")
class LLMModel: class LLMModel:
def __init__(self, model_name="deepseek-ai/DeepSeek-V3", **kwargs): def __init__(self, model_name="deepseek-ai/DeepSeek-V3", **kwargs):

View File

@@ -5,7 +5,7 @@ from datetime import datetime
from typing import Tuple, Union from typing import Tuple, Union
import aiohttp import aiohttp
from loguru import logger from src.common.logger import get_module_logger
from nonebot import get_driver from nonebot import get_driver
import base64 import base64
from PIL import Image from PIL import Image
@@ -16,6 +16,8 @@ from ..chat.config import global_config
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
logger = get_module_logger("model_utils")
class LLM_request: class LLM_request:
# 定义需要转换的模型列表,作为类变量避免重复 # 定义需要转换的模型列表,作为类变量避免重复

View File

@@ -4,7 +4,9 @@ import time
from dataclasses import dataclass from dataclasses import dataclass
from ..chat.config import global_config from ..chat.config import global_config
from loguru import logger from src.common.logger import get_module_logger
logger = get_module_logger("mood_manager")
@dataclass @dataclass
class MoodState: class MoodState:

View File

@@ -5,7 +5,9 @@ import platform
import os import os
import json import json
import threading import threading
from loguru import logger from src.common.logger import get_module_logger
logger = get_module_logger("remote")
# UUID文件路径 # UUID文件路径
UUID_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "client_uuid.json") UUID_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "client_uuid.json")
@@ -30,9 +32,9 @@ def get_unique_id():
try: try:
with open(UUID_FILE, "w") as f: with open(UUID_FILE, "w") as f:
json.dump({"client_id": client_id}, f) json.dump({"client_id": client_id}, f)
print("已保存新生成的客户端ID到本地文件") logger.info("已保存新生成的客户端ID到本地文件")
except IOError as e: except IOError as e:
print(f"保存UUID时出错: {e}") logger.error(f"保存UUID时出错: {e}")
return client_id return client_id

View File

@@ -3,13 +3,15 @@ import json
import re import re
from typing import Dict, Union from typing import Dict, Union
from loguru import logger
from nonebot import get_driver from nonebot import get_driver
from src.plugins.chat.config import global_config from src.plugins.chat.config import global_config
from ...common.database import db # 使用正确的导入语法 from ...common.database import db # 使用正确的导入语法
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from src.common.logger import get_module_logger
logger = get_module_logger("scheduler")
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config

View File

@@ -3,10 +3,11 @@ import time
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict from typing import Any, Dict
from loguru import logger from src.common.logger import get_module_logger
from ...common.database import db from ...common.database import db
logger = get_module_logger("llm_statistics")
class LLMStatistics: class LLMStatistics:
def __init__(self, output_file: str = "llm_statistics.txt"): def __init__(self, output_file: str = "llm_statistics.txt"):

View File

@@ -13,8 +13,9 @@ from pathlib import Path
import jieba import jieba
from pypinyin import Style, pinyin from pypinyin import Style, pinyin
from loguru import logger from src.common.logger import get_module_logger
logger = get_module_logger("typo_gen")
class ChineseTypoGenerator: class ChineseTypoGenerator:
def __init__(self, def __init__(self,

View File

@@ -2,7 +2,9 @@ import asyncio
import random import random
import time import time
from typing import Dict from typing import Dict
from loguru import logger from src.common.logger import get_module_logger
logger = get_module_logger("mode_dynamic")
from ..chat.config import global_config from ..chat.config import global_config

View File

@@ -1,11 +1,13 @@
from typing import Optional from typing import Optional
from loguru import logger from src.common.logger import get_module_logger
from ..chat.config import global_config from ..chat.config import global_config
from .mode_classical import WillingManager as ClassicalWillingManager from .mode_classical import WillingManager as ClassicalWillingManager
from .mode_dynamic import WillingManager as DynamicWillingManager from .mode_dynamic import WillingManager as DynamicWillingManager
from .mode_custom import WillingManager as CustomWillingManager from .mode_custom import WillingManager as CustomWillingManager
logger = get_module_logger("willing")
def init_willing_manager() -> Optional[object]: def init_willing_manager() -> Optional[object]:
""" """
根据配置初始化并返回对应的WillingManager实例 根据配置初始化并返回对应的WillingManager实例

View File

@@ -23,7 +23,13 @@ CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/ SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1 DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
#定义你要用的api的key(需要去对应网站申请哦) # 定义你要用的api的key(需要去对应网站申请哦)
DEEP_SEEK_KEY= DEEP_SEEK_KEY=
CHAT_ANY_WHERE_KEY= CHAT_ANY_WHERE_KEY=
SILICONFLOW_KEY= SILICONFLOW_KEY=
# 定义日志相关配置
CONSOLE_LOG_LEVEL=INFO # 自定义日志的默认控制台输出日志级别
FILE_LOG_LEVEL=DEBUG # 自定义日志的默认文件输出日志级别
DEFAULT_CONSOLE_LOG_LEVEL=SUCCESS # 原生日志的控制台输出日志级别nonebot就是这一类
DEFAULT_FILE_LOG_LEVEL=DEBUG # 原生日志的默认文件输出日志级别nonebot就是这一类

View File

@@ -2,11 +2,12 @@ import gradio as gr
import os import os
import sys import sys
import toml import toml
from loguru import logger from src.common.logger import get_module_logger
import shutil import shutil
import ast import ast
import json import json
logger = get_module_logger("webui")
is_share = False is_share = False
debug = True debug = True