初始化

This commit is contained in:
雅诺狐
2025-08-11 19:34:18 +08:00
parent ff7d1177fa
commit 2d4745cd58
257 changed files with 69069 additions and 0 deletions

1
src/common/__init__.py Normal file
View File

@@ -0,0 +1 @@
# 这个文件可以为空,但必须存在

View File

View File

@@ -0,0 +1,192 @@
import os
from pymongo import MongoClient
from pymongo.database import Database
from rich.traceback import install
from src.common.logger import get_logger
# SQLAlchemy相关导入
from src.common.database.sqlalchemy_init import initialize_database_compat
from src.common.database.sqlalchemy_models import get_engine, get_session
install(extra_lines=3)
_client = None
_db = None
_sql_engine = None
logger = get_logger("database")
# 兼容性为了不破坏现有代码保留db变量但指向SQLAlchemy
class DatabaseProxy:
"""数据库代理类提供Peewee到SQLAlchemy的兼容性接口"""
def __init__(self):
self._engine = None
self._session = None
def initialize(self, *args, **kwargs):
"""初始化数据库连接"""
return initialize_database_compat()
def connect(self, reuse_if_open=True):
"""连接数据库(兼容性方法)"""
try:
self._engine = get_engine()
return True
except Exception as e:
logger.error(f"数据库连接失败: {e}")
return False
def is_closed(self):
"""检查数据库是否关闭(兼容性方法)"""
return self._engine is None
def create_tables(self, models, safe=True):
"""创建表(兼容性方法)"""
try:
from src.common.database.sqlalchemy_models import Base
engine = get_engine()
Base.metadata.create_all(bind=engine)
return True
except Exception as e:
logger.error(f"创建表失败: {e}")
return False
def table_exists(self, model):
"""检查表是否存在(兼容性方法)"""
try:
from sqlalchemy import inspect
engine = get_engine()
inspector = inspect(engine)
table_name = getattr(model, '_meta', {}).get('table_name', model.__name__.lower())
return table_name in inspector.get_table_names()
except Exception:
return False
def execute_sql(self, sql):
"""执行SQL兼容性方法"""
try:
from sqlalchemy import text
session = get_session()
result = session.execute(text(sql))
session.close()
return result
except Exception as e:
logger.error(f"执行SQL失败: {e}")
raise
def atomic(self):
"""事务上下文管理器(兼容性方法)"""
return SQLAlchemyTransaction()
class SQLAlchemyTransaction:
"""SQLAlchemy事务上下文管理器"""
def __init__(self):
self.session = None
def __enter__(self):
self.session = get_session()
return self.session
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
self.session.commit()
else:
self.session.rollback()
self.session.close()
# 创建全局数据库代理实例
db = DatabaseProxy()
def __create_database_instance():
uri = os.getenv("MONGODB_URI")
host = os.getenv("MONGODB_HOST", "127.0.0.1")
port = int(os.getenv("MONGODB_PORT", "27017"))
# db_name 变量在创建连接时不需要,在获取数据库实例时才使用
username = os.getenv("MONGODB_USERNAME")
password = os.getenv("MONGODB_PASSWORD")
auth_source = os.getenv("MONGODB_AUTH_SOURCE")
if uri:
# 支持标准mongodb://和mongodb+srv://连接字符串
if uri.startswith(("mongodb://", "mongodb+srv://")):
return MongoClient(uri)
else:
raise ValueError(
"Invalid MongoDB URI format. URI must start with 'mongodb://' or 'mongodb+srv://'. "
"For MongoDB Atlas, use 'mongodb+srv://' format. "
"See: https://www.mongodb.com/docs/manual/reference/connection-string/"
)
if username and password:
# 如果有用户名和密码,使用认证连接
return MongoClient(host, port, username=username, password=password, authSource=auth_source)
# 否则使用无认证连接
return MongoClient(host, port)
def get_db():
"""获取MongoDB连接实例延迟初始化。"""
global _client, _db
if _client is None:
_client = __create_database_instance()
_db = _client[os.getenv("DATABASE_NAME", "MegBot")]
return _db
def initialize_sql_database(database_config):
"""
根据配置初始化SQL数据库连接SQLAlchemy版本
Args:
database_config: DatabaseConfig对象
"""
global _sql_engine
try:
logger.info("使用SQLAlchemy初始化SQL数据库...")
# 记录数据库配置信息
if database_config.database_type == "mysql":
connection_info = f"{database_config.mysql_user}@{database_config.mysql_host}:{database_config.mysql_port}/{database_config.mysql_database}"
logger.info("MySQL数据库连接配置:")
logger.info(f" 连接信息: {connection_info}")
logger.info(f" 字符集: {database_config.mysql_charset}")
else:
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
if not os.path.isabs(database_config.sqlite_path):
db_path = os.path.join(ROOT_PATH, database_config.sqlite_path)
else:
db_path = database_config.sqlite_path
logger.info("SQLite数据库连接配置:")
logger.info(f" 数据库文件: {db_path}")
# 使用SQLAlchemy初始化
success = initialize_database_compat()
if success:
_sql_engine = get_engine()
logger.info("SQLAlchemy数据库初始化成功")
else:
logger.error("SQLAlchemy数据库初始化失败")
return _sql_engine
except Exception as e:
logger.error(f"初始化SQL数据库失败: {e}")
return None
class DBWrapper:
"""数据库代理类,保持接口兼容性同时实现懒加载。"""
def __getattr__(self, name):
return getattr(get_db(), name)
def __getitem__(self, key):
return get_db()[key] # type: ignore
# 全局MongoDB数据库访问点
memory_db: Database = DBWrapper() # type: ignore

View File

