初始化
This commit is contained in:
1
src/common/__init__.py
Normal file
1
src/common/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# 这个文件可以为空,但必须存在
|
||||
0
src/common/database/__init__.py
Normal file
0
src/common/database/__init__.py
Normal file
192
src/common/database/database.py
Normal file
192
src/common/database/database.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import os
|
||||
from pymongo import MongoClient
|
||||
from pymongo.database import Database
|
||||
from rich.traceback import install
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# SQLAlchemy相关导入
|
||||
from src.common.database.sqlalchemy_init import initialize_database_compat
|
||||
from src.common.database.sqlalchemy_models import get_engine, get_session
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
_client = None
|
||||
_db = None
|
||||
_sql_engine = None
|
||||
|
||||
logger = get_logger("database")
|
||||
|
||||
# 兼容性:为了不破坏现有代码,保留db变量但指向SQLAlchemy
|
||||
class DatabaseProxy:
|
||||
"""数据库代理类,提供Peewee到SQLAlchemy的兼容性接口"""
|
||||
|
||||
def __init__(self):
|
||||
self._engine = None
|
||||
self._session = None
|
||||
|
||||
def initialize(self, *args, **kwargs):
|
||||
"""初始化数据库连接"""
|
||||
return initialize_database_compat()
|
||||
|
||||
def connect(self, reuse_if_open=True):
|
||||
"""连接数据库(兼容性方法)"""
|
||||
try:
|
||||
self._engine = get_engine()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
return False
|
||||
|
||||
def is_closed(self):
|
||||
"""检查数据库是否关闭(兼容性方法)"""
|
||||
return self._engine is None
|
||||
|
||||
def create_tables(self, models, safe=True):
|
||||
"""创建表(兼容性方法)"""
|
||||
try:
|
||||
from src.common.database.sqlalchemy_models import Base
|
||||
engine = get_engine()
|
||||
Base.metadata.create_all(bind=engine)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"创建表失败: {e}")
|
||||
return False
|
||||
|
||||
def table_exists(self, model):
|
||||
"""检查表是否存在(兼容性方法)"""
|
||||
try:
|
||||
from sqlalchemy import inspect
|
||||
engine = get_engine()
|
||||
inspector = inspect(engine)
|
||||
table_name = getattr(model, '_meta', {}).get('table_name', model.__name__.lower())
|
||||
return table_name in inspector.get_table_names()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def execute_sql(self, sql):
|
||||
"""执行SQL(兼容性方法)"""
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
session = get_session()
|
||||
result = session.execute(text(sql))
|
||||
session.close()
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"执行SQL失败: {e}")
|
||||
raise
|
||||
|
||||
def atomic(self):
|
||||
"""事务上下文管理器(兼容性方法)"""
|
||||
return SQLAlchemyTransaction()
|
||||
|
||||
class SQLAlchemyTransaction:
|
||||
"""SQLAlchemy事务上下文管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.session = None
|
||||
|
||||
def __enter__(self):
|
||||
self.session = get_session()
|
||||
return self.session
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type is None:
|
||||
self.session.commit()
|
||||
else:
|
||||
self.session.rollback()
|
||||
self.session.close()
|
||||
|
||||
# 创建全局数据库代理实例
|
||||
db = DatabaseProxy()
|
||||
|
||||
def __create_database_instance():
|
||||
uri = os.getenv("MONGODB_URI")
|
||||
host = os.getenv("MONGODB_HOST", "127.0.0.1")
|
||||
port = int(os.getenv("MONGODB_PORT", "27017"))
|
||||
# db_name 变量在创建连接时不需要,在获取数据库实例时才使用
|
||||
username = os.getenv("MONGODB_USERNAME")
|
||||
password = os.getenv("MONGODB_PASSWORD")
|
||||
auth_source = os.getenv("MONGODB_AUTH_SOURCE")
|
||||
|
||||
if uri:
|
||||
# 支持标准mongodb://和mongodb+srv://连接字符串
|
||||
if uri.startswith(("mongodb://", "mongodb+srv://")):
|
||||
return MongoClient(uri)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid MongoDB URI format. URI must start with 'mongodb://' or 'mongodb+srv://'. "
|
||||
"For MongoDB Atlas, use 'mongodb+srv://' format. "
|
||||
"See: https://www.mongodb.com/docs/manual/reference/connection-string/"
|
||||
)
|
||||
|
||||
if username and password:
|
||||
# 如果有用户名和密码,使用认证连接
|
||||
return MongoClient(host, port, username=username, password=password, authSource=auth_source)
|
||||
|
||||
# 否则使用无认证连接
|
||||
return MongoClient(host, port)
|
||||
|
||||
|
||||
def get_db():
|
||||
"""获取MongoDB连接实例,延迟初始化。"""
|
||||
global _client, _db
|
||||
if _client is None:
|
||||
_client = __create_database_instance()
|
||||
_db = _client[os.getenv("DATABASE_NAME", "MegBot")]
|
||||
return _db
|
||||
|
||||
|
||||
def initialize_sql_database(database_config):
|
||||
"""
|
||||
根据配置初始化SQL数据库连接(SQLAlchemy版本)
|
||||
|
||||
Args:
|
||||
database_config: DatabaseConfig对象
|
||||
"""
|
||||
global _sql_engine
|
||||
|
||||
try:
|
||||
logger.info("使用SQLAlchemy初始化SQL数据库...")
|
||||
|
||||
# 记录数据库配置信息
|
||||
if database_config.database_type == "mysql":
|
||||
connection_info = f"{database_config.mysql_user}@{database_config.mysql_host}:{database_config.mysql_port}/{database_config.mysql_database}"
|
||||
logger.info("MySQL数据库连接配置:")
|
||||
logger.info(f" 连接信息: {connection_info}")
|
||||
logger.info(f" 字符集: {database_config.mysql_charset}")
|
||||
else:
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
if not os.path.isabs(database_config.sqlite_path):
|
||||
db_path = os.path.join(ROOT_PATH, database_config.sqlite_path)
|
||||
else:
|
||||
db_path = database_config.sqlite_path
|
||||
logger.info("SQLite数据库连接配置:")
|
||||
logger.info(f" 数据库文件: {db_path}")
|
||||
|
||||
# 使用SQLAlchemy初始化
|
||||
success = initialize_database_compat()
|
||||
if success:
|
||||
_sql_engine = get_engine()
|
||||
logger.info("SQLAlchemy数据库初始化成功")
|
||||
else:
|
||||
logger.error("SQLAlchemy数据库初始化失败")
|
||||
|
||||
return _sql_engine
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"初始化SQL数据库失败: {e}")
|
||||
return None
|
||||
|
||||
class DBWrapper:
|
||||
"""数据库代理类,保持接口兼容性同时实现懒加载。"""
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(get_db(), name)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return get_db()[key] # type: ignore
|
||||
|
||||
|
||||
# 全局MongoDB数据库访问点
|
||||
memory_db: Database = DBWrapper() # type: ignore
|
||||
|
||||
420
src/common/database/sqlalchemy_database_api.py
Normal file
420
src/common/database/sqlalchemy_database_api.py
Normal file
@@ -0,0 +1,420 @@
|
||||
"""SQLAlchemy数据库API模块
|
||||
|
||||
提供基于SQLAlchemy的数据库操作,替换Peewee以解决MySQL连接问题
|
||||
支持自动重连、连接池管理和更好的错误处理
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import time
|
||||
from typing import Dict, List, Any, Union, Type, Optional
|
||||
from contextlib import contextmanager
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import SQLAlchemyError, DisconnectionError, OperationalError
|
||||
from sqlalchemy import desc, asc, func, and_, or_
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import (
|
||||
Base, get_db_session, Messages, ActionRecords, PersonInfo, ChatStreams,
|
||||
LLMUsage, Emoji, Images, ImageDescriptions, OnlineTime, Memory,
|
||||
Expression, ThinkingLog, GraphNodes, GraphEdges,get_session
|
||||
)
|
||||
|
||||
logger = get_logger("sqlalchemy_database_api")
|
||||
|
||||
# 模型映射表,用于通过名称获取模型类
|
||||
MODEL_MAPPING = {
|
||||
'Messages': Messages,
|
||||
'ActionRecords': ActionRecords,
|
||||
'PersonInfo': PersonInfo,
|
||||
'ChatStreams': ChatStreams,
|
||||
'LLMUsage': LLMUsage,
|
||||
'Emoji': Emoji,
|
||||
'Images': Images,
|
||||
'ImageDescriptions': ImageDescriptions,
|
||||
'OnlineTime': OnlineTime,
|
||||
'Memory': Memory,
|
||||
'Expression': Expression,
|
||||
'ThinkingLog': ThinkingLog,
|
||||
'GraphNodes': GraphNodes,
|
||||
'GraphEdges': GraphEdges,
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_session():
|
||||
"""数据库会话上下文管理器,自动处理事务和连接错误"""
|
||||
session = None
|
||||
max_retries = 3
|
||||
retry_delay = 1.0
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
session = get_session()
|
||||
yield session
|
||||
session.commit()
|
||||
break
|
||||
except (DisconnectionError, OperationalError) as e:
|
||||
logger.warning(f"数据库连接错误 (尝试 {attempt + 1}/{max_retries}): {e}")
|
||||
if session:
|
||||
session.rollback()
|
||||
session.close()
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(retry_delay * (attempt + 1))
|
||||
else:
|
||||
raise
|
||||
except Exception as e:
|
||||
if session:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
if session:
|
||||
session.close()
|
||||
|
||||
|
||||
def build_filters(session: Session, model_class: Type[Base], filters: Dict[str, Any]):
|
||||
"""构建查询过滤条件"""
|
||||
conditions = []
|
||||
|
||||
for field_name, value in filters.items():
|
||||
if not hasattr(model_class, field_name):
|
||||
logger.warning(f"模型 {model_class.__name__} 中不存在字段 '{field_name}'")
|
||||
continue
|
||||
|
||||
field = getattr(model_class, field_name)
|
||||
|
||||
if isinstance(value, dict):
|
||||
# 处理 MongoDB 风格的操作符
|
||||
for op, op_value in value.items():
|
||||
if op == "$gt":
|
||||
conditions.append(field > op_value)
|
||||
elif op == "$lt":
|
||||
conditions.append(field < op_value)
|
||||
elif op == "$gte":
|
||||
conditions.append(field >= op_value)
|
||||
elif op == "$lte":
|
||||
conditions.append(field <= op_value)
|
||||
elif op == "$ne":
|
||||
conditions.append(field != op_value)
|
||||
elif op == "$in":
|
||||
conditions.append(field.in_(op_value))
|
||||
elif op == "$nin":
|
||||
conditions.append(~field.in_(op_value))
|
||||
else:
|
||||
logger.warning(f"未知操作符 '{op}' (字段: '{field_name}')")
|
||||
else:
|
||||
# 直接相等比较
|
||||
conditions.append(field == value)
|
||||
|
||||
return conditions
|
||||
|
||||
|
||||
async def db_query(
|
||||
model_class: Type[Base],
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
query_type: Optional[str] = "get",
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[List[str]] = None,
|
||||
single_result: Optional[bool] = False,
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""执行数据库查询操作
|
||||
|
||||
Args:
|
||||
model_class: SQLAlchemy模型类
|
||||
data: 用于创建或更新的数据字典
|
||||
query_type: 查询类型 ("get", "create", "update", "delete", "count")
|
||||
filters: 过滤条件字典
|
||||
limit: 限制结果数量
|
||||
order_by: 排序字段,前缀'-'表示降序
|
||||
single_result: 是否只返回单个结果
|
||||
|
||||
Returns:
|
||||
根据查询类型返回相应结果
|
||||
"""
|
||||
try:
|
||||
if query_type not in ["get", "create", "update", "delete", "count"]:
|
||||
raise ValueError("query_type must be 'get', 'create', 'update', 'delete' or 'count'")
|
||||
|
||||
with get_db_session() as session:
|
||||
if query_type == "get":
|
||||
query = session.query(model_class)
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
conditions = build_filters(session, model_class, filters)
|
||||
if conditions:
|
||||
query = query.filter(and_(*conditions))
|
||||
|
||||
# 应用排序
|
||||
if order_by:
|
||||
for field_name in order_by:
|
||||
if field_name.startswith("-"):
|
||||
field_name = field_name[1:]
|
||||
if hasattr(model_class, field_name):
|
||||
query = query.order_by(desc(getattr(model_class, field_name)))
|
||||
else:
|
||||
if hasattr(model_class, field_name):
|
||||
query = query.order_by(asc(getattr(model_class, field_name)))
|
||||
|
||||
# 应用限制
|
||||
if limit and limit > 0:
|
||||
query = query.limit(limit)
|
||||
|
||||
# 执行查询
|
||||
results = query.all()
|
||||
|
||||
# 转换为字典格式
|
||||
result_dicts = []
|
||||
for result in results:
|
||||
result_dict = {}
|
||||
for column in result.__table__.columns:
|
||||
result_dict[column.name] = getattr(result, column.name)
|
||||
result_dicts.append(result_dict)
|
||||
|
||||
if single_result:
|
||||
return result_dicts[0] if result_dicts else None
|
||||
return result_dicts
|
||||
|
||||
elif query_type == "create":
|
||||
if not data:
|
||||
raise ValueError("创建记录需要提供data参数")
|
||||
|
||||
# 创建新记录
|
||||
new_record = model_class(**data)
|
||||
session.add(new_record)
|
||||
session.flush() # 获取自动生成的ID
|
||||
|
||||
# 转换为字典格式返回
|
||||
result_dict = {}
|
||||
for column in new_record.__table__.columns:
|
||||
result_dict[column.name] = getattr(new_record, column.name)
|
||||
return result_dict
|
||||
|
||||
elif query_type == "update":
|
||||
if not data:
|
||||
raise ValueError("更新记录需要提供data参数")
|
||||
|
||||
query = session.query(model_class)
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
conditions = build_filters(session, model_class, filters)
|
||||
if conditions:
|
||||
query = query.filter(and_(*conditions))
|
||||
|
||||
# 执行更新
|
||||
affected_rows = query.update(data)
|
||||
return affected_rows
|
||||
|
||||
elif query_type == "delete":
|
||||
query = session.query(model_class)
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
conditions = build_filters(session, model_class, filters)
|
||||
if conditions:
|
||||
query = query.filter(and_(*conditions))
|
||||
|
||||
# 执行删除
|
||||
affected_rows = query.delete()
|
||||
return affected_rows
|
||||
|
||||
elif query_type == "count":
|
||||
query = session.query(func.count(model_class.id))
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
base_query = session.query(model_class)
|
||||
conditions = build_filters(session, model_class, filters)
|
||||
if conditions:
|
||||
base_query = base_query.filter(and_(*conditions))
|
||||
query = session.query(func.count()).select_from(base_query.subquery())
|
||||
|
||||
return query.scalar()
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"[SQLAlchemy] 数据库操作出错: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# 根据查询类型返回合适的默认值
|
||||
if query_type == "get":
|
||||
return None if single_result else []
|
||||
elif query_type in ["create", "update", "delete", "count"]:
|
||||
return None
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SQLAlchemy] 意外错误: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
if query_type == "get":
|
||||
return None if single_result else []
|
||||
return None
|
||||
|
||||
|
||||
async def db_save(
|
||||
model_class: Type[Base],
|
||||
data: Dict[str, Any],
|
||||
key_field: Optional[str] = None,
|
||||
key_value: Optional[Any] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""保存数据到数据库(创建或更新)
|
||||
|
||||
Args:
|
||||
model_class: SQLAlchemy模型类
|
||||
data: 要保存的数据字典
|
||||
key_field: 用于查找现有记录的字段名
|
||||
key_value: 用于查找现有记录的字段值
|
||||
|
||||
Returns:
|
||||
保存后的记录数据或None
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
# 如果提供了key_field和key_value,尝试更新现有记录
|
||||
if key_field and key_value is not None:
|
||||
if hasattr(model_class, key_field):
|
||||
existing_record = session.query(model_class).filter(
|
||||
getattr(model_class, key_field) == key_value
|
||||
).first()
|
||||
|
||||
if existing_record:
|
||||
# 更新现有记录
|
||||
for field, value in data.items():
|
||||
if hasattr(existing_record, field):
|
||||
setattr(existing_record, field, value)
|
||||
|
||||
session.flush()
|
||||
|
||||
# 转换为字典格式返回
|
||||
result_dict = {}
|
||||
for column in existing_record.__table__.columns:
|
||||
result_dict[column.name] = getattr(existing_record, column.name)
|
||||
return result_dict
|
||||
|
||||
# 创建新记录
|
||||
new_record = model_class(**data)
|
||||
session.add(new_record)
|
||||
session.flush()
|
||||
|
||||
# 转换为字典格式返回
|
||||
result_dict = {}
|
||||
for column in new_record.__table__.columns:
|
||||
result_dict[column.name] = getattr(new_record, column.name)
|
||||
return result_dict
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"[SQLAlchemy] 保存数据库记录出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[SQLAlchemy] 保存时意外错误: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
async def db_get(
|
||||
model_class: Type[Base],
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[str] = None,
|
||||
single_result: Optional[bool] = False,
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""从数据库获取记录
|
||||
|
||||
Args:
|
||||
model_class: SQLAlchemy模型类
|
||||
filters: 过滤条件
|
||||
limit: 结果数量限制
|
||||
order_by: 排序字段,前缀'-'表示降序
|
||||
single_result: 是否只返回单个结果
|
||||
|
||||
Returns:
|
||||
记录数据或None
|
||||
"""
|
||||
order_by_list = [order_by] if order_by else None
|
||||
return await db_query(
|
||||
model_class=model_class,
|
||||
query_type="get",
|
||||
filters=filters,
|
||||
limit=limit,
|
||||
order_by=order_by_list,
|
||||
single_result=single_result
|
||||
)
|
||||
|
||||
|
||||
async def store_action_info(
|
||||
chat_stream=None,
|
||||
action_build_into_prompt: bool = False,
|
||||
action_prompt_display: str = "",
|
||||
action_done: bool = True,
|
||||
thinking_id: str = "",
|
||||
action_data: Optional[dict] = None,
|
||||
action_name: str = "",
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""存储动作信息到数据库
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象
|
||||
action_build_into_prompt: 是否将此动作构建到提示中
|
||||
action_prompt_display: 动作的提示显示文本
|
||||
action_done: 动作是否完成
|
||||
thinking_id: 关联的思考ID
|
||||
action_data: 动作数据字典
|
||||
action_name: 动作名称
|
||||
|
||||
Returns:
|
||||
保存的记录数据或None
|
||||
"""
|
||||
try:
|
||||
import json
|
||||
|
||||
# 构建动作记录数据
|
||||
record_data = {
|
||||
"action_id": thinking_id or str(int(time.time() * 1000000)),
|
||||
"time": time.time(),
|
||||
"action_name": action_name,
|
||||
"action_data": json.dumps(action_data or {}, ensure_ascii=False),
|
||||
"action_done": action_done,
|
||||
"action_build_into_prompt": action_build_into_prompt,
|
||||
"action_prompt_display": action_prompt_display,
|
||||
}
|
||||
|
||||
# 从chat_stream获取聊天信息
|
||||
if chat_stream:
|
||||
record_data.update({
|
||||
"chat_id": getattr(chat_stream, "stream_id", ""),
|
||||
"chat_info_stream_id": getattr(chat_stream, "stream_id", ""),
|
||||
"chat_info_platform": getattr(chat_stream, "platform", ""),
|
||||
})
|
||||
else:
|
||||
record_data.update({
|
||||
"chat_id": "",
|
||||
"chat_info_stream_id": "",
|
||||
"chat_info_platform": "",
|
||||
})
|
||||
|
||||
# 保存记录
|
||||
saved_record = await db_save(
|
||||
ActionRecords,
|
||||
data=record_data,
|
||||
key_field="action_id",
|
||||
key_value=record_data["action_id"]
|
||||
)
|
||||
|
||||
if saved_record:
|
||||
logger.debug(f"[SQLAlchemy] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
|
||||
else:
|
||||
logger.error(f"[SQLAlchemy] 存储动作信息失败: {action_name}")
|
||||
|
||||
return saved_record
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SQLAlchemy] 存储动作信息时发生错误: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
# 兼容性函数,方便从Peewee迁移
|
||||
def get_model_class(model_name: str) -> Optional[Type[Base]]:
|
||||
"""根据模型名称获取模型类"""
|
||||
return MODEL_MAPPING.get(model_name)
|
||||
158
src/common/database/sqlalchemy_init.py
Normal file
158
src/common/database/sqlalchemy_init.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""SQLAlchemy数据库初始化模块
|
||||
|
||||
替换Peewee的数据库初始化逻辑
|
||||
提供统一的数据库初始化接口
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import (
|
||||
Base, get_engine, get_session, initialize_database
|
||||
)
|
||||
|
||||
logger = get_logger("sqlalchemy_init")
|
||||
|
||||
|
||||
def initialize_sqlalchemy_database() -> bool:
|
||||
"""
|
||||
初始化SQLAlchemy数据库
|
||||
创建所有表结构
|
||||
|
||||
Returns:
|
||||
bool: 初始化是否成功
|
||||
"""
|
||||
try:
|
||||
logger.info("开始初始化SQLAlchemy数据库...")
|
||||
|
||||
# 初始化数据库引擎和会话
|
||||
engine, session_local = initialize_database()
|
||||
|
||||
if engine is None:
|
||||
logger.error("数据库引擎初始化失败")
|
||||
return False
|
||||
|
||||
logger.info("SQLAlchemy数据库初始化成功")
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy数据库初始化失败: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"数据库初始化过程中发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def create_all_tables() -> bool:
|
||||
"""
|
||||
创建所有数据库表
|
||||
|
||||
Returns:
|
||||
bool: 创建是否成功
|
||||
"""
|
||||
try:
|
||||
logger.info("开始创建数据库表...")
|
||||
|
||||
engine = get_engine()
|
||||
if engine is None:
|
||||
logger.error("无法获取数据库引擎")
|
||||
return False
|
||||
|
||||
# 创建所有表
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
logger.info("数据库表创建成功")
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"创建数据库表失败: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"创建数据库表过程中发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def check_database_connection() -> bool:
|
||||
"""
|
||||
检查数据库连接是否正常
|
||||
|
||||
Returns:
|
||||
bool: 连接是否正常
|
||||
"""
|
||||
try:
|
||||
session = get_session()
|
||||
if session is None:
|
||||
logger.error("无法获取数据库会话")
|
||||
return False
|
||||
|
||||
# 检查会话是否可用(如果能获取到会话说明连接正常)
|
||||
if session is None:
|
||||
logger.error("数据库会话无效")
|
||||
return False
|
||||
|
||||
session.close()
|
||||
|
||||
logger.info("数据库连接检查通过")
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"数据库连接检查失败: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接检查过程中发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_database_info() -> Optional[dict]:
|
||||
"""
|
||||
获取数据库信息
|
||||
|
||||
Returns:
|
||||
dict: 数据库信息字典,包含引擎信息等
|
||||
"""
|
||||
try:
|
||||
engine = get_engine()
|
||||
if engine is None:
|
||||
return None
|
||||
|
||||
info = {
|
||||
'engine_name': engine.name,
|
||||
'driver': engine.driver,
|
||||
'url': str(engine.url).replace(engine.url.password or '', '***'), # 隐藏密码
|
||||
'pool_size': getattr(engine.pool, 'size', None),
|
||||
'max_overflow': getattr(engine.pool, 'max_overflow', None),
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取数据库信息失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
_database_initialized = False
|
||||
|
||||
def initialize_database_compat() -> bool:
|
||||
"""
|
||||
兼容性数据库初始化函数
|
||||
用于替换原有的Peewee初始化代码
|
||||
|
||||
Returns:
|
||||
bool: 初始化是否成功
|
||||
"""
|
||||
global _database_initialized
|
||||
|
||||
if _database_initialized:
|
||||
return True
|
||||
|
||||
success = initialize_sqlalchemy_database()
|
||||
if success:
|
||||
success = create_all_tables()
|
||||
|
||||
if success:
|
||||
success = check_database_connection()
|
||||
|
||||
if success:
|
||||
_database_initialized = True
|
||||
|
||||
return success
|
||||
555
src/common/database/sqlalchemy_models.py
Normal file
555
src/common/database/sqlalchemy_models.py
Normal file
@@ -0,0 +1,555 @@
|
||||
"""SQLAlchemy数据库模型定义
|
||||
|
||||
替换Peewee ORM,使用SQLAlchemy提供更好的连接池管理和错误恢复能力
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, create_engine, DateTime
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import QueuePool
|
||||
import os
|
||||
import datetime
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
|
||||
logger = get_logger("sqlalchemy_models")
|
||||
|
||||
# 创建基类
|
||||
Base = declarative_base()
|
||||
|
||||
# MySQL兼容的字段类型辅助函数
|
||||
def get_string_field(max_length=255, **kwargs):
|
||||
"""
|
||||
根据数据库类型返回合适的字符串字段
|
||||
MySQL需要指定长度的VARCHAR用于索引,SQLite可以使用Text
|
||||
"""
|
||||
if global_config.database.database_type == "mysql":
|
||||
return String(max_length, **kwargs)
|
||||
else:
|
||||
return Text(**kwargs)
|
||||
|
||||
class SessionProxy:
|
||||
"""线程安全的Session代理类,自动管理session生命周期"""
|
||||
|
||||
def __init__(self):
|
||||
self._local = threading.local()
|
||||
|
||||
def _get_current_session(self):
|
||||
"""获取当前线程的session,如果没有则创建新的"""
|
||||
if not hasattr(self._local, 'session') or self._local.session is None:
|
||||
_, SessionLocal = initialize_database()
|
||||
self._local.session = SessionLocal()
|
||||
return self._local.session
|
||||
|
||||
def _close_current_session(self):
|
||||
"""关闭当前线程的session"""
|
||||
if hasattr(self._local, 'session') and self._local.session is not None:
|
||||
try:
|
||||
self._local.session.close()
|
||||
except:
|
||||
pass
|
||||
finally:
|
||||
self._local.session = None
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""代理所有session方法"""
|
||||
session = self._get_current_session()
|
||||
attr = getattr(session, name)
|
||||
|
||||
# 如果是方法,需要特殊处理一些关键方法
|
||||
if callable(attr):
|
||||
if name in ['commit', 'rollback']:
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
result = attr(*args, **kwargs)
|
||||
if name == 'commit':
|
||||
# commit后不要清除session,只是刷新状态
|
||||
pass # 保持session活跃
|
||||
return result
|
||||
except Exception as e:
|
||||
try:
|
||||
if session and hasattr(session, 'rollback'):
|
||||
session.rollback()
|
||||
except:
|
||||
pass
|
||||
# 发生错误时重新创建session
|
||||
self._close_current_session()
|
||||
raise
|
||||
return wrapper
|
||||
elif name == 'close':
|
||||
def wrapper(*args, **kwargs):
|
||||
result = attr(*args, **kwargs)
|
||||
self._close_current_session()
|
||||
return result
|
||||
return wrapper
|
||||
elif name in ['execute', 'query', 'add', 'delete', 'merge']:
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return attr(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# 如果是连接相关错误,重新创建session再试一次
|
||||
if "not bound to a Session" in str(e) or "provisioning a new connection" in str(e):
|
||||
logger.warning(f"Session问题,重新创建session: {e}")
|
||||
self._close_current_session()
|
||||
new_session = self._get_current_session()
|
||||
new_attr = getattr(new_session, name)
|
||||
return new_attr(*args, **kwargs)
|
||||
raise
|
||||
return wrapper
|
||||
|
||||
return attr
|
||||
|
||||
def new_session(self):
|
||||
"""强制创建新的session(关闭当前的,创建新的)"""
|
||||
self._close_current_session()
|
||||
return self._get_current_session()
|
||||
|
||||
def ensure_fresh_session(self):
|
||||
"""确保使用新鲜的session(如果当前session有问题则重新创建)"""
|
||||
if hasattr(self._local, 'session') and self._local.session is not None:
|
||||
try:
|
||||
# 测试session是否还可用
|
||||
self._local.session.execute("SELECT 1")
|
||||
except Exception:
|
||||
# session有问题,重新创建
|
||||
self._close_current_session()
|
||||
return self._get_current_session()
|
||||
|
||||
# 创建全局session代理实例
|
||||
_global_session_proxy = SessionProxy()
|
||||
|
||||
def get_session():
|
||||
"""返回线程安全的session代理,自动管理生命周期"""
|
||||
return _global_session_proxy
|
||||
|
||||
|
||||
class ChatStreams(Base):
|
||||
"""聊天流模型"""
|
||||
__tablename__ = 'chat_streams'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
stream_id = Column(get_string_field(64), nullable=False, unique=True, index=True)
|
||||
create_time = Column(Float, nullable=False)
|
||||
group_platform = Column(Text, nullable=True)
|
||||
group_id = Column(get_string_field(100), nullable=True, index=True)
|
||||
group_name = Column(Text, nullable=True)
|
||||
last_active_time = Column(Float, nullable=False)
|
||||
platform = Column(Text, nullable=False)
|
||||
user_platform = Column(Text, nullable=False)
|
||||
user_id = Column(get_string_field(100), nullable=False, index=True)
|
||||
user_nickname = Column(Text, nullable=False)
|
||||
user_cardname = Column(Text, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_chatstreams_stream_id', 'stream_id'),
|
||||
Index('idx_chatstreams_user_id', 'user_id'),
|
||||
Index('idx_chatstreams_group_id', 'group_id'),
|
||||
)
|
||||
|
||||
|
||||
class LLMUsage(Base):
|
||||
"""LLM使用记录模型"""
|
||||
__tablename__ = 'llm_usage'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
model_name = Column(get_string_field(100), nullable=False, index=True)
|
||||
user_id = Column(get_string_field(50), nullable=False, index=True)
|
||||
request_type = Column(get_string_field(50), nullable=False, index=True)
|
||||
endpoint = Column(Text, nullable=False)
|
||||
prompt_tokens = Column(Integer, nullable=False)
|
||||
completion_tokens = Column(Integer, nullable=False)
|
||||
total_tokens = Column(Integer, nullable=False)
|
||||
cost = Column(Float, nullable=False)
|
||||
status = Column(Text, nullable=False)
|
||||
timestamp = Column(DateTime, nullable=False, index=True, default=datetime.datetime.now)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_llmusage_model_name', 'model_name'),
|
||||
Index('idx_llmusage_user_id', 'user_id'),
|
||||
Index('idx_llmusage_request_type', 'request_type'),
|
||||
Index('idx_llmusage_timestamp', 'timestamp'),
|
||||
)
|
||||
|
||||
|
||||
class Emoji(Base):
|
||||
"""表情包模型"""
|
||||
__tablename__ = 'emoji'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
full_path = Column(get_string_field(500), nullable=False, unique=True, index=True)
|
||||
format = Column(Text, nullable=False)
|
||||
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
|
||||
description = Column(Text, nullable=False)
|
||||
query_count = Column(Integer, nullable=False, default=0)
|
||||
is_registered = Column(Boolean, nullable=False, default=False)
|
||||
is_banned = Column(Boolean, nullable=False, default=False)
|
||||
emotion = Column(Text, nullable=True)
|
||||
record_time = Column(Float, nullable=False)
|
||||
register_time = Column(Float, nullable=True)
|
||||
usage_count = Column(Integer, nullable=False, default=0)
|
||||
last_used_time = Column(Float, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_emoji_full_path', 'full_path'),
|
||||
Index('idx_emoji_hash', 'emoji_hash'),
|
||||
)
|
||||
|
||||
|
||||
class Messages(Base):
|
||||
"""消息模型"""
|
||||
__tablename__ = 'messages'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
message_id = Column(get_string_field(100), nullable=False, index=True)
|
||||
time = Column(Float, nullable=False)
|
||||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
||||
reply_to = Column(Text, nullable=True)
|
||||
interest_value = Column(Float, nullable=True)
|
||||
is_mentioned = Column(Boolean, nullable=True)
|
||||
|
||||
# 从 chat_info 扁平化而来的字段
|
||||
chat_info_stream_id = Column(Text, nullable=False)
|
||||
chat_info_platform = Column(Text, nullable=False)
|
||||
chat_info_user_platform = Column(Text, nullable=False)
|
||||
chat_info_user_id = Column(Text, nullable=False)
|
||||
chat_info_user_nickname = Column(Text, nullable=False)
|
||||
chat_info_user_cardname = Column(Text, nullable=True)
|
||||
chat_info_group_platform = Column(Text, nullable=True)
|
||||
chat_info_group_id = Column(Text, nullable=True)
|
||||
chat_info_group_name = Column(Text, nullable=True)
|
||||
chat_info_create_time = Column(Float, nullable=False)
|
||||
chat_info_last_active_time = Column(Float, nullable=False)
|
||||
|
||||
# 从顶层 user_info 扁平化而来的字段
|
||||
user_platform = Column(Text, nullable=True)
|
||||
user_id = Column(get_string_field(100), nullable=True, index=True)
|
||||
user_nickname = Column(Text, nullable=True)
|
||||
user_cardname = Column(Text, nullable=True)
|
||||
|
||||
processed_plain_text = Column(Text, nullable=True)
|
||||
display_message = Column(Text, nullable=True)
|
||||
memorized_times = Column(Integer, nullable=False, default=0)
|
||||
priority_mode = Column(Text, nullable=True)
|
||||
priority_info = Column(Text, nullable=True)
|
||||
additional_config = Column(Text, nullable=True)
|
||||
is_emoji = Column(Boolean, nullable=False, default=False)
|
||||
is_picid = Column(Boolean, nullable=False, default=False)
|
||||
is_command = Column(Boolean, nullable=False, default=False)
|
||||
is_notify = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_messages_message_id', 'message_id'),
|
||||
Index('idx_messages_chat_id', 'chat_id'),
|
||||
Index('idx_messages_time', 'time'),
|
||||
Index('idx_messages_user_id', 'user_id'),
|
||||
)
|
||||
|
||||
|
||||
class ActionRecords(Base):
|
||||
"""动作记录模型"""
|
||||
__tablename__ = 'action_records'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
action_id = Column(get_string_field(100), nullable=False, index=True)
|
||||
time = Column(Float, nullable=False)
|
||||
action_name = Column(Text, nullable=False)
|
||||
action_data = Column(Text, nullable=False)
|
||||
action_done = Column(Boolean, nullable=False, default=False)
|
||||
action_build_into_prompt = Column(Boolean, nullable=False, default=False)
|
||||
action_prompt_display = Column(Text, nullable=False)
|
||||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
||||
chat_info_stream_id = Column(Text, nullable=False)
|
||||
chat_info_platform = Column(Text, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_actionrecords_action_id', 'action_id'),
|
||||
Index('idx_actionrecords_chat_id', 'chat_id'),
|
||||
Index('idx_actionrecords_time', 'time'),
|
||||
)
|
||||
|
||||
|
||||
class Images(Base):
|
||||
"""图像信息模型"""
|
||||
__tablename__ = 'images'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
image_id = Column(Text, nullable=False, default="")
|
||||
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
|
||||
description = Column(Text, nullable=True)
|
||||
path = Column(get_string_field(500), nullable=False, unique=True)
|
||||
count = Column(Integer, nullable=False, default=1)
|
||||
timestamp = Column(Float, nullable=False)
|
||||
type = Column(Text, nullable=False)
|
||||
vlm_processed = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_images_emoji_hash', 'emoji_hash'),
|
||||
Index('idx_images_path', 'path'),
|
||||
)
|
||||
|
||||
|
||||
class ImageDescriptions(Base):
|
||||
"""图像描述信息模型"""
|
||||
__tablename__ = 'image_descriptions'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
type = Column(Text, nullable=False)
|
||||
image_description_hash = Column(get_string_field(64), nullable=False, index=True)
|
||||
description = Column(Text, nullable=False)
|
||||
timestamp = Column(Float, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_imagedesc_hash', 'image_description_hash'),
|
||||
)
|
||||
|
||||
|
||||
class OnlineTime(Base):
|
||||
"""在线时长记录模型"""
|
||||
__tablename__ = 'online_time'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
timestamp = Column(Text, nullable=False, default=str(datetime.datetime.now))
|
||||
duration = Column(Integer, nullable=False)
|
||||
start_timestamp = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||
end_timestamp = Column(DateTime, nullable=False, index=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_onlinetime_end_timestamp', 'end_timestamp'),
|
||||
)
|
||||
|
||||
|
||||
class PersonInfo(Base):
|
||||
"""人物信息模型"""
|
||||
__tablename__ = 'person_info'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
person_id = Column(get_string_field(100), nullable=False, unique=True, index=True)
|
||||
person_name = Column(Text, nullable=True)
|
||||
name_reason = Column(Text, nullable=True)
|
||||
platform = Column(Text, nullable=False)
|
||||
user_id = Column(get_string_field(50), nullable=False, index=True)
|
||||
nickname = Column(Text, nullable=True)
|
||||
impression = Column(Text, nullable=True)
|
||||
short_impression = Column(Text, nullable=True)
|
||||
points = Column(Text, nullable=True)
|
||||
forgotten_points = Column(Text, nullable=True)
|
||||
info_list = Column(Text, nullable=True)
|
||||
know_times = Column(Float, nullable=True)
|
||||
know_since = Column(Float, nullable=True)
|
||||
last_know = Column(Float, nullable=True)
|
||||
attitude = Column(Integer, nullable=True, default=50)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_personinfo_person_id', 'person_id'),
|
||||
Index('idx_personinfo_user_id', 'user_id'),
|
||||
)
|
||||
|
||||
|
||||
class Memory(Base):
|
||||
"""记忆模型"""
|
||||
__tablename__ = 'memory'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
memory_id = Column(get_string_field(64), nullable=False, index=True)
|
||||
chat_id = Column(Text, nullable=True)
|
||||
memory_text = Column(Text, nullable=True)
|
||||
keywords = Column(Text, nullable=True)
|
||||
create_time = Column(Float, nullable=True)
|
||||
last_view_time = Column(Float, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_memory_memory_id', 'memory_id'),
|
||||
)
|
||||
|
||||
|
||||
class Expression(Base):
|
||||
"""表达风格模型"""
|
||||
__tablename__ = 'expression'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
situation = Column(Text, nullable=False)
|
||||
style = Column(Text, nullable=False)
|
||||
count = Column(Float, nullable=False)
|
||||
last_active_time = Column(Float, nullable=False)
|
||||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
||||
type = Column(Text, nullable=False)
|
||||
create_date = Column(Float, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_expression_chat_id', 'chat_id'),
|
||||
)
|
||||
|
||||
|
||||
class ThinkingLog(Base):
|
||||
"""思考日志模型"""
|
||||
__tablename__ = 'thinking_logs'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
||||
trigger_text = Column(Text, nullable=True)
|
||||
response_text = Column(Text, nullable=True)
|
||||
trigger_info_json = Column(Text, nullable=True)
|
||||
response_info_json = Column(Text, nullable=True)
|
||||
timing_results_json = Column(Text, nullable=True)
|
||||
chat_history_json = Column(Text, nullable=True)
|
||||
chat_history_in_thinking_json = Column(Text, nullable=True)
|
||||
chat_history_after_response_json = Column(Text, nullable=True)
|
||||
heartflow_data_json = Column(Text, nullable=True)
|
||||
reasoning_data_json = Column(Text, nullable=True)
|
||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_thinkinglog_chat_id', 'chat_id'),
|
||||
)
|
||||
|
||||
|
||||
class GraphNodes(Base):
|
||||
"""记忆图节点模型"""
|
||||
__tablename__ = 'graph_nodes'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
concept = Column(get_string_field(255), nullable=False, unique=True, index=True)
|
||||
memory_items = Column(Text, nullable=False)
|
||||
hash = Column(Text, nullable=False)
|
||||
created_time = Column(Float, nullable=False)
|
||||
last_modified = Column(Float, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_graphnodes_concept', 'concept'),
|
||||
)
|
||||
|
||||
|
||||
class GraphEdges(Base):
|
||||
"""记忆图边模型"""
|
||||
__tablename__ = 'graph_edges'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
source = Column(get_string_field(255), nullable=False, index=True)
|
||||
target = Column(get_string_field(255), nullable=False, index=True)
|
||||
strength = Column(Integer, nullable=False)
|
||||
hash = Column(Text, nullable=False)
|
||||
created_time = Column(Float, nullable=False)
|
||||
last_modified = Column(Float, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_graphedges_source', 'source'),
|
||||
Index('idx_graphedges_target', 'target'),
|
||||
)
|
||||
|
||||
|
||||
# 数据库引擎和会话管理
|
||||
_engine = None
|
||||
_SessionLocal = None
|
||||
|
||||
|
||||
def get_database_url():
|
||||
"""获取数据库连接URL"""
|
||||
config = global_config.database
|
||||
|
||||
if config.database_type == "mysql":
|
||||
# 对用户名和密码进行URL编码,处理特殊字符
|
||||
from urllib.parse import quote_plus
|
||||
encoded_user = quote_plus(config.mysql_user)
|
||||
encoded_password = quote_plus(config.mysql_password)
|
||||
|
||||
return (
|
||||
f"mysql+pymysql://{encoded_user}:{encoded_password}"
|
||||
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
|
||||
f"?charset={config.mysql_charset}"
|
||||
)
|
||||
else: # SQLite
|
||||
# 如果是相对路径,则相对于项目根目录
|
||||
if not os.path.isabs(config.sqlite_path):
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
|
||||
else:
|
||||
db_path = config.sqlite_path
|
||||
|
||||
# 确保数据库目录存在
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
|
||||
return f"sqlite:///{db_path}"
|
||||
|
||||
|
||||
def initialize_database():
|
||||
"""初始化数据库引擎和会话"""
|
||||
global _engine, _SessionLocal
|
||||
|
||||
if _engine is not None:
|
||||
return _engine, _SessionLocal
|
||||
|
||||
database_url = get_database_url()
|
||||
config = global_config.database
|
||||
|
||||
# 配置引擎参数
|
||||
engine_kwargs = {
|
||||
'echo': False, # 生产环境关闭SQL日志
|
||||
'future': True,
|
||||
}
|
||||
|
||||
if config.database_type == "mysql":
|
||||
# MySQL连接池配置
|
||||
engine_kwargs.update({
|
||||
'poolclass': QueuePool,
|
||||
'pool_size': config.connection_pool_size,
|
||||
'max_overflow': config.connection_pool_size * 2,
|
||||
'pool_timeout': config.connection_timeout,
|
||||
'pool_recycle': 3600, # 1小时回收连接
|
||||
'pool_pre_ping': True, # 连接前ping检查
|
||||
'connect_args': {
|
||||
'autocommit': config.mysql_autocommit,
|
||||
'charset': config.mysql_charset,
|
||||
'connect_timeout': config.connection_timeout,
|
||||
'read_timeout': 30,
|
||||
'write_timeout': 30,
|
||||
}
|
||||
})
|
||||
else:
|
||||
# SQLite配置 - 添加连接池设置以避免连接耗尽
|
||||
engine_kwargs.update({
|
||||
'poolclass': QueuePool,
|
||||
'pool_size': 20, # 增加池大小
|
||||
'max_overflow': 30, # 增加溢出连接数
|
||||
'pool_timeout': 60, # 增加超时时间
|
||||
'pool_recycle': 3600, # 1小时回收连接
|
||||
'pool_pre_ping': True, # 连接前ping检查
|
||||
'connect_args': {
|
||||
'check_same_thread': False,
|
||||
'timeout': 30,
|
||||
}
|
||||
})
|
||||
|
||||
_engine = create_engine(database_url, **engine_kwargs)
|
||||
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=_engine)
|
||||
|
||||
# 创建所有表
|
||||
Base.metadata.create_all(bind=_engine)
|
||||
|
||||
logger.info(f"SQLAlchemy数据库初始化成功: {config.database_type}")
|
||||
return _engine, _SessionLocal
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_session():
|
||||
"""数据库会话上下文管理器 - 推荐使用这个而不是get_session()"""
|
||||
session = None
|
||||
try:
|
||||
_, SessionLocal = initialize_database()
|
||||
session = SessionLocal()
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
if session:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
if session:
|
||||
session.close()
|
||||
|
||||
|
||||
def get_engine():
|
||||
"""获取数据库引擎"""
|
||||
engine, _ = initialize_database()
|
||||
return engine
|
||||
808
src/common/logger.py
Normal file
808
src/common/logger.py
Normal file
@@ -0,0 +1,808 @@
|
||||
# 使用基于时间戳的文件处理器,简单的轮转份数限制
|
||||
|
||||
import logging
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
import structlog
|
||||
import tomlkit
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# 创建logs目录
|
||||
LOG_DIR = Path("logs")
|
||||
LOG_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# 全局handler实例,避免重复创建
|
||||
_file_handler = None
|
||||
_console_handler = None
|
||||
|
||||
|
||||
def get_file_handler():
|
||||
"""获取文件handler单例"""
|
||||
global _file_handler
|
||||
if _file_handler is None:
|
||||
# 确保日志目录存在
|
||||
LOG_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# 检查现有handler,避免重复创建
|
||||
root_logger = logging.getLogger()
|
||||
for handler in root_logger.handlers:
|
||||
if isinstance(handler, TimestampedFileHandler):
|
||||
_file_handler = handler
|
||||
return _file_handler
|
||||
|
||||
# 使用基于时间戳的handler,简单的轮转份数限制
|
||||
_file_handler = TimestampedFileHandler(
|
||||
log_dir=LOG_DIR,
|
||||
max_bytes=5 * 1024 * 1024, # 5MB
|
||||
backup_count=30,
|
||||
encoding="utf-8",
|
||||
)
|
||||
# 设置文件handler的日志级别
|
||||
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
||||
_file_handler.setLevel(getattr(logging, file_level.upper(), logging.INFO))
|
||||
return _file_handler
|
||||
|
||||
|
||||
def get_console_handler():
|
||||
"""获取控制台handler单例"""
|
||||
global _console_handler
|
||||
if _console_handler is None:
|
||||
_console_handler = logging.StreamHandler()
|
||||
# 设置控制台handler的日志级别
|
||||
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
||||
_console_handler.setLevel(getattr(logging, console_level.upper(), logging.INFO))
|
||||
return _console_handler
|
||||
|
||||
|
||||
class TimestampedFileHandler(logging.Handler):
|
||||
"""基于时间戳的文件处理器,简单的轮转份数限制"""
|
||||
|
||||
def __init__(self, log_dir, max_bytes=5 * 1024 * 1024, backup_count=30, encoding="utf-8"):
|
||||
super().__init__()
|
||||
self.log_dir = Path(log_dir)
|
||||
self.log_dir.mkdir(exist_ok=True)
|
||||
self.max_bytes = max_bytes
|
||||
self.backup_count = backup_count
|
||||
self.encoding = encoding
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# 当前活跃的日志文件
|
||||
self.current_file = None
|
||||
self.current_stream = None
|
||||
self._init_current_file()
|
||||
|
||||
def _init_current_file(self):
|
||||
"""初始化当前日志文件"""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self.current_file = self.log_dir / f"app_{timestamp}.log.jsonl"
|
||||
self.current_stream = open(self.current_file, "a", encoding=self.encoding)
|
||||
|
||||
def _should_rollover(self):
|
||||
"""检查是否需要轮转"""
|
||||
if self.current_file and self.current_file.exists():
|
||||
return self.current_file.stat().st_size >= self.max_bytes
|
||||
return False
|
||||
|
||||
def _do_rollover(self):
|
||||
"""执行轮转:关闭当前文件,创建新文件"""
|
||||
if self.current_stream:
|
||||
self.current_stream.close()
|
||||
|
||||
# 清理旧文件
|
||||
self._cleanup_old_files()
|
||||
|
||||
# 创建新文件
|
||||
self._init_current_file()
|
||||
|
||||
def _cleanup_old_files(self):
|
||||
"""清理旧的日志文件,保留指定数量"""
|
||||
try:
|
||||
# 获取所有日志文件
|
||||
log_files = list(self.log_dir.glob("app_*.log.jsonl"))
|
||||
|
||||
# 按修改时间排序
|
||||
log_files.sort(key=lambda f: f.stat().st_mtime, reverse=True)
|
||||
|
||||
# 删除超出数量限制的文件
|
||||
for old_file in log_files[self.backup_count :]:
|
||||
try:
|
||||
old_file.unlink()
|
||||
print(f"[日志清理] 删除旧文件: {old_file.name}")
|
||||
except Exception as e:
|
||||
print(f"[日志清理] 删除失败 {old_file}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[日志清理] 清理过程出错: {e}")
|
||||
|
||||
def emit(self, record):
|
||||
"""发出日志记录"""
|
||||
try:
|
||||
with self._lock:
|
||||
# 检查是否需要轮转
|
||||
if self._should_rollover():
|
||||
self._do_rollover()
|
||||
|
||||
# 写入日志
|
||||
if self.current_stream:
|
||||
msg = self.format(record)
|
||||
self.current_stream.write(msg + "\n")
|
||||
self.current_stream.flush()
|
||||
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
def close(self):
|
||||
"""关闭处理器"""
|
||||
with self._lock:
|
||||
if self.current_stream:
|
||||
self.current_stream.close()
|
||||
self.current_stream = None
|
||||
super().close()
|
||||
|
||||
|
||||
# 旧的轮转文件处理器已移除,现在使用基于时间戳的处理器
|
||||
|
||||
|
||||
def close_handlers():
|
||||
"""安全关闭所有handler"""
|
||||
global _file_handler, _console_handler
|
||||
|
||||
if _file_handler:
|
||||
_file_handler.close()
|
||||
_file_handler = None
|
||||
|
||||
if _console_handler:
|
||||
_console_handler.close()
|
||||
_console_handler = None
|
||||
|
||||
|
||||
def remove_duplicate_handlers(): # sourcery skip: for-append-to-extend, list-comprehension
|
||||
"""移除重复的handler,特别是文件handler"""
|
||||
root_logger = logging.getLogger()
|
||||
|
||||
# 收集所有时间戳文件handler
|
||||
file_handlers = []
|
||||
for handler in root_logger.handlers[:]:
|
||||
if isinstance(handler, TimestampedFileHandler):
|
||||
file_handlers.append(handler)
|
||||
|
||||
# 如果有多个文件handler,保留第一个,关闭其他的
|
||||
if len(file_handlers) > 1:
|
||||
print(f"[日志系统] 检测到 {len(file_handlers)} 个重复的文件handler,正在清理...")
|
||||
for i, handler in enumerate(file_handlers[1:], 1):
|
||||
print(f"[日志系统] 关闭重复的文件handler {i}")
|
||||
root_logger.removeHandler(handler)
|
||||
handler.close()
|
||||
|
||||
# 更新全局引用
|
||||
global _file_handler
|
||||
_file_handler = file_handlers[0]
|
||||
|
||||
|
||||
# 读取日志配置
|
||||
def load_log_config(): # sourcery skip: use-contextlib-suppress
|
||||
"""从配置文件加载日志设置"""
|
||||
config_path = Path("config/bot_config.toml")
|
||||
default_config = {
|
||||
"date_style": "m-d H:i:s",
|
||||
"log_level_style": "lite",
|
||||
"color_text": "full",
|
||||
"log_level": "INFO", # 全局日志级别(向下兼容)
|
||||
"console_log_level": "INFO", # 控制台日志级别
|
||||
"file_log_level": "DEBUG", # 文件日志级别
|
||||
"suppress_libraries": ["faiss","httpx", "urllib3", "asyncio", "websockets", "httpcore", "requests", "peewee", "openai","uvicorn","jieba"],
|
||||
"library_log_levels": { "aiohttp": "WARNING"},
|
||||
}
|
||||
|
||||
try:
|
||||
if config_path.exists():
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = tomlkit.load(f)
|
||||
return config.get("log", default_config)
|
||||
except Exception as e:
|
||||
print(f"[日志系统] 加载日志配置失败: {e}")
|
||||
pass
|
||||
|
||||
return default_config
|
||||
|
||||
|
||||
LOG_CONFIG = load_log_config()
|
||||
|
||||
|
||||
def get_timestamp_format():
|
||||
"""将配置中的日期格式转换为Python格式"""
|
||||
date_style = LOG_CONFIG.get("date_style", "Y-m-d H:i:s")
|
||||
# 转换PHP风格的日期格式到Python格式
|
||||
format_map = {
|
||||
"Y": "%Y", # 4位年份
|
||||
"m": "%m", # 月份(01-12)
|
||||
"d": "%d", # 日期(01-31)
|
||||
"H": "%H", # 小时(00-23)
|
||||
"i": "%M", # 分钟(00-59)
|
||||
"s": "%S", # 秒数(00-59)
|
||||
}
|
||||
|
||||
python_format = date_style
|
||||
for php_char, python_char in format_map.items():
|
||||
python_format = python_format.replace(php_char, python_char)
|
||||
|
||||
return python_format
|
||||
|
||||
|
||||
def configure_third_party_loggers():
|
||||
"""配置第三方库的日志级别"""
|
||||
# 设置根logger级别为所有handler中最低的级别,确保所有日志都能被捕获
|
||||
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
||||
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
||||
|
||||
# 获取最低级别(DEBUG < INFO < WARNING < ERROR < CRITICAL)
|
||||
console_level_num = getattr(logging, console_level.upper(), logging.INFO)
|
||||
file_level_num = getattr(logging, file_level.upper(), logging.INFO)
|
||||
min_level = min(console_level_num, file_level_num)
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(min_level)
|
||||
|
||||
# 完全屏蔽的库
|
||||
suppress_libraries = LOG_CONFIG.get("suppress_libraries", [])
|
||||
for lib_name in suppress_libraries:
|
||||
lib_logger = logging.getLogger(lib_name)
|
||||
lib_logger.setLevel(logging.CRITICAL + 1) # 设置为比CRITICAL更高的级别,基本屏蔽所有日志
|
||||
lib_logger.propagate = False # 阻止向上传播
|
||||
|
||||
# 设置特定级别的库
|
||||
library_log_levels = LOG_CONFIG.get("library_log_levels", {})
|
||||
for lib_name, level_name in library_log_levels.items():
|
||||
lib_logger = logging.getLogger(lib_name)
|
||||
level = getattr(logging, level_name.upper(), logging.WARNING)
|
||||
lib_logger.setLevel(level)
|
||||
|
||||
|
||||
def reconfigure_existing_loggers():
|
||||
"""重新配置所有已存在的logger,解决加载顺序问题"""
|
||||
# 获取根logger
|
||||
root_logger = logging.getLogger()
|
||||
|
||||
# 重新设置根logger的所有handler的格式化器
|
||||
for handler in root_logger.handlers:
|
||||
if isinstance(handler, TimestampedFileHandler):
|
||||
handler.setFormatter(file_formatter)
|
||||
elif isinstance(handler, logging.StreamHandler):
|
||||
handler.setFormatter(console_formatter)
|
||||
|
||||
# 遍历所有已存在的logger并重新配置
|
||||
logger_dict = logging.getLogger().manager.loggerDict
|
||||
for name, logger_obj in logger_dict.items():
|
||||
if isinstance(logger_obj, logging.Logger):
|
||||
# 检查是否是第三方库logger
|
||||
suppress_libraries = LOG_CONFIG.get("suppress_libraries", [])
|
||||
library_log_levels = LOG_CONFIG.get("library_log_levels", {})
|
||||
|
||||
# 如果在屏蔽列表中
|
||||
if any(name.startswith(lib) for lib in suppress_libraries):
|
||||
logger_obj.setLevel(logging.CRITICAL + 1)
|
||||
logger_obj.propagate = False
|
||||
continue
|
||||
|
||||
# 如果在特定级别设置中
|
||||
for lib_name, level_name in library_log_levels.items():
|
||||
if name.startswith(lib_name):
|
||||
level = getattr(logging, level_name.upper(), logging.WARNING)
|
||||
logger_obj.setLevel(level)
|
||||
break
|
||||
|
||||
# 强制清除并重新设置所有handler
|
||||
original_handlers = logger_obj.handlers[:]
|
||||
for handler in original_handlers:
|
||||
# 安全关闭handler
|
||||
if hasattr(handler, "close"):
|
||||
handler.close()
|
||||
logger_obj.removeHandler(handler)
|
||||
|
||||
# 如果logger没有handler,让它使用根logger的handler(propagate=True)
|
||||
if not logger_obj.handlers:
|
||||
logger_obj.propagate = True
|
||||
|
||||
# 如果logger有自己的handler,重新配置它们(避免重复创建文件handler)
|
||||
for handler in original_handlers:
|
||||
if isinstance(handler, TimestampedFileHandler):
|
||||
# 不重新添加,让它使用根logger的文件handler
|
||||
continue
|
||||
elif isinstance(handler, logging.StreamHandler):
|
||||
handler.setFormatter(console_formatter)
|
||||
logger_obj.addHandler(handler)
|
||||
|
||||
|
||||
# 定义模块颜色映射
|
||||
MODULE_COLORS = {
|
||||
# 核心模块
|
||||
"main": "\033[1;97m", # 亮白色+粗体 (主程序)
|
||||
"api": "\033[92m", # 亮绿色
|
||||
"emoji": "\033[38;5;214m", # 橙黄色,偏向橙色但与replyer和action_manager不同
|
||||
"chat": "\033[92m", # 亮蓝色
|
||||
"config": "\033[93m", # 亮黄色
|
||||
"common": "\033[95m", # 亮紫色
|
||||
"tools": "\033[96m", # 亮青色
|
||||
"lpmm": "\033[96m",
|
||||
"plugin_system": "\033[91m", # 亮红色
|
||||
"person_info": "\033[32m", # 绿色
|
||||
"individuality": "\033[94m", # 显眼的亮蓝色
|
||||
"manager": "\033[35m", # 紫色
|
||||
"llm_models": "\033[36m", # 青色
|
||||
"remote": "\033[38;5;242m", # 深灰色,更不显眼
|
||||
"planner": "\033[36m",
|
||||
"memory": "\033[38;5;117m", # 天蓝色
|
||||
"hfc": "\033[38;5;81m", # 稍微暗一些的青色,保持可读
|
||||
"action_manager": "\033[38;5;208m", # 橙色,不与replyer重复
|
||||
# 关系系统
|
||||
"relation": "\033[38;5;139m", # 柔和的紫色,不刺眼
|
||||
# 聊天相关模块
|
||||
"normal_chat": "\033[38;5;81m", # 亮蓝绿色
|
||||
"heartflow": "\033[38;5;175m", # 柔和的粉色,不显眼但保持粉色系
|
||||
"sub_heartflow": "\033[38;5;207m", # 粉紫色
|
||||
"subheartflow_manager": "\033[38;5;201m", # 深粉色
|
||||
"background_tasks": "\033[38;5;240m", # 灰色
|
||||
"chat_message": "\033[38;5;45m", # 青色
|
||||
"chat_stream": "\033[38;5;51m", # 亮青色
|
||||
"sender": "\033[38;5;67m", # 稍微暗一些的蓝色,不显眼
|
||||
"message_storage": "\033[38;5;33m", # 深蓝色
|
||||
"expressor": "\033[38;5;166m", # 橙色
|
||||
# 专注聊天模块
|
||||
"replyer": "\033[38;5;166m", # 橙色
|
||||
"memory_activator": "\033[38;5;117m", # 天蓝色
|
||||
# 插件系统
|
||||
"plugins": "\033[31m", # 红色
|
||||
"plugin_api": "\033[33m", # 黄色
|
||||
"plugin_manager": "\033[38;5;208m", # 红色
|
||||
"base_plugin": "\033[38;5;202m", # 橙红色
|
||||
"send_api": "\033[38;5;208m", # 橙色
|
||||
"base_command": "\033[38;5;208m", # 橙色
|
||||
"component_registry": "\033[38;5;214m", # 橙黄色
|
||||
"stream_api": "\033[38;5;220m", # 黄色
|
||||
"plugin_hot_reload": "\033[38;5;226m", #品红色
|
||||
"config_api": "\033[38;5;226m", # 亮黄色
|
||||
"heartflow_api": "\033[38;5;154m", # 黄绿色
|
||||
"action_apis": "\033[38;5;118m", # 绿色
|
||||
"independent_apis": "\033[38;5;82m", # 绿色
|
||||
"llm_api": "\033[38;5;46m", # 亮绿色
|
||||
"database_api": "\033[38;5;10m", # 绿色
|
||||
"utils_api": "\033[38;5;14m", # 青色
|
||||
"message_api": "\033[38;5;6m", # 青色
|
||||
# 管理器模块
|
||||
"async_task_manager": "\033[38;5;129m", # 紫色
|
||||
"mood": "\033[38;5;135m", # 紫红色
|
||||
"local_storage": "\033[38;5;141m", # 紫色
|
||||
"willing": "\033[38;5;147m", # 浅紫色
|
||||
# 工具模块
|
||||
"tool_use": "\033[38;5;172m", # 橙褐色
|
||||
"tool_executor": "\033[38;5;172m", # 橙褐色
|
||||
"base_tool": "\033[38;5;178m", # 金黄色
|
||||
# 工具和实用模块
|
||||
"prompt_build": "\033[38;5;105m", # 紫色
|
||||
"chat_utils": "\033[38;5;111m", # 蓝色
|
||||
"chat_image": "\033[38;5;117m", # 浅蓝色
|
||||
"maibot_statistic": "\033[38;5;129m", # 紫色
|
||||
# 特殊功能插件
|
||||
"mute_plugin": "\033[38;5;240m", # 灰色
|
||||
"core_actions": "\033[38;5;117m", # 深红色
|
||||
"tts_action": "\033[38;5;58m", # 深黄色
|
||||
"doubao_pic_plugin": "\033[38;5;64m", # 深绿色
|
||||
# Action组件
|
||||
"no_reply_action": "\033[38;5;214m", # 亮橙色,显眼但不像警告
|
||||
"reply_action": "\033[38;5;46m", # 亮绿色
|
||||
"base_action": "\033[38;5;250m", # 浅灰色
|
||||
# 数据库和消息
|
||||
"database_model": "\033[38;5;94m", # 橙褐色
|
||||
"database": "\033[38;5;46m", # 橙褐色
|
||||
"maim_message": "\033[38;5;140m", # 紫褐色
|
||||
# 日志系统
|
||||
"logger": "\033[38;5;8m", # 深灰色
|
||||
"confirm": "\033[1;93m", # 黄色+粗体
|
||||
# 模型相关
|
||||
"model_utils": "\033[38;5;164m", # 紫红色
|
||||
"relationship_fetcher": "\033[38;5;170m", # 浅紫色
|
||||
"relationship_builder": "\033[38;5;93m", # 浅蓝色
|
||||
|
||||
#s4u
|
||||
"context_web_api": "\033[38;5;240m", # 深灰色
|
||||
"S4U_chat": "\033[92m", # 深灰色
|
||||
}
|
||||
|
||||
# 定义模块别名映射 - 将真实的logger名称映射到显示的别名
|
||||
MODULE_ALIASES = {
|
||||
# 示例映射
|
||||
"individuality": "人格特质",
|
||||
"emoji": "表情包",
|
||||
"no_reply_action": "摸鱼",
|
||||
"reply_action": "回复",
|
||||
"action_manager": "动作",
|
||||
"memory_activator": "记忆",
|
||||
"tool_use": "工具",
|
||||
"expressor": "表达方式",
|
||||
"plugin_hot_reload": "热重载",
|
||||
"database": "数据库",
|
||||
"database_model": "数据库",
|
||||
"mood": "情绪",
|
||||
"memory": "记忆",
|
||||
"tool_executor": "工具",
|
||||
"hfc": "聊天节奏",
|
||||
"chat": "所见",
|
||||
"plugin_manager": "插件",
|
||||
"relationship_builder": "关系",
|
||||
"llm_models": "模型",
|
||||
"person_info": "人物",
|
||||
"chat_stream": "聊天流",
|
||||
"planner": "规划器",
|
||||
"replyer": "言语",
|
||||
"config": "配置",
|
||||
"main": "主程序",
|
||||
}
|
||||
|
||||
RESET_COLOR = "\033[0m"
|
||||
|
||||
|
||||
class ModuleColoredConsoleRenderer:
|
||||
"""自定义控制台渲染器,为不同模块提供不同颜色"""
|
||||
|
||||
def __init__(self, colors=True):
|
||||
# sourcery skip: merge-duplicate-blocks, remove-redundant-if
|
||||
self._colors = colors
|
||||
self._config = LOG_CONFIG
|
||||
|
||||
# 日志级别颜色
|
||||
self._level_colors = {
|
||||
"debug": "\033[38;5;208m", # 橙色
|
||||
"info": "\033[38;5;117m", # 天蓝色
|
||||
"success": "\033[32m", # 绿色
|
||||
"warning": "\033[33m", # 黄色
|
||||
"error": "\033[31m", # 红色
|
||||
"critical": "\033[35m", # 紫色
|
||||
}
|
||||
|
||||
# 根据配置决定是否启用颜色
|
||||
color_text = self._config.get("color_text", "title")
|
||||
if color_text == "none":
|
||||
self._colors = False
|
||||
elif color_text == "title":
|
||||
self._enable_module_colors = True
|
||||
self._enable_level_colors = False
|
||||
self._enable_full_content_colors = False
|
||||
elif color_text == "full":
|
||||
self._enable_module_colors = True
|
||||
self._enable_level_colors = True
|
||||
self._enable_full_content_colors = True
|
||||
else:
|
||||
self._enable_module_colors = True
|
||||
self._enable_level_colors = False
|
||||
self._enable_full_content_colors = False
|
||||
|
||||
def __call__(self, logger, method_name, event_dict):
|
||||
# sourcery skip: merge-duplicate-blocks
|
||||
"""渲染日志消息"""
|
||||
# 获取基本信息
|
||||
timestamp = event_dict.get("timestamp", "")
|
||||
level = event_dict.get("level", "info")
|
||||
logger_name = event_dict.get("logger_name", "")
|
||||
event = event_dict.get("event", "")
|
||||
|
||||
# 构建输出
|
||||
parts = []
|
||||
|
||||
# 日志级别样式配置
|
||||
log_level_style = self._config.get("log_level_style", "lite")
|
||||
level_color = self._level_colors.get(level.lower(), "") if self._colors else ""
|
||||
|
||||
# 时间戳(lite模式下按级别着色)
|
||||
if timestamp:
|
||||
if log_level_style == "lite" and level_color:
|
||||
timestamp_part = f"{level_color}{timestamp}{RESET_COLOR}"
|
||||
else:
|
||||
timestamp_part = timestamp
|
||||
parts.append(timestamp_part)
|
||||
|
||||
# 日志级别显示(根据配置样式)
|
||||
if log_level_style == "full":
|
||||
# 显示完整级别名并着色
|
||||
level_text = level.upper()
|
||||
if level_color:
|
||||
level_part = f"{level_color}[{level_text:>8}]{RESET_COLOR}"
|
||||
else:
|
||||
level_part = f"[{level_text:>8}]"
|
||||
parts.append(level_part)
|
||||
|
||||
elif log_level_style == "compact":
|
||||
# 只显示首字母并着色
|
||||
level_text = level.upper()[0]
|
||||
if level_color:
|
||||
level_part = f"{level_color}[{level_text:>8}]{RESET_COLOR}"
|
||||
else:
|
||||
level_part = f"[{level_text:>8}]"
|
||||
parts.append(level_part)
|
||||
|
||||
# lite模式不显示级别,只给时间戳着色
|
||||
|
||||
# 获取模块颜色,用于full模式下的整体着色
|
||||
module_color = ""
|
||||
if self._colors and self._enable_module_colors and logger_name:
|
||||
module_color = MODULE_COLORS.get(logger_name, "")
|
||||
|
||||
# 模块名称(带颜色和别名支持)
|
||||
if logger_name:
|
||||
# 获取别名,如果没有别名则使用原名称
|
||||
display_name = MODULE_ALIASES.get(logger_name, logger_name)
|
||||
|
||||
if self._colors and self._enable_module_colors:
|
||||
if module_color:
|
||||
module_part = f"{module_color}[{display_name}]{RESET_COLOR}"
|
||||
else:
|
||||
module_part = f"[{display_name}]"
|
||||
else:
|
||||
module_part = f"[{display_name}]"
|
||||
parts.append(module_part)
|
||||
|
||||
# 消息内容(确保转换为字符串)
|
||||
event_content = ""
|
||||
if isinstance(event, str):
|
||||
event_content = event
|
||||
elif isinstance(event, dict):
|
||||
# 如果是字典,格式化为可读字符串
|
||||
try:
|
||||
event_content = json.dumps(event, ensure_ascii=False, indent=None)
|
||||
except (TypeError, ValueError):
|
||||
event_content = str(event)
|
||||
else:
|
||||
# 其他类型直接转换为字符串
|
||||
event_content = str(event)
|
||||
|
||||
# 在full模式下为消息内容着色
|
||||
if self._colors and self._enable_full_content_colors and module_color:
|
||||
event_content = f"{module_color}{event_content}{RESET_COLOR}"
|
||||
|
||||
parts.append(event_content)
|
||||
|
||||
# 处理其他字段
|
||||
extras = []
|
||||
for key, value in event_dict.items():
|
||||
if key not in ("timestamp", "level", "logger_name", "event"):
|
||||
# 确保值也转换为字符串
|
||||
if isinstance(value, (dict, list)):
|
||||
try:
|
||||
value_str = json.dumps(value, ensure_ascii=False, indent=None)
|
||||
except (TypeError, ValueError):
|
||||
value_str = str(value)
|
||||
else:
|
||||
value_str = str(value)
|
||||
|
||||
# 在full模式下为额外字段着色
|
||||
extra_field = f"{key}={value_str}"
|
||||
if self._colors and self._enable_full_content_colors and module_color:
|
||||
extra_field = f"{module_color}{extra_field}{RESET_COLOR}"
|
||||
|
||||
extras.append(extra_field)
|
||||
|
||||
if extras:
|
||||
parts.append(" ".join(extras))
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
# 配置标准logging以支持文件输出和压缩
|
||||
# 使用单例handler避免重复创建
|
||||
file_handler = get_file_handler()
|
||||
console_handler = get_console_handler()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(message)s",
|
||||
handlers=[file_handler, console_handler],
|
||||
)
|
||||
|
||||
|
||||
def configure_structlog():
|
||||
"""配置structlog"""
|
||||
structlog.configure(
|
||||
processors=[
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.processors.add_log_level,
|
||||
structlog.processors.StackInfoRenderer(),
|
||||
structlog.dev.set_exc_info,
|
||||
structlog.processors.TimeStamper(fmt=get_timestamp_format(), utc=False),
|
||||
# 根据输出类型选择不同的渲染器
|
||||
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
|
||||
],
|
||||
wrapper_class=structlog.stdlib.BoundLogger,
|
||||
context_class=dict,
|
||||
logger_factory=structlog.stdlib.LoggerFactory(),
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
|
||||
|
||||
# 配置structlog
|
||||
configure_structlog()
|
||||
|
||||
# 为文件输出配置JSON格式
|
||||
file_formatter = structlog.stdlib.ProcessorFormatter(
|
||||
processor=structlog.processors.JSONRenderer(ensure_ascii=False),
|
||||
foreign_pre_chain=[
|
||||
structlog.stdlib.add_logger_name,
|
||||
structlog.stdlib.add_log_level,
|
||||
structlog.stdlib.PositionalArgumentsFormatter(),
|
||||
structlog.processors.TimeStamper(fmt="iso"),
|
||||
structlog.processors.StackInfoRenderer(),
|
||||
structlog.processors.format_exc_info,
|
||||
],
|
||||
)
|
||||
|
||||
# 为控制台输出配置可读格式
|
||||
console_formatter = structlog.stdlib.ProcessorFormatter(
|
||||
processor=ModuleColoredConsoleRenderer(colors=True),
|
||||
foreign_pre_chain=[
|
||||
structlog.stdlib.add_logger_name,
|
||||
structlog.stdlib.add_log_level,
|
||||
structlog.stdlib.PositionalArgumentsFormatter(),
|
||||
structlog.processors.TimeStamper(fmt=get_timestamp_format(), utc=False),
|
||||
structlog.processors.StackInfoRenderer(),
|
||||
structlog.processors.format_exc_info,
|
||||
],
|
||||
)
|
||||
|
||||
# 获取根logger并配置格式化器
|
||||
root_logger = logging.getLogger()
|
||||
for handler in root_logger.handlers:
|
||||
if isinstance(handler, TimestampedFileHandler):
|
||||
handler.setFormatter(file_formatter)
|
||||
else:
|
||||
handler.setFormatter(console_formatter)
|
||||
|
||||
|
||||
# 立即配置日志系统,确保最早期的日志也使用正确格式
|
||||
def _immediate_setup():
|
||||
"""立即设置日志系统,在模块导入时就生效"""
|
||||
# 重新配置structlog
|
||||
configure_structlog()
|
||||
|
||||
# 清除所有已有的handler,重新配置
|
||||
root_logger = logging.getLogger()
|
||||
for handler in root_logger.handlers[:]:
|
||||
root_logger.removeHandler(handler)
|
||||
|
||||
# 使用单例handler避免重复创建
|
||||
file_handler = get_file_handler()
|
||||
console_handler = get_console_handler()
|
||||
|
||||
# 重新添加配置好的handler
|
||||
root_logger.addHandler(file_handler)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# 设置格式化器
|
||||
file_handler.setFormatter(file_formatter)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
|
||||
# 清理重复的handler
|
||||
remove_duplicate_handlers()
|
||||
|
||||
# 配置第三方库日志
|
||||
configure_third_party_loggers()
|
||||
|
||||
# 重新配置所有已存在的logger
|
||||
reconfigure_existing_loggers()
|
||||
|
||||
|
||||
# 立即执行配置
|
||||
_immediate_setup()
|
||||
|
||||
raw_logger: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
binds: dict[str, Callable] = {}
|
||||
|
||||
|
||||
def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger:
|
||||
"""获取logger实例,支持按名称绑定"""
|
||||
if name is None:
|
||||
return raw_logger
|
||||
logger = binds.get(name) # type: ignore
|
||||
if logger is None:
|
||||
logger: structlog.stdlib.BoundLogger = structlog.get_logger(name).bind(logger_name=name)
|
||||
binds[name] = logger
|
||||
return logger
|
||||
|
||||
|
||||
def initialize_logging():
|
||||
"""手动初始化日志系统,确保所有logger都使用正确的配置
|
||||
|
||||
在应用程序的早期调用此函数,确保所有模块都使用统一的日志配置
|
||||
"""
|
||||
global LOG_CONFIG
|
||||
LOG_CONFIG = load_log_config()
|
||||
# print(LOG_CONFIG)
|
||||
configure_third_party_loggers()
|
||||
reconfigure_existing_loggers()
|
||||
|
||||
# 启动日志清理任务
|
||||
start_log_cleanup_task()
|
||||
|
||||
# 输出初始化信息
|
||||
logger = get_logger("logger")
|
||||
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
||||
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
||||
|
||||
logger.info("日志系统已初始化:")
|
||||
logger.info(f" - 控制台级别: {console_level}")
|
||||
logger.info(f" - 文件级别: {file_level}")
|
||||
logger.info(" - 轮转份数: 30个文件|自动清理: 30天前的日志")
|
||||
|
||||
|
||||
def cleanup_old_logs():
|
||||
"""清理过期的日志文件"""
|
||||
try:
|
||||
cleanup_days = 30 # 硬编码30天
|
||||
cutoff_date = datetime.now() - timedelta(days=cleanup_days)
|
||||
deleted_count = 0
|
||||
deleted_size = 0
|
||||
|
||||
# 遍历日志目录
|
||||
for log_file in LOG_DIR.glob("*.log*"):
|
||||
try:
|
||||
file_time = datetime.fromtimestamp(log_file.stat().st_mtime)
|
||||
if file_time < cutoff_date:
|
||||
file_size = log_file.stat().st_size
|
||||
log_file.unlink()
|
||||
deleted_count += 1
|
||||
deleted_size += file_size
|
||||
except Exception as e:
|
||||
logger = get_logger("logger")
|
||||
logger.warning(f"清理日志文件 {log_file} 时出错: {e}")
|
||||
|
||||
if deleted_count > 0:
|
||||
logger = get_logger("logger")
|
||||
logger.info(f"清理了 {deleted_count} 个过期日志文件,释放空间 {deleted_size / 1024 / 1024:.2f} MB")
|
||||
|
||||
except Exception as e:
|
||||
logger = get_logger("logger")
|
||||
logger.error(f"清理旧日志文件时出错: {e}")
|
||||
|
||||
|
||||
def start_log_cleanup_task():
|
||||
"""启动日志清理任务"""
|
||||
|
||||
def cleanup_task():
|
||||
while True:
|
||||
time.sleep(24 * 60 * 60) # 每24小时执行一次
|
||||
cleanup_old_logs()
|
||||
|
||||
cleanup_thread = threading.Thread(target=cleanup_task, daemon=True)
|
||||
cleanup_thread.start()
|
||||
|
||||
logger = get_logger("logger")
|
||||
logger.info("已启动日志清理任务,将自动清理30天前的日志文件(轮转份数限制: 30个文件)")
|
||||
|
||||
|
||||
def shutdown_logging():
|
||||
"""优雅关闭日志系统,释放所有文件句柄"""
|
||||
logger = get_logger("logger")
|
||||
logger.info("正在关闭日志系统...")
|
||||
|
||||
# 关闭所有handler
|
||||
root_logger = logging.getLogger()
|
||||
for handler in root_logger.handlers[:]:
|
||||
if hasattr(handler, "close"):
|
||||
handler.close()
|
||||
root_logger.removeHandler(handler)
|
||||
|
||||
# 关闭全局handler
|
||||
close_handlers()
|
||||
|
||||
# 关闭所有其他logger的handler
|
||||
logger_dict = logging.getLogger().manager.loggerDict
|
||||
for _name, logger_obj in logger_dict.items():
|
||||
if isinstance(logger_obj, logging.Logger):
|
||||
for handler in logger_obj.handlers[:]:
|
||||
if hasattr(handler, "close"):
|
||||
handler.close()
|
||||
logger_obj.removeHandler(handler)
|
||||
|
||||
logger.info("日志系统已关闭")
|
||||
10
src/common/message/__init__.py
Normal file
10
src/common/message/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Maim Message - A message handling library"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
from .api import get_global_api
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_global_api",
|
||||
]
|
||||
59
src/common/message/api.py
Normal file
59
src/common/message/api.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from src.common.server import get_global_server
|
||||
import os
|
||||
import importlib.metadata
|
||||
from maim_message import MessageServer
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
global_api = None
|
||||
|
||||
|
||||
def get_global_api() -> MessageServer: # sourcery skip: extract-method
|
||||
"""获取全局MessageServer实例"""
|
||||
global global_api
|
||||
if global_api is None:
|
||||
# 检查maim_message版本
|
||||
try:
|
||||
maim_message_version = importlib.metadata.version("maim_message")
|
||||
version_compatible = [int(x) for x in maim_message_version.split(".")] >= [0, 3, 3]
|
||||
except (importlib.metadata.PackageNotFoundError, ValueError):
|
||||
version_compatible = False
|
||||
|
||||
# 读取配置项
|
||||
maim_message_config = global_config.maim_message
|
||||
|
||||
# 设置基本参数
|
||||
kwargs = {
|
||||
"host": os.environ["HOST"],
|
||||
"port": int(os.environ["PORT"]),
|
||||
"app": get_global_server().get_app(),
|
||||
}
|
||||
|
||||
# 只有在版本 >= 0.3.0 时才使用高级特性
|
||||
if version_compatible:
|
||||
# 添加自定义logger
|
||||
maim_message_logger = get_logger("maim_message")
|
||||
kwargs["custom_logger"] = maim_message_logger
|
||||
|
||||
# 添加token认证
|
||||
if maim_message_config.auth_token and len(maim_message_config.auth_token) > 0:
|
||||
kwargs["enable_token"] = True
|
||||
|
||||
if maim_message_config.use_custom:
|
||||
# 添加WSS模式支持
|
||||
del kwargs["app"]
|
||||
kwargs["host"] = maim_message_config.host
|
||||
kwargs["port"] = maim_message_config.port
|
||||
kwargs["mode"] = maim_message_config.mode
|
||||
if maim_message_config.use_wss:
|
||||
if maim_message_config.cert_file:
|
||||
kwargs["ssl_certfile"] = maim_message_config.cert_file
|
||||
if maim_message_config.key_file:
|
||||
kwargs["ssl_keyfile"] = maim_message_config.key_file
|
||||
kwargs["enable_custom_uvicorn_logger"] = False
|
||||
|
||||
global_api = MessageServer(**kwargs)
|
||||
if version_compatible and maim_message_config.auth_token:
|
||||
for token in maim_message_config.auth_token:
|
||||
global_api.add_valid_token(token)
|
||||
return global_api
|
||||
203
src/common/message_repository.py
Normal file
203
src/common/message_repository.py
Normal file
@@ -0,0 +1,203 @@
|
||||
import traceback
|
||||
|
||||
from typing import List, Optional, Any, Dict
|
||||
from sqlalchemy import not_, select, func
|
||||
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from src.config.config import global_config
|
||||
|
||||
# from src.common.database.database_model import Messages
|
||||
from src.common.database.sqlalchemy_models import Messages
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
def _model_to_dict(instance: Base) -> Dict[str, Any]:
|
||||
"""
|
||||
将 SQLAlchemy 模型实例转换为字典。
|
||||
"""
|
||||
return {col.name: getattr(instance, col.name) for col in instance.__table__.columns}
|
||||
|
||||
|
||||
def find_messages(
|
||||
message_filter: dict[str, Any],
|
||||
sort: Optional[List[tuple[str, int]]] = None,
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
filter_bot=False,
|
||||
filter_command=False,
|
||||
) -> List[dict[str, Any]]:
|
||||
"""
|
||||
根据提供的过滤器、排序和限制条件查找消息。
|
||||
|
||||
Args:
|
||||
message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}).
|
||||
sort: 排序条件列表,例如 [('time', 1)] (1 for asc, -1 for desc)。仅在 limit 为 0 时生效。
|
||||
limit: 返回的最大文档数,0表示不限制。
|
||||
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest'。
|
||||
|
||||
Returns:
|
||||
消息字典列表,如果出错则返回空列表。
|
||||
"""
|
||||
try:
|
||||
session = get_session()
|
||||
query = select(Messages)
|
||||
|
||||
# 应用过滤器
|
||||
if message_filter:
|
||||
conditions = []
|
||||
for key, value in message_filter.items():
|
||||
if hasattr(Messages, key):
|
||||
field = getattr(Messages, key)
|
||||
if isinstance(value, dict):
|
||||
# 处理 MongoDB 风格的操作符
|
||||
for op, op_value in value.items():
|
||||
if op == "$gt":
|
||||
conditions.append(field > op_value)
|
||||
elif op == "$lt":
|
||||
conditions.append(field < op_value)
|
||||
elif op == "$gte":
|
||||
conditions.append(field >= op_value)
|
||||
elif op == "$lte":
|
||||
conditions.append(field <= op_value)
|
||||
elif op == "$ne":
|
||||
conditions.append(field != op_value)
|
||||
elif op == "$in":
|
||||
conditions.append(field.in_(op_value))
|
||||
elif op == "$nin":
|
||||
conditions.append(field.not_in(op_value))
|
||||
else:
|
||||
logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
|
||||
else:
|
||||
# 直接相等比较
|
||||
conditions.append(field == value)
|
||||
else:
|
||||
logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
|
||||
if conditions:
|
||||
query = query.where(*conditions)
|
||||
|
||||
if filter_bot:
|
||||
query = query.where(Messages.user_id != global_config.bot.qq_account)
|
||||
|
||||
if filter_command:
|
||||
query = query.where(not_(Messages.is_command))
|
||||
|
||||
if limit > 0:
|
||||
# 确保limit是正整数
|
||||
limit = max(1, int(limit))
|
||||
|
||||
if limit_mode == "earliest":
|
||||
# 获取时间最早的 limit 条记录,已经是正序
|
||||
query = query.order_by(Messages.time.asc()).limit(limit)
|
||||
try:
|
||||
results = session.execute(query).scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"执行earliest查询失败: {e}")
|
||||
results = []
|
||||
else: # 默认为 'latest'
|
||||
# 获取时间最晚的 limit 条记录
|
||||
query = query.order_by(Messages.time.desc()).limit(limit)
|
||||
try:
|
||||
latest_results = session.execute(query).scalars().all()
|
||||
# 将结果按时间正序排列
|
||||
results = sorted(latest_results, key=lambda msg: msg.time)
|
||||
except Exception as e:
|
||||
logger.error(f"执行latest查询失败: {e}")
|
||||
results = []
|
||||
else:
|
||||
# limit 为 0 时,应用传入的 sort 参数
|
||||
if sort:
|
||||
sort_terms = []
|
||||
for field_name, direction in sort:
|
||||
if hasattr(Messages, field_name):
|
||||
field = getattr(Messages, field_name)
|
||||
if direction == 1: # ASC
|
||||
sort_terms.append(field.asc())
|
||||
elif direction == -1: # DESC
|
||||
sort_terms.append(field.desc())
|
||||
else:
|
||||
logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
|
||||
else:
|
||||
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
|
||||
if sort_terms:
|
||||
query = query.order_by(*sort_terms)
|
||||
try:
|
||||
results = session.execute(query).scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"执行无限制查询失败: {e}")
|
||||
results = []
|
||||
|
||||
return [_model_to_dict(msg) for msg in results]
|
||||
except Exception as e:
|
||||
log_message = (
|
||||
f"使用 SQLAlchemy 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
|
||||
+ traceback.format_exc()
|
||||
)
|
||||
logger.error(log_message)
|
||||
return []
|
||||
|
||||
|
||||
def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
"""
|
||||
根据提供的过滤器计算消息数量。
|
||||
|
||||
Args:
|
||||
message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}).
|
||||
|
||||
Returns:
|
||||
符合条件的消息数量,如果出错则返回 0。
|
||||
"""
|
||||
try:
|
||||
session = get_session()
|
||||
query = select(func.count(Messages.id))
|
||||
|
||||
# 应用过滤器
|
||||
if message_filter:
|
||||
conditions = []
|
||||
for key, value in message_filter.items():
|
||||
if hasattr(Messages, key):
|
||||
field = getattr(Messages, key)
|
||||
if isinstance(value, dict):
|
||||
# 处理 MongoDB 风格的操作符
|
||||
for op, op_value in value.items():
|
||||
if op == "$gt":
|
||||
conditions.append(field > op_value)
|
||||
elif op == "$lt":
|
||||
conditions.append(field < op_value)
|
||||
elif op == "$gte":
|
||||
conditions.append(field >= op_value)
|
||||
elif op == "$lte":
|
||||
conditions.append(field <= op_value)
|
||||
elif op == "$ne":
|
||||
conditions.append(field != op_value)
|
||||
elif op == "$in":
|
||||
conditions.append(field.in_(op_value))
|
||||
elif op == "$nin":
|
||||
conditions.append(field.not_in(op_value))
|
||||
else:
|
||||
logger.warning(
|
||||
f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。"
|
||||
)
|
||||
else:
|
||||
# 直接相等比较
|
||||
conditions.append(field == value)
|
||||
else:
|
||||
logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
|
||||
if conditions:
|
||||
query = query.where(*conditions)
|
||||
|
||||
count = session.execute(query).scalar()
|
||||
return count or 0
|
||||
except Exception as e:
|
||||
log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
|
||||
logger.error(log_message)
|
||||
return 0
|
||||
|
||||
|
||||
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
|
||||
# 注意:对于 SQLAlchemy,插入操作通常是使用 session.add() 和 session.commit()。
|
||||
# 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。
|
||||
165
src/common/remote.py
Normal file
165
src/common/remote.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import asyncio
|
||||
|
||||
import aiohttp
|
||||
import platform
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.tcp_connector import get_tcp_connector
|
||||
from src.config.config import global_config
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
logger = get_logger("remote")
|
||||
|
||||
TELEMETRY_SERVER_URL = "http://hyybuth.xyz:10058"
|
||||
"""遥测服务地址"""
|
||||
|
||||
|
||||
class TelemetryHeartBeatTask(AsyncTask):
|
||||
HEARTBEAT_INTERVAL = 300
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(task_name="Telemetry Heart Beat Task", run_interval=self.HEARTBEAT_INTERVAL)
|
||||
self.server_url = TELEMETRY_SERVER_URL
|
||||
"""遥测服务地址"""
|
||||
|
||||
self.client_uuid: str | None = local_storage["mmc_uuid"] if "mmc_uuid" in local_storage else None # type: ignore
|
||||
"""客户端UUID"""
|
||||
|
||||
self.info_dict = self._get_sys_info()
|
||||
"""系统信息字典"""
|
||||
|
||||
@staticmethod
|
||||
def _get_sys_info() -> dict[str, str]:
|
||||
"""获取系统信息"""
|
||||
info_dict = {
|
||||
"os_type": "Unknown",
|
||||
"py_version": platform.python_version(),
|
||||
"mmc_version": global_config.MMC_VERSION,
|
||||
}
|
||||
|
||||
match platform.system():
|
||||
case "Windows":
|
||||
info_dict["os_type"] = "Windows"
|
||||
case "Linux":
|
||||
info_dict["os_type"] = "Linux"
|
||||
case "Darwin":
|
||||
info_dict["os_type"] = "macOS"
|
||||
case _:
|
||||
info_dict["os_type"] = "Unknown"
|
||||
|
||||
return info_dict
|
||||
|
||||
async def _req_uuid(self) -> bool:
|
||||
"""
|
||||
向服务端请求UUID(不应在已存在UUID的情况下调用,会覆盖原有的UUID)
|
||||
"""
|
||||
|
||||
if "deploy_time" not in local_storage:
|
||||
logger.error("本地存储中缺少部署时间,无法请求UUID")
|
||||
return False
|
||||
|
||||
try_count: int = 0
|
||||
while True:
|
||||
# 如果不存在,则向服务端请求一个新的UUID(注册客户端)
|
||||
logger.info("正在向遥测服务端请求UUID...")
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session:
|
||||
async with session.post(
|
||||
f"{TELEMETRY_SERVER_URL}/stat/reg_client",
|
||||
json={"deploy_time": local_storage["deploy_time"]},
|
||||
timeout=aiohttp.ClientTimeout(total=5), # 设置超时时间为5秒
|
||||
) as response:
|
||||
logger.debug(f"{TELEMETRY_SERVER_URL}/stat/reg_client")
|
||||
logger.debug(local_storage["deploy_time"]) # type: ignore
|
||||
logger.debug(f"Response status: {response.status}")
|
||||
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
if client_id := data.get("mmc_uuid"):
|
||||
# 将UUID存储到本地
|
||||
local_storage["mmc_uuid"] = client_id
|
||||
self.client_uuid = client_id
|
||||
logger.info(f"成功获取UUID: {self.client_uuid}")
|
||||
return True # 成功获取UUID,返回True
|
||||
else:
|
||||
logger.error("无效的服务端响应")
|
||||
else:
|
||||
response_text = await response.text()
|
||||
logger.error(
|
||||
f"请求UUID失败,不过你还是可以正常使用麦麦,状态码: {response.status}, 响应内容: {response_text}"
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
error_msg = str(e) or "未知错误"
|
||||
logger.warning(
|
||||
f"请求UUID出错,不过你还是可以正常使用麦麦: {type(e).__name__}: {error_msg}"
|
||||
) # 可能是网络问题
|
||||
logger.debug(f"完整错误信息: {traceback.format_exc()}")
|
||||
|
||||
# 请求失败,重试次数+1
|
||||
try_count += 1
|
||||
if try_count > 3:
|
||||
# 如果超过3次仍然失败,则退出
|
||||
logger.error("获取UUID失败,请检查网络连接或服务端状态")
|
||||
return False
|
||||
else:
|
||||
# 如果可以重试,等待后继续(指数退避)
|
||||
logger.info(f"获取UUID失败,将于 {4**try_count} 秒后重试...")
|
||||
await asyncio.sleep(4**try_count)
|
||||
|
||||
async def _send_heartbeat(self):
|
||||
"""向服务器发送心跳"""
|
||||
headers = {
|
||||
"Client-UUID": self.client_uuid,
|
||||
"User-Agent": f"HeartbeatClient/{self.client_uuid[:8]}", # type: ignore
|
||||
}
|
||||
|
||||
logger.debug(f"正在发送心跳到服务器: {self.server_url}")
|
||||
logger.debug(str(headers))
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session:
|
||||
async with session.post(
|
||||
f"{self.server_url}/stat/client_heartbeat",
|
||||
headers=headers,
|
||||
json=self.info_dict,
|
||||
timeout=aiohttp.ClientTimeout(total=5), # 设置超时时间为5秒
|
||||
) as response:
|
||||
logger.debug(f"Response status: {response.status}")
|
||||
|
||||
# 处理响应
|
||||
if 200 <= response.status < 300:
|
||||
# 成功
|
||||
logger.debug(f"心跳发送成功,状态码: {response.status}")
|
||||
elif response.status == 403:
|
||||
# 403 Forbidden
|
||||
logger.warning(
|
||||
"(此消息不会影响正常使用)心跳发送失败,403 Forbidden: 可能是UUID无效或未注册。"
|
||||
"处理措施:重置UUID,下次发送心跳时将尝试重新注册。"
|
||||
)
|
||||
self.client_uuid = None
|
||||
del local_storage["mmc_uuid"] # 删除本地存储的UUID
|
||||
else:
|
||||
# 其他错误
|
||||
response_text = await response.text()
|
||||
logger.warning(
|
||||
f"(此消息不会影响正常使用)状态未发送,状态码: {response.status}, 响应内容: {response_text}"
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
error_msg = str(e) or "未知错误"
|
||||
logger.warning(f"(此消息不会影响正常使用)状态未发生: {type(e).__name__}: {error_msg}")
|
||||
logger.debug(f"完整错误信息: {traceback.format_exc()}")
|
||||
|
||||
async def run(self):
|
||||
# 发送心跳
|
||||
if global_config.telemetry.enable:
|
||||
if self.client_uuid is None and not await self._req_uuid():
|
||||
logger.warning("获取UUID失败,跳过此次心跳")
|
||||
return
|
||||
|
||||
await self._send_heartbeat()
|
||||
101
src/common/server.py
Normal file
101
src/common/server.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from fastapi import FastAPI, APIRouter
|
||||
from fastapi.middleware.cors import CORSMiddleware # 新增导入
|
||||
from typing import Optional
|
||||
from uvicorn import Config, Server as UvicornServer
|
||||
import os
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
class Server:
|
||||
def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"):
|
||||
self.app = FastAPI(title=app_name)
|
||||
self._host: str = "127.0.0.1"
|
||||
self._port: int = 8080
|
||||
self._server: Optional[UvicornServer] = None
|
||||
self.set_address(host, port)
|
||||
|
||||
# 配置 CORS
|
||||
origins = [
|
||||
"http://localhost:3000", # 允许的前端源
|
||||
"http://127.0.0.1:3000",
|
||||
# 在生产环境中,您应该添加实际的前端域名
|
||||
]
|
||||
|
||||
self.app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True, # 是否支持 cookie
|
||||
allow_methods=["*"], # 允许所有 HTTP 方法
|
||||
allow_headers=["*"], # 允许所有 HTTP 请求头
|
||||
)
|
||||
|
||||
def register_router(self, router: APIRouter, prefix: str = ""):
|
||||
"""注册路由
|
||||
|
||||
APIRouter 用于对相关的路由端点进行分组和模块化管理:
|
||||
1. 可以将相关的端点组织在一起,便于管理
|
||||
2. 支持添加统一的路由前缀
|
||||
3. 可以为一组路由添加共同的依赖项、标签等
|
||||
|
||||
示例:
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/users")
|
||||
def get_users():
|
||||
return {"users": [...]}
|
||||
|
||||
@router.post("/users")
|
||||
def create_user():
|
||||
return {"msg": "user created"}
|
||||
|
||||
# 注册路由,添加前缀 "/api/v1"
|
||||
server.register_router(router, prefix="/api/v1")
|
||||
"""
|
||||
self.app.include_router(router, prefix=prefix)
|
||||
|
||||
def set_address(self, host: Optional[str] = None, port: Optional[int] = None):
|
||||
"""设置服务器地址和端口"""
|
||||
if host:
|
||||
self._host = host
|
||||
if port:
|
||||
self._port = port
|
||||
|
||||
async def run(self):
|
||||
"""启动服务器"""
|
||||
# 禁用 uvicorn 默认日志和访问日志
|
||||
config = Config(app=self.app, host=self._host, port=self._port, log_config=None, access_log=False)
|
||||
self._server = UvicornServer(config=config)
|
||||
try:
|
||||
await self._server.serve()
|
||||
except KeyboardInterrupt:
|
||||
await self.shutdown()
|
||||
raise
|
||||
except Exception as e:
|
||||
await self.shutdown()
|
||||
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
|
||||
finally:
|
||||
await self.shutdown()
|
||||
|
||||
async def shutdown(self):
|
||||
"""安全关闭服务器"""
|
||||
if self._server:
|
||||
self._server.should_exit = True
|
||||
await self._server.shutdown()
|
||||
self._server = None
|
||||
|
||||
def get_app(self) -> FastAPI:
|
||||
"""获取 FastAPI 实例"""
|
||||
return self.app
|
||||
|
||||
|
||||
global_server = None
|
||||
|
||||
|
||||
def get_global_server() -> Server:
|
||||
"""获取全局服务器实例"""
|
||||
global global_server
|
||||
if global_server is None:
|
||||
global_server = Server(host=os.environ["HOST"], port=int(os.environ["PORT"]))
|
||||
return global_server
|
||||
9
src/common/tcp_connector.py
Normal file
9
src/common/tcp_connector.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import ssl
|
||||
import certifi
|
||||
import aiohttp
|
||||
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
|
||||
|
||||
async def get_tcp_connector():
|
||||
return aiohttp.TCPConnector(ssl=ssl_context)
|
||||
Reference in New Issue
Block a user