初始化
This commit is contained in:
@@ -12,9 +12,10 @@ import binascii
|
||||
from typing import Optional, Tuple, List, Any
|
||||
from PIL import Image
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.database.database_model import Emoji
|
||||
from src.common.database.database import db as peewee_db
|
||||
from sqlalchemy import select
|
||||
from src.common.database.database import db
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
from src.common.database.sqlalchemy_models import Emoji, Images
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
|
||||
@@ -29,6 +30,8 @@ EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录
|
||||
EMOJI_REGISTERED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录
|
||||
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
|
||||
|
||||
session = get_session()
|
||||
|
||||
"""
|
||||
还没经过测试,有些地方数据库和内存数据同步可能不完全
|
||||
|
||||
@@ -151,7 +154,7 @@ class MaiEmoji:
|
||||
# 准备数据库记录 for emoji collection
|
||||
emotion_str = ",".join(self.emotion) if self.emotion else ""
|
||||
|
||||
Emoji.create(
|
||||
emoji = Emoji(
|
||||
emoji_hash=self.hash,
|
||||
full_path=self.full_path,
|
||||
format=self.format,
|
||||
@@ -165,6 +168,8 @@ class MaiEmoji:
|
||||
usage_count=self.usage_count,
|
||||
last_used_time=self.last_used_time,
|
||||
)
|
||||
session.add(emoji)
|
||||
session.commit()
|
||||
|
||||
logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
|
||||
|
||||
@@ -200,7 +205,7 @@ class MaiEmoji:
|
||||
|
||||
# 2. 删除数据库记录
|
||||
try:
|
||||
will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash)
|
||||
will_delete_emoji = session.execute(select(Emoji).where(Emoji.emoji_hash == self.hash)).scalar_one_or_none()
|
||||
result = will_delete_emoji.delete_instance() # Returns the number of rows deleted.
|
||||
except Emoji.DoesNotExist: # type: ignore
|
||||
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
|
||||
@@ -248,7 +253,6 @@ def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str
|
||||
def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
|
||||
emoji_objects = []
|
||||
load_errors = 0
|
||||
# data is now an iterable of Peewee Emoji model instances
|
||||
emoji_data_list = list(data)
|
||||
|
||||
for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance
|
||||
@@ -393,12 +397,17 @@ class EmojiManager:
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""初始化数据库连接和表情目录"""
|
||||
peewee_db.connect(reuse_if_open=True)
|
||||
if peewee_db.is_closed():
|
||||
raise RuntimeError("数据库连接失败")
|
||||
_ensure_emoji_dir()
|
||||
Emoji.create_table(safe=True) # Ensures table exists
|
||||
self._initialized = True
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
if db.is_closed():
|
||||
raise RuntimeError("数据库连接失败")
|
||||
_ensure_emoji_dir()
|
||||
self._initialized = True # 标记为已初始化
|
||||
logger.info("EmojiManager初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"EmojiManager初始化失败: {e}")
|
||||
self._initialized = False
|
||||
raise
|
||||
|
||||
def _ensure_db(self) -> None:
|
||||
"""确保数据库已初始化"""
|
||||
@@ -410,7 +419,7 @@ class EmojiManager:
|
||||
def record_usage(self, emoji_hash: str) -> None:
|
||||
"""记录表情使用次数"""
|
||||
try:
|
||||
emoji_update = Emoji.get(Emoji.emoji_hash == emoji_hash)
|
||||
emoji_update = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none()
|
||||
emoji_update.usage_count += 1
|
||||
emoji_update.last_used_time = time.time() # Update last used time
|
||||
emoji_update.save() # Persist changes to DB
|
||||
@@ -644,10 +653,10 @@ class EmojiManager:
|
||||
"""获取所有表情包并初始化为MaiEmoji类对象,更新 self.emoji_objects"""
|
||||
try:
|
||||
self._ensure_db()
|
||||
logger.debug("[数据库] 开始加载所有表情包记录 (Peewee)...")
|
||||
logger.debug("[数据库] 开始加载所有表情包记录 ...")
|
||||
|
||||
emoji_peewee_instances = Emoji.select()
|
||||
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
|
||||
emoji_instances = session.execute(stmt = select(Emoji)).scalars().all()
|
||||
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
|
||||
|
||||
# 更新内存中的列表和数量
|
||||
self.emoji_objects = emoji_objects
|
||||
@@ -675,15 +684,15 @@ class EmojiManager:
|
||||
self._ensure_db()
|
||||
|
||||
if emoji_hash:
|
||||
query = Emoji.select().where(Emoji.emoji_hash == emoji_hash)
|
||||
session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalars().all()
|
||||
else:
|
||||
logger.warning(
|
||||
"[查询] 未提供 hash,将尝试加载所有表情包,建议使用 get_all_emoji_from_db 更新管理器状态。"
|
||||
)
|
||||
query = Emoji.select()
|
||||
query = session.execute(select(Emoji)).scalars().all()
|
||||
|
||||
emoji_peewee_instances = query
|
||||
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
|
||||
emoji_instances = query
|
||||
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
|
||||
|
||||
if load_errors > 0:
|
||||
logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。")
|
||||
@@ -760,7 +769,7 @@ class EmojiManager:
|
||||
# 如果内存中没有,从数据库查找
|
||||
self._ensure_db()
|
||||
try:
|
||||
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
|
||||
emoji_record = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none()
|
||||
if emoji_record and emoji_record.description:
|
||||
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
|
||||
return emoji_record.description
|
||||
@@ -921,9 +930,10 @@ class EmojiManager:
|
||||
# 尝试从Images表获取已有的详细描述(可能在收到表情包时已生成)
|
||||
existing_description = None
|
||||
try:
|
||||
from src.common.database.database_model import Images
|
||||
# from src.common.database.database_model_compat import Images
|
||||
|
||||
existing_image = Images.get_or_none((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
|
||||
stmt = select(Images).where((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
|
||||
existing_image = session.execute(stmt).scalar_one_or_none()
|
||||
if existing_image and existing_image.description:
|
||||
existing_description = existing_image.description
|
||||
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")
|
||||
|
||||
@@ -7,7 +7,9 @@ from datetime import datetime
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
from sqlalchemy import select
|
||||
from src.common.database.sqlalchemy_models import Expression
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
|
||||
@@ -20,7 +22,7 @@ DECAY_DAYS = 30 # 30天衰减到0.01
|
||||
DECAY_MIN = 0.01 # 最小衰减值
|
||||
|
||||
logger = get_logger("expressor")
|
||||
|
||||
session = get_session()
|
||||
|
||||
def format_create_date(timestamp: float) -> str:
|
||||
"""
|
||||
@@ -168,30 +170,50 @@ class ExpressionLearner:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
|
||||
return False
|
||||
|
||||
# def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
|
||||
# """
|
||||
# 获取指定chat_id的style表达方式(已禁用grammar的获取)
|
||||
# 返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
|
||||
# """
|
||||
# learnt_style_expressions = []
|
||||
def get_expression_by_chat_id(self) -> Tuple[List[Dict[str, float]], List[Dict[str, float]]]:
|
||||
"""
|
||||
获取指定chat_id的style和grammar表达方式
|
||||
返回的每个表达方式字典中都包含了source_id, 用于后续的更新操作
|
||||
"""
|
||||
learnt_style_expressions = []
|
||||
learnt_grammar_expressions = []
|
||||
|
||||
# 直接从数据库查询
|
||||
style_query = session.execute(select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "style")))
|
||||
for expr in style_query.scalars():
|
||||
# 确保create_date存在,如果不存在则使用last_active_time
|
||||
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
||||
learnt_style_expressions.append(
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": self.chat_id,
|
||||
"type": "style",
|
||||
"create_date": create_date,
|
||||
}
|
||||
)
|
||||
grammar_query = session.execute(select(Expression).where((Expression.chat_id == self.chat_id) & (Expression.type == "grammar")))
|
||||
for expr in grammar_query.scalars():
|
||||
# 确保create_date存在,如果不存在则使用last_active_time
|
||||
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
||||
learnt_grammar_expressions.append(
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": self.chat_id,
|
||||
"type": "grammar",
|
||||
"create_date": create_date,
|
||||
}
|
||||
)
|
||||
return learnt_style_expressions, learnt_grammar_expressions
|
||||
|
||||
|
||||
|
||||
|
||||
# # 直接从数据库查询
|
||||
# style_query = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.type == "style"))
|
||||
# for expr in style_query:
|
||||
# # 确保create_date存在,如果不存在则使用last_active_time
|
||||
# create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
|
||||
# learnt_style_expressions.append(
|
||||
# {
|
||||
# "situation": expr.situation,
|
||||
# "style": expr.style,
|
||||
# "count": expr.count,
|
||||
# "last_active_time": expr.last_active_time,
|
||||
# "source_id": self.chat_id,
|
||||
# "type": "style",
|
||||
# "create_date": create_date,
|
||||
# }
|
||||
# )
|
||||
# return learnt_style_expressions
|
||||
|
||||
|
||||
|
||||
@@ -201,7 +223,7 @@ class ExpressionLearner:
|
||||
"""
|
||||
try:
|
||||
# 获取所有表达方式
|
||||
all_expressions = Expression.select()
|
||||
all_expressions = session.execute(select(Expression)).scalars()
|
||||
|
||||
updated_count = 0
|
||||
deleted_count = 0
|
||||
@@ -217,18 +239,20 @@ class ExpressionLearner:
|
||||
|
||||
if new_count <= 0.01:
|
||||
# 如果count太小,删除这个表达方式
|
||||
expr.delete_instance()
|
||||
session.delete(expr)
|
||||
deleted_count += 1
|
||||
else:
|
||||
# 更新count
|
||||
expr.count = new_count
|
||||
expr.save()
|
||||
updated_count += 1
|
||||
|
||||
session.commit()
|
||||
|
||||
if updated_count > 0 or deleted_count > 0:
|
||||
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"数据库全局衰减失败: {e}")
|
||||
|
||||
def calculate_decay_factor(self, time_diff_days: float) -> float:
|
||||
@@ -297,23 +321,22 @@ class ExpressionLearner:
|
||||
for chat_id, expr_list in chat_dict.items():
|
||||
for new_expr in expr_list:
|
||||
# 查找是否已存在相似表达方式
|
||||
query = Expression.select().where(
|
||||
query = session.execute(select(Expression).where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == "style")
|
||||
& (Expression.situation == new_expr["situation"])
|
||||
& (Expression.style == new_expr["style"])
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
)).scalar()
|
||||
if query:
|
||||
expr_obj = query
|
||||
# 50%概率替换内容
|
||||
if random.random() < 0.5:
|
||||
expr_obj.situation = new_expr["situation"]
|
||||
expr_obj.style = new_expr["style"]
|
||||
expr_obj.count = expr_obj.count + 1
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.save()
|
||||
else:
|
||||
Expression.create(
|
||||
new_expression = Expression(
|
||||
situation=new_expr["situation"],
|
||||
style=new_expr["style"],
|
||||
count=1,
|
||||
@@ -322,16 +345,18 @@ class ExpressionLearner:
|
||||
type="style",
|
||||
create_date=current_time, # 手动设置创建日期
|
||||
)
|
||||
session.add(new_expression)
|
||||
# 限制最大数量
|
||||
exprs = list(
|
||||
Expression.select()
|
||||
.where((Expression.chat_id == chat_id) & (Expression.type == "style"))
|
||||
.order_by(Expression.count.asc())
|
||||
session.execute(select(Expression)
|
||||
.where((Expression.chat_id == chat_id) & (Expression.type == type))
|
||||
.order_by(Expression.count.asc())).scalars()
|
||||
)
|
||||
if len(exprs) > MAX_EXPRESSION_COUNT:
|
||||
# 删除count最小的多余表达方式
|
||||
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||
expr.delete_instance()
|
||||
session.delete(expr)
|
||||
session.commit()
|
||||
return learnt_expressions
|
||||
|
||||
async def learn_expression(self, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
|
||||
@@ -509,54 +534,35 @@ class ExpressionLearnerManager:
|
||||
logger.warning(f"表达方式缺少必要字段,跳过: {expr}")
|
||||
continue
|
||||
|
||||
# 查重:同chat_id+type+situation+style
|
||||
from src.common.database.database_model import Expression
|
||||
# 查重:同chat_id+type+situation+style
|
||||
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == type_str)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style_val)
|
||||
query = session.execute(select(Expression).where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == type_str)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style_val)
|
||||
)).scalar()
|
||||
if query:
|
||||
expr_obj = query
|
||||
expr_obj.count = max(expr_obj.count, count)
|
||||
expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time)
|
||||
else:
|
||||
new_expression = Expression(
|
||||
situation=situation,
|
||||
style=style_val,
|
||||
count=count,
|
||||
last_active_time=last_active_time,
|
||||
chat_id=chat_id,
|
||||
type=type_str,
|
||||
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
expr_obj.count = max(expr_obj.count, count)
|
||||
expr_obj.last_active_time = max(expr_obj.last_active_time, last_active_time)
|
||||
expr_obj.save()
|
||||
else:
|
||||
Expression.create(
|
||||
situation=situation,
|
||||
style=style_val,
|
||||
count=count,
|
||||
last_active_time=last_active_time,
|
||||
chat_id=chat_id,
|
||||
type=type_str,
|
||||
create_date=last_active_time, # 迁移时使用last_active_time作为创建时间
|
||||
)
|
||||
migrated_count += 1
|
||||
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析失败 {expr_file}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
|
||||
|
||||
# 标记迁移完成
|
||||
try:
|
||||
# 确保done.done文件的父目录存在
|
||||
done_parent_dir = os.path.dirname(done_flag)
|
||||
if not os.path.exists(done_parent_dir):
|
||||
os.makedirs(done_parent_dir, exist_ok=True)
|
||||
logger.debug(f"为done.done创建父目录: {done_parent_dir}")
|
||||
|
||||
with open(done_flag, "w", encoding="utf-8") as f:
|
||||
f.write("done\n")
|
||||
logger.info(f"表达方式JSON迁移已完成,共迁移 {migrated_count} 个表达方式,已写入done.done标记文件")
|
||||
except PermissionError as e:
|
||||
logger.error(f"权限不足,无法写入done.done标记文件: {e}")
|
||||
except OSError as e:
|
||||
logger.error(f"文件系统错误,无法写入done.done标记文件: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"写入done.done标记文件失败: {e}")
|
||||
session.add(new_expression)
|
||||
migrated_count += 1
|
||||
logger.info(f"已迁移 {expr_file} 到数据库,包含 {len(expressions)} 个表达方式")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析失败 {expr_file}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
|
||||
|
||||
# 检查并处理grammar表达删除
|
||||
if not os.path.exists(done_flag2):
|
||||
@@ -581,18 +587,20 @@ class ExpressionLearnerManager:
|
||||
"""
|
||||
try:
|
||||
# 查找所有create_date为空的表达方式
|
||||
old_expressions = Expression.select().where(Expression.create_date.is_null())
|
||||
old_expressions = session.execute(select(Expression).where(Expression.create_date.is_(None))).scalars()
|
||||
updated_count = 0
|
||||
|
||||
for expr in old_expressions:
|
||||
# 使用last_active_time作为create_date
|
||||
expr.create_date = expr.last_active_time
|
||||
expr.save()
|
||||
updated_count += 1
|
||||
|
||||
session.commit()
|
||||
|
||||
if updated_count > 0:
|
||||
logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期")
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"迁移老数据创建日期失败: {e}")
|
||||
|
||||
def delete_all_grammar_expressions(self) -> int:
|
||||
|
||||
@@ -9,8 +9,11 @@ from json_repair import repair_json
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from sqlalchemy import select
|
||||
from src.common.database.sqlalchemy_models import Expression
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
session = get_session()
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
@@ -131,9 +134,12 @@ class ExpressionSelector:
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
# 优化:一次性查询所有相关chat_id的表达方式
|
||||
style_query = Expression.select().where(
|
||||
style_query = session.execute(select(Expression).where(
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
|
||||
)
|
||||
))
|
||||
grammar_query = session.execute(select(Expression).where(
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar")
|
||||
))
|
||||
|
||||
style_exprs = [
|
||||
{
|
||||
@@ -146,9 +152,24 @@ class ExpressionSelector:
|
||||
"type": "style",
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
}
|
||||
for expr in style_query
|
||||
for expr in style_query.scalars()
|
||||
]
|
||||
|
||||
grammar_exprs = [
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"type": "grammar",
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
}
|
||||
for expr in grammar_query.scalars()
|
||||
]
|
||||
|
||||
style_num = int(total_num * style_percentage)
|
||||
grammar_num = int(total_num * grammar_percentage)
|
||||
# 按权重抽样(使用count作为权重)
|
||||
if style_exprs:
|
||||
style_weights = [expr.get("count", 1) for expr in style_exprs]
|
||||
@@ -174,19 +195,19 @@ class ExpressionSelector:
|
||||
if key not in updates_by_key:
|
||||
updates_by_key[key] = expr
|
||||
for chat_id, expr_type, situation, style in updates_by_key:
|
||||
query = Expression.select().where(
|
||||
query = session.execute(select(Expression).where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == expr_type)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style)
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
)).scalar()
|
||||
if query:
|
||||
expr_obj = query
|
||||
current_count = expr_obj.count
|
||||
new_count = min(current_count + increment, 5.0)
|
||||
expr_obj.count = new_count
|
||||
expr_obj.last_active_time = time.time()
|
||||
expr_obj.save()
|
||||
session.commit()
|
||||
logger.debug(
|
||||
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
|
||||
)
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
raise DeprecationWarning("MemoryActiveManager is not used yet, please do not import it")
|
||||
from .lpmmconfig import global_config
|
||||
from .embedding_store import EmbeddingManager
|
||||
from .llm_client import LLMClient
|
||||
from .utils.dyn_topk import dyn_select_top_k
|
||||
|
||||
|
||||
class MemoryActiveManager:
|
||||
def __init__(
|
||||
self,
|
||||
embed_manager: EmbeddingManager,
|
||||
llm_client_embedding: LLMClient,
|
||||
):
|
||||
self.embed_manager = embed_manager
|
||||
self.embedding_client = llm_client_embedding
|
||||
|
||||
def get_activation(self, question: str) -> float:
|
||||
"""获取记忆激活度"""
|
||||
# 生成问题的Embedding
|
||||
question_embedding = self.embedding_client.send_embedding_request("text-embedding", question)
|
||||
# 查询关系库中的相似度
|
||||
rel_search_res = self.embed_manager.relation_embedding_store.search_top_k(question_embedding, 10)
|
||||
|
||||
# 动态过滤阈值
|
||||
rel_scores = dyn_select_top_k(rel_search_res, 0.5, 1.0)
|
||||
if rel_scores[0][1] < global_config["qa"]["params"]["relation_threshold"]:
|
||||
# 未找到相关关系
|
||||
return 0.0
|
||||
|
||||
# 计算激活度
|
||||
activation = sum([item[2] for item in rel_scores]) * 10
|
||||
|
||||
return activation
|
||||
@@ -16,8 +16,10 @@ from rich.traceback import install
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
|
||||
from sqlalchemy import select,insert,update,text,delete
|
||||
from src.common.database.sqlalchemy_models import Messages, GraphNodes, GraphEdges # SQLAlchemy Models导入
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp,
|
||||
@@ -37,7 +39,7 @@ def cosine_similarity(v1, v2):
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
session = get_session()
|
||||
|
||||
def calculate_information_content(text):
|
||||
"""计算文本的信息量(熵)"""
|
||||
@@ -731,13 +733,14 @@ class Hippocampus:
|
||||
memory_items = node_data.get("memory_items", "")
|
||||
# 直接使用完整的记忆内容
|
||||
if memory_items:
|
||||
logger.debug("节点包含完整记忆")
|
||||
# 计算记忆与关键词的相似度
|
||||
memory_words = set(jieba.cut(memory_items))
|
||||
text_words = set(keywords)
|
||||
all_words = memory_words | text_words
|
||||
if all_words:
|
||||
# 计算相似度(虽然这里没有使用,但保持逻辑一致性)
|
||||
logger.debug(f"节点包含 {len(memory_items)} 条记忆")
|
||||
# 计算每条记忆与输入文本的相似度
|
||||
memory_similarities = []
|
||||
for memory in memory_items:
|
||||
# 计算与输入文本的相似度
|
||||
memory_words = set(jieba.cut(memory))
|
||||
text_words = set(jieba.cut(text))
|
||||
all_words = memory_words | text_words
|
||||
v1 = [1 if word in memory_words else 0 for word in all_words]
|
||||
v2 = [1 if word in text_words else 0 for word in all_words]
|
||||
_ = cosine_similarity(v1, v2) # 计算但不使用,用_表示
|
||||
@@ -844,11 +847,6 @@ class Hippocampus:
|
||||
else:
|
||||
activate_map[node] = activation_value
|
||||
|
||||
# 输出激活映射
|
||||
# logger.info("激活映射统计:")
|
||||
# for node, total_activation in sorted(activate_map.items(), key=lambda x: x[1], reverse=True):
|
||||
# logger.info(f"节点 '{node}': 累计激活值 = {total_activation:.2f}")
|
||||
|
||||
# 计算激活节点数与总节点数的比值
|
||||
total_activation = sum(activate_map.values())
|
||||
# logger.debug(f"总激活值: {total_activation:.2f}")
|
||||
@@ -942,10 +940,13 @@ class EntorhinalCortex:
|
||||
for message in messages:
|
||||
# 确保在更新前获取最新的 memorized_times
|
||||
current_memorized_times = message.get("memorized_times", 0)
|
||||
# 使用 Peewee 更新记录
|
||||
Messages.update(memorized_times=current_memorized_times + 1).where(
|
||||
Messages.message_id == message["message_id"]
|
||||
).execute()
|
||||
# 使用 SQLAlchemy 2.0 更新记录
|
||||
session.execute(
|
||||
update(Messages)
|
||||
.where(Messages.message_id == message["message_id"])
|
||||
.values(memorized_times=current_memorized_times + 1)
|
||||
)
|
||||
session.commit()
|
||||
return messages # 直接返回原始的消息列表
|
||||
|
||||
target_timestamp -= 120 # 如果第一次尝试失败,稍微向前调整时间戳再试
|
||||
@@ -959,7 +960,7 @@ class EntorhinalCortex:
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
|
||||
# 获取数据库中所有节点和内存中所有节点
|
||||
db_nodes = {node.concept: node for node in GraphNodes.select()}
|
||||
db_nodes = {node.concept: node for node in session.execute(select(GraphNodes)).scalars()}
|
||||
memory_nodes = list(self.memory_graph.G.nodes(data=True))
|
||||
|
||||
# 批量准备节点数据
|
||||
@@ -1025,22 +1026,27 @@ class EntorhinalCortex:
|
||||
batch_size = 100
|
||||
for i in range(0, len(nodes_to_create), batch_size):
|
||||
batch = nodes_to_create[i : i + batch_size]
|
||||
GraphNodes.insert_many(batch).execute()
|
||||
session.execute(insert(GraphNodes), batch)
|
||||
session.commit()
|
||||
|
||||
if nodes_to_update:
|
||||
batch_size = 100
|
||||
for i in range(0, len(nodes_to_update), batch_size):
|
||||
batch = nodes_to_update[i : i + batch_size]
|
||||
for node_data in batch:
|
||||
GraphNodes.update(**{k: v for k, v in node_data.items() if k != "concept"}).where(
|
||||
GraphNodes.concept == node_data["concept"]
|
||||
).execute()
|
||||
session.execute(
|
||||
update(GraphNodes)
|
||||
.where(GraphNodes.concept == node_data["concept"])
|
||||
.values(**{k: v for k, v in node_data.items() if k != "concept"})
|
||||
)
|
||||
session.commit()
|
||||
|
||||
if nodes_to_delete:
|
||||
GraphNodes.delete().where(GraphNodes.concept.in_(nodes_to_delete)).execute() # type: ignore
|
||||
session.execute(delete(GraphNodes).where(GraphNodes.concept.in_(nodes_to_delete)))
|
||||
session.commit()
|
||||
|
||||
# 处理边的信息
|
||||
db_edges = list(GraphEdges.select())
|
||||
db_edges = list(session.execute(select(GraphEdges)).scalars())
|
||||
memory_edges = list(self.memory_graph.G.edges(data=True))
|
||||
|
||||
# 创建边的哈希值字典
|
||||
@@ -1092,20 +1098,29 @@ class EntorhinalCortex:
|
||||
batch_size = 100
|
||||
for i in range(0, len(edges_to_create), batch_size):
|
||||
batch = edges_to_create[i : i + batch_size]
|
||||
GraphEdges.insert_many(batch).execute()
|
||||
session.execute(insert(GraphEdges), batch)
|
||||
session.commit()
|
||||
|
||||
if edges_to_update:
|
||||
batch_size = 100
|
||||
for i in range(0, len(edges_to_update), batch_size):
|
||||
batch = edges_to_update[i : i + batch_size]
|
||||
for edge_data in batch:
|
||||
GraphEdges.update(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]}).where(
|
||||
(GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"])
|
||||
).execute()
|
||||
session.execute(
|
||||
update(GraphEdges)
|
||||
.where(
|
||||
(GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"])
|
||||
)
|
||||
.values(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]})
|
||||
)
|
||||
session.commit()
|
||||
|
||||
if edges_to_delete:
|
||||
for source, target in edges_to_delete:
|
||||
GraphEdges.delete().where((GraphEdges.source == source) & (GraphEdges.target == target)).execute()
|
||||
session.execute(
|
||||
delete(GraphEdges).where((GraphEdges.source == source) & (GraphEdges.target == target))
|
||||
)
|
||||
session.commit()
|
||||
|
||||
end_time = time.time()
|
||||
logger.info(f"[数据库] 同步完成,总耗时: {end_time - start_time:.2f}秒")
|
||||
@@ -1118,8 +1133,9 @@ class EntorhinalCortex:
|
||||
|
||||
# 清空数据库
|
||||
clear_start = time.time()
|
||||
GraphNodes.delete().execute()
|
||||
GraphEdges.delete().execute()
|
||||
session.execute(delete(GraphNodes))
|
||||
session.execute(delete(GraphEdges))
|
||||
session.commit()
|
||||
clear_end = time.time()
|
||||
logger.info(f"[数据库] 清空数据库耗时: {clear_end - clear_start:.2f}秒")
|
||||
|
||||
@@ -1186,12 +1202,27 @@ class EntorhinalCortex:
|
||||
logger.error(f"准备边 {source}-{target} 数据时发生错误: {e}")
|
||||
continue
|
||||
|
||||
# 批量插入边
|
||||
# 批量写入节点
|
||||
node_start = time.time()
|
||||
if nodes_data:
|
||||
batch_size = 500 # 增加批量大小
|
||||
for i in range(0, len(nodes_data), batch_size):
|
||||
batch = nodes_data[i : i + batch_size]
|
||||
session.execute(insert(GraphNodes), batch)
|
||||
session.commit()
|
||||
node_end = time.time()
|
||||
logger.info(f"[数据库] 写入 {len(nodes_data)} 个节点耗时: {node_end - node_start:.2f}秒")
|
||||
|
||||
# 批量写入边
|
||||
edge_start = time.time()
|
||||
if edges_data:
|
||||
batch_size = 100
|
||||
batch_size = 500 # 增加批量大小
|
||||
for i in range(0, len(edges_data), batch_size):
|
||||
batch = edges_data[i : i + batch_size]
|
||||
GraphEdges.insert_many(batch).execute()
|
||||
session.execute(insert(GraphEdges), batch)
|
||||
session.commit()
|
||||
edge_end = time.time()
|
||||
logger.info(f"[数据库] 写入 {len(edges_data)} 条边耗时: {edge_end - edge_start:.2f}秒")
|
||||
|
||||
end_time = time.time()
|
||||
logger.info(f"[数据库] 重新同步完成,总耗时: {end_time - start_time:.2f}秒")
|
||||
@@ -1211,9 +1242,7 @@ class EntorhinalCortex:
|
||||
skipped_nodes = 0
|
||||
|
||||
# 从数据库加载所有节点
|
||||
nodes = list(GraphNodes.select())
|
||||
total_nodes = len(nodes)
|
||||
|
||||
nodes = list(session.execute(select(GraphNodes)).scalars())
|
||||
for node in nodes:
|
||||
concept = node.concept
|
||||
try:
|
||||
@@ -1235,8 +1264,10 @@ class EntorhinalCortex:
|
||||
if not node.last_modified:
|
||||
update_data["last_modified"] = current_time
|
||||
|
||||
if update_data:
|
||||
GraphNodes.update(**update_data).where(GraphNodes.concept == concept).execute()
|
||||
session.execute(
|
||||
update(GraphNodes).where(GraphNodes.concept == concept).values(**update_data)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# 获取时间信息(如果不存在则使用当前时间)
|
||||
created_time = node.created_time or current_time
|
||||
@@ -1256,7 +1287,7 @@ class EntorhinalCortex:
|
||||
continue
|
||||
|
||||
# 从数据库加载所有边
|
||||
edges = list(GraphEdges.select())
|
||||
edges = list(session.execute(select(GraphEdges)).scalars())
|
||||
for edge in edges:
|
||||
source = edge.source
|
||||
target = edge.target
|
||||
@@ -1272,9 +1303,12 @@ class EntorhinalCortex:
|
||||
if not edge.last_modified:
|
||||
update_data["last_modified"] = current_time
|
||||
|
||||
GraphEdges.update(**update_data).where(
|
||||
(GraphEdges.source == source) & (GraphEdges.target == target)
|
||||
).execute()
|
||||
session.execute(
|
||||
update(GraphEdges)
|
||||
.where((GraphEdges.source == source) & (GraphEdges.target == target))
|
||||
.values(**update_data)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# 获取时间信息(如果不存在则使用当前时间)
|
||||
created_time = edge.created_time or current_time
|
||||
@@ -1398,7 +1432,6 @@ class ParahippocampalGyrus:
|
||||
all_words = topic_words | existing_words
|
||||
v1 = [1 if word in topic_words else 0 for word in all_words]
|
||||
v2 = [1 if word in existing_words else 0 for word in all_words]
|
||||
|
||||
similarity = cosine_similarity(v1, v2)
|
||||
|
||||
if similarity >= 0.7:
|
||||
@@ -1502,7 +1535,7 @@ class ParahippocampalGyrus:
|
||||
check_nodes_count = max(1, min(len(all_nodes), int(len(all_nodes) * percentage)))
|
||||
check_edges_count = max(1, min(len(all_edges), int(len(all_edges) * percentage)))
|
||||
|
||||
# 只有在有足够的节点和边时才进行采样
|
||||
# 只有在有足够的节点和边时进行采样
|
||||
if len(all_nodes) >= check_nodes_count and len(all_edges) >= check_edges_count:
|
||||
try:
|
||||
nodes_to_check = random.sample(all_nodes, check_nodes_count)
|
||||
@@ -1548,6 +1581,11 @@ class ParahippocampalGyrus:
|
||||
|
||||
logger.info("[遗忘] 开始检查节点...")
|
||||
node_check_start = time.time()
|
||||
|
||||
# 初始化整合相关变量
|
||||
merged_count = 0
|
||||
nodes_modified = set()
|
||||
|
||||
for node in nodes_to_check:
|
||||
# 检查节点是否存在,以防在迭代中被移除(例如边移除导致)
|
||||
if node not in self.memory_graph.G:
|
||||
@@ -1567,64 +1605,91 @@ class ParahippocampalGyrus:
|
||||
logger.warning(f"[遗忘] 移除空节点 {node} 时发生错误(可能已被移除): {e}")
|
||||
continue # 处理下一个节点
|
||||
|
||||
# --- 如果节点不为空,则执行原来的不活跃检查和随机移除逻辑 ---
|
||||
# 检查节点的最后修改时间,如果太旧则尝试遗忘
|
||||
last_modified = node_data.get("last_modified", current_time)
|
||||
node_weight = node_data.get("weight", 1.0)
|
||||
|
||||
# 条件1:检查是否长时间未修改 (使用配置的遗忘时间)
|
||||
time_threshold = 3600 * global_config.memory.memory_forget_time
|
||||
|
||||
# 基于权重调整遗忘阈值:权重越高,需要更长时间才能被遗忘
|
||||
# 权重为1时使用默认阈值,权重越高阈值越大(越难遗忘)
|
||||
adjusted_threshold = time_threshold * node_weight
|
||||
|
||||
if current_time - last_modified > adjusted_threshold and memory_items:
|
||||
# 既然每个节点现在是完整记忆,直接删除整个节点
|
||||
try:
|
||||
self.memory_graph.G.remove_node(node)
|
||||
node_changes["removed"].append(f"{node}(长时间未修改,权重{node_weight:.1f})")
|
||||
logger.debug(f"[遗忘] 移除了长时间未修改的节点: {node} (权重: {node_weight:.1f})")
|
||||
except nx.NetworkXError as e:
|
||||
logger.warning(f"[遗忘] 移除节点 {node} 时发生错误(可能已被移除): {e}")
|
||||
continue
|
||||
if current_time - last_modified > 3600 * global_config.memory.memory_forget_time:
|
||||
# 随机遗忘一条记忆
|
||||
if len(memory_items) > 1:
|
||||
removed_item = self.memory_graph.forget_topic(node)
|
||||
if removed_item:
|
||||
node_changes["reduced"].append(f"{node} (移除: {removed_item[:50]}...)")
|
||||
elif len(memory_items) == 1:
|
||||
# 如果只有一条记忆,检查是否应该完全移除节点
|
||||
try:
|
||||
self.memory_graph.G.remove_node(node)
|
||||
node_changes["removed"].append(f"{node} (最后记忆)")
|
||||
except nx.NetworkXError as e:
|
||||
logger.warning(f"[遗忘] 移除节点 {node} 时发生错误: {e}")
|
||||
|
||||
# 检查节点内是否有相似的记忆项需要整合
|
||||
if len(memory_items) > 1:
|
||||
merged_in_this_node = False
|
||||
items_to_remove = []
|
||||
|
||||
for i in range(len(memory_items)):
|
||||
for j in range(i + 1, len(memory_items)):
|
||||
similarity = self._calculate_item_similarity(memory_items[i], memory_items[j])
|
||||
if similarity > 0.8: # 相似度阈值
|
||||
# 合并相似记忆项
|
||||
longer_item = memory_items[i] if len(memory_items[i]) > len(memory_items[j]) else memory_items[j]
|
||||
shorter_item = memory_items[j] if len(memory_items[i]) > len(memory_items[j]) else memory_items[i]
|
||||
|
||||
# 保留更长的记忆项,标记短的用于删除
|
||||
if shorter_item not in items_to_remove:
|
||||
items_to_remove.append(shorter_item)
|
||||
merged_count += 1
|
||||
merged_in_this_node = True
|
||||
logger.debug(f"[整合] 在节点 {node} 中合并相似记忆: {shorter_item[:30]}... -> {longer_item[:30]}...")
|
||||
|
||||
# 移除被合并的记忆项
|
||||
if items_to_remove:
|
||||
for item in items_to_remove:
|
||||
if item in memory_items:
|
||||
memory_items.remove(item)
|
||||
nodes_modified.add(node)
|
||||
# 更新节点的记忆项
|
||||
self.memory_graph.G.nodes[node]["memory_items"] = memory_items
|
||||
self.memory_graph.G.nodes[node]["last_modified"] = current_time
|
||||
|
||||
node_check_end = time.time()
|
||||
logger.info(f"[遗忘] 节点检查耗时: {node_check_end - node_check_start:.2f}秒")
|
||||
|
||||
if any(edge_changes.values()) or any(node_changes.values()):
|
||||
# 输出变化统计
|
||||
if edge_changes["weakened"]:
|
||||
logger.info(f"[遗忘] 减弱了 {len(edge_changes['weakened'])} 个连接")
|
||||
if edge_changes["removed"]:
|
||||
logger.info(f"[遗忘] 移除了 {len(edge_changes['removed'])} 个连接")
|
||||
if node_changes["reduced"]:
|
||||
logger.info(f"[遗忘] 减少了 {len(node_changes['reduced'])} 个节点的记忆")
|
||||
if node_changes["removed"]:
|
||||
logger.info(f"[遗忘] 移除了 {len(node_changes['removed'])} 个节点")
|
||||
|
||||
# 检查是否有变化需要同步到数据库
|
||||
has_changes = (
|
||||
edge_changes["weakened"] or
|
||||
edge_changes["removed"] or
|
||||
node_changes["reduced"] or
|
||||
node_changes["removed"] or
|
||||
merged_count > 0
|
||||
)
|
||||
|
||||
if has_changes:
|
||||
logger.info("[遗忘] 开始将变更同步到数据库...")
|
||||
sync_start = time.time()
|
||||
|
||||
await self.hippocampus.entorhinal_cortex.resync_memory_to_db()
|
||||
|
||||
await self.hippocampus.entorhinal_cortex.sync_memory_to_db()
|
||||
sync_end = time.time()
|
||||
logger.info(f"[遗忘] 数据库同步耗时: {sync_end - sync_start:.2f}秒")
|
||||
|
||||
# 汇总输出所有变化
|
||||
logger.info("[遗忘] 遗忘操作统计:")
|
||||
if edge_changes["weakened"]:
|
||||
logger.info(
|
||||
f"[遗忘] 减弱的连接 ({len(edge_changes['weakened'])}个): {', '.join(edge_changes['weakened'])}"
|
||||
)
|
||||
|
||||
if edge_changes["removed"]:
|
||||
logger.info(
|
||||
f"[遗忘] 移除的连接 ({len(edge_changes['removed'])}个): {', '.join(edge_changes['removed'])}"
|
||||
)
|
||||
|
||||
if node_changes["reduced"]:
|
||||
logger.info(
|
||||
f"[遗忘] 减少记忆的节点 ({len(node_changes['reduced'])}个): {', '.join(node_changes['reduced'])}"
|
||||
)
|
||||
|
||||
if node_changes["removed"]:
|
||||
logger.info(
|
||||
f"[遗忘] 移除的节点 ({len(node_changes['removed'])}个): {', '.join(node_changes['removed'])}"
|
||||
)
|
||||
if merged_count > 0:
|
||||
logger.info(f"[整合] 共合并了 {merged_count} 对相似记忆项,分布在 {len(nodes_modified)} 个节点中。")
|
||||
sync_start = time.time()
|
||||
logger.info("[整合] 开始将变更同步到数据库...")
|
||||
# 使用 resync 更安全地处理删除和添加
|
||||
await self.hippocampus.entorhinal_cortex.resync_memory_to_db()
|
||||
sync_end = time.time()
|
||||
logger.info(f"[整合] 数据库同步耗时: {sync_end - sync_start:.2f}秒")
|
||||
else:
|
||||
logger.info("[遗忘] 本次检查没有节点或连接满足遗忘条件")
|
||||
|
||||
end_time = time.time()
|
||||
logger.info(f"[遗忘] 总耗时: {end_time - start_time:.2f}秒")
|
||||
|
||||
logger.info("[整合] 本次检查未发现需要合并的记忆项。")
|
||||
|
||||
|
||||
|
||||
@@ -1734,10 +1799,7 @@ class HippocampusManager:
|
||||
"""获取所有节点名称的公共接口"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||
return self._hippocampus.get_all_node_names()
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
hippocampus_manager = HippocampusManager()
|
||||
|
||||
|
||||
|
||||
@@ -10,12 +10,13 @@ from datetime import datetime, timedelta
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Memory # Peewee Models导入
|
||||
from src.common.database.sqlalchemy_models import Memory # SQLAlchemy Models导入
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
from src.config.config import model_config
|
||||
|
||||
|
||||
from sqlalchemy import select
|
||||
logger = get_logger(__name__)
|
||||
|
||||
session = get_session()
|
||||
|
||||
class MemoryItem:
|
||||
def __init__(self, memory_id: str, chat_id: str, memory_text: str, keywords: list[str]):
|
||||
@@ -120,7 +121,8 @@ class InstantMemory:
|
||||
create_time=memory_item.create_time,
|
||||
last_view_time=memory_item.last_view_time,
|
||||
)
|
||||
memory.save()
|
||||
session.add(memory)
|
||||
session.commit()
|
||||
|
||||
async def get_memory(self, target: str):
|
||||
from json_repair import repair_json
|
||||
@@ -166,13 +168,13 @@ class InstantMemory:
|
||||
if start_time and end_time:
|
||||
start_ts = start_time.timestamp()
|
||||
end_ts = end_time.timestamp()
|
||||
query = Memory.select().where(
|
||||
query = session.execute(select(Memory).where(
|
||||
(Memory.chat_id == self.chat_id)
|
||||
& (Memory.create_time >= start_ts) # type: ignore
|
||||
& (Memory.create_time < end_ts) # type: ignore
|
||||
)
|
||||
& (Memory.create_time >= start_ts)
|
||||
& (Memory.create_time < end_ts)
|
||||
)).scalars()
|
||||
else:
|
||||
query = Memory.select().where(Memory.chat_id == self.chat_id)
|
||||
query = session.execute(select(Memory).where(Memory.chat_id == self.chat_id)).scalars()
|
||||
|
||||
for mem in query:
|
||||
# 对每条记忆
|
||||
|
||||
@@ -8,8 +8,12 @@ from maim_message import GroupInfo, UserInfo
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import ChatStreams # 新增导入
|
||||
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
from src.config.config import global_config # 新增导入
|
||||
# 避免循环导入,使用TYPE_CHECKING进行类型提示
|
||||
if TYPE_CHECKING:
|
||||
from .message import MessageRecv
|
||||
@@ -19,7 +23,7 @@ install(extra_lines=3)
|
||||
|
||||
|
||||
logger = get_logger("chat_stream")
|
||||
|
||||
session = get_session()
|
||||
|
||||
class ChatMessageContext:
|
||||
"""聊天消息上下文,存储消息的上下文信息"""
|
||||
@@ -131,7 +135,8 @@ class ChatManager:
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
# 确保 ChatStreams 表存在
|
||||
db.create_tables([ChatStreams], safe=True)
|
||||
session.execute(text("CREATE TABLE IF NOT EXISTS chat_streams (stream_id TEXT PRIMARY KEY, platform TEXT, create_time REAL, last_active_time REAL, user_platform TEXT, user_id TEXT, user_nickname TEXT, user_cardname TEXT, group_platform TEXT, group_id TEXT, group_name TEXT)"))
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
|
||||
|
||||
@@ -231,7 +236,7 @@ class ChatManager:
|
||||
|
||||
# 检查数据库中是否存在
|
||||
def _db_find_stream_sync(s_id: str):
|
||||
return ChatStreams.get_or_none(ChatStreams.stream_id == s_id)
|
||||
return session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)).scalar()
|
||||
|
||||
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
|
||||
|
||||
@@ -342,7 +347,28 @@ class ChatManager:
|
||||
"group_name": group_info_d["group_name"] if group_info_d else "",
|
||||
}
|
||||
|
||||
ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute()
|
||||
# 根据数据库类型选择插入语句
|
||||
if global_config.database.database_type == "sqlite":
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=['stream_id'],
|
||||
set_=fields_to_save
|
||||
)
|
||||
elif global_config.database.database_type == "mysql":
|
||||
stmt = mysql_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||||
stmt = stmt.on_duplicate_key_update(
|
||||
**{key: value for key, value in fields_to_save.items() if key != "stream_id"}
|
||||
)
|
||||
else:
|
||||
# 默认使用通用插入,尝试SQLite语法
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=['stream_id'],
|
||||
set_=fields_to_save
|
||||
)
|
||||
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
|
||||
@@ -361,7 +387,7 @@ class ChatManager:
|
||||
|
||||
def _db_load_all_streams_sync():
|
||||
loaded_streams_data = []
|
||||
for model_instance in ChatStreams.select():
|
||||
for model_instance in session.execute(select(ChatStreams)).scalars():
|
||||
user_info_data = {
|
||||
"platform": model_instance.user_platform,
|
||||
"user_id": model_instance.user_id,
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
import re
|
||||
import json
|
||||
import traceback
|
||||
import json
|
||||
from typing import Union
|
||||
|
||||
from src.common.database.database_model import Messages, Images
|
||||
from src.common.database.sqlalchemy_models import Messages, Images
|
||||
from src.common.logger import get_logger
|
||||
from .chat_stream import ChatStream
|
||||
from .message import MessageSending, MessageRecv
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
from sqlalchemy import select, update, desc
|
||||
|
||||
logger = get_logger("message_storage")
|
||||
|
||||
|
||||
class MessageStorage:
|
||||
@staticmethod
|
||||
def _serialize_keywords(keywords) -> str:
|
||||
@@ -33,15 +35,11 @@ class MessageStorage:
|
||||
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
|
||||
"""存储消息到数据库"""
|
||||
try:
|
||||
# 莫越权 救世啊
|
||||
# 过滤敏感信息的正则模式
|
||||
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
|
||||
|
||||
# print(message)
|
||||
|
||||
processed_plain_text = message.processed_plain_text
|
||||
|
||||
# print(processed_plain_text)
|
||||
|
||||
if processed_plain_text:
|
||||
processed_plain_text = MessageStorage.replace_image_descriptions(processed_plain_text)
|
||||
filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL)
|
||||
@@ -93,11 +91,16 @@ class MessageStorage:
|
||||
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
|
||||
user_info_from_chat = chat_info_dict.get("user_info") or {}
|
||||
|
||||
Messages.create(
|
||||
# 将priority_info字典序列化为JSON字符串,以便存储到数据库的Text字段
|
||||
priority_info_json = json.dumps(priority_info) if priority_info else None
|
||||
|
||||
# 获取数据库会话
|
||||
session = get_session()
|
||||
|
||||
new_message = Messages(
|
||||
message_id=msg_id,
|
||||
time=float(message.message_info.time), # type: ignore
|
||||
time=float(message.message_info.time),
|
||||
chat_id=chat_stream.stream_id,
|
||||
# Flattened chat_info
|
||||
reply_to=reply_to,
|
||||
is_mentioned=is_mentioned,
|
||||
chat_info_stream_id=chat_info_dict.get("stream_id"),
|
||||
@@ -111,18 +114,16 @@ class MessageStorage:
|
||||
chat_info_group_name=group_info_from_chat.get("group_name"),
|
||||
chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)),
|
||||
chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)),
|
||||
# Flattened user_info (message sender)
|
||||
user_platform=user_info_dict.get("platform"),
|
||||
user_id=user_info_dict.get("user_id"),
|
||||
user_nickname=user_info_dict.get("user_nickname"),
|
||||
user_cardname=user_info_dict.get("user_cardname"),
|
||||
# Text content
|
||||
processed_plain_text=filtered_processed_plain_text,
|
||||
display_message=filtered_display_message,
|
||||
memorized_times=message.memorized_times,
|
||||
interest_value=interest_value,
|
||||
priority_mode=priority_mode,
|
||||
priority_info=priority_info,
|
||||
priority_info=priority_info_json,
|
||||
is_emoji=is_emoji,
|
||||
is_picid=is_picid,
|
||||
is_notify=is_notify,
|
||||
@@ -131,35 +132,44 @@ class MessageStorage:
|
||||
key_words_lite=key_words_lite,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
session.add(new_message)
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("存储消息失败")
|
||||
logger.error(f"消息:{message}")
|
||||
traceback.print_exc()
|
||||
|
||||
# 如果需要其他存储相关的函数,可以在这里添加
|
||||
@staticmethod
|
||||
async def update_message(
|
||||
message: MessageRecv,
|
||||
) -> None: # 用于实时更新数据库的自身发送消息ID,目前能处理text,reply,image和emoji
|
||||
"""更新最新一条匹配消息的message_id"""
|
||||
async def update_message(message):
|
||||
"""更新消息ID"""
|
||||
try:
|
||||
if message.message_segment.type == "notify":
|
||||
mmc_message_id = message.message_segment.data.get("echo") # type: ignore
|
||||
qq_message_id = message.message_segment.data.get("actual_id") # type: ignore
|
||||
mmc_message_id = message.message_info.message_id # 修复:正确访问message_id
|
||||
if message.message_segment.type == "text":
|
||||
qq_message_id = message.message_segment.data.get("id")
|
||||
elif message.message_segment.type == "reply":
|
||||
qq_message_id = message.message_segment.data.get("id")
|
||||
else:
|
||||
logger.info(f"更新消息ID错误,seg类型为{message.message_segment.type}")
|
||||
return
|
||||
if not qq_message_id:
|
||||
logger.info("消息不存在message_id,无法更新")
|
||||
return
|
||||
if matched_message := (
|
||||
Messages.select().where((Messages.message_id == mmc_message_id)).order_by(Messages.time.desc()).first()
|
||||
):
|
||||
# 更新找到的消息记录
|
||||
Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute() # type: ignore
|
||||
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
|
||||
else:
|
||||
logger.debug("未找到匹配的消息")
|
||||
|
||||
# 使用上下文管理器确保session正确管理
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
with get_db_session() as session:
|
||||
matched_message = session.execute(
|
||||
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
|
||||
).scalar()
|
||||
|
||||
if matched_message:
|
||||
session.execute(
|
||||
update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id)
|
||||
)
|
||||
# session.commit() 会在上下文管理器中自动调用
|
||||
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
|
||||
else:
|
||||
logger.debug("未找到匹配的消息")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新消息ID失败: {e}")
|
||||
@@ -178,10 +188,12 @@ class MessageStorage:
|
||||
def replace_match(match):
|
||||
description = match.group(1).strip()
|
||||
try:
|
||||
image_record = (
|
||||
Images.select().where(Images.description == description).order_by(Images.timestamp.desc()).first()
|
||||
)
|
||||
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
with get_db_session() as session:
|
||||
image_record = session.execute(
|
||||
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
|
||||
).scalar()
|
||||
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
|
||||
except Exception:
|
||||
return match.group(0)
|
||||
|
||||
|
||||
@@ -7,13 +7,14 @@ from rich.traceback import install
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.message_repository import find_messages, count_messages
|
||||
from src.common.database.database_model import ActionRecords
|
||||
from src.common.database.database_model import Images
|
||||
from src.person_info.person_info import Person,get_person_id
|
||||
from src.common.database.sqlalchemy_models import ActionRecords, Images
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
from sqlalchemy import select, and_
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
session = get_session()
|
||||
|
||||
def replace_user_references_sync(
|
||||
content: str,
|
||||
@@ -254,50 +255,90 @@ def get_actions_by_timestamp_with_chat(
|
||||
limit_mode: str = "latest",
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
|
||||
query = ActionRecords.select().where(
|
||||
(ActionRecords.chat_id == chat_id)
|
||||
& (ActionRecords.time > timestamp_start) # type: ignore
|
||||
& (ActionRecords.time < timestamp_end) # type: ignore
|
||||
)
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end
|
||||
)
|
||||
))
|
||||
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
query = query.order_by(ActionRecords.time.desc()).limit(limit)
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end
|
||||
)
|
||||
).order_by(ActionRecords.time.desc()).limit(limit))
|
||||
# 获取后需要反转列表,以保持最终输出为时间升序
|
||||
actions = list(query)
|
||||
return [action.__data__ for action in reversed(actions)]
|
||||
actions = list(query.scalars())
|
||||
return [action.__dict__ for action in reversed(actions)]
|
||||
else: # earliest
|
||||
query = query.order_by(ActionRecords.time.asc()).limit(limit)
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end
|
||||
)
|
||||
).order_by(ActionRecords.time.asc()).limit(limit))
|
||||
else:
|
||||
query = query.order_by(ActionRecords.time.asc())
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time > timestamp_start,
|
||||
ActionRecords.time < timestamp_end
|
||||
)
|
||||
).order_by(ActionRecords.time.asc()))
|
||||
|
||||
actions = list(query)
|
||||
return [action.__data__ for action in actions]
|
||||
actions = list(query.scalars())
|
||||
return [action.__dict__ for action in actions]
|
||||
|
||||
|
||||
def get_actions_by_timestamp_with_chat_inclusive(
|
||||
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
|
||||
query = ActionRecords.select().where(
|
||||
(ActionRecords.chat_id == chat_id)
|
||||
& (ActionRecords.time >= timestamp_start) # type: ignore
|
||||
& (ActionRecords.time <= timestamp_end) # type: ignore
|
||||
)
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time >= timestamp_start,
|
||||
ActionRecords.time <= timestamp_end
|
||||
)
|
||||
))
|
||||
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
query = query.order_by(ActionRecords.time.desc()).limit(limit)
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time >= timestamp_start,
|
||||
ActionRecords.time <= timestamp_end
|
||||
)
|
||||
).order_by(ActionRecords.time.desc()).limit(limit))
|
||||
# 获取后需要反转列表,以保持最终输出为时间升序
|
||||
actions = list(query)
|
||||
return [action.__data__ for action in reversed(actions)]
|
||||
actions = list(query.scalars())
|
||||
return [action.__dict__ for action in reversed(actions)]
|
||||
else: # earliest
|
||||
query = query.order_by(ActionRecords.time.asc()).limit(limit)
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time >= timestamp_start,
|
||||
ActionRecords.time <= timestamp_end
|
||||
)
|
||||
).order_by(ActionRecords.time.asc()).limit(limit))
|
||||
else:
|
||||
query = query.order_by(ActionRecords.time.asc())
|
||||
query = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.chat_id == chat_id,
|
||||
ActionRecords.time >= timestamp_start,
|
||||
ActionRecords.time <= timestamp_end
|
||||
)
|
||||
).order_by(ActionRecords.time.asc()))
|
||||
|
||||
actions = list(query)
|
||||
return [action.__data__ for action in actions]
|
||||
actions = list(query.scalars())
|
||||
return [action.__dict__ for action in actions]
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp_random(
|
||||
@@ -700,7 +741,7 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
# 从数据库中获取图片描述
|
||||
description = "内容正在阅读,请稍等"
|
||||
try:
|
||||
image = Images.get_or_none(Images.image_id == pic_id)
|
||||
image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar()
|
||||
if image and image.description:
|
||||
description = image.description
|
||||
except Exception:
|
||||
@@ -813,7 +854,7 @@ def build_readable_messages(
|
||||
timestamp_mode: str = "relative",
|
||||
read_mark: float = 0.0,
|
||||
truncate: bool = False,
|
||||
show_actions: bool = False,
|
||||
show_actions: bool = True,
|
||||
show_pic: bool = True,
|
||||
message_id_list: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> str: # sourcery skip: extract-method
|
||||
@@ -846,21 +887,21 @@ def build_readable_messages(
|
||||
chat_id = copy_messages[0].get("chat_id") if copy_messages else None
|
||||
|
||||
# 获取这个时间范围内的动作记录,并匹配chat_id
|
||||
actions_in_range = (
|
||||
ActionRecords.select()
|
||||
.where(
|
||||
(ActionRecords.time >= min_time) & (ActionRecords.time <= max_time) & (ActionRecords.chat_id == chat_id)
|
||||
actions_in_range = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.time >= min_time,
|
||||
ActionRecords.time <= max_time,
|
||||
ActionRecords.chat_id == chat_id
|
||||
)
|
||||
.order_by(ActionRecords.time)
|
||||
)
|
||||
).order_by(ActionRecords.time)).scalars()
|
||||
|
||||
# 获取最新消息之后的第一个动作记录
|
||||
action_after_latest = (
|
||||
ActionRecords.select()
|
||||
.where((ActionRecords.time > max_time) & (ActionRecords.chat_id == chat_id))
|
||||
.order_by(ActionRecords.time)
|
||||
.limit(1)
|
||||
)
|
||||
action_after_latest = session.execute(select(ActionRecords).where(
|
||||
and_(
|
||||
ActionRecords.time > max_time,
|
||||
ActionRecords.chat_id == chat_id
|
||||
)
|
||||
).order_by(ActionRecords.time).limit(1)).scalars()
|
||||
|
||||
# 合并两部分动作记录
|
||||
actions = list(actions_in_range) + list(action_after_latest)
|
||||
|
||||
@@ -6,13 +6,52 @@ from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Tuple, List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import OnlineTime, LLMUsage, Messages
|
||||
from src.common.database.sqlalchemy_models import OnlineTime, LLMUsage, Messages
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session, db_query, db_save, db_get
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
logger = get_logger("maibot_statistic")
|
||||
|
||||
# 同步包装器函数,用于在非异步环境中调用异步数据库API
|
||||
def _sync_db_get(model_class, filters=None, order_by=None, limit=None, single_result=False):
|
||||
"""同步版本的db_get,用于在线程池中调用"""
|
||||
import asyncio
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# 如果事件循环正在运行,创建新的事件循环
|
||||
import threading
|
||||
result = None
|
||||
exception = None
|
||||
|
||||
def run_in_thread():
|
||||
nonlocal result, exception
|
||||
try:
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
result = new_loop.run_until_complete(
|
||||
db_get(model_class, filters, limit, order_by, single_result)
|
||||
)
|
||||
new_loop.close()
|
||||
except Exception as e:
|
||||
exception = e
|
||||
|
||||
thread = threading.Thread(target=run_in_thread)
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
if exception:
|
||||
raise exception
|
||||
return result
|
||||
else:
|
||||
return loop.run_until_complete(
|
||||
db_get(model_class, filters, limit, order_by, single_result)
|
||||
)
|
||||
except RuntimeError:
|
||||
# 没有事件循环,创建一个新的
|
||||
return asyncio.run(db_get(model_class, filters, limit, order_by, single_result))
|
||||
|
||||
# 统计数据的键
|
||||
TOTAL_REQ_CNT = "total_requests"
|
||||
TOTAL_COST = "total_cost"
|
||||
@@ -59,17 +98,9 @@ class OnlineTimeRecordTask(AsyncTask):
|
||||
def __init__(self):
|
||||
super().__init__(task_name="Online Time Record Task", run_interval=60)
|
||||
|
||||
self.record_id: int | None = None # Changed to int for Peewee's default ID
|
||||
self.record_id: int | None = None
|
||||
"""记录ID"""
|
||||
|
||||
self._init_database() # 初始化数据库
|
||||
|
||||
@staticmethod
|
||||
def _init_database():
|
||||
"""初始化数据库"""
|
||||
with db.atomic(): # Use atomic operations for schema changes
|
||||
OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model
|
||||
|
||||
async def run(self): # sourcery skip: use-named-expression
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
@@ -77,36 +108,50 @@ class OnlineTimeRecordTask(AsyncTask):
|
||||
|
||||
if self.record_id:
|
||||
# 如果有记录,则更新结束时间
|
||||
query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id) # type: ignore
|
||||
updated_rows = query.execute()
|
||||
updated_rows = await db_query(
|
||||
model_class=OnlineTime,
|
||||
query_type="update",
|
||||
filters={"id": self.record_id},
|
||||
data={"end_timestamp": extended_end_time}
|
||||
)
|
||||
if updated_rows == 0:
|
||||
# Record might have been deleted or ID is stale, try to find/create
|
||||
self.record_id = None # Reset record_id to trigger find/create logic below
|
||||
self.record_id = None
|
||||
|
||||
if not self.record_id: # Check again if record_id was reset or initially None
|
||||
# 如果没有记录,检查一分钟以内是否已有记录
|
||||
# Look for a record whose end_timestamp is recent enough to be considered ongoing
|
||||
recent_record = (
|
||||
OnlineTime.select()
|
||||
.where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))) # type: ignore
|
||||
.order_by(OnlineTime.end_timestamp.desc())
|
||||
.first()
|
||||
if not self.record_id:
|
||||
# 查找最近一分钟内的记录
|
||||
recent_threshold = current_time - timedelta(minutes=1)
|
||||
recent_records = await db_get(
|
||||
model_class=OnlineTime,
|
||||
filters={"end_timestamp": {"$gte": recent_threshold}},
|
||||
order_by="-end_timestamp",
|
||||
limit=1,
|
||||
single_result=True
|
||||
)
|
||||
|
||||
if recent_record:
|
||||
# 如果有记录,则更新结束时间
|
||||
self.record_id = recent_record.id
|
||||
recent_record.end_timestamp = extended_end_time
|
||||
recent_record.save()
|
||||
else:
|
||||
# 若没有记录,则插入新的在线时间记录
|
||||
new_record = OnlineTime.create(
|
||||
timestamp=current_time.timestamp(), # 添加此行
|
||||
start_timestamp=current_time,
|
||||
end_timestamp=extended_end_time,
|
||||
duration=5, # 初始时长为5分钟
|
||||
|
||||
if recent_records:
|
||||
# 找到近期记录,更新它
|
||||
self.record_id = recent_records['id']
|
||||
await db_query(
|
||||
model_class=OnlineTime,
|
||||
query_type="update",
|
||||
filters={"id": self.record_id},
|
||||
data={"end_timestamp": extended_end_time}
|
||||
)
|
||||
self.record_id = new_record.id
|
||||
else:
|
||||
# 创建新记录
|
||||
new_record = await db_save(
|
||||
model_class=OnlineTime,
|
||||
data={
|
||||
"timestamp": str(current_time),
|
||||
"duration": 5, # 初始时长为5分钟
|
||||
"start_timestamp": current_time,
|
||||
"end_timestamp": extended_end_time,
|
||||
}
|
||||
)
|
||||
if new_record:
|
||||
self.record_id = new_record['id']
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"在线时间记录失败,错误信息:{e}")
|
||||
|
||||
@@ -322,18 +367,23 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
# 以最早的时间戳为起始时间获取记录
|
||||
# Assuming LLMUsage.timestamp is a DateTimeField
|
||||
query_start_time = collect_period[-1][1]
|
||||
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
|
||||
record_timestamp = record.timestamp # This is already a datetime object
|
||||
records = _sync_db_get(
|
||||
model_class=LLMUsage,
|
||||
filters={"timestamp": {"$gte": query_start_time}},
|
||||
order_by="-timestamp"
|
||||
)
|
||||
|
||||
for record in records:
|
||||
record_timestamp = record['timestamp'] # 从字典中获取
|
||||
for idx, (_, period_start) in enumerate(collect_period):
|
||||
if record_timestamp >= period_start:
|
||||
for period_key, _ in collect_period[idx:]:
|
||||
stats[period_key][TOTAL_REQ_CNT] += 1
|
||||
|
||||
request_type = record.request_type or "unknown"
|
||||
user_id = record.user_id or "unknown" # user_id is TextField, already string
|
||||
model_name = record.model_name or "unknown"
|
||||
request_type = record.get('request_type') or "unknown"
|
||||
user_id = record.get('user_id') or "unknown"
|
||||
model_name = record.get('model_name') or "unknown"
|
||||
|
||||
# 提取模块名:如果请求类型包含".",取第一个"."之前的部分
|
||||
module_name = request_type.split(".")[0] if "." in request_type else request_type
|
||||
@@ -343,8 +393,8 @@ class StatisticOutputTask(AsyncTask):
|
||||
stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1
|
||||
stats[period_key][REQ_CNT_BY_MODULE][module_name] += 1
|
||||
|
||||
prompt_tokens = record.prompt_tokens or 0
|
||||
completion_tokens = record.completion_tokens or 0
|
||||
prompt_tokens = record.get('prompt_tokens') or 0
|
||||
completion_tokens = record.get('completion_tokens') or 0
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
|
||||
@@ -362,7 +412,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens
|
||||
stats[period_key][TOTAL_TOK_BY_MODULE][module_name] += total_tokens
|
||||
|
||||
cost = record.cost or 0.0
|
||||
cost = record.get('cost') or 0.0
|
||||
stats[period_key][TOTAL_COST] += cost
|
||||
stats[period_key][COST_BY_TYPE][request_type] += cost
|
||||
stats[period_key][COST_BY_USER][user_id] += cost
|
||||
@@ -425,11 +475,15 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
query_start_time = collect_period[-1][1]
|
||||
# Assuming OnlineTime.end_timestamp is a DateTimeField
|
||||
for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time): # type: ignore
|
||||
# record.end_timestamp and record.start_timestamp are datetime objects
|
||||
record_end_timestamp = record.end_timestamp
|
||||
record_start_timestamp = record.start_timestamp
|
||||
records = _sync_db_get(
|
||||
model_class=OnlineTime,
|
||||
filters={"end_timestamp": {"$gte": query_start_time}},
|
||||
order_by="-end_timestamp"
|
||||
)
|
||||
|
||||
for record in records:
|
||||
record_end_timestamp = record['end_timestamp']
|
||||
record_start_timestamp = record['start_timestamp']
|
||||
|
||||
for idx, (_, period_boundary_start) in enumerate(collect_period):
|
||||
if record_end_timestamp >= period_boundary_start:
|
||||
@@ -466,24 +520,30 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
|
||||
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
|
||||
message_time_ts = message.time # This is a float timestamp
|
||||
records = _sync_db_get(
|
||||
model_class=Messages,
|
||||
filters={"time": {"$gte": query_start_timestamp}},
|
||||
order_by="-time"
|
||||
)
|
||||
|
||||
for message in records:
|
||||
message_time_ts = message['time'] # This is a float timestamp
|
||||
|
||||
chat_id = None
|
||||
chat_name = None
|
||||
|
||||
# Logic based on Peewee model structure, aiming to replicate original intent
|
||||
if message.chat_info_group_id:
|
||||
chat_id = f"g{message.chat_info_group_id}"
|
||||
chat_name = message.chat_info_group_name or f"群{message.chat_info_group_id}"
|
||||
elif message.user_id: # Fallback to sender's info for chat_id if not a group_info based chat
|
||||
# Logic based on SQLAlchemy model structure, aiming to replicate original intent
|
||||
if message.get('chat_info_group_id'):
|
||||
chat_id = f"g{message['chat_info_group_id']}"
|
||||
chat_name = message.get('chat_info_group_name') or f"群{message['chat_info_group_id']}"
|
||||
elif message.get('user_id'): # Fallback to sender's info for chat_id if not a group_info based chat
|
||||
# This uses the message SENDER's ID as per original logic's fallback
|
||||
chat_id = f"u{message.user_id}" # SENDER's user_id
|
||||
chat_name = message.user_nickname # SENDER's nickname
|
||||
chat_id = f"u{message['user_id']}" # SENDER's user_id
|
||||
chat_name = message.get('user_nickname') # SENDER's nickname
|
||||
else:
|
||||
# If neither group_id nor sender_id is available for chat identification
|
||||
logger.warning(
|
||||
f"Message (PK: {message.id if hasattr(message, 'id') else 'N/A'}) lacks group_id and user_id for chat stats."
|
||||
f"Message (PK: {message.get('id', 'N/A')}) lacks group_id and user_id for chat stats."
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -1025,8 +1085,14 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 查询LLM使用记录
|
||||
query_start_time = start_time
|
||||
for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore
|
||||
record_time = record.timestamp
|
||||
records = _sync_db_get(
|
||||
model_class=LLMUsage,
|
||||
filters={"timestamp": {"$gte": query_start_time}},
|
||||
order_by="-timestamp"
|
||||
)
|
||||
|
||||
for record in records:
|
||||
record_time = record['timestamp']
|
||||
|
||||
# 找到对应的时间间隔索引
|
||||
time_diff = (record_time - start_time).total_seconds()
|
||||
@@ -1034,17 +1100,17 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
if 0 <= interval_index < len(time_points):
|
||||
# 累加总花费数据
|
||||
cost = record.cost or 0.0
|
||||
cost = record.get('cost') or 0.0
|
||||
total_cost_data[interval_index] += cost # type: ignore
|
||||
|
||||
# 累加按模型分类的花费
|
||||
model_name = record.model_name or "unknown"
|
||||
model_name = record.get('model_name') or "unknown"
|
||||
if model_name not in cost_by_model:
|
||||
cost_by_model[model_name] = [0] * len(time_points)
|
||||
cost_by_model[model_name][interval_index] += cost
|
||||
|
||||
# 累加按模块分类的花费
|
||||
request_type = record.request_type or "unknown"
|
||||
request_type = record.get('request_type') or "unknown"
|
||||
module_name = request_type.split(".")[0] if "." in request_type else request_type
|
||||
if module_name not in cost_by_module:
|
||||
cost_by_module[module_name] = [0] * len(time_points)
|
||||
@@ -1052,8 +1118,14 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 查询消息记录
|
||||
query_start_timestamp = start_time.timestamp()
|
||||
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
|
||||
message_time_ts = message.time
|
||||
records = _sync_db_get(
|
||||
model_class=Messages,
|
||||
filters={"time": {"$gte": query_start_timestamp}},
|
||||
order_by="-time"
|
||||
)
|
||||
|
||||
for message in records:
|
||||
message_time_ts = message['time']
|
||||
|
||||
# 找到对应的时间间隔索引
|
||||
time_diff = message_time_ts - query_start_timestamp
|
||||
@@ -1062,10 +1134,10 @@ class StatisticOutputTask(AsyncTask):
|
||||
if 0 <= interval_index < len(time_points):
|
||||
# 确定聊天流名称
|
||||
chat_name = None
|
||||
if message.chat_info_group_id:
|
||||
chat_name = message.chat_info_group_name or f"群{message.chat_info_group_id}"
|
||||
elif message.user_id:
|
||||
chat_name = message.user_nickname or f"用户{message.user_id}"
|
||||
if message.get('chat_info_group_id'):
|
||||
chat_name = message.get('chat_info_group_name') or f"群{message['chat_info_group_id']}"
|
||||
elif message.get('user_id'):
|
||||
chat_name = message.get('user_nickname') or f"用户{message['user_id']}"
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
@@ -13,10 +13,12 @@ from rich.traceback import install
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import Images, ImageDescriptions
|
||||
from src.common.database.sqlalchemy_models import Images, ImageDescriptions
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
|
||||
from sqlalchemy import select, and_
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("chat_image")
|
||||
@@ -41,9 +43,10 @@ class ImageManager:
|
||||
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
db.create_tables([Images, ImageDescriptions], safe=True)
|
||||
# 使用SQLAlchemy创建表已在初始化时完成
|
||||
logger.debug("使用SQLAlchemy进行表管理")
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接或表创建失败: {e}")
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
|
||||
self._initialized = True
|
||||
|
||||
@@ -63,12 +66,13 @@ class ImageManager:
|
||||
Optional[str]: 描述文本,如果不存在则返回None
|
||||
"""
|
||||
try:
|
||||
record = ImageDescriptions.get_or_none(
|
||||
(ImageDescriptions.image_description_hash == image_hash) & (ImageDescriptions.type == description_type)
|
||||
)
|
||||
return record.description if record else None
|
||||
with get_db_session() as session:
|
||||
record = session.execute(select(ImageDescriptions).where(
|
||||
and_(ImageDescriptions.image_description_hash == image_hash, ImageDescriptions.type == description_type)
|
||||
)).scalar()
|
||||
return record.description if record else None
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}")
|
||||
logger.error(f"从数据库获取描述失败 (SQLAlchemy): {str(e)}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@@ -82,16 +86,28 @@ class ImageManager:
|
||||
"""
|
||||
try:
|
||||
current_timestamp = time.time()
|
||||
defaults = {"description": description, "timestamp": current_timestamp}
|
||||
desc_obj, created = ImageDescriptions.get_or_create(
|
||||
image_description_hash=image_hash, type=description_type, defaults=defaults
|
||||
)
|
||||
if not created: # 如果记录已存在,则更新
|
||||
desc_obj.description = description
|
||||
desc_obj.timestamp = current_timestamp
|
||||
desc_obj.save()
|
||||
with get_db_session() as session:
|
||||
# 查找现有记录
|
||||
existing = session.execute(select(ImageDescriptions).where(
|
||||
and_(ImageDescriptions.image_description_hash == image_hash, ImageDescriptions.type == description_type)
|
||||
)).scalar()
|
||||
|
||||
if existing:
|
||||
# 更新现有记录
|
||||
existing.description = description
|
||||
existing.timestamp = current_timestamp
|
||||
else:
|
||||
# 创建新记录
|
||||
new_desc = ImageDescriptions(
|
||||
image_description_hash=image_hash,
|
||||
type=description_type,
|
||||
description=description,
|
||||
timestamp=current_timestamp
|
||||
)
|
||||
session.add(new_desc)
|
||||
# session.commit() 会在上下文管理器中自动调用
|
||||
except Exception as e:
|
||||
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
|
||||
logger.error(f"保存描述到数据库失败 (SQLAlchemy): {str(e)}")
|
||||
|
||||
async def get_emoji_tag(self, image_base64: str) -> str:
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
@@ -214,19 +230,29 @@ class ImageManager:
|
||||
|
||||
# 保存到数据库 (Images表) - 包含详细描述用于可能的注册流程
|
||||
try:
|
||||
img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
|
||||
img_obj.path = file_path
|
||||
img_obj.description = detailed_description # 保存详细描述
|
||||
img_obj.timestamp = current_timestamp
|
||||
img_obj.save()
|
||||
except Images.DoesNotExist: # type: ignore
|
||||
Images.create(
|
||||
emoji_hash=image_hash,
|
||||
path=file_path,
|
||||
type="emoji",
|
||||
description=detailed_description, # 保存详细描述
|
||||
timestamp=current_timestamp,
|
||||
)
|
||||
from src.common.database.sqlalchemy_models import get_db_session
|
||||
with get_db_session() as session:
|
||||
existing_img = session.execute(select(Images).where(
|
||||
and_(Images.emoji_hash == image_hash, Images.type == "emoji")
|
||||
)).scalar()
|
||||
|
||||
if existing_img:
|
||||
existing_img.path = file_path
|
||||
existing_img.description = detailed_description # 保存详细描述
|
||||
existing_img.timestamp = current_timestamp
|
||||
else:
|
||||
new_img = Images(
|
||||
emoji_hash=image_hash,
|
||||
path=file_path,
|
||||
type="emoji",
|
||||
description=detailed_description, # 保存详细描述
|
||||
timestamp=current_timestamp,
|
||||
)
|
||||
session.add(new_img)
|
||||
# session.commit() 会在上下文管理器中自动调用
|
||||
except Exception as e:
|
||||
logger.error(f"保存到Images表失败: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存表情包文件或元数据失败: {str(e)}")
|
||||
|
||||
@@ -249,19 +275,19 @@ class ImageManager:
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
# 优先检查Images表中是否已有完整的描述
|
||||
existing_image = Images.get_or_none(Images.emoji_hash == image_hash)
|
||||
if existing_image:
|
||||
# 更新计数
|
||||
if hasattr(existing_image, "count") and existing_image.count is not None:
|
||||
existing_image.count += 1
|
||||
else:
|
||||
existing_image.count = 1
|
||||
existing_image.save()
|
||||
with get_db_session() as session:
|
||||
existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar()
|
||||
if existing_image:
|
||||
# 更新计数
|
||||
if hasattr(existing_image, "count") and existing_image.count is not None:
|
||||
existing_image.count += 1
|
||||
else:
|
||||
existing_image.count = 1
|
||||
|
||||
# 如果已有描述,直接返回
|
||||
if existing_image.description:
|
||||
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...")
|
||||
return f"[图片:{existing_image.description}]"
|
||||
# 如果已有描述,直接返回
|
||||
if existing_image.description:
|
||||
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...")
|
||||
return f"[图片:{existing_image.description}]"
|
||||
|
||||
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
|
||||
@@ -300,10 +326,10 @@ class ImageManager:
|
||||
existing_image.image_id = str(uuid.uuid4())
|
||||
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
|
||||
existing_image.vlm_processed = True
|
||||
existing_image.save()
|
||||
session.commit()
|
||||
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
|
||||
else:
|
||||
Images.create(
|
||||
new_img = Images(
|
||||
image_id=str(uuid.uuid4()),
|
||||
emoji_hash=image_hash,
|
||||
path=file_path,
|
||||
@@ -313,6 +339,8 @@ class ImageManager:
|
||||
vlm_processed=True,
|
||||
count=1,
|
||||
)
|
||||
session.add(new_img)
|
||||
session.commit()
|
||||
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"保存图片文件或元数据失败: {str(e)}")
|
||||
@@ -465,31 +493,32 @@ class ImageManager:
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
with get_db_session() as session:
|
||||
existing_image = session.execute(select(Images).where(Images.emoji_hash == image_hash)).scalar()
|
||||
if existing_image:
|
||||
# 检查是否缺少必要字段,如果缺少则创建新记录
|
||||
if (
|
||||
not hasattr(existing_image, "image_id")
|
||||
or not existing_image.image_id
|
||||
or not hasattr(existing_image, "count")
|
||||
or existing_image.count is None
|
||||
or not hasattr(existing_image, "vlm_processed")
|
||||
or existing_image.vlm_processed is None
|
||||
):
|
||||
logger.debug(f"图片记录缺少必要字段,补全旧记录: {image_hash}")
|
||||
if not existing_image.image_id:
|
||||
existing_image.image_id = str(uuid.uuid4())
|
||||
if existing_image.count is None:
|
||||
existing_image.count = 0
|
||||
if existing_image.vlm_processed is None:
|
||||
existing_image.vlm_processed = False
|
||||
|
||||
if existing_image := Images.get_or_none(Images.emoji_hash == image_hash):
|
||||
# 检查是否缺少必要字段,如果缺少则创建新记录
|
||||
if (
|
||||
not hasattr(existing_image, "image_id")
|
||||
or not existing_image.image_id
|
||||
or not hasattr(existing_image, "count")
|
||||
or existing_image.count is None
|
||||
or not hasattr(existing_image, "vlm_processed")
|
||||
or existing_image.vlm_processed is None
|
||||
):
|
||||
logger.debug(f"图片记录缺少必要字段,补全旧记录: {image_hash}")
|
||||
if not existing_image.image_id:
|
||||
existing_image.image_id = str(uuid.uuid4())
|
||||
if existing_image.count is None:
|
||||
existing_image.count = 0
|
||||
if existing_image.vlm_processed is None:
|
||||
existing_image.vlm_processed = False
|
||||
existing_image.count += 1
|
||||
session.commit()
|
||||
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
|
||||
|
||||
existing_image.count += 1
|
||||
existing_image.save()
|
||||
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
|
||||
else:
|
||||
# print(f"图片不存在: {image_hash}")
|
||||
image_id = str(uuid.uuid4())
|
||||
# print(f"图片不存在: {image_hash}")
|
||||
image_id = str(uuid.uuid4())
|
||||
|
||||
# 保存新图片
|
||||
current_timestamp = time.time()
|
||||
@@ -503,7 +532,7 @@ class ImageManager:
|
||||
f.write(image_bytes)
|
||||
|
||||
# 保存到数据库
|
||||
Images.create(
|
||||
new_img = Images(
|
||||
image_id=image_id,
|
||||
emoji_hash=image_hash,
|
||||
path=file_path,
|
||||
@@ -512,6 +541,8 @@ class ImageManager:
|
||||
vlm_processed=False,
|
||||
count=1,
|
||||
)
|
||||
session.add(new_img)
|
||||
session.commit()
|
||||
|
||||
# 启动异步VLM处理
|
||||
asyncio.create_task(self._process_image_with_vlm(image_id, image_base64))
|
||||
@@ -536,60 +567,64 @@ class ImageManager:
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
with get_db_session() as session:
|
||||
# 获取当前图片记录
|
||||
image = session.execute(select(Images).where(Images.image_id == image_id)).scalar()
|
||||
|
||||
# 获取当前图片记录
|
||||
image = Images.get(Images.image_id == image_id)
|
||||
# 优先检查是否已有其他相同哈希的图片记录包含描述
|
||||
existing_with_description = session.execute(select(Images).where(
|
||||
and_(
|
||||
Images.emoji_hash == image_hash,
|
||||
Images.description.isnot(None),
|
||||
Images.description != "",
|
||||
Images.id != image.id
|
||||
)
|
||||
)).scalar()
|
||||
if existing_with_description:
|
||||
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
|
||||
image.description = existing_with_description.description
|
||||
image.vlm_processed = True
|
||||
session.commit()
|
||||
# 同时保存到ImageDescriptions表作为备用缓存
|
||||
self._save_description_to_db(image_hash, existing_with_description.description, "image")
|
||||
return
|
||||
|
||||
# 优先检查是否已有其他相同哈希的图片记录包含描述
|
||||
existing_with_description = Images.get_or_none(
|
||||
(Images.emoji_hash == image_hash) & (Images.description.is_null(False)) & (Images.description != "")
|
||||
)
|
||||
if existing_with_description and existing_with_description.id != image.id:
|
||||
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
|
||||
image.description = existing_with_description.description
|
||||
# 检查ImageDescriptions表的缓存描述
|
||||
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
|
||||
image.description = cached_description
|
||||
image.vlm_processed = True
|
||||
session.commit()
|
||||
return
|
||||
|
||||
# 获取图片格式
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
|
||||
# 构建prompt
|
||||
prompt = global_config.custom_prompt.image_prompt
|
||||
|
||||
# 获取VLM描述
|
||||
logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)")
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
|
||||
)
|
||||
|
||||
if description is None:
|
||||
logger.warning("VLM未能生成图片描述")
|
||||
description = "无法生成描述"
|
||||
|
||||
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}")
|
||||
description = cached_description
|
||||
|
||||
# 更新数据库
|
||||
image.description = description
|
||||
image.vlm_processed = True
|
||||
image.save()
|
||||
# 同时保存到ImageDescriptions表作为备用缓存
|
||||
self._save_description_to_db(image_hash, existing_with_description.description, "image")
|
||||
return
|
||||
|
||||
# 检查ImageDescriptions表的缓存描述
|
||||
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
|
||||
image.description = cached_description
|
||||
image.vlm_processed = True
|
||||
image.save()
|
||||
return
|
||||
# 保存描述到ImageDescriptions表作为备用缓存
|
||||
self._save_description_to_db(image_hash, description, "image")
|
||||
|
||||
# 获取图片格式
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
|
||||
# 构建prompt
|
||||
prompt = global_config.custom_prompt.image_prompt
|
||||
|
||||
# 获取VLM描述
|
||||
logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)")
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
|
||||
)
|
||||
|
||||
if description is None:
|
||||
logger.warning("VLM未能生成图片描述")
|
||||
description = "无法生成描述"
|
||||
|
||||
if cached_description := self._get_description_from_db(image_hash, "image"):
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}")
|
||||
description = cached_description
|
||||
|
||||
# 更新数据库
|
||||
image.description = description
|
||||
image.vlm_processed = True
|
||||
image.save()
|
||||
|
||||
# 保存描述到ImageDescriptions表作为备用缓存
|
||||
self._save_description_to_db(image_hash, description, "image")
|
||||
|
||||
logger.info(f"[VLM异步完成] 图片描述生成: {description[:50]}...")
|
||||
logger.info(f"[VLM异步完成] 图片描述生成: {description[:50]}...")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"VLM处理图片失败: {str(e)}")
|
||||
|
||||
@@ -1,14 +1,103 @@
|
||||
import os
|
||||
from pymongo import MongoClient
|
||||
from peewee import SqliteDatabase
|
||||
from pymongo.database import Database
|
||||
from rich.traceback import install
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# SQLAlchemy相关导入
|
||||
from src.common.database.sqlalchemy_init import initialize_database_compat
|
||||
from src.common.database.sqlalchemy_models import get_engine, get_session
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
_client = None
|
||||
_db = None
|
||||
_sql_engine = None
|
||||
|
||||
logger = get_logger("database")
|
||||
|
||||
# 兼容性:为了不破坏现有代码,保留db变量但指向SQLAlchemy
|
||||
class DatabaseProxy:
|
||||
"""数据库代理类,提供Peewee到SQLAlchemy的兼容性接口"""
|
||||
|
||||
def __init__(self):
|
||||
self._engine = None
|
||||
self._session = None
|
||||
|
||||
def initialize(self, *args, **kwargs):
|
||||
"""初始化数据库连接"""
|
||||
return initialize_database_compat()
|
||||
|
||||
def connect(self, reuse_if_open=True):
|
||||
"""连接数据库(兼容性方法)"""
|
||||
try:
|
||||
self._engine = get_engine()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
return False
|
||||
|
||||
def is_closed(self):
|
||||
"""检查数据库是否关闭(兼容性方法)"""
|
||||
return self._engine is None
|
||||
|
||||
def create_tables(self, models, safe=True):
|
||||
"""创建表(兼容性方法)"""
|
||||
try:
|
||||
from src.common.database.sqlalchemy_models import Base
|
||||
engine = get_engine()
|
||||
Base.metadata.create_all(bind=engine)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"创建表失败: {e}")
|
||||
return False
|
||||
|
||||
def table_exists(self, model):
|
||||
"""检查表是否存在(兼容性方法)"""
|
||||
try:
|
||||
from sqlalchemy import inspect
|
||||
engine = get_engine()
|
||||
inspector = inspect(engine)
|
||||
table_name = getattr(model, '_meta', {}).get('table_name', model.__name__.lower())
|
||||
return table_name in inspector.get_table_names()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def execute_sql(self, sql):
|
||||
"""执行SQL(兼容性方法)"""
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
session = get_session()
|
||||
result = session.execute(text(sql))
|
||||
session.close()
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"执行SQL失败: {e}")
|
||||
raise
|
||||
|
||||
def atomic(self):
|
||||
"""事务上下文管理器(兼容性方法)"""
|
||||
return SQLAlchemyTransaction()
|
||||
|
||||
class SQLAlchemyTransaction:
|
||||
"""SQLAlchemy事务上下文管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.session = None
|
||||
|
||||
def __enter__(self):
|
||||
self.session = get_session()
|
||||
return self.session
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type is None:
|
||||
self.session.commit()
|
||||
else:
|
||||
self.session.rollback()
|
||||
self.session.close()
|
||||
|
||||
# 创建全局数据库代理实例
|
||||
db = DatabaseProxy()
|
||||
|
||||
def __create_database_instance():
|
||||
uri = os.getenv("MONGODB_URI")
|
||||
@@ -39,7 +128,7 @@ def __create_database_instance():
|
||||
|
||||
|
||||
def get_db():
|
||||
"""获取数据库连接实例,延迟初始化。"""
|
||||
"""获取MongoDB连接实例,延迟初始化。"""
|
||||
global _client, _db
|
||||
if _client is None:
|
||||
_client = __create_database_instance()
|
||||
@@ -47,6 +136,47 @@ def get_db():
|
||||
return _db
|
||||
|
||||
|
||||
def initialize_sql_database(database_config):
|
||||
"""
|
||||
根据配置初始化SQL数据库连接(SQLAlchemy版本)
|
||||
|
||||
Args:
|
||||
database_config: DatabaseConfig对象
|
||||
"""
|
||||
global _sql_engine
|
||||
|
||||
try:
|
||||
logger.info("使用SQLAlchemy初始化SQL数据库...")
|
||||
|
||||
# 记录数据库配置信息
|
||||
if database_config.database_type == "mysql":
|
||||
connection_info = f"{database_config.mysql_user}@{database_config.mysql_host}:{database_config.mysql_port}/{database_config.mysql_database}"
|
||||
logger.info("MySQL数据库连接配置:")
|
||||
logger.info(f" 连接信息: {connection_info}")
|
||||
logger.info(f" 字符集: {database_config.mysql_charset}")
|
||||
else:
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
if not os.path.isabs(database_config.sqlite_path):
|
||||
db_path = os.path.join(ROOT_PATH, database_config.sqlite_path)
|
||||
else:
|
||||
db_path = database_config.sqlite_path
|
||||
logger.info("SQLite数据库连接配置:")
|
||||
logger.info(f" 数据库文件: {db_path}")
|
||||
|
||||
# 使用SQLAlchemy初始化
|
||||
success = initialize_database_compat()
|
||||
if success:
|
||||
_sql_engine = get_engine()
|
||||
logger.info("SQLAlchemy数据库初始化成功")
|
||||
else:
|
||||
logger.error("SQLAlchemy数据库初始化失败")
|
||||
|
||||
return _sql_engine
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"初始化SQL数据库失败: {e}")
|
||||
return None
|
||||
|
||||
class DBWrapper:
|
||||
"""数据库代理类,保持接口兼容性同时实现懒加载。"""
|
||||
|
||||
@@ -57,26 +187,6 @@ class DBWrapper:
|
||||
return get_db()[key] # type: ignore
|
||||
|
||||
|
||||
# 全局数据库访问点
|
||||
# 全局MongoDB数据库访问点
|
||||
memory_db: Database = DBWrapper() # type: ignore
|
||||
|
||||
# 定义数据库文件路径
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
_DB_DIR = os.path.join(ROOT_PATH, "data")
|
||||
_DB_FILE = os.path.join(_DB_DIR, "MaiBot.db")
|
||||
|
||||
# 确保数据库目录存在
|
||||
os.makedirs(_DB_DIR, exist_ok=True)
|
||||
|
||||
# 全局 Peewee SQLite 数据库访问点
|
||||
db = SqliteDatabase(
|
||||
_DB_FILE,
|
||||
pragmas={
|
||||
"journal_mode": "wal", # WAL模式提高并发性能
|
||||
"cache_size": -64 * 1000, # 64MB缓存
|
||||
"foreign_keys": 1,
|
||||
"ignore_check_constraints": 0,
|
||||
"synchronous": 0, # 异步写入提高性能
|
||||
"busy_timeout": 1000, # 1秒超时而不是3秒
|
||||
},
|
||||
)
|
||||
|
||||
420
src/common/database/sqlalchemy_database_api.py
Normal file
420
src/common/database/sqlalchemy_database_api.py
Normal file
@@ -0,0 +1,420 @@
|
||||
"""SQLAlchemy数据库API模块
|
||||
|
||||
提供基于SQLAlchemy的数据库操作,替换Peewee以解决MySQL连接问题
|
||||
支持自动重连、连接池管理和更好的错误处理
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import time
|
||||
from typing import Dict, List, Any, Union, Type, Optional
|
||||
from contextlib import contextmanager
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import SQLAlchemyError, DisconnectionError, OperationalError
|
||||
from sqlalchemy import desc, asc, func, and_, or_
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import (
|
||||
Base, get_db_session, Messages, ActionRecords, PersonInfo, ChatStreams,
|
||||
LLMUsage, Emoji, Images, ImageDescriptions, OnlineTime, Memory,
|
||||
Expression, ThinkingLog, GraphNodes, GraphEdges,get_session
|
||||
)
|
||||
|
||||
logger = get_logger("sqlalchemy_database_api")
|
||||
|
||||
# 模型映射表,用于通过名称获取模型类
|
||||
MODEL_MAPPING = {
|
||||
'Messages': Messages,
|
||||
'ActionRecords': ActionRecords,
|
||||
'PersonInfo': PersonInfo,
|
||||
'ChatStreams': ChatStreams,
|
||||
'LLMUsage': LLMUsage,
|
||||
'Emoji': Emoji,
|
||||
'Images': Images,
|
||||
'ImageDescriptions': ImageDescriptions,
|
||||
'OnlineTime': OnlineTime,
|
||||
'Memory': Memory,
|
||||
'Expression': Expression,
|
||||
'ThinkingLog': ThinkingLog,
|
||||
'GraphNodes': GraphNodes,
|
||||
'GraphEdges': GraphEdges,
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_session():
|
||||
"""数据库会话上下文管理器,自动处理事务和连接错误"""
|
||||
session = None
|
||||
max_retries = 3
|
||||
retry_delay = 1.0
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
session = get_session()
|
||||
yield session
|
||||
session.commit()
|
||||
break
|
||||
except (DisconnectionError, OperationalError) as e:
|
||||
logger.warning(f"数据库连接错误 (尝试 {attempt + 1}/{max_retries}): {e}")
|
||||
if session:
|
||||
session.rollback()
|
||||
session.close()
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(retry_delay * (attempt + 1))
|
||||
else:
|
||||
raise
|
||||
except Exception as e:
|
||||
if session:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
if session:
|
||||
session.close()
|
||||
|
||||
|
||||
def build_filters(session: Session, model_class: Type[Base], filters: Dict[str, Any]):
|
||||
"""构建查询过滤条件"""
|
||||
conditions = []
|
||||
|
||||
for field_name, value in filters.items():
|
||||
if not hasattr(model_class, field_name):
|
||||
logger.warning(f"模型 {model_class.__name__} 中不存在字段 '{field_name}'")
|
||||
continue
|
||||
|
||||
field = getattr(model_class, field_name)
|
||||
|
||||
if isinstance(value, dict):
|
||||
# 处理 MongoDB 风格的操作符
|
||||
for op, op_value in value.items():
|
||||
if op == "$gt":
|
||||
conditions.append(field > op_value)
|
||||
elif op == "$lt":
|
||||
conditions.append(field < op_value)
|
||||
elif op == "$gte":
|
||||
conditions.append(field >= op_value)
|
||||
elif op == "$lte":
|
||||
conditions.append(field <= op_value)
|
||||
elif op == "$ne":
|
||||
conditions.append(field != op_value)
|
||||
elif op == "$in":
|
||||
conditions.append(field.in_(op_value))
|
||||
elif op == "$nin":
|
||||
conditions.append(~field.in_(op_value))
|
||||
else:
|
||||
logger.warning(f"未知操作符 '{op}' (字段: '{field_name}')")
|
||||
else:
|
||||
# 直接相等比较
|
||||
conditions.append(field == value)
|
||||
|
||||
return conditions
|
||||
|
||||
|
||||
async def db_query(
|
||||
model_class: Type[Base],
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
query_type: Optional[str] = "get",
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[List[str]] = None,
|
||||
single_result: Optional[bool] = False,
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""执行数据库查询操作
|
||||
|
||||
Args:
|
||||
model_class: SQLAlchemy模型类
|
||||
data: 用于创建或更新的数据字典
|
||||
query_type: 查询类型 ("get", "create", "update", "delete", "count")
|
||||
filters: 过滤条件字典
|
||||
limit: 限制结果数量
|
||||
order_by: 排序字段,前缀'-'表示降序
|
||||
single_result: 是否只返回单个结果
|
||||
|
||||
Returns:
|
||||
根据查询类型返回相应结果
|
||||
"""
|
||||
try:
|
||||
if query_type not in ["get", "create", "update", "delete", "count"]:
|
||||
raise ValueError("query_type must be 'get', 'create', 'update', 'delete' or 'count'")
|
||||
|
||||
with get_db_session() as session:
|
||||
if query_type == "get":
|
||||
query = session.query(model_class)
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
conditions = build_filters(session, model_class, filters)
|
||||
if conditions:
|
||||
query = query.filter(and_(*conditions))
|
||||
|
||||
# 应用排序
|
||||
if order_by:
|
||||
for field_name in order_by:
|
||||
if field_name.startswith("-"):
|
||||
field_name = field_name[1:]
|
||||
if hasattr(model_class, field_name):
|
||||
query = query.order_by(desc(getattr(model_class, field_name)))
|
||||
else:
|
||||
if hasattr(model_class, field_name):
|
||||
query = query.order_by(asc(getattr(model_class, field_name)))
|
||||
|
||||
# 应用限制
|
||||
if limit and limit > 0:
|
||||
query = query.limit(limit)
|
||||
|
||||
# 执行查询
|
||||
results = query.all()
|
||||
|
||||
# 转换为字典格式
|
||||
result_dicts = []
|
||||
for result in results:
|
||||
result_dict = {}
|
||||
for column in result.__table__.columns:
|
||||
result_dict[column.name] = getattr(result, column.name)
|
||||
result_dicts.append(result_dict)
|
||||
|
||||
if single_result:
|
||||
return result_dicts[0] if result_dicts else None
|
||||
return result_dicts
|
||||
|
||||
elif query_type == "create":
|
||||
if not data:
|
||||
raise ValueError("创建记录需要提供data参数")
|
||||
|
||||
# 创建新记录
|
||||
new_record = model_class(**data)
|
||||
session.add(new_record)
|
||||
session.flush() # 获取自动生成的ID
|
||||
|
||||
# 转换为字典格式返回
|
||||
result_dict = {}
|
||||
for column in new_record.__table__.columns:
|
||||
result_dict[column.name] = getattr(new_record, column.name)
|
||||
return result_dict
|
||||
|
||||
elif query_type == "update":
|
||||
if not data:
|
||||
raise ValueError("更新记录需要提供data参数")
|
||||
|
||||
query = session.query(model_class)
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
conditions = build_filters(session, model_class, filters)
|
||||
if conditions:
|
||||
query = query.filter(and_(*conditions))
|
||||
|
||||
# 执行更新
|
||||
affected_rows = query.update(data)
|
||||
return affected_rows
|
||||
|
||||
elif query_type == "delete":
|
||||
query = session.query(model_class)
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
conditions = build_filters(session, model_class, filters)
|
||||
if conditions:
|
||||
query = query.filter(and_(*conditions))
|
||||
|
||||
# 执行删除
|
||||
affected_rows = query.delete()
|
||||
return affected_rows
|
||||
|
||||
elif query_type == "count":
|
||||
query = session.query(func.count(model_class.id))
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
base_query = session.query(model_class)
|
||||
conditions = build_filters(session, model_class, filters)
|
||||
if conditions:
|
||||
base_query = base_query.filter(and_(*conditions))
|
||||
query = session.query(func.count()).select_from(base_query.subquery())
|
||||
|
||||
return query.scalar()
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"[SQLAlchemy] 数据库操作出错: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# 根据查询类型返回合适的默认值
|
||||
if query_type == "get":
|
||||
return None if single_result else []
|
||||
elif query_type in ["create", "update", "delete", "count"]:
|
||||
return None
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SQLAlchemy] 意外错误: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
if query_type == "get":
|
||||
return None if single_result else []
|
||||
return None
|
||||
|
||||
|
||||
async def db_save(
|
||||
model_class: Type[Base],
|
||||
data: Dict[str, Any],
|
||||
key_field: Optional[str] = None,
|
||||
key_value: Optional[Any] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""保存数据到数据库(创建或更新)
|
||||
|
||||
Args:
|
||||
model_class: SQLAlchemy模型类
|
||||
data: 要保存的数据字典
|
||||
key_field: 用于查找现有记录的字段名
|
||||
key_value: 用于查找现有记录的字段值
|
||||
|
||||
Returns:
|
||||
保存后的记录数据或None
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
# 如果提供了key_field和key_value,尝试更新现有记录
|
||||
if key_field and key_value is not None:
|
||||
if hasattr(model_class, key_field):
|
||||
existing_record = session.query(model_class).filter(
|
||||
getattr(model_class, key_field) == key_value
|
||||
).first()
|
||||
|
||||
if existing_record:
|
||||
# 更新现有记录
|
||||
for field, value in data.items():
|
||||
if hasattr(existing_record, field):
|
||||
setattr(existing_record, field, value)
|
||||
|
||||
session.flush()
|
||||
|
||||
# 转换为字典格式返回
|
||||
result_dict = {}
|
||||
for column in existing_record.__table__.columns:
|
||||
result_dict[column.name] = getattr(existing_record, column.name)
|
||||
return result_dict
|
||||
|
||||
# 创建新记录
|
||||
new_record = model_class(**data)
|
||||
session.add(new_record)
|
||||
session.flush()
|
||||
|
||||
# 转换为字典格式返回
|
||||
result_dict = {}
|
||||
for column in new_record.__table__.columns:
|
||||
result_dict[column.name] = getattr(new_record, column.name)
|
||||
return result_dict
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"[SQLAlchemy] 保存数据库记录出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[SQLAlchemy] 保存时意外错误: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
async def db_get(
|
||||
model_class: Type[Base],
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[str] = None,
|
||||
single_result: Optional[bool] = False,
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""从数据库获取记录
|
||||
|
||||
Args:
|
||||
model_class: SQLAlchemy模型类
|
||||
filters: 过滤条件
|
||||
limit: 结果数量限制
|
||||
order_by: 排序字段,前缀'-'表示降序
|
||||
single_result: 是否只返回单个结果
|
||||
|
||||
Returns:
|
||||
记录数据或None
|
||||
"""
|
||||
order_by_list = [order_by] if order_by else None
|
||||
return await db_query(
|
||||
model_class=model_class,
|
||||
query_type="get",
|
||||
filters=filters,
|
||||
limit=limit,
|
||||
order_by=order_by_list,
|
||||
single_result=single_result
|
||||
)
|
||||
|
||||
|
||||
async def store_action_info(
|
||||
chat_stream=None,
|
||||
action_build_into_prompt: bool = False,
|
||||
action_prompt_display: str = "",
|
||||
action_done: bool = True,
|
||||
thinking_id: str = "",
|
||||
action_data: Optional[dict] = None,
|
||||
action_name: str = "",
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""存储动作信息到数据库
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象
|
||||
action_build_into_prompt: 是否将此动作构建到提示中
|
||||
action_prompt_display: 动作的提示显示文本
|
||||
action_done: 动作是否完成
|
||||
thinking_id: 关联的思考ID
|
||||
action_data: 动作数据字典
|
||||
action_name: 动作名称
|
||||
|
||||
Returns:
|
||||
保存的记录数据或None
|
||||
"""
|
||||
try:
|
||||
import json
|
||||
|
||||
# 构建动作记录数据
|
||||
record_data = {
|
||||
"action_id": thinking_id or str(int(time.time() * 1000000)),
|
||||
"time": time.time(),
|
||||
"action_name": action_name,
|
||||
"action_data": json.dumps(action_data or {}, ensure_ascii=False),
|
||||
"action_done": action_done,
|
||||
"action_build_into_prompt": action_build_into_prompt,
|
||||
"action_prompt_display": action_prompt_display,
|
||||
}
|
||||
|
||||
# 从chat_stream获取聊天信息
|
||||
if chat_stream:
|
||||
record_data.update({
|
||||
"chat_id": getattr(chat_stream, "stream_id", ""),
|
||||
"chat_info_stream_id": getattr(chat_stream, "stream_id", ""),
|
||||
"chat_info_platform": getattr(chat_stream, "platform", ""),
|
||||
})
|
||||
else:
|
||||
record_data.update({
|
||||
"chat_id": "",
|
||||
"chat_info_stream_id": "",
|
||||
"chat_info_platform": "",
|
||||
})
|
||||
|
||||
# 保存记录
|
||||
saved_record = await db_save(
|
||||
ActionRecords,
|
||||
data=record_data,
|
||||
key_field="action_id",
|
||||
key_value=record_data["action_id"]
|
||||
)
|
||||
|
||||
if saved_record:
|
||||
logger.debug(f"[SQLAlchemy] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
|
||||
else:
|
||||
logger.error(f"[SQLAlchemy] 存储动作信息失败: {action_name}")
|
||||
|
||||
return saved_record
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SQLAlchemy] 存储动作信息时发生错误: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
# 兼容性函数,方便从Peewee迁移
|
||||
def get_model_class(model_name: str) -> Optional[Type[Base]]:
|
||||
"""根据模型名称获取模型类"""
|
||||
return MODEL_MAPPING.get(model_name)
|
||||
158
src/common/database/sqlalchemy_init.py
Normal file
158
src/common/database/sqlalchemy_init.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""SQLAlchemy数据库初始化模块
|
||||
|
||||
替换Peewee的数据库初始化逻辑
|
||||
提供统一的数据库初始化接口
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import (
|
||||
Base, get_engine, get_session, initialize_database
|
||||
)
|
||||
|
||||
logger = get_logger("sqlalchemy_init")
|
||||
|
||||
|
||||
def initialize_sqlalchemy_database() -> bool:
|
||||
"""
|
||||
初始化SQLAlchemy数据库
|
||||
创建所有表结构
|
||||
|
||||
Returns:
|
||||
bool: 初始化是否成功
|
||||
"""
|
||||
try:
|
||||
logger.info("开始初始化SQLAlchemy数据库...")
|
||||
|
||||
# 初始化数据库引擎和会话
|
||||
engine, session_local = initialize_database()
|
||||
|
||||
if engine is None:
|
||||
logger.error("数据库引擎初始化失败")
|
||||
return False
|
||||
|
||||
logger.info("SQLAlchemy数据库初始化成功")
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy数据库初始化失败: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"数据库初始化过程中发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def create_all_tables() -> bool:
|
||||
"""
|
||||
创建所有数据库表
|
||||
|
||||
Returns:
|
||||
bool: 创建是否成功
|
||||
"""
|
||||
try:
|
||||
logger.info("开始创建数据库表...")
|
||||
|
||||
engine = get_engine()
|
||||
if engine is None:
|
||||
logger.error("无法获取数据库引擎")
|
||||
return False
|
||||
|
||||
# 创建所有表
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
logger.info("数据库表创建成功")
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"创建数据库表失败: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"创建数据库表过程中发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def check_database_connection() -> bool:
|
||||
"""
|
||||
检查数据库连接是否正常
|
||||
|
||||
Returns:
|
||||
bool: 连接是否正常
|
||||
"""
|
||||
try:
|
||||
session = get_session()
|
||||
if session is None:
|
||||
logger.error("无法获取数据库会话")
|
||||
return False
|
||||
|
||||
# 检查会话是否可用(如果能获取到会话说明连接正常)
|
||||
if session is None:
|
||||
logger.error("数据库会话无效")
|
||||
return False
|
||||
|
||||
session.close()
|
||||
|
||||
logger.info("数据库连接检查通过")
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"数据库连接检查失败: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接检查过程中发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_database_info() -> Optional[dict]:
|
||||
"""
|
||||
获取数据库信息
|
||||
|
||||
Returns:
|
||||
dict: 数据库信息字典,包含引擎信息等
|
||||
"""
|
||||
try:
|
||||
engine = get_engine()
|
||||
if engine is None:
|
||||
return None
|
||||
|
||||
info = {
|
||||
'engine_name': engine.name,
|
||||
'driver': engine.driver,
|
||||
'url': str(engine.url).replace(engine.url.password or '', '***'), # 隐藏密码
|
||||
'pool_size': getattr(engine.pool, 'size', None),
|
||||
'max_overflow': getattr(engine.pool, 'max_overflow', None),
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取数据库信息失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
_database_initialized = False
|
||||
|
||||
def initialize_database_compat() -> bool:
|
||||
"""
|
||||
兼容性数据库初始化函数
|
||||
用于替换原有的Peewee初始化代码
|
||||
|
||||
Returns:
|
||||
bool: 初始化是否成功
|
||||
"""
|
||||
global _database_initialized
|
||||
|
||||
if _database_initialized:
|
||||
return True
|
||||
|
||||
success = initialize_sqlalchemy_database()
|
||||
if success:
|
||||
success = create_all_tables()
|
||||
|
||||
if success:
|
||||
success = check_database_connection()
|
||||
|
||||
if success:
|
||||
_database_initialized = True
|
||||
|
||||
return success
|
||||
555
src/common/database/sqlalchemy_models.py
Normal file
555
src/common/database/sqlalchemy_models.py
Normal file
@@ -0,0 +1,555 @@
|
||||
"""SQLAlchemy数据库模型定义
|
||||
|
||||
替换Peewee ORM,使用SQLAlchemy提供更好的连接池管理和错误恢复能力
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, create_engine, DateTime
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import QueuePool
|
||||
import os
|
||||
import datetime
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
|
||||
logger = get_logger("sqlalchemy_models")
|
||||
|
||||
# 创建基类
|
||||
Base = declarative_base()
|
||||
|
||||
# MySQL兼容的字段类型辅助函数
|
||||
def get_string_field(max_length=255, **kwargs):
|
||||
"""
|
||||
根据数据库类型返回合适的字符串字段
|
||||
MySQL需要指定长度的VARCHAR用于索引,SQLite可以使用Text
|
||||
"""
|
||||
if global_config.database.database_type == "mysql":
|
||||
return String(max_length, **kwargs)
|
||||
else:
|
||||
return Text(**kwargs)
|
||||
|
||||
class SessionProxy:
|
||||
"""线程安全的Session代理类,自动管理session生命周期"""
|
||||
|
||||
def __init__(self):
|
||||
self._local = threading.local()
|
||||
|
||||
def _get_current_session(self):
|
||||
"""获取当前线程的session,如果没有则创建新的"""
|
||||
if not hasattr(self._local, 'session') or self._local.session is None:
|
||||
_, SessionLocal = initialize_database()
|
||||
self._local.session = SessionLocal()
|
||||
return self._local.session
|
||||
|
||||
def _close_current_session(self):
|
||||
"""关闭当前线程的session"""
|
||||
if hasattr(self._local, 'session') and self._local.session is not None:
|
||||
try:
|
||||
self._local.session.close()
|
||||
except:
|
||||
pass
|
||||
finally:
|
||||
self._local.session = None
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""代理所有session方法"""
|
||||
session = self._get_current_session()
|
||||
attr = getattr(session, name)
|
||||
|
||||
# 如果是方法,需要特殊处理一些关键方法
|
||||
if callable(attr):
|
||||
if name in ['commit', 'rollback']:
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
result = attr(*args, **kwargs)
|
||||
if name == 'commit':
|
||||
# commit后不要清除session,只是刷新状态
|
||||
pass # 保持session活跃
|
||||
return result
|
||||
except Exception as e:
|
||||
try:
|
||||
if session and hasattr(session, 'rollback'):
|
||||
session.rollback()
|
||||
except:
|
||||
pass
|
||||
# 发生错误时重新创建session
|
||||
self._close_current_session()
|
||||
raise
|
||||
return wrapper
|
||||
elif name == 'close':
|
||||
def wrapper(*args, **kwargs):
|
||||
result = attr(*args, **kwargs)
|
||||
self._close_current_session()
|
||||
return result
|
||||
return wrapper
|
||||
elif name in ['execute', 'query', 'add', 'delete', 'merge']:
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return attr(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# 如果是连接相关错误,重新创建session再试一次
|
||||
if "not bound to a Session" in str(e) or "provisioning a new connection" in str(e):
|
||||
logger.warning(f"Session问题,重新创建session: {e}")
|
||||
self._close_current_session()
|
||||
new_session = self._get_current_session()
|
||||
new_attr = getattr(new_session, name)
|
||||
return new_attr(*args, **kwargs)
|
||||
raise
|
||||
return wrapper
|
||||
|
||||
return attr
|
||||
|
||||
def new_session(self):
|
||||
"""强制创建新的session(关闭当前的,创建新的)"""
|
||||
self._close_current_session()
|
||||
return self._get_current_session()
|
||||
|
||||
def ensure_fresh_session(self):
|
||||
"""确保使用新鲜的session(如果当前session有问题则重新创建)"""
|
||||
if hasattr(self._local, 'session') and self._local.session is not None:
|
||||
try:
|
||||
# 测试session是否还可用
|
||||
self._local.session.execute("SELECT 1")
|
||||
except Exception:
|
||||
# session有问题,重新创建
|
||||
self._close_current_session()
|
||||
return self._get_current_session()
|
||||
|
||||
# 创建全局session代理实例
|
||||
_global_session_proxy = SessionProxy()
|
||||
|
||||
def get_session():
|
||||
"""返回线程安全的session代理,自动管理生命周期"""
|
||||
return _global_session_proxy
|
||||
|
||||
|
||||
class ChatStreams(Base):
|
||||
"""聊天流模型"""
|
||||
__tablename__ = 'chat_streams'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
stream_id = Column(get_string_field(64), nullable=False, unique=True, index=True)
|
||||
create_time = Column(Float, nullable=False)
|
||||
group_platform = Column(Text, nullable=True)
|
||||
group_id = Column(get_string_field(100), nullable=True, index=True)
|
||||
group_name = Column(Text, nullable=True)
|
||||
last_active_time = Column(Float, nullable=False)
|
||||
platform = Column(Text, nullable=False)
|
||||
user_platform = Column(Text, nullable=False)
|
||||
user_id = Column(get_string_field(100), nullable=False, index=True)
|
||||
user_nickname = Column(Text, nullable=False)
|
||||
user_cardname = Column(Text, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_chatstreams_stream_id', 'stream_id'),
|
||||
Index('idx_chatstreams_user_id', 'user_id'),
|
||||
Index('idx_chatstreams_group_id', 'group_id'),
|
||||
)
|
||||
|
||||
|
||||
class LLMUsage(Base):
|
||||
"""LLM使用记录模型"""
|
||||
__tablename__ = 'llm_usage'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
model_name = Column(get_string_field(100), nullable=False, index=True)
|
||||
user_id = Column(get_string_field(50), nullable=False, index=True)
|
||||
request_type = Column(get_string_field(50), nullable=False, index=True)
|
||||
endpoint = Column(Text, nullable=False)
|
||||
prompt_tokens = Column(Integer, nullable=False)
|
||||
completion_tokens = Column(Integer, nullable=False)
|
||||
total_tokens = Column(Integer, nullable=False)
|
||||
cost = Column(Float, nullable=False)
|
||||
status = Column(Text, nullable=False)
|
||||
timestamp = Column(DateTime, nullable=False, index=True, default=datetime.datetime.now)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_llmusage_model_name', 'model_name'),
|
||||
Index('idx_llmusage_user_id', 'user_id'),
|
||||
Index('idx_llmusage_request_type', 'request_type'),
|
||||
Index('idx_llmusage_timestamp', 'timestamp'),
|
||||
)
|
||||
|
||||
|
||||
class Emoji(Base):
|
||||
"""表情包模型"""
|
||||
__tablename__ = 'emoji'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
full_path = Column(get_string_field(500), nullable=False, unique=True, index=True)
|
||||
format = Column(Text, nullable=False)
|
||||
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
|
||||
description = Column(Text, nullable=False)
|
||||
query_count = Column(Integer, nullable=False, default=0)
|
||||
is_registered = Column(Boolean, nullable=False, default=False)
|
||||
is_banned = Column(Boolean, nullable=False, default=False)
|
||||
emotion = Column(Text, nullable=True)
|
||||
record_time = Column(Float, nullable=False)
|
||||
register_time = Column(Float, nullable=True)
|
||||
usage_count = Column(Integer, nullable=False, default=0)
|
||||
last_used_time = Column(Float, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_emoji_full_path', 'full_path'),
|
||||
Index('idx_emoji_hash', 'emoji_hash'),
|
||||
)
|
||||
|
||||
|
||||
class Messages(Base):
|
||||
"""消息模型"""
|
||||
__tablename__ = 'messages'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
message_id = Column(get_string_field(100), nullable=False, index=True)
|
||||
time = Column(Float, nullable=False)
|
||||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
||||
reply_to = Column(Text, nullable=True)
|
||||
interest_value = Column(Float, nullable=True)
|
||||
is_mentioned = Column(Boolean, nullable=True)
|
||||
|
||||
# 从 chat_info 扁平化而来的字段
|
||||
chat_info_stream_id = Column(Text, nullable=False)
|
||||
chat_info_platform = Column(Text, nullable=False)
|
||||
chat_info_user_platform = Column(Text, nullable=False)
|
||||
chat_info_user_id = Column(Text, nullable=False)
|
||||
chat_info_user_nickname = Column(Text, nullable=False)
|
||||
chat_info_user_cardname = Column(Text, nullable=True)
|
||||
chat_info_group_platform = Column(Text, nullable=True)
|
||||
chat_info_group_id = Column(Text, nullable=True)
|
||||
chat_info_group_name = Column(Text, nullable=True)
|
||||
chat_info_create_time = Column(Float, nullable=False)
|
||||
chat_info_last_active_time = Column(Float, nullable=False)
|
||||
|
||||
# 从顶层 user_info 扁平化而来的字段
|
||||
user_platform = Column(Text, nullable=True)
|
||||
user_id = Column(get_string_field(100), nullable=True, index=True)
|
||||
user_nickname = Column(Text, nullable=True)
|
||||
user_cardname = Column(Text, nullable=True)
|
||||
|
||||
processed_plain_text = Column(Text, nullable=True)
|
||||
display_message = Column(Text, nullable=True)
|
||||
memorized_times = Column(Integer, nullable=False, default=0)
|
||||
priority_mode = Column(Text, nullable=True)
|
||||
priority_info = Column(Text, nullable=True)
|
||||
additional_config = Column(Text, nullable=True)
|
||||
is_emoji = Column(Boolean, nullable=False, default=False)
|
||||
is_picid = Column(Boolean, nullable=False, default=False)
|
||||
is_command = Column(Boolean, nullable=False, default=False)
|
||||
is_notify = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_messages_message_id', 'message_id'),
|
||||
Index('idx_messages_chat_id', 'chat_id'),
|
||||
Index('idx_messages_time', 'time'),
|
||||
Index('idx_messages_user_id', 'user_id'),
|
||||
)
|
||||
|
||||
|
||||
class ActionRecords(Base):
|
||||
"""动作记录模型"""
|
||||
__tablename__ = 'action_records'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
action_id = Column(get_string_field(100), nullable=False, index=True)
|
||||
time = Column(Float, nullable=False)
|
||||
action_name = Column(Text, nullable=False)
|
||||
action_data = Column(Text, nullable=False)
|
||||
action_done = Column(Boolean, nullable=False, default=False)
|
||||
action_build_into_prompt = Column(Boolean, nullable=False, default=False)
|
||||
action_prompt_display = Column(Text, nullable=False)
|
||||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
||||
chat_info_stream_id = Column(Text, nullable=False)
|
||||
chat_info_platform = Column(Text, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_actionrecords_action_id', 'action_id'),
|
||||
Index('idx_actionrecords_chat_id', 'chat_id'),
|
||||
Index('idx_actionrecords_time', 'time'),
|
||||
)
|
||||
|
||||
|
||||
class Images(Base):
|
||||
"""图像信息模型"""
|
||||
__tablename__ = 'images'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
image_id = Column(Text, nullable=False, default="")
|
||||
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
|
||||
description = Column(Text, nullable=True)
|
||||
path = Column(get_string_field(500), nullable=False, unique=True)
|
||||
count = Column(Integer, nullable=False, default=1)
|
||||
timestamp = Column(Float, nullable=False)
|
||||
type = Column(Text, nullable=False)
|
||||
vlm_processed = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_images_emoji_hash', 'emoji_hash'),
|
||||
Index('idx_images_path', 'path'),
|
||||
)
|
||||
|
||||
|
||||
class ImageDescriptions(Base):
|
||||
"""图像描述信息模型"""
|
||||
__tablename__ = 'image_descriptions'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
type = Column(Text, nullable=False)
|
||||
image_description_hash = Column(get_string_field(64), nullable=False, index=True)
|
||||
description = Column(Text, nullable=False)
|
||||
timestamp = Column(Float, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_imagedesc_hash', 'image_description_hash'),
|
||||
)
|
||||
|
||||
|
||||
class OnlineTime(Base):
|
||||
"""在线时长记录模型"""
|
||||
__tablename__ = 'online_time'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
timestamp = Column(Text, nullable=False, default=str(datetime.datetime.now))
|
||||
duration = Column(Integer, nullable=False)
|
||||
start_timestamp = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||
end_timestamp = Column(DateTime, nullable=False, index=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_onlinetime_end_timestamp', 'end_timestamp'),
|
||||
)
|
||||
|
||||
|
||||
class PersonInfo(Base):
|
||||
"""人物信息模型"""
|
||||
__tablename__ = 'person_info'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
person_id = Column(get_string_field(100), nullable=False, unique=True, index=True)
|
||||
person_name = Column(Text, nullable=True)
|
||||
name_reason = Column(Text, nullable=True)
|
||||
platform = Column(Text, nullable=False)
|
||||
user_id = Column(get_string_field(50), nullable=False, index=True)
|
||||
nickname = Column(Text, nullable=True)
|
||||
impression = Column(Text, nullable=True)
|
||||
short_impression = Column(Text, nullable=True)
|
||||
points = Column(Text, nullable=True)
|
||||
forgotten_points = Column(Text, nullable=True)
|
||||
info_list = Column(Text, nullable=True)
|
||||
know_times = Column(Float, nullable=True)
|
||||
know_since = Column(Float, nullable=True)
|
||||
last_know = Column(Float, nullable=True)
|
||||
attitude = Column(Integer, nullable=True, default=50)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_personinfo_person_id', 'person_id'),
|
||||
Index('idx_personinfo_user_id', 'user_id'),
|
||||
)
|
||||
|
||||
|
||||
class Memory(Base):
|
||||
"""记忆模型"""
|
||||
__tablename__ = 'memory'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
memory_id = Column(get_string_field(64), nullable=False, index=True)
|
||||
chat_id = Column(Text, nullable=True)
|
||||
memory_text = Column(Text, nullable=True)
|
||||
keywords = Column(Text, nullable=True)
|
||||
create_time = Column(Float, nullable=True)
|
||||
last_view_time = Column(Float, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_memory_memory_id', 'memory_id'),
|
||||
)
|
||||
|
||||
|
||||
class Expression(Base):
|
||||
"""表达风格模型"""
|
||||
__tablename__ = 'expression'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
situation = Column(Text, nullable=False)
|
||||
style = Column(Text, nullable=False)
|
||||
count = Column(Float, nullable=False)
|
||||
last_active_time = Column(Float, nullable=False)
|
||||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
||||
type = Column(Text, nullable=False)
|
||||
create_date = Column(Float, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_expression_chat_id', 'chat_id'),
|
||||
)
|
||||
|
||||
|
||||
class ThinkingLog(Base):
|
||||
"""思考日志模型"""
|
||||
__tablename__ = 'thinking_logs'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
||||
trigger_text = Column(Text, nullable=True)
|
||||
response_text = Column(Text, nullable=True)
|
||||
trigger_info_json = Column(Text, nullable=True)
|
||||
response_info_json = Column(Text, nullable=True)
|
||||
timing_results_json = Column(Text, nullable=True)
|
||||
chat_history_json = Column(Text, nullable=True)
|
||||
chat_history_in_thinking_json = Column(Text, nullable=True)
|
||||
chat_history_after_response_json = Column(Text, nullable=True)
|
||||
heartflow_data_json = Column(Text, nullable=True)
|
||||
reasoning_data_json = Column(Text, nullable=True)
|
||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_thinkinglog_chat_id', 'chat_id'),
|
||||
)
|
||||
|
||||
|
||||
class GraphNodes(Base):
|
||||
"""记忆图节点模型"""
|
||||
__tablename__ = 'graph_nodes'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
concept = Column(get_string_field(255), nullable=False, unique=True, index=True)
|
||||
memory_items = Column(Text, nullable=False)
|
||||
hash = Column(Text, nullable=False)
|
||||
created_time = Column(Float, nullable=False)
|
||||
last_modified = Column(Float, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_graphnodes_concept', 'concept'),
|
||||
)
|
||||
|
||||
|
||||
class GraphEdges(Base):
|
||||
"""记忆图边模型"""
|
||||
__tablename__ = 'graph_edges'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
source = Column(get_string_field(255), nullable=False, index=True)
|
||||
target = Column(get_string_field(255), nullable=False, index=True)
|
||||
strength = Column(Integer, nullable=False)
|
||||
hash = Column(Text, nullable=False)
|
||||
created_time = Column(Float, nullable=False)
|
||||
last_modified = Column(Float, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_graphedges_source', 'source'),
|
||||
Index('idx_graphedges_target', 'target'),
|
||||
)
|
||||
|
||||
|
||||
# 数据库引擎和会话管理
|
||||
_engine = None
|
||||
_SessionLocal = None
|
||||
|
||||
|
||||
def get_database_url():
|
||||
"""获取数据库连接URL"""
|
||||
config = global_config.database
|
||||
|
||||
if config.database_type == "mysql":
|
||||
# 对用户名和密码进行URL编码,处理特殊字符
|
||||
from urllib.parse import quote_plus
|
||||
encoded_user = quote_plus(config.mysql_user)
|
||||
encoded_password = quote_plus(config.mysql_password)
|
||||
|
||||
return (
|
||||
f"mysql+pymysql://{encoded_user}:{encoded_password}"
|
||||
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
|
||||
f"?charset={config.mysql_charset}"
|
||||
)
|
||||
else: # SQLite
|
||||
# 如果是相对路径,则相对于项目根目录
|
||||
if not os.path.isabs(config.sqlite_path):
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
|
||||
else:
|
||||
db_path = config.sqlite_path
|
||||
|
||||
# 确保数据库目录存在
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
|
||||
return f"sqlite:///{db_path}"
|
||||
|
||||
|
||||
def initialize_database():
|
||||
"""初始化数据库引擎和会话"""
|
||||
global _engine, _SessionLocal
|
||||
|
||||
if _engine is not None:
|
||||
return _engine, _SessionLocal
|
||||
|
||||
database_url = get_database_url()
|
||||
config = global_config.database
|
||||
|
||||
# 配置引擎参数
|
||||
engine_kwargs = {
|
||||
'echo': False, # 生产环境关闭SQL日志
|
||||
'future': True,
|
||||
}
|
||||
|
||||
if config.database_type == "mysql":
|
||||
# MySQL连接池配置
|
||||
engine_kwargs.update({
|
||||
'poolclass': QueuePool,
|
||||
'pool_size': config.connection_pool_size,
|
||||
'max_overflow': config.connection_pool_size * 2,
|
||||
'pool_timeout': config.connection_timeout,
|
||||
'pool_recycle': 3600, # 1小时回收连接
|
||||
'pool_pre_ping': True, # 连接前ping检查
|
||||
'connect_args': {
|
||||
'autocommit': config.mysql_autocommit,
|
||||
'charset': config.mysql_charset,
|
||||
'connect_timeout': config.connection_timeout,
|
||||
'read_timeout': 30,
|
||||
'write_timeout': 30,
|
||||
}
|
||||
})
|
||||
else:
|
||||
# SQLite配置 - 添加连接池设置以避免连接耗尽
|
||||
engine_kwargs.update({
|
||||
'poolclass': QueuePool,
|
||||
'pool_size': 20, # 增加池大小
|
||||
'max_overflow': 30, # 增加溢出连接数
|
||||
'pool_timeout': 60, # 增加超时时间
|
||||
'pool_recycle': 3600, # 1小时回收连接
|
||||
'pool_pre_ping': True, # 连接前ping检查
|
||||
'connect_args': {
|
||||
'check_same_thread': False,
|
||||
'timeout': 30,
|
||||
}
|
||||
})
|
||||
|
||||
_engine = create_engine(database_url, **engine_kwargs)
|
||||
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=_engine)
|
||||
|
||||
# 创建所有表
|
||||
Base.metadata.create_all(bind=_engine)
|
||||
|
||||
logger.info(f"SQLAlchemy数据库初始化成功: {config.database_type}")
|
||||
return _engine, _SessionLocal
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_session():
|
||||
"""数据库会话上下文管理器 - 推荐使用这个而不是get_session()"""
|
||||
session = None
|
||||
try:
|
||||
_, SessionLocal = initialize_database()
|
||||
session = SessionLocal()
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
if session:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
if session:
|
||||
session.close()
|
||||
|
||||
|
||||
def get_engine():
|
||||
"""获取数据库引擎"""
|
||||
engine, _ = initialize_database()
|
||||
return engine
|
||||
@@ -373,6 +373,7 @@ MODULE_COLORS = {
|
||||
"base_command": "\033[38;5;208m", # 橙色
|
||||
"component_registry": "\033[38;5;214m", # 橙黄色
|
||||
"stream_api": "\033[38;5;220m", # 黄色
|
||||
"plugin_hot_reload": "\033[38;5;226m", #品红色
|
||||
"config_api": "\033[38;5;226m", # 亮黄色
|
||||
"heartflow_api": "\033[38;5;154m", # 黄绿色
|
||||
"action_apis": "\033[38;5;118m", # 绿色
|
||||
@@ -406,6 +407,7 @@ MODULE_COLORS = {
|
||||
"base_action": "\033[38;5;250m", # 浅灰色
|
||||
# 数据库和消息
|
||||
"database_model": "\033[38;5;94m", # 橙褐色
|
||||
"database": "\033[38;5;46m", # 橙褐色
|
||||
"maim_message": "\033[38;5;140m", # 紫褐色
|
||||
# 日志系统
|
||||
"logger": "\033[38;5;8m", # 深灰色
|
||||
@@ -430,6 +432,8 @@ MODULE_ALIASES = {
|
||||
"memory_activator": "记忆",
|
||||
"tool_use": "工具",
|
||||
"expressor": "表达方式",
|
||||
"plugin_hot_reload": "热重载",
|
||||
"database": "数据库",
|
||||
"database_model": "数据库",
|
||||
"mood": "情绪",
|
||||
"memory": "记忆",
|
||||
|
||||
@@ -1,20 +1,26 @@
|
||||
import traceback
|
||||
|
||||
from typing import List, Any, Optional
|
||||
from peewee import Model # 添加 Peewee Model 导入
|
||||
from typing import List, Optional, Any, Dict
|
||||
from sqlalchemy import not_, select, func
|
||||
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from src.config.config import global_config
|
||||
|
||||
from src.common.database.database_model import Messages
|
||||
# from src.common.database.database_model import Messages
|
||||
from src.common.database.sqlalchemy_models import Messages
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
def _model_to_dict(model_instance: Model) -> dict[str, Any]:
|
||||
def _model_to_dict(instance: Base) -> Dict[str, Any]:
|
||||
"""
|
||||
将 Peewee 模型实例转换为字典。
|
||||
将 SQLAlchemy 模型实例转换为字典。
|
||||
"""
|
||||
return model_instance.__data__
|
||||
return {col.name: getattr(instance, col.name) for col in instance.__table__.columns}
|
||||
|
||||
|
||||
def find_messages(
|
||||
@@ -38,7 +44,8 @@ def find_messages(
|
||||
消息字典列表,如果出错则返回空列表。
|
||||
"""
|
||||
try:
|
||||
query = Messages.select()
|
||||
session = get_session()
|
||||
query = select(Messages)
|
||||
|
||||
# 应用过滤器
|
||||
if message_filter:
|
||||
@@ -77,42 +84,57 @@ def find_messages(
|
||||
query = query.where(Messages.user_id != global_config.bot.qq_account)
|
||||
|
||||
if filter_command:
|
||||
query = query.where(not Messages.is_command)
|
||||
query = query.where(not_(Messages.is_command))
|
||||
|
||||
if limit > 0:
|
||||
# 确保limit是正整数
|
||||
limit = max(1, int(limit))
|
||||
|
||||
if limit_mode == "earliest":
|
||||
# 获取时间最早的 limit 条记录,已经是正序
|
||||
query = query.order_by(Messages.time.asc()).limit(limit)
|
||||
peewee_results = list(query)
|
||||
try:
|
||||
results = session.execute(query).scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"执行earliest查询失败: {e}")
|
||||
results = []
|
||||
else: # 默认为 'latest'
|
||||
# 获取时间最晚的 limit 条记录
|
||||
query = query.order_by(Messages.time.desc()).limit(limit)
|
||||
latest_results_peewee = list(query)
|
||||
# 将结果按时间正序排列
|
||||
peewee_results = sorted(latest_results_peewee, key=lambda msg: msg.time)
|
||||
try:
|
||||
latest_results = session.execute(query).scalars().all()
|
||||
# 将结果按时间正序排列
|
||||
results = sorted(latest_results, key=lambda msg: msg.time)
|
||||
except Exception as e:
|
||||
logger.error(f"执行latest查询失败: {e}")
|
||||
results = []
|
||||
else:
|
||||
# limit 为 0 时,应用传入的 sort 参数
|
||||
if sort:
|
||||
peewee_sort_terms = []
|
||||
sort_terms = []
|
||||
for field_name, direction in sort:
|
||||
if hasattr(Messages, field_name):
|
||||
field = getattr(Messages, field_name)
|
||||
if direction == 1: # ASC
|
||||
peewee_sort_terms.append(field.asc())
|
||||
sort_terms.append(field.asc())
|
||||
elif direction == -1: # DESC
|
||||
peewee_sort_terms.append(field.desc())
|
||||
sort_terms.append(field.desc())
|
||||
else:
|
||||
logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
|
||||
else:
|
||||
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
|
||||
if peewee_sort_terms:
|
||||
query = query.order_by(*peewee_sort_terms)
|
||||
peewee_results = list(query)
|
||||
if sort_terms:
|
||||
query = query.order_by(*sort_terms)
|
||||
try:
|
||||
results = session.execute(query).scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"执行无限制查询失败: {e}")
|
||||
results = []
|
||||
|
||||
return [_model_to_dict(msg) for msg in peewee_results]
|
||||
return [_model_to_dict(msg) for msg in results]
|
||||
except Exception as e:
|
||||
log_message = (
|
||||
f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
|
||||
f"使用 SQLAlchemy 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
|
||||
+ traceback.format_exc()
|
||||
)
|
||||
logger.error(log_message)
|
||||
@@ -130,7 +152,8 @@ def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
符合条件的消息数量,如果出错则返回 0。
|
||||
"""
|
||||
try:
|
||||
query = Messages.select()
|
||||
session = get_session()
|
||||
query = select(func.count(Messages.id))
|
||||
|
||||
# 应用过滤器
|
||||
if message_filter:
|
||||
@@ -167,14 +190,14 @@ def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
if conditions:
|
||||
query = query.where(*conditions)
|
||||
|
||||
count = query.count()
|
||||
return count
|
||||
count = session.execute(query).scalar()
|
||||
return count or 0
|
||||
except Exception as e:
|
||||
log_message = f"使用 Peewee 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
|
||||
log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
|
||||
logger.error(log_message)
|
||||
return 0
|
||||
|
||||
|
||||
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
|
||||
# 注意:对于 Peewee,插入操作通常是 Messages.create(...) 或 instance.save()。
|
||||
# 查找单个消息可以是 Messages.get_or_none(...) 或 query.first()。
|
||||
# 注意:对于 SQLAlchemy,插入操作通常是使用 session.add() 和 session.commit()。
|
||||
# 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。
|
||||
|
||||
@@ -13,6 +13,7 @@ from typing import List, Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config_base import ConfigBase
|
||||
from src.config.official_configs import (
|
||||
DatabaseConfig,
|
||||
BotConfig,
|
||||
PersonalityConfig,
|
||||
ExpressionConfig,
|
||||
@@ -340,6 +341,7 @@ class Config(ConfigBase):
|
||||
|
||||
MMC_VERSION: str = field(default=MMC_VERSION, repr=False, init=False) # 硬编码的版本信息
|
||||
|
||||
database: DatabaseConfig
|
||||
bot: BotConfig
|
||||
personality: PersonalityConfig
|
||||
relationship: RelationshipConfig
|
||||
@@ -466,4 +468,25 @@ update_model_config()
|
||||
logger.info("正在品鉴配置文件...")
|
||||
global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml"))
|
||||
model_config = api_ada_load_config(config_path=os.path.join(CONFIG_DIR, "model_config.toml"))
|
||||
|
||||
# 初始化数据库连接
|
||||
logger.info("正在初始化数据库连接...")
|
||||
from src.common.database.database import initialize_sql_database
|
||||
try:
|
||||
initialize_sql_database(global_config.database)
|
||||
logger.info(f"数据库连接初始化成功,使用 {global_config.database.database_type} 数据库")
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接初始化失败: {e}")
|
||||
raise e
|
||||
|
||||
# 初始化数据库表结构
|
||||
logger.info("正在初始化数据库表结构...")
|
||||
from src.common.database.sqlalchemy_models import initialize_database as init_db
|
||||
try:
|
||||
init_db()
|
||||
logger.info("数据库表结构初始化完成")
|
||||
except Exception as e:
|
||||
logger.error(f"数据库表结构初始化失败: {e}")
|
||||
raise e
|
||||
|
||||
logger.info("非常的新鲜,非常的美味!")
|
||||
|
||||
@@ -13,6 +13,65 @@ from src.config.config_base import ConfigBase
|
||||
4. 对于新增的字段,若为可选项,则应在其后添加field()并设置default_factory或default
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class DatabaseConfig(ConfigBase):
|
||||
"""数据库配置类"""
|
||||
|
||||
database_type: Literal["sqlite", "mysql"] = "sqlite"
|
||||
"""数据库类型,支持 sqlite 或 mysql"""
|
||||
|
||||
# SQLite 配置
|
||||
sqlite_path: str = "data/MaiBot.db"
|
||||
"""SQLite数据库文件路径"""
|
||||
|
||||
# MySQL 配置
|
||||
mysql_host: str = "localhost"
|
||||
"""MySQL服务器地址"""
|
||||
|
||||
mysql_port: int = 3306
|
||||
"""MySQL服务器端口"""
|
||||
|
||||
mysql_database: str = "maibot"
|
||||
"""MySQL数据库名"""
|
||||
|
||||
mysql_user: str = "root"
|
||||
"""MySQL用户名"""
|
||||
|
||||
mysql_password: str = ""
|
||||
"""MySQL密码"""
|
||||
|
||||
mysql_charset: str = "utf8mb4"
|
||||
"""MySQL字符集"""
|
||||
|
||||
mysql_unix_socket: str = ""
|
||||
"""MySQL Unix套接字路径(可选,用于本地连接,优先于host/port)"""
|
||||
|
||||
# MySQL SSL 配置
|
||||
mysql_ssl_mode: str = "DISABLED"
|
||||
"""SSL模式: DISABLED, PREFERRED, REQUIRED, VERIFY_CA, VERIFY_IDENTITY"""
|
||||
|
||||
mysql_ssl_ca: str = ""
|
||||
"""SSL CA证书路径"""
|
||||
|
||||
mysql_ssl_cert: str = ""
|
||||
"""SSL客户端证书路径"""
|
||||
|
||||
mysql_ssl_key: str = ""
|
||||
"""SSL客户端密钥路径"""
|
||||
|
||||
# MySQL 高级配置
|
||||
mysql_autocommit: bool = True
|
||||
"""自动提交事务"""
|
||||
|
||||
mysql_sql_mode: str = "TRADITIONAL"
|
||||
"""SQL模式"""
|
||||
|
||||
# 连接池配置
|
||||
connection_pool_size: int = 10
|
||||
"""连接池大小(仅MySQL有效)"""
|
||||
|
||||
connection_timeout: int = 10
|
||||
"""连接超时时间(秒)"""
|
||||
|
||||
@dataclass
|
||||
class BotConfig(ConfigBase):
|
||||
@@ -72,6 +131,19 @@ class ChatConfig(ConfigBase):
|
||||
max_context_size: int = 18
|
||||
"""上下文长度"""
|
||||
|
||||
|
||||
replyer_random_probability: float = 0.5
|
||||
"""
|
||||
发言时选择推理模型的概率(0-1之间)
|
||||
选择普通模型的概率为 1 - reasoning_normal_model_probability
|
||||
"""
|
||||
|
||||
thinking_timeout: int = 40
|
||||
"""麦麦最长思考规划时间,超过这个时间的思考会放弃(往往是api反应太慢)"""
|
||||
|
||||
talk_frequency: float = 1
|
||||
"""回复频率阈值"""
|
||||
|
||||
mentioned_bot_inevitable_reply: bool = False
|
||||
"""提及 bot 必然回复"""
|
||||
|
||||
@@ -93,17 +165,17 @@ class ChatConfig(ConfigBase):
|
||||
"""
|
||||
统一的活跃度和专注度配置
|
||||
格式:[["platform:chat_id:type", "HH:MM,frequency", "HH:MM,frequency", ...], ...]
|
||||
|
||||
|
||||
全局配置示例:
|
||||
[["", "8:00,1", "12:00,2", "18:00,1.5", "00:00,0.5"]]
|
||||
|
||||
|
||||
特定聊天流配置示例:
|
||||
[
|
||||
["", "8:00,1", "12:00,1.2", "18:00,1.5", "01:00,0.6"], # 全局默认配置
|
||||
["qq:1026294844:group", "12:20,1", "16:10,2", "20:10,1", "00:10,0.3"], # 特定群聊配置
|
||||
["qq:729957033:private", "8:20,1", "12:10,2", "20:10,1.5", "00:10,0.2"] # 特定私聊配置
|
||||
]
|
||||
|
||||
|
||||
说明:
|
||||
- 当第一个元素为空字符串""时,表示全局默认配置
|
||||
- 当第一个元素为"platform:id:type"格式时,表示特定聊天流配置
|
||||
@@ -155,72 +227,11 @@ class ChatConfig(ConfigBase):
|
||||
|
||||
# 检查全局时段配置(第一个元素为空字符串的配置)
|
||||
global_frequency = self._get_global_frequency()
|
||||
return self.talk_frequency if global_frequency is None else global_frequency
|
||||
|
||||
def _get_global_focus_value(self) -> Optional[float]:
|
||||
"""
|
||||
获取全局默认专注度配置
|
||||
if global_frequency is not None:
|
||||
return global_frequency
|
||||
|
||||
Returns:
|
||||
float: 专注度值,如果没有配置则返回 None
|
||||
"""
|
||||
for config_item in self.focus_value_adjust:
|
||||
if not config_item or len(config_item) < 2:
|
||||
continue
|
||||
|
||||
# 检查是否为全局默认配置(第一个元素为空字符串)
|
||||
if config_item[0] == "":
|
||||
return self._get_time_based_focus_value(config_item[1:])
|
||||
|
||||
return None
|
||||
|
||||
def _get_time_based_focus_value(self, time_focus_list: list[str]) -> Optional[float]:
|
||||
"""
|
||||
根据时间配置列表获取当前时段的专注度
|
||||
|
||||
Args:
|
||||
time_focus_list: 时间专注度配置列表,格式为 ["HH:MM,focus_value", ...]
|
||||
|
||||
Returns:
|
||||
float: 专注度值,如果没有配置则返回 None
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
current_time = datetime.now().strftime("%H:%M")
|
||||
current_hour, current_minute = map(int, current_time.split(":"))
|
||||
current_minutes = current_hour * 60 + current_minute
|
||||
|
||||
# 解析时间专注度配置
|
||||
time_focus_pairs = []
|
||||
for time_focus_str in time_focus_list:
|
||||
try:
|
||||
time_str, focus_str = time_focus_str.split(",")
|
||||
hour, minute = map(int, time_str.split(":"))
|
||||
focus_value = float(focus_str)
|
||||
minutes = hour * 60 + minute
|
||||
time_focus_pairs.append((minutes, focus_value))
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
if not time_focus_pairs:
|
||||
return None
|
||||
|
||||
# 按时间排序
|
||||
time_focus_pairs.sort(key=lambda x: x[0])
|
||||
|
||||
# 查找当前时间对应的专注度
|
||||
current_focus_value = None
|
||||
for minutes, focus_value in time_focus_pairs:
|
||||
if current_minutes >= minutes:
|
||||
current_focus_value = focus_value
|
||||
else:
|
||||
break
|
||||
|
||||
# 如果当前时间在所有配置时间之前,使用最后一个时间段的专注度(跨天逻辑)
|
||||
if current_focus_value is None and time_focus_pairs:
|
||||
current_focus_value = time_focus_pairs[-1][1]
|
||||
|
||||
return current_focus_value
|
||||
# 如果都没有匹配,返回默认值
|
||||
return self.talk_frequency
|
||||
|
||||
def _get_time_based_frequency(self, time_freq_list: list[str]) -> Optional[float]:
|
||||
"""
|
||||
@@ -395,6 +406,14 @@ class MessageReceiveConfig(ConfigBase):
|
||||
ban_msgs_regex: set[str] = field(default_factory=lambda: set())
|
||||
"""过滤正则表达式列表"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class NormalChatConfig(ConfigBase):
|
||||
"""普通聊天配置类"""
|
||||
|
||||
willing_mode: str = "classical"
|
||||
"""意愿模式"""
|
||||
|
||||
@dataclass
|
||||
class ExpressionConfig(ConfigBase):
|
||||
"""表达配置类"""
|
||||
@@ -403,14 +422,14 @@ class ExpressionConfig(ConfigBase):
|
||||
"""
|
||||
表达学习配置列表,支持按聊天流配置
|
||||
格式: [["chat_stream_id", "use_expression", "enable_learning", learning_intensity], ...]
|
||||
|
||||
|
||||
示例:
|
||||
[
|
||||
["", "enable", "enable", 1.0], # 全局配置:使用表达,启用学习,学习强度1.0
|
||||
["qq:1919810:private", "enable", "enable", 1.5], # 特定私聊配置:使用表达,启用学习,学习强度1.5
|
||||
["qq:114514:private", "enable", "disable", 0.5], # 特定私聊配置:使用表达,禁用学习,学习强度0.5
|
||||
]
|
||||
|
||||
|
||||
说明:
|
||||
- 第一位: chat_stream_id,空字符串表示全局配置
|
||||
- 第二位: 是否使用学到的表达 ("enable"/"disable")
|
||||
@@ -475,14 +494,14 @@ class ExpressionConfig(ConfigBase):
|
||||
|
||||
# 优先检查聊天流特定的配置
|
||||
if chat_stream_id:
|
||||
specific_expression_config = self._get_stream_specific_config(chat_stream_id)
|
||||
if specific_expression_config is not None:
|
||||
return specific_expression_config
|
||||
specific_config = self._get_stream_specific_config(chat_stream_id)
|
||||
if specific_config is not None:
|
||||
return specific_config
|
||||
|
||||
# 检查全局配置(第一个元素为空字符串的配置)
|
||||
global_expression_config = self._get_global_config()
|
||||
if global_expression_config is not None:
|
||||
return global_expression_config
|
||||
global_config = self._get_global_config()
|
||||
if global_config is not None:
|
||||
return global_config
|
||||
|
||||
# 如果都没有匹配,返回默认值
|
||||
return True, True, 300
|
||||
@@ -518,10 +537,10 @@ class ExpressionConfig(ConfigBase):
|
||||
|
||||
# 解析配置
|
||||
try:
|
||||
use_expression: bool = config_item[1].lower() == "enable"
|
||||
enable_learning: bool = config_item[2].lower() == "enable"
|
||||
learning_intensity: float = float(config_item[3])
|
||||
return use_expression, enable_learning, learning_intensity # type: ignore
|
||||
use_expression = config_item[1].lower() == "enable"
|
||||
enable_learning = config_item[2].lower() == "enable"
|
||||
learning_intensity = float(config_item[3])
|
||||
return use_expression, enable_learning, learning_intensity
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
@@ -541,10 +560,10 @@ class ExpressionConfig(ConfigBase):
|
||||
# 检查是否为全局配置(第一个元素为空字符串)
|
||||
if config_item[0] == "":
|
||||
try:
|
||||
use_expression: bool = config_item[1].lower() == "enable"
|
||||
enable_learning: bool = config_item[2].lower() == "enable"
|
||||
use_expression = config_item[1].lower() == "enable"
|
||||
enable_learning = config_item[2].lower() == "enable"
|
||||
learning_intensity = float(config_item[3])
|
||||
return use_expression, enable_learning, learning_intensity # type: ignore
|
||||
return use_expression, enable_learning, learning_intensity
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
@@ -558,7 +577,6 @@ class ToolConfig(ConfigBase):
|
||||
enable_tool: bool = False
|
||||
"""是否在聊天中启用工具"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoiceConfig(ConfigBase):
|
||||
"""语音识别配置类"""
|
||||
@@ -703,7 +721,6 @@ class KeywordReactionConfig(ConfigBase):
|
||||
if not isinstance(rule, KeywordRuleConfig):
|
||||
raise ValueError(f"规则必须是KeywordRuleConfig类型,而不是{type(rule).__name__}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomPromptConfig(ConfigBase):
|
||||
"""自定义提示词配置类"""
|
||||
@@ -852,3 +869,4 @@ class LPMMKnowledgeConfig(ConfigBase):
|
||||
|
||||
embedding_dimension: int = 1024
|
||||
"""嵌入向量维度,应该与模型的输出维度一致"""
|
||||
|
||||
|
||||
@@ -5,8 +5,7 @@ from PIL import Image
|
||||
from datetime import datetime
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db # 确保 db 被导入用于 create_tables
|
||||
from src.common.database.database_model import LLMUsage
|
||||
from src.common.database.sqlalchemy_models import LLMUsage, get_session
|
||||
from src.config.api_ada_configs import ModelInfo
|
||||
from .payload_content.message import Message, MessageBuilder
|
||||
from .model_client.base_client import UsageRecord
|
||||
@@ -143,16 +142,9 @@ def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 *
|
||||
|
||||
class LLMUsageRecorder:
|
||||
"""
|
||||
LLM使用情况记录器
|
||||
LLM使用情况记录器(SQLAlchemy版本)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
# 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误
|
||||
db.create_tables([LLMUsage], safe=True)
|
||||
# logger.debug("LLMUsage 表已初始化/确保存在。")
|
||||
except Exception as e:
|
||||
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
|
||||
|
||||
def record_usage_to_database(
|
||||
self, model_info: ModelInfo, model_usage: UsageRecord, user_id: str, request_type: str, endpoint: str, time_cost: float = 0.0
|
||||
@@ -160,9 +152,13 @@ class LLMUsageRecorder:
|
||||
input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in
|
||||
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
|
||||
total_cost = round(input_cost + output_cost, 6)
|
||||
|
||||
session = None
|
||||
try:
|
||||
# 使用 Peewee 模型创建记录
|
||||
LLMUsage.create(
|
||||
# 使用 SQLAlchemy 会话创建记录
|
||||
session = get_session()
|
||||
|
||||
usage_record = LLMUsage(
|
||||
model_name=model_info.model_identifier,
|
||||
model_assign_name=model_info.name,
|
||||
model_api_provider=model_info.api_provider,
|
||||
@@ -175,8 +171,12 @@ class LLMUsageRecorder:
|
||||
cost=total_cost or 0.0,
|
||||
time_cost = round(time_cost or 0.0, 3),
|
||||
status="success",
|
||||
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
|
||||
timestamp=datetime.now(), # SQLAlchemy 会处理 DateTime 字段
|
||||
)
|
||||
|
||||
session.add(usage_record)
|
||||
session.commit()
|
||||
|
||||
logger.debug(
|
||||
f"Token使用情况 - 模型: {model_usage.model_name}, "
|
||||
f"用户: {user_id}, 类型: {request_type}, "
|
||||
@@ -184,6 +184,11 @@ class LLMUsageRecorder:
|
||||
f"总计: {model_usage.total_tokens}"
|
||||
)
|
||||
except Exception as e:
|
||||
if session:
|
||||
session.rollback()
|
||||
logger.error(f"记录token使用情况失败: {str(e)}")
|
||||
finally:
|
||||
if session:
|
||||
session.close()
|
||||
|
||||
llm_usage_recorder = LLMUsageRecorder()
|
||||
40
src/main.py
40
src/main.py
@@ -1,5 +1,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
import signal
|
||||
import sys
|
||||
from maim_message import MessageServer
|
||||
|
||||
from src.common.remote import TelemetryHeartBeatTask
|
||||
@@ -17,8 +19,9 @@ from rich.traceback import install
|
||||
from src.migrate_helper.migrate import check_and_run_migrations
|
||||
# from src.api.main import start_api_server
|
||||
|
||||
# 导入新的插件管理器
|
||||
# 导入新的插件管理器和热重载管理器
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
|
||||
|
||||
# 导入消息API和traceback模块
|
||||
from src.common.message import get_global_api
|
||||
@@ -48,6 +51,28 @@ class MainSystem:
|
||||
self.app: MessageServer = get_global_api()
|
||||
self.server: Server = get_global_server()
|
||||
|
||||
# 设置信号处理器用于优雅退出
|
||||
self._setup_signal_handlers()
|
||||
|
||||
def _setup_signal_handlers(self):
|
||||
"""设置信号处理器"""
|
||||
def signal_handler(signum, frame):
|
||||
logger.info("收到退出信号,正在优雅关闭系统...")
|
||||
self._cleanup()
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
def _cleanup(self):
|
||||
"""清理资源"""
|
||||
try:
|
||||
# 停止插件热重载系统
|
||||
hot_reload_manager.stop()
|
||||
logger.info("🛑 插件热重载系统已停止")
|
||||
except Exception as e:
|
||||
logger.error(f"停止热重载系统时出错: {e}")
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化系统组件"""
|
||||
logger.info(f"正在唤醒{global_config.bot.nickname}......")
|
||||
@@ -58,14 +83,7 @@ class MainSystem:
|
||||
logger.info(f"""
|
||||
--------------------------------
|
||||
全部系统初始化完成,{global_config.bot.nickname}已成功唤醒
|
||||
--------------------------------
|
||||
如果想要自定义{global_config.bot.nickname}的功能,请查阅:https://docs.mai-mai.org/manual/usage/
|
||||
或者遇到了问题,请访问我们的文档:https://docs.mai-mai.org/
|
||||
--------------------------------
|
||||
如果你想要编写或了解插件相关内容,请访问开发文档https://docs.mai-mai.org/develop/
|
||||
--------------------------------
|
||||
如果你需要查阅模型的消耗以及麦麦的统计数据,请访问根目录的maibot_statistics.html文件
|
||||
""")
|
||||
--------------------------------""")
|
||||
|
||||
async def _init_components(self):
|
||||
"""初始化其他组件"""
|
||||
@@ -87,6 +105,10 @@ class MainSystem:
|
||||
# 加载所有actions,包括默认的和插件的
|
||||
plugin_manager.load_all_plugins()
|
||||
|
||||
# 启动插件热重载系统
|
||||
|
||||
hot_reload_manager.start()
|
||||
|
||||
# 初始化表情管理器
|
||||
get_emoji_manager().initialize()
|
||||
logger.info("表情包管理器初始化成功")
|
||||
|
||||
0
src/person_info/fix_session.py
Normal file
0
src/person_info/fix_session.py
Normal file
@@ -2,17 +2,17 @@ import hashlib
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
|
||||
from json_repair import repair_json
|
||||
from typing import Union
|
||||
|
||||
from typing import Any, Callable, Dict, Union, Optional
|
||||
from sqlalchemy import select
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import PersonInfo
|
||||
from src.common.database.sqlalchemy_models import PersonInfo
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
|
||||
session = get_session()
|
||||
|
||||
logger = get_logger("person_info")
|
||||
|
||||
@@ -380,36 +380,282 @@ class Person:
|
||||
|
||||
return relation_info
|
||||
|
||||
# 统一的会话管理函数
|
||||
def with_session(func):
|
||||
"""装饰器:为函数自动注入session参数"""
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
|
||||
return await func(session, *args, **kwargs)
|
||||
return async_wrapper
|
||||
else:
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
|
||||
return func(session, *args, **kwargs)
|
||||
return sync_wrapper
|
||||
|
||||
# 全局会话获取函数,用于替换所有裸露的session使用
|
||||
def _get_session():
|
||||
"""获取数据库会话的统一函数"""
|
||||
return get_session()
|
||||
|
||||
|
||||
class PersonInfoManager:
|
||||
def __init__(self):
|
||||
|
||||
"""初始化PersonInfoManager"""
|
||||
from src.common.database.sqlalchemy_models import PersonInfo
|
||||
self.person_name_list = {}
|
||||
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
# 设置连接池参数
|
||||
# 设置连接池参数(仅对SQLite有效)
|
||||
if hasattr(db, "execute_sql"):
|
||||
# 设置SQLite优化参数
|
||||
db.execute_sql("PRAGMA cache_size = -64000") # 64MB缓存
|
||||
db.execute_sql("PRAGMA temp_store = memory") # 临时存储在内存中
|
||||
db.execute_sql("PRAGMA mmap_size = 268435456") # 256MB内存映射
|
||||
# 检查数据库类型,只对SQLite执行PRAGMA语句
|
||||
if global_config.database.database_type == "sqlite":
|
||||
# 设置SQLite优化参数
|
||||
db.execute_sql("PRAGMA cache_size = -64000") # 64MB缓存
|
||||
db.execute_sql("PRAGMA temp_store = memory") # 临时存储在内存中
|
||||
db.execute_sql("PRAGMA mmap_size = 268435456") # 256MB内存映射
|
||||
db.create_tables([PersonInfo], safe=True)
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}")
|
||||
|
||||
# 初始化时读取所有person_name
|
||||
try:
|
||||
for record in PersonInfo.select(PersonInfo.person_id, PersonInfo.person_name).where(
|
||||
PersonInfo.person_name.is_null(False)
|
||||
):
|
||||
from src.common.database.sqlalchemy_models import PersonInfo
|
||||
# 在这里获取会话
|
||||
for record in session.execute(select(PersonInfo.person_id, PersonInfo.person_name).where(
|
||||
PersonInfo.person_name.is_not(None)
|
||||
)).fetchall():
|
||||
if record.person_name:
|
||||
self.person_name_list[record.person_id] = record.person_name
|
||||
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)")
|
||||
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (SQLAlchemy)")
|
||||
except Exception as e:
|
||||
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
|
||||
|
||||
logger.error(f"从 SQLAlchemy 加载 person_name_list 失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
||||
"""获取唯一id"""
|
||||
if "-" in platform:
|
||||
platform = platform.split("-")[1]
|
||||
|
||||
components = [platform, str(user_id)]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
async def is_person_known(self, platform: str, user_id: int):
|
||||
"""判断是否认识某人"""
|
||||
person_id = self.get_person_id(platform, user_id)
|
||||
|
||||
def _db_check_known_sync(p_id: str):
|
||||
# 在需要时获取会话
|
||||
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar() is not None
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_db_check_known_sync, person_id)
|
||||
except Exception as e:
|
||||
logger.error(f"检查用户 {person_id} 是否已知时出错 (SQLAlchemy): {e}")
|
||||
return False
|
||||
|
||||
def get_person_id_by_person_name(self, person_name: str) -> str:
|
||||
"""根据用户名获取用户ID"""
|
||||
try:
|
||||
# 在需要时获取会话
|
||||
record = session.execute(select(PersonInfo).where(PersonInfo.person_name == person_name)).scalar()
|
||||
return record.person_id if record else ""
|
||||
except Exception as e:
|
||||
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (SQLAlchemy): {e}")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
async def create_person_info(person_id: str, data: Optional[dict] = None):
|
||||
"""创建一个项"""
|
||||
if not person_id:
|
||||
logger.debug("创建失败,person_id不存在")
|
||||
return
|
||||
|
||||
_person_info_default = copy.deepcopy(person_info_default)
|
||||
# 获取 SQLAlchemy 模型的所有字段名
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
|
||||
final_data = {"person_id": person_id}
|
||||
|
||||
# Start with defaults for all model fields
|
||||
for key, default_value in _person_info_default.items():
|
||||
if key in model_fields:
|
||||
final_data[key] = default_value
|
||||
|
||||
# Override with provided data
|
||||
if data:
|
||||
for key, value in data.items():
|
||||
if key in model_fields:
|
||||
final_data[key] = value
|
||||
|
||||
# Ensure person_id is correctly set from the argument
|
||||
final_data["person_id"] = person_id
|
||||
|
||||
# Serialize JSON fields
|
||||
for key in JSON_SERIALIZED_FIELDS:
|
||||
if key in final_data:
|
||||
if isinstance(final_data[key], (list, dict)):
|
||||
final_data[key] = json.dumps(final_data[key], ensure_ascii=False)
|
||||
elif final_data[key] is None: # Default for lists is [], store as "[]"
|
||||
final_data[key] = json.dumps([], ensure_ascii=False)
|
||||
# If it's already a string, assume it's valid JSON or a non-JSON string field
|
||||
|
||||
def _db_create_sync(p_data: dict):
|
||||
try:
|
||||
new_person = PersonInfo(**p_data)
|
||||
session.add(new_person)
|
||||
session.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}")
|
||||
return False
|
||||
|
||||
await asyncio.to_thread(_db_create_sync, final_data)
|
||||
|
||||
async def _safe_create_person_info(self, person_id: str, data: Optional[dict] = None):
|
||||
"""安全地创建用户信息,处理竞态条件"""
|
||||
if not person_id:
|
||||
logger.debug("创建失败,person_id不存在")
|
||||
return
|
||||
|
||||
_person_info_default = copy.deepcopy(person_info_default)
|
||||
# 获取 SQLAlchemy 模型的所有字段名
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
|
||||
final_data = {"person_id": person_id}
|
||||
|
||||
# Start with defaults for all model fields
|
||||
for key, default_value in _person_info_default.items():
|
||||
if key in model_fields:
|
||||
final_data[key] = default_value
|
||||
|
||||
# Override with provided data
|
||||
if data:
|
||||
for key, value in data.items():
|
||||
if key in model_fields:
|
||||
final_data[key] = value
|
||||
|
||||
# Ensure person_id is correctly set from the argument
|
||||
final_data["person_id"] = person_id
|
||||
|
||||
# Serialize JSON fields
|
||||
for key in JSON_SERIALIZED_FIELDS:
|
||||
if key in final_data:
|
||||
if isinstance(final_data[key], (list, dict)):
|
||||
final_data[key] = json.dumps(final_data[key], ensure_ascii=False)
|
||||
elif final_data[key] is None: # Default for lists is [], store as "[]"
|
||||
final_data[key] = json.dumps([], ensure_ascii=False)
|
||||
|
||||
def _db_safe_create_sync(p_data: dict):
|
||||
try:
|
||||
existing = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_data["person_id"])).scalar()
|
||||
if existing:
|
||||
logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建")
|
||||
return True
|
||||
|
||||
# 尝试创建
|
||||
new_person = PersonInfo(**p_data)
|
||||
session.add(new_person)
|
||||
session.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
logger.debug(f"检测到并发创建用户 {p_data.get('person_id')},跳过错误")
|
||||
return True # 其他协程已创建,视为成功
|
||||
else:
|
||||
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (SQLAlchemy): {e}")
|
||||
return False
|
||||
|
||||
await asyncio.to_thread(_db_safe_create_sync, final_data)
|
||||
|
||||
async def update_one_field(self, person_id: str, field_name: str, value, data: Optional[Dict] = None):
|
||||
"""更新某一个字段,会补全"""
|
||||
# 获取 SQLAlchemy 模型的所有字段名
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
if field_name not in model_fields:
|
||||
logger.debug(f"更新'{field_name}'失败,未在 PersonInfo SQLAlchemy 模型中定义的字段。")
|
||||
return
|
||||
|
||||
processed_value = value
|
||||
if field_name in JSON_SERIALIZED_FIELDS:
|
||||
if isinstance(value, (list, dict)):
|
||||
processed_value = json.dumps(value, ensure_ascii=False, indent=None)
|
||||
elif value is None: # Store None as "[]" for JSON list fields
|
||||
processed_value = json.dumps([], ensure_ascii=False, indent=None)
|
||||
|
||||
def _db_update_sync(p_id: str, f_name: str, val_to_set):
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
|
||||
query_time = time.time()
|
||||
|
||||
if record:
|
||||
setattr(record, f_name, val_to_set)
|
||||
session.commit()
|
||||
save_time = time.time()
|
||||
|
||||
total_time = save_time - start_time
|
||||
if total_time > 0.5: # 如果超过500ms就记录日志
|
||||
logger.warning(
|
||||
f"数据库更新操作耗时 {total_time:.3f}秒 (查询: {query_time - start_time:.3f}s, 保存: {save_time - query_time:.3f}s) person_id={p_id}, field={f_name}"
|
||||
)
|
||||
|
||||
return True, False # Found and updated, no creation needed
|
||||
else:
|
||||
total_time = time.time() - start_time
|
||||
if total_time > 0.5:
|
||||
logger.warning(f"数据库查询操作耗时 {total_time:.3f}秒 person_id={p_id}, field={f_name}")
|
||||
return False, True # Not found, needs creation
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
total_time = time.time() - start_time
|
||||
logger.error(f"数据库操作异常,耗时 {total_time:.3f}秒: {e}")
|
||||
raise
|
||||
|
||||
found, needs_creation = await asyncio.to_thread(_db_update_sync, person_id, field_name, processed_value)
|
||||
|
||||
if needs_creation:
|
||||
logger.info(f"{person_id} 不存在,将新建。")
|
||||
creation_data = data if data is not None else {}
|
||||
# Ensure platform and user_id are present for context if available from 'data'
|
||||
# but primarily, set the field that triggered the update.
|
||||
# The create_person_info will handle defaults and serialization.
|
||||
creation_data[field_name] = value # Pass original value to create_person_info
|
||||
|
||||
# Ensure platform and user_id are in creation_data if available,
|
||||
# otherwise create_person_info will use defaults.
|
||||
if data and "platform" in data:
|
||||
creation_data["platform"] = data["platform"]
|
||||
if data and "user_id" in data:
|
||||
creation_data["user_id"] = data["user_id"]
|
||||
|
||||
# 使用安全的创建方法,处理竞态条件
|
||||
await self._safe_create_person_info(person_id, creation_data)
|
||||
|
||||
@staticmethod
|
||||
async def has_one_field(person_id: str, field_name: str):
|
||||
"""判断是否存在某一个字段"""
|
||||
# 获取 SQLAlchemy 模型的所有字段名
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
if field_name not in model_fields:
|
||||
logger.debug(f"检查字段'{field_name}'失败,未在 PersonInfo SQLAlchemy 模型中定义。")
|
||||
return False
|
||||
|
||||
def _db_has_field_sync(p_id: str, f_name: str):
|
||||
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
|
||||
return bool(record)
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_db_has_field_sync, person_id, field_name)
|
||||
except Exception as e:
|
||||
logger.error(f"检查字段 {field_name} for {person_id} 时出错 (SQLAlchemy): {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _extract_json_from_text(text: str) -> dict:
|
||||
@@ -513,12 +759,13 @@ class PersonInfoManager:
|
||||
else:
|
||||
|
||||
def _db_check_name_exists_sync(name_to_check):
|
||||
return PersonInfo.select().where(PersonInfo.person_name == name_to_check).exists()
|
||||
return session.execute(select(PersonInfo).where(PersonInfo.person_name == name_to_check)).scalar() is not None
|
||||
|
||||
if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname):
|
||||
is_duplicate = True
|
||||
current_name_set.add(generated_nickname)
|
||||
|
||||
|
||||
if not is_duplicate:
|
||||
person.person_name = generated_nickname
|
||||
person.name_reason = result.get("reason", "未提供理由")
|
||||
@@ -547,4 +794,304 @@ class PersonInfoManager:
|
||||
return {"nickname": unique_nickname, "reason": "使用用户原始昵称作为默认值"}
|
||||
|
||||
|
||||
person_info_manager = PersonInfoManager()
|
||||
@staticmethod
|
||||
async def del_one_document(person_id: str):
|
||||
"""删除指定 person_id 的文档"""
|
||||
if not person_id:
|
||||
logger.debug("删除失败:person_id 不能为空")
|
||||
return
|
||||
|
||||
def _db_delete_sync(p_id: str):
|
||||
try:
|
||||
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
|
||||
if record:
|
||||
session.delete(record)
|
||||
session.commit()
|
||||
return 1
|
||||
return 0
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"删除 PersonInfo {p_id} 失败 (SQLAlchemy): {e}")
|
||||
return 0
|
||||
|
||||
deleted_count = await asyncio.to_thread(_db_delete_sync, person_id)
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"删除成功:person_id={person_id} (Peewee)")
|
||||
else:
|
||||
logger.debug(f"删除失败:未找到 person_id={person_id} 或删除未影响行 (Peewee)")
|
||||
|
||||
@staticmethod
|
||||
async def get_value(person_id: str, field_name: str):
|
||||
"""获取指定用户指定字段的值"""
|
||||
default_value_for_field = person_info_default.get(field_name)
|
||||
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
|
||||
default_value_for_field = [] # Ensure JSON fields default to [] if not in DB
|
||||
|
||||
def _db_get_value_sync(p_id: str, f_name: str):
|
||||
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
|
||||
if record:
|
||||
val = getattr(record, f_name, None)
|
||||
if f_name in JSON_SERIALIZED_FIELDS:
|
||||
if isinstance(val, str):
|
||||
try:
|
||||
return json.loads(val)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"字段 {f_name} for {p_id} 包含无效JSON: {val}. 返回默认值.")
|
||||
return [] # Default for JSON fields on error
|
||||
elif val is None: # Field exists in DB but is None
|
||||
return [] # Default for JSON fields
|
||||
# If val is already a list/dict (e.g. if somehow set without serialization)
|
||||
return val # Should ideally not happen if update_one_field is always used
|
||||
return val
|
||||
return None # Record not found
|
||||
|
||||
try:
|
||||
value_from_db = await asyncio.to_thread(_db_get_value_sync, person_id, field_name)
|
||||
if value_from_db is not None:
|
||||
return value_from_db
|
||||
if field_name in person_info_default:
|
||||
return default_value_for_field
|
||||
logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。")
|
||||
return None # Ultimate fallback
|
||||
except Exception as e:
|
||||
logger.error(f"获取字段 {field_name} for {person_id} 时出错 (Peewee): {e}")
|
||||
# Fallback to default in case of any error during DB access
|
||||
return default_value_for_field if field_name in person_info_default else None
|
||||
|
||||
@staticmethod
|
||||
def get_value_sync(person_id: str, field_name: str):
|
||||
"""同步获取指定用户指定字段的值"""
|
||||
default_value_for_field = person_info_default.get(field_name)
|
||||
if field_name in JSON_SERIALIZED_FIELDS and default_value_for_field is None:
|
||||
default_value_for_field = []
|
||||
|
||||
if record := session.execute(select(PersonInfo).where(PersonInfo.person_id == person_id)).scalar():
|
||||
val = getattr(record, field_name, None)
|
||||
if field_name in JSON_SERIALIZED_FIELDS:
|
||||
if isinstance(val, str):
|
||||
try:
|
||||
return json.loads(val)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"字段 {field_name} for {person_id} 包含无效JSON: {val}. 返回默认值.")
|
||||
return []
|
||||
elif val is None:
|
||||
return []
|
||||
return val
|
||||
return val
|
||||
|
||||
if field_name in person_info_default:
|
||||
return default_value_for_field
|
||||
logger.warning(f"字段 {field_name} 在 person_info_default 中未定义,且在数据库中未找到。")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def get_values(person_id: str, field_names: list) -> dict:
|
||||
"""获取指定person_id文档的多个字段值,若不存在该字段,则返回该字段的全局默认值"""
|
||||
if not person_id:
|
||||
logger.debug("get_values获取失败:person_id不能为空")
|
||||
return {}
|
||||
|
||||
result = {}
|
||||
|
||||
def _db_get_record_sync(p_id: str):
|
||||
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
|
||||
|
||||
record = await asyncio.to_thread(_db_get_record_sync, person_id)
|
||||
|
||||
# 获取 SQLAlchemy 模型的所有字段名
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
|
||||
for field_name in field_names:
|
||||
if field_name not in model_fields:
|
||||
if field_name in person_info_default:
|
||||
result[field_name] = copy.deepcopy(person_info_default[field_name])
|
||||
logger.debug(f"字段'{field_name}'不在SQLAlchemy模型中,使用默认配置值。")
|
||||
else:
|
||||
logger.debug(f"get_values查询失败:字段'{field_name}'未在SQLAlchemy模型和默认配置中定义。")
|
||||
result[field_name] = None
|
||||
continue
|
||||
|
||||
if record:
|
||||
value = getattr(record, field_name)
|
||||
if value is not None:
|
||||
result[field_name] = value
|
||||
else:
|
||||
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
|
||||
else:
|
||||
result[field_name] = copy.deepcopy(person_info_default.get(field_name))
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def get_specific_value_list(
|
||||
field_name: str,
|
||||
way: Callable[[Any], bool],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取满足条件的字段值字典
|
||||
"""
|
||||
# 获取 SQLAlchemy 模型的所有字段名
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
if field_name not in model_fields:
|
||||
logger.error(f"字段检查失败:'{field_name}'未在 PersonInfo SQLAlchemy 模 modelo中定义")
|
||||
return {}
|
||||
|
||||
def _db_get_specific_sync(f_name: str):
|
||||
found_results = {}
|
||||
try:
|
||||
for record in session.execute(select(PersonInfo.person_id, getattr(PersonInfo, f_name))).fetchall():
|
||||
value = getattr(record, f_name)
|
||||
if way(value):
|
||||
found_results[record.person_id] = value
|
||||
except Exception as e_query:
|
||||
logger.error(f"数据库查询失败 (SQLAlchemy specific_value_list for {f_name}): {str(e_query)}", exc_info=True)
|
||||
return found_results
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_db_get_specific_sync, field_name)
|
||||
except Exception as e:
|
||||
logger.error(f"执行 get_specific_value_list 线程时出错: {str(e)}", exc_info=True)
|
||||
return {}
|
||||
|
||||
async def get_or_create_person(
|
||||
self, platform: str, user_id: int, nickname: str, user_cardname: str, user_avatar: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
根据 platform 和 user_id 获取 person_id。
|
||||
如果对应的用户不存在,则使用提供的可选信息创建新用户。
|
||||
使用try-except处理竞态条件,避免重复创建错误。
|
||||
"""
|
||||
person_id = self.get_person_id(platform, user_id)
|
||||
|
||||
def _db_get_or_create_sync(p_id: str, init_data: dict):
|
||||
"""原子性的获取或创建操作"""
|
||||
# 首先尝试获取现有记录
|
||||
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
|
||||
if record:
|
||||
return record, False # 记录存在,未创建
|
||||
|
||||
# 记录不存在,尝试创建
|
||||
try:
|
||||
new_person = PersonInfo(**init_data)
|
||||
session.add(new_person)
|
||||
session.commit()
|
||||
return session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar(), True # 创建成功
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
# 如果创建失败(可能是因为竞态条件),再次尝试获取
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
logger.debug(f"检测到并发创建用户 {p_id},获取现有记录")
|
||||
record = session.execute(select(PersonInfo).where(PersonInfo.person_id == p_id)).scalar()
|
||||
if record:
|
||||
return record, False # 其他协程已创建,返回现有记录
|
||||
# 如果仍然失败,重新抛出异常
|
||||
raise e
|
||||
|
||||
unique_nickname = await self._generate_unique_person_name(nickname)
|
||||
initial_data = {
|
||||
"person_id": person_id,
|
||||
"platform": platform,
|
||||
"user_id": str(user_id),
|
||||
"nickname": nickname,
|
||||
"person_name": unique_nickname, # 使用群昵称作为person_name
|
||||
"name_reason": "从群昵称获取",
|
||||
"know_times": 0,
|
||||
"know_since": int(datetime.datetime.now().timestamp()),
|
||||
"last_know": int(datetime.datetime.now().timestamp()),
|
||||
"impression": None,
|
||||
"points": [],
|
||||
"forgotten_points": [],
|
||||
}
|
||||
|
||||
# 序列化JSON字段
|
||||
for key in JSON_SERIALIZED_FIELDS:
|
||||
if key in initial_data:
|
||||
if isinstance(initial_data[key], (list, dict)):
|
||||
initial_data[key] = json.dumps(initial_data[key], ensure_ascii=False)
|
||||
elif initial_data[key] is None:
|
||||
initial_data[key] = json.dumps([], ensure_ascii=False)
|
||||
|
||||
# 获取 SQLAlchemy 模odel的所有字段名
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
|
||||
|
||||
record, was_created = await asyncio.to_thread(_db_get_or_create_sync, person_id, filtered_initial_data)
|
||||
|
||||
if was_created:
|
||||
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。")
|
||||
logger.info(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}")
|
||||
else:
|
||||
logger.debug(f"用户 {platform}:{user_id} (person_id: {person_id}) 已存在,返回现有记录。")
|
||||
|
||||
return person_id
|
||||
|
||||
async def get_person_info_by_name(self, person_name: str) -> dict | None:
|
||||
"""根据 person_name 查找用户并返回基本信息 (如果找到)"""
|
||||
if not person_name:
|
||||
logger.debug("get_person_info_by_name 获取失败:person_name 不能为空")
|
||||
return None
|
||||
|
||||
found_person_id = None
|
||||
for pid, name_in_cache in self.person_name_list.items():
|
||||
if name_in_cache == person_name:
|
||||
found_person_id = pid
|
||||
break
|
||||
|
||||
if not found_person_id:
|
||||
|
||||
def _db_find_by_name_sync(p_name_to_find: str):
|
||||
return session.execute(select(PersonInfo).where(PersonInfo.person_name == p_name_to_find)).scalar()
|
||||
|
||||
record = await asyncio.to_thread(_db_find_by_name_sync, person_name)
|
||||
if record:
|
||||
found_person_id = record.person_id
|
||||
if (
|
||||
found_person_id not in self.person_name_list
|
||||
or self.person_name_list[found_person_id] != person_name
|
||||
):
|
||||
self.person_name_list[found_person_id] = person_name
|
||||
else:
|
||||
logger.debug(f"数据库中也未找到名为 '{person_name}' 的用户 (Peewee)")
|
||||
return None
|
||||
|
||||
if found_person_id:
|
||||
required_fields = [
|
||||
"person_id",
|
||||
"platform",
|
||||
"user_id",
|
||||
"nickname",
|
||||
"user_cardname",
|
||||
"user_avatar",
|
||||
"person_name",
|
||||
"name_reason",
|
||||
]
|
||||
# 获取 SQLAlchemy 模型的所有字段名
|
||||
model_fields = [column.name for column in PersonInfo.__table__.columns]
|
||||
valid_fields_to_get = [
|
||||
f
|
||||
for f in required_fields
|
||||
if f in model_fields or f in person_info_default
|
||||
]
|
||||
|
||||
person_data = await self.get_values(found_person_id, valid_fields_to_get)
|
||||
|
||||
if person_data:
|
||||
final_result = {key: person_data.get(key) for key in required_fields}
|
||||
return final_result
|
||||
else:
|
||||
logger.warning(f"找到了 person_id '{found_person_id}' 但 get_values 返回空 (Peewee)")
|
||||
return None
|
||||
|
||||
logger.error(f"逻辑错误:未能为 '{person_name}' 确定 person_id (Peewee)")
|
||||
return None
|
||||
|
||||
|
||||
person_info_manager = None
|
||||
|
||||
|
||||
def get_person_info_manager():
|
||||
global person_info_manager
|
||||
if person_info_manager is None:
|
||||
person_info_manager = PersonInfoManager()
|
||||
return person_info_manager
|
||||
|
||||
@@ -5,385 +5,25 @@
|
||||
from src.plugin_system.apis import database_api
|
||||
records = await database_api.db_query(ActionRecords, query_type="get")
|
||||
record = await database_api.db_save(ActionRecords, data={"action_id": "123"})
|
||||
|
||||
注意:此模块现在使用SQLAlchemy实现,提供更好的连接管理和错误处理
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from typing import Dict, List, Any, Union, Type, Optional
|
||||
from src.common.logger import get_logger
|
||||
from peewee import Model, DoesNotExist
|
||||
from src.common.database.sqlalchemy_database_api import (
|
||||
db_query,
|
||||
db_save,
|
||||
db_get,
|
||||
store_action_info,
|
||||
get_model_class,
|
||||
MODEL_MAPPING
|
||||
)
|
||||
|
||||
logger = get_logger("database_api")
|
||||
|
||||
# =============================================================================
|
||||
# 通用数据库查询API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def db_query(
|
||||
model_class: Type[Model],
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
query_type: Optional[str] = "get",
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[List[str]] = None,
|
||||
single_result: Optional[bool] = False,
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""执行数据库查询操作
|
||||
|
||||
这个方法提供了一个通用接口来执行数据库操作,包括查询、创建、更新和删除记录。
|
||||
|
||||
Args:
|
||||
model_class: Peewee 模型类,例如 ActionRecords, Messages 等
|
||||
data: 用于创建或更新的数据字典
|
||||
query_type: 查询类型,可选值: "get", "create", "update", "delete", "count"
|
||||
filters: 过滤条件字典,键为字段名,值为要匹配的值
|
||||
limit: 限制结果数量
|
||||
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段(即time字段)降序
|
||||
single_result: 是否只返回单个结果
|
||||
|
||||
Returns:
|
||||
根据查询类型返回不同的结果:
|
||||
- "get": 返回查询结果列表或单个结果(如果 single_result=True)
|
||||
- "create": 返回创建的记录
|
||||
- "update": 返回受影响的行数
|
||||
- "delete": 返回受影响的行数
|
||||
- "count": 返回记录数量
|
||||
"""
|
||||
"""
|
||||
示例:
|
||||
# 查询最近10条消息
|
||||
messages = await database_api.db_query(
|
||||
Messages,
|
||||
query_type="get",
|
||||
filters={"chat_id": chat_stream.stream_id},
|
||||
limit=10,
|
||||
order_by=["-time"]
|
||||
)
|
||||
|
||||
# 创建一条记录
|
||||
new_record = await database_api.db_query(
|
||||
ActionRecords,
|
||||
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"},
|
||||
query_type="create",
|
||||
)
|
||||
|
||||
# 更新记录
|
||||
updated_count = await database_api.db_query(
|
||||
ActionRecords,
|
||||
data={"action_done": True},
|
||||
query_type="update",
|
||||
filters={"action_id": "123"},
|
||||
)
|
||||
|
||||
# 删除记录
|
||||
deleted_count = await database_api.db_query(
|
||||
ActionRecords,
|
||||
query_type="delete",
|
||||
filters={"action_id": "123"}
|
||||
)
|
||||
|
||||
# 计数
|
||||
count = await database_api.db_query(
|
||||
Messages,
|
||||
query_type="count",
|
||||
filters={"chat_id": chat_stream.stream_id}
|
||||
)
|
||||
"""
|
||||
try:
|
||||
if query_type not in ["get", "create", "update", "delete", "count"]:
|
||||
raise ValueError("query_type must be 'get' or 'create' or 'update' or 'delete' or 'count'")
|
||||
# 构建基本查询
|
||||
if query_type in ["get", "update", "delete", "count"]:
|
||||
query = model_class.select()
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
query = query.where(getattr(model_class, field) == value)
|
||||
|
||||
# 执行查询
|
||||
if query_type == "get":
|
||||
# 应用排序
|
||||
if order_by:
|
||||
for field in order_by:
|
||||
if field.startswith("-"):
|
||||
query = query.order_by(getattr(model_class, field[1:]).desc())
|
||||
else:
|
||||
query = query.order_by(getattr(model_class, field))
|
||||
|
||||
# 应用限制
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
# 执行查询
|
||||
results = list(query.dicts())
|
||||
|
||||
# 返回结果
|
||||
if single_result:
|
||||
return results[0] if results else None
|
||||
return results
|
||||
|
||||
elif query_type == "create":
|
||||
if not data:
|
||||
raise ValueError("创建记录需要提供data参数")
|
||||
|
||||
# 创建记录
|
||||
record = model_class.create(**data)
|
||||
# 返回创建的记录
|
||||
return model_class.select().where(model_class.id == record.id).dicts().get() # type: ignore
|
||||
|
||||
elif query_type == "update":
|
||||
if not data:
|
||||
raise ValueError("更新记录需要提供data参数")
|
||||
|
||||
# 更新记录
|
||||
return query.update(**data).execute()
|
||||
|
||||
elif query_type == "delete":
|
||||
# 删除记录
|
||||
return query.delete().execute()
|
||||
|
||||
elif query_type == "count":
|
||||
# 计数
|
||||
return query.count()
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的查询类型: {query_type}")
|
||||
|
||||
except DoesNotExist:
|
||||
# 记录不存在
|
||||
return None if query_type == "get" and single_result else []
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseAPI] 数据库操作出错: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# 根据查询类型返回合适的默认值
|
||||
if query_type == "get":
|
||||
return None if single_result else []
|
||||
elif query_type in ["create", "update", "delete", "count"]:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
async def db_save(
|
||||
model_class: Type[Model], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
# sourcery skip: inline-immediately-returned-variable
|
||||
"""保存数据到数据库(创建或更新)
|
||||
|
||||
如果提供了key_field和key_value,会先尝试查找匹配的记录进行更新;
|
||||
如果没有找到匹配记录,或未提供key_field和key_value,则创建新记录。
|
||||
|
||||
Args:
|
||||
model_class: Peewee模型类,如ActionRecords, Messages等
|
||||
data: 要保存的数据字典
|
||||
key_field: 用于查找现有记录的字段名,例如"action_id"
|
||||
key_value: 用于查找现有记录的字段值
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 保存后的记录数据
|
||||
None: 如果操作失败
|
||||
|
||||
示例:
|
||||
# 创建或更新一条记录
|
||||
record = await database_api.db_save(
|
||||
ActionRecords,
|
||||
{
|
||||
"action_id": "123",
|
||||
"time": time.time(),
|
||||
"action_name": "TestAction",
|
||||
"action_done": True
|
||||
},
|
||||
key_field="action_id",
|
||||
key_value="123"
|
||||
)
|
||||
"""
|
||||
try:
|
||||
# 如果提供了key_field和key_value,尝试更新现有记录
|
||||
if key_field and key_value is not None:
|
||||
if existing_records := list(
|
||||
model_class.select().where(getattr(model_class, key_field) == key_value).limit(1)
|
||||
):
|
||||
# 更新现有记录
|
||||
existing_record = existing_records[0]
|
||||
for field, value in data.items():
|
||||
setattr(existing_record, field, value)
|
||||
existing_record.save()
|
||||
|
||||
# 返回更新后的记录
|
||||
updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get() # type: ignore
|
||||
return updated_record
|
||||
|
||||
# 如果没有找到现有记录或未提供key_field和key_value,创建新记录
|
||||
new_record = model_class.create(**data)
|
||||
|
||||
# 返回创建的记录
|
||||
created_record = model_class.select().where(model_class.id == new_record.id).dicts().get() # type: ignore
|
||||
return created_record
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseAPI] 保存数据库记录出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
async def db_get(
|
||||
model_class: Type[Model],
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[str] = None,
|
||||
single_result: Optional[bool] = False,
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""从数据库获取记录
|
||||
|
||||
这是db_query方法的简化版本,专注于数据检索操作。
|
||||
|
||||
Args:
|
||||
model_class: Peewee模型类
|
||||
filters: 过滤条件,字段名和值的字典
|
||||
limit: 结果数量限制
|
||||
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段(即time字段)降序
|
||||
single_result: 是否只返回单个结果,如果为True,则返回单个记录字典或None;否则返回记录字典列表或空列表
|
||||
|
||||
Returns:
|
||||
如果single_result为True,返回单个记录字典或None;
|
||||
否则返回记录字典列表或空列表。
|
||||
|
||||
示例:
|
||||
# 获取单个记录
|
||||
record = await database_api.db_get(
|
||||
ActionRecords,
|
||||
filters={"action_id": "123"},
|
||||
limit=1
|
||||
)
|
||||
|
||||
# 获取最近10条记录
|
||||
records = await database_api.db_get(
|
||||
Messages,
|
||||
filters={"chat_id": chat_stream.stream_id},
|
||||
limit=10,
|
||||
order_by="-time",
|
||||
)
|
||||
"""
|
||||
try:
|
||||
# 构建查询
|
||||
query = model_class.select()
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
query = query.where(getattr(model_class, field) == value)
|
||||
|
||||
# 应用排序
|
||||
if order_by:
|
||||
if order_by.startswith("-"):
|
||||
query = query.order_by(getattr(model_class, order_by[1:]).desc())
|
||||
else:
|
||||
query = query.order_by(getattr(model_class, order_by))
|
||||
|
||||
# 应用限制
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
# 执行查询
|
||||
results = list(query.dicts())
|
||||
|
||||
# 返回结果
|
||||
if single_result:
|
||||
return results[0] if results else None
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseAPI] 获取数据库记录出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None if single_result else []
|
||||
|
||||
|
||||
async def store_action_info(
|
||||
chat_stream=None,
|
||||
action_build_into_prompt: bool = False,
|
||||
action_prompt_display: str = "",
|
||||
action_done: bool = True,
|
||||
thinking_id: str = "",
|
||||
action_data: Optional[dict] = None,
|
||||
action_name: str = "",
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""存储动作信息到数据库
|
||||
|
||||
将Action执行的相关信息保存到ActionRecords表中,用于后续的记忆和上下文构建。
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象,包含聊天相关信息
|
||||
action_build_into_prompt: 是否将此动作构建到提示中
|
||||
action_prompt_display: 动作的提示显示文本
|
||||
action_done: 动作是否完成
|
||||
thinking_id: 关联的思考ID
|
||||
action_data: 动作数据字典
|
||||
action_name: 动作名称
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 保存的记录数据
|
||||
None: 如果保存失败
|
||||
|
||||
示例:
|
||||
record = await database_api.store_action_info(
|
||||
chat_stream=chat_stream,
|
||||
action_build_into_prompt=True,
|
||||
action_prompt_display="执行了回复动作",
|
||||
action_done=True,
|
||||
thinking_id="thinking_123",
|
||||
action_data={"content": "Hello"},
|
||||
action_name="reply_action"
|
||||
)
|
||||
"""
|
||||
try:
|
||||
import time
|
||||
import json
|
||||
from src.common.database.database_model import ActionRecords
|
||||
|
||||
# 构建动作记录数据
|
||||
record_data = {
|
||||
"action_id": thinking_id or str(int(time.time() * 1000000)), # 使用thinking_id或生成唯一ID
|
||||
"time": time.time(),
|
||||
"action_name": action_name,
|
||||
"action_data": json.dumps(action_data or {}, ensure_ascii=False),
|
||||
"action_done": action_done,
|
||||
"action_build_into_prompt": action_build_into_prompt,
|
||||
"action_prompt_display": action_prompt_display,
|
||||
}
|
||||
|
||||
# 从chat_stream获取聊天信息
|
||||
if chat_stream:
|
||||
record_data.update(
|
||||
{
|
||||
"chat_id": getattr(chat_stream, "stream_id", ""),
|
||||
"chat_info_stream_id": getattr(chat_stream, "stream_id", ""),
|
||||
"chat_info_platform": getattr(chat_stream, "platform", ""),
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 如果没有chat_stream,设置默认值
|
||||
record_data.update(
|
||||
{
|
||||
"chat_id": "",
|
||||
"chat_info_stream_id": "",
|
||||
"chat_info_platform": "",
|
||||
}
|
||||
)
|
||||
|
||||
# 使用已有的db_save函数保存记录
|
||||
saved_record = await db_save(
|
||||
ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"]
|
||||
)
|
||||
|
||||
if saved_record:
|
||||
logger.debug(f"[DatabaseAPI] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
|
||||
else:
|
||||
logger.error(f"[DatabaseAPI] 存储动作信息失败: {action_name}")
|
||||
|
||||
return saved_record
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseAPI] 存储动作信息时发生错误: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
# 保持向后兼容性
|
||||
__all__ = [
|
||||
'db_query',
|
||||
'db_save',
|
||||
'db_get',
|
||||
'store_action_info',
|
||||
'get_model_class',
|
||||
'MODEL_MAPPING'
|
||||
]
|
||||
|
||||
@@ -8,10 +8,12 @@ from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
|
||||
|
||||
__all__ = [
|
||||
"plugin_manager",
|
||||
"component_registry",
|
||||
"events_manager",
|
||||
"global_announcement_manager",
|
||||
"hot_reload_manager",
|
||||
]
|
||||
|
||||
@@ -237,35 +237,55 @@ class ComponentRegistry:
|
||||
logger.warning(f"组件 {component_name} 未注册,无法移除")
|
||||
return False
|
||||
try:
|
||||
# 根据组件类型进行特定的清理操作
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
self._action_registry.pop(component_name)
|
||||
self._default_actions.pop(component_name)
|
||||
# 移除Action注册
|
||||
self._action_registry.pop(component_name, None)
|
||||
self._default_actions.pop(component_name, None)
|
||||
logger.debug(f"已移除Action组件: {component_name}")
|
||||
|
||||
case ComponentType.COMMAND:
|
||||
self._command_registry.pop(component_name)
|
||||
# 移除Command注册和模式
|
||||
self._command_registry.pop(component_name, None)
|
||||
keys_to_remove = [k for k, v in self._command_patterns.items() if v == component_name]
|
||||
for key in keys_to_remove:
|
||||
self._command_patterns.pop(key)
|
||||
self._command_patterns.pop(key, None)
|
||||
logger.debug(f"已移除Command组件: {component_name} (清理了 {len(keys_to_remove)} 个模式)")
|
||||
|
||||
case ComponentType.TOOL:
|
||||
self._tool_registry.pop(component_name)
|
||||
self._llm_available_tools.pop(component_name)
|
||||
# 移除Tool注册
|
||||
self._tool_registry.pop(component_name, None)
|
||||
self._llm_available_tools.pop(component_name, None)
|
||||
logger.debug(f"已移除Tool组件: {component_name}")
|
||||
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
# 移除EventHandler注册和事件订阅
|
||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||
|
||||
self._event_handler_registry.pop(component_name)
|
||||
self._enabled_event_handlers.pop(component_name)
|
||||
await events_manager.unregister_event_subscriber(component_name)
|
||||
self._event_handler_registry.pop(component_name, None)
|
||||
self._enabled_event_handlers.pop(component_name, None)
|
||||
try:
|
||||
await events_manager.unregister_event_subscriber(component_name)
|
||||
logger.debug(f"已移除EventHandler组件: {component_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"移除EventHandler事件订阅时出错: {e}")
|
||||
|
||||
case _:
|
||||
logger.warning(f"未知的组件类型: {component_type}")
|
||||
return False
|
||||
|
||||
# 移除通用注册信息
|
||||
namespaced_name = f"{component_type}.{component_name}"
|
||||
self._components.pop(namespaced_name)
|
||||
self._components_by_type[component_type].pop(component_name)
|
||||
self._components_classes.pop(namespaced_name)
|
||||
logger.info(f"组件 {component_name} 已移除")
|
||||
self._components.pop(namespaced_name, None)
|
||||
self._components_by_type[component_type].pop(component_name, None)
|
||||
self._components_classes.pop(namespaced_name, None)
|
||||
|
||||
logger.info(f"组件 {component_name} ({component_type}) 已完全移除")
|
||||
return True
|
||||
except KeyError as e:
|
||||
logger.warning(f"移除组件时未找到组件: {component_name}, 发生错误: {e}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"移除组件 {component_name} 时发生错误: {e}")
|
||||
logger.error(f"移除组件 {component_name} ({component_type}) 时发生错误: {e}")
|
||||
return False
|
||||
|
||||
def remove_plugin_registry(self, plugin_name: str) -> bool:
|
||||
@@ -615,5 +635,54 @@ class ComponentRegistry:
|
||||
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
|
||||
}
|
||||
|
||||
# === 组件移除相关 ===
|
||||
|
||||
async def unregister_plugin(self, plugin_name: str) -> bool:
|
||||
"""卸载插件及其所有组件
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
bool: 是否成功卸载
|
||||
"""
|
||||
plugin_info = self.get_plugin_info(plugin_name)
|
||||
if not plugin_info:
|
||||
logger.warning(f"插件 {plugin_name} 未注册,无法卸载")
|
||||
return False
|
||||
|
||||
logger.info(f"开始卸载插件: {plugin_name}")
|
||||
|
||||
# 记录卸载失败的组件
|
||||
failed_components = []
|
||||
|
||||
# 逐个移除插件的所有组件
|
||||
for component_info in plugin_info.components:
|
||||
try:
|
||||
success = await self.remove_component(
|
||||
component_info.name,
|
||||
component_info.component_type,
|
||||
plugin_name,
|
||||
)
|
||||
if not success:
|
||||
failed_components.append(f"{component_info.component_type}.{component_info.name}")
|
||||
except Exception as e:
|
||||
logger.error(f"移除组件 {component_info.name} 时发生异常: {e}")
|
||||
failed_components.append(f"{component_info.component_type}.{component_info.name}")
|
||||
|
||||
# 移除插件注册信息
|
||||
plugin_removed = self.remove_plugin_registry(plugin_name)
|
||||
|
||||
if failed_components:
|
||||
logger.warning(f"插件 {plugin_name} 部分组件卸载失败: {failed_components}")
|
||||
return False
|
||||
elif not plugin_removed:
|
||||
logger.error(f"插件 {plugin_name} 注册信息移除失败")
|
||||
return False
|
||||
else:
|
||||
logger.info(f"插件 {plugin_name} 卸载成功")
|
||||
return True
|
||||
|
||||
|
||||
# 创建全局组件注册中心实例
|
||||
component_registry = ComponentRegistry()
|
||||
|
||||
@@ -33,7 +33,7 @@ class EventsManager:
|
||||
|
||||
if handler_name in self._handler_mapping:
|
||||
logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
|
||||
return False
|
||||
return True
|
||||
|
||||
if not issubclass(handler_class, BaseEventHandler):
|
||||
logger.error(f"类 {handler_class.__name__} 不是 BaseEventHandler 的子类")
|
||||
|
||||
242
src/plugin_system/core/plugin_hot_reload.py
Normal file
242
src/plugin_system/core/plugin_hot_reload.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""
|
||||
插件热重载模块
|
||||
|
||||
使用 Watchdog 监听插件目录变化,自动重载插件
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from typing import Dict, Set
|
||||
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .plugin_manager import plugin_manager
|
||||
|
||||
logger = get_logger("plugin_hot_reload")
|
||||
|
||||
|
||||
class PluginFileHandler(FileSystemEventHandler):
|
||||
"""插件文件变化处理器"""
|
||||
|
||||
def __init__(self, hot_reload_manager):
|
||||
super().__init__()
|
||||
self.hot_reload_manager = hot_reload_manager
|
||||
self.pending_reloads: Set[str] = set() # 待重载的插件名称
|
||||
self.last_reload_time: Dict[str, float] = {} # 上次重载时间
|
||||
self.debounce_delay = 1.0 # 防抖延迟(秒)
|
||||
|
||||
def on_modified(self, event):
|
||||
"""文件修改事件"""
|
||||
if not event.is_directory and (event.src_path.endswith('.py') or event.src_path.endswith('.toml')):
|
||||
self._handle_file_change(event.src_path, "modified")
|
||||
|
||||
def on_created(self, event):
|
||||
"""文件创建事件"""
|
||||
if not event.is_directory and (event.src_path.endswith('.py') or event.src_path.endswith('.toml')):
|
||||
self._handle_file_change(event.src_path, "created")
|
||||
|
||||
def on_deleted(self, event):
|
||||
"""文件删除事件"""
|
||||
if not event.is_directory and (event.src_path.endswith('.py') or event.src_path.endswith('.toml')):
|
||||
self._handle_file_change(event.src_path, "deleted")
|
||||
|
||||
def _handle_file_change(self, file_path: str, change_type: str):
|
||||
"""处理文件变化"""
|
||||
try:
|
||||
# 获取插件名称
|
||||
plugin_name = self._get_plugin_name_from_path(file_path)
|
||||
if not plugin_name:
|
||||
return
|
||||
|
||||
current_time = time.time()
|
||||
last_time = self.last_reload_time.get(plugin_name, 0)
|
||||
|
||||
# 防抖处理,避免频繁重载
|
||||
if current_time - last_time < self.debounce_delay:
|
||||
return
|
||||
|
||||
file_name = Path(file_path).name
|
||||
logger.info(f"📁 检测到插件文件变化: {file_name} ({change_type})")
|
||||
|
||||
# 如果是删除事件,处理关键文件删除
|
||||
if change_type == "deleted":
|
||||
if file_name == "plugin.py":
|
||||
if plugin_name in plugin_manager.loaded_plugins:
|
||||
logger.info(f"🗑️ 插件主文件被删除,卸载插件: {plugin_name}")
|
||||
self.hot_reload_manager._unload_plugin(plugin_name)
|
||||
return
|
||||
elif file_name == "manifest.toml":
|
||||
if plugin_name in plugin_manager.loaded_plugins:
|
||||
logger.info(f"🗑️ 插件配置文件被删除,卸载插件: {plugin_name}")
|
||||
self.hot_reload_manager._unload_plugin(plugin_name)
|
||||
return
|
||||
|
||||
# 对于修改和创建事件,都进行重载
|
||||
# 添加到待重载列表
|
||||
self.pending_reloads.add(plugin_name)
|
||||
self.last_reload_time[plugin_name] = current_time
|
||||
|
||||
# 延迟重载,避免文件正在写入时重载
|
||||
reload_thread = Thread(
|
||||
target=self._delayed_reload,
|
||||
args=(plugin_name,),
|
||||
daemon=True
|
||||
)
|
||||
reload_thread.start()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 处理文件变化时发生错误: {e}")
|
||||
|
||||
def _delayed_reload(self, plugin_name: str):
|
||||
"""延迟重载插件"""
|
||||
try:
|
||||
time.sleep(self.debounce_delay)
|
||||
|
||||
if plugin_name in self.pending_reloads:
|
||||
self.pending_reloads.remove(plugin_name)
|
||||
self.hot_reload_manager._reload_plugin(plugin_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 延迟重载插件 {plugin_name} 时发生错误: {e}")
|
||||
|
||||
def _get_plugin_name_from_path(self, file_path: str) -> str:
|
||||
"""从文件路径获取插件名称"""
|
||||
try:
|
||||
path = Path(file_path)
|
||||
|
||||
# 检查是否在监听的插件目录中
|
||||
plugin_root = Path(self.hot_reload_manager.watch_directory)
|
||||
if not path.is_relative_to(plugin_root):
|
||||
return ""
|
||||
|
||||
# 获取插件目录名(插件名)
|
||||
relative_path = path.relative_to(plugin_root)
|
||||
plugin_name = relative_path.parts[0]
|
||||
|
||||
# 确认这是一个有效的插件目录(检查是否有 plugin.py 或 manifest.toml)
|
||||
plugin_dir = plugin_root / plugin_name
|
||||
if plugin_dir.is_dir() and ((plugin_dir / "plugin.py").exists() or (plugin_dir / "manifest.toml").exists()):
|
||||
return plugin_name
|
||||
|
||||
return ""
|
||||
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
class PluginHotReloadManager:
|
||||
"""插件热重载管理器"""
|
||||
|
||||
def __init__(self, watch_directory: str = None):
|
||||
print("fuck")
|
||||
print(os.getcwd())
|
||||
self.watch_directory = os.path.join(os.getcwd(), "plugins")
|
||||
self.observer = None
|
||||
self.file_handler = None
|
||||
self.is_running = False
|
||||
|
||||
# 确保监听目录存在
|
||||
if not os.path.exists(self.watch_directory):
|
||||
os.makedirs(self.watch_directory, exist_ok=True)
|
||||
logger.info(f"创建插件监听目录: {self.watch_directory}")
|
||||
|
||||
def start(self):
|
||||
"""启动热重载监听"""
|
||||
if self.is_running:
|
||||
logger.warning("插件热重载已经在运行中")
|
||||
return
|
||||
|
||||
try:
|
||||
self.observer = Observer()
|
||||
self.file_handler = PluginFileHandler(self)
|
||||
|
||||
self.observer.schedule(
|
||||
self.file_handler,
|
||||
self.watch_directory,
|
||||
recursive=True
|
||||
)
|
||||
|
||||
self.observer.start()
|
||||
self.is_running = True
|
||||
|
||||
logger.info("🚀 插件热重载已启动,监听目录: plugins")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 启动插件热重载失败: {e}")
|
||||
self.is_running = False
|
||||
|
||||
def stop(self):
|
||||
"""停止热重载监听"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
if self.observer:
|
||||
self.observer.stop()
|
||||
self.observer.join()
|
||||
|
||||
self.is_running = False
|
||||
|
||||
def _reload_plugin(self, plugin_name: str):
|
||||
"""重载指定插件"""
|
||||
try:
|
||||
logger.info(f"🔄 开始重载插件: {plugin_name}")
|
||||
|
||||
if plugin_manager.reload_plugin(plugin_name):
|
||||
logger.info(f"✅ 插件重载成功: {plugin_name}")
|
||||
else:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 重载插件 {plugin_name} 时发生错误: {e}")
|
||||
|
||||
def _unload_plugin(self, plugin_name: str):
|
||||
"""卸载指定插件"""
|
||||
try:
|
||||
logger.info(f"🗑️ 开始卸载插件: {plugin_name}")
|
||||
|
||||
if plugin_manager.unload_plugin(plugin_name):
|
||||
logger.info(f"✅ 插件卸载成功: {plugin_name}")
|
||||
else:
|
||||
logger.error(f"❌ 插件卸载失败: {plugin_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 卸载插件 {plugin_name} 时发生错误: {e}")
|
||||
|
||||
def reload_all_plugins(self):
|
||||
"""重载所有插件"""
|
||||
try:
|
||||
logger.info("🔄 开始重载所有插件...")
|
||||
|
||||
# 获取当前已加载的插件列表
|
||||
loaded_plugins = list(plugin_manager.loaded_plugins.keys())
|
||||
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
|
||||
for plugin_name in loaded_plugins:
|
||||
if plugin_manager.reload_plugin(plugin_name):
|
||||
success_count += 1
|
||||
else:
|
||||
fail_count += 1
|
||||
|
||||
logger.info(f"✅ 插件重载完成: 成功 {success_count} 个,失败 {fail_count} 个")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 重载所有插件时发生错误: {e}")
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""获取热重载状态"""
|
||||
return {
|
||||
"is_running": self.is_running,
|
||||
"watch_directory": self.watch_directory,
|
||||
"loaded_plugins": len(plugin_manager.loaded_plugins),
|
||||
"failed_plugins": len(plugin_manager.failed_plugins),
|
||||
}
|
||||
|
||||
|
||||
# 全局热重载管理器实例
|
||||
hot_reload_manager = PluginHotReloadManager()
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import traceback
|
||||
import sys
|
||||
|
||||
from typing import Dict, List, Optional, Tuple, Type, Any
|
||||
from importlib.util import spec_from_file_location, module_from_spec
|
||||
@@ -488,6 +489,105 @@ class PluginManager:
|
||||
else:
|
||||
logger.info(f"✅ 插件加载成功: {plugin_name}")
|
||||
|
||||
# === 插件卸载和重载管理 ===
|
||||
|
||||
def unload_plugin(self, plugin_name: str) -> bool:
|
||||
"""卸载指定插件
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
bool: 卸载是否成功
|
||||
"""
|
||||
if plugin_name not in self.loaded_plugins:
|
||||
logger.warning(f"插件 {plugin_name} 未加载,无需卸载")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 获取插件实例
|
||||
plugin_instance = self.loaded_plugins[plugin_name]
|
||||
|
||||
# 调用插件的清理方法(如果有的话)
|
||||
if hasattr(plugin_instance, 'on_unload'):
|
||||
plugin_instance.on_unload()
|
||||
|
||||
# 从组件注册表中移除插件的所有组件
|
||||
component_registry.unregister_plugin(plugin_name)
|
||||
|
||||
# 从已加载插件中移除
|
||||
del self.loaded_plugins[plugin_name]
|
||||
|
||||
# 从失败列表中移除(如果存在)
|
||||
if plugin_name in self.failed_plugins:
|
||||
del self.failed_plugins[plugin_name]
|
||||
|
||||
logger.info(f"✅ 插件卸载成功: {plugin_name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 插件卸载失败: {plugin_name} - {str(e)}")
|
||||
return False
|
||||
|
||||
def reload_plugin(self, plugin_name: str) -> bool:
|
||||
"""重载指定插件
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
bool: 重载是否成功
|
||||
"""
|
||||
try:
|
||||
# 先卸载插件
|
||||
if plugin_name in self.loaded_plugins:
|
||||
self.unload_plugin(plugin_name)
|
||||
|
||||
# 清除Python模块缓存
|
||||
plugin_path = self.plugin_paths.get(plugin_name)
|
||||
if plugin_path:
|
||||
plugin_file = os.path.join(plugin_path, "plugin.py")
|
||||
if os.path.exists(plugin_file):
|
||||
# 从sys.modules中移除相关模块
|
||||
modules_to_remove = []
|
||||
plugin_module_prefix = ".".join(Path(plugin_file).parent.parts)
|
||||
|
||||
for module_name in sys.modules:
|
||||
if module_name.startswith(plugin_module_prefix):
|
||||
modules_to_remove.append(module_name)
|
||||
|
||||
for module_name in modules_to_remove:
|
||||
del sys.modules[module_name]
|
||||
|
||||
# 从插件类注册表中移除
|
||||
if plugin_name in self.plugin_classes:
|
||||
del self.plugin_classes[plugin_name]
|
||||
|
||||
# 重新加载插件模块
|
||||
if self._load_plugin_module_file(plugin_file):
|
||||
# 重新加载插件实例
|
||||
success, _ = self.load_registered_plugin_classes(plugin_name)
|
||||
if success:
|
||||
logger.info(f"🔄 插件重载成功: {plugin_name}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - 实例化失败")
|
||||
return False
|
||||
else:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - 模块加载失败")
|
||||
return False
|
||||
else:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - 插件文件不存在")
|
||||
return False
|
||||
else:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - 插件路径未知")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - {str(e)}")
|
||||
logger.debug("详细错误信息: ", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
# 全局插件管理器实例
|
||||
plugin_manager = PluginManager()
|
||||
|
||||
@@ -1,149 +1,149 @@
|
||||
from src.plugin_system.apis.plugin_register_api import register_plugin
|
||||
from src.plugin_system.base.base_plugin import BasePlugin
|
||||
from src.plugin_system.base.component_types import ComponentInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_action import BaseAction, ActionActivationType, ChatMode
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from typing import Tuple, List, Type
|
||||
|
||||
logger = get_logger("tts")
|
||||
|
||||
|
||||
class TTSAction(BaseAction):
|
||||
"""TTS语音转换动作处理类"""
|
||||
|
||||
# 激活设置
|
||||
focus_activation_type = ActionActivationType.LLM_JUDGE
|
||||
normal_activation_type = ActionActivationType.KEYWORD
|
||||
mode_enable = ChatMode.ALL
|
||||
parallel_action = False
|
||||
|
||||
# 动作基本信息
|
||||
action_name = "tts_action"
|
||||
action_description = "将文本转换为语音进行播放,适用于需要语音输出的场景"
|
||||
|
||||
# 关键词配置 - Normal模式下使用关键词触发
|
||||
activation_keywords = ["语音", "tts", "播报", "读出来", "语音播放", "听", "朗读"]
|
||||
keyword_case_sensitive = False
|
||||
|
||||
# 动作参数定义
|
||||
action_parameters = {
|
||||
"text": "需要转换为语音的文本内容,必填,内容应当适合语音播报,语句流畅、清晰",
|
||||
}
|
||||
|
||||
# 动作使用场景
|
||||
action_require = [
|
||||
"当需要发送语音信息时使用",
|
||||
"当用户明确要求使用语音功能时使用",
|
||||
"当表达内容更适合用语音而不是文字传达时使用",
|
||||
"当用户想听到语音回答而非阅读文本时使用",
|
||||
]
|
||||
|
||||
# 关联类型
|
||||
associated_types = ["tts_text"]
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""处理TTS文本转语音动作"""
|
||||
logger.info(f"{self.log_prefix} 执行TTS动作: {self.reasoning}")
|
||||
|
||||
# 获取要转换的文本
|
||||
text = self.action_data.get("text")
|
||||
|
||||
if not text:
|
||||
logger.error(f"{self.log_prefix} 执行TTS动作时未提供文本内容")
|
||||
return False, "执行TTS动作失败:未提供文本内容"
|
||||
|
||||
# 确保文本适合TTS使用
|
||||
processed_text = self._process_text_for_tts(text)
|
||||
|
||||
try:
|
||||
# 发送TTS消息
|
||||
await self.send_custom(message_type="tts_text", content=processed_text)
|
||||
|
||||
# 记录动作信息
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True, action_prompt_display="已经发送了语音消息。", action_done=True
|
||||
)
|
||||
|
||||
logger.info(f"{self.log_prefix} TTS动作执行成功,文本长度: {len(processed_text)}")
|
||||
return True, "TTS动作执行成功"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行TTS动作时出错: {e}")
|
||||
return False, f"执行TTS动作时出错: {e}"
|
||||
|
||||
def _process_text_for_tts(self, text: str) -> str:
|
||||
"""
|
||||
处理文本使其更适合TTS使用
|
||||
- 移除不必要的特殊字符和表情符号
|
||||
- 修正标点符号以提高语音质量
|
||||
- 优化文本结构使语音更流畅
|
||||
"""
|
||||
# 这里可以添加文本处理逻辑
|
||||
# 例如:移除多余的标点、表情符号,优化语句结构等
|
||||
|
||||
# 简单示例实现
|
||||
processed_text = text
|
||||
|
||||
# 移除多余的标点符号
|
||||
import re
|
||||
|
||||
processed_text = re.sub(r"([!?,.;:。!?,、;:])\1+", r"\1", processed_text)
|
||||
|
||||
# 确保句子结尾有合适的标点
|
||||
if not any(processed_text.endswith(end) for end in [".", "?", "!", "。", "!", "?"]):
|
||||
processed_text = f"{processed_text}。"
|
||||
|
||||
return processed_text
|
||||
|
||||
|
||||
@register_plugin
|
||||
class TTSPlugin(BasePlugin):
|
||||
"""TTS插件
|
||||
- 这是文字转语音插件
|
||||
- Normal模式下依靠关键词触发
|
||||
- Focus模式下由LLM判断触发
|
||||
- 具有一定的文本预处理能力
|
||||
"""
|
||||
|
||||
# 插件基本信息
|
||||
plugin_name: str = "tts_plugin" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = [] # 插件依赖列表
|
||||
python_dependencies: list[str] = [] # Python包依赖列表
|
||||
config_file_name: str = "config.toml"
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {
|
||||
"plugin": "插件基本信息配置",
|
||||
"components": "组件启用控制",
|
||||
"logging": "日志记录相关配置",
|
||||
}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema: dict = {
|
||||
"plugin": {
|
||||
"name": ConfigField(type=str, default="tts_plugin", description="插件名称", required=True),
|
||||
"version": ConfigField(type=str, default="0.1.0", description="插件版本号"),
|
||||
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
||||
"description": ConfigField(type=str, default="文字转语音插件", description="插件描述", required=True),
|
||||
},
|
||||
"components": {"enable_tts": ConfigField(type=bool, default=True, description="是否启用TTS Action")},
|
||||
"logging": {
|
||||
"level": ConfigField(
|
||||
type=str, default="INFO", description="日志记录级别", choices=["DEBUG", "INFO", "WARNING", "ERROR"]
|
||||
),
|
||||
"prefix": ConfigField(type=str, default="[TTS]", description="日志记录前缀"),
|
||||
},
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
"""返回插件包含的组件列表"""
|
||||
|
||||
# 从配置获取组件启用状态
|
||||
enable_tts = self.get_config("components.enable_tts", True)
|
||||
components = [] # 添加Action组件
|
||||
if enable_tts:
|
||||
components.append((TTSAction.get_action_info(), TTSAction))
|
||||
|
||||
return components
|
||||
from src.plugin_system.apis.plugin_register_api import register_plugin
|
||||
from src.plugin_system.base.base_plugin import BasePlugin
|
||||
from src.plugin_system.base.component_types import ComponentInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_action import BaseAction, ActionActivationType, ChatMode
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from typing import Tuple, List, Type
|
||||
|
||||
logger = get_logger("tts")
|
||||
|
||||
|
||||
class TTSAction(BaseAction):
|
||||
"""TTS语音转换动作处理类"""
|
||||
|
||||
# 激活设置
|
||||
focus_activation_type = ActionActivationType.LLM_JUDGE
|
||||
normal_activation_type = ActionActivationType.KEYWORD
|
||||
mode_enable = ChatMode.ALL
|
||||
parallel_action = False
|
||||
|
||||
# 动作基本信息
|
||||
action_name = "tts_action"
|
||||
action_description = "将文本转换为语音进行播放,适用于需要语音输出的场景"
|
||||
|
||||
# 关键词配置 - Normal模式下使用关键词触发
|
||||
activation_keywords = ["语音", "tts", "播报", "读出来", "语音播放", "听", "朗读"]
|
||||
keyword_case_sensitive = False
|
||||
|
||||
# 动作参数定义
|
||||
action_parameters = {
|
||||
"text": "需要转换为语音的文本内容,必填,内容应当适合语音播报,语句流畅、清晰",
|
||||
}
|
||||
|
||||
# 动作使用场景
|
||||
action_require = [
|
||||
"当需要发送语音信息时使用",
|
||||
"当用户明确要求使用语音功能时使用",
|
||||
"当表达内容更适合用语音而不是文字传达时使用",
|
||||
"当用户想听到语音回答而非阅读文本时使用",
|
||||
]
|
||||
|
||||
# 关联类型
|
||||
associated_types = ["tts_text"]
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""处理TTS文本转语音动作"""
|
||||
logger.info(f"{self.log_prefix} 执行TTS动作: {self.reasoning}")
|
||||
|
||||
# 获取要转换的文本
|
||||
text = self.action_data.get("text")
|
||||
|
||||
if not text:
|
||||
logger.error(f"{self.log_prefix} 执行TTS动作时未提供文本内容")
|
||||
return False, "执行TTS动作失败:未提供文本内容"
|
||||
|
||||
# 确保文本适合TTS使用
|
||||
processed_text = self._process_text_for_tts(text)
|
||||
|
||||
try:
|
||||
# 发送TTS消息
|
||||
await self.send_custom(message_type="tts_text", content=processed_text)
|
||||
|
||||
# 记录动作信息
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True, action_prompt_display="已经发送了语音消息。", action_done=True
|
||||
)
|
||||
|
||||
logger.info(f"{self.log_prefix} TTS动作执行成功,文本长度: {len(processed_text)}")
|
||||
return True, "TTS动作执行成功"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行TTS动作时出错: {e}")
|
||||
return False, f"执行TTS动作时出错: {e}"
|
||||
|
||||
def _process_text_for_tts(self, text: str) -> str:
|
||||
"""
|
||||
处理文本使其更适合TTS使用
|
||||
- 移除不必要的特殊字符和表情符号
|
||||
- 修正标点符号以提高语音质量
|
||||
- 优化文本结构使语音更流畅
|
||||
"""
|
||||
# 这里可以添加文本处理逻辑
|
||||
# 例如:移除多余的标点、表情符号,优化语句结构等
|
||||
|
||||
# 简单示例实现
|
||||
processed_text = text
|
||||
|
||||
# 移除多余的标点符号
|
||||
import re
|
||||
|
||||
processed_text = re.sub(r"([!?,.;:。!?,、;:])\1+", r"\1", processed_text)
|
||||
|
||||
# 确保句子结尾有合适的标点
|
||||
if not any(processed_text.endswith(end) for end in [".", "?", "!", "。", "!", "?"]):
|
||||
processed_text = f"{processed_text}。"
|
||||
|
||||
return processed_text
|
||||
|
||||
|
||||
@register_plugin
|
||||
class TTSPlugin(BasePlugin):
|
||||
"""TTS插件
|
||||
- 这是文字转语音插件
|
||||
- Normal模式下依靠关键词触发
|
||||
- Focus模式下由LLM判断触发
|
||||
- 具有一定的文本预处理能力
|
||||
"""
|
||||
|
||||
# 插件基本信息
|
||||
plugin_name: str = "tts_plugin" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: list[str] = [] # 插件依赖列表
|
||||
python_dependencies: list[str] = [] # Python包依赖列表
|
||||
config_file_name: str = "config.toml"
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {
|
||||
"plugin": "插件基本信息配置",
|
||||
"components": "组件启用控制",
|
||||
"logging": "日志记录相关配置",
|
||||
}
|
||||
|
||||
# 配置Schema定义
|
||||
config_schema: dict = {
|
||||
"plugin": {
|
||||
"name": ConfigField(type=str, default="tts_plugin", description="插件名称", required=True),
|
||||
"version": ConfigField(type=str, default="0.1.0", description="插件版本号"),
|
||||
"enabled": ConfigField(type=bool, default=True, description="是否启用插件"),
|
||||
"description": ConfigField(type=str, default="文字转语音插件", description="插件描述", required=True),
|
||||
},
|
||||
"components": {"enable_tts": ConfigField(type=bool, default=True, description="是否启用TTS Action")},
|
||||
"logging": {
|
||||
"level": ConfigField(
|
||||
type=str, default="INFO", description="日志记录级别", choices=["DEBUG", "INFO", "WARNING", "ERROR"]
|
||||
),
|
||||
"prefix": ConfigField(type=str, default="[TTS]", description="日志记录前缀"),
|
||||
},
|
||||
}
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
"""返回插件包含的组件列表"""
|
||||
|
||||
# 从配置获取组件启用状态
|
||||
enable_tts = self.get_config("components.enable_tts", True)
|
||||
components = [] # 添加Action组件
|
||||
if enable_tts:
|
||||
components.append((TTSAction.get_action_info(), TTSAction))
|
||||
|
||||
return components
|
||||
|
||||
Reference in New Issue
Block a user