@@ -0,0 +1,420 @@
"""SQLAlchemy数据库API模块
提供基于SQLAlchemy的数据库操作替换Peewee以解决MySQL连接问题
支持自动重连、连接池管理和更好的错误处理
"""
import traceback
import time
from typing import Dict, List, Any, Union, Type, Optional
from contextlib import contextmanager
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError, DisconnectionError, OperationalError
from sqlalchemy import desc, asc, func, and_, or_
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import (
Base, get_db_session, Messages, ActionRecords, PersonInfo, ChatStreams,
LLMUsage, Emoji, Images, ImageDescriptions, OnlineTime, Memory,
Expression, ThinkingLog, GraphNodes, GraphEdges,get_session
)
logger = get_logger("sqlalchemy_database_api")
# 模型映射表,用于通过名称获取模型类
MODEL_MAPPING = {
'Messages': Messages,
'ActionRecords': ActionRecords,
'PersonInfo': PersonInfo,
'ChatStreams': ChatStreams,
'LLMUsage': LLMUsage,
'Emoji': Emoji,
'Images': Images,
'ImageDescriptions': ImageDescriptions,
'OnlineTime': OnlineTime,
'Memory': Memory,
'Expression': Expression,
'ThinkingLog': ThinkingLog,
'GraphNodes': GraphNodes,
'GraphEdges': GraphEdges,
}
@contextmanager
def get_db_session():
"""数据库会话上下文管理器,自动处理事务和连接错误"""
session = None
max_retries = 3
retry_delay = 1.0
for attempt in range(max_retries):
try:
session = get_session()
yield session
session.commit()
break
except (DisconnectionError, OperationalError) as e:
logger.warning(f"数据库连接错误 (尝试 {attempt + 1}/{max_retries}): {e}")
if session:
session.rollback()
session.close()
if attempt < max_retries - 1:
time.sleep(retry_delay * (attempt + 1))
else:
raise
except Exception as e:
if session:
session.rollback()
raise
finally:
if session:
session.close()
def build_filters(session: Session, model_class: Type[Base], filters: Dict[str, Any]):
"""构建查询过滤条件"""
conditions = []
for field_name, value in filters.items():
if not hasattr(model_class, field_name):
logger.warning(f"模型 {model_class.__name__} 中不存在字段 '{field_name}'")
continue
field = getattr(model_class, field_name)
if isinstance(value, dict):
# 处理 MongoDB 风格的操作符
for op, op_value in value.items():
if op == "$gt":
conditions.append(field > op_value)
elif op == "$lt":
conditions.append(field < op_value)
elif op == "$gte":
conditions.append(field >= op_value)
elif op == "$lte":
conditions.append(field <= op_value)
elif op == "$ne":
conditions.append(field != op_value)
elif op == "$in":
conditions.append(field.in_(op_value))
elif op == "$nin":
conditions.append(~field.in_(op_value))
else:
logger.warning(f"未知操作符 '{op}' (字段: '{field_name}')")
else:
# 直接相等比较
conditions.append(field == value)
return conditions
async def db_query(
model_class: Type[Base],
data: Optional[Dict[str, Any]] = None,
query_type: Optional[str] = "get",
filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""执行数据库查询操作
Args:
model_class: SQLAlchemy模型类
data: 用于创建或更新的数据字典
query_type: 查询类型 ("get", "create", "update", "delete", "count")
filters: 过滤条件字典
limit: 限制结果数量
order_by: 排序字段,前缀'-'表示降序
single_result: 是否只返回单个结果
Returns:
根据查询类型返回相应结果
"""
try:
if query_type not in ["get", "create", "update", "delete", "count"]:
raise ValueError("query_type must be 'get', 'create', 'update', 'delete' or 'count'")
with get_db_session() as session:
if query_type == "get":
query = session.query(model_class)
# 应用过滤条件
if filters:
conditions = build_filters(session, model_class, filters)
if conditions:
query = query.filter(and_(*conditions))
# 应用排序
if order_by:
for field_name in order_by:
if field_name.startswith("-"):
field_name = field_name[1:]
if hasattr(model_class, field_name):
query = query.order_by(desc(getattr(model_class, field_name)))
else:
if hasattr(model_class, field_name):
query = query.order_by(asc(getattr(model_class, field_name)))
# 应用限制
if limit and limit > 0:
query = query.limit(limit)
# 执行查询
results = query.all()
# 转换为字典格式
result_dicts = []
for result in results:
result_dict = {}
for column in result.__table__.columns:
result_dict[column.name] = getattr(result, column.name)
result_dicts.append(result_dict)
if single_result:
return result_dicts[0] if result_dicts else None
return result_dicts
elif query_type == "create":
if not data:
raise ValueError("创建记录需要提供data参数")
# 创建新记录
new_record = model_class(**data)
session.add(new_record)
session.flush() # 获取自动生成的ID
# 转换为字典格式返回
result_dict = {}
for column in new_record.__table__.columns:
result_dict[column.name] = getattr(new_record, column.name)
return result_dict
elif query_type == "update":
if not data:
raise ValueError("更新记录需要提供data参数")
query = session.query(model_class)
# 应用过滤条件
if filters:
conditions = build_filters(session, model_class, filters)
if conditions:
query = query.filter(and_(*conditions))
# 执行更新
affected_rows = query.update(data)
return affected_rows
elif query_type == "delete":
query = session.query(model_class)
# 应用过滤条件
if filters:
conditions = build_filters(session, model_class, filters)
if conditions:
query = query.filter(and_(*conditions))
# 执行删除
affected_rows = query.delete()
return affected_rows
elif query_type == "count":
query = session.query(func.count(model_class.id))
# 应用过滤条件
if filters:
base_query = session.query(model_class)
conditions = build_filters(session, model_class, filters)
if conditions:
base_query = base_query.filter(and_(*conditions))
query = session.query(func.count()).select_from(base_query.subquery())
return query.scalar()
except SQLAlchemyError as e:
logger.error(f"[SQLAlchemy] 数据库操作出错: {e}")
traceback.print_exc()
# 根据查询类型返回合适的默认值
if query_type == "get":
return None if single_result else []
elif query_type in ["create", "update", "delete", "count"]:
return None
return None
except Exception as e:
logger.error(f"[SQLAlchemy] 意外错误: {e}")
traceback.print_exc()
if query_type == "get":
return None if single_result else []
return None
async def db_save(
model_class: Type[Base],
data: Dict[str, Any],
key_field: Optional[str] = None,
key_value: Optional[Any] = None
) -> Optional[Dict[str, Any]]:
"""保存数据到数据库(创建或更新)
Args:
model_class: SQLAlchemy模型类
data: 要保存的数据字典
key_field: 用于查找现有记录的字段名
key_value: 用于查找现有记录的字段值
Returns:
保存后的记录数据或None
"""
try:
with get_db_session() as session:
# 如果提供了key_field和key_value尝试更新现有记录
if key_field and key_value is not None:
if hasattr(model_class, key_field):
existing_record = session.query(model_class).filter(
getattr(model_class, key_field) == key_value
).first()
if existing_record:
# 更新现有记录
for field, value in data.items():
if hasattr(existing_record, field):
setattr(existing_record, field, value)
session.flush()
# 转换为字典格式返回
result_dict = {}
for column in existing_record.__table__.columns:
result_dict[column.name] = getattr(existing_record, column.name)
return result_dict
# 创建新记录
new_record = model_class(**data)
session.add(new_record)
session.flush()
# 转换为字典格式返回
result_dict = {}
for column in new_record.__table__.columns:
result_dict[column.name] = getattr(new_record, column.name)
return result_dict
except SQLAlchemyError as e:
logger.error(f"[SQLAlchemy] 保存数据库记录出错: {e}")
traceback.print_exc()
return None
except Exception as e:
logger.error(f"[SQLAlchemy] 保存时意外错误: {e}")
traceback.print_exc()
return None
async def db_get(
model_class: Type[Base],
filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[str] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""从数据库获取记录
Args:
model_class: SQLAlchemy模型类
filters: 过滤条件
limit: 结果数量限制
order_by: 排序字段,前缀'-'表示降序
single_result: 是否只返回单个结果
Returns:
记录数据或None
"""
order_by_list = [order_by] if order_by else None
return await db_query(
model_class=model_class,
query_type="get",
filters=filters,
limit=limit,
order_by=order_by_list,
single_result=single_result
)
async def store_action_info(
chat_stream=None,
action_build_into_prompt: bool = False,
action_prompt_display: str = "",
action_done: bool = True,
thinking_id: str = "",
action_data: Optional[dict] = None,
action_name: str = "",
) -> Optional[Dict[str, Any]]:
"""存储动作信息到数据库
Args:
chat_stream: 聊天流对象
action_build_into_prompt: 是否将此动作构建到提示中
action_prompt_display: 动作的提示显示文本
action_done: 动作是否完成
thinking_id: 关联的思考ID
action_data: 动作数据字典
action_name: 动作名称
Returns:
保存的记录数据或None
"""
try:
import json
# 构建动作记录数据
record_data = {
"action_id": thinking_id or str(int(time.time() * 1000000)),
"time": time.time(),
"action_name": action_name,
"action_data": json.dumps(action_data or {}, ensure_ascii=False),
"action_done": action_done,
"action_build_into_prompt": action_build_into_prompt,
"action_prompt_display": action_prompt_display,
}
# 从chat_stream获取聊天信息
if chat_stream:
record_data.update({
"chat_id": getattr(chat_stream, "stream_id", ""),
"chat_info_stream_id": getattr(chat_stream, "stream_id", ""),
"chat_info_platform": getattr(chat_stream, "platform", ""),
})
else:
record_data.update({
"chat_id": "",
"chat_info_stream_id": "",
"chat_info_platform": "",
})
# 保存记录
saved_record = await db_save(
ActionRecords,
data=record_data,
key_field="action_id",
key_value=record_data["action_id"]
)
if saved_record:
logger.debug(f"[SQLAlchemy] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
else:
logger.error(f"[SQLAlchemy] 存储动作信息失败: {action_name}")
return saved_record
except Exception as e:
logger.error(f"[SQLAlchemy] 存储动作信息时发生错误: {e}")
traceback.print_exc()
return None
# 兼容性函数方便从Peewee迁移
def get_model_class(model_name: str) -> Optional[Type[Base]]:
"""根据模型名称获取模型类"""
return MODEL_MAPPING.get(model_name)

View File

@@ -0,0 +1,158 @@
"""SQLAlchemy数据库初始化模块
替换Peewee的数据库初始化逻辑
提供统一的数据库初始化接口
"""
from typing import Optional
from sqlalchemy.exc import SQLAlchemyError
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import (
Base, get_engine, get_session, initialize_database
)
logger = get_logger("sqlalchemy_init")
def initialize_sqlalchemy_database() -> bool:
"""
初始化SQLAlchemy数据库
创建所有表结构
Returns:
bool: 初始化是否成功
"""
try:
logger.info("开始初始化SQLAlchemy数据库...")
# 初始化数据库引擎和会话
engine, session_local = initialize_database()
if engine is None:
logger.error("数据库引擎初始化失败")
return False
logger.info("SQLAlchemy数据库初始化成功")
return True
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy数据库初始化失败: {e}")
return False
except Exception as e:
logger.error(f"数据库初始化过程中发生未知错误: {e}")
return False
def create_all_tables() -> bool:
"""
创建所有数据库表
Returns:
bool: 创建是否成功
"""
try:
logger.info("开始创建数据库表...")
engine = get_engine()
if engine is None:
logger.error("无法获取数据库引擎")
return False
# 创建所有表
Base.metadata.create_all(bind=engine)
logger.info("数据库表创建成功")
return True
except SQLAlchemyError as e:
logger.error(f"创建数据库表失败: {e}")
return False
except Exception as e:
logger.error(f"创建数据库表过程中发生未知错误: {e}")
return False
def check_database_connection() -> bool:
"""
检查数据库连接是否正常
Returns:
bool: 连接是否正常
"""
try:
session = get_session()
if session is None:
logger.error("无法获取数据库会话")
return False
# 检查会话是否可用(如果能获取到会话说明连接正常)
if session is None:
logger.error("数据库会话无效")
return False
session.close()
logger.info("数据库连接检查通过")
return True
except SQLAlchemyError as e:
logger.error(f"数据库连接检查失败: {e}")
return False
except Exception as e:
logger.error(f"数据库连接检查过程中发生未知错误: {e}")
return False
def get_database_info() -> Optional[dict]:
"""
获取数据库信息
Returns:
dict: 数据库信息字典,包含引擎信息等
"""
try:
engine = get_engine()
if engine is None:
return None
info = {
'engine_name': engine.name,
'driver': engine.driver,
'url': str(engine.url).replace(engine.url.password or '', '***'), # 隐藏密码
'pool_size': getattr(engine.pool, 'size', None),
'max_overflow': getattr(engine.pool, 'max_overflow', None),
}
return info
except Exception as e:
logger.error(f"获取数据库信息失败: {e}")
return None
_database_initialized = False
def initialize_database_compat() -> bool:
"""
兼容性数据库初始化函数
用于替换原有的Peewee初始化代码
Returns:
bool: 初始化是否成功
"""
global _database_initialized
if _database_initialized:
return True
success = initialize_sqlalchemy_database()
if success:
success = create_all_tables()
if success:
success = check_database_connection()
if success:
_database_initialized = True
return success

View File

@@ -0,0 +1,555 @@
"""SQLAlchemy数据库模型定义
替换Peewee ORM使用SQLAlchemy提供更好的连接池管理和错误恢复能力
"""
from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, create_engine, DateTime
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import QueuePool
import os
import datetime
from src.config.config import global_config
from src.common.logger import get_logger
import threading
from contextlib import contextmanager
logger = get_logger("sqlalchemy_models")
# 创建基类
Base = declarative_base()
# MySQL兼容的字段类型辅助函数
def get_string_field(max_length=255, **kwargs):
"""
根据数据库类型返回合适的字符串字段
MySQL需要指定长度的VARCHAR用于索引SQLite可以使用Text
"""
if global_config.database.database_type == "mysql":
return String(max_length, **kwargs)
else:
return Text(**kwargs)
class SessionProxy:
"""线程安全的Session代理类自动管理session生命周期"""
def __init__(self):
self._local = threading.local()
def _get_current_session(self):
"""获取当前线程的session如果没有则创建新的"""
if not hasattr(self._local, 'session') or self._local.session is None:
_, SessionLocal = initialize_database()
self._local.session = SessionLocal()
return self._local.session
def _close_current_session(self):
"""关闭当前线程的session"""
if hasattr(self._local, 'session') and self._local.session is not None:
try:
self._local.session.close()
except:
pass
finally:
self._local.session = None
def __getattr__(self, name):
"""代理所有session方法"""
session = self._get_current_session()
attr = getattr(session, name)
# 如果是方法,需要特殊处理一些关键方法
if callable(attr):
if name in ['commit', 'rollback']:
def wrapper(*args, **kwargs):
try:
result = attr(*args, **kwargs)
if name == 'commit':
# commit后不要清除session只是刷新状态
pass # 保持session活跃
return result
except Exception as e:
try:
if session and hasattr(session, 'rollback'):
session.rollback()
except:
pass
# 发生错误时重新创建session
self._close_current_session()
raise
return wrapper
elif name == 'close':
def wrapper(*args, **kwargs):
result = attr(*args, **kwargs)
self._close_current_session()
return result
return wrapper
elif name in ['execute', 'query', 'add', 'delete', 'merge']:
def wrapper(*args, **kwargs):
try:
return attr(*args, **kwargs)
except Exception as e:
# 如果是连接相关错误重新创建session再试一次
if "not bound to a Session" in str(e) or "provisioning a new connection" in str(e):
logger.warning(f"Session问题重新创建session: {e}")
self._close_current_session()
new_session = self._get_current_session()
new_attr = getattr(new_session, name)
return new_attr(*args, **kwargs)
raise
return wrapper
return attr
def new_session(self):
"""强制创建新的session关闭当前的创建新的"""
self._close_current_session()
return self._get_current_session()
def ensure_fresh_session(self):
"""确保使用新鲜的session如果当前session有问题则重新创建"""
if hasattr(self._local, 'session') and self._local.session is not None:
try:
# 测试session是否还可用
self._local.session.execute("SELECT 1")
except Exception:
# session有问题重新创建
self._close_current_session()
return self._get_current_session()
# 创建全局session代理实例
_global_session_proxy = SessionProxy()
def get_session():
"""返回线程安全的session代理自动管理生命周期"""
return _global_session_proxy
class ChatStreams(Base):
"""聊天流模型"""
__tablename__ = 'chat_streams'
id = Column(Integer, primary_key=True, autoincrement=True)
stream_id = Column(get_string_field(64), nullable=False, unique=True, index=True)
create_time = Column(Float, nullable=False)
group_platform = Column(Text, nullable=True)
group_id = Column(get_string_field(100), nullable=True, index=True)
group_name = Column(Text, nullable=True)
last_active_time = Column(Float, nullable=False)
platform = Column(Text, nullable=False)
user_platform = Column(Text, nullable=False)
user_id = Column(get_string_field(100), nullable=False, index=True)
user_nickname = Column(Text, nullable=False)
user_cardname = Column(Text, nullable=True)
__table_args__ = (
Index('idx_chatstreams_stream_id', 'stream_id'),
Index('idx_chatstreams_user_id', 'user_id'),
Index('idx_chatstreams_group_id', 'group_id'),
)
class LLMUsage(Base):
"""LLM使用记录模型"""
__tablename__ = 'llm_usage'
id = Column(Integer, primary_key=True, autoincrement=True)
model_name = Column(get_string_field(100), nullable=False, index=True)
user_id = Column(get_string_field(50), nullable=False, index=True)
request_type = Column(get_string_field(50), nullable=False, index=True)
endpoint = Column(Text, nullable=False)
prompt_tokens = Column(Integer, nullable=False)
completion_tokens = Column(Integer, nullable=False)
total_tokens = Column(Integer, nullable=False)
cost = Column(Float, nullable=False)
status = Column(Text, nullable=False)
timestamp = Column(DateTime, nullable=False, index=True, default=datetime.datetime.now)
__table_args__ = (
Index('idx_llmusage_model_name', 'model_name'),
Index('idx_llmusage_user_id', 'user_id'),
Index('idx_llmusage_request_type', 'request_type'),
Index('idx_llmusage_timestamp', 'timestamp'),
)
class Emoji(Base):
"""表情包模型"""
__tablename__ = 'emoji'
id = Column(Integer, primary_key=True, autoincrement=True)
full_path = Column(get_string_field(500), nullable=False, unique=True, index=True)
format = Column(Text, nullable=False)
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
description = Column(Text, nullable=False)
query_count = Column(Integer, nullable=False, default=0)
is_registered = Column(Boolean, nullable=False, default=False)
is_banned = Column(Boolean, nullable=False, default=False)
emotion = Column(Text, nullable=True)
record_time = Column(Float, nullable=False)
register_time = Column(Float, nullable=True)
usage_count = Column(Integer, nullable=False, default=0)
last_used_time = Column(Float, nullable=True)
__table_args__ = (
Index('idx_emoji_full_path', 'full_path'),
Index('idx_emoji_hash', 'emoji_hash'),
)
class Messages(Base):
"""消息模型"""
__tablename__ = 'messages'
id = Column(Integer, primary_key=True, autoincrement=True)
message_id = Column(get_string_field(100), nullable=False, index=True)
time = Column(Float, nullable=False)
chat_id = Column(get_string_field(64), nullable=False, index=True)
reply_to = Column(Text, nullable=True)
interest_value = Column(Float, nullable=True)
is_mentioned = Column(Boolean, nullable=True)
# 从 chat_info 扁平化而来的字段
chat_info_stream_id = Column(Text, nullable=False)
chat_info_platform = Column(Text, nullable=False)
chat_info_user_platform = Column(Text, nullable=False)
chat_info_user_id = Column(Text, nullable=False)
chat_info_user_nickname = Column(Text, nullable=False)
chat_info_user_cardname = Column(Text, nullable=True)
chat_info_group_platform = Column(Text, nullable=True)
chat_info_group_id = Column(Text, nullable=True)
chat_info_group_name = Column(Text, nullable=True)
chat_info_create_time = Column(Float, nullable=False)
chat_info_last_active_time = Column(Float, nullable=False)
# 从顶层 user_info 扁平化而来的字段
user_platform = Column(Text, nullable=True)
user_id = Column(get_string_field(100), nullable=True, index=True)
user_nickname = Column(Text, nullable=True)
user_cardname = Column(Text, nullable=True)
processed_plain_text = Column(Text, nullable=True)
display_message = Column(Text, nullable=True)
memorized_times = Column(Integer, nullable=False, default=0)
priority_mode = Column(Text, nullable=True)
priority_info = Column(Text, nullable=True)
additional_config = Column(Text, nullable=True)
is_emoji = Column(Boolean, nullable=False, default=False)
is_picid = Column(Boolean, nullable=False, default=False)
is_command = Column(Boolean, nullable=False, default=False)
is_notify = Column(Boolean, nullable=False, default=False)
__table_args__ = (
Index('idx_messages_message_id', 'message_id'),
Index('idx_messages_chat_id', 'chat_id'),
Index('idx_messages_time', 'time'),
Index('idx_messages_user_id', 'user_id'),
)
class ActionRecords(Base):
"""动作记录模型"""
__tablename__ = 'action_records'
id = Column(Integer, primary_key=True, autoincrement=True)
action_id = Column(get_string_field(100), nullable=False, index=True)
time = Column(Float, nullable=False)
action_name = Column(Text, nullable=False)
action_data = Column(Text, nullable=False)
action_done = Column(Boolean, nullable=False, default=False)
action_build_into_prompt = Column(Boolean, nullable=False, default=False)
action_prompt_display = Column(Text, nullable=False)
chat_id = Column(get_string_field(64), nullable=False, index=True)
chat_info_stream_id = Column(Text, nullable=False)
chat_info_platform = Column(Text, nullable=False)
__table_args__ = (
Index('idx_actionrecords_action_id', 'action_id'),
Index('idx_actionrecords_chat_id', 'chat_id'),
Index('idx_actionrecords_time', 'time'),
)
class Images(Base):
"""图像信息模型"""
__tablename__ = 'images'
id = Column(Integer, primary_key=True, autoincrement=True)
image_id = Column(Text, nullable=False, default="")
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
description = Column(Text, nullable=True)
path = Column(get_string_field(500), nullable=False, unique=True)
count = Column(Integer, nullable=False, default=1)
timestamp = Column(Float, nullable=False)
type = Column(Text, nullable=False)
vlm_processed = Column(Boolean, nullable=False, default=False)
__table_args__ = (
Index('idx_images_emoji_hash', 'emoji_hash'),
Index('idx_images_path', 'path'),
)
class ImageDescriptions(Base):
"""图像描述信息模型"""
__tablename__ = 'image_descriptions'
id = Column(Integer, primary_key=True, autoincrement=True)
type = Column(Text, nullable=False)
image_description_hash = Column(get_string_field(64), nullable=False, index=True)
description = Column(Text, nullable=False)
timestamp = Column(Float, nullable=False)
__table_args__ = (
Index('idx_imagedesc_hash', 'image_description_hash'),
)
class OnlineTime(Base):
"""在线时长记录模型"""
__tablename__ = 'online_time'
id = Column(Integer, primary_key=True, autoincrement=True)
timestamp = Column(Text, nullable=False, default=str(datetime.datetime.now))
duration = Column(Integer, nullable=False)
start_timestamp = Column(DateTime, nullable=False, default=datetime.datetime.now)
end_timestamp = Column(DateTime, nullable=False, index=True)
__table_args__ = (
Index('idx_onlinetime_end_timestamp', 'end_timestamp'),
)
class PersonInfo(Base):
"""人物信息模型"""
__tablename__ = 'person_info'
id = Column(Integer, primary_key=True, autoincrement=True)
person_id = Column(get_string_field(100), nullable=False, unique=True, index=True)
person_name = Column(Text, nullable=True)
name_reason = Column(Text, nullable=True)
platform = Column(Text, nullable=False)
user_id = Column(get_string_field(50), nullable=False, index=True)
nickname = Column(Text, nullable=True)
impression = Column(Text, nullable=True)
short_impression = Column(Text, nullable=True)
points = Column(Text, nullable=True)
forgotten_points = Column(Text, nullable=True)
info_list = Column(Text, nullable=True)
know_times = Column(Float, nullable=True)
know_since = Column(Float, nullable=True)
last_know = Column(Float, nullable=True)
attitude = Column(Integer, nullable=True, default=50)
__table_args__ = (
Index('idx_personinfo_person_id', 'person_id'),
Index('idx_personinfo_user_id', 'user_id'),
)
class Memory(Base):
"""记忆模型"""
__tablename__ = 'memory'
id = Column(Integer, primary_key=True, autoincrement=True)
memory_id = Column(get_string_field(64), nullable=False, index=True)
chat_id = Column(Text, nullable=True)
memory_text = Column(Text, nullable=True)
keywords = Column(Text, nullable=True)
create_time = Column(Float, nullable=True)
last_view_time = Column(Float, nullable=True)
__table_args__ = (
Index('idx_memory_memory_id', 'memory_id'),
)
class Expression(Base):
"""表达风格模型"""
__tablename__ = 'expression'
id = Column(Integer, primary_key=True, autoincrement=True)
situation = Column(Text, nullable=False)
style = Column(Text, nullable=False)
count = Column(Float, nullable=False)
last_active_time = Column(Float, nullable=False)
chat_id = Column(get_string_field(64), nullable=False, index=True)
type = Column(Text, nullable=False)
create_date = Column(Float, nullable=True)
__table_args__ = (
Index('idx_expression_chat_id', 'chat_id'),
)
class ThinkingLog(Base):
"""思考日志模型"""
__tablename__ = 'thinking_logs'
id = Column(Integer, primary_key=True, autoincrement=True)
chat_id = Column(get_string_field(64), nullable=False, index=True)
trigger_text = Column(Text, nullable=True)
response_text = Column(Text, nullable=True)
trigger_info_json = Column(Text, nullable=True)
response_info_json = Column(Text, nullable=True)
timing_results_json = Column(Text, nullable=True)
chat_history_json = Column(Text, nullable=True)
chat_history_in_thinking_json = Column(Text, nullable=True)
chat_history_after_response_json = Column(Text, nullable=True)
heartflow_data_json = Column(Text, nullable=True)
reasoning_data_json = Column(Text, nullable=True)
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
__table_args__ = (
Index('idx_thinkinglog_chat_id', 'chat_id'),
)
class GraphNodes(Base):
"""记忆图节点模型"""
__tablename__ = 'graph_nodes'
id = Column(Integer, primary_key=True, autoincrement=True)
concept = Column(get_string_field(255), nullable=False, unique=True, index=True)
memory_items = Column(Text, nullable=False)
hash = Column(Text, nullable=False)
created_time = Column(Float, nullable=False)
last_modified = Column(Float, nullable=False)
__table_args__ = (
Index('idx_graphnodes_concept', 'concept'),
)
class GraphEdges(Base):
"""记忆图边模型"""
__tablename__ = 'graph_edges'
id = Column(Integer, primary_key=True, autoincrement=True)
source = Column(get_string_field(255), nullable=False, index=True)
target = Column(get_string_field(255), nullable=False, index=True)
strength = Column(Integer, nullable=False)
hash = Column(Text, nullable=False)
created_time = Column(Float, nullable=False)
last_modified = Column(Float, nullable=False)
__table_args__ = (
Index('idx_graphedges_source', 'source'),
Index('idx_graphedges_target', 'target'),
)
# 数据库引擎和会话管理
_engine = None
_SessionLocal = None
def get_database_url():
"""获取数据库连接URL"""
config = global_config.database
if config.database_type == "mysql":
# 对用户名和密码进行URL编码处理特殊字符
from urllib.parse import quote_plus
encoded_user = quote_plus(config.mysql_user)
encoded_password = quote_plus(config.mysql_password)
return (
f"mysql+pymysql://{encoded_user}:{encoded_password}"
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
f"?charset={config.mysql_charset}"
)
else: # SQLite
# 如果是相对路径,则相对于项目根目录
if not os.path.isabs(config.sqlite_path):
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
else:
db_path = config.sqlite_path
# 确保数据库目录存在
os.makedirs(os.path.dirname(db_path), exist_ok=True)
return f"sqlite:///{db_path}"
def initialize_database():
"""初始化数据库引擎和会话"""
global _engine, _SessionLocal
if _engine is not None:
return _engine, _SessionLocal
database_url = get_database_url()
config = global_config.database
# 配置引擎参数
engine_kwargs = {
'echo': False, # 生产环境关闭SQL日志
'future': True,
}
if config.database_type == "mysql":
# MySQL连接池配置
engine_kwargs.update({
'poolclass': QueuePool,
'pool_size': config.connection_pool_size,
'max_overflow': config.connection_pool_size * 2,
'pool_timeout': config.connection_timeout,
'pool_recycle': 3600, # 1小时回收连接
'pool_pre_ping': True, # 连接前ping检查
'connect_args': {
'autocommit': config.mysql_autocommit,
'charset': config.mysql_charset,
'connect_timeout': config.connection_timeout,
'read_timeout': 30,
'write_timeout': 30,
}
})
else:
# SQLite配置 - 添加连接池设置以避免连接耗尽
engine_kwargs.update({
'poolclass': QueuePool,
'pool_size': 20, # 增加池大小
'max_overflow': 30, # 增加溢出连接数
'pool_timeout': 60, # 增加超时时间
'pool_recycle': 3600, # 1小时回收连接
'pool_pre_ping': True, # 连接前ping检查
'connect_args': {
'check_same_thread': False,
'timeout': 30,
}
})
_engine = create_engine(database_url, **engine_kwargs)
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=_engine)
# 创建所有表
Base.metadata.create_all(bind=_engine)
logger.info(f"SQLAlchemy数据库初始化成功: {config.database_type}")
return _engine, _SessionLocal
@contextmanager
def get_db_session():
"""数据库会话上下文管理器 - 推荐使用这个而不是get_session()"""
session = None
try:
_, SessionLocal = initialize_database()
session = SessionLocal()
yield session
session.commit()
except Exception as e:
if session:
session.rollback()
raise
finally:
if session:
session.close()
def get_engine():
"""获取数据库引擎"""
engine, _ = initialize_database()
return engine

808
src/common/logger.py Normal file
View File

@@ -0,0 +1,808 @@
# 使用基于时间戳的文件处理器,简单的轮转份数限制
import logging
import json
import threading
import time
import structlog
import tomlkit
from pathlib import Path
from typing import Callable, Optional
from datetime import datetime, timedelta
# 创建logs目录
LOG_DIR = Path("logs")
LOG_DIR.mkdir(exist_ok=True)
# 全局handler实例避免重复创建
_file_handler = None
_console_handler = None
def get_file_handler():
"""获取文件handler单例"""
global _file_handler
if _file_handler is None:
# 确保日志目录存在
LOG_DIR.mkdir(exist_ok=True)
# 检查现有handler避免重复创建
root_logger = logging.getLogger()
for handler in root_logger.handlers:
if isinstance(handler, TimestampedFileHandler):
_file_handler = handler
return _file_handler
# 使用基于时间戳的handler简单的轮转份数限制
_file_handler = TimestampedFileHandler(
log_dir=LOG_DIR,
max_bytes=5 * 1024 * 1024, # 5MB
backup_count=30,
encoding="utf-8",
)
# 设置文件handler的日志级别
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
_file_handler.setLevel(getattr(logging, file_level.upper(), logging.INFO))
return _file_handler
def get_console_handler():
"""获取控制台handler单例"""
global _console_handler
if _console_handler is None:
_console_handler = logging.StreamHandler()
# 设置控制台handler的日志级别
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
_console_handler.setLevel(getattr(logging, console_level.upper(), logging.INFO))
return _console_handler
class TimestampedFileHandler(logging.Handler):
"""基于时间戳的文件处理器,简单的轮转份数限制"""
def __init__(self, log_dir, max_bytes=5 * 1024 * 1024, backup_count=30, encoding="utf-8"):
super().__init__()
self.log_dir = Path(log_dir)
self.log_dir.mkdir(exist_ok=True)
self.max_bytes = max_bytes
self.backup_count = backup_count
self.encoding = encoding
self._lock = threading.Lock()
# 当前活跃的日志文件
self.current_file = None
self.current_stream = None
self._init_current_file()
def _init_current_file(self):
"""初始化当前日志文件"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.current_file = self.log_dir / f"app_{timestamp}.log.jsonl"
self.current_stream = open(self.current_file, "a", encoding=self.encoding)
def _should_rollover(self):
"""检查是否需要轮转"""
if self.current_file and self.current_file.exists():
return self.current_file.stat().st_size >= self.max_bytes
return False
def _do_rollover(self):
"""执行轮转:关闭当前文件,创建新文件"""
if self.current_stream:
self.current_stream.close()
# 清理旧文件
self._cleanup_old_files()
# 创建新文件
self._init_current_file()
def _cleanup_old_files(self):
"""清理旧的日志文件,保留指定数量"""
try:
# 获取所有日志文件
log_files = list(self.log_dir.glob("app_*.log.jsonl"))
# 按修改时间排序
log_files.sort(key=lambda f: f.stat().st_mtime, reverse=True)
# 删除超出数量限制的文件
for old_file in log_files[self.backup_count :]:
try:
old_file.unlink()
print(f"[日志清理] 删除旧文件: {old_file.name}")
except Exception as e:
print(f"[日志清理] 删除失败 {old_file}: {e}")
except Exception as e:
print(f"[日志清理] 清理过程出错: {e}")
def emit(self, record):
"""发出日志记录"""
try:
with self._lock:
# 检查是否需要轮转
if self._should_rollover():
self._do_rollover()
# 写入日志
if self.current_stream:
msg = self.format(record)
self.current_stream.write(msg + "\n")
self.current_stream.flush()
except Exception:
self.handleError(record)
def close(self):
"""关闭处理器"""
with self._lock:
if self.current_stream:
self.current_stream.close()
self.current_stream = None
super().close()
# 旧的轮转文件处理器已移除,现在使用基于时间戳的处理器
def close_handlers():
"""安全关闭所有handler"""
global _file_handler, _console_handler
if _file_handler:
_file_handler.close()
_file_handler = None
if _console_handler:
_console_handler.close()
_console_handler = None
def remove_duplicate_handlers(): # sourcery skip: for-append-to-extend, list-comprehension
"""移除重复的handler特别是文件handler"""
root_logger = logging.getLogger()
# 收集所有时间戳文件handler
file_handlers = []
for handler in root_logger.handlers[:]:
if isinstance(handler, TimestampedFileHandler):
file_handlers.append(handler)
# 如果有多个文件handler保留第一个关闭其他的
if len(file_handlers) > 1:
print(f"[日志系统] 检测到 {len(file_handlers)} 个重复的文件handler正在清理...")
for i, handler in enumerate(file_handlers[1:], 1):
print(f"[日志系统] 关闭重复的文件handler {i}")
root_logger.removeHandler(handler)
handler.close()
# 更新全局引用
global _file_handler
_file_handler = file_handlers[0]
# 读取日志配置
def load_log_config(): # sourcery skip: use-contextlib-suppress
"""从配置文件加载日志设置"""
config_path = Path("config/bot_config.toml")
default_config = {
"date_style": "m-d H:i:s",
"log_level_style": "lite",
"color_text": "full",
"log_level": "INFO", # 全局日志级别(向下兼容)
"console_log_level": "INFO", # 控制台日志级别
"file_log_level": "DEBUG", # 文件日志级别
"suppress_libraries": ["faiss","httpx", "urllib3", "asyncio", "websockets", "httpcore", "requests", "peewee", "openai","uvicorn","jieba"],
"library_log_levels": { "aiohttp": "WARNING"},
}
try:
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f:
config = tomlkit.load(f)
return config.get("log", default_config)
except Exception as e:
print(f"[日志系统] 加载日志配置失败: {e}")
pass
return default_config
LOG_CONFIG = load_log_config()
def get_timestamp_format():
"""将配置中的日期格式转换为Python格式"""
date_style = LOG_CONFIG.get("date_style", "Y-m-d H:i:s")
# 转换PHP风格的日期格式到Python格式
format_map = {
"Y": "%Y", # 4位年份
"m": "%m", # 月份01-12
"d": "%d", # 日期01-31
"H": "%H", # 小时00-23
"i": "%M", # 分钟00-59
"s": "%S", # 秒数00-59
}
python_format = date_style
for php_char, python_char in format_map.items():
python_format = python_format.replace(php_char, python_char)
return python_format
def configure_third_party_loggers():
"""配置第三方库的日志级别"""
# 设置根logger级别为所有handler中最低的级别确保所有日志都能被捕获
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
# 获取最低级别DEBUG < INFO < WARNING < ERROR < CRITICAL
console_level_num = getattr(logging, console_level.upper(), logging.INFO)
file_level_num = getattr(logging, file_level.upper(), logging.INFO)
min_level = min(console_level_num, file_level_num)
root_logger = logging.getLogger()
root_logger.setLevel(min_level)
# 完全屏蔽的库
suppress_libraries = LOG_CONFIG.get("suppress_libraries", [])
for lib_name in suppress_libraries:
lib_logger = logging.getLogger(lib_name)
lib_logger.setLevel(logging.CRITICAL + 1) # 设置为比CRITICAL更高的级别基本屏蔽所有日志
lib_logger.propagate = False # 阻止向上传播
# 设置特定级别的库
library_log_levels = LOG_CONFIG.get("library_log_levels", {})
for lib_name, level_name in library_log_levels.items():
lib_logger = logging.getLogger(lib_name)
level = getattr(logging, level_name.upper(), logging.WARNING)
lib_logger.setLevel(level)
def reconfigure_existing_loggers():
"""重新配置所有已存在的logger解决加载顺序问题"""
# 获取根logger
root_logger = logging.getLogger()
# 重新设置根logger的所有handler的格式化器
for handler in root_logger.handlers:
if isinstance(handler, TimestampedFileHandler):
handler.setFormatter(file_formatter)
elif isinstance(handler, logging.StreamHandler):
handler.setFormatter(console_formatter)
# 遍历所有已存在的logger并重新配置
logger_dict = logging.getLogger().manager.loggerDict
for name, logger_obj in logger_dict.items():
if isinstance(logger_obj, logging.Logger):
# 检查是否是第三方库logger
suppress_libraries = LOG_CONFIG.get("suppress_libraries", [])
library_log_levels = LOG_CONFIG.get("library_log_levels", {})
# 如果在屏蔽列表中
if any(name.startswith(lib) for lib in suppress_libraries):
logger_obj.setLevel(logging.CRITICAL + 1)
logger_obj.propagate = False
continue
# 如果在特定级别设置中
for lib_name, level_name in library_log_levels.items():
if name.startswith(lib_name):
level = getattr(logging, level_name.upper(), logging.WARNING)
logger_obj.setLevel(level)
break
# 强制清除并重新设置所有handler
original_handlers = logger_obj.handlers[:]
for handler in original_handlers:
# 安全关闭handler
if hasattr(handler, "close"):
handler.close()
logger_obj.removeHandler(handler)
# 如果logger没有handler让它使用根logger的handlerpropagate=True
if not logger_obj.handlers:
logger_obj.propagate = True
# 如果logger有自己的handler重新配置它们避免重复创建文件handler
for handler in original_handlers:
if isinstance(handler, TimestampedFileHandler):
# 不重新添加让它使用根logger的文件handler
continue
elif isinstance(handler, logging.StreamHandler):
handler.setFormatter(console_formatter)
logger_obj.addHandler(handler)
# 定义模块颜色映射
MODULE_COLORS = {
# 核心模块
"main": "\033[1;97m", # 亮白色+粗体 (主程序)
"api": "\033[92m", # 亮绿色
"emoji": "\033[38;5;214m", # 橙黄色偏向橙色但与replyer和action_manager不同
"chat": "\033[92m", # 亮蓝色
"config": "\033[93m", # 亮黄色
"common": "\033[95m", # 亮紫色
"tools": "\033[96m", # 亮青色
"lpmm": "\033[96m",
"plugin_system": "\033[91m", # 亮红色
"person_info": "\033[32m", # 绿色
"individuality": "\033[94m", # 显眼的亮蓝色
"manager": "\033[35m", # 紫色
"llm_models": "\033[36m", # 青色
"remote": "\033[38;5;242m", # 深灰色,更不显眼
"planner": "\033[36m",
"memory": "\033[38;5;117m", # 天蓝色
"hfc": "\033[38;5;81m", # 稍微暗一些的青色,保持可读
"action_manager": "\033[38;5;208m", # 橙色不与replyer重复
# 关系系统
"relation": "\033[38;5;139m", # 柔和的紫色,不刺眼
# 聊天相关模块
"normal_chat": "\033[38;5;81m", # 亮蓝绿色
"heartflow": "\033[38;5;175m", # 柔和的粉色,不显眼但保持粉色系
"sub_heartflow": "\033[38;5;207m", # 粉紫色
"subheartflow_manager": "\033[38;5;201m", # 深粉色
"background_tasks": "\033[38;5;240m", # 灰色
"chat_message": "\033[38;5;45m", # 青色
"chat_stream": "\033[38;5;51m", # 亮青色
"sender": "\033[38;5;67m", # 稍微暗一些的蓝色,不显眼
"message_storage": "\033[38;5;33m", # 深蓝色
"expressor": "\033[38;5;166m", # 橙色
# 专注聊天模块
"replyer": "\033[38;5;166m", # 橙色
"memory_activator": "\033[38;5;117m", # 天蓝色
# 插件系统
"plugins": "\033[31m", # 红色
"plugin_api": "\033[33m", # 黄色
"plugin_manager": "\033[38;5;208m", # 红色
"base_plugin": "\033[38;5;202m", # 橙红色
"send_api": "\033[38;5;208m", # 橙色
"base_command": "\033[38;5;208m", # 橙色
"component_registry": "\033[38;5;214m", # 橙黄色
"stream_api": "\033[38;5;220m", # 黄色
"plugin_hot_reload": "\033[38;5;226m", #品红色
"config_api": "\033[38;5;226m", # 亮黄色
"heartflow_api": "\033[38;5;154m", # 黄绿色
"action_apis": "\033[38;5;118m", # 绿色
"independent_apis": "\033[38;5;82m", # 绿色
"llm_api": "\033[38;5;46m", # 亮绿色
"database_api": "\033[38;5;10m", # 绿色
"utils_api": "\033[38;5;14m", # 青色
"message_api": "\033[38;5;6m", # 青色
# 管理器模块
"async_task_manager": "\033[38;5;129m", # 紫色
"mood": "\033[38;5;135m", # 紫红色
"local_storage": "\033[38;5;141m", # 紫色
"willing": "\033[38;5;147m", # 浅紫色
# 工具模块
"tool_use": "\033[38;5;172m", # 橙褐色
"tool_executor": "\033[38;5;172m", # 橙褐色
"base_tool": "\033[38;5;178m", # 金黄色
# 工具和实用模块
"prompt_build": "\033[38;5;105m", # 紫色
"chat_utils": "\033[38;5;111m", # 蓝色
"chat_image": "\033[38;5;117m", # 浅蓝色
"maibot_statistic": "\033[38;5;129m", # 紫色
# 特殊功能插件
"mute_plugin": "\033[38;5;240m", # 灰色
"core_actions": "\033[38;5;117m", # 深红色
"tts_action": "\033[38;5;58m", # 深黄色
"doubao_pic_plugin": "\033[38;5;64m", # 深绿色
# Action组件
"no_reply_action": "\033[38;5;214m", # 亮橙色,显眼但不像警告
"reply_action": "\033[38;5;46m", # 亮绿色
"base_action": "\033[38;5;250m", # 浅灰色
# 数据库和消息
"database_model": "\033[38;5;94m", # 橙褐色
"database": "\033[38;5;46m", # 橙褐色
"maim_message": "\033[38;5;140m", # 紫褐色
# 日志系统
"logger": "\033[38;5;8m", # 深灰色
"confirm": "\033[1;93m", # 黄色+粗体
# 模型相关
"model_utils": "\033[38;5;164m", # 紫红色
"relationship_fetcher": "\033[38;5;170m", # 浅紫色
"relationship_builder": "\033[38;5;93m", # 浅蓝色
#s4u
"context_web_api": "\033[38;5;240m", # 深灰色
"S4U_chat": "\033[92m", # 深灰色
}
# 定义模块别名映射 - 将真实的logger名称映射到显示的别名
MODULE_ALIASES = {
# 示例映射
"individuality": "人格特质",
"emoji": "表情包",
"no_reply_action": "摸鱼",
"reply_action": "回复",
"action_manager": "动作",
"memory_activator": "记忆",
"tool_use": "工具",
"expressor": "表达方式",
"plugin_hot_reload": "热重载",
"database": "数据库",
"database_model": "数据库",
"mood": "情绪",
"memory": "记忆",
"tool_executor": "工具",
"hfc": "聊天节奏",
"chat": "所见",
"plugin_manager": "插件",
"relationship_builder": "关系",
"llm_models": "模型",
"person_info": "人物",
"chat_stream": "聊天流",
"planner": "规划器",
"replyer": "言语",
"config": "配置",
"main": "主程序",
}
RESET_COLOR = "\033[0m"
class ModuleColoredConsoleRenderer:
"""自定义控制台渲染器,为不同模块提供不同颜色"""
def __init__(self, colors=True):
# sourcery skip: merge-duplicate-blocks, remove-redundant-if
self._colors = colors
self._config = LOG_CONFIG
# 日志级别颜色
self._level_colors = {
"debug": "\033[38;5;208m", # 橙色
"info": "\033[38;5;117m", # 天蓝色
"success": "\033[32m", # 绿色
"warning": "\033[33m", # 黄色
"error": "\033[31m", # 红色
"critical": "\033[35m", # 紫色
}
# 根据配置决定是否启用颜色
color_text = self._config.get("color_text", "title")
if color_text == "none":
self._colors = False
elif color_text == "title":
self._enable_module_colors = True
self._enable_level_colors = False
self._enable_full_content_colors = False
elif color_text == "full":
self._enable_module_colors = True
self._enable_level_colors = True
self._enable_full_content_colors = True
else:
self._enable_module_colors = True
self._enable_level_colors = False
self._enable_full_content_colors = False
def __call__(self, logger, method_name, event_dict):
# sourcery skip: merge-duplicate-blocks
"""渲染日志消息"""
# 获取基本信息
timestamp = event_dict.get("timestamp", "")
level = event_dict.get("level", "info")
logger_name = event_dict.get("logger_name", "")
event = event_dict.get("event", "")
# 构建输出
parts = []
# 日志级别样式配置
log_level_style = self._config.get("log_level_style", "lite")
level_color = self._level_colors.get(level.lower(), "") if self._colors else ""
# 时间戳lite模式下按级别着色
if timestamp:
if log_level_style == "lite" and level_color:
timestamp_part = f"{level_color}{timestamp}{RESET_COLOR}"
else:
timestamp_part = timestamp
parts.append(timestamp_part)
# 日志级别显示(根据配置样式)
if log_level_style == "full":
# 显示完整级别名并着色
level_text = level.upper()
if level_color:
level_part = f"{level_color}[{level_text:>8}]{RESET_COLOR}"
else:
level_part = f"[{level_text:>8}]"
parts.append(level_part)
elif log_level_style == "compact":
# 只显示首字母并着色
level_text = level.upper()[0]
if level_color:
level_part = f"{level_color}[{level_text:>8}]{RESET_COLOR}"
else:
level_part = f"[{level_text:>8}]"
parts.append(level_part)
# lite模式不显示级别只给时间戳着色
# 获取模块颜色用于full模式下的整体着色
module_color = ""
if self._colors and self._enable_module_colors and logger_name:
module_color = MODULE_COLORS.get(logger_name, "")
# 模块名称(带颜色和别名支持)
if logger_name:
# 获取别名,如果没有别名则使用原名称
display_name = MODULE_ALIASES.get(logger_name, logger_name)
if self._colors and self._enable_module_colors:
if module_color:
module_part = f"{module_color}[{display_name}]{RESET_COLOR}"
else:
module_part = f"[{display_name}]"
else:
module_part = f"[{display_name}]"
parts.append(module_part)
# 消息内容(确保转换为字符串)
event_content = ""
if isinstance(event, str):
event_content = event
elif isinstance(event, dict):
# 如果是字典,格式化为可读字符串
try:
event_content = json.dumps(event, ensure_ascii=False, indent=None)
except (TypeError, ValueError):
event_content = str(event)
else:
# 其他类型直接转换为字符串
event_content = str(event)
# 在full模式下为消息内容着色
if self._colors and self._enable_full_content_colors and module_color:
event_content = f"{module_color}{event_content}{RESET_COLOR}"
parts.append(event_content)
# 处理其他字段
extras = []
for key, value in event_dict.items():
if key not in ("timestamp", "level", "logger_name", "event"):
# 确保值也转换为字符串
if isinstance(value, (dict, list)):
try:
value_str = json.dumps(value, ensure_ascii=False, indent=None)
except (TypeError, ValueError):
value_str = str(value)
else:
value_str = str(value)
# 在full模式下为额外字段着色
extra_field = f"{key}={value_str}"
if self._colors and self._enable_full_content_colors and module_color:
extra_field = f"{module_color}{extra_field}{RESET_COLOR}"
extras.append(extra_field)
if extras:
parts.append(" ".join(extras))
return " ".join(parts)
# 配置标准logging以支持文件输出和压缩
# 使用单例handler避免重复创建
file_handler = get_file_handler()
console_handler = get_console_handler()
logging.basicConfig(
level=logging.INFO,
format="%(message)s",
handlers=[file_handler, console_handler],
)
def configure_structlog():
"""配置structlog"""
structlog.configure(
processors=[
structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level,
structlog.processors.StackInfoRenderer(),
structlog.dev.set_exc_info,
structlog.processors.TimeStamper(fmt=get_timestamp_format(), utc=False),
# 根据输出类型选择不同的渲染器
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
],
wrapper_class=structlog.stdlib.BoundLogger,
context_class=dict,
logger_factory=structlog.stdlib.LoggerFactory(),
cache_logger_on_first_use=True,
)
# 配置structlog
configure_structlog()
# 为文件输出配置JSON格式
file_formatter = structlog.stdlib.ProcessorFormatter(
processor=structlog.processors.JSONRenderer(ensure_ascii=False),
foreign_pre_chain=[
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
],
)
# 为控制台输出配置可读格式
console_formatter = structlog.stdlib.ProcessorFormatter(
processor=ModuleColoredConsoleRenderer(colors=True),
foreign_pre_chain=[
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt=get_timestamp_format(), utc=False),
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
],
)
# 获取根logger并配置格式化器
root_logger = logging.getLogger()
for handler in root_logger.handlers:
if isinstance(handler, TimestampedFileHandler):
handler.setFormatter(file_formatter)
else:
handler.setFormatter(console_formatter)
# 立即配置日志系统,确保最早期的日志也使用正确格式
def _immediate_setup():
"""立即设置日志系统,在模块导入时就生效"""
# 重新配置structlog
configure_structlog()
# 清除所有已有的handler重新配置
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
# 使用单例handler避免重复创建
file_handler = get_file_handler()
console_handler = get_console_handler()
# 重新添加配置好的handler
root_logger.addHandler(file_handler)
root_logger.addHandler(console_handler)
# 设置格式化器
file_handler.setFormatter(file_formatter)
console_handler.setFormatter(console_formatter)
# 清理重复的handler
remove_duplicate_handlers()
# 配置第三方库日志
configure_third_party_loggers()
# 重新配置所有已存在的logger
reconfigure_existing_loggers()
# 立即执行配置
_immediate_setup()
raw_logger: structlog.stdlib.BoundLogger = structlog.get_logger()
binds: dict[str, Callable] = {}
def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger:
"""获取logger实例支持按名称绑定"""
if name is None:
return raw_logger
logger = binds.get(name) # type: ignore
if logger is None:
logger: structlog.stdlib.BoundLogger = structlog.get_logger(name).bind(logger_name=name)
binds[name] = logger
return logger
def initialize_logging():
"""手动初始化日志系统确保所有logger都使用正确的配置
在应用程序的早期调用此函数,确保所有模块都使用统一的日志配置
"""
global LOG_CONFIG
LOG_CONFIG = load_log_config()
# print(LOG_CONFIG)
configure_third_party_loggers()
reconfigure_existing_loggers()
# 启动日志清理任务
start_log_cleanup_task()
# 输出初始化信息
logger = get_logger("logger")
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
logger.info("日志系统已初始化:")
logger.info(f" - 控制台级别: {console_level}")
logger.info(f" - 文件级别: {file_level}")
logger.info(" - 轮转份数: 30个文件|自动清理: 30天前的日志")
def cleanup_old_logs():
"""清理过期的日志文件"""
try:
cleanup_days = 30 # 硬编码30天
cutoff_date = datetime.now() - timedelta(days=cleanup_days)
deleted_count = 0
deleted_size = 0
# 遍历日志目录
for log_file in LOG_DIR.glob("*.log*"):
try:
file_time = datetime.fromtimestamp(log_file.stat().st_mtime)
if file_time < cutoff_date:
file_size = log_file.stat().st_size
log_file.unlink()
deleted_count += 1
deleted_size += file_size
except Exception as e:
logger = get_logger("logger")
logger.warning(f"清理日志文件 {log_file} 时出错: {e}")
if deleted_count > 0:
logger = get_logger("logger")
logger.info(f"清理了 {deleted_count} 个过期日志文件,释放空间 {deleted_size / 1024 / 1024:.2f} MB")
except Exception as e:
logger = get_logger("logger")
logger.error(f"清理旧日志文件时出错: {e}")
def start_log_cleanup_task():
"""启动日志清理任务"""
def cleanup_task():
while True:
time.sleep(24 * 60 * 60) # 每24小时执行一次
cleanup_old_logs()
cleanup_thread = threading.Thread(target=cleanup_task, daemon=True)
cleanup_thread.start()
logger = get_logger("logger")
logger.info("已启动日志清理任务将自动清理30天前的日志文件轮转份数限制: 30个文件")
def shutdown_logging():
"""优雅关闭日志系统,释放所有文件句柄"""
logger = get_logger("logger")
logger.info("正在关闭日志系统...")
# 关闭所有handler
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
if hasattr(handler, "close"):
handler.close()
root_logger.removeHandler(handler)
# 关闭全局handler
close_handlers()
# 关闭所有其他logger的handler
logger_dict = logging.getLogger().manager.loggerDict
for _name, logger_obj in logger_dict.items():
if isinstance(logger_obj, logging.Logger):
for handler in logger_obj.handlers[:]:
if hasattr(handler, "close"):
handler.close()
logger_obj.removeHandler(handler)
logger.info("日志系统已关闭")

View File

@@ -0,0 +1,10 @@
"""Maim Message - A message handling library"""
__version__ = "0.1.0"
from .api import get_global_api
__all__ = [
"get_global_api",
]

59
src/common/message/api.py Normal file
View File

@@ -0,0 +1,59 @@
from src.common.server import get_global_server
import os
import importlib.metadata
from maim_message import MessageServer
from src.common.logger import get_logger
from src.config.config import global_config
global_api = None
def get_global_api() -> MessageServer: # sourcery skip: extract-method
"""获取全局MessageServer实例"""
global global_api
if global_api is None:
# 检查maim_message版本
try:
maim_message_version = importlib.metadata.version("maim_message")
version_compatible = [int(x) for x in maim_message_version.split(".")] >= [0, 3, 3]
except (importlib.metadata.PackageNotFoundError, ValueError):
version_compatible = False
# 读取配置项
maim_message_config = global_config.maim_message
# 设置基本参数
kwargs = {
"host": os.environ["HOST"],
"port": int(os.environ["PORT"]),
"app": get_global_server().get_app(),
}
# 只有在版本 >= 0.3.0 时才使用高级特性
if version_compatible:
# 添加自定义logger
maim_message_logger = get_logger("maim_message")
kwargs["custom_logger"] = maim_message_logger
# 添加token认证
if maim_message_config.auth_token and len(maim_message_config.auth_token) > 0:
kwargs["enable_token"] = True
if maim_message_config.use_custom:
# 添加WSS模式支持
del kwargs["app"]
kwargs["host"] = maim_message_config.host
kwargs["port"] = maim_message_config.port
kwargs["mode"] = maim_message_config.mode
if maim_message_config.use_wss:
if maim_message_config.cert_file:
kwargs["ssl_certfile"] = maim_message_config.cert_file
if maim_message_config.key_file:
kwargs["ssl_keyfile"] = maim_message_config.key_file
kwargs["enable_custom_uvicorn_logger"] = False
global_api = MessageServer(**kwargs)
if version_compatible and maim_message_config.auth_token:
for token in maim_message_config.auth_token:
global_api.add_valid_token(token)
return global_api

View File

@@ -0,0 +1,203 @@
import traceback
from typing import List, Optional, Any, Dict
from sqlalchemy import not_, select, func
from sqlalchemy.orm import DeclarativeBase
from src.config.config import global_config
# from src.common.database.database_model import Messages
from src.common.database.sqlalchemy_models import Messages
from src.common.database.sqlalchemy_database_api import get_session
from src.common.logger import get_logger
logger = get_logger(__name__)
class Base(DeclarativeBase):
pass
def _model_to_dict(instance: Base) -> Dict[str, Any]:
"""
将 SQLAlchemy 模型实例转换为字典。
"""
return {col.name: getattr(instance, col.name) for col in instance.__table__.columns}
def find_messages(
message_filter: dict[str, Any],
sort: Optional[List[tuple[str, int]]] = None,
limit: int = 0,
limit_mode: str = "latest",
filter_bot=False,
filter_command=False,
) -> List[dict[str, Any]]:
"""
根据提供的过滤器、排序和限制条件查找消息。
Args:
message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}).
sort: 排序条件列表,例如 [('time', 1)] (1 for asc, -1 for desc)。仅在 limit 为 0 时生效。
limit: 返回的最大文档数0表示不限制。
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest'
Returns:
消息字典列表,如果出错则返回空列表。
"""
try:
session = get_session()
query = select(Messages)
# 应用过滤器
if message_filter:
conditions = []
for key, value in message_filter.items():
if hasattr(Messages, key):
field = getattr(Messages, key)
if isinstance(value, dict):
# 处理 MongoDB 风格的操作符
for op, op_value in value.items():
if op == "$gt":
conditions.append(field > op_value)
elif op == "$lt":
conditions.append(field < op_value)
elif op == "$gte":
conditions.append(field >= op_value)
elif op == "$lte":
conditions.append(field <= op_value)
elif op == "$ne":
conditions.append(field != op_value)
elif op == "$in":
conditions.append(field.in_(op_value))
elif op == "$nin":
conditions.append(field.not_in(op_value))
else:
logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
else:
# 直接相等比较
conditions.append(field == value)
else:
logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
if conditions:
query = query.where(*conditions)
if filter_bot:
query = query.where(Messages.user_id != global_config.bot.qq_account)
if filter_command:
query = query.where(not_(Messages.is_command))
if limit > 0:
# 确保limit是正整数
limit = max(1, int(limit))
if limit_mode == "earliest":
# 获取时间最早的 limit 条记录,已经是正序
query = query.order_by(Messages.time.asc()).limit(limit)
try:
results = session.execute(query).scalars().all()
except Exception as e:
logger.error(f"执行earliest查询失败: {e}")
results = []
else: # 默认为 'latest'
# 获取时间最晚的 limit 条记录
query = query.order_by(Messages.time.desc()).limit(limit)
try:
latest_results = session.execute(query).scalars().all()
# 将结果按时间正序排列
results = sorted(latest_results, key=lambda msg: msg.time)
except Exception as e:
logger.error(f"执行latest查询失败: {e}")
results = []
else:
# limit 为 0 时,应用传入的 sort 参数
if sort:
sort_terms = []
for field_name, direction in sort:
if hasattr(Messages, field_name):
field = getattr(Messages, field_name)
if direction == 1: # ASC
sort_terms.append(field.asc())
elif direction == -1: # DESC
sort_terms.append(field.desc())
else:
logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
else:
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
if sort_terms:
query = query.order_by(*sort_terms)
try:
results = session.execute(query).scalars().all()
except Exception as e:
logger.error(f"执行无限制查询失败: {e}")
results = []
return [_model_to_dict(msg) for msg in results]
except Exception as e:
log_message = (
f"使用 SQLAlchemy 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
+ traceback.format_exc()
)
logger.error(log_message)
return []
def count_messages(message_filter: dict[str, Any]) -> int:
"""
根据提供的过滤器计算消息数量。
Args:
message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}).
Returns:
符合条件的消息数量,如果出错则返回 0。
"""
try:
session = get_session()
query = select(func.count(Messages.id))
# 应用过滤器
if message_filter:
conditions = []
for key, value in message_filter.items():
if hasattr(Messages, key):
field = getattr(Messages, key)
if isinstance(value, dict):
# 处理 MongoDB 风格的操作符
for op, op_value in value.items():
if op == "$gt":
conditions.append(field > op_value)
elif op == "$lt":
conditions.append(field < op_value)
elif op == "$gte":
conditions.append(field >= op_value)
elif op == "$lte":
conditions.append(field <= op_value)
elif op == "$ne":
conditions.append(field != op_value)
elif op == "$in":
conditions.append(field.in_(op_value))
elif op == "$nin":
conditions.append(field.not_in(op_value))
else:
logger.warning(
f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。"
)
else:
# 直接相等比较
conditions.append(field == value)
else:
logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
if conditions:
query = query.where(*conditions)
count = session.execute(query).scalar()
return count or 0
except Exception as e:
log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
logger.error(log_message)
return 0
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
# 注意:对于 SQLAlchemy插入操作通常是使用 session.add() 和 session.commit()。
# 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。

165
src/common/remote.py Normal file
View File

@@ -0,0 +1,165 @@
import asyncio
import aiohttp
import platform
from src.common.logger import get_logger
from src.common.tcp_connector import get_tcp_connector
from src.config.config import global_config
from src.manager.async_task_manager import AsyncTask
from src.manager.local_store_manager import local_storage
logger = get_logger("remote")
TELEMETRY_SERVER_URL = "http://hyybuth.xyz:10058"
"""遥测服务地址"""
class TelemetryHeartBeatTask(AsyncTask):
HEARTBEAT_INTERVAL = 300
def __init__(self):
super().__init__(task_name="Telemetry Heart Beat Task", run_interval=self.HEARTBEAT_INTERVAL)
self.server_url = TELEMETRY_SERVER_URL
"""遥测服务地址"""
self.client_uuid: str | None = local_storage["mmc_uuid"] if "mmc_uuid" in local_storage else None # type: ignore
"""客户端UUID"""
self.info_dict = self._get_sys_info()
"""系统信息字典"""
@staticmethod
def _get_sys_info() -> dict[str, str]:
"""获取系统信息"""
info_dict = {
"os_type": "Unknown",
"py_version": platform.python_version(),
"mmc_version": global_config.MMC_VERSION,
}
match platform.system():
case "Windows":
info_dict["os_type"] = "Windows"
case "Linux":
info_dict["os_type"] = "Linux"
case "Darwin":
info_dict["os_type"] = "macOS"
case _:
info_dict["os_type"] = "Unknown"
return info_dict
async def _req_uuid(self) -> bool:
"""
向服务端请求UUID不应在已存在UUID的情况下调用会覆盖原有的UUID
"""
if "deploy_time" not in local_storage:
logger.error("本地存储中缺少部署时间无法请求UUID")
return False
try_count: int = 0
while True:
# 如果不存在则向服务端请求一个新的UUID注册客户端
logger.info("正在向遥测服务端请求UUID...")
try:
async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session:
async with session.post(
f"{TELEMETRY_SERVER_URL}/stat/reg_client",
json={"deploy_time": local_storage["deploy_time"]},
timeout=aiohttp.ClientTimeout(total=5), # 设置超时时间为5秒
) as response:
logger.debug(f"{TELEMETRY_SERVER_URL}/stat/reg_client")
logger.debug(local_storage["deploy_time"]) # type: ignore
logger.debug(f"Response status: {response.status}")
if response.status == 200:
data = await response.json()
if client_id := data.get("mmc_uuid"):
# 将UUID存储到本地
local_storage["mmc_uuid"] = client_id
self.client_uuid = client_id
logger.info(f"成功获取UUID: {self.client_uuid}")
return True # 成功获取UUID返回True
else:
logger.error("无效的服务端响应")
else:
response_text = await response.text()
logger.error(
f"请求UUID失败不过你还是可以正常使用麦麦状态码: {response.status}, 响应内容: {response_text}"
)
except Exception as e:
import traceback
error_msg = str(e) or "未知错误"
logger.warning(
f"请求UUID出错不过你还是可以正常使用麦麦: {type(e).__name__}: {error_msg}"
) # 可能是网络问题
logger.debug(f"完整错误信息: {traceback.format_exc()}")
# 请求失败,重试次数+1
try_count += 1
if try_count > 3:
# 如果超过3次仍然失败则退出
logger.error("获取UUID失败请检查网络连接或服务端状态")
return False
else:
# 如果可以重试,等待后继续(指数退避)
logger.info(f"获取UUID失败将于 {4**try_count} 秒后重试...")
await asyncio.sleep(4**try_count)
async def _send_heartbeat(self):
"""向服务器发送心跳"""
headers = {
"Client-UUID": self.client_uuid,
"User-Agent": f"HeartbeatClient/{self.client_uuid[:8]}", # type: ignore
}
logger.debug(f"正在发送心跳到服务器: {self.server_url}")
logger.debug(str(headers))
try:
async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session:
async with session.post(
f"{self.server_url}/stat/client_heartbeat",
headers=headers,
json=self.info_dict,
timeout=aiohttp.ClientTimeout(total=5), # 设置超时时间为5秒
) as response:
logger.debug(f"Response status: {response.status}")
# 处理响应
if 200 <= response.status < 300:
# 成功
logger.debug(f"心跳发送成功,状态码: {response.status}")
elif response.status == 403:
# 403 Forbidden
logger.warning(
"此消息不会影响正常使用心跳发送失败403 Forbidden: 可能是UUID无效或未注册。"
"处理措施重置UUID下次发送心跳时将尝试重新注册。"
)
self.client_uuid = None
del local_storage["mmc_uuid"] # 删除本地存储的UUID
else:
# 其他错误
response_text = await response.text()
logger.warning(
f"(此消息不会影响正常使用)状态未发送,状态码: {response.status}, 响应内容: {response_text}"
)
except Exception as e:
import traceback
error_msg = str(e) or "未知错误"
logger.warning(f"(此消息不会影响正常使用)状态未发生: {type(e).__name__}: {error_msg}")
logger.debug(f"完整错误信息: {traceback.format_exc()}")
async def run(self):
# 发送心跳
if global_config.telemetry.enable:
if self.client_uuid is None and not await self._req_uuid():
logger.warning("获取UUID失败跳过此次心跳")
return
await self._send_heartbeat()

101
src/common/server.py Normal file
View File

@@ -0,0 +1,101 @@
from fastapi import FastAPI, APIRouter
from fastapi.middleware.cors import CORSMiddleware # 新增导入
from typing import Optional
from uvicorn import Config, Server as UvicornServer
import os
from rich.traceback import install
install(extra_lines=3)
class Server:
def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"):
self.app = FastAPI(title=app_name)
self._host: str = "127.0.0.1"
self._port: int = 8080
self._server: Optional[UvicornServer] = None
self.set_address(host, port)
# 配置 CORS
origins = [
"http://localhost:3000", # 允许的前端源
"http://127.0.0.1:3000",
# 在生产环境中,您应该添加实际的前端域名
]
self.app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True, # 是否支持 cookie
allow_methods=["*"], # 允许所有 HTTP 方法
allow_headers=["*"], # 允许所有 HTTP 请求头
)
def register_router(self, router: APIRouter, prefix: str = ""):
"""注册路由
APIRouter 用于对相关的路由端点进行分组和模块化管理:
1. 可以将相关的端点组织在一起,便于管理
2. 支持添加统一的路由前缀
3. 可以为一组路由添加共同的依赖项、标签等
示例:
router = APIRouter()
@router.get("/users")
def get_users():
return {"users": [...]}
@router.post("/users")
def create_user():
return {"msg": "user created"}
# 注册路由,添加前缀 "/api/v1"
server.register_router(router, prefix="/api/v1")
"""
self.app.include_router(router, prefix=prefix)
def set_address(self, host: Optional[str] = None, port: Optional[int] = None):
"""设置服务器地址和端口"""
if host:
self._host = host
if port:
self._port = port
async def run(self):
"""启动服务器"""
# 禁用 uvicorn 默认日志和访问日志
config = Config(app=self.app, host=self._host, port=self._port, log_config=None, access_log=False)
self._server = UvicornServer(config=config)
try:
await self._server.serve()
except KeyboardInterrupt:
await self.shutdown()
raise
except Exception as e:
await self.shutdown()
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
finally:
await self.shutdown()
async def shutdown(self):
"""安全关闭服务器"""
if self._server:
self._server.should_exit = True
await self._server.shutdown()
self._server = None
def get_app(self) -> FastAPI:
"""获取 FastAPI 实例"""
return self.app
global_server = None
def get_global_server() -> Server:
"""获取全局服务器实例"""
global global_server
if global_server is None:
global_server = Server(host=os.environ["HOST"], port=int(os.environ["PORT"]))
return global_server

View File

@@ -0,0 +1,9 @@
import ssl
import certifi
import aiohttp
ssl_context = ssl.create_default_context(cafile=certifi.where())
async def get_tcp_connector():
return aiohttp.TCPConnector(ssl=ssl_context